From de6ed9ba9fd59f753bbfd4c9a88ff15e03a89de7 Mon Sep 17 00:00:00 2001 From: Sergii Tkachenko Date: Fri, 9 Jun 2023 18:08:55 -0400 Subject: [PATCH] [Python] Migrate from yapf to black (#33138) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Switched from yapf to black - Reconfigure isort for black - Resolve black/pylint idiosyncrasies Note: I used `--experimental-string-processing` because black was producing "implicit string concatenation", similar to what described here: https://github.com/psf/black/issues/1837. While currently this feature is experimental, it will be enabled by default: https://github.com/psf/black/issues/2188. After running black with the new string processing so that the generated code merges these `"hello" " world"` strings concatenations, then I removed `--experimental-string-processing` for stability, and regenerated the code again. To the reviewer: don't even try to open "Files Changed" tab 😄 It's better to review commit-by-commit, and ignore `run black and isort`. --- .gitignore | 2 +- black.toml | 51 + examples/python/async_streaming/client.py | 40 +- examples/python/async_streaming/server.py | 15 +- examples/python/auth/_credentials.py | 8 +- .../auth/async_customized_auth_client.py | 43 +- .../auth/async_customized_auth_server.py | 58 +- .../python/auth/customized_auth_client.py | 36 +- .../python/auth/customized_auth_server.py | 51 +- .../python/auth/test/_auth_example_test.py | 20 +- examples/python/cancellation/client.py | 66 +- examples/python/cancellation/search.py | 35 +- examples/python/cancellation/server.py | 45 +- .../test/_cancellation_example_test.py | 53 +- examples/python/compression/client.py | 48 +- examples/python/compression/server.py | 68 +- .../test/compression_example_test.py | 34 +- .../python/data_transmission/alts_client.py | 6 +- .../python/data_transmission/alts_server.py | 9 +- examples/python/data_transmission/client.py | 51 +- examples/python/data_transmission/server.py | 43 +- examples/python/debug/asyncio_debug_server.py | 39 +- examples/python/debug/asyncio_get_stats.py | 19 +- examples/python/debug/asyncio_send_message.py | 40 +- examples/python/debug/debug_server.py | 33 +- examples/python/debug/get_stats.py | 16 +- examples/python/debug/send_message.py | 35 +- .../python/debug/test/_debug_example_test.py | 18 +- examples/python/errors/client.py | 14 +- examples/python/errors/server.py | 27 +- .../test/_error_handling_example_test.py | 8 +- .../python/health_checking/greeter_client.py | 11 +- .../python/health_checking/greeter_server.py | 19 +- .../async_greeter_client.py | 17 +- .../async_greeter_server.py | 6 +- .../python/helloworld/async_greeter_client.py | 6 +- .../async_greeter_client_with_options.py | 20 +- .../python/helloworld/async_greeter_server.py | 13 +- ...c_greeter_server_with_graceful_shutdown.py | 17 +- .../async_greeter_server_with_reflection.py | 15 +- examples/python/helloworld/greeter_client.py | 6 +- .../helloworld/greeter_client_reflection.py | 12 +- .../helloworld/greeter_client_with_options.py | 20 +- examples/python/helloworld/greeter_server.py | 9 +- .../greeter_server_with_reflection.py | 9 +- .../async/async_greeter_client.py | 17 +- .../async_greeter_server_with_interceptor.py | 37 +- .../default_value_client_interceptor.py | 22 +- .../default_value/greeter_client.py | 21 +- .../headers/generic_client_interceptor.py | 38 +- .../interceptors/headers/greeter_client.py | 18 +- .../interceptors/headers/greeter_server.py | 25 +- .../header_manipulator_client_interceptor.py | 35 +- .../request_header_validator_interceptor.py | 8 +- examples/python/keep_alive/greeter_client.py | 26 +- examples/python/keep_alive/greeter_server.py | 31 +- examples/python/lb_policies/greeter_client.py | 6 +- examples/python/lb_policies/greeter_server.py | 9 +- examples/python/metadata/metadata_client.py | 25 +- examples/python/metadata/metadata_server.py | 19 +- examples/python/multiplex/multiplex_client.py | 26 +- examples/python/multiplex/multiplex_server.py | 44 +- .../python/multiplex/route_guide_resources.py | 12 +- examples/python/multiplex/run_codegen.py | 32 +- examples/python/multiprocessing/client.py | 35 +- examples/python/multiprocessing/server.py | 31 +- .../test/_multiprocessing_example_test.py | 30 +- examples/python/no_codegen/greeter_client.py | 6 +- examples/python/no_codegen/greeter_server.py | 7 +- examples/python/retry/async_retry_client.py | 47 +- examples/python/retry/flaky_server.py | 30 +- examples/python/retry/retry_client.py | 44 +- .../route_guide/asyncio_route_guide_client.py | 36 +- .../route_guide/asyncio_route_guide_server.py | 67 +- .../python/route_guide/route_guide_client.py | 18 +- .../route_guide/route_guide_resources.py | 12 +- .../python/route_guide/route_guide_server.py | 35 +- examples/python/route_guide/run_codegen.py | 16 +- examples/python/timeout/greeter_client.py | 18 +- examples/python/timeout/greeter_server.py | 7 +- examples/python/uds/async_greeter_client.py | 9 +- examples/python/uds/async_greeter_server.py | 15 +- examples/python/uds/greeter_client.py | 8 +- examples/python/uds/greeter_server.py | 9 +- .../asyncio_wait_for_ready_example.py | 41 +- .../test/_wait_for_ready_example_test.py | 6 +- .../wait_for_ready/wait_for_ready_example.py | 41 +- ...eady_with_client_timeout_example_client.py | 25 +- ...eady_with_client_timeout_example_server.py | 15 +- examples/python/xds/client.py | 13 +- examples/python/xds/server.py | 74 +- setup.cfg | 4 - setup.py | 489 +- src/abseil-cpp/gen_build_yaml.py | 15 +- .../preprocessed_builds.yaml.gen.py | 290 +- src/benchmark/gen_build_yaml.py | 34 +- src/boringssl/gen_build_yaml.py | 160 +- src/c-ares/gen_build_yaml.py | 231 +- .../Grpc.Tools.Tests/scripts/fakeprotoc.py | 104 +- src/objective-c/change-comments.py | 19 +- src/php/bin/xds_manager.py | 72 +- src/proto/gen_build_yaml.py | 30 +- src/python/grpcio/_parallel_compile_patch.py | 31 +- src/python/grpcio/_spawn_patch.py | 19 +- src/python/grpcio/commands.py | 211 +- src/python/grpcio/grpc/__init__.py | 652 +-- src/python/grpcio/grpc/_auth.py | 38 +- src/python/grpcio/grpc/_channel.py | 1030 ++-- src/python/grpcio/grpc/_common.py | 64 +- src/python/grpcio/grpc/_compression.py | 30 +- src/python/grpcio/grpc/_interceptor.py | 567 ++- src/python/grpcio/grpc/_observability.py | 33 +- src/python/grpcio/grpc/_plugin_wrapping.py | 57 +- src/python/grpcio/grpc/_runtime_protos.py | 12 +- src/python/grpcio/grpc/_server.py | 804 ++-- src/python/grpcio/grpc/_simple_stubs.py | 215 +- src/python/grpcio/grpc/_typing.py | 87 +- src/python/grpcio/grpc/_utilities.py | 43 +- src/python/grpcio/grpc/aio/__init__.py | 66 +- src/python/grpcio/grpc/aio/_base_call.py | 29 +- src/python/grpcio/grpc/aio/_base_channel.py | 21 +- src/python/grpcio/grpc/aio/_base_server.py | 24 +- src/python/grpcio/grpc/aio/_call.py | 290 +- src/python/grpcio/grpc/aio/_channel.py | 361 +- src/python/grpcio/grpc/aio/_interceptor.py | 580 ++- src/python/grpcio/grpc/aio/_metadata.py | 2 +- src/python/grpcio/grpc/aio/_server.py | 87 +- src/python/grpcio/grpc/aio/_typing.py | 16 +- .../grpcio/grpc/beta/_client_adaptations.py | 1017 ++-- src/python/grpcio/grpc/beta/_metadata.py | 24 +- .../grpcio/grpc/beta/_server_adaptations.py | 236 +- .../grpcio/grpc/beta/implementations.py | 324 +- src/python/grpcio/grpc/beta/interfaces.py | 124 +- src/python/grpcio/grpc/beta/utilities.py | 36 +- .../grpcio/grpc/experimental/__init__.py | 34 +- .../grpc/framework/common/cardinality.py | 8 +- .../grpcio/grpc/framework/common/style.py | 4 +- .../grpc/framework/foundation/abandonment.py | 6 +- .../framework/foundation/callable_util.py | 64 +- .../grpc/framework/foundation/future.py | 164 +- .../grpc/framework/foundation/logging_pool.py | 23 +- .../grpc/framework/foundation/stream.py | 12 +- .../grpc/framework/interfaces/base/base.py | 313 +- .../framework/interfaces/base/utilities.py | 78 +- .../grpc/framework/interfaces/face/face.py | 1333 +++--- .../framework/interfaces/face/utilities.py | 283 +- src/python/grpcio/support.py | 63 +- .../grpcio_admin/grpc_admin/__init__.py | 4 +- src/python/grpcio_admin/setup.py | 46 +- .../grpcio_channelz/channelz_commands.py | 19 +- .../grpc_channelz/v1/_async.py | 31 +- .../grpc_channelz/v1/_servicer.py | 8 +- .../grpc_channelz/v1/channelz.py | 10 +- src/python/grpcio_channelz/setup.py | 68 +- src/python/grpcio_csds/grpc_csds/__init__.py | 14 +- src/python/grpcio_csds/setup.py | 48 +- .../grpc_health/v1/_async.py | 37 +- .../grpc_health/v1/health.py | 46 +- .../grpcio_health_checking/health_commands.py | 16 +- src/python/grpcio_health_checking/setup.py | 79 +- .../grpc_observability/__init__.py | 2 +- .../grpc_observability/_gcp_observability.py | 66 +- .../grpc_observability/_observability.py | 2 + .../_open_census_exporter.py | 9 +- .../grpc_reflection/v1alpha/_async.py | 34 +- .../grpc_reflection/v1alpha/_base.py | 51 +- .../proto_reflection_descriptor_database.py | 48 +- .../grpc_reflection/v1alpha/reflection.py | 43 +- .../grpcio_reflection/reflection_commands.py | 20 +- src/python/grpcio_reflection/setup.py | 79 +- .../grpcio_status/grpc_status/_async.py | 12 +- .../grpcio_status/grpc_status/_common.py | 4 +- .../grpcio_status/grpc_status/rpc_status.py | 34 +- src/python/grpcio_status/setup.py | 74 +- src/python/grpcio_status/status_commands.py | 17 +- .../grpcio_testing/grpc_testing/__init__.py | 24 +- .../grpc_testing/_channel/_channel.py | 29 +- .../grpc_testing/_channel/_channel_rpc.py | 26 +- .../grpc_testing/_channel/_channel_state.py | 21 +- .../grpc_testing/_channel/_invocation.py | 16 +- .../grpc_testing/_channel/_multi_callable.py | 100 +- .../grpc_testing/_channel/_rpc_state.py | 51 +- .../grpcio_testing/grpc_testing/_common.py | 79 +- .../grpc_testing/_server/__init__.py | 5 +- .../grpc_testing/_server/_handler.py | 13 +- .../grpc_testing/_server/_rpc.py | 16 +- .../grpc_testing/_server/_server.py | 147 +- .../grpc_testing/_server/_server_rpc.py | 4 - .../grpc_testing/_server/_service.py | 11 +- .../grpc_testing/_server/_servicer_context.py | 10 +- .../grpcio_testing/grpc_testing/_time.py | 36 +- src/python/grpcio_testing/setup.py | 40 +- src/python/grpcio_testing/testing_commands.py | 6 +- src/python/grpcio_tests/commands.py | 155 +- src/python/grpcio_tests/setup.py | 87 +- src/python/grpcio_tests/tests/_loader.py | 66 +- src/python/grpcio_tests/tests/_result.py | 451 +- src/python/grpcio_tests/tests/_runner.py | 113 +- .../tests/_sanity/_sanity_test.py | 23 +- .../grpcio_tests/tests/admin/test_admin.py | 10 +- .../tests/channelz/_channelz_servicer_test.py | 288 +- .../grpcio_tests/tests/csds/test_csds.py | 51 +- .../tests/fork/_fork_interop_test.py | 56 +- src/python/grpcio_tests/tests/fork/client.py | 53 +- src/python/grpcio_tests/tests/fork/methods.py | 184 +- .../health_check/_health_servicer_test.py | 223 +- .../tests/http2/negative_http2_client.py | 118 +- .../tests/interop/_insecure_intraop_test.py | 23 +- .../tests/interop/_intraop_test_case.py | 15 +- .../tests/interop/_secure_intraop_test.py | 37 +- .../grpcio_tests/tests/interop/client.py | 178 +- .../grpcio_tests/tests/interop/methods.py | 250 +- .../grpcio_tests/tests/interop/resources.py | 12 +- .../grpcio_tests/tests/interop/server.py | 42 +- .../grpcio_tests/tests/interop/service.py | 46 +- .../observability/_observability_test.py | 151 +- .../protoc_plugin/_python_plugin_test.py | 335 +- .../protoc_plugin/_split_definitions_test.py | 250 +- .../protoc_plugin/beta_python_plugin_test.py | 393 +- .../tests/qps/benchmark_client.py | 77 +- .../tests/qps/benchmark_server.py | 11 +- .../grpcio_tests/tests/qps/client_runner.py | 7 +- .../grpcio_tests/tests/qps/histogram.py | 4 +- .../grpcio_tests/tests/qps/qps_worker.py | 25 +- .../grpcio_tests/tests/qps/worker_server.py | 78 +- .../reflection/_reflection_client_test.py | 15 +- .../reflection/_reflection_servicer_test.py | 140 +- .../tests/status/_grpc_status_test.py | 64 +- .../grpcio_tests/tests/stress/client.py | 122 +- .../tests/stress/metrics_server.py | 5 +- .../grpcio_tests/tests/stress/test_runner.py | 7 +- .../tests/stress/unary_stream_benchmark.py | 41 +- .../tests/testing/_application_common.py | 10 +- .../testing/_application_testing_common.py | 22 +- .../tests/testing/_client_application.py | 68 +- .../tests/testing/_client_test.py | 344 +- .../tests/testing/_server_application.py | 43 +- .../tests/testing/_server_test.py | 115 +- .../grpcio_tests/tests/testing/_time_test.py | 23 +- .../grpcio_tests/tests/unit/_abort_test.py | 41 +- .../grpcio_tests/tests/unit/_api_test.py | 157 +- .../tests/unit/_auth_context_test.py | 202 +- .../grpcio_tests/tests/unit/_auth_test.py | 19 +- .../tests/unit/_channel_args_test.py | 39 +- .../tests/unit/_channel_close_test.py | 54 +- .../tests/unit/_channel_connectivity_test.py | 94 +- .../tests/unit/_channel_ready_future_test.py | 21 +- .../tests/unit/_compression_test.py | 328 +- .../unit/_contextvars_propagation_test.py | 37 +- .../tests/unit/_credentials_test.py | 39 +- .../unit/_cython/_cancel_many_calls_test.py | 130 +- .../tests/unit/_cython/_channel_test.py | 14 +- .../tests/unit/_cython/_common.py | 65 +- .../tests/unit/_cython/_fork_test.py | 11 +- ...s_server_completion_queue_per_call_test.py | 180 +- ...ges_single_server_completion_queue_test.py | 163 +- .../_read_some_but_not_all_responses_test.py | 191 +- .../tests/unit/_cython/_server_test.py | 18 +- .../tests/unit/_cython/cygrpc_test.py | 501 +- .../tests/unit/_cython/test_utilities.py | 11 +- .../tests/unit/_dns_resolver_test.py | 19 +- .../tests/unit/_dynamic_stubs_test.py | 24 +- .../tests/unit/_empty_message_test.py | 41 +- .../unit/_error_message_encoding_test.py | 28 +- .../tests/unit/_exit_scenarios.py | 98 +- .../grpcio_tests/tests/unit/_exit_test.py | 153 +- .../tests/unit/_from_grpc_import_star.py | 6 +- .../tests/unit/_grpc_shutdown_test.py | 9 +- .../tests/unit/_interceptor_test.py | 663 ++- .../tests/unit/_invalid_metadata_test.py | 71 +- .../tests/unit/_invocation_defects_test.py | 241 +- .../tests/unit/_local_credentials_test.py | 54 +- .../grpcio_tests/tests/unit/_logging_test.py | 21 +- .../tests/unit/_metadata_code_details_test.py | 523 +- .../tests/unit/_metadata_flags_test.py | 87 +- .../grpcio_tests/tests/unit/_metadata_test.py | 146 +- .../tests/unit/_reconnect_test.py | 25 +- .../tests/unit/_resource_exhausted_test.py | 68 +- .../tests/unit/_rpc_part_1_test.py | 140 +- .../tests/unit/_rpc_part_2_test.py | 257 +- .../tests/unit/_rpc_test_helpers.py | 360 +- .../tests/unit/_server_shutdown_scenarios.py | 23 +- .../tests/unit/_server_shutdown_test.py | 21 +- .../unit/_server_ssl_cert_config_test.py | 310 +- .../grpcio_tests/tests/unit/_server_test.py | 46 +- .../unit/_server_wait_for_termination_test.py | 41 +- .../tests/unit/_session_cache_test.py | 139 +- .../grpcio_tests/tests/unit/_signal_client.py | 44 +- .../tests/unit/_signal_handling_test.py | 84 +- .../grpcio_tests/tests/unit/_tcp_proxy.py | 20 +- .../grpcio_tests/tests/unit/_version_test.py | 3 +- .../tests/unit/_xds_credentials_test.py | 40 +- .../tests/unit/beta/_beta_features_test.py | 259 +- .../unit/beta/_connectivity_channel_test.py | 3 +- .../tests/unit/beta/_implementations_test.py | 30 +- .../tests/unit/beta/_not_found_test.py | 39 +- .../tests/unit/beta/_utilities_test.py | 15 +- .../tests/unit/beta/test_utilities.py | 41 +- .../tests/unit/framework/common/__init__.py | 47 +- .../unit/framework/common/test_control.py | 30 +- .../foundation/_logging_pool_test.py | 9 +- .../framework/foundation/stream_testing.py | 8 +- .../grpcio_tests/tests/unit/resources.py | 32 +- .../grpcio_tests/tests/unit/test_common.py | 100 +- .../tests_aio/_sanity/_sanity_test.py | 7 +- .../tests_aio/benchmark/benchmark_client.py | 99 +- .../tests_aio/benchmark/benchmark_servicer.py | 12 +- .../tests_aio/benchmark/server.py | 9 +- .../tests_aio/benchmark/worker.py | 27 +- .../tests_aio/benchmark/worker_servicer.py | 182 +- .../channelz/channelz_servicer_test.py | 244 +- .../health_check/health_servicer_test.py | 166 +- .../grpcio_tests/tests_aio/interop/client.py | 17 +- .../tests_aio/interop/local_interop_test.py | 92 +- .../grpcio_tests/tests_aio/interop/methods.py | 284 +- .../grpcio_tests/tests_aio/interop/server.py | 9 +- .../reflection/reflection_servicer_test.py | 135 +- .../tests_aio/status/grpc_status_test.py | 60 +- .../grpcio_tests/tests_aio/unit/_common.py | 23 +- .../grpcio_tests/tests_aio/unit/_constants.py | 2 +- .../tests_aio/unit/_metadata_test.py | 19 +- .../grpcio_tests/tests_aio/unit/_test_base.py | 7 +- .../tests_aio/unit/_test_server.py | 107 +- .../grpcio_tests/tests_aio/unit/abort_test.py | 31 +- .../tests_aio/unit/aio_rpc_error_test.py | 44 +- .../tests_aio/unit/auth_context_test.py | 207 +- .../grpcio_tests/tests_aio/unit/call_test.py | 186 +- .../tests_aio/unit/channel_argument_test.py | 101 +- .../tests_aio/unit/channel_ready_test.py | 17 +- .../tests_aio/unit/channel_test.py | 99 +- .../client_stream_stream_interceptor_test.py | 111 +- .../client_stream_unary_interceptor_test.py | 286 +- .../client_unary_stream_interceptor_test.py | 233 +- .../client_unary_unary_interceptor_test.py | 461 +- .../tests_aio/unit/close_channel_test.py | 5 +- .../tests_aio/unit/compatibility_test.py | 191 +- .../tests_aio/unit/compression_test.py | 69 +- .../tests_aio/unit/connectivity_test.py | 60 +- .../tests_aio/unit/context_peer_test.py | 22 +- .../tests_aio/unit/done_callback_test.py | 64 +- .../grpcio_tests/tests_aio/unit/init_test.py | 9 +- .../tests_aio/unit/metadata_test.py | 214 +- .../tests_aio/unit/outside_init_test.py | 5 +- .../tests_aio/unit/secure_call_test.py | 62 +- .../tests_aio/unit/server_interceptor_test.py | 269 +- .../tests_aio/unit/server_test.py | 250 +- .../unit/server_time_remaining_test.py | 19 +- .../tests_aio/unit/timeout_test.py | 50 +- .../unit/wait_for_connection_test.py | 30 +- .../tests_aio/unit/wait_for_ready_test.py | 49 +- .../tests_gevent/unit/_test_server.py | 26 +- .../tests_gevent/unit/close_channel_test.py | 5 +- .../interop/xds_interop_client.py | 234 +- .../interop/xds_interop_client_test.py | 89 +- .../interop/xds_interop_server.py | 94 +- .../tests_py3_only/unit/_leak_test.py | 24 +- .../tests_py3_only/unit/_simple_stubs_test.py | 168 +- src/re2/gen_build_yaml.py | 44 +- src/upb/gen_build_yaml.py | 336 +- src/zlib/gen_build_yaml.py | 45 +- ..._client_examples_of_bad_closing_streams.py | 20 +- test/core/http/test_server.py | 50 +- test/cpp/naming/gen_build_yaml.py | 87 +- .../manual_run_resolver_component_test.py | 38 +- test/cpp/naming/utils/dns_resolver.py | 52 +- test/cpp/naming/utils/dns_server.py | 127 +- .../run_dns_server_for_lb_interop_tests.py | 133 +- test/cpp/naming/utils/tcp_connect.py | 46 +- .../qps/json_run_localhost_scenario_gen.py | 5 +- test/cpp/qps/qps_json_driver_scenario_gen.py | 5 +- test/cpp/qps/scenario_generator_helper.py | 49 +- test/distrib/bazel/python/helloworld.py | 29 +- test/distrib/bazel/python/helloworld_moved.py | 30 +- .../upper/example/import_no_strip_test.py | 14 +- .../upper/example/import_strip_test.py | 4 +- .../upper/example/no_import_no_strip_test.py | 14 +- .../upper/example/no_import_strip_test.py | 4 +- test/distrib/gcf/python/main.py | 4 +- test/distrib/python/distribtest.py | 4 +- test/http2_test/http2_base_server.py | 196 +- test/http2_test/http2_server_health_check.py | 12 +- test/http2_test/http2_test_server.py | 67 +- test/http2_test/test_data_frame_padding.py | 52 +- test/http2_test/test_goaway.py | 28 +- test/http2_test/test_max_streams.py | 19 +- test/http2_test/test_ping.py | 22 +- test/http2_test/test_rst_after_data.py | 9 +- test/http2_test/test_rst_after_header.py | 5 +- test/http2_test/test_rst_during_data.py | 16 +- tools/buildgen/_mako_renderer.py | 81 +- tools/buildgen/_utils.py | 9 +- tools/buildgen/build_cleaner.py | 56 +- .../extract_metadata_from_bazel_xml.py | 981 ++-- tools/buildgen/generate_projects.py | 109 +- tools/buildgen/plugins/check_attrs.py | 173 +- tools/buildgen/plugins/expand_bin_attrs.py | 36 +- tools/buildgen/plugins/expand_version.py | 110 +- tools/buildgen/plugins/list_api.py | 42 +- tools/buildgen/plugins/list_protos.py | 22 +- .../plugins/transitive_dependencies.py | 17 +- .../plugins/verify_duplicate_sources.py | 18 +- tools/codegen/core/experiments_compiler.py | 283 +- tools/codegen/core/gen_config_vars.py | 328 +- tools/codegen/core/gen_experiments.py | 71 +- .../core/gen_grpc_tls_credentials_options.py | 441 +- tools/codegen/core/gen_header_frame.py | 107 +- tools/codegen/core/gen_if_list.py | 39 +- ..._registered_method_bad_client_test_body.py | 45 +- tools/codegen/core/gen_settings_ids.py | 229 +- tools/codegen/core/gen_stats_data.py | 406 +- tools/codegen/core/gen_switch.py | 43 +- .../core/gen_upb_api_from_bazel_xml.py | 150 +- .../codegen/core/optimize_arena_pool_sizes.py | 24 +- tools/debug/core/chttp2_ref_leak.py | 9 +- tools/debug/core/error_ref_leak.py | 10 +- tools/distrib/add-iwyu.py | 67 +- tools/distrib/{yapf_code.sh => black_code.sh} | 10 +- tools/distrib/c-ish/check_documentation.py | 26 +- tools/distrib/check_copyright.py | 340 +- tools/distrib/check_include_guards.py | 179 +- tools/distrib/check_naked_includes.py | 38 +- .../distrib/check_namespace_qualification.py | 52 +- .../check_redundant_namespace_qualifiers.py | 77 +- tools/distrib/fix_build_deps.py | 729 ++- tools/distrib/gen_compilation_database.py | 49 +- tools/distrib/isort_code.sh | 20 +- tools/distrib/pylint_code.sh | 5 +- tools/distrib/python/check_grpcio_tools.py | 7 +- tools/distrib/python/docgen.py | 118 +- .../distrib/python/grpc_prefixed/generate.py | 121 +- .../grpcio_tools/_parallel_compile_patch.py | 31 +- .../python/grpcio_tools/grpc_tools/command.py | 42 +- .../python/grpcio_tools/grpc_tools/protoc.py | 76 +- .../grpc_tools/test/protoc_test.py | 41 +- tools/distrib/python/grpcio_tools/setup.py | 196 +- tools/distrib/python/make_grpcio_tools.py | 107 +- tools/distrib/python/xds_protos/build.py | 130 +- tools/distrib/python/xds_protos/setup.py | 39 +- tools/distrib/run_buildozer.py | 6 +- tools/distrib/run_clang_tidy.py | 40 +- tools/distrib/sanitize.sh | 2 +- tools/distrib/update_flakes.py | 48 +- tools/gcp/utils/big_query_utils.py | 221 +- tools/interop_matrix/client_matrix.py | 1434 +++--- tools/interop_matrix/create_matrix_images.py | 382 +- .../run_interop_matrix_tests.py | 288 +- tools/mkowners/mkowners.py | 103 +- tools/profiling/bloat/bloat_diff.py | 136 +- tools/profiling/ios_bin/binary_size.py | 127 +- tools/profiling/ios_bin/parse_link_map.py | 55 +- tools/profiling/memory/memory_diff.py | 120 +- tools/profiling/microbenchmarks/bm2bq.py | 27 +- .../microbenchmarks/bm_diff/bm_build.py | 67 +- .../microbenchmarks/bm_diff/bm_constants.py | 41 +- .../microbenchmarks/bm_diff/bm_diff.py | 174 +- .../microbenchmarks/bm_diff/bm_main.py | 158 +- .../microbenchmarks/bm_diff/bm_run.py | 138 +- tools/profiling/microbenchmarks/bm_json.py | 267 +- tools/profiling/qps/qps_diff.py | 112 +- tools/profiling/qps/qps_scenarios.py | 30 +- tools/release/release_notes.py | 150 +- tools/release/verify_python_release.py | 27 +- tools/run_tests/artifacts/artifact_targets.py | 509 +- .../artifacts/distribtest_targets.py | 509 +- tools/run_tests/artifacts/package_targets.py | 144 +- .../lb_interop_tests/gen_build_yaml.py | 380 +- .../run_tests/performance/bq_upload_result.py | 425 +- .../performance/loadtest_concat_yaml.py | 33 +- .../run_tests/performance/loadtest_config.py | 551 ++- .../performance/loadtest_template.py | 261 +- .../patch_scenario_results_schema.py | 36 +- tools/run_tests/performance/prometheus.py | 170 +- .../run_tests/performance/scenario_config.py | 2004 ++++---- .../performance/scenario_config_exporter.py | 204 +- .../python_utils/bazel_report_helper.py | 269 +- tools/run_tests/python_utils/check_on_pr.py | 165 +- tools/run_tests/python_utils/dockerjob.py | 102 +- .../python_utils/download_and_unzip.py | 8 +- .../python_utils/filter_pull_request_tests.py | 223 +- tools/run_tests/python_utils/jobset.py | 457 +- tools/run_tests/python_utils/port_server.py | 146 +- tools/run_tests/python_utils/report_utils.py | 181 +- .../python_utils/start_port_server.py | 107 +- .../python_utils/upload_rbe_results.py | 347 +- .../python_utils/upload_test_results.py | 192 +- tools/run_tests/python_utils/watch_dirs.py | 7 +- tools/run_tests/run_grpclb_interop_tests.py | 634 +-- tools/run_tests/run_interop_tests.py | 1436 +++--- tools/run_tests/run_microbenchmark.py | 112 +- tools/run_tests/run_performance_tests.py | 801 ++-- tools/run_tests/run_tests.py | 1648 ++++--- tools/run_tests/run_tests_matrix.py | 785 ++-- tools/run_tests/run_xds_tests.py | 4185 ++++++++++------- .../sanity/check_banned_filenames.py | 4 +- .../run_tests/sanity/check_bazel_workspace.py | 162 +- .../sanity/check_deprecated_grpc++.py | 84 +- tools/run_tests/sanity/check_include_style.py | 35 +- tools/run_tests/sanity/check_package_name.py | 55 +- tools/run_tests/sanity/check_port_platform.py | 64 +- .../sanity/check_qps_scenario_changes.py | 22 +- .../run_tests/sanity/check_test_filtering.py | 217 +- tools/run_tests/sanity/check_tracer_sanity.py | 17 +- tools/run_tests/sanity/check_version.py | 55 +- .../run_tests/sanity/core_banned_functions.py | 91 +- tools/run_tests/sanity/sanity_tests.yaml | 2 +- tools/run_tests/task_runner.py | 108 +- tools/run_tests/xds_k8s_test_driver/README.md | 2 +- .../bin/{yapf.sh => black.sh} | 14 +- .../bin/cleanup/cleanup.py | 419 +- .../bin/cleanup/namespace.py | 9 +- .../xds_k8s_test_driver/bin/isort.sh | 3 +- .../xds_k8s_test_driver/bin/lib/common.py | 93 +- .../xds_k8s_test_driver/bin/run_channelz.py | 134 +- .../xds_k8s_test_driver/bin/run_ping_pong.py | 83 +- .../xds_k8s_test_driver/bin/run_td_setup.py | 294 +- .../bin/run_test_client.py | 84 +- .../bin/run_test_server.py | 68 +- .../framework/bootstrap_generator_testcase.py | 90 +- .../framework/helpers/datetime.py | 6 +- .../framework/helpers/grpc.py | 13 +- .../framework/helpers/highlighter.py | 47 +- .../framework/helpers/logs.py | 6 +- .../framework/helpers/rand.py | 12 +- .../framework/helpers/retryers.py | 192 +- .../framework/helpers/skips.py | 27 +- .../framework/infrastructure/gcp/api.py | 312 +- .../framework/infrastructure/gcp/compute.py | 566 ++- .../framework/infrastructure/gcp/iam.py | 206 +- .../infrastructure/gcp/network_security.py | 118 +- .../infrastructure/gcp/network_services.py | 155 +- .../framework/infrastructure/k8s.py | 422 +- .../k8s_internal/k8s_log_collector.py | 68 +- .../k8s_internal/k8s_port_forwarder.py | 70 +- .../infrastructure/traffic_director.py | 609 ++- .../xds_k8s_test_driver/framework/rpc/grpc.py | 49 +- .../framework/rpc/grpc_channelz.py | 117 +- .../framework/rpc/grpc_csds.py | 32 +- .../framework/rpc/grpc_testing.py | 120 +- .../framework/test_app/client_app.py | 187 +- .../framework/test_app/runners/base_runner.py | 32 +- .../test_app/runners/k8s/k8s_base_runner.py | 456 +- .../runners/k8s/k8s_xds_client_runner.py | 138 +- .../runners/k8s/k8s_xds_server_runner.py | 190 +- .../framework/test_app/server_app.py | 95 +- .../framework/xds_flags.py | 139 +- .../framework/xds_k8s_flags.py | 56 +- .../framework/xds_k8s_testcase.py | 755 +-- .../framework/xds_url_map_test_resources.py | 155 +- .../framework/xds_url_map_testcase.py | 339 +- .../xds_k8s_test_driver/requirements-dev.txt | 4 +- .../tests/affinity_test.py | 117 +- .../tests/api_listener_test.py | 75 +- .../xds_k8s_test_driver/tests/app_net_test.py | 31 +- .../xds_k8s_test_driver/tests/authz_test.py | 290 +- .../tests/baseline_test.py | 23 +- .../tests/bootstrap_generator_test.py | 93 +- .../tests/change_backend_service_test.py | 64 +- .../tests/custom_lb_test.py | 73 +- .../tests/failover_test.py | 81 +- .../tests/outlier_detection_test.py | 65 +- .../tests/remove_neg_test.py | 54 +- .../tests/round_robin_test.py | 48 +- .../tests/security_test.py | 113 +- .../tests/subsetting_test.py | 71 +- .../tests/url_map/__main__.py | 9 +- .../tests/url_map/affinity_test.py | 140 +- .../tests/url_map/csds_test.py | 23 +- .../tests/url_map/fault_injection_test.py | 395 +- .../tests/url_map/header_matching_test.py | 491 +- .../tests/url_map/metadata_filter_test.py | 402 +- .../tests/url_map/path_matching_test.py | 217 +- .../tests/url_map/retry_test.py | 148 +- .../tests/url_map/timeout_test.py | 119 +- 573 files changed, 42016 insertions(+), 29841 deletions(-) create mode 100644 black.toml rename tools/distrib/{yapf_code.sh => black_code.sh} (75%) rename tools/run_tests/xds_k8s_test_driver/bin/{yapf.sh => black.sh} (82%) diff --git a/.gitignore b/.gitignore index b81937ccfb448..d9a8d08ad412e 100644 --- a/.gitignore +++ b/.gitignore @@ -23,7 +23,7 @@ src/python/grpcio_*/=* src/python/grpcio_*/build/ src/python/grpcio_*/LICENSE src/python/grpcio_status/grpc_status/google/rpc/status.proto -yapf_virtual_environment/ +black_virtual_environment/ isort_virtual_environment/ # Node installation output diff --git a/black.toml b/black.toml new file mode 100644 index 0000000000000..55087061282b1 --- /dev/null +++ b/black.toml @@ -0,0 +1,51 @@ +[tool.black] +line-length = 80 +target-version = [ + "py37", + "py38", + "py39", + "py310", + "py311", +] +extend-exclude = ''' +# A regex preceded with ^/ will apply only to files and directories +# in the root of the project. +( + site-packages + | test/cpp/naming/resolver_component_tests_runner.py # AUTO-GENERATED + # AUTO-GENERATED from a template: + | grpc_version.py + | src/python/grpcio/grpc_core_dependencies.py + | src/python/grpcio/grpc/_grpcio_metadata.py + # AUTO-GENERATED BY make_grpcio_tools.py + | tools/distrib/python/grpcio_tools/protoc_lib_deps.py + | .*_pb2.py # autogenerated Protocol Buffer files + | .*_pb2_grpc.py # autogenerated Protocol Buffer gRPC files +) +''' + +[tool.isort] +profile = "black" +line_length = 80 +src_paths = [ + "examples/python/data_transmission", + "examples/python/async_streaming", + "tools/run_tests/xds_k8s_test_driver", + "src/python/grpcio_tests", + "tools/run_tests", +] +known_first_party = [ + "examples", + "src", +] +known_third_party = ["grpc"] +skip_glob = [ + "third_party/*", + "*/env/*", + "*pb2*.py", + "*pb2*.pyi", + "**/site-packages/**/*", +] +single_line_exclusions = ["typing"] +force_single_line = true +force_sort_within_sections = true diff --git a/examples/python/async_streaming/client.py b/examples/python/async_streaming/client.py index cb6eaf4876714..e1346c28feaef 100644 --- a/examples/python/async_streaming/client.py +++ b/examples/python/async_streaming/client.py @@ -24,9 +24,12 @@ class CallMaker: - - def __init__(self, executor: ThreadPoolExecutor, channel: grpc.Channel, - phone_number: str) -> None: + def __init__( + self, + executor: ThreadPoolExecutor, + channel: grpc.Channel, + phone_number: str, + ) -> None: self._executor = executor self._channel = channel self._stub = phone_pb2_grpc.PhoneStub(self._channel) @@ -39,8 +42,8 @@ def __init__(self, executor: ThreadPoolExecutor, channel: grpc.Channel, self._consumer_future = None def _response_watcher( - self, - response_iterator: Iterator[phone_pb2.StreamCallResponse]) -> None: + self, response_iterator: Iterator[phone_pb2.StreamCallResponse] + ) -> None: try: for response in response_iterator: # NOTE: All fields in Proto3 are optional. This is the recommended way @@ -52,7 +55,8 @@ def _response_watcher( self._on_call_state(response.call_state.state) else: raise RuntimeError( - "Received StreamCallResponse without call_info and call_state" + "Received StreamCallResponse without call_info and" + " call_state" ) except Exception as e: self._peer_responded.set() @@ -63,8 +67,11 @@ def _on_call_info(self, call_info: phone_pb2.CallInfo) -> None: self._audio_session_link = call_info.media def _on_call_state(self, call_state: phone_pb2.CallState.State) -> None: - logging.info("Call toward [%s] enters [%s] state", self._phone_number, - phone_pb2.CallState.State.Name(call_state)) + logging.info( + "Call toward [%s] enters [%s] state", + self._phone_number, + phone_pb2.CallState.State.Name(call_state), + ) self._call_state = call_state if call_state == phone_pb2.CallState.State.ACTIVE: self._peer_responded.set() @@ -77,8 +84,9 @@ def call(self) -> None: request.phone_number = self._phone_number response_iterator = self._stub.StreamCall(iter((request,))) # Instead of consuming the response on current thread, spawn a consumption thread. - self._consumer_future = self._executor.submit(self._response_watcher, - response_iterator) + self._consumer_future = self._executor.submit( + self._response_watcher, response_iterator + ) def wait_peer(self) -> bool: logging.info("Waiting for peer to connect [%s]...", self._phone_number) @@ -95,8 +103,9 @@ def audio_session(self) -> None: logging.info("Audio session finished [%s]", self._audio_session_link) -def process_call(executor: ThreadPoolExecutor, channel: grpc.Channel, - phone_number: str) -> None: +def process_call( + executor: ThreadPoolExecutor, channel: grpc.Channel, phone_number: str +) -> None: call_maker = CallMaker(executor, channel, phone_number) call_maker.call() if call_maker.wait_peer(): @@ -109,11 +118,12 @@ def process_call(executor: ThreadPoolExecutor, channel: grpc.Channel, def run(): executor = ThreadPoolExecutor() with grpc.insecure_channel("localhost:50051") as channel: - future = executor.submit(process_call, executor, channel, - "555-0100-XXXX") + future = executor.submit( + process_call, executor, channel, "555-0100-XXXX" + ) future.result() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) run() diff --git a/examples/python/async_streaming/server.py b/examples/python/async_streaming/server.py index ae2a10dcd9441..75b36430ee667 100644 --- a/examples/python/async_streaming/server.py +++ b/examples/python/async_streaming/server.py @@ -26,14 +26,14 @@ def create_state_response( - call_state: phone_pb2.CallState.State) -> phone_pb2.StreamCallResponse: + call_state: phone_pb2.CallState.State, +) -> phone_pb2.StreamCallResponse: response = phone_pb2.StreamCallResponse() response.call_state.state = call_state return response class Phone(phone_pb2_grpc.PhoneServicer): - def __init__(self): self._id_counter = 0 self._lock = threading.RLock() @@ -51,13 +51,16 @@ def _clean_call_session(self, call_info: phone_pb2.CallInfo) -> None: logging.info("Call session cleaned [%s]", MessageToJson(call_info)) def StreamCall( - self, request_iterator: Iterable[phone_pb2.StreamCallRequest], - context: grpc.ServicerContext + self, + request_iterator: Iterable[phone_pb2.StreamCallRequest], + context: grpc.ServicerContext, ) -> Iterable[phone_pb2.StreamCallResponse]: try: request = next(request_iterator) - logging.info("Received a phone call request for number [%s]", - request.phone_number) + logging.info( + "Received a phone call request for number [%s]", + request.phone_number, + ) except StopIteration: raise RuntimeError("Failed to receive call request") # Simulate the acceptance of call request diff --git a/examples/python/auth/_credentials.py b/examples/python/auth/_credentials.py index ebd0ef6d74889..8a8d3fd56333e 100644 --- a/examples/python/auth/_credentials.py +++ b/examples/python/auth/_credentials.py @@ -18,10 +18,10 @@ def _load_credential_from_file(filepath): real_path = os.path.join(os.path.dirname(__file__), filepath) - with open(real_path, 'rb') as f: + with open(real_path, "rb") as f: return f.read() -SERVER_CERTIFICATE = _load_credential_from_file('credentials/localhost.crt') -SERVER_CERTIFICATE_KEY = _load_credential_from_file('credentials/localhost.key') -ROOT_CERTIFICATE = _load_credential_from_file('credentials/root.crt') +SERVER_CERTIFICATE = _load_credential_from_file("credentials/localhost.crt") +SERVER_CERTIFICATE_KEY = _load_credential_from_file("credentials/localhost.key") +ROOT_CERTIFICATE = _load_credential_from_file("credentials/root.crt") diff --git a/examples/python/auth/async_customized_auth_client.py b/examples/python/auth/async_customized_auth_client.py index d0c6f8fa9f6cc..39191f08d265c 100644 --- a/examples/python/auth/async_customized_auth_client.py +++ b/examples/python/auth/async_customized_auth_client.py @@ -21,19 +21,22 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) -_SERVER_ADDR_TEMPLATE = 'localhost:%d' -_SIGNATURE_HEADER_KEY = 'x-signature' +_SERVER_ADDR_TEMPLATE = "localhost:%d" +_SIGNATURE_HEADER_KEY = "x-signature" class AuthGateway(grpc.AuthMetadataPlugin): - - def __call__(self, context: grpc.AuthMetadataContext, - callback: grpc.AuthMetadataPluginCallback) -> None: + def __call__( + self, + context: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ) -> None: """Implements authentication by passing metadata to a callback. Implementations of this method must not block. @@ -54,11 +57,13 @@ def __call__(self, context: grpc.AuthMetadataContext, def create_client_channel(addr: str) -> grpc.aio.Channel: # Call credential object will be invoked for every single RPC - call_credentials = grpc.metadata_call_credentials(AuthGateway(), - name='auth gateway') + call_credentials = grpc.metadata_call_credentials( + AuthGateway(), name="auth gateway" + ) # Channel credential will be valid for the entire channel channel_credential = grpc.ssl_channel_credentials( - _credentials.ROOT_CERTIFICATE) + _credentials.ROOT_CERTIFICATE + ) # Combining channel credentials and call credentials together composite_credentials = grpc.composite_channel_credentials( channel_credential, @@ -70,24 +75,26 @@ def create_client_channel(addr: str) -> grpc.aio.Channel: async def send_rpc(channel: grpc.aio.Channel) -> helloworld_pb2.HelloReply: stub = helloworld_pb2_grpc.GreeterStub(channel) - request = helloworld_pb2.HelloRequest(name='you') + request = helloworld_pb2.HelloRequest(name="you") try: response = await stub.SayHello(request) except grpc.RpcError as rpc_error: - _LOGGER.error('Received error: %s', rpc_error) + _LOGGER.error("Received error: %s", rpc_error) return rpc_error else: - _LOGGER.info('Received message: %s', response) + _LOGGER.info("Received message: %s", response) return response async def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--port', - nargs='?', - type=int, - default=50051, - help='the address of server') + parser.add_argument( + "--port", + nargs="?", + type=int, + default=50051, + help="the address of server", + ) args = parser.parse_args() channel = create_client_channel(_SERVER_ADDR_TEMPLATE % args.port) @@ -95,6 +102,6 @@ async def main() -> None: await channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(main()) diff --git a/examples/python/auth/async_customized_auth_server.py b/examples/python/auth/async_customized_auth_server.py index e1571a517b9bf..71dbe54484139 100644 --- a/examples/python/auth/async_customized_auth_server.py +++ b/examples/python/auth/async_customized_auth_server.py @@ -22,34 +22,35 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) -_LISTEN_ADDRESS_TEMPLATE = 'localhost:%d' -_SIGNATURE_HEADER_KEY = 'x-signature' +_LISTEN_ADDRESS_TEMPLATE = "localhost:%d" +_SIGNATURE_HEADER_KEY = "x-signature" class SignatureValidationInterceptor(grpc.aio.ServerInterceptor): - def __init__(self): - def abort(ignored_request, context: grpc.aio.ServicerContext) -> None: - context.abort(grpc.StatusCode.UNAUTHENTICATED, 'Invalid signature') + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid signature") self._abort_handler = grpc.unary_unary_rpc_method_handler(abort) async def intercept_service( - self, continuation: Callable[[grpc.HandlerCallDetails], - Awaitable[grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: # Example HandlerCallDetails object: # _HandlerCallDetails( # method=u'/helloworld.Greeter/SayHello', # invocation_metadata=...) - method_name = handler_call_details.method.split('/')[-1] + method_name = handler_call_details.method.split("/")[-1] expected_metadata = (_SIGNATURE_HEADER_KEY, method_name[::-1]) if expected_metadata in handler_call_details.invocation_metadata: return await continuation(handler_call_details) @@ -58,10 +59,10 @@ async def intercept_service( class SimpleGreeter(helloworld_pb2_grpc.GreeterServicer): - - async def SayHello(self, request: helloworld_pb2.HelloRequest, - unused_context) -> helloworld_pb2.HelloReply: - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + async def SayHello( + self, request: helloworld_pb2.HelloRequest, unused_context + ) -> helloworld_pb2.HelloReply: + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) async def run_server(port: int) -> Tuple[grpc.aio.Server, int]: @@ -70,14 +71,19 @@ async def run_server(port: int) -> Tuple[grpc.aio.Server, int]: helloworld_pb2_grpc.add_GreeterServicer_to_server(SimpleGreeter(), server) # Loading credentials - server_credentials = grpc.ssl_server_credentials((( - _credentials.SERVER_CERTIFICATE_KEY, - _credentials.SERVER_CERTIFICATE, - ),)) + server_credentials = grpc.ssl_server_credentials( + ( + ( + _credentials.SERVER_CERTIFICATE_KEY, + _credentials.SERVER_CERTIFICATE, + ), + ) + ) # Pass down credentials - port = server.add_secure_port(_LISTEN_ADDRESS_TEMPLATE % port, - server_credentials) + port = server.add_secure_port( + _LISTEN_ADDRESS_TEMPLATE % port, server_credentials + ) await server.start() return server, port @@ -85,18 +91,16 @@ async def run_server(port: int) -> Tuple[grpc.aio.Server, int]: async def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--port', - nargs='?', - type=int, - default=50051, - help='the listening port') + parser.add_argument( + "--port", nargs="?", type=int, default=50051, help="the listening port" + ) args = parser.parse_args() server, port = await run_server(args.port) - logging.info('Server is listening at port :%d', port) + logging.info("Server is listening at port :%d", port) await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(main()) diff --git a/examples/python/auth/customized_auth_client.py b/examples/python/auth/customized_auth_client.py index 2f2f59a123212..65ded079f7ece 100644 --- a/examples/python/auth/customized_auth_client.py +++ b/examples/python/auth/customized_auth_client.py @@ -21,17 +21,17 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) -_SERVER_ADDR_TEMPLATE = 'localhost:%d' -_SIGNATURE_HEADER_KEY = 'x-signature' +_SERVER_ADDR_TEMPLATE = "localhost:%d" +_SIGNATURE_HEADER_KEY = "x-signature" class AuthGateway(grpc.AuthMetadataPlugin): - def __call__(self, context, callback): """Implements authentication by passing metadata to a callback. @@ -55,11 +55,13 @@ def __call__(self, context, callback): @contextlib.contextmanager def create_client_channel(addr): # Call credential object will be invoked for every single RPC - call_credentials = grpc.metadata_call_credentials(AuthGateway(), - name='auth gateway') + call_credentials = grpc.metadata_call_credentials( + AuthGateway(), name="auth gateway" + ) # Channel credential will be valid for the entire channel channel_credential = grpc.ssl_channel_credentials( - _credentials.ROOT_CERTIFICATE) + _credentials.ROOT_CERTIFICATE + ) # Combining channel credentials and call credentials together composite_credentials = grpc.composite_channel_credentials( channel_credential, @@ -71,30 +73,32 @@ def create_client_channel(addr): def send_rpc(channel): stub = helloworld_pb2_grpc.GreeterStub(channel) - request = helloworld_pb2.HelloRequest(name='you') + request = helloworld_pb2.HelloRequest(name="you") try: response = stub.SayHello(request) except grpc.RpcError as rpc_error: - _LOGGER.error('Received error: %s', rpc_error) + _LOGGER.error("Received error: %s", rpc_error) return rpc_error else: - _LOGGER.info('Received message: %s', response) + _LOGGER.info("Received message: %s", response) return response def main(): parser = argparse.ArgumentParser() - parser.add_argument('--port', - nargs='?', - type=int, - default=50051, - help='the address of server') + parser.add_argument( + "--port", + nargs="?", + type=int, + default=50051, + help="the address of server", + ) args = parser.parse_args() with create_client_channel(_SERVER_ADDR_TEMPLATE % args.port) as channel: send_rpc(channel) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main() diff --git a/examples/python/auth/customized_auth_server.py b/examples/python/auth/customized_auth_server.py index dafcbb0be8ec9..2357365677f2d 100644 --- a/examples/python/auth/customized_auth_server.py +++ b/examples/python/auth/customized_auth_server.py @@ -22,21 +22,20 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) -_LISTEN_ADDRESS_TEMPLATE = 'localhost:%d' -_SIGNATURE_HEADER_KEY = 'x-signature' +_LISTEN_ADDRESS_TEMPLATE = "localhost:%d" +_SIGNATURE_HEADER_KEY = "x-signature" class SignatureValidationInterceptor(grpc.ServerInterceptor): - def __init__(self): - def abort(ignored_request, context): - context.abort(grpc.StatusCode.UNAUTHENTICATED, 'Invalid signature') + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid signature") self._abortion = grpc.unary_unary_rpc_method_handler(abort) @@ -45,7 +44,7 @@ def intercept_service(self, continuation, handler_call_details): # _HandlerCallDetails( # method=u'/helloworld.Greeter/SayHello', # invocation_metadata=...) - method_name = handler_call_details.method.split('/')[-1] + method_name = handler_call_details.method.split("/")[-1] expected_metadata = (_SIGNATURE_HEADER_KEY, method_name[::-1]) if expected_metadata in handler_call_details.invocation_metadata: return continuation(handler_call_details) @@ -54,27 +53,33 @@ def intercept_service(self, continuation, handler_call_details): class SimpleGreeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, unused_context): - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) @contextlib.contextmanager def run_server(port): # Bind interceptor to server - server = grpc.server(futures.ThreadPoolExecutor(), - interceptors=(SignatureValidationInterceptor(),)) + server = grpc.server( + futures.ThreadPoolExecutor(), + interceptors=(SignatureValidationInterceptor(),), + ) helloworld_pb2_grpc.add_GreeterServicer_to_server(SimpleGreeter(), server) # Loading credentials - server_credentials = grpc.ssl_server_credentials((( - _credentials.SERVER_CERTIFICATE_KEY, - _credentials.SERVER_CERTIFICATE, - ),)) + server_credentials = grpc.ssl_server_credentials( + ( + ( + _credentials.SERVER_CERTIFICATE_KEY, + _credentials.SERVER_CERTIFICATE, + ), + ) + ) # Pass down credentials - port = server.add_secure_port(_LISTEN_ADDRESS_TEMPLATE % port, - server_credentials) + port = server.add_secure_port( + _LISTEN_ADDRESS_TEMPLATE % port, server_credentials + ) server.start() try: @@ -85,18 +90,16 @@ def run_server(port): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--port', - nargs='?', - type=int, - default=50051, - help='the listening port') + parser.add_argument( + "--port", nargs="?", type=int, default=50051, help="the listening port" + ) args = parser.parse_args() with run_server(args.port) as (server, port): - logging.info('Server is listening at port :%d', port) + logging.info("Server is listening at port :%d", port) server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main() diff --git a/examples/python/auth/test/_auth_example_test.py b/examples/python/auth/test/_auth_example_test.py index 45f207de4060f..377fa003832bf 100644 --- a/examples/python/auth/test/_auth_example_test.py +++ b/examples/python/auth/test/_auth_example_test.py @@ -24,15 +24,15 @@ from examples.python.auth import customized_auth_client from examples.python.auth import customized_auth_server -_SERVER_ADDR_TEMPLATE = 'localhost:%d' +_SERVER_ADDR_TEMPLATE = "localhost:%d" class AuthExampleTest(unittest.TestCase): - def test_successful_call(self): with customized_auth_server.run_server(0) as (_, port): with customized_auth_client.create_client_channel( - _SERVER_ADDR_TEMPLATE % port) as channel: + _SERVER_ADDR_TEMPLATE % port + ) as channel: customized_auth_client.send_rpc(channel) # No unhandled exception raised, test passed! @@ -45,18 +45,20 @@ def test_no_channel_credential(self): def test_no_call_credential(self): with customized_auth_server.run_server(0) as (_, port): channel_credential = grpc.ssl_channel_credentials( - _credentials.ROOT_CERTIFICATE) - with grpc.secure_channel(_SERVER_ADDR_TEMPLATE % port, - channel_credential) as channel: + _credentials.ROOT_CERTIFICATE + ) + with grpc.secure_channel( + _SERVER_ADDR_TEMPLATE % port, channel_credential + ) as channel: resp = customized_auth_client.send_rpc(channel) self.assertEqual(resp.code(), grpc.StatusCode.UNAUTHENTICATED) def test_successful_call_asyncio(self): - async def test_body(): server, port = await async_customized_auth_server.run_server(0) channel = async_customized_auth_client.create_client_channel( - _SERVER_ADDR_TEMPLATE % port) + _SERVER_ADDR_TEMPLATE % port + ) await async_customized_auth_client.send_rpc(channel) await channel.close() await server.stop(0) @@ -65,5 +67,5 @@ async def test_body(): asyncio.get_event_loop().run_until_complete(test_body()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/examples/python/cancellation/client.py b/examples/python/cancellation/client.py index c61b3c3a6dae0..ce9b94b55c17b 100644 --- a/examples/python/cancellation/client.py +++ b/examples/python/cancellation/client.py @@ -34,9 +34,12 @@ def run_unary_client(server_target, name, ideal_distance): with grpc.insecure_channel(server_target) as channel: stub = hash_name_pb2_grpc.HashFinderStub(channel) - future = stub.Find.future(hash_name_pb2.HashNameRequest( - desired_name=name, ideal_hamming_distance=ideal_distance), - wait_for_ready=True) + future = stub.Find.future( + hash_name_pb2.HashNameRequest( + desired_name=name, ideal_hamming_distance=ideal_distance + ), + wait_for_ready=True, + ) def cancel_request(unused_signum, unused_frame): future.cancel() @@ -47,15 +50,19 @@ def cancel_request(unused_signum, unused_frame): print(result) -def run_streaming_client(server_target, name, ideal_distance, - interesting_distance): +def run_streaming_client( + server_target, name, ideal_distance, interesting_distance +): with grpc.insecure_channel(server_target) as channel: stub = hash_name_pb2_grpc.HashFinderStub(channel) - result_generator = stub.FindRange(hash_name_pb2.HashNameRequest( - desired_name=name, - ideal_hamming_distance=ideal_distance, - interesting_hamming_distance=interesting_distance), - wait_for_ready=True) + result_generator = stub.FindRange( + hash_name_pb2.HashNameRequest( + desired_name=name, + ideal_hamming_distance=ideal_distance, + interesting_hamming_distance=interesting_distance, + ), + wait_for_ready=True, + ) def cancel_request(unused_signum, unused_frame): result_generator.cancel() @@ -68,29 +75,36 @@ def cancel_request(unused_signum, unused_frame): def main(): parser = argparse.ArgumentParser(description=_DESCRIPTION) - parser.add_argument("name", type=str, help='The desired name.') - parser.add_argument("--ideal-distance", - default=0, - nargs='?', - type=int, - help="The desired Hamming distance.") - parser.add_argument('--server', - default='localhost:50051', - type=str, - nargs='?', - help='The host-port pair at which to reach the server.') + parser.add_argument("name", type=str, help="The desired name.") parser.add_argument( - '--show-inferior', + "--ideal-distance", + default=0, + nargs="?", + type=int, + help="The desired Hamming distance.", + ) + parser.add_argument( + "--server", + default="localhost:50051", + type=str, + nargs="?", + help="The host-port pair at which to reach the server.", + ) + parser.add_argument( + "--show-inferior", default=None, type=int, - nargs='?', - help='Also show candidates with a Hamming distance less than this value.' + nargs="?", + help=( + "Also show candidates with a Hamming distance less than this value." + ), ) args = parser.parse_args() if args.show_inferior is not None: - run_streaming_client(args.server, args.name, args.ideal_distance, - args.show_inferior) + run_streaming_client( + args.server, args.name, args.ideal_distance, args.show_inferior + ) else: run_unary_client(args.server, args.name, args.ideal_distance) diff --git a/examples/python/cancellation/search.py b/examples/python/cancellation/search.py index 28ac3664fc885..731b40e25d2aa 100644 --- a/examples/python/cancellation/search.py +++ b/examples/python/cancellation/search.py @@ -53,8 +53,9 @@ def _get_substring_hamming_distance(candidate, target): if len(target) > len(candidate): raise ValueError("Candidate must be at least as long as target.") for i in range(len(candidate) - len(target) + 1): - distance = _get_hamming_distance(candidate[i:i + len(target)].lower(), - target.lower()) + distance = _get_hamming_distance( + candidate[i : i + len(target)].lower(), target.lower() + ) if min_distance is None or distance < min_distance: min_distance = distance return min_distance @@ -63,7 +64,7 @@ def _get_substring_hamming_distance(candidate, target): def _get_hash(secret): hasher = hashlib.sha1() hasher.update(secret) - return base64.b64encode(hasher.digest()).decode('ascii') + return base64.b64encode(hasher.digest()).decode("ascii") class ResourceLimitExceededError(Exception): @@ -80,7 +81,7 @@ def _bytestrings_of_length(length): All bytestrings of length `length`. """ for digits in itertools.product(range(_BYTE_MAX), repeat=length): - yield b''.join(struct.pack('B', i) for i in digits) + yield b"".join(struct.pack("B", i) for i in digits) def _all_bytestrings(): @@ -92,15 +93,18 @@ def _all_bytestrings(): All bytestrings in ascending order of length. """ for bytestring in itertools.chain.from_iterable( - _bytestrings_of_length(length) for length in itertools.count()): + _bytestrings_of_length(length) for length in itertools.count() + ): yield bytestring -def search(target, - ideal_distance, - stop_event, - maximum_hashes, - interesting_hamming_distance=None): +def search( + target, + ideal_distance, + stop_event, + maximum_hashes, + interesting_hamming_distance=None, +): """Find candidate strings. Search through the space of all bytestrings, in order of increasing length, @@ -130,18 +134,23 @@ def search(target, return candidate_hash = _get_hash(secret) distance = _get_substring_hamming_distance(candidate_hash, target) - if interesting_hamming_distance is not None and distance <= interesting_hamming_distance: + if ( + interesting_hamming_distance is not None + and distance <= interesting_hamming_distance + ): # Surface interesting candidates, but don't stop. yield hash_name_pb2.HashNameResponse( secret=base64.b64encode(secret), hashed_name=candidate_hash, - hamming_distance=distance) + hamming_distance=distance, + ) elif distance <= ideal_distance: # Yield ideal candidate and end the stream. yield hash_name_pb2.HashNameResponse( secret=base64.b64encode(secret), hashed_name=candidate_hash, - hamming_distance=distance) + hamming_distance=distance, + ) return hashes_computed += 1 if hashes_computed == maximum_hashes: diff --git a/examples/python/cancellation/server.py b/examples/python/cancellation/server.py index cbfd49605f31b..d2db0b24227d8 100644 --- a/examples/python/cancellation/server.py +++ b/examples/python/cancellation/server.py @@ -29,13 +29,12 @@ from examples.python.cancellation import hash_name_pb2_grpc _LOGGER = logging.getLogger(__name__) -_SERVER_HOST = 'localhost' +_SERVER_HOST = "localhost" _DESCRIPTION = "A server for finding hashes similar to names." class HashFinder(hash_name_pb2_grpc.HashFinderServicer): - def __init__(self, maximum_hashes): super(HashFinder, self).__init__() self._maximum_hashes = maximum_hashes @@ -51,9 +50,13 @@ def on_rpc_done(): candidates = [] try: candidates = list( - search.search(request.desired_name, - request.ideal_hamming_distance, stop_event, - self._maximum_hashes)) + search.search( + request.desired_name, + request.ideal_hamming_distance, + stop_event, + self._maximum_hashes, + ) + ) except search.ResourceLimitExceededError: _LOGGER.info("Cancelling RPC due to exhausted resources.") context.cancel() @@ -75,7 +78,8 @@ def on_rpc_done(): request.ideal_hamming_distance, stop_event, self._maximum_hashes, - interesting_hamming_distance=request.interesting_hamming_distance) + interesting_hamming_distance=request.interesting_hamming_distance, + ) try: for candidate in secret_generator: yield candidate @@ -89,11 +93,13 @@ def _running_server(port, maximum_hashes): # We use only a single servicer thread here to demonstrate that, if managed # carefully, cancelled RPCs can need not continue occupying servicers # threads. - server = grpc.server(futures.ThreadPoolExecutor(max_workers=1), - maximum_concurrent_rpcs=1) + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=1), maximum_concurrent_rpcs=1 + ) hash_name_pb2_grpc.add_HashFinderServicer_to_server( - HashFinder(maximum_hashes), server) - address = '{}:{}'.format(_SERVER_HOST, port) + HashFinder(maximum_hashes), server + ) + address = "{}:{}".format(_SERVER_HOST, port) actual_port = server.add_insecure_port(address) server.start() print("Server listening at '{}'".format(address)) @@ -102,17 +108,20 @@ def _running_server(port, maximum_hashes): def main(): parser = argparse.ArgumentParser(description=_DESCRIPTION) - parser.add_argument('--port', - type=int, - default=50051, - nargs='?', - help='The port on which the server will listen.') parser.add_argument( - '--maximum-hashes', + "--port", + type=int, + default=50051, + nargs="?", + help="The port on which the server will listen.", + ) + parser.add_argument( + "--maximum-hashes", type=int, default=1000000, - nargs='?', - help='The maximum number of hashes to search before cancelling.') + nargs="?", + help="The maximum number of hashes to search before cancelling.", + ) args = parser.parse_args() server = _running_server(args.port, args.maximum_hashes) server.wait_for_termination() diff --git a/examples/python/cancellation/test/_cancellation_example_test.py b/examples/python/cancellation/test/_cancellation_example_test.py index 2b936d2147ab1..c8a9097961407 100644 --- a/examples/python/cancellation/test/_cancellation_example_test.py +++ b/examples/python/cancellation/test/_cancellation_example_test.py @@ -21,9 +21,10 @@ import unittest _BINARY_DIR = os.path.realpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) -_SERVER_PATH = os.path.join(_BINARY_DIR, 'server') -_CLIENT_PATH = os.path.join(_BINARY_DIR, 'client') + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") +) +_SERVER_PATH = os.path.join(_BINARY_DIR, "server") +_CLIENT_PATH = os.path.join(_BINARY_DIR, "client") @contextlib.contextmanager @@ -32,33 +33,42 @@ def _get_port(): sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: raise RuntimeError("Failed to set SO_REUSEPORT.") - sock.bind(('', 0)) + sock.bind(("", 0)) try: yield sock.getsockname()[1] finally: sock.close() -def _start_client(server_port, - desired_string, - ideal_distance, - interesting_distance=None): - interesting_distance_args = () if interesting_distance is None else ( - '--show-inferior', interesting_distance) - return subprocess.Popen((_CLIENT_PATH, desired_string, '--server', - 'localhost:{}'.format(server_port), - '--ideal-distance', str(ideal_distance)) + - interesting_distance_args) +def _start_client( + server_port, desired_string, ideal_distance, interesting_distance=None +): + interesting_distance_args = ( + () + if interesting_distance is None + else ("--show-inferior", interesting_distance) + ) + return subprocess.Popen( + ( + _CLIENT_PATH, + desired_string, + "--server", + "localhost:{}".format(server_port), + "--ideal-distance", + str(ideal_distance), + ) + + interesting_distance_args + ) class CancellationExampleTest(unittest.TestCase): - def test_successful_run(self): with _get_port() as test_port: server_process = subprocess.Popen( - (_SERVER_PATH, '--port', str(test_port))) + (_SERVER_PATH, "--port", str(test_port)) + ) try: - client_process = _start_client(test_port, 'aa', 0) + client_process = _start_client(test_port, "aa", 0) client_return_code = client_process.wait() self.assertEqual(0, client_return_code) self.assertIsNone(server_process.poll()) @@ -69,12 +79,13 @@ def test_successful_run(self): def test_graceful_sigint(self): with _get_port() as test_port: server_process = subprocess.Popen( - (_SERVER_PATH, '--port', str(test_port))) + (_SERVER_PATH, "--port", str(test_port)) + ) try: - client_process1 = _start_client(test_port, 'aaaaaaaaaa', 0) + client_process1 = _start_client(test_port, "aaaaaaaaaa", 0) client_process1.send_signal(signal.SIGINT) client_process1.wait() - client_process2 = _start_client(test_port, 'aa', 0) + client_process2 = _start_client(test_port, "aa", 0) client_return_code = client_process2.wait() self.assertEqual(0, client_return_code) self.assertIsNone(server_process.poll()) @@ -83,5 +94,5 @@ def test_graceful_sigint(self): server_process.wait() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/examples/python/compression/client.py b/examples/python/compression/client.py index e1822f708d2d3..89ea855bdc0a2 100644 --- a/examples/python/compression/client.py +++ b/examples/python/compression/client.py @@ -25,7 +25,7 @@ from examples.protos import helloworld_pb2 from examples.protos import helloworld_pb2_grpc -_DESCRIPTION = 'A client capable of compression.' +_DESCRIPTION = "A client capable of compression." _COMPRESSION_OPTIONS = { "none": grpc.Compression.NoCompression, "deflate": grpc.Compression.Deflate, @@ -36,33 +36,41 @@ def run_client(channel_compression, call_compression, target): - with grpc.insecure_channel(target, - compression=channel_compression) as channel: + with grpc.insecure_channel( + target, compression=channel_compression + ) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you'), - compression=call_compression, - wait_for_ready=True) + response = stub.SayHello( + helloworld_pb2.HelloRequest(name="you"), + compression=call_compression, + wait_for_ready=True, + ) print("Response: {}".format(response)) def main(): parser = argparse.ArgumentParser(description=_DESCRIPTION) - parser.add_argument('--channel_compression', - default='none', - nargs='?', - choices=_COMPRESSION_OPTIONS.keys(), - help='The compression method to use for the channel.') parser.add_argument( - '--call_compression', - default='none', - nargs='?', + "--channel_compression", + default="none", + nargs="?", choices=_COMPRESSION_OPTIONS.keys(), - help='The compression method to use for an individual call.') - parser.add_argument('--server', - default='localhost:50051', - type=str, - nargs='?', - help='The host-port pair at which to reach the server.') + help="The compression method to use for the channel.", + ) + parser.add_argument( + "--call_compression", + default="none", + nargs="?", + choices=_COMPRESSION_OPTIONS.keys(), + help="The compression method to use for an individual call.", + ) + parser.add_argument( + "--server", + default="localhost:50051", + type=str, + nargs="?", + help="The host-port pair at which to reach the server.", + ) args = parser.parse_args() channel_compression = _COMPRESSION_OPTIONS[args.channel_compression] call_compression = _COMPRESSION_OPTIONS[args.call_compression] diff --git a/examples/python/compression/server.py b/examples/python/compression/server.py index d0a77bbe1ac86..d7da982e2a099 100644 --- a/examples/python/compression/server.py +++ b/examples/python/compression/server.py @@ -27,7 +27,7 @@ from examples.protos import helloworld_pb2 from examples.protos import helloworld_pb2_grpc -_DESCRIPTION = 'A server capable of compression.' +_DESCRIPTION = "A server capable of compression." _COMPRESSION_OPTIONS = { "none": grpc.Compression.NoCompression, "deflate": grpc.Compression.Deflate, @@ -35,11 +35,10 @@ } _LOGGER = logging.getLogger(__name__) -_SERVER_HOST = 'localhost' +_SERVER_HOST = "localhost" class Greeter(helloworld_pb2_grpc.GreeterServicer): - def __init__(self, no_compress_every_n): super(Greeter, self).__init__() self._no_compress_every_n = 0 @@ -49,7 +48,10 @@ def __init__(self, no_compress_every_n): def _should_suppress_compression(self): suppress_compression = False with self._counter_lock: - if self._no_compress_every_n and self._request_counter % self._no_compress_every_n == 0: + if ( + self._no_compress_every_n + and self._request_counter % self._no_compress_every_n == 0 + ): suppress_compression = True self._request_counter += 1 return suppress_compression @@ -57,16 +59,19 @@ def _should_suppress_compression(self): def SayHello(self, request, context): if self._should_suppress_compression(): context.set_response_compression(grpc.Compression.NoCompression) - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def run_server(server_compression, no_compress_every_n, port): - server = grpc.server(futures.ThreadPoolExecutor(), - compression=server_compression, - options=(('grpc.so_reuseport', 1),)) + server = grpc.server( + futures.ThreadPoolExecutor(), + compression=server_compression, + options=(("grpc.so_reuseport", 1),), + ) helloworld_pb2_grpc.add_GreeterServicer_to_server( - Greeter(no_compress_every_n), server) - address = '{}:{}'.format(_SERVER_HOST, port) + Greeter(no_compress_every_n), server + ) + address = "{}:{}".format(_SERVER_HOST, port) server.add_insecure_port(address) server.start() print("Server listening at '{}'".format(address)) @@ -75,24 +80,33 @@ def run_server(server_compression, no_compress_every_n, port): def main(): parser = argparse.ArgumentParser(description=_DESCRIPTION) - parser.add_argument('--server_compression', - default='none', - nargs='?', - choices=_COMPRESSION_OPTIONS.keys(), - help='The default compression method for the server.') - parser.add_argument('--no_compress_every_n', - type=int, - default=0, - nargs='?', - help='If set, every nth reply will be uncompressed.') - parser.add_argument('--port', - type=int, - default=50051, - nargs='?', - help='The port on which the server will listen.') + parser.add_argument( + "--server_compression", + default="none", + nargs="?", + choices=_COMPRESSION_OPTIONS.keys(), + help="The default compression method for the server.", + ) + parser.add_argument( + "--no_compress_every_n", + type=int, + default=0, + nargs="?", + help="If set, every nth reply will be uncompressed.", + ) + parser.add_argument( + "--port", + type=int, + default=50051, + nargs="?", + help="The port on which the server will listen.", + ) args = parser.parse_args() - run_server(_COMPRESSION_OPTIONS[args.server_compression], - args.no_compress_every_n, args.port) + run_server( + _COMPRESSION_OPTIONS[args.server_compression], + args.no_compress_every_n, + args.port, + ) if __name__ == "__main__": diff --git a/examples/python/compression/test/compression_example_test.py b/examples/python/compression/test/compression_example_test.py index 2a8f632a38a88..27af6ef929320 100644 --- a/examples/python/compression/test/compression_example_test.py +++ b/examples/python/compression/test/compression_example_test.py @@ -20,9 +20,10 @@ import unittest _BINARY_DIR = os.path.realpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) -_SERVER_PATH = os.path.join(_BINARY_DIR, 'server') -_CLIENT_PATH = os.path.join(_BINARY_DIR, 'client') + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") +) +_SERVER_PATH = os.path.join(_BINARY_DIR, "server") +_CLIENT_PATH = os.path.join(_BINARY_DIR, "client") @contextlib.contextmanager @@ -31,7 +32,7 @@ def _get_port(): sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: raise RuntimeError("Failed to set SO_REUSEPORT.") - sock.bind(('', 0)) + sock.bind(("", 0)) try: yield sock.getsockname()[1] finally: @@ -39,17 +40,28 @@ def _get_port(): class CompressionExampleTest(unittest.TestCase): - def test_compression_example(self): with _get_port() as test_port: server_process = subprocess.Popen( - (_SERVER_PATH, '--port', str(test_port), '--server_compression', - 'gzip')) + ( + _SERVER_PATH, + "--port", + str(test_port), + "--server_compression", + "gzip", + ) + ) try: - server_target = 'localhost:{}'.format(test_port) + server_target = "localhost:{}".format(test_port) client_process = subprocess.Popen( - (_CLIENT_PATH, '--server', server_target, - '--channel_compression', 'gzip')) + ( + _CLIENT_PATH, + "--server", + server_target, + "--channel_compression", + "gzip", + ) + ) client_return_code = client_process.wait() self.assertEqual(0, client_return_code) self.assertIsNone(server_process.poll()) @@ -58,5 +70,5 @@ def test_compression_example(self): server_process.wait() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/examples/python/data_transmission/alts_client.py b/examples/python/data_transmission/alts_client.py index 33d79e24e0a35..8cfe6fdc1248d 100644 --- a/examples/python/data_transmission/alts_client.py +++ b/examples/python/data_transmission/alts_client.py @@ -28,8 +28,8 @@ def main(): with grpc.secure_channel( - SERVER_ADDRESS, - credentials=grpc.alts_channel_credentials()) as channel: + SERVER_ADDRESS, credentials=grpc.alts_channel_credentials() + ) as channel: stub = demo_pb2_grpc.GRPCDemoStub(channel) simple_method(stub) client_streaming_method(stub) @@ -37,5 +37,5 @@ def main(): bidirectional_streaming_method(stub) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/python/data_transmission/alts_server.py b/examples/python/data_transmission/alts_server.py index 4e8747ce7a99b..7d7b6e355e8eb 100644 --- a/examples/python/data_transmission/alts_server.py +++ b/examples/python/data_transmission/alts_server.py @@ -22,18 +22,19 @@ import demo_pb2_grpc from server import DemoServer -SERVER_ADDRESS = 'localhost:23333' +SERVER_ADDRESS = "localhost:23333" def main(): svr = grpc.server(futures.ThreadPoolExecutor()) demo_pb2_grpc.add_GRPCDemoServicer_to_server(DemoServer(), svr) - svr.add_secure_port(SERVER_ADDRESS, - server_credentials=grpc.alts_server_credentials()) + svr.add_secure_port( + SERVER_ADDRESS, server_credentials=grpc.alts_server_credentials() + ) print("------------------start Python GRPC server with ALTS encryption") svr.start() svr.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/python/data_transmission/client.py b/examples/python/data_transmission/client.py index bfea00fb8b541..1dbfba57aaf80 100644 --- a/examples/python/data_transmission/client.py +++ b/examples/python/data_transmission/client.py @@ -21,8 +21,10 @@ import demo_pb2_grpc __all__ = [ - 'simple_method', 'client_streaming_method', 'server_streaming_method', - 'bidirectional_streaming_method' + "simple_method", + "client_streaming_method", + "server_streaming_method", + "bidirectional_streaming_method", ] SERVER_ADDRESS = "localhost:23333" @@ -38,11 +40,14 @@ # only respond once.) def simple_method(stub): print("--------------Call SimpleMethod Begin--------------") - request = demo_pb2.Request(client_id=CLIENT_ID, - request_data="called by Python client") + request = demo_pb2.Request( + client_id=CLIENT_ID, request_data="called by Python client" + ) response = stub.SimpleMethod(request) - print("resp from server(%d), the message=%s" % - (response.server_id, response.response_data)) + print( + "resp from server(%d), the message=%s" + % (response.server_id, response.response_data) + ) print("--------------Call SimpleMethod Over---------------") @@ -58,12 +63,15 @@ def request_messages(): for i in range(5): request = demo_pb2.Request( client_id=CLIENT_ID, - request_data=("called by Python client, message:%d" % i)) + request_data="called by Python client, message:%d" % i, + ) yield request response = stub.ClientStreamingMethod(request_messages()) - print("resp from server(%d), the message=%s" % - (response.server_id, response.response_data)) + print( + "resp from server(%d), the message=%s" + % (response.server_id, response.response_data) + ) print("--------------Call ClientStreamingMethod Over---------------") @@ -72,12 +80,15 @@ def request_messages(): # but the server can return the response many times.) def server_streaming_method(stub): print("--------------Call ServerStreamingMethod Begin--------------") - request = demo_pb2.Request(client_id=CLIENT_ID, - request_data="called by Python client") + request = demo_pb2.Request( + client_id=CLIENT_ID, request_data="called by Python client" + ) response_iterator = stub.ServerStreamingMethod(request) for response in response_iterator: - print("recv from server(%d), message=%s" % - (response.server_id, response.response_data)) + print( + "recv from server(%d), message=%s" + % (response.server_id, response.response_data) + ) print("--------------Call ServerStreamingMethod Over---------------") @@ -87,7 +98,8 @@ def server_streaming_method(stub): # to each other multiple times.) def bidirectional_streaming_method(stub): print( - "--------------Call BidirectionalStreamingMethod Begin---------------") + "--------------Call BidirectionalStreamingMethod Begin---------------" + ) # 创建一个生成器 # create a generator @@ -95,14 +107,17 @@ def request_messages(): for i in range(5): request = demo_pb2.Request( client_id=CLIENT_ID, - request_data=("called by Python client, message: %d" % i)) + request_data="called by Python client, message: %d" % i, + ) yield request time.sleep(1) response_iterator = stub.BidirectionalStreamingMethod(request_messages()) for response in response_iterator: - print("recv from server(%d), message=%s" % - (response.server_id, response.response_data)) + print( + "recv from server(%d), message=%s" + % (response.server_id, response.response_data) + ) print("--------------Call BidirectionalStreamingMethod Over---------------") @@ -120,5 +135,5 @@ def main(): bidirectional_streaming_method(stub) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/python/data_transmission/server.py b/examples/python/data_transmission/server.py index ab655f27f4caa..2223f3e02487c 100644 --- a/examples/python/data_transmission/server.py +++ b/examples/python/data_transmission/server.py @@ -21,22 +21,24 @@ import demo_pb2 import demo_pb2_grpc -__all__ = 'DemoServer' -SERVER_ADDRESS = 'localhost:23333' +__all__ = "DemoServer" +SERVER_ADDRESS = "localhost:23333" SERVER_ID = 1 class DemoServer(demo_pb2_grpc.GRPCDemoServicer): - # 一元模式(在一次调用中, 客户端只能向服务器传输一次请求数据, 服务器也只能返回一次响应) # unary-unary(In a single call, the client can only send request once, and the server can # only respond once.) def SimpleMethod(self, request, context): - print("SimpleMethod called by client(%d) the message: %s" % - (request.client_id, request.request_data)) + print( + "SimpleMethod called by client(%d) the message: %s" + % (request.client_id, request.request_data) + ) response = demo_pb2.Response( server_id=SERVER_ID, - response_data="Python server SimpleMethod Ok!!!!") + response_data="Python server SimpleMethod Ok!!!!", + ) return response # 客户端流模式(在一次调用中, 客户端可以多次向服务器传输数据, 但是服务器只能返回一次响应) @@ -45,19 +47,24 @@ def SimpleMethod(self, request, context): def ClientStreamingMethod(self, request_iterator, context): print("ClientStreamingMethod called by client...") for request in request_iterator: - print("recv from client(%d), message= %s" % - (request.client_id, request.request_data)) + print( + "recv from client(%d), message= %s" + % (request.client_id, request.request_data) + ) response = demo_pb2.Response( server_id=SERVER_ID, - response_data="Python server ClientStreamingMethod ok") + response_data="Python server ClientStreamingMethod ok", + ) return response # 服务端流模式(在一次调用中, 客户端只能一次向服务器传输数据, 但是服务器可以多次返回响应) # unary-stream (In a single call, the client can only transmit data to the server at one time, # but the server can return the response many times.) def ServerStreamingMethod(self, request, context): - print("ServerStreamingMethod called by client(%d), message= %s" % - (request.client_id, request.request_data)) + print( + "ServerStreamingMethod called by client(%d), message= %s" + % (request.client_id, request.request_data) + ) # 创建一个生成器 # create a generator @@ -65,7 +72,8 @@ def response_messages(): for i in range(5): response = demo_pb2.Response( server_id=SERVER_ID, - response_data=("send by Python server, message=%d" % i)) + response_data="send by Python server, message=%d" % i, + ) yield response return response_messages() @@ -80,8 +88,10 @@ def BidirectionalStreamingMethod(self, request_iterator, context): # Open a sub thread to receive data def parse_request(): for request in request_iterator: - print("recv from client(%d), message= %s" % - (request.client_id, request.request_data)) + print( + "recv from client(%d), message= %s" + % (request.client_id, request.request_data) + ) t = Thread(target=parse_request) t.start() @@ -89,7 +99,8 @@ def parse_request(): for i in range(5): yield demo_pb2.Response( server_id=SERVER_ID, - response_data=("send by Python server, message= %d" % i)) + response_data="send by Python server, message= %d" % i, + ) t.join() @@ -112,5 +123,5 @@ def main(): # time.sleep(10) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/python/debug/asyncio_debug_server.py b/examples/python/debug/asyncio_debug_server.py index de59fe1012399..4d5a1a3c811c2 100644 --- a/examples/python/debug/asyncio_debug_server.py +++ b/examples/python/debug/asyncio_debug_server.py @@ -21,7 +21,8 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) # TODO: Suppress until the macOS segfault fix rolled out from grpc_channelz.v1 import channelz # pylint: disable=wrong-import-position @@ -33,23 +34,26 @@ class FaultInjectGreeter(helloworld_pb2_grpc.GreeterServicer): - def __init__(self, failure_rate): self._failure_rate = failure_rate async def SayHello( - self, request: helloworld_pb2.HelloRequest, - context: grpc.aio.ServicerContext) -> helloworld_pb2.HelloReply: + self, + request: helloworld_pb2.HelloRequest, + context: grpc.aio.ServicerContext, + ) -> helloworld_pb2.HelloReply: if random.random() < self._failure_rate: - context.abort(grpc.StatusCode.UNAVAILABLE, - 'Randomly injected failure.') - return helloworld_pb2.HelloReply(message=f'Hello, {request.name}!') + context.abort( + grpc.StatusCode.UNAVAILABLE, "Randomly injected failure." + ) + return helloworld_pb2.HelloReply(message=f"Hello, {request.name}!") def create_server(addr: str, failure_rate: float) -> grpc.aio.Server: server = grpc.aio.server() helloworld_pb2_grpc.add_GreeterServicer_to_server( - FaultInjectGreeter(failure_rate), server) + FaultInjectGreeter(failure_rate), server + ) # Add Channelz Servicer to the gRPC server channelz.add_channelz_servicer(server) @@ -60,17 +64,20 @@ def create_server(addr: str, failure_rate: float) -> grpc.aio.Server: async def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--addr', - nargs=1, - type=str, - default='[::]:50051', - help='the address to listen on') parser.add_argument( - '--failure_rate', + "--addr", + nargs=1, + type=str, + default="[::]:50051", + help="the address to listen on", + ) + parser.add_argument( + "--failure_rate", nargs=1, type=float, default=0.3, - help='a float indicates the percentage of failed message injections') + help="a float indicates the percentage of failed message injections", + ) args = parser.parse_args() server = create_server(addr=args.addr, failure_rate=args.failure_rate) @@ -78,6 +85,6 @@ async def main() -> None: await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.get_event_loop().run_until_complete(main()) diff --git a/examples/python/debug/asyncio_get_stats.py b/examples/python/debug/asyncio_get_stats.py index 3a835ebec5a18..467d3dc75d5ff 100644 --- a/examples/python/debug/asyncio_get_stats.py +++ b/examples/python/debug/asyncio_get_stats.py @@ -26,21 +26,24 @@ async def run(addr: str) -> None: async with grpc.aio.insecure_channel(addr) as channel: channelz_stub = channelz_pb2_grpc.ChannelzStub(channel) response = await channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=0)) - print('Info for all servers: %s' % response) + channelz_pb2.GetServersRequest(start_server_id=0) + ) + print("Info for all servers: %s" % response) async def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--addr', - nargs=1, - type=str, - default='[::]:50051', - help='the address to request') + parser.add_argument( + "--addr", + nargs=1, + type=str, + default="[::]:50051", + help="the address to request", + ) args = parser.parse_args() run(addr=args.addr) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() asyncio.get_event_loop().run_until_complete(main()) diff --git a/examples/python/debug/asyncio_send_message.py b/examples/python/debug/asyncio_send_message.py index 1937ebf7a88b4..127d91a895bbc 100644 --- a/examples/python/debug/asyncio_send_message.py +++ b/examples/python/debug/asyncio_send_message.py @@ -20,43 +20,49 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) -async def process(stub: helloworld_pb2_grpc.GreeterStub, - request: helloworld_pb2.HelloRequest) -> None: +async def process( + stub: helloworld_pb2_grpc.GreeterStub, request: helloworld_pb2.HelloRequest +) -> None: try: response = await stub.SayHello(request) except grpc.aio.AioRpcError as rpc_error: - print(f'Received error: {rpc_error}') + print(f"Received error: {rpc_error}") else: - print(f'Received message: {response}') + print(f"Received message: {response}") async def run(addr: str, n: int) -> None: async with grpc.aio.insecure_channel(addr) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - request = helloworld_pb2.HelloRequest(name='you') + request = helloworld_pb2.HelloRequest(name="you") for _ in range(n): await process(stub, request) async def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--addr', - nargs=1, - type=str, - default='[::]:50051', - help='the address to request') - parser.add_argument('-n', - nargs=1, - type=int, - default=10, - help='an integer for number of messages to sent') + parser.add_argument( + "--addr", + nargs=1, + type=str, + default="[::]:50051", + help="the address to request", + ) + parser.add_argument( + "-n", + nargs=1, + type=int, + default=10, + help="an integer for number of messages to sent", + ) args = parser.parse_args() await run(addr=args.addr, n=args.n) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.get_event_loop().run_until_complete(main()) diff --git a/examples/python/debug/debug_server.py b/examples/python/debug/debug_server.py index dd98c2219d7e9..0fd31e303a1f1 100644 --- a/examples/python/debug/debug_server.py +++ b/examples/python/debug/debug_server.py @@ -25,7 +25,8 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) # TODO: Suppress until the macOS segfault fix rolled out from grpc_channelz.v1 import channelz # pylint: disable=wrong-import-position @@ -37,21 +38,22 @@ class FaultInjectGreeter(helloworld_pb2_grpc.GreeterServicer): - def __init__(self, failure_rate): self._failure_rate = failure_rate def SayHello(self, request, context): if random.random() < self._failure_rate: - context.abort(grpc.StatusCode.UNAVAILABLE, - 'Randomly injected failure.') - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + context.abort( + grpc.StatusCode.UNAVAILABLE, "Randomly injected failure." + ) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def create_server(addr, failure_rate): server = grpc.server(futures.ThreadPoolExecutor()) helloworld_pb2_grpc.add_GreeterServicer_to_server( - FaultInjectGreeter(failure_rate), server) + FaultInjectGreeter(failure_rate), server + ) # Add Channelz Servicer to the gRPC server channelz.add_channelz_servicer(server) @@ -62,17 +64,20 @@ def create_server(addr, failure_rate): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--addr', - nargs=1, - type=str, - default='[::]:50051', - help='the address to listen on') parser.add_argument( - '--failure_rate', + "--addr", + nargs=1, + type=str, + default="[::]:50051", + help="the address to listen on", + ) + parser.add_argument( + "--failure_rate", nargs=1, type=float, default=0.3, - help='a float indicates the percentage of failed message injections') + help="a float indicates the percentage of failed message injections", + ) args = parser.parse_args() server = create_server(addr=args.addr, failure_rate=args.failure_rate) @@ -80,6 +85,6 @@ def main(): server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main() diff --git a/examples/python/debug/get_stats.py b/examples/python/debug/get_stats.py index a7add5fd7902d..ce2ac7db110ca 100644 --- a/examples/python/debug/get_stats.py +++ b/examples/python/debug/get_stats.py @@ -32,20 +32,22 @@ def run(addr): # succeeded/failed RPCs. For more info see: # https://github.com/grpc/grpc/blob/master/src/proto/grpc/channelz/channelz.proto response = channelz_stub.GetServers(channelz_pb2.GetServersRequest()) - print(f'Info for all servers: {response}') + print(f"Info for all servers: {response}") def main(): parser = argparse.ArgumentParser() - parser.add_argument('--addr', - nargs=1, - type=str, - default='[::]:50051', - help='the address to request') + parser.add_argument( + "--addr", + nargs=1, + type=str, + default="[::]:50051", + help="the address to request", + ) args = parser.parse_args() run(addr=args.addr) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() main() diff --git a/examples/python/debug/send_message.py b/examples/python/debug/send_message.py index 9467a60a3cf0f..056c1b8d6fb41 100644 --- a/examples/python/debug/send_message.py +++ b/examples/python/debug/send_message.py @@ -23,42 +23,47 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) def process(stub, request): try: response = stub.SayHello(request) except grpc.RpcError as rpc_error: - print('Received error: %s' % rpc_error) + print("Received error: %s" % rpc_error) else: - print('Received message: %s' % response) + print("Received message: %s" % response) def run(addr, n): with grpc.insecure_channel(addr) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - request = helloworld_pb2.HelloRequest(name='you') + request = helloworld_pb2.HelloRequest(name="you") for _ in range(n): process(stub, request) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--addr', - nargs=1, - type=str, - default='[::]:50051', - help='the address to request') - parser.add_argument('-n', - nargs=1, - type=int, - default=10, - help='an integer for number of messages to sent') + parser.add_argument( + "--addr", + nargs=1, + type=str, + default="[::]:50051", + help="the address to request", + ) + parser.add_argument( + "-n", + nargs=1, + type=int, + default=10, + help="an integer for number of messages to sent", + ) args = parser.parse_args() run(addr=args.addr, n=args.n) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() main() diff --git a/examples/python/debug/test/_debug_example_test.py b/examples/python/debug/test/_debug_example_test.py index 57f206f586606..d149a6ec08086 100644 --- a/examples/python/debug/test/_debug_example_test.py +++ b/examples/python/debug/test/_debug_example_test.py @@ -30,15 +30,15 @@ _FAILURE_RATE = 0.5 _NUMBER_OF_MESSAGES = 100 -_ADDR_TEMPLATE = 'localhost:%d' +_ADDR_TEMPLATE = "localhost:%d" class DebugExampleTest(unittest.TestCase): - def test_channelz_example(self): - server = debug_server.create_server(addr='[::]:0', - failure_rate=_FAILURE_RATE) - port = server.add_insecure_port('[::]:0') + server = debug_server.create_server( + addr="[::]:0", failure_rate=_FAILURE_RATE + ) + port = server.add_insecure_port("[::]:0") server.start() address = _ADDR_TEMPLATE % port @@ -48,11 +48,11 @@ def test_channelz_example(self): # No unhandled exception raised, test passed! def test_asyncio_channelz_example(self): - async def body(): server = asyncio_debug_server.create_server( - addr='[::]:0', failure_rate=_FAILURE_RATE) - port = server.add_insecure_port('[::]:0') + addr="[::]:0", failure_rate=_FAILURE_RATE + ) + port = server.add_insecure_port("[::]:0") await server.start() address = _ADDR_TEMPLATE % port @@ -64,6 +64,6 @@ async def body(): asyncio.get_event_loop().run_until_complete(body()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/examples/python/errors/client.py b/examples/python/errors/client.py index c37d62e1a7bf7..94c5a5ba731c5 100644 --- a/examples/python/errors/client.py +++ b/examples/python/errors/client.py @@ -29,29 +29,29 @@ def process(stub): try: - response = stub.SayHello(helloworld_pb2.HelloRequest(name='Alice')) - _LOGGER.info('Call success: %s', response.message) + response = stub.SayHello(helloworld_pb2.HelloRequest(name="Alice")) + _LOGGER.info("Call success: %s", response.message) except grpc.RpcError as rpc_error: - _LOGGER.error('Call failure: %s', rpc_error) + _LOGGER.error("Call failure: %s", rpc_error) status = rpc_status.from_call(rpc_error) for detail in status.details: if detail.Is(error_details_pb2.QuotaFailure.DESCRIPTOR): info = error_details_pb2.QuotaFailure() detail.Unpack(info) - _LOGGER.error('Quota failure: %s', info) + _LOGGER.error("Quota failure: %s", info) else: - raise RuntimeError('Unexpected failure: %s' % detail) + raise RuntimeError("Unexpected failure: %s" % detail) def main(): # NOTE(gRPC Python Team): .close() is possible on a channel and should be # used in circumstances in which the with statement does not fit the needs # of the code. - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) process(stub) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() main() diff --git a/examples/python/errors/server.py b/examples/python/errors/server.py index 585ff0344e7ce..371847bac31ae 100644 --- a/examples/python/errors/server.py +++ b/examples/python/errors/server.py @@ -31,21 +31,23 @@ def create_greet_limit_exceed_error_status(name): detail = any_pb2.Any() detail.Pack( - error_details_pb2.QuotaFailure(violations=[ - error_details_pb2.QuotaFailure.Violation( - subject="name: %s" % name, - description="Limit one greeting per person", - ) - ],)) + error_details_pb2.QuotaFailure( + violations=[ + error_details_pb2.QuotaFailure.Violation( + subject="name: %s" % name, + description="Limit one greeting per person", + ) + ], + ) + ) return status_pb2.Status( code=code_pb2.RESOURCE_EXHAUSTED, - message='Request limit exceeded.', + message="Request limit exceeded.", details=[detail], ) class LimitedGreeter(helloworld_pb2_grpc.GreeterServicer): - def __init__(self): self._lock = threading.RLock() self._greeted = set() @@ -54,11 +56,12 @@ def SayHello(self, request, context): with self._lock: if request.name in self._greeted: rich_status = create_greet_limit_exceed_error_status( - request.name) + request.name + ) context.abort_with_status(rpc_status.to_status(rich_status)) else: self._greeted.add(request.name) - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def create_server(server_address): @@ -74,10 +77,10 @@ def serve(server): def main(): - server, unused_port = create_server('[::]:50051') + server, unused_port = create_server("[::]:50051") serve(server) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() main() diff --git a/examples/python/errors/test/_error_handling_example_test.py b/examples/python/errors/test/_error_handling_example_test.py index c5427f804ea02..539877737d077 100644 --- a/examples/python/errors/test/_error_handling_example_test.py +++ b/examples/python/errors/test/_error_handling_example_test.py @@ -17,6 +17,7 @@ # please refer to comments in the "bazel_namespace_package_hack" module. try: from tests import bazel_namespace_package_hack + bazel_namespace_package_hack.sys_path_to_site_dir_hack() except ImportError: pass @@ -32,11 +33,10 @@ class ErrorHandlingExampleTest(unittest.TestCase): - def setUp(self): - self._server, port = error_handling_server.create_server('[::]:0') + self._server, port = error_handling_server.create_server("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._channel.close() @@ -49,6 +49,6 @@ def test_error_handling_example(self): # No unhandled exception raised, test passed! -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/examples/python/health_checking/greeter_client.py b/examples/python/health_checking/greeter_client.py index 9405c1ef5bf81..6e11369d3ea3b 100644 --- a/examples/python/health_checking/greeter_client.py +++ b/examples/python/health_checking/greeter_client.py @@ -24,8 +24,9 @@ def unary_call(stub: helloworld_pb2_grpc.GreeterStub, message: str): - response = stub.SayHello(helloworld_pb2.HelloRequest(name=message), - timeout=3) + response = stub.SayHello( + helloworld_pb2.HelloRequest(name=message), timeout=3 + ) print(f"Greeter client received: {response.message}") @@ -39,11 +40,11 @@ def health_check_call(stub: health_pb2_grpc.HealthStub): def run(): - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) health_stub = health_pb2_grpc.HealthStub(channel) # Should succeed - unary_call(stub, 'you') + unary_call(stub, "you") # Check health status every 1 second for 30 seconds for _ in range(30): @@ -51,6 +52,6 @@ def run(): sleep(1) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/health_checking/greeter_server.py b/examples/python/health_checking/greeter_server.py index d10fd87adcb3f..ed5efd522d7f3 100644 --- a/examples/python/health_checking/greeter_server.py +++ b/examples/python/health_checking/greeter_server.py @@ -27,7 +27,6 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): return helloworld_pb2.HelloReply(message=request.name) @@ -47,28 +46,30 @@ def _toggle_health(health_servicer: health.HealthServicer, service: str): def _configure_health_server(server: grpc.Server): health_servicer = health.HealthServicer( experimental_non_blocking=True, - experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=10)) + experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=10), + ) health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) # Use a daemon thread to toggle health status - toggle_health_status_thread = threading.Thread(target=_toggle_health, - args=(health_servicer, - "helloworld.Greeter"), - daemon=True) + toggle_health_status_thread = threading.Thread( + target=_toggle_health, + args=(health_servicer, "helloworld.Greeter"), + daemon=True, + ) toggle_health_status_thread.start() def serve(): - port = '50051' + port = "50051" server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:' + port) + server.add_insecure_port("[::]:" + port) _configure_health_server(server) server.start() print("Server started, listening on " + port) server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/hellostreamingworld/async_greeter_client.py b/examples/python/hellostreamingworld/async_greeter_client.py index 754e5e3f4a48a..a2bc78370647c 100644 --- a/examples/python/hellostreamingworld/async_greeter_client.py +++ b/examples/python/hellostreamingworld/async_greeter_client.py @@ -27,19 +27,24 @@ async def run() -> None: # Read from an async generator async for response in stub.sayHello( - hellostreamingworld_pb2.HelloRequest(name="you")): - print("Greeter client received from async generator: " + - response.message) + hellostreamingworld_pb2.HelloRequest(name="you") + ): + print( + "Greeter client received from async generator: " + + response.message + ) # Direct read from the stub hello_stream = stub.sayHello( - hellostreamingworld_pb2.HelloRequest(name="you")) + hellostreamingworld_pb2.HelloRequest(name="you") + ) while True: response = await hello_stream.read() if response == grpc.aio.EOF: break - print("Greeter client received from direct read: " + - response.message) + print( + "Greeter client received from direct read: " + response.message + ) if __name__ == "__main__": diff --git a/examples/python/hellostreamingworld/async_greeter_server.py b/examples/python/hellostreamingworld/async_greeter_server.py index d87621c1db0b7..ecf1b9eae44f8 100644 --- a/examples/python/hellostreamingworld/async_greeter_server.py +++ b/examples/python/hellostreamingworld/async_greeter_server.py @@ -26,9 +26,9 @@ class Greeter(MultiGreeterServicer): - - async def sayHello(self, request: HelloRequest, - context: grpc.aio.ServicerContext) -> HelloReply: + async def sayHello( + self, request: HelloRequest, context: grpc.aio.ServicerContext + ) -> HelloReply: logging.info("Serving sayHello request %s", request) for i in range(NUMBER_OF_REPLY): yield HelloReply(message=f"Hello number {i}, {request.name}!") diff --git a/examples/python/helloworld/async_greeter_client.py b/examples/python/helloworld/async_greeter_client.py index b779a8819550e..78d64a58a8dc6 100644 --- a/examples/python/helloworld/async_greeter_client.py +++ b/examples/python/helloworld/async_greeter_client.py @@ -22,12 +22,12 @@ async def run() -> None: - async with grpc.aio.insecure_channel('localhost:50051') as channel: + async with grpc.aio.insecure_channel("localhost:50051") as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - response = await stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + response = await stub.SayHello(helloworld_pb2.HelloRequest(name="you")) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() asyncio.run(run()) diff --git a/examples/python/helloworld/async_greeter_client_with_options.py b/examples/python/helloworld/async_greeter_client_with_options.py index b625ec8923637..5fdb8f78f1878 100644 --- a/examples/python/helloworld/async_greeter_client_with_options.py +++ b/examples/python/helloworld/async_greeter_client_with_options.py @@ -21,22 +21,26 @@ import helloworld_pb2_grpc # For more channel options, please see https://grpc.io/grpc/core/group__grpc__arg__keys.html -CHANNEL_OPTIONS = [('grpc.lb_policy_name', 'pick_first'), - ('grpc.enable_retries', 0), - ('grpc.keepalive_timeout_ms', 10000)] +CHANNEL_OPTIONS = [ + ("grpc.lb_policy_name", "pick_first"), + ("grpc.enable_retries", 0), + ("grpc.keepalive_timeout_ms", 10000), +] async def run() -> None: - async with grpc.aio.insecure_channel(target='localhost:50051', - options=CHANNEL_OPTIONS) as channel: + async with grpc.aio.insecure_channel( + target="localhost:50051", options=CHANNEL_OPTIONS + ) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) # Timeout in seconds. # Please refer gRPC Python documents for more detail. https://grpc.io/grpc/python/grpc.html - response = await stub.SayHello(helloworld_pb2.HelloRequest(name='you'), - timeout=10) + response = await stub.SayHello( + helloworld_pb2.HelloRequest(name="you"), timeout=10 + ) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() asyncio.run(run()) diff --git a/examples/python/helloworld/async_greeter_server.py b/examples/python/helloworld/async_greeter_server.py index e0912a3454ea9..84070b2ffc812 100644 --- a/examples/python/helloworld/async_greeter_server.py +++ b/examples/python/helloworld/async_greeter_server.py @@ -22,23 +22,24 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - async def SayHello( - self, request: helloworld_pb2.HelloRequest, - context: grpc.aio.ServicerContext) -> helloworld_pb2.HelloReply: - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + self, + request: helloworld_pb2.HelloRequest, + context: grpc.aio.ServicerContext, + ) -> helloworld_pb2.HelloReply: + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) async def serve() -> None: server = grpc.aio.server() helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - listen_addr = '[::]:50051' + listen_addr = "[::]:50051" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) await server.start() await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(serve()) diff --git a/examples/python/helloworld/async_greeter_server_with_graceful_shutdown.py b/examples/python/helloworld/async_greeter_server_with_graceful_shutdown.py index 737f42b7a65cb..2c4dd2da1bd5e 100644 --- a/examples/python/helloworld/async_greeter_server_with_graceful_shutdown.py +++ b/examples/python/helloworld/async_greeter_server_with_graceful_shutdown.py @@ -25,20 +25,21 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - async def SayHello( - self, request: helloworld_pb2.HelloRequest, - context: grpc.aio.ServicerContext) -> helloworld_pb2.HelloReply: - logging.info('Received request, sleeping for 4 seconds...') + self, + request: helloworld_pb2.HelloRequest, + context: grpc.aio.ServicerContext, + ) -> helloworld_pb2.HelloReply: + logging.info("Received request, sleeping for 4 seconds...") await asyncio.sleep(4) - logging.info('Sleep completed, responding') - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + logging.info("Sleep completed, responding") + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) async def serve() -> None: server = grpc.aio.server() helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - listen_addr = '[::]:50051' + listen_addr = "[::]:50051" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) await server.start() @@ -54,7 +55,7 @@ async def server_graceful_shutdown(): await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) loop = asyncio.get_event_loop() try: diff --git a/examples/python/helloworld/async_greeter_server_with_reflection.py b/examples/python/helloworld/async_greeter_server_with_reflection.py index 5ce2abf9c06c9..8ef6ac07a9ae8 100644 --- a/examples/python/helloworld/async_greeter_server_with_reflection.py +++ b/examples/python/helloworld/async_greeter_server_with_reflection.py @@ -23,26 +23,27 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - async def SayHello( - self, request: helloworld_pb2.HelloRequest, - context: grpc.aio.ServicerContext) -> helloworld_pb2.HelloReply: - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + self, + request: helloworld_pb2.HelloRequest, + context: grpc.aio.ServicerContext, + ) -> helloworld_pb2.HelloReply: + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) async def serve() -> None: server = grpc.aio.server() helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) SERVICE_NAMES = ( - helloworld_pb2.DESCRIPTOR.services_by_name['Greeter'].full_name, + helloworld_pb2.DESCRIPTOR.services_by_name["Greeter"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, server) - server.add_insecure_port('[::]:50051') + server.add_insecure_port("[::]:50051") await server.start() await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() asyncio.run(serve()) diff --git a/examples/python/helloworld/greeter_client.py b/examples/python/helloworld/greeter_client.py index fc8f1b0905349..33fc52be36103 100644 --- a/examples/python/helloworld/greeter_client.py +++ b/examples/python/helloworld/greeter_client.py @@ -27,12 +27,12 @@ def run(): # used in circumstances in which the with statement does not fit the needs # of the code. print("Will try to greet world ...") - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + response = stub.SayHello(helloworld_pb2.HelloRequest(name="you")) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/helloworld/greeter_client_reflection.py b/examples/python/helloworld/greeter_client_reflection.py index 77e32930b7550..586f2a400b690 100644 --- a/examples/python/helloworld/greeter_client_reflection.py +++ b/examples/python/helloworld/greeter_client_reflection.py @@ -17,13 +17,14 @@ from google.protobuf.descriptor_pool import DescriptorPool import grpc -from grpc_reflection.v1alpha.proto_reflection_descriptor_database import \ - ProtoReflectionDescriptorDatabase +from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ( + ProtoReflectionDescriptorDatabase, +) def run(): print("Will try to greet world ...") - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: reflection_db = ProtoReflectionDescriptorDatabase(channel) services = reflection_db.get_services() print(f"found services: {services}") @@ -37,10 +38,11 @@ def run(): print(f"input type for this method: {input_type.full_name}") request_desc = desc_pool.FindMessageTypeByName( - "helloworld.HelloRequest") + "helloworld.HelloRequest" + ) print(f"found request name: {request_desc.full_name}") -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/helloworld/greeter_client_with_options.py b/examples/python/helloworld/greeter_client_with_options.py index efb5fd58bd8eb..c40fea9d2b2b5 100644 --- a/examples/python/helloworld/greeter_client_with_options.py +++ b/examples/python/helloworld/greeter_client_with_options.py @@ -28,19 +28,23 @@ def run(): # of the code. # # For more channel options, please see https://grpc.io/grpc/core/group__grpc__arg__keys.html - with grpc.insecure_channel(target='localhost:50051', - options=[('grpc.lb_policy_name', 'pick_first'), - ('grpc.enable_retries', 0), - ('grpc.keepalive_timeout_ms', 10000) - ]) as channel: + with grpc.insecure_channel( + target="localhost:50051", + options=[ + ("grpc.lb_policy_name", "pick_first"), + ("grpc.enable_retries", 0), + ("grpc.keepalive_timeout_ms", 10000), + ], + ) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) # Timeout in seconds. # Please refer gRPC Python documents for more detail. https://grpc.io/grpc/python/grpc.html - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you'), - timeout=10) + response = stub.SayHello( + helloworld_pb2.HelloRequest(name="you"), timeout=10 + ) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/helloworld/greeter_server.py b/examples/python/helloworld/greeter_server.py index 04a0e50efa26d..611067131eab4 100644 --- a/examples/python/helloworld/greeter_server.py +++ b/examples/python/helloworld/greeter_server.py @@ -22,21 +22,20 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def serve(): - port = '50051' + port = "50051" server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:' + port) + server.add_insecure_port("[::]:" + port) server.start() print("Server started, listening on " + port) server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/helloworld/greeter_server_with_reflection.py b/examples/python/helloworld/greeter_server_with_reflection.py index 8ce3b1d49aead..a5eb72e748237 100644 --- a/examples/python/helloworld/greeter_server_with_reflection.py +++ b/examples/python/helloworld/greeter_server_with_reflection.py @@ -23,24 +23,23 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) SERVICE_NAMES = ( - helloworld_pb2.DESCRIPTOR.services_by_name['Greeter'].full_name, + helloworld_pb2.DESCRIPTOR.services_by_name["Greeter"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, server) - server.add_insecure_port('[::]:50051') + server.add_insecure_port("[::]:50051") server.start() server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/interceptors/async/async_greeter_client.py b/examples/python/interceptors/async/async_greeter_client.py index 9101725da0798..3590c88ae7776 100644 --- a/examples/python/interceptors/async/async_greeter_client.py +++ b/examples/python/interceptors/async/async_greeter_client.py @@ -22,20 +22,23 @@ import helloworld_pb2 import helloworld_pb2_grpc -test_var = contextvars.ContextVar('test', default='test') +test_var = contextvars.ContextVar("test", default="test") async def run() -> None: - async with grpc.aio.insecure_channel('localhost:50051') as channel: + async with grpc.aio.insecure_channel("localhost:50051") as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - rpc_id = '{:032x}'.format(random.getrandbits(128)) - metadata = grpc.aio.Metadata(('client-rpc-id', rpc_id),) + rpc_id = "{:032x}".format(random.getrandbits(128)) + metadata = grpc.aio.Metadata( + ("client-rpc-id", rpc_id), + ) print(f"Sending request with rpc id: {rpc_id}") - response = await stub.SayHello(helloworld_pb2.HelloRequest(name='you'), - metadata=metadata) + response = await stub.SayHello( + helloworld_pb2.HelloRequest(name="you"), metadata=metadata + ) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() asyncio.run(run()) diff --git a/examples/python/interceptors/async/async_greeter_server_with_interceptor.py b/examples/python/interceptors/async/async_greeter_server_with_interceptor.py index 277718078f624..677355d40c8d4 100644 --- a/examples/python/interceptors/async/async_greeter_server_with_interceptor.py +++ b/examples/python/interceptors/async/async_greeter_server_with_interceptor.py @@ -22,19 +22,20 @@ import helloworld_pb2 import helloworld_pb2_grpc -rpc_id_var = contextvars.ContextVar('rpc_id', default='default') +rpc_id_var = contextvars.ContextVar("rpc_id", default="default") class RPCIdInterceptor(grpc.aio.ServerInterceptor): - def __init__(self, tag: str, rpc_id: Optional[str] = None) -> None: self.tag = tag self.rpc_id = rpc_id async def intercept_service( - self, continuation: Callable[[grpc.HandlerCallDetails], - Awaitable[grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: """ This interceptor prepends its tag to the rpc_id. @@ -42,9 +43,9 @@ async def intercept_service( will be something like this: Interceptor2-Interceptor1-RPC_ID. """ logging.info("%s called with rpc_id: %s", self.tag, rpc_id_var.get()) - if rpc_id_var.get() == 'default': + if rpc_id_var.get() == "default": _metadata = dict(handler_call_details.invocation_metadata) - rpc_id_var.set(self.decorate(_metadata['client-rpc-id'])) + rpc_id_var.set(self.decorate(_metadata["client-rpc-id"])) else: rpc_id_var.set(self.decorate(rpc_id_var.get())) return await continuation(handler_call_details) @@ -54,30 +55,32 @@ def decorate(self, rpc_id: str): class Greeter(helloworld_pb2_grpc.GreeterServicer): - async def SayHello( - self, request: helloworld_pb2.HelloRequest, - context: grpc.aio.ServicerContext) -> helloworld_pb2.HelloReply: - logging.info("Handle rpc with id %s in server handler.", - rpc_id_var.get()) - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + self, + request: helloworld_pb2.HelloRequest, + context: grpc.aio.ServicerContext, + ) -> helloworld_pb2.HelloReply: + logging.info( + "Handle rpc with id %s in server handler.", rpc_id_var.get() + ) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) async def serve() -> None: interceptors = [ - RPCIdInterceptor('Interceptor1'), - RPCIdInterceptor('Interceptor2') + RPCIdInterceptor("Interceptor1"), + RPCIdInterceptor("Interceptor2"), ] server = grpc.aio.server(interceptors=interceptors) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - listen_addr = '[::]:50051' + listen_addr = "[::]:50051" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) await server.start() await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(serve()) diff --git a/examples/python/interceptors/default_value/default_value_client_interceptor.py b/examples/python/interceptors/default_value/default_value_client_interceptor.py index c935b9549187d..3d0a7ca3412a5 100644 --- a/examples/python/interceptors/default_value/default_value_client_interceptor.py +++ b/examples/python/interceptors/default_value/default_value_client_interceptor.py @@ -17,7 +17,6 @@ class _ConcreteValue(grpc.Future): - def __init__(self, result): self._result = result @@ -46,21 +45,24 @@ def add_done_callback(self, fn): fn(self._result) -class DefaultValueClientInterceptor(grpc.UnaryUnaryClientInterceptor, - grpc.StreamUnaryClientInterceptor): - +class DefaultValueClientInterceptor( + grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor +): def __init__(self, value): self._default = _ConcreteValue(value) - def _intercept_call(self, continuation, client_call_details, - request_or_iterator): + def _intercept_call( + self, continuation, client_call_details, request_or_iterator + ): response = continuation(client_call_details, request_or_iterator) return self._default if response.exception() else response def intercept_unary_unary(self, continuation, client_call_details, request): return self._intercept_call(continuation, client_call_details, request) - def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): - return self._intercept_call(continuation, client_call_details, - request_iterator) + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): + return self._intercept_call( + continuation, client_call_details, request_iterator + ) diff --git a/examples/python/interceptors/default_value/greeter_client.py b/examples/python/interceptors/default_value/greeter_client.py index 7eaf556bc0009..d3fc5baf0edad 100644 --- a/examples/python/interceptors/default_value/greeter_client.py +++ b/examples/python/interceptors/default_value/greeter_client.py @@ -25,20 +25,25 @@ def run(): default_value = helloworld_pb2.HelloReply( - message='Hello from your local interceptor!') - default_value_interceptor = default_value_client_interceptor.DefaultValueClientInterceptor( - default_value) + message="Hello from your local interceptor!" + ) + default_value_interceptor = ( + default_value_client_interceptor.DefaultValueClientInterceptor( + default_value + ) + ) # NOTE(gRPC Python Team): .close() is possible on a channel and should be # used in circumstances in which the with statement does not fit the needs # of the code. - with grpc.insecure_channel('localhost:50051') as channel: - intercept_channel = grpc.intercept_channel(channel, - default_value_interceptor) + with grpc.insecure_channel("localhost:50051") as channel: + intercept_channel = grpc.intercept_channel( + channel, default_value_interceptor + ) stub = helloworld_pb2_grpc.GreeterStub(intercept_channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + response = stub.SayHello(helloworld_pb2.HelloRequest(name="you")) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/interceptors/headers/generic_client_interceptor.py b/examples/python/interceptors/headers/generic_client_interceptor.py index 3fed93d52df54..4cce31640f8ae 100644 --- a/examples/python/interceptors/headers/generic_client_interceptor.py +++ b/examples/python/interceptors/headers/generic_client_interceptor.py @@ -16,38 +16,46 @@ import grpc -class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor): - +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): def __init__(self, interceptor_function): self._fn = interceptor_function def intercept_unary_unary(self, continuation, client_call_details, request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, False) + client_call_details, iter((request,)), False, False + ) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response - def intercept_unary_stream(self, continuation, client_call_details, - request): + def intercept_unary_stream( + self, continuation, client_call_details, request + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, True) + client_call_details, iter((request,)), False, True + ) response_it = continuation(new_details, next(new_request_iterator)) return postprocess(response_it) if postprocess else response_it - def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, False) + client_call_details, request_iterator, True, False + ) response = continuation(new_details, new_request_iterator) return postprocess(response) if postprocess else response - def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, True) + client_call_details, request_iterator, True, True + ) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it diff --git a/examples/python/interceptors/headers/greeter_client.py b/examples/python/interceptors/headers/greeter_client.py index 0714aa7d02712..351e3dcb092f3 100644 --- a/examples/python/interceptors/headers/greeter_client.py +++ b/examples/python/interceptors/headers/greeter_client.py @@ -24,19 +24,23 @@ def run(): - header_adder_interceptor = header_manipulator_client_interceptor.header_adder_interceptor( - 'one-time-password', '42') + header_adder_interceptor = ( + header_manipulator_client_interceptor.header_adder_interceptor( + "one-time-password", "42" + ) + ) # NOTE(gRPC Python Team): .close() is possible on a channel and should be # used in circumstances in which the with statement does not fit the needs # of the code. - with grpc.insecure_channel('localhost:50051') as channel: - intercept_channel = grpc.intercept_channel(channel, - header_adder_interceptor) + with grpc.insecure_channel("localhost:50051") as channel: + intercept_channel = grpc.intercept_channel( + channel, header_adder_interceptor + ) stub = helloworld_pb2_grpc.GreeterStub(intercept_channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + response = stub.SayHello(helloworld_pb2.HelloRequest(name="you")) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/interceptors/headers/greeter_server.py b/examples/python/interceptors/headers/greeter_server.py index 95f3b2a22bcc0..ea0aa87a4b273 100644 --- a/examples/python/interceptors/headers/greeter_server.py +++ b/examples/python/interceptors/headers/greeter_server.py @@ -19,28 +19,33 @@ import grpc import helloworld_pb2 import helloworld_pb2_grpc -from request_header_validator_interceptor import \ - RequestHeaderValidatorInterceptor +from request_header_validator_interceptor import ( + RequestHeaderValidatorInterceptor, +) class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def serve(): header_validator = RequestHeaderValidatorInterceptor( - 'one-time-password', '42', grpc.StatusCode.UNAUTHENTICATED, - 'Access denied!') - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), - interceptors=(header_validator,)) + "one-time-password", + "42", + grpc.StatusCode.UNAUTHENTICATED, + "Access denied!", + ) + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=10), + interceptors=(header_validator,), + ) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:50051') + server.add_insecure_port("[::]:50051") server.start() server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/interceptors/headers/header_manipulator_client_interceptor.py b/examples/python/interceptors/headers/header_manipulator_client_interceptor.py index b2b06b8624bf5..28fffbd4fef52 100644 --- a/examples/python/interceptors/headers/header_manipulator_client_interceptor.py +++ b/examples/python/interceptors/headers/header_manipulator_client_interceptor.py @@ -20,27 +20,36 @@ class _ClientCallDetails( - collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials')), - grpc.ClientCallDetails): + collections.namedtuple( + "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") + ), + grpc.ClientCallDetails, +): pass def header_adder_interceptor(header, value): - - def intercept_call(client_call_details, request_iterator, request_streaming, - response_streaming): + def intercept_call( + client_call_details, + request_iterator, + request_streaming, + response_streaming, + ): metadata = [] if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) - metadata.append(( - header, - value, - )) + metadata.append( + ( + header, + value, + ) + ) client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials) + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + ) return client_call_details, request_iterator, None return generic_client_interceptor.create(intercept_call) diff --git a/examples/python/interceptors/headers/request_header_validator_interceptor.py b/examples/python/interceptors/headers/request_header_validator_interceptor.py index 95af4177baf97..5379743470ac3 100644 --- a/examples/python/interceptors/headers/request_header_validator_interceptor.py +++ b/examples/python/interceptors/headers/request_header_validator_interceptor.py @@ -17,7 +17,6 @@ def _unary_unary_rpc_terminator(code, details): - def terminate(ignored_request, context): context.abort(code, details) @@ -25,15 +24,16 @@ def terminate(ignored_request, context): class RequestHeaderValidatorInterceptor(grpc.ServerInterceptor): - def __init__(self, header, value, code, details): self._header = header self._value = value self._terminator = _unary_unary_rpc_terminator(code, details) def intercept_service(self, continuation, handler_call_details): - if (self._header, - self._value) in handler_call_details.invocation_metadata: + if ( + self._header, + self._value, + ) in handler_call_details.invocation_metadata: return continuation(handler_call_details) else: return self._terminator diff --git a/examples/python/keep_alive/greeter_client.py b/examples/python/keep_alive/greeter_client.py index c5d03cb0d7212..387acd8e8b1db 100644 --- a/examples/python/keep_alive/greeter_client.py +++ b/examples/python/keep_alive/greeter_client.py @@ -21,14 +21,15 @@ import helloworld_pb2_grpc -def unary_call(stub: helloworld_pb2_grpc.GreeterStub, request_id: int, - message: str): +def unary_call( + stub: helloworld_pb2_grpc.GreeterStub, request_id: int, message: str +): print("call:", request_id) try: response = stub.SayHello(helloworld_pb2.HelloRequest(name=message)) print(f"Greeter client received: {response.message}") except grpc.RpcError as rpc_error: - print('Call failed with code: ', rpc_error.code()) + print("Call failed with code: ", rpc_error.code()) def run(): @@ -44,16 +45,19 @@ def run(): send a data/header frame. For more details, check: https://github.com/grpc/grpc/blob/master/doc/keepalive.md """ - channel_options = [('grpc.keepalive_time_ms', 8000), - ('grpc.keepalive_timeout_ms', 5000), - ('grpc.http2.max_pings_without_data', 5), - ('grpc.keepalive_permit_without_calls', 1)] + channel_options = [ + ("grpc.keepalive_time_ms", 8000), + ("grpc.keepalive_timeout_ms", 5000), + ("grpc.http2.max_pings_without_data", 5), + ("grpc.keepalive_permit_without_calls", 1), + ] - with grpc.insecure_channel(target='localhost:50051', - options=channel_options) as channel: + with grpc.insecure_channel( + target="localhost:50051", options=channel_options + ) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) # Should succeed - unary_call(stub, 1, 'you') + unary_call(stub, 1, "you") # Run 30s, run this with GRPC_VERBOSITY=DEBUG GRPC_TRACE=http_keepalive to observe logs. # Client will be closed after receveing GOAWAY from server. @@ -62,6 +66,6 @@ def run(): sleep(1) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/keep_alive/greeter_server.py b/examples/python/keep_alive/greeter_server.py index db01ab3f5e6d8..edfa306a90ca3 100644 --- a/examples/python/keep_alive/greeter_server.py +++ b/examples/python/keep_alive/greeter_server.py @@ -23,7 +23,6 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): message = request.name if message.startswith("[delay]"): @@ -51,24 +50,28 @@ def serve(): pings to be sent even if there are no calls in flight. For more details, check: https://github.com/grpc/grpc/blob/master/doc/keepalive.md """ - server_options = [('grpc.keepalive_time_ms', 20000), - ('grpc.keepalive_timeout_ms', 10000), - ('grpc.http2.min_ping_interval_without_data_ms', 5000), - ('grpc.max_connection_idle_ms', 10000), - ('grpc.max_connection_age_ms', 30000), - ('grpc.max_connection_age_grace_ms', 5000), - ('grpc.http2.max_pings_without_data', 5), - ('grpc.keepalive_permit_without_calls', 1)] - port = '50051' - server = grpc.server(thread_pool=futures.ThreadPoolExecutor(max_workers=10), - options=server_options) + server_options = [ + ("grpc.keepalive_time_ms", 20000), + ("grpc.keepalive_timeout_ms", 10000), + ("grpc.http2.min_ping_interval_without_data_ms", 5000), + ("grpc.max_connection_idle_ms", 10000), + ("grpc.max_connection_age_ms", 30000), + ("grpc.max_connection_age_grace_ms", 5000), + ("grpc.http2.max_pings_without_data", 5), + ("grpc.keepalive_permit_without_calls", 1), + ] + port = "50051" + server = grpc.server( + thread_pool=futures.ThreadPoolExecutor(max_workers=10), + options=server_options, + ) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:' + port) + server.add_insecure_port("[::]:" + port) server.start() print("Server started, listening on " + port) server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/lb_policies/greeter_client.py b/examples/python/lb_policies/greeter_client.py index 4e785dbcae146..5bcf666a1cbf6 100644 --- a/examples/python/lb_policies/greeter_client.py +++ b/examples/python/lb_policies/greeter_client.py @@ -25,12 +25,12 @@ def run(): options = (("grpc.lb_policy_name", "round_robin"),) # Load balancing takes effect when the DNS server returns multiple IPs for the DNS hostname. # Replace "localhost" with such hostname to see the round robin LB policy take effect. - with grpc.insecure_channel('localhost:50051', options=options) as channel: + with grpc.insecure_channel("localhost:50051", options=options) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + response = stub.SayHello(helloworld_pb2.HelloRequest(name="you")) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/lb_policies/greeter_server.py b/examples/python/lb_policies/greeter_server.py index 1b86e11761be2..b7356a8ee104c 100644 --- a/examples/python/lb_policies/greeter_server.py +++ b/examples/python/lb_policies/greeter_server.py @@ -22,21 +22,20 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def serve(): - port = '50051' + port = "50051" server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:' + port) + server.add_insecure_port("[::]:" + port) server.start() print("Server started, listening on " + port) server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/metadata/metadata_client.py b/examples/python/metadata/metadata_client.py index 147dcedc61000..ced46ee1a2ebf 100644 --- a/examples/python/metadata/metadata_client.py +++ b/examples/python/metadata/metadata_client.py @@ -26,23 +26,28 @@ def run(): # NOTE(gRPC Python Team): .close() is possible on a channel and should be # used in circumstances in which the with statement does not fit the needs # of the code. - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) response, call = stub.SayHello.with_call( - helloworld_pb2.HelloRequest(name='you'), + helloworld_pb2.HelloRequest(name="you"), metadata=( - ('initial-metadata-1', 'The value should be str'), - ('binary-metadata-bin', - b'With -bin surffix, the value can be bytes'), - ('accesstoken', 'gRPC Python is great'), - )) + ("initial-metadata-1", "The value should be str"), + ( + "binary-metadata-bin", + b"With -bin surffix, the value can be bytes", + ), + ("accesstoken", "gRPC Python is great"), + ), + ) print("Greeter client received: " + response.message) for key, value in call.trailing_metadata(): - print('Greeter client received trailing metadata: key=%s value=%s' % - (key, value)) + print( + "Greeter client received trailing metadata: key=%s value=%s" + % (key, value) + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/metadata/metadata_server.py b/examples/python/metadata/metadata_server.py index 1562e1434634b..8340ff632b402 100644 --- a/examples/python/metadata/metadata_server.py +++ b/examples/python/metadata/metadata_server.py @@ -24,26 +24,27 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): for key, value in context.invocation_metadata(): - print('Received initial metadata: key=%s value=%s' % (key, value)) + print("Received initial metadata: key=%s value=%s" % (key, value)) - context.set_trailing_metadata(( - ('checksum-bin', b'I agree'), - ('retry', 'false'), - )) - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + context.set_trailing_metadata( + ( + ("checksum-bin", b"I agree"), + ("retry", "false"), + ) + ) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:50051') + server.add_insecure_port("[::]:50051") server.start() server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/multiplex/multiplex_client.py b/examples/python/multiplex/multiplex_client.py index d97d6c65bf405..5faf4355fd0bb 100644 --- a/examples/python/multiplex/multiplex_client.py +++ b/examples/python/multiplex/multiplex_client.py @@ -30,7 +30,8 @@ def make_route_note(message, latitude, longitude): return route_guide_pb2.RouteNote( message=message, - location=route_guide_pb2.Point(latitude=latitude, longitude=longitude)) + location=route_guide_pb2.Point(latitude=latitude, longitude=longitude), + ) def guide_get_one_feature(route_guide_stub, point): @@ -48,15 +49,18 @@ def guide_get_one_feature(route_guide_stub, point): def guide_get_feature(route_guide_stub): guide_get_one_feature( route_guide_stub, - route_guide_pb2.Point(latitude=409146138, longitude=-746188906)) - guide_get_one_feature(route_guide_stub, - route_guide_pb2.Point(latitude=0, longitude=0)) + route_guide_pb2.Point(latitude=409146138, longitude=-746188906), + ) + guide_get_one_feature( + route_guide_stub, route_guide_pb2.Point(latitude=0, longitude=0) + ) def guide_list_features(route_guide_stub): rectangle = route_guide_pb2.Rectangle( lo=route_guide_pb2.Point(latitude=400000000, longitude=-750000000), - hi=route_guide_pb2.Point(latitude=420000000, longitude=-730000000)) + hi=route_guide_pb2.Point(latitude=420000000, longitude=-730000000), + ) print("Looking for features between 40, -75 and 42, -73") features = route_guide_stub.ListFeatures(rectangle) @@ -101,19 +105,21 @@ def generate_messages(): def guide_route_chat(route_guide_stub): responses = route_guide_stub.RouteChat(generate_messages()) for response in responses: - print("Received message %s at %s" % - (response.message, response.location)) + print( + "Received message %s at %s" % (response.message, response.location) + ) def run(): # NOTE(gRPC Python Team): .close() is possible on a channel and should be # used in circumstances in which the with statement does not fit the needs # of the code. - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: greeter_stub = helloworld_pb2_grpc.GreeterStub(channel) route_guide_stub = route_guide_pb2_grpc.RouteGuideStub(channel) greeter_response = greeter_stub.SayHello( - helloworld_pb2.HelloRequest(name='you')) + helloworld_pb2.HelloRequest(name="you") + ) print("Greeter client received: " + greeter_response.message) print("-------------- GetFeature --------------") guide_get_feature(route_guide_stub) @@ -125,6 +131,6 @@ def run(): guide_route_chat(route_guide_stub) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/multiplex/multiplex_server.py b/examples/python/multiplex/multiplex_server.py index eb1902aafcd10..4ad58ee0f4816 100644 --- a/examples/python/multiplex/multiplex_server.py +++ b/examples/python/multiplex/multiplex_server.py @@ -46,9 +46,11 @@ def _get_distance(start, end): delta_lat_rad = math.radians(lat_2 - lat_1) delta_lon_rad = math.radians(lon_2 - lon_1) - a = (pow(math.sin(delta_lat_rad / 2), 2) + - (math.cos(lat_rad_1) * math.cos(lat_rad_2) * - pow(math.sin(delta_lon_rad / 2), 2))) + a = pow(math.sin(delta_lat_rad / 2), 2) + ( + math.cos(lat_rad_1) + * math.cos(lat_rad_2) + * pow(math.sin(delta_lon_rad / 2), 2) + ) c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) R = 6371000 # metres @@ -56,10 +58,10 @@ def _get_distance(start, end): class _GreeterServicer(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): return helloworld_pb2.HelloReply( - message='Hello, {}!'.format(request.name)) + message="Hello, {}!".format(request.name) + ) class _RouteGuideServicer(route_guide_pb2_grpc.RouteGuideServicer): @@ -81,10 +83,12 @@ def ListFeatures(self, request, context): top = max(request.lo.latitude, request.hi.latitude) bottom = min(request.lo.latitude, request.hi.latitude) for feature in self.db: - if (feature.location.longitude >= left and - feature.location.longitude <= right and - feature.location.latitude >= bottom and - feature.location.latitude <= top): + if ( + feature.location.longitude >= left + and feature.location.longitude <= right + and feature.location.latitude >= bottom + and feature.location.latitude <= top + ): yield feature def RecordRoute(self, request_iterator, context): @@ -103,10 +107,12 @@ def RecordRoute(self, request_iterator, context): prev_point = point elapsed_time = time.time() - start_time - return route_guide_pb2.RouteSummary(point_count=point_count, - feature_count=feature_count, - distance=int(distance), - elapsed_time=int(elapsed_time)) + return route_guide_pb2.RouteSummary( + point_count=point_count, + feature_count=feature_count, + distance=int(distance), + elapsed_time=int(elapsed_time), + ) def RouteChat(self, request_iterator, context): prev_notes = [] @@ -119,15 +125,17 @@ def RouteChat(self, request_iterator, context): def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - helloworld_pb2_grpc.add_GreeterServicer_to_server(_GreeterServicer(), - server) + helloworld_pb2_grpc.add_GreeterServicer_to_server( + _GreeterServicer(), server + ) route_guide_pb2_grpc.add_RouteGuideServicer_to_server( - _RouteGuideServicer(), server) - server.add_insecure_port('[::]:50051') + _RouteGuideServicer(), server + ) + server.add_insecure_port("[::]:50051") server.start() server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/multiplex/route_guide_resources.py b/examples/python/multiplex/route_guide_resources.py index ace85d6f9d431..7cee03adabc6c 100644 --- a/examples/python/multiplex/route_guide_resources.py +++ b/examples/python/multiplex/route_guide_resources.py @@ -21,10 +21,10 @@ def read_route_guide_database(): """Reads the route guide database. - Returns: - The full contents of the route guide database as a sequence of - route_guide_pb2.Features. - """ + Returns: + The full contents of the route guide database as a sequence of + route_guide_pb2.Features. + """ feature_list = [] with open("route_guide_db.json") as route_guide_db_file: for item in json.load(route_guide_db_file): @@ -32,6 +32,8 @@ def read_route_guide_database(): name=item["name"], location=route_guide_pb2.Point( latitude=item["location"]["latitude"], - longitude=item["location"]["longitude"])) + longitude=item["location"]["longitude"], + ), + ) feature_list.append(feature) return feature_list diff --git a/examples/python/multiplex/run_codegen.py b/examples/python/multiplex/run_codegen.py index be8915fe201fb..d3836699fe35c 100644 --- a/examples/python/multiplex/run_codegen.py +++ b/examples/python/multiplex/run_codegen.py @@ -15,17 +15,21 @@ from grpc_tools import protoc -protoc.main(( - '', - '-I../../protos', - '--python_out=.', - '--grpc_python_out=.', - '../../protos/helloworld.proto', -)) -protoc.main(( - '', - '-I../../protos', - '--python_out=.', - '--grpc_python_out=.', - '../../protos/route_guide.proto', -)) +protoc.main( + ( + "", + "-I../../protos", + "--python_out=.", + "--grpc_python_out=.", + "../../protos/helloworld.proto", + ) +) +protoc.main( + ( + "", + "-I../../protos", + "--python_out=.", + "--grpc_python_out=.", + "../../protos/route_guide.proto", + ) +) diff --git a/examples/python/multiprocessing/client.py b/examples/python/multiprocessing/client.py index 67681c7b9f205..17103fa4ddcfb 100644 --- a/examples/python/multiprocessing/client.py +++ b/examples/python/multiprocessing/client.py @@ -49,43 +49,50 @@ def _shutdown_worker(): def _initialize_worker(server_address): global _worker_channel_singleton # pylint: disable=global-statement global _worker_stub_singleton # pylint: disable=global-statement - _LOGGER.info('Initializing worker process.') + _LOGGER.info("Initializing worker process.") _worker_channel_singleton = grpc.insecure_channel(server_address) _worker_stub_singleton = prime_pb2_grpc.PrimeCheckerStub( - _worker_channel_singleton) + _worker_channel_singleton + ) atexit.register(_shutdown_worker) def _run_worker_query(primality_candidate): - _LOGGER.info('Checking primality of %s.', primality_candidate) + _LOGGER.info("Checking primality of %s.", primality_candidate) return _worker_stub_singleton.check( - prime_pb2.PrimeCandidate(candidate=primality_candidate)) + prime_pb2.PrimeCandidate(candidate=primality_candidate) + ) def _calculate_primes(server_address): - worker_pool = multiprocessing.Pool(processes=_PROCESS_COUNT, - initializer=_initialize_worker, - initargs=(server_address,)) + worker_pool = multiprocessing.Pool( + processes=_PROCESS_COUNT, + initializer=_initialize_worker, + initargs=(server_address,), + ) check_range = range(2, _MAXIMUM_CANDIDATE) primality = worker_pool.map(_run_worker_query, check_range) - primes = zip(check_range, map(operator.attrgetter('isPrime'), primality)) + primes = zip(check_range, map(operator.attrgetter("isPrime"), primality)) return tuple(primes) def main(): - msg = 'Determine the primality of the first {} integers.'.format( - _MAXIMUM_CANDIDATE) + msg = "Determine the primality of the first {} integers.".format( + _MAXIMUM_CANDIDATE + ) parser = argparse.ArgumentParser(description=msg) - parser.add_argument('server_address', - help='The address of the server (e.g. localhost:50051)') + parser.add_argument( + "server_address", + help="The address of the server (e.g. localhost:50051)", + ) args = parser.parse_args() primes = _calculate_primes(args.server_address) print(primes) -if __name__ == '__main__': +if __name__ == "__main__": handler = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter('[PID %(process)d] %(message)s') + formatter = logging.Formatter("[PID %(process)d] %(message)s") handler.setFormatter(formatter) _LOGGER.addHandler(handler) _LOGGER.setLevel(logging.INFO) diff --git a/examples/python/multiprocessing/server.py b/examples/python/multiprocessing/server.py index 48e49eccd8da3..f5ddd5d01e5bb 100644 --- a/examples/python/multiprocessing/server.py +++ b/examples/python/multiprocessing/server.py @@ -47,9 +47,8 @@ def is_prime(n): class PrimeChecker(prime_pb2_grpc.PrimeCheckerServicer): - def check(self, request, context): - _LOGGER.info('Determining primality of %s', request.candidate) + _LOGGER.info("Determining primality of %s", request.candidate) return prime_pb2.Primality(isPrime=is_prime(request.candidate)) @@ -63,12 +62,15 @@ def _wait_forever(server): def _run_server(bind_address): """Start a server in a subprocess.""" - _LOGGER.info('Starting new server.') - options = (('grpc.so_reuseport', 1),) - - server = grpc.server(futures.ThreadPoolExecutor( - max_workers=_THREAD_CONCURRENCY,), - options=options) + _LOGGER.info("Starting new server.") + options = (("grpc.so_reuseport", 1),) + + server = grpc.server( + futures.ThreadPoolExecutor( + max_workers=_THREAD_CONCURRENCY, + ), + options=options, + ) prime_pb2_grpc.add_PrimeCheckerServicer_to_server(PrimeChecker(), server) server.add_insecure_port(bind_address) server.start() @@ -82,7 +84,7 @@ def _reserve_port(): sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: raise RuntimeError("Failed to set SO_REUSEPORT.") - sock.bind(('', 0)) + sock.bind(("", 0)) try: yield sock.getsockname()[1] finally: @@ -91,7 +93,7 @@ def _reserve_port(): def main(): with _reserve_port() as port: - bind_address = 'localhost:{}'.format(port) + bind_address = "localhost:{}".format(port) _LOGGER.info("Binding to '%s'", bind_address) sys.stdout.flush() workers = [] @@ -99,17 +101,18 @@ def main(): # NOTE: It is imperative that the worker subprocesses be forked before # any gRPC servers start up. See # https://github.com/grpc/grpc/issues/16001 for more details. - worker = multiprocessing.Process(target=_run_server, - args=(bind_address,)) + worker = multiprocessing.Process( + target=_run_server, args=(bind_address,) + ) worker.start() workers.append(worker) for worker in workers: worker.join() -if __name__ == '__main__': +if __name__ == "__main__": handler = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter('[PID %(process)d] %(message)s') + formatter = logging.Formatter("[PID %(process)d] %(message)s") handler.setFormatter(formatter) _LOGGER.addHandler(handler) _LOGGER.setLevel(logging.INFO) diff --git a/examples/python/multiprocessing/test/_multiprocessing_example_test.py b/examples/python/multiprocessing/test/_multiprocessing_example_test.py index b8b7714125466..e323c6e67e1ec 100644 --- a/examples/python/multiprocessing/test/_multiprocessing_example_test.py +++ b/examples/python/multiprocessing/test/_multiprocessing_example_test.py @@ -23,9 +23,10 @@ import unittest _BINARY_DIR = os.path.realpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) -_SERVER_PATH = os.path.join(_BINARY_DIR, 'server') -_CLIENT_PATH = os.path.join(_BINARY_DIR, 'client') + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") +) +_SERVER_PATH = os.path.join(_BINARY_DIR, "server") +_CLIENT_PATH = os.path.join(_BINARY_DIR, "client") def is_prime(n): @@ -41,34 +42,35 @@ def _get_server_address(server_stream): server_stream.seek(0) line = server_stream.readline() while line: - matches = re.search('Binding to \'(.+)\'', line) + matches = re.search("Binding to '(.+)'", line) if matches is not None: return matches.groups()[0] line = server_stream.readline() class MultiprocessingExampleTest(unittest.TestCase): - def test_multiprocessing_example(self): - server_stdout = tempfile.TemporaryFile(mode='r') + server_stdout = tempfile.TemporaryFile(mode="r") server_process = subprocess.Popen((_SERVER_PATH,), stdout=server_stdout) server_address = _get_server_address(server_stdout) - client_stdout = tempfile.TemporaryFile(mode='r') - client_process = subprocess.Popen(( - _CLIENT_PATH, - server_address, - ), - stdout=client_stdout) + client_stdout = tempfile.TemporaryFile(mode="r") + client_process = subprocess.Popen( + ( + _CLIENT_PATH, + server_address, + ), + stdout=client_stdout, + ) client_process.wait() server_process.terminate() client_stdout.seek(0) - results = ast.literal_eval(client_stdout.read().strip().split('\n')[-1]) + results = ast.literal_eval(client_stdout.read().strip().split("\n")[-1]) values = tuple(result[0] for result in results) self.assertSequenceEqual(range(2, 10000), values) for result in results: self.assertEqual(is_prime(result[0]), result[1]) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/examples/python/no_codegen/greeter_client.py b/examples/python/no_codegen/greeter_client.py index c16d4db2576e9..47f1bca7358f0 100644 --- a/examples/python/no_codegen/greeter_client.py +++ b/examples/python/no_codegen/greeter_client.py @@ -35,7 +35,7 @@ logging.basicConfig() -response = services.Greeter.SayHello(protos.HelloRequest(name='you'), - 'localhost:50051', - insecure=True) +response = services.Greeter.SayHello( + protos.HelloRequest(name="you"), "localhost:50051", insecure=True +) print("Greeter client received: " + response.message) diff --git a/examples/python/no_codegen/greeter_server.py b/examples/python/no_codegen/greeter_server.py index 6ca052d377816..34d09a098c84f 100644 --- a/examples/python/no_codegen/greeter_server.py +++ b/examples/python/no_codegen/greeter_server.py @@ -22,19 +22,18 @@ class Greeter(services.GreeterServicer): - def SayHello(self, request, context): - return protos.HelloReply(message='Hello, %s!' % request.name) + return protos.HelloReply(message="Hello, %s!" % request.name) def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) services.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:50051') + server.add_insecure_port("[::]:50051") server.start() server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/retry/async_retry_client.py b/examples/python/retry/async_retry_client.py index faadf6a57d4f8..49cd12797fa7a 100644 --- a/examples/python/retry/async_retry_client.py +++ b/examples/python/retry/async_retry_client.py @@ -20,39 +20,44 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) async def run() -> None: # The ServiceConfig proto definition can be found: # https://github.com/grpc/grpc-proto/blob/ec886024c2f7b7f597ba89d5b7d60c3f94627b17/grpc/service_config/service_config.proto#L377 - service_config_json = json.dumps({ - "methodConfig": [{ - # To apply retry to all methods, put [{}] in the "name" field - "name": [{ - "service": "helloworld.Greeter", - "method": "SayHello" - }], - "retryPolicy": { - "maxAttempts": 5, - "initialBackoff": "0.1s", - "maxBackoff": "1s", - "backoffMultiplier": 2, - "retryableStatusCodes": ["UNAVAILABLE"], - }, - }] - }) + service_config_json = json.dumps( + { + "methodConfig": [ + { + # To apply retry to all methods, put [{}] in the "name" field + "name": [ + {"service": "helloworld.Greeter", "method": "SayHello"} + ], + "retryPolicy": { + "maxAttempts": 5, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": ["UNAVAILABLE"], + }, + } + ] + } + ) options = [] # NOTE: the retry feature will be enabled by default >=v1.40.0 options.append(("grpc.enable_retries", 1)) options.append(("grpc.service_config", service_config_json)) - async with grpc.aio.insecure_channel('localhost:50051', - options=options) as channel: + async with grpc.aio.insecure_channel( + "localhost:50051", options=options + ) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - response = await stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + response = await stub.SayHello(helloworld_pb2.HelloRequest(name="you")) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() asyncio.run(run()) diff --git a/examples/python/retry/flaky_server.py b/examples/python/retry/flaky_server.py index caa166bcdb923..d38584f8bdf07 100644 --- a/examples/python/retry/flaky_server.py +++ b/examples/python/retry/flaky_server.py @@ -21,38 +21,42 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) class ErrorInjectingGreeter(helloworld_pb2_grpc.GreeterServicer): - def __init__(self): self._counter = collections.defaultdict(int) async def SayHello( - self, request: helloworld_pb2.HelloRequest, - context: grpc.aio.ServicerContext) -> helloworld_pb2.HelloReply: + self, + request: helloworld_pb2.HelloRequest, + context: grpc.aio.ServicerContext, + ) -> helloworld_pb2.HelloReply: self._counter[context.peer()] += 1 if self._counter[context.peer()] < 5: if random.random() < 0.75: - logging.info('Injecting error to RPC from %s', context.peer()) - await context.abort(grpc.StatusCode.UNAVAILABLE, - 'injected error') - logging.info('Successfully responding to RPC from %s', context.peer()) - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + logging.info("Injecting error to RPC from %s", context.peer()) + await context.abort( + grpc.StatusCode.UNAVAILABLE, "injected error" + ) + logging.info("Successfully responding to RPC from %s", context.peer()) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) async def serve() -> None: server = grpc.aio.server() - helloworld_pb2_grpc.add_GreeterServicer_to_server(ErrorInjectingGreeter(), - server) - listen_addr = '[::]:50051' + helloworld_pb2_grpc.add_GreeterServicer_to_server( + ErrorInjectingGreeter(), server + ) + listen_addr = "[::]:50051" server.add_insecure_port(listen_addr) logging.info("Starting flaky server on %s", listen_addr) await server.start() await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(serve()) diff --git a/examples/python/retry/retry_client.py b/examples/python/retry/retry_client.py index 486340e44d071..6547675f3034b 100644 --- a/examples/python/retry/retry_client.py +++ b/examples/python/retry/retry_client.py @@ -19,38 +19,42 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) def run(): # The ServiceConfig proto definition can be found: # https://github.com/grpc/grpc-proto/blob/ec886024c2f7b7f597ba89d5b7d60c3f94627b17/grpc/service_config/service_config.proto#L377 - service_config_json = json.dumps({ - "methodConfig": [{ - # To apply retry to all methods, put [{}] in the "name" field - "name": [{ - "service": "helloworld.Greeter", - "method": "SayHello" - }], - "retryPolicy": { - "maxAttempts": 5, - "initialBackoff": "0.1s", - "maxBackoff": "1s", - "backoffMultiplier": 2, - "retryableStatusCodes": ["UNAVAILABLE"], - }, - }] - }) + service_config_json = json.dumps( + { + "methodConfig": [ + { + # To apply retry to all methods, put [{}] in the "name" field + "name": [ + {"service": "helloworld.Greeter", "method": "SayHello"} + ], + "retryPolicy": { + "maxAttempts": 5, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": ["UNAVAILABLE"], + }, + } + ] + } + ) options = [] # NOTE: the retry feature will be enabled by default >=v1.40.0 options.append(("grpc.enable_retries", 1)) options.append(("grpc.service_config", service_config_json)) - with grpc.insecure_channel('localhost:50051', options=options) as channel: + with grpc.insecure_channel("localhost:50051", options=options) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + response = stub.SayHello(helloworld_pb2.HelloRequest(name="you")) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/route_guide/asyncio_route_guide_client.py b/examples/python/route_guide/asyncio_route_guide_client.py index f387c860ef1bf..c874cc5369590 100644 --- a/examples/python/route_guide/asyncio_route_guide_client.py +++ b/examples/python/route_guide/asyncio_route_guide_client.py @@ -24,16 +24,19 @@ import route_guide_resources -def make_route_note(message: str, latitude: int, - longitude: int) -> route_guide_pb2.RouteNote: +def make_route_note( + message: str, latitude: int, longitude: int +) -> route_guide_pb2.RouteNote: return route_guide_pb2.RouteNote( message=message, - location=route_guide_pb2.Point(latitude=latitude, longitude=longitude)) + location=route_guide_pb2.Point(latitude=latitude, longitude=longitude), + ) # Performs an unary call -async def guide_get_one_feature(stub: route_guide_pb2_grpc.RouteGuideStub, - point: route_guide_pb2.Point) -> None: +async def guide_get_one_feature( + stub: route_guide_pb2_grpc.RouteGuideStub, point: route_guide_pb2.Point +) -> None: feature = await stub.GetFeature(point) if not feature.location: print("Server returned incomplete feature") @@ -50,20 +53,25 @@ async def guide_get_feature(stub: route_guide_pb2_grpc.RouteGuideStub) -> None: # and scheduled in the event loop so that they can run concurrently task_group = asyncio.gather( guide_get_one_feature( - stub, route_guide_pb2.Point(latitude=409146138, - longitude=-746188906)), - guide_get_one_feature(stub, - route_guide_pb2.Point(latitude=0, longitude=0))) + stub, + route_guide_pb2.Point(latitude=409146138, longitude=-746188906), + ), + guide_get_one_feature( + stub, route_guide_pb2.Point(latitude=0, longitude=0) + ), + ) # Wait until the Future is resolved await task_group # Performs a server-streaming call async def guide_list_features( - stub: route_guide_pb2_grpc.RouteGuideStub) -> None: + stub: route_guide_pb2_grpc.RouteGuideStub, +) -> None: rectangle = route_guide_pb2.Rectangle( lo=route_guide_pb2.Point(latitude=400000000, longitude=-750000000), - hi=route_guide_pb2.Point(latitude=420000000, longitude=-730000000)) + hi=route_guide_pb2.Point(latitude=420000000, longitude=-730000000), + ) print("Looking for features between 40, -75 and 42, -73") features = stub.ListFeatures(rectangle) @@ -73,7 +81,7 @@ async def guide_list_features( def generate_route( - feature_list: List[route_guide_pb2.Feature] + feature_list: List[route_guide_pb2.Feature], ) -> Iterable[route_guide_pb2.Point]: for _ in range(0, 10): random_feature = random.choice(feature_list) @@ -118,7 +126,7 @@ async def guide_route_chat(stub: route_guide_pb2_grpc.RouteGuideStub) -> None: async def main() -> None: - async with grpc.aio.insecure_channel('localhost:50051') as channel: + async with grpc.aio.insecure_channel("localhost:50051") as channel: stub = route_guide_pb2_grpc.RouteGuideStub(channel) print("-------------- GetFeature --------------") await guide_get_feature(stub) @@ -130,6 +138,6 @@ async def main() -> None: await guide_route_chat(stub) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.get_event_loop().run_until_complete(main()) diff --git a/examples/python/route_guide/asyncio_route_guide_server.py b/examples/python/route_guide/asyncio_route_guide_server.py index b948da1712544..1c3901e47c101 100644 --- a/examples/python/route_guide/asyncio_route_guide_server.py +++ b/examples/python/route_guide/asyncio_route_guide_server.py @@ -25,8 +25,9 @@ import route_guide_resources -def get_feature(feature_db: Iterable[route_guide_pb2.Feature], - point: route_guide_pb2.Point) -> route_guide_pb2.Feature: +def get_feature( + feature_db: Iterable[route_guide_pb2.Feature], point: route_guide_pb2.Point +) -> route_guide_pb2.Feature: """Returns Feature at given location or None.""" for feature in feature_db: if feature.location == point: @@ -34,8 +35,9 @@ def get_feature(feature_db: Iterable[route_guide_pb2.Feature], return None -def get_distance(start: route_guide_pb2.Point, - end: route_guide_pb2.Point) -> float: +def get_distance( + start: route_guide_pb2.Point, end: route_guide_pb2.Point +) -> float: """Distance between two points.""" coord_factor = 10000000.0 lat_1 = start.latitude / coord_factor @@ -48,9 +50,11 @@ def get_distance(start: route_guide_pb2.Point, delta_lon_rad = math.radians(lon_2 - lon_1) # Formula is based on http://mathforum.org/library/drmath/view/51879.html - a = (pow(math.sin(delta_lat_rad / 2), 2) + - (math.cos(lat_rad_1) * math.cos(lat_rad_2) * - pow(math.sin(delta_lon_rad / 2), 2))) + a = pow(math.sin(delta_lat_rad / 2), 2) + ( + math.cos(lat_rad_1) + * math.cos(lat_rad_2) + * pow(math.sin(delta_lon_rad / 2), 2) + ) c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) R = 6371000 # metres @@ -63,8 +67,9 @@ class RouteGuideServicer(route_guide_pb2_grpc.RouteGuideServicer): def __init__(self) -> None: self.db = route_guide_resources.read_route_guide_database() - def GetFeature(self, request: route_guide_pb2.Point, - unused_context) -> route_guide_pb2.Feature: + def GetFeature( + self, request: route_guide_pb2.Point, unused_context + ) -> route_guide_pb2.Feature: feature = get_feature(self.db, request) if feature is None: return route_guide_pb2.Feature(name="", location=request) @@ -72,21 +77,26 @@ def GetFeature(self, request: route_guide_pb2.Point, return feature async def ListFeatures( - self, request: route_guide_pb2.Rectangle, - unused_context) -> AsyncIterable[route_guide_pb2.Feature]: + self, request: route_guide_pb2.Rectangle, unused_context + ) -> AsyncIterable[route_guide_pb2.Feature]: left = min(request.lo.longitude, request.hi.longitude) right = max(request.lo.longitude, request.hi.longitude) top = max(request.lo.latitude, request.hi.latitude) bottom = min(request.lo.latitude, request.hi.latitude) for feature in self.db: - if (feature.location.longitude >= left and - feature.location.longitude <= right and - feature.location.latitude >= bottom and - feature.location.latitude <= top): + if ( + feature.location.longitude >= left + and feature.location.longitude <= right + and feature.location.latitude >= bottom + and feature.location.latitude <= top + ): yield feature - async def RecordRoute(self, request_iterator: AsyncIterable[ - route_guide_pb2.Point], unused_context) -> route_guide_pb2.RouteSummary: + async def RecordRoute( + self, + request_iterator: AsyncIterable[route_guide_pb2.Point], + unused_context, + ) -> route_guide_pb2.RouteSummary: point_count = 0 feature_count = 0 distance = 0.0 @@ -102,14 +112,18 @@ async def RecordRoute(self, request_iterator: AsyncIterable[ prev_point = point elapsed_time = time.time() - start_time - return route_guide_pb2.RouteSummary(point_count=point_count, - feature_count=feature_count, - distance=int(distance), - elapsed_time=int(elapsed_time)) + return route_guide_pb2.RouteSummary( + point_count=point_count, + feature_count=feature_count, + distance=int(distance), + elapsed_time=int(elapsed_time), + ) async def RouteChat( - self, request_iterator: AsyncIterable[route_guide_pb2.RouteNote], - unused_context) -> AsyncIterable[route_guide_pb2.RouteNote]: + self, + request_iterator: AsyncIterable[route_guide_pb2.RouteNote], + unused_context, + ) -> AsyncIterable[route_guide_pb2.RouteNote]: prev_notes = [] async for new_note in request_iterator: for prev_note in prev_notes: @@ -121,12 +135,13 @@ async def RouteChat( async def serve() -> None: server = grpc.aio.server() route_guide_pb2_grpc.add_RouteGuideServicer_to_server( - RouteGuideServicer(), server) - server.add_insecure_port('[::]:50051') + RouteGuideServicer(), server + ) + server.add_insecure_port("[::]:50051") await server.start() await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.get_event_loop().run_until_complete(serve()) diff --git a/examples/python/route_guide/route_guide_client.py b/examples/python/route_guide/route_guide_client.py index c7383d3b6ec1b..2210be6c1bfd3 100644 --- a/examples/python/route_guide/route_guide_client.py +++ b/examples/python/route_guide/route_guide_client.py @@ -27,7 +27,8 @@ def make_route_note(message, latitude, longitude): return route_guide_pb2.RouteNote( message=message, - location=route_guide_pb2.Point(latitude=latitude, longitude=longitude)) + location=route_guide_pb2.Point(latitude=latitude, longitude=longitude), + ) def guide_get_one_feature(stub, point): @@ -44,14 +45,16 @@ def guide_get_one_feature(stub, point): def guide_get_feature(stub): guide_get_one_feature( - stub, route_guide_pb2.Point(latitude=409146138, longitude=-746188906)) + stub, route_guide_pb2.Point(latitude=409146138, longitude=-746188906) + ) guide_get_one_feature(stub, route_guide_pb2.Point(latitude=0, longitude=0)) def guide_list_features(stub): rectangle = route_guide_pb2.Rectangle( lo=route_guide_pb2.Point(latitude=400000000, longitude=-750000000), - hi=route_guide_pb2.Point(latitude=420000000, longitude=-730000000)) + hi=route_guide_pb2.Point(latitude=420000000, longitude=-730000000), + ) print("Looking for features between 40, -75 and 42, -73") features = stub.ListFeatures(rectangle) @@ -94,15 +97,16 @@ def generate_messages(): def guide_route_chat(stub): responses = stub.RouteChat(generate_messages()) for response in responses: - print("Received message %s at %s" % - (response.message, response.location)) + print( + "Received message %s at %s" % (response.message, response.location) + ) def run(): # NOTE(gRPC Python Team): .close() is possible on a channel and should be # used in circumstances in which the with statement does not fit the needs # of the code. - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: stub = route_guide_pb2_grpc.RouteGuideStub(channel) print("-------------- GetFeature --------------") guide_get_feature(stub) @@ -114,6 +118,6 @@ def run(): guide_route_chat(stub) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/route_guide/route_guide_resources.py b/examples/python/route_guide/route_guide_resources.py index ace85d6f9d431..7cee03adabc6c 100644 --- a/examples/python/route_guide/route_guide_resources.py +++ b/examples/python/route_guide/route_guide_resources.py @@ -21,10 +21,10 @@ def read_route_guide_database(): """Reads the route guide database. - Returns: - The full contents of the route guide database as a sequence of - route_guide_pb2.Features. - """ + Returns: + The full contents of the route guide database as a sequence of + route_guide_pb2.Features. + """ feature_list = [] with open("route_guide_db.json") as route_guide_db_file: for item in json.load(route_guide_db_file): @@ -32,6 +32,8 @@ def read_route_guide_database(): name=item["name"], location=route_guide_pb2.Point( latitude=item["location"]["latitude"], - longitude=item["location"]["longitude"])) + longitude=item["location"]["longitude"], + ), + ) feature_list.append(feature) return feature_list diff --git a/examples/python/route_guide/route_guide_server.py b/examples/python/route_guide/route_guide_server.py index e58f2ad299c16..f44ac43f778c0 100644 --- a/examples/python/route_guide/route_guide_server.py +++ b/examples/python/route_guide/route_guide_server.py @@ -45,9 +45,11 @@ def get_distance(start, end): delta_lon_rad = math.radians(lon_2 - lon_1) # Formula is based on http://mathforum.org/library/drmath/view/51879.html - a = (pow(math.sin(delta_lat_rad / 2), 2) + - (math.cos(lat_rad_1) * math.cos(lat_rad_2) * - pow(math.sin(delta_lon_rad / 2), 2))) + a = pow(math.sin(delta_lat_rad / 2), 2) + ( + math.cos(lat_rad_1) + * math.cos(lat_rad_2) + * pow(math.sin(delta_lon_rad / 2), 2) + ) c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) R = 6371000 # metres @@ -73,10 +75,12 @@ def ListFeatures(self, request, context): top = max(request.lo.latitude, request.hi.latitude) bottom = min(request.lo.latitude, request.hi.latitude) for feature in self.db: - if (feature.location.longitude >= left and - feature.location.longitude <= right and - feature.location.latitude >= bottom and - feature.location.latitude <= top): + if ( + feature.location.longitude >= left + and feature.location.longitude <= right + and feature.location.latitude >= bottom + and feature.location.latitude <= top + ): yield feature def RecordRoute(self, request_iterator, context): @@ -95,10 +99,12 @@ def RecordRoute(self, request_iterator, context): prev_point = point elapsed_time = time.time() - start_time - return route_guide_pb2.RouteSummary(point_count=point_count, - feature_count=feature_count, - distance=int(distance), - elapsed_time=int(elapsed_time)) + return route_guide_pb2.RouteSummary( + point_count=point_count, + feature_count=feature_count, + distance=int(distance), + elapsed_time=int(elapsed_time), + ) def RouteChat(self, request_iterator, context): prev_notes = [] @@ -112,12 +118,13 @@ def RouteChat(self, request_iterator, context): def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) route_guide_pb2_grpc.add_RouteGuideServicer_to_server( - RouteGuideServicer(), server) - server.add_insecure_port('[::]:50051') + RouteGuideServicer(), server + ) + server.add_insecure_port("[::]:50051") server.start() server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/route_guide/run_codegen.py b/examples/python/route_guide/run_codegen.py index 8df562d3497b1..aa237cd8c9738 100644 --- a/examples/python/route_guide/run_codegen.py +++ b/examples/python/route_guide/run_codegen.py @@ -15,10 +15,12 @@ from grpc_tools import protoc -protoc.main(( - '', - '-I../../protos', - '--python_out=.', - '--grpc_python_out=.', - '../../protos/route_guide.proto', -)) +protoc.main( + ( + "", + "-I../../protos", + "--python_out=.", + "--grpc_python_out=.", + "../../protos/route_guide.proto", + ) +) diff --git a/examples/python/timeout/greeter_client.py b/examples/python/timeout/greeter_client.py index d1487994ac575..2420275bf19b3 100644 --- a/examples/python/timeout/greeter_client.py +++ b/examples/python/timeout/greeter_client.py @@ -20,26 +20,28 @@ import helloworld_pb2_grpc -def unary_call(stub: helloworld_pb2_grpc.GreeterStub, request_id: int, - message: str): +def unary_call( + stub: helloworld_pb2_grpc.GreeterStub, request_id: int, message: str +): print("call:", request_id) try: - response = stub.SayHello(helloworld_pb2.HelloRequest(name=message), - timeout=3) + response = stub.SayHello( + helloworld_pb2.HelloRequest(name=message), timeout=3 + ) print(f"Greeter client received: {response.message}") except grpc.RpcError as rpc_error: print(f"Call failed with code: {rpc_error.code()}") def run(): - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) # Should success - unary_call(stub, 1, 'you') + unary_call(stub, 1, "you") # Should fail with DEADLINE_EXCEEDED - unary_call(stub, 2, '[delay] you') + unary_call(stub, 2, "[delay] you") -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() run() diff --git a/examples/python/timeout/greeter_server.py b/examples/python/timeout/greeter_server.py index 14b050b926ba0..6506b85dd9821 100644 --- a/examples/python/timeout/greeter_server.py +++ b/examples/python/timeout/greeter_server.py @@ -23,7 +23,6 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): message = request.name if message.startswith("[delay]"): @@ -32,15 +31,15 @@ def SayHello(self, request, context): def serve(): - port = '50051' + port = "50051" server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:' + port) + server.add_insecure_port("[::]:" + port) server.start() print("Server started, listening on " + port) server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/uds/async_greeter_client.py b/examples/python/uds/async_greeter_client.py index 7c386cc84f1ef..1513032419c26 100644 --- a/examples/python/uds/async_greeter_client.py +++ b/examples/python/uds/async_greeter_client.py @@ -22,15 +22,16 @@ async def run() -> None: - uds_addresses = ['unix:helloworld.sock', 'unix:///tmp/helloworld.sock'] + uds_addresses = ["unix:helloworld.sock", "unix:///tmp/helloworld.sock"] for uds_address in uds_addresses: async with grpc.aio.insecure_channel(uds_address) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) response = await stub.SayHello( - helloworld_pb2.HelloRequest(name='you')) - logging.info('Received: %s', response.message) + helloworld_pb2.HelloRequest(name="you") + ) + logging.info("Received: %s", response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(run()) diff --git a/examples/python/uds/async_greeter_server.py b/examples/python/uds/async_greeter_server.py index b27d6aaa9de51..f55b55ceeaf6a 100644 --- a/examples/python/uds/async_greeter_server.py +++ b/examples/python/uds/async_greeter_server.py @@ -22,25 +22,26 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - async def SayHello( - self, request: helloworld_pb2.HelloRequest, - context: grpc.aio.ServicerContext) -> helloworld_pb2.HelloReply: + self, + request: helloworld_pb2.HelloRequest, + context: grpc.aio.ServicerContext, + ) -> helloworld_pb2.HelloReply: del request - return helloworld_pb2.HelloReply(message=f'Hello to {context.peer()}!') + return helloworld_pb2.HelloReply(message=f"Hello to {context.peer()}!") async def serve() -> None: - uds_addresses = ['unix:helloworld.sock', 'unix:///tmp/helloworld.sock'] + uds_addresses = ["unix:helloworld.sock", "unix:///tmp/helloworld.sock"] server = grpc.aio.server() helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) for uds_address in uds_addresses: server.add_insecure_port(uds_address) - logging.info('Server listening on: %s', uds_address) + logging.info("Server listening on: %s", uds_address) await server.start() await server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(serve()) diff --git a/examples/python/uds/greeter_client.py b/examples/python/uds/greeter_client.py index 4acb9de8aa6b7..b8cc51e685877 100644 --- a/examples/python/uds/greeter_client.py +++ b/examples/python/uds/greeter_client.py @@ -23,14 +23,14 @@ def run(): - uds_addresses = ['unix:helloworld.sock', 'unix:///tmp/helloworld.sock'] + uds_addresses = ["unix:helloworld.sock", "unix:///tmp/helloworld.sock"] for uds_address in uds_addresses: with grpc.insecure_channel(uds_address) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) - logging.info('Received: %s', response.message) + response = stub.SayHello(helloworld_pb2.HelloRequest(name="you")) + logging.info("Received: %s", response.message) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) run() diff --git a/examples/python/uds/greeter_server.py b/examples/python/uds/greeter_server.py index 9288cb9551136..aca3c877f312c 100644 --- a/examples/python/uds/greeter_server.py +++ b/examples/python/uds/greeter_server.py @@ -22,23 +22,22 @@ class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): del request - return helloworld_pb2.HelloReply(message=f'Hello to {context.peer()}!') + return helloworld_pb2.HelloReply(message=f"Hello to {context.peer()}!") def serve(): - uds_addresses = ['unix:helloworld.sock', 'unix:///tmp/helloworld.sock'] + uds_addresses = ["unix:helloworld.sock", "unix:///tmp/helloworld.sock"] server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) for uds_address in uds_addresses: server.add_insecure_port(uds_address) - logging.info('Server listening on: %s', uds_address) + logging.info("Server listening on: %s", uds_address) server.start() server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) serve() diff --git a/examples/python/wait_for_ready/asyncio_wait_for_ready_example.py b/examples/python/wait_for_ready/asyncio_wait_for_ready_example.py index 97922062c1329..ffacddfca8078 100644 --- a/examples/python/wait_for_ready/asyncio_wait_for_ready_example.py +++ b/examples/python/wait_for_ready/asyncio_wait_for_ready_example.py @@ -22,7 +22,8 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) @@ -34,32 +35,35 @@ def get_free_loopback_tcp_port() -> Iterable[str]: tcp_socket = socket.socket(socket.AF_INET6) else: tcp_socket = socket.socket(socket.AF_INET) - tcp_socket.bind(('', 0)) + tcp_socket.bind(("", 0)) address_tuple = tcp_socket.getsockname() yield f"localhost:{address_tuple[1]}" tcp_socket.close() class Greeter(helloworld_pb2_grpc.GreeterServicer): - - async def SayHello(self, request: helloworld_pb2.HelloRequest, - unused_context) -> helloworld_pb2.HelloReply: - return helloworld_pb2.HelloReply(message=f'Hello, {request.name}!') + async def SayHello( + self, request: helloworld_pb2.HelloRequest, unused_context + ) -> helloworld_pb2.HelloReply: + return helloworld_pb2.HelloReply(message=f"Hello, {request.name}!") def create_server(server_address: str): server = grpc.aio.server() helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) bound_port = server.add_insecure_port(server_address) - assert bound_port == int(server_address.split(':')[-1]) + assert bound_port == int(server_address.split(":")[-1]) return server -async def process(stub: helloworld_pb2_grpc.GreeterStub, - wait_for_ready: bool = None) -> None: +async def process( + stub: helloworld_pb2_grpc.GreeterStub, wait_for_ready: bool = None +) -> None: try: - response = await stub.SayHello(helloworld_pb2.HelloRequest(name='you'), - wait_for_ready=wait_for_ready) + response = await stub.SayHello( + helloworld_pb2.HelloRequest(name="you"), + wait_for_ready=wait_for_ready, + ) message = response.message except grpc.aio.AioRpcError as rpc_error: assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE @@ -67,8 +71,11 @@ async def process(stub: helloworld_pb2_grpc.GreeterStub, message = rpc_error else: assert wait_for_ready - _LOGGER.info("Wait-for-ready %s, client received: %s", - "enabled" if wait_for_ready else "disabled", message) + _LOGGER.info( + "Wait-for-ready %s, client received: %s", + "enabled" if wait_for_ready else "disabled", + message, + ) async def main() -> None: @@ -80,10 +87,12 @@ async def main() -> None: # Fire an RPC without wait_for_ready fail_fast_task = asyncio.get_event_loop().create_task( - process(stub, wait_for_ready=False)) + process(stub, wait_for_ready=False) + ) # Fire an RPC with wait_for_ready wait_for_ready_task = asyncio.get_event_loop().create_task( - process(stub, wait_for_ready=True)) + process(stub, wait_for_ready=True) + ) # Wait for the channel entering TRANSIENT FAILURE state. state = channel.get_state() @@ -104,6 +113,6 @@ async def main() -> None: await channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.get_event_loop().run_until_complete(main()) diff --git a/examples/python/wait_for_ready/test/_wait_for_ready_example_test.py b/examples/python/wait_for_ready/test/_wait_for_ready_example_test.py index ebc8164ad30ee..6a433024064c2 100644 --- a/examples/python/wait_for_ready/test/_wait_for_ready_example_test.py +++ b/examples/python/wait_for_ready/test/_wait_for_ready_example_test.py @@ -22,17 +22,17 @@ class WaitForReadyExampleTest(unittest.TestCase): - def test_wait_for_ready_example(self): wait_for_ready_example.main() # No unhandled exception raised, no deadlock, test passed! def test_asyncio_wait_for_ready_example(self): asyncio.get_event_loop().run_until_complete( - asyncio_wait_for_ready_example.main()) + asyncio_wait_for_ready_example.main() + ) # No unhandled exception raised, no deadlock, test passed! -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/examples/python/wait_for_ready/wait_for_ready_example.py b/examples/python/wait_for_ready/wait_for_ready_example.py index c3fa90df2b893..8aca4a591ec32 100644 --- a/examples/python/wait_for_ready/wait_for_ready_example.py +++ b/examples/python/wait_for_ready/wait_for_ready_example.py @@ -22,7 +22,8 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) @@ -34,30 +35,31 @@ def get_free_loopback_tcp_port(): tcp_socket = socket.socket(socket.AF_INET6) else: tcp_socket = socket.socket(socket.AF_INET) - tcp_socket.bind(('', 0)) + tcp_socket.bind(("", 0)) address_tuple = tcp_socket.getsockname() yield "localhost:%s" % (address_tuple[1]) tcp_socket.close() class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, unused_context): - return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name) + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) def create_server(server_address): server = grpc.server(futures.ThreadPoolExecutor()) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) bound_port = server.add_insecure_port(server_address) - assert bound_port == int(server_address.split(':')[-1]) + assert bound_port == int(server_address.split(":")[-1]) return server def process(stub, wait_for_ready=None): try: - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you'), - wait_for_ready=wait_for_ready) + response = stub.SayHello( + helloworld_pb2.HelloRequest(name="you"), + wait_for_ready=wait_for_ready, + ) message = response.message except grpc.RpcError as rpc_error: assert rpc_error.code() == grpc.StatusCode.UNAVAILABLE @@ -65,19 +67,24 @@ def process(stub, wait_for_ready=None): message = rpc_error else: assert wait_for_ready - _LOGGER.info("Wait-for-ready %s, client received: %s", - "enabled" if wait_for_ready else "disabled", message) + _LOGGER.info( + "Wait-for-ready %s, client received: %s", + "enabled" if wait_for_ready else "disabled", + message, + ) def main(): # Pick a random free port with get_free_loopback_tcp_port() as server_address: - # Register connectivity event to notify main thread transient_failure_event = threading.Event() def wait_for_transient_failure(channel_connectivity): - if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: + if ( + channel_connectivity + == grpc.ChannelConnectivity.TRANSIENT_FAILURE + ): transient_failure_event.set() # Create gRPC channel @@ -86,12 +93,14 @@ def wait_for_transient_failure(channel_connectivity): stub = helloworld_pb2_grpc.GreeterStub(channel) # Fire an RPC without wait_for_ready - thread_disabled_wait_for_ready = threading.Thread(target=process, - args=(stub, False)) + thread_disabled_wait_for_ready = threading.Thread( + target=process, args=(stub, False) + ) thread_disabled_wait_for_ready.start() # Fire an RPC with wait_for_ready - thread_enabled_wait_for_ready = threading.Thread(target=process, - args=(stub, True)) + thread_enabled_wait_for_ready = threading.Thread( + target=process, args=(stub, True) + ) thread_enabled_wait_for_ready.start() # Wait for the channel entering TRANSIENT FAILURE state. @@ -108,6 +117,6 @@ def wait_for_transient_failure(channel_connectivity): channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main() diff --git a/examples/python/wait_for_ready/wait_for_ready_with_client_timeout_example_client.py b/examples/python/wait_for_ready/wait_for_ready_with_client_timeout_example_client.py index f911de2f91600..0c0d3cb16ad23 100644 --- a/examples/python/wait_for_ready/wait_for_ready_with_client_timeout_example_client.py +++ b/examples/python/wait_for_ready/wait_for_ready_with_client_timeout_example_client.py @@ -28,7 +28,8 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) @@ -37,8 +38,10 @@ def wait_for_metadata(response_future, event): metadata: Sequence[Tuple[str, str]] = response_future.initial_metadata() for key, value in metadata: - print('Greeter client received initial metadata: key=%s value=%s' % - (key, value)) + print( + "Greeter client received initial metadata: key=%s value=%s" + % (key, value) + ) event.set() @@ -55,20 +58,22 @@ def check_status(response_future, wait_success): def main(): # Create gRPC channel - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) event_for_delay = threading.Event() # Server will delay send initial metadata back for this RPC response_future_delay = stub.SayHelloStreamReply( - helloworld_pb2.HelloRequest(name='you'), wait_for_ready=True) + helloworld_pb2.HelloRequest(name="you"), wait_for_ready=True + ) # Fire RPC and wait for metadata - thread_with_delay = threading.Thread(target=wait_for_metadata, - args=(response_future_delay, - event_for_delay), - daemon=True) + thread_with_delay = threading.Thread( + target=wait_for_metadata, + args=(response_future_delay, event_for_delay), + daemon=True, + ) thread_with_delay.start() # Wait on client side with 7 seconds timeout @@ -76,6 +81,6 @@ def main(): check_status(response_future_delay, event_for_delay.wait(timeout)) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main() diff --git a/examples/python/wait_for_ready/wait_for_ready_with_client_timeout_example_server.py b/examples/python/wait_for_ready/wait_for_ready_with_client_timeout_example_server.py index 92754debd1f7c..7259d25e94f78 100644 --- a/examples/python/wait_for_ready/wait_for_ready_with_client_timeout_example_server.py +++ b/examples/python/wait_for_ready/wait_for_ready_with_client_timeout_example_server.py @@ -29,9 +29,10 @@ import grpc helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services( - "helloworld.proto") + "helloworld.proto" +) -_INITIAL_METADATA = ((b'initial-md', 'initial-md-value'),) +_INITIAL_METADATA = ((b"initial-md", "initial-md-value"),) def starting_up_server(): @@ -45,7 +46,6 @@ def do_work(): class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHelloStreamReply(self, request, servicer_context): # Suppose server will take some time to setup, client can set the time it willing to wait # for server to up and running. @@ -60,19 +60,20 @@ def SayHelloStreamReply(self, request, servicer_context): # Sending actual response. for i in range(3): - yield helloworld_pb2.HelloReply(message='Hello %s times %s' % - (request.name, i)) + yield helloworld_pb2.HelloReply( + message="Hello %s times %s" % (request.name, i) + ) def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) - server.add_insecure_port('[::]:50051') + server.add_insecure_port("[::]:50051") print("starting server") server.start() server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() serve() diff --git a/examples/python/xds/client.py b/examples/python/xds/client.py index b422dcdc0cbd2..4a56573eeee47 100644 --- a/examples/python/xds/client.py +++ b/examples/python/xds/client.py @@ -35,19 +35,20 @@ def run(server_address, secure): channel = grpc.insecure_channel(server_address) with channel: stub = helloworld_pb2_grpc.GreeterStub(channel) - response = stub.SayHello(helloworld_pb2.HelloRequest(name='you')) + response = stub.SayHello(helloworld_pb2.HelloRequest(name="you")) print("Greeter client received: " + response.message) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser(description=_DESCRIPTION) - parser.add_argument("server", - default=None, - help="The address of the server.") + parser.add_argument( + "server", default=None, help="The address of the server." + ) parser.add_argument( "--xds-creds", action="store_true", - help="If specified, uses xDS credentials to connect to the server.") + help="If specified, uses xDS credentials to connect to the server.", + ) args = parser.parse_args() logging.basicConfig() run(args.server, args.xds_creds) diff --git a/examples/python/xds/server.py b/examples/python/xds/server.py index cedcffc2338ed..bea103876bfe0 100644 --- a/examples/python/xds/server.py +++ b/examples/python/xds/server.py @@ -34,24 +34,28 @@ logger = logging.getLogger() console_handler = logging.StreamHandler() -formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') +formatter = logging.Formatter(fmt="%(asctime)s: %(levelname)-8s %(message)s") console_handler.setFormatter(formatter) logger.addHandler(console_handler) class Greeter(helloworld_pb2_grpc.GreeterServicer): - def __init__(self, hostname: str): self._hostname = hostname if hostname else socket.gethostname() - def SayHello(self, request: helloworld_pb2.HelloRequest, - context: grpc.ServicerContext) -> helloworld_pb2.HelloReply: + def SayHello( + self, + request: helloworld_pb2.HelloRequest, + context: grpc.ServicerContext, + ) -> helloworld_pb2.HelloReply: return helloworld_pb2.HelloReply( - message=f"Hello {request.name} from {self._hostname}!") + message=f"Hello {request.name} from {self._hostname}!" + ) -def _configure_maintenance_server(server: grpc.Server, - maintenance_port: int) -> None: +def _configure_maintenance_server( + server: grpc.Server, maintenance_port: int +) -> None: listen_address = f"{_LISTEN_HOST}:{maintenance_port}" server.add_insecure_port(listen_address) @@ -60,13 +64,15 @@ def _configure_maintenance_server(server: grpc.Server, health_servicer = health.HealthServicer( experimental_non_blocking=True, experimental_thread_pool=futures.ThreadPoolExecutor( - max_workers=_THREAD_POOL_SIZE)) + max_workers=_THREAD_POOL_SIZE + ), + ) # Create a tuple of all of the services we want to export via reflection. services = tuple( service.full_name - for service in helloworld_pb2.DESCRIPTOR.services_by_name.values()) + ( - reflection.SERVICE_NAME, health.SERVICE_NAME) + for service in helloworld_pb2.DESCRIPTOR.services_by_name.values() + ) + (reflection.SERVICE_NAME, health.SERVICE_NAME) # Mark all services as healthy. health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) @@ -75,8 +81,9 @@ def _configure_maintenance_server(server: grpc.Server, reflection.enable_server_reflection(services, server) -def _configure_greeter_server(server: grpc.Server, port: int, secure_mode: bool, - hostname) -> None: +def _configure_greeter_server( + server: grpc.Server, port: int, secure_mode: bool, hostname +) -> None: # Add the application servicer to the server. helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(hostname), server) listen_address = f"{_LISTEN_HOST}:{port}" @@ -92,12 +99,14 @@ def _configure_greeter_server(server: grpc.Server, port: int, secure_mode: bool, server.add_secure_port(listen_address, server_creds) -def serve(port: int, hostname: str, maintenance_port: int, - secure_mode: bool) -> None: +def serve( + port: int, hostname: str, maintenance_port: int, secure_mode: bool +) -> None: if port == maintenance_port: # If maintenance port and port are the same, start a single server. server = grpc.server( - futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE)) + futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE) + ) _configure_greeter_server(server, port, secure_mode, hostname) _configure_maintenance_server(server, maintenance_port) server.start() @@ -108,12 +117,14 @@ def serve(port: int, hostname: str, maintenance_port: int, # Otherwise, start two different servers. greeter_server = grpc.server( futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE), - xds=secure_mode) + xds=secure_mode, + ) _configure_greeter_server(greeter_server, port, secure_mode, hostname) greeter_server.start() logger.info("Greeter server listening on port %d", port) maintenance_server = grpc.server( - futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE)) + futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE) + ) _configure_maintenance_server(maintenance_server, maintenance_port) maintenance_server.start() logger.info("Maintenance server listening on port %d", maintenance_port) @@ -121,22 +132,27 @@ def serve(port: int, hostname: str, maintenance_port: int, maintenance_server.wait_for_termination() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser(description=_DESCRIPTION) - parser.add_argument("port", - default=50051, - type=int, - nargs="?", - help="The port on which to listen.") - parser.add_argument("hostname", - type=str, - default=None, - nargs="?", - help="The name clients will see in responses.") + parser.add_argument( + "port", + default=50051, + type=int, + nargs="?", + help="The port on which to listen.", + ) + parser.add_argument( + "hostname", + type=str, + default=None, + nargs="?", + help="The name clients will see in responses.", + ) parser.add_argument( "--xds-creds", action="store_true", - help="If specified, uses xDS credentials to connect to the server.") + help="If specified, uses xDS credentials to connect to the server.", + ) args = parser.parse_args() logging.basicConfig() logger.setLevel(logging.INFO) diff --git a/setup.cfg b/setup.cfg index 7bf1b5e4236bd..3ac6b513ae3a9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,10 +12,6 @@ inplace=1 [build_package_protos] exclude=.*protoc_plugin/protoc_plugin_test\.proto$ -# Style settings -[yapf] -based_on_style = google - [metadata] license_files = LICENSE diff --git a/setup.py b/setup.py index 6c07738ff0e8c..57364bfb04495 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ # files used by boring SSL. from distutils.unixccompiler import UnixCCompiler -UnixCCompiler.src_extensions.append('.S') +UnixCCompiler.src_extensions.append(".S") del UnixCCompiler from distutils import cygwinccompiler @@ -44,42 +44,46 @@ from setuptools.command import egg_info # Redirect the manifest template from MANIFEST.in to PYTHON-MANIFEST.in. -egg_info.manifest_maker.template = 'PYTHON-MANIFEST.in' +egg_info.manifest_maker.template = "PYTHON-MANIFEST.in" PY3 = sys.version_info.major == 3 -PYTHON_STEM = os.path.join('src', 'python', 'grpcio') +PYTHON_STEM = os.path.join("src", "python", "grpcio") CORE_INCLUDE = ( - 'include', - '.', + "include", + ".", +) +ABSL_INCLUDE = (os.path.join("third_party", "abseil-cpp"),) +ADDRESS_SORTING_INCLUDE = ( + os.path.join("third_party", "address_sorting", "include"), ) -ABSL_INCLUDE = (os.path.join('third_party', 'abseil-cpp'),) -ADDRESS_SORTING_INCLUDE = (os.path.join('third_party', 'address_sorting', - 'include'),) CARES_INCLUDE = ( - os.path.join('third_party', 'cares', 'cares', 'include'), - os.path.join('third_party', 'cares'), - os.path.join('third_party', 'cares', 'cares'), + os.path.join("third_party", "cares", "cares", "include"), + os.path.join("third_party", "cares"), + os.path.join("third_party", "cares", "cares"), +) +if "darwin" in sys.platform: + CARES_INCLUDE += (os.path.join("third_party", "cares", "config_darwin"),) +if "freebsd" in sys.platform: + CARES_INCLUDE += (os.path.join("third_party", "cares", "config_freebsd"),) +if "linux" in sys.platform: + CARES_INCLUDE += (os.path.join("third_party", "cares", "config_linux"),) +if "openbsd" in sys.platform: + CARES_INCLUDE += (os.path.join("third_party", "cares", "config_openbsd"),) +RE2_INCLUDE = (os.path.join("third_party", "re2"),) +SSL_INCLUDE = ( + os.path.join("third_party", "boringssl-with-bazel", "src", "include"), +) +UPB_INCLUDE = (os.path.join("third_party", "upb"),) +UPB_GRPC_GENERATED_INCLUDE = ( + os.path.join("src", "core", "ext", "upb-generated"), ) -if 'darwin' in sys.platform: - CARES_INCLUDE += (os.path.join('third_party', 'cares', 'config_darwin'),) -if 'freebsd' in sys.platform: - CARES_INCLUDE += (os.path.join('third_party', 'cares', 'config_freebsd'),) -if 'linux' in sys.platform: - CARES_INCLUDE += (os.path.join('third_party', 'cares', 'config_linux'),) -if 'openbsd' in sys.platform: - CARES_INCLUDE += (os.path.join('third_party', 'cares', 'config_openbsd'),) -RE2_INCLUDE = (os.path.join('third_party', 're2'),) -SSL_INCLUDE = (os.path.join('third_party', 'boringssl-with-bazel', 'src', - 'include'),) -UPB_INCLUDE = (os.path.join('third_party', 'upb'),) -UPB_GRPC_GENERATED_INCLUDE = (os.path.join('src', 'core', 'ext', - 'upb-generated'),) -UPBDEFS_GRPC_GENERATED_INCLUDE = (os.path.join('src', 'core', 'ext', - 'upbdefs-generated'),) -UTF8_RANGE_INCLUDE = (os.path.join('third_party', 'utf8_range'),) -XXHASH_INCLUDE = (os.path.join('third_party', 'xxhash'),) -ZLIB_INCLUDE = (os.path.join('third_party', 'zlib'),) -README = os.path.join(PYTHON_STEM, 'README.rst') +UPBDEFS_GRPC_GENERATED_INCLUDE = ( + os.path.join("src", "core", "ext", "upbdefs-generated"), +) +UTF8_RANGE_INCLUDE = (os.path.join("third_party", "utf8_range"),) +XXHASH_INCLUDE = (os.path.join("third_party", "xxhash"),) +ZLIB_INCLUDE = (os.path.join("third_party", "zlib"),) +README = os.path.join(PYTHON_STEM, "README.rst") # Ensure we're in the proper directory whether or not we're being used by pip. os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -96,28 +100,29 @@ _parallel_compile_patch.monkeypatch_compile_maybe() _spawn_patch.monkeypatch_spawn() -LICENSE = 'Apache License 2.0' +LICENSE = "Apache License 2.0" CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: Apache Software License", ] def _env_bool_value(env_name, default): """Parses a bool option from an environment variable""" - return os.environ.get(env_name, default).upper() not in ['FALSE', '0', ''] + return os.environ.get(env_name, default).upper() not in ["FALSE", "0", ""] -BUILD_WITH_BORING_SSL_ASM = _env_bool_value('GRPC_BUILD_WITH_BORING_SSL_ASM', - 'True') +BUILD_WITH_BORING_SSL_ASM = _env_bool_value( + "GRPC_BUILD_WITH_BORING_SSL_ASM", "True" +) # Export this environment variable to override the platform variant that will # be chosen for boringssl assembly optimizations. This option is useful when @@ -125,42 +130,46 @@ def _env_bool_value(env_name, default): # doesn't match the platform we are targetting. # Example value: "linux-aarch64" BUILD_OVERRIDE_BORING_SSL_ASM_PLATFORM = os.environ.get( - 'GRPC_BUILD_OVERRIDE_BORING_SSL_ASM_PLATFORM', '') + "GRPC_BUILD_OVERRIDE_BORING_SSL_ASM_PLATFORM", "" +) # Environment variable to determine whether or not the Cython extension should # *use* Cython or use the generated C files. Note that this requires the C files # to have been generated by building first *with* Cython support. Even if this # is set to false, if the script detects that the generated `.c` file isn't # present, then it will still attempt to use Cython. -BUILD_WITH_CYTHON = _env_bool_value('GRPC_PYTHON_BUILD_WITH_CYTHON', 'False') +BUILD_WITH_CYTHON = _env_bool_value("GRPC_PYTHON_BUILD_WITH_CYTHON", "False") # Export this variable to use the system installation of openssl. You need to # have the header files installed (in /usr/include/openssl) and during # runtime, the shared library must be installed -BUILD_WITH_SYSTEM_OPENSSL = _env_bool_value('GRPC_PYTHON_BUILD_SYSTEM_OPENSSL', - 'False') +BUILD_WITH_SYSTEM_OPENSSL = _env_bool_value( + "GRPC_PYTHON_BUILD_SYSTEM_OPENSSL", "False" +) # Export this variable to use the system installation of zlib. You need to # have the header files installed (in /usr/include/) and during # runtime, the shared library must be installed -BUILD_WITH_SYSTEM_ZLIB = _env_bool_value('GRPC_PYTHON_BUILD_SYSTEM_ZLIB', - 'False') +BUILD_WITH_SYSTEM_ZLIB = _env_bool_value( + "GRPC_PYTHON_BUILD_SYSTEM_ZLIB", "False" +) # Export this variable to use the system installation of cares. You need to # have the header files installed (in /usr/include/) and during # runtime, the shared library must be installed -BUILD_WITH_SYSTEM_CARES = _env_bool_value('GRPC_PYTHON_BUILD_SYSTEM_CARES', - 'False') +BUILD_WITH_SYSTEM_CARES = _env_bool_value( + "GRPC_PYTHON_BUILD_SYSTEM_CARES", "False" +) # Export this variable to use the system installation of re2. You need to # have the header files installed (in /usr/include/re2) and during # runtime, the shared library must be installed -BUILD_WITH_SYSTEM_RE2 = _env_bool_value('GRPC_PYTHON_BUILD_SYSTEM_RE2', 'False') +BUILD_WITH_SYSTEM_RE2 = _env_bool_value("GRPC_PYTHON_BUILD_SYSTEM_RE2", "False") # Export this variable to use the system installation of abseil. You need to # have the header files installed (in /usr/include/absl) and during # runtime, the shared library must be installed -BUILD_WITH_SYSTEM_ABSL = os.environ.get('GRPC_PYTHON_BUILD_SYSTEM_ABSL', False) +BUILD_WITH_SYSTEM_ABSL = os.environ.get("GRPC_PYTHON_BUILD_SYSTEM_ABSL", False) # Export this variable to force building the python extension with a statically linked libstdc++. # At least on linux, this is normally not needed as we can build manylinux-compatible wheels on linux just fine @@ -170,7 +179,8 @@ def _env_bool_value(env_name, default): # of GCC (we require >=5.1) but still uses old-enough libstdc++ symbols. # TODO(jtattermusch): remove this workaround once issues with crosscompiler version are resolved. BUILD_WITH_STATIC_LIBSTDCXX = _env_bool_value( - 'GRPC_PYTHON_BUILD_WITH_STATIC_LIBSTDCXX', 'False') + "GRPC_PYTHON_BUILD_WITH_STATIC_LIBSTDCXX", "False" +) # For local development use only: This skips building gRPC Core and its # dependencies, including protobuf and boringssl. This allows "incremental" @@ -183,44 +193,53 @@ def _env_bool_value(env_name, default): # make HAS_SYSTEM_OPENSSL_ALPN=0 # # TODO(ericgribkoff) Respect the BUILD_WITH_SYSTEM_* flags alongside this option -USE_PREBUILT_GRPC_CORE = _env_bool_value('GRPC_PYTHON_USE_PREBUILT_GRPC_CORE', - 'False') +USE_PREBUILT_GRPC_CORE = _env_bool_value( + "GRPC_PYTHON_USE_PREBUILT_GRPC_CORE", "False" +) # If this environmental variable is set, GRPC will not try to be compatible with # libc versions old than the one it was compiled against. DISABLE_LIBC_COMPATIBILITY = _env_bool_value( - 'GRPC_PYTHON_DISABLE_LIBC_COMPATIBILITY', 'False') + "GRPC_PYTHON_DISABLE_LIBC_COMPATIBILITY", "False" +) # Environment variable to determine whether or not to enable coverage analysis # in Cython modules. -ENABLE_CYTHON_TRACING = _env_bool_value('GRPC_PYTHON_ENABLE_CYTHON_TRACING', - 'False') +ENABLE_CYTHON_TRACING = _env_bool_value( + "GRPC_PYTHON_ENABLE_CYTHON_TRACING", "False" +) # Environment variable specifying whether or not there's interest in setting up # documentation building. ENABLE_DOCUMENTATION_BUILD = _env_bool_value( - 'GRPC_PYTHON_ENABLE_DOCUMENTATION_BUILD', 'False') + "GRPC_PYTHON_ENABLE_DOCUMENTATION_BUILD", "False" +) def check_linker_need_libatomic(): """Test if linker on system needs libatomic.""" - code_test = (b'#include \n' + - b'int main() { return std::atomic{}; }') - cxx = shlex.split(os.environ.get('CXX', 'c++')) - cpp_test = subprocess.Popen(cxx + ['-x', 'c++', '-std=c++14', '-'], - stdin=PIPE, - stdout=PIPE, - stderr=PIPE) + code_test = ( + b"#include \n" + + b"int main() { return std::atomic{}; }" + ) + cxx = shlex.split(os.environ.get("CXX", "c++")) + cpp_test = subprocess.Popen( + cxx + ["-x", "c++", "-std=c++14", "-"], + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + ) cpp_test.communicate(input=code_test) if cpp_test.returncode == 0: return False # Double-check to see if -latomic actually can solve the problem. # https://github.com/grpc/grpc/issues/22491 - cpp_test = subprocess.Popen(cxx + - ['-x', 'c++', '-std=c++14', '-', '-latomic'], - stdin=PIPE, - stdout=PIPE, - stderr=PIPE) + cpp_test = subprocess.Popen( + cxx + ["-x", "c++", "-std=c++14", "-", "-latomic"], + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + ) cpp_test.communicate(input=code_test) return cpp_test.returncode == 0 @@ -232,119 +251,141 @@ def check_linker_need_libatomic(): # We can also use these variables as a way to inject environment-specific # compiler/linker flags. We assume GCC-like compilers and/or MinGW as a # reasonable default. -EXTRA_ENV_COMPILE_ARGS = os.environ.get('GRPC_PYTHON_CFLAGS', None) -EXTRA_ENV_LINK_ARGS = os.environ.get('GRPC_PYTHON_LDFLAGS', None) +EXTRA_ENV_COMPILE_ARGS = os.environ.get("GRPC_PYTHON_CFLAGS", None) +EXTRA_ENV_LINK_ARGS = os.environ.get("GRPC_PYTHON_LDFLAGS", None) if EXTRA_ENV_COMPILE_ARGS is None: - EXTRA_ENV_COMPILE_ARGS = ' -std=c++14' - if 'win32' in sys.platform: + EXTRA_ENV_COMPILE_ARGS = " -std=c++14" + if "win32" in sys.platform: if sys.version_info < (3, 5): - EXTRA_ENV_COMPILE_ARGS += ' -D_hypot=hypot' + EXTRA_ENV_COMPILE_ARGS += " -D_hypot=hypot" # We use define flags here and don't directly add to DEFINE_MACROS below to # ensure that the expert user/builder has a way of turning it off (via the # envvars) without adding yet more GRPC-specific envvars. # See https://sourceforge.net/p/mingw-w64/bugs/363/ - if '32' in platform.architecture()[0]: - EXTRA_ENV_COMPILE_ARGS += ' -D_ftime=_ftime32 -D_timeb=__timeb32 -D_ftime_s=_ftime32_s' + if "32" in platform.architecture()[0]: + EXTRA_ENV_COMPILE_ARGS += ( + " -D_ftime=_ftime32 -D_timeb=__timeb32" + " -D_ftime_s=_ftime32_s" + ) else: - EXTRA_ENV_COMPILE_ARGS += ' -D_ftime=_ftime64 -D_timeb=__timeb64' + EXTRA_ENV_COMPILE_ARGS += ( + " -D_ftime=_ftime64 -D_timeb=__timeb64" + ) else: # We need to statically link the C++ Runtime, only the C runtime is # available dynamically - EXTRA_ENV_COMPILE_ARGS += ' /MT' + EXTRA_ENV_COMPILE_ARGS += " /MT" elif "linux" in sys.platform: - EXTRA_ENV_COMPILE_ARGS += ' -fvisibility=hidden -fno-wrapv -fno-exceptions' + EXTRA_ENV_COMPILE_ARGS += ( + " -fvisibility=hidden -fno-wrapv -fno-exceptions" + ) elif "darwin" in sys.platform: - EXTRA_ENV_COMPILE_ARGS += ' -stdlib=libc++ -fvisibility=hidden -fno-wrapv -fno-exceptions -DHAVE_UNISTD_H' + EXTRA_ENV_COMPILE_ARGS += ( + " -stdlib=libc++ -fvisibility=hidden -fno-wrapv -fno-exceptions" + " -DHAVE_UNISTD_H" + ) if EXTRA_ENV_LINK_ARGS is None: - EXTRA_ENV_LINK_ARGS = '' + EXTRA_ENV_LINK_ARGS = "" if "linux" in sys.platform or "darwin" in sys.platform: - EXTRA_ENV_LINK_ARGS += ' -lpthread' + EXTRA_ENV_LINK_ARGS += " -lpthread" if check_linker_need_libatomic(): - EXTRA_ENV_LINK_ARGS += ' -latomic' + EXTRA_ENV_LINK_ARGS += " -latomic" elif "win32" in sys.platform and sys.version_info < (3, 5): msvcr = cygwinccompiler.get_msvcr()[0] EXTRA_ENV_LINK_ARGS += ( - ' -static-libgcc -static-libstdc++ -mcrtdll={msvcr}' - ' -static -lshlwapi'.format(msvcr=msvcr)) + " -static-libgcc -static-libstdc++ -mcrtdll={msvcr}" + " -static -lshlwapi".format(msvcr=msvcr) + ) if "linux" in sys.platform: - EXTRA_ENV_LINK_ARGS += ' -static-libgcc' + EXTRA_ENV_LINK_ARGS += " -static-libgcc" EXTRA_COMPILE_ARGS = shlex.split(EXTRA_ENV_COMPILE_ARGS) EXTRA_LINK_ARGS = shlex.split(EXTRA_ENV_LINK_ARGS) if BUILD_WITH_STATIC_LIBSTDCXX: - EXTRA_LINK_ARGS.append('-static-libstdc++') + EXTRA_LINK_ARGS.append("-static-libstdc++") CYTHON_EXTENSION_PACKAGE_NAMES = () -CYTHON_EXTENSION_MODULE_NAMES = ('grpc._cython.cygrpc',) +CYTHON_EXTENSION_MODULE_NAMES = ("grpc._cython.cygrpc",) CYTHON_HELPER_C_FILES = () CORE_C_FILES = tuple(grpc_core_dependencies.CORE_SOURCE_FILES) if "win32" in sys.platform: - CORE_C_FILES = filter(lambda x: 'third_party/cares' not in x, CORE_C_FILES) + CORE_C_FILES = filter(lambda x: "third_party/cares" not in x, CORE_C_FILES) if BUILD_WITH_SYSTEM_OPENSSL: - CORE_C_FILES = filter(lambda x: 'third_party/boringssl' not in x, - CORE_C_FILES) - CORE_C_FILES = filter(lambda x: 'src/boringssl' not in x, CORE_C_FILES) - SSL_INCLUDE = (os.path.join('/usr', 'include', 'openssl'),) + CORE_C_FILES = filter( + lambda x: "third_party/boringssl" not in x, CORE_C_FILES + ) + CORE_C_FILES = filter(lambda x: "src/boringssl" not in x, CORE_C_FILES) + SSL_INCLUDE = (os.path.join("/usr", "include", "openssl"),) if BUILD_WITH_SYSTEM_ZLIB: - CORE_C_FILES = filter(lambda x: 'third_party/zlib' not in x, CORE_C_FILES) - ZLIB_INCLUDE = (os.path.join('/usr', 'include'),) + CORE_C_FILES = filter(lambda x: "third_party/zlib" not in x, CORE_C_FILES) + ZLIB_INCLUDE = (os.path.join("/usr", "include"),) if BUILD_WITH_SYSTEM_CARES: - CORE_C_FILES = filter(lambda x: 'third_party/cares' not in x, CORE_C_FILES) - CARES_INCLUDE = (os.path.join('/usr', 'include'),) + CORE_C_FILES = filter(lambda x: "third_party/cares" not in x, CORE_C_FILES) + CARES_INCLUDE = (os.path.join("/usr", "include"),) if BUILD_WITH_SYSTEM_RE2: - CORE_C_FILES = filter(lambda x: 'third_party/re2' not in x, CORE_C_FILES) - RE2_INCLUDE = (os.path.join('/usr', 'include', 're2'),) + CORE_C_FILES = filter(lambda x: "third_party/re2" not in x, CORE_C_FILES) + RE2_INCLUDE = (os.path.join("/usr", "include", "re2"),) if BUILD_WITH_SYSTEM_ABSL: - CORE_C_FILES = filter(lambda x: 'third_party/abseil-cpp' not in x, - CORE_C_FILES) - ABSL_INCLUDE = (os.path.join('/usr', 'include'),) - -EXTENSION_INCLUDE_DIRECTORIES = ((PYTHON_STEM,) + CORE_INCLUDE + ABSL_INCLUDE + - ADDRESS_SORTING_INCLUDE + CARES_INCLUDE + - RE2_INCLUDE + SSL_INCLUDE + UPB_INCLUDE + - UPB_GRPC_GENERATED_INCLUDE + - UPBDEFS_GRPC_GENERATED_INCLUDE + - UTF8_RANGE_INCLUDE + XXHASH_INCLUDE + - ZLIB_INCLUDE) + CORE_C_FILES = filter( + lambda x: "third_party/abseil-cpp" not in x, CORE_C_FILES + ) + ABSL_INCLUDE = (os.path.join("/usr", "include"),) + +EXTENSION_INCLUDE_DIRECTORIES = ( + (PYTHON_STEM,) + + CORE_INCLUDE + + ABSL_INCLUDE + + ADDRESS_SORTING_INCLUDE + + CARES_INCLUDE + + RE2_INCLUDE + + SSL_INCLUDE + + UPB_INCLUDE + + UPB_GRPC_GENERATED_INCLUDE + + UPBDEFS_GRPC_GENERATED_INCLUDE + + UTF8_RANGE_INCLUDE + + XXHASH_INCLUDE + + ZLIB_INCLUDE +) EXTENSION_LIBRARIES = () if "linux" in sys.platform: - EXTENSION_LIBRARIES += ('rt',) + EXTENSION_LIBRARIES += ("rt",) if not "win32" in sys.platform: - EXTENSION_LIBRARIES += ('m',) + EXTENSION_LIBRARIES += ("m",) if "win32" in sys.platform: EXTENSION_LIBRARIES += ( - 'advapi32', - 'bcrypt', - 'dbghelp', - 'ws2_32', + "advapi32", + "bcrypt", + "dbghelp", + "ws2_32", ) if BUILD_WITH_SYSTEM_OPENSSL: EXTENSION_LIBRARIES += ( - 'ssl', - 'crypto', + "ssl", + "crypto", ) if BUILD_WITH_SYSTEM_ZLIB: - EXTENSION_LIBRARIES += ('z',) + EXTENSION_LIBRARIES += ("z",) if BUILD_WITH_SYSTEM_CARES: - EXTENSION_LIBRARIES += ('cares',) + EXTENSION_LIBRARIES += ("cares",) if BUILD_WITH_SYSTEM_RE2: - EXTENSION_LIBRARIES += ('re2',) + EXTENSION_LIBRARIES += ("re2",) if BUILD_WITH_SYSTEM_ABSL: EXTENSION_LIBRARIES += tuple( - lib.stem[3:] for lib in pathlib.Path('/usr').glob('lib*/libabsl_*.so')) + lib.stem[3:] for lib in pathlib.Path("/usr").glob("lib*/libabsl_*.so") + ) -DEFINE_MACROS = (('_WIN32_WINNT', 0x600),) +DEFINE_MACROS = (("_WIN32_WINNT", 0x600),) asm_files = [] @@ -353,69 +394,76 @@ def check_linker_need_libatomic(): # the binary. def _quote_build_define(argument): if "win32" in sys.platform: - return '"\\\"{}\\\""'.format(argument) + return '"\\"{}\\""'.format(argument) return '"{}"'.format(argument) DEFINE_MACROS += ( ("GRPC_XDS_USER_AGENT_NAME_SUFFIX", _quote_build_define("Python")), - ("GRPC_XDS_USER_AGENT_VERSION_SUFFIX", - _quote_build_define(_metadata.__version__)), + ( + "GRPC_XDS_USER_AGENT_VERSION_SUFFIX", + _quote_build_define(_metadata.__version__), + ), ) -asm_key = '' +asm_key = "" if BUILD_WITH_BORING_SSL_ASM and not BUILD_WITH_SYSTEM_OPENSSL: - boringssl_asm_platform = BUILD_OVERRIDE_BORING_SSL_ASM_PLATFORM if BUILD_OVERRIDE_BORING_SSL_ASM_PLATFORM else util.get_platform( + boringssl_asm_platform = ( + BUILD_OVERRIDE_BORING_SSL_ASM_PLATFORM + if BUILD_OVERRIDE_BORING_SSL_ASM_PLATFORM + else util.get_platform() ) - LINUX_X86_64 = 'linux-x86_64' - LINUX_ARM = 'linux-arm' - LINUX_AARCH64 = 'linux-aarch64' + LINUX_X86_64 = "linux-x86_64" + LINUX_ARM = "linux-arm" + LINUX_AARCH64 = "linux-aarch64" if LINUX_X86_64 == boringssl_asm_platform: - asm_key = 'crypto_linux_x86_64' + asm_key = "crypto_linux_x86_64" elif LINUX_ARM == boringssl_asm_platform: - asm_key = 'crypto_linux_arm' + asm_key = "crypto_linux_arm" elif LINUX_AARCH64 == boringssl_asm_platform: - asm_key = 'crypto_linux_aarch64' + asm_key = "crypto_linux_aarch64" elif "mac" in boringssl_asm_platform and "x86_64" in boringssl_asm_platform: - asm_key = 'crypto_apple_x86_64' + asm_key = "crypto_apple_x86_64" elif "mac" in boringssl_asm_platform and "arm64" in boringssl_asm_platform: - asm_key = 'crypto_apple_aarch64' + asm_key = "crypto_apple_aarch64" else: - print("ASM Builds for BoringSSL currently not supported on:", - boringssl_asm_platform) + print( + "ASM Builds for BoringSSL currently not supported on:", + boringssl_asm_platform, + ) if asm_key: asm_files = grpc_core_dependencies.ASM_SOURCE_FILES[asm_key] else: - DEFINE_MACROS += (('OPENSSL_NO_ASM', 1),) + DEFINE_MACROS += (("OPENSSL_NO_ASM", 1),) if not DISABLE_LIBC_COMPATIBILITY: - DEFINE_MACROS += (('GPR_BACKWARDS_COMPATIBILITY_MODE', 1),) + DEFINE_MACROS += (("GPR_BACKWARDS_COMPATIBILITY_MODE", 1),) if "win32" in sys.platform: # TODO(zyc): Re-enable c-ares on x64 and x86 windows after fixing the # ares_library_init compilation issue DEFINE_MACROS += ( - ('WIN32_LEAN_AND_MEAN', 1), - ('CARES_STATICLIB', 1), - ('GRPC_ARES', 0), - ('NTDDI_VERSION', 0x06000000), - ('NOMINMAX', 1), + ("WIN32_LEAN_AND_MEAN", 1), + ("CARES_STATICLIB", 1), + ("GRPC_ARES", 0), + ("NTDDI_VERSION", 0x06000000), + ("NOMINMAX", 1), ) - if '64bit' in platform.architecture()[0]: - DEFINE_MACROS += (('MS_WIN64', 1),) + if "64bit" in platform.architecture()[0]: + DEFINE_MACROS += (("MS_WIN64", 1),) elif sys.version_info >= (3, 5): # For some reason, this is needed to get access to inet_pton/inet_ntop # on msvc, but only for 32 bits - DEFINE_MACROS += (('NTDDI_VERSION', 0x06000000),) + DEFINE_MACROS += (("NTDDI_VERSION", 0x06000000),) else: DEFINE_MACROS += ( - ('HAVE_CONFIG_H', 1), - ('GRPC_ENABLE_FORK_SUPPORT', 1), + ("HAVE_CONFIG_H", 1), + ("GRPC_ENABLE_FORK_SUPPORT", 1), ) # Fix for multiprocessing support on Apple devices. # TODO(vigneshbabu): Remove this once the poll poller gets fork support. -DEFINE_MACROS += (('GRPC_DO_NOT_INSTANTIATE_POSIX_POLLER', 1),) +DEFINE_MACROS += (("GRPC_DO_NOT_INSTANTIATE_POSIX_POLLER", 1),) # Fix for Cython build issue in aarch64. # It's required to define this macro before include . @@ -424,44 +472,48 @@ def _quote_build_define(argument): # but we're still having issue in aarch64, so we manually define the macro here. # TODO(xuanwn): Figure out what's going on in the aarch64 build so we can support # gcc + Bazel. -DEFINE_MACROS += (('__STDC_FORMAT_MACROS', None),) +DEFINE_MACROS += (("__STDC_FORMAT_MACROS", None),) LDFLAGS = tuple(EXTRA_LINK_ARGS) CFLAGS = tuple(EXTRA_COMPILE_ARGS) if "linux" in sys.platform or "darwin" in sys.platform: - pymodinit_type = 'PyObject*' if PY3 else 'void' + pymodinit_type = "PyObject*" if PY3 else "void" pymodinit = 'extern "C" __attribute__((visibility ("default"))) {}'.format( - pymodinit_type) - DEFINE_MACROS += (('PyMODINIT_FUNC', pymodinit),) - DEFINE_MACROS += (('GRPC_POSIX_FORK_ALLOW_PTHREAD_ATFORK', 1),) + pymodinit_type + ) + DEFINE_MACROS += (("PyMODINIT_FUNC", pymodinit),) + DEFINE_MACROS += (("GRPC_POSIX_FORK_ALLOW_PTHREAD_ATFORK", 1),) # By default, Python3 distutils enforces compatibility of # c plugins (.so files) with the OSX version Python was built with. # We need OSX 10.10, the oldest which supports C++ thread_local. # Python 3.9: Mac OS Big Sur sysconfig.get_config_var('MACOSX_DEPLOYMENT_TARGET') returns int (11) -if 'darwin' in sys.platform: - mac_target = sysconfig.get_config_var('MACOSX_DEPLOYMENT_TARGET') +if "darwin" in sys.platform: + mac_target = sysconfig.get_config_var("MACOSX_DEPLOYMENT_TARGET") if mac_target: mac_target = pkg_resources.parse_version(str(mac_target)) - if mac_target < pkg_resources.parse_version('10.10.0'): - os.environ['MACOSX_DEPLOYMENT_TARGET'] = '10.10' - os.environ['_PYTHON_HOST_PLATFORM'] = re.sub( - r'macosx-[0-9]+\.[0-9]+-(.+)', r'macosx-10.10-\1', - util.get_platform()) + if mac_target < pkg_resources.parse_version("10.10.0"): + os.environ["MACOSX_DEPLOYMENT_TARGET"] = "10.10" + os.environ["_PYTHON_HOST_PLATFORM"] = re.sub( + r"macosx-[0-9]+\.[0-9]+-(.+)", + r"macosx-10.10-\1", + util.get_platform(), + ) def cython_extensions_and_necessity(): cython_module_files = [ - os.path.join(PYTHON_STEM, - name.replace('.', '/') + '.pyx') + os.path.join(PYTHON_STEM, name.replace(".", "/") + ".pyx") for name in CYTHON_EXTENSION_MODULE_NAMES ] - config = os.environ.get('CONFIG', 'opt') - prefix = 'libs/' + config + '/' + config = os.environ.get("CONFIG", "opt") + prefix = "libs/" + config + "/" if USE_PREBUILT_GRPC_CORE: extra_objects = [ - prefix + 'libares.a', prefix + 'libboringssl.a', - prefix + 'libgpr.a', prefix + 'libgrpc.a' + prefix + "libares.a", + prefix + "libboringssl.a", + prefix + "libgpr.a", + prefix + "libgrpc.a", ] core_c_files = [] else: @@ -470,42 +522,56 @@ def cython_extensions_and_necessity(): extensions = [ _extension.Extension( name=module_name, - sources=([module_file] + list(CYTHON_HELPER_C_FILES) + - core_c_files + asm_files), + sources=( + [module_file] + + list(CYTHON_HELPER_C_FILES) + + core_c_files + + asm_files + ), include_dirs=list(EXTENSION_INCLUDE_DIRECTORIES), libraries=list(EXTENSION_LIBRARIES), define_macros=list(DEFINE_MACROS), extra_objects=extra_objects, extra_compile_args=list(CFLAGS), extra_link_args=list(LDFLAGS), - ) for (module_name, module_file - ) in zip(list(CYTHON_EXTENSION_MODULE_NAMES), cython_module_files) + ) + for (module_name, module_file) in zip( + list(CYTHON_EXTENSION_MODULE_NAMES), cython_module_files + ) ] need_cython = BUILD_WITH_CYTHON if not BUILD_WITH_CYTHON: - need_cython = need_cython or not commands.check_and_update_cythonization( - extensions) + need_cython = ( + need_cython + or not commands.check_and_update_cythonization(extensions) + ) # TODO: the strategy for conditional compiling and exposing the aio Cython # dependencies will be revisited by https://github.com/grpc/grpc/issues/19728 - return commands.try_cythonize(extensions, - linetracing=ENABLE_CYTHON_TRACING, - mandatory=BUILD_WITH_CYTHON), need_cython + return ( + commands.try_cythonize( + extensions, + linetracing=ENABLE_CYTHON_TRACING, + mandatory=BUILD_WITH_CYTHON, + ), + need_cython, + ) CYTHON_EXTENSION_MODULES, need_cython = cython_extensions_and_necessity() PACKAGE_DIRECTORIES = { - '': PYTHON_STEM, + "": PYTHON_STEM, } INSTALL_REQUIRES = () EXTRAS_REQUIRES = { - 'protobuf': 'grpcio-tools>={version}'.format(version=grpc_version.VERSION), + "protobuf": "grpcio-tools>={version}".format(version=grpc_version.VERSION), } -SETUP_REQUIRES = INSTALL_REQUIRES + ( - 'Sphinx~=1.8.1',) if ENABLE_DOCUMENTATION_BUILD else () +SETUP_REQUIRES = ( + INSTALL_REQUIRES + ("Sphinx~=1.8.1",) if ENABLE_DOCUMENTATION_BUILD else () +) try: import Cython @@ -514,62 +580,65 @@ def cython_extensions_and_necessity(): sys.stderr.write( "You requested a Cython build via GRPC_PYTHON_BUILD_WITH_CYTHON, " "but do not have Cython installed. We won't stop you from using " - "other commands, but the extension files will fail to build.\n") + "other commands, but the extension files will fail to build.\n" + ) elif need_cython: sys.stderr.write( - 'We could not find Cython. Setup may take 10-20 minutes.\n') - SETUP_REQUIRES += ('cython>=0.23',) + "We could not find Cython. Setup may take 10-20 minutes.\n" + ) + SETUP_REQUIRES += ("cython>=0.23",) COMMAND_CLASS = { - 'doc': commands.SphinxDocumentation, - 'build_project_metadata': commands.BuildProjectMetadata, - 'build_py': commands.BuildPy, - 'build_ext': commands.BuildExt, - 'gather': commands.Gather, - 'clean': commands.Clean, + "doc": commands.SphinxDocumentation, + "build_project_metadata": commands.BuildProjectMetadata, + "build_py": commands.BuildPy, + "build_ext": commands.BuildExt, + "gather": commands.Gather, + "clean": commands.Clean, } # Ensure that package data is copied over before any commands have been run: -credentials_dir = os.path.join(PYTHON_STEM, 'grpc', '_cython', '_credentials') +credentials_dir = os.path.join(PYTHON_STEM, "grpc", "_cython", "_credentials") try: os.mkdir(credentials_dir) except OSError: pass -shutil.copyfile(os.path.join('etc', 'roots.pem'), - os.path.join(credentials_dir, 'roots.pem')) +shutil.copyfile( + os.path.join("etc", "roots.pem"), os.path.join(credentials_dir, "roots.pem") +) PACKAGE_DATA = { # Binaries that may or may not be present in the final installation, but are # mentioned here for completeness. - 'grpc._cython': [ - '_credentials/roots.pem', - '_windows/grpc_c.32.python', - '_windows/grpc_c.64.python', + "grpc._cython": [ + "_credentials/roots.pem", + "_windows/grpc_c.32.python", + "_windows/grpc_c.64.python", ], } PACKAGES = setuptools.find_packages(PYTHON_STEM) setuptools.setup( - name='grpcio', + name="grpcio", version=grpc_version.VERSION, - description='HTTP/2-based RPC framework', - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - url='https://grpc.io', + description="HTTP/2-based RPC framework", + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + url="https://grpc.io", project_urls={ "Source Code": "https://github.com/grpc/grpc", "Bug Tracker": "https://github.com/grpc/grpc/issues", - 'Documentation': 'https://grpc.github.io/grpc/python', + "Documentation": "https://grpc.github.io/grpc/python", }, license=LICENSE, classifiers=CLASSIFIERS, - long_description_content_type='text/x-rst', + long_description_content_type="text/x-rst", long_description=open(README).read(), ext_modules=CYTHON_EXTENSION_MODULES, packages=list(PACKAGES), package_dir=PACKAGE_DIRECTORIES, package_data=PACKAGE_DATA, - python_requires='>=3.7', + python_requires=">=3.7", install_requires=INSTALL_REQUIRES, extras_require=EXTRAS_REQUIRES, setup_requires=SETUP_REQUIRES, diff --git a/src/abseil-cpp/gen_build_yaml.py b/src/abseil-cpp/gen_build_yaml.py index 8dfacb7911e5c..6bf9a6d71cb7b 100755 --- a/src/abseil-cpp/gen_build_yaml.py +++ b/src/abseil-cpp/gen_build_yaml.py @@ -17,14 +17,15 @@ import os import yaml -BUILDS_YAML_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'preprocessed_builds.yaml') +BUILDS_YAML_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "preprocessed_builds.yaml" +) with open(BUILDS_YAML_PATH) as f: builds = yaml.safe_load(f) for build in builds: - build['build'] = 'private' - build['build_system'] = [] - build['language'] = 'c' - build['secure'] = False -print(yaml.dump({'libs': builds})) + build["build"] = "private" + build["build_system"] = [] + build["language"] = "c" + build["secure"] = False +print(yaml.dump({"libs": builds})) diff --git a/src/abseil-cpp/preprocessed_builds.yaml.gen.py b/src/abseil-cpp/preprocessed_builds.yaml.gen.py index 599cc98de7f38..7cc767534cb30 100755 --- a/src/abseil-cpp/preprocessed_builds.yaml.gen.py +++ b/src/abseil-cpp/preprocessed_builds.yaml.gen.py @@ -29,185 +29,201 @@ # Rule object representing the rule of Bazel BUILD. Rule = collections.namedtuple( - "Rule", "type name package srcs hdrs textual_hdrs deps visibility testonly") + "Rule", "type name package srcs hdrs textual_hdrs deps visibility testonly" +) def get_elem_value(elem, name): - """Returns the value of XML element with the given name.""" - for child in elem: - if child.attrib.get("name") == name: - if child.tag == "string": - return child.attrib.get("value") - elif child.tag == "boolean": - return child.attrib.get("value") == "true" - elif child.tag == "list": - return [nested_child.attrib.get("value") for nested_child in child] - else: - raise "Cannot recognize tag: " + child.tag - return None + """Returns the value of XML element with the given name.""" + for child in elem: + if child.attrib.get("name") == name: + if child.tag == "string": + return child.attrib.get("value") + elif child.tag == "boolean": + return child.attrib.get("value") == "true" + elif child.tag == "list": + return [ + nested_child.attrib.get("value") for nested_child in child + ] + else: + raise "Cannot recognize tag: " + child.tag + return None def normalize_paths(paths): - """Returns the list of normalized path.""" - # e.g. ["//absl/strings:dir/header.h"] -> ["absl/strings/dir/header.h"] - return [path.lstrip("/").replace(":", "/") for path in paths] + """Returns the list of normalized path.""" + # e.g. ["//absl/strings:dir/header.h"] -> ["absl/strings/dir/header.h"] + return [path.lstrip("/").replace(":", "/") for path in paths] def parse_bazel_rule(elem, package): - """Returns a rule from bazel XML rule.""" - return Rule( - type=elem.attrib["class"], - name=get_elem_value(elem, "name"), - package=package, - srcs=normalize_paths(get_elem_value(elem, "srcs") or []), - hdrs=normalize_paths(get_elem_value(elem, "hdrs") or []), - textual_hdrs=normalize_paths(get_elem_value(elem, "textual_hdrs") or []), - deps=get_elem_value(elem, "deps") or [], - visibility=get_elem_value(elem, "visibility") or [], - testonly=get_elem_value(elem, "testonly") or False) + """Returns a rule from bazel XML rule.""" + return Rule( + type=elem.attrib["class"], + name=get_elem_value(elem, "name"), + package=package, + srcs=normalize_paths(get_elem_value(elem, "srcs") or []), + hdrs=normalize_paths(get_elem_value(elem, "hdrs") or []), + textual_hdrs=normalize_paths( + get_elem_value(elem, "textual_hdrs") or [] + ), + deps=get_elem_value(elem, "deps") or [], + visibility=get_elem_value(elem, "visibility") or [], + testonly=get_elem_value(elem, "testonly") or False, + ) def read_bazel_build(package): - """Runs bazel query on given package file and returns all cc rules.""" - # Use a wrapper version of bazel in gRPC not to use system-wide bazel - # to avoid bazel conflict when running on Kokoro. - BAZEL_BIN = "../../tools/bazel" - result = subprocess.check_output( - [BAZEL_BIN, "query", package + ":all", "--output", "xml"]) - root = ET.fromstring(result) - return [ - parse_bazel_rule(elem, package) - for elem in root - if elem.tag == "rule" and elem.attrib["class"].startswith("cc_") - ] + """Runs bazel query on given package file and returns all cc rules.""" + # Use a wrapper version of bazel in gRPC not to use system-wide bazel + # to avoid bazel conflict when running on Kokoro. + BAZEL_BIN = "../../tools/bazel" + result = subprocess.check_output( + [BAZEL_BIN, "query", package + ":all", "--output", "xml"] + ) + root = ET.fromstring(result) + return [ + parse_bazel_rule(elem, package) + for elem in root + if elem.tag == "rule" and elem.attrib["class"].startswith("cc_") + ] def collect_bazel_rules(root_path): - """Collects and returns all bazel rules from root path recursively.""" - rules = [] - for cur, _, _ in os.walk(root_path): - build_path = os.path.join(cur, "BUILD.bazel") - if os.path.exists(build_path): - rules.extend(read_bazel_build("//" + cur)) - return rules + """Collects and returns all bazel rules from root path recursively.""" + rules = [] + for cur, _, _ in os.walk(root_path): + build_path = os.path.join(cur, "BUILD.bazel") + if os.path.exists(build_path): + rules.extend(read_bazel_build("//" + cur)) + return rules def parse_cmake_rule(rule, package): - """Returns a rule from absl cmake rule. - Reference: https://github.com/abseil/abseil-cpp/blob/master/CMake/AbseilHelpers.cmake - """ - kv = {} - bucket = None - lines = rule.splitlines() - for line in lines[1:-1]: - if CAPITAL_WORD.match(line.strip()): - bucket = kv.setdefault(line.strip(), []) - else: - if bucket is not None: - bucket.append(line.strip()) - else: - raise ValueError("Illegal syntax: {}".format(rule)) - return Rule( - type=lines[0].rstrip("("), - name="absl::" + kv["NAME"][0], - package=package, - srcs=[package + "/" + f.strip('"') for f in kv.get("SRCS", [])], - hdrs=[package + "/" + f.strip('"') for f in kv.get("HDRS", [])], - textual_hdrs=[], - deps=kv.get("DEPS", []), - visibility="PUBLIC" in kv, - testonly="TESTONLY" in kv, - ) + """Returns a rule from absl cmake rule. + Reference: https://github.com/abseil/abseil-cpp/blob/master/CMake/AbseilHelpers.cmake + """ + kv = {} + bucket = None + lines = rule.splitlines() + for line in lines[1:-1]: + if CAPITAL_WORD.match(line.strip()): + bucket = kv.setdefault(line.strip(), []) + else: + if bucket is not None: + bucket.append(line.strip()) + else: + raise ValueError("Illegal syntax: {}".format(rule)) + return Rule( + type=lines[0].rstrip("("), + name="absl::" + kv["NAME"][0], + package=package, + srcs=[package + "/" + f.strip('"') for f in kv.get("SRCS", [])], + hdrs=[package + "/" + f.strip('"') for f in kv.get("HDRS", [])], + textual_hdrs=[], + deps=kv.get("DEPS", []), + visibility="PUBLIC" in kv, + testonly="TESTONLY" in kv, + ) def read_cmake_build(build_path, package): - """Parses given CMakeLists.txt file and returns all cc rules.""" - rules = [] - with open(build_path, "r") as f: - src = f.read() - for begin_mo in ABSEIL_CMAKE_RULE_BEGIN.finditer(src): - end_mo = ABSEIL_CMAKE_RULE_END.search(src[begin_mo.start(0):]) - expr = src[begin_mo.start(0):begin_mo.start(0) + end_mo.start(0) + 1] - rules.append(parse_cmake_rule(expr, package)) - return rules + """Parses given CMakeLists.txt file and returns all cc rules.""" + rules = [] + with open(build_path, "r") as f: + src = f.read() + for begin_mo in ABSEIL_CMAKE_RULE_BEGIN.finditer(src): + end_mo = ABSEIL_CMAKE_RULE_END.search(src[begin_mo.start(0) :]) + expr = src[ + begin_mo.start(0) : begin_mo.start(0) + end_mo.start(0) + 1 + ] + rules.append(parse_cmake_rule(expr, package)) + return rules def collect_cmake_rules(root_path): - """Collects and returns all cmake rules from root path recursively.""" - rules = [] - for cur, _, _ in os.walk(root_path): - build_path = os.path.join(cur, "CMakeLists.txt") - if os.path.exists(build_path): - rules.extend(read_cmake_build(build_path, cur)) - return rules + """Collects and returns all cmake rules from root path recursively.""" + rules = [] + for cur, _, _ in os.walk(root_path): + build_path = os.path.join(cur, "CMakeLists.txt") + if os.path.exists(build_path): + rules.extend(read_cmake_build(build_path, cur)) + return rules def pairing_bazel_and_cmake_rules(bazel_rules, cmake_rules): - """Returns a pair map between bazel rules and cmake rules based on - the similarity of the file list in the rule. This is because - cmake build and bazel build of abseil are not identical. - """ - pair_map = {} - for rule in bazel_rules: - best_crule, best_similarity = None, 0 - for crule in cmake_rules: - similarity = len( - set(rule.srcs + rule.hdrs + rule.textual_hdrs).intersection( - set(crule.srcs + crule.hdrs + crule.textual_hdrs))) - if similarity > best_similarity: - best_crule, best_similarity = crule, similarity - if best_crule: - pair_map[(rule.package, rule.name)] = best_crule.name - return pair_map + """Returns a pair map between bazel rules and cmake rules based on + the similarity of the file list in the rule. This is because + cmake build and bazel build of abseil are not identical. + """ + pair_map = {} + for rule in bazel_rules: + best_crule, best_similarity = None, 0 + for crule in cmake_rules: + similarity = len( + set(rule.srcs + rule.hdrs + rule.textual_hdrs).intersection( + set(crule.srcs + crule.hdrs + crule.textual_hdrs) + ) + ) + if similarity > best_similarity: + best_crule, best_similarity = crule, similarity + if best_crule: + pair_map[(rule.package, rule.name)] = best_crule.name + return pair_map def resolve_hdrs(files): - return [ABSEIL_PATH + "/" + f for f in files if f.endswith((".h", ".inc"))] + return [ABSEIL_PATH + "/" + f for f in files if f.endswith((".h", ".inc"))] def resolve_srcs(files): - return [ABSEIL_PATH + "/" + f for f in files if f.endswith(".cc")] + return [ABSEIL_PATH + "/" + f for f in files if f.endswith(".cc")] def resolve_deps(targets): - return [(t[2:] if t.startswith("//") else t) for t in targets] + return [(t[2:] if t.startswith("//") else t) for t in targets] def generate_builds(root_path): - """Generates builds from all BUILD files under absl directory.""" - bazel_rules = list( - filter(lambda r: r.type == "cc_library" and not r.testonly, - collect_bazel_rules(root_path))) - cmake_rules = list( - filter(lambda r: r.type == "absl_cc_library" and not r.testonly, - collect_cmake_rules(root_path))) - pair_map = pairing_bazel_and_cmake_rules(bazel_rules, cmake_rules) - builds = [] - for rule in sorted(bazel_rules, key=lambda r: r.package[2:] + ":" + r.name): - p = { - "name": - rule.package[2:] + ":" + rule.name, - "cmake_target": - pair_map.get((rule.package, rule.name)) or "", - "headers": - sorted(resolve_hdrs(rule.srcs + rule.hdrs + rule.textual_hdrs)), - "src": - sorted(resolve_srcs(rule.srcs + rule.hdrs + rule.textual_hdrs)), - "deps": - sorted(resolve_deps(rule.deps)), - } - builds.append(p) - return builds + """Generates builds from all BUILD files under absl directory.""" + bazel_rules = list( + filter( + lambda r: r.type == "cc_library" and not r.testonly, + collect_bazel_rules(root_path), + ) + ) + cmake_rules = list( + filter( + lambda r: r.type == "absl_cc_library" and not r.testonly, + collect_cmake_rules(root_path), + ) + ) + pair_map = pairing_bazel_and_cmake_rules(bazel_rules, cmake_rules) + builds = [] + for rule in sorted(bazel_rules, key=lambda r: r.package[2:] + ":" + r.name): + p = { + "name": rule.package[2:] + ":" + rule.name, + "cmake_target": pair_map.get((rule.package, rule.name)) or "", + "headers": sorted( + resolve_hdrs(rule.srcs + rule.hdrs + rule.textual_hdrs) + ), + "src": sorted( + resolve_srcs(rule.srcs + rule.hdrs + rule.textual_hdrs) + ), + "deps": sorted(resolve_deps(rule.deps)), + } + builds.append(p) + return builds def main(): - previous_dir = os.getcwd() - os.chdir(ABSEIL_PATH) - builds = generate_builds("absl") - os.chdir(previous_dir) - with open(OUTPUT_PATH, 'w') as outfile: - outfile.write(yaml.dump(builds, indent=2)) + previous_dir = os.getcwd() + os.chdir(ABSEIL_PATH) + builds = generate_builds("absl") + os.chdir(previous_dir) + with open(OUTPUT_PATH, "w") as outfile: + outfile.write(yaml.dump(builds, indent=2)) if __name__ == "__main__": - main() + main() diff --git a/src/benchmark/gen_build_yaml.py b/src/benchmark/gen_build_yaml.py index 45c39ffb9703a..4aea994500d12 100755 --- a/src/benchmark/gen_build_yaml.py +++ b/src/benchmark/gen_build_yaml.py @@ -19,27 +19,23 @@ import glob import yaml -os.chdir(os.path.dirname(sys.argv[0]) + '/../..') +os.chdir(os.path.dirname(sys.argv[0]) + "/../..") out = {} -out['libs'] = [{ - 'name': - 'benchmark', - 'build': - 'private', - 'language': - 'c++', - 'secure': - False, - 'defaults': - 'benchmark', - 'src': - sorted(glob.glob('third_party/benchmark/src/*.cc')), - 'headers': - sorted( - glob.glob('third_party/benchmark/src/*.h') + - glob.glob('third_party/benchmark/include/benchmark/*.h')), -}] +out["libs"] = [ + { + "name": "benchmark", + "build": "private", + "language": "c++", + "secure": False, + "defaults": "benchmark", + "src": sorted(glob.glob("third_party/benchmark/src/*.cc")), + "headers": sorted( + glob.glob("third_party/benchmark/src/*.h") + + glob.glob("third_party/benchmark/include/benchmark/*.h") + ), + } +] print(yaml.dump(out)) diff --git a/src/boringssl/gen_build_yaml.py b/src/boringssl/gen_build_yaml.py index 1b77fedad01e1..ec3b716e35a8d 100755 --- a/src/boringssl/gen_build_yaml.py +++ b/src/boringssl/gen_build_yaml.py @@ -20,112 +20,114 @@ run_dir = os.path.dirname(sys.argv[0]) sources_path = os.path.abspath( - os.path.join(run_dir, - '../../third_party/boringssl-with-bazel/sources.json')) + os.path.join(run_dir, "../../third_party/boringssl-with-bazel/sources.json") +) try: - with open(sources_path, 'r') as s: + with open(sources_path, "r") as s: sources = json.load(s) except IOError: sources_path = os.path.abspath( - os.path.join(run_dir, - '../../../../third_party/openssl/boringssl/sources.json')) - with open(sources_path, 'r') as s: + os.path.join( + run_dir, "../../../../third_party/openssl/boringssl/sources.json" + ) + ) + with open(sources_path, "r") as s: sources = json.load(s) def map_dir(filename): - return 'third_party/boringssl-with-bazel/' + filename + return "third_party/boringssl-with-bazel/" + filename class Grpc(object): - """Adapter for boring-SSL json sources files. """ + """Adapter for boring-SSL json sources files.""" def __init__(self, sources): self.yaml = None self.WriteFiles(sources) def WriteFiles(self, files): - test_binaries = ['ssl_test', 'crypto_test'] + test_binaries = ["ssl_test", "crypto_test"] asm_outputs = { - key: value for key, value in files.items() if any( - f.endswith(".S") or f.endswith(".asm") for f in value) + key: value + for key, value in files.items() + if any(f.endswith(".S") or f.endswith(".asm") for f in value) } self.yaml = { - '#': - 'generated with src/boringssl/gen_build_yaml.py', - 'raw_boringssl_build_output_for_debugging': { - 'files': files, + "#": "generated with src/boringssl/gen_build_yaml.py", + "raw_boringssl_build_output_for_debugging": { + "files": files, }, - 'libs': [ + "libs": [ { - 'name': - 'boringssl', - 'build': - 'private', - 'language': - 'c', - 'secure': - False, - 'src': - sorted( - map_dir(f) for f in files['ssl'] + files['crypto']), - 'asm_src': { - k: [map_dir(f) for f in value - ] for k, value in asm_outputs.items() + "name": "boringssl", + "build": "private", + "language": "c", + "secure": False, + "src": sorted( + map_dir(f) for f in files["ssl"] + files["crypto"] + ), + "asm_src": { + k: [map_dir(f) for f in value] + for k, value in asm_outputs.items() }, - 'headers': - sorted( - map_dir(f) - # We want to include files['fips_fragments'], but not build them as objects. - # See https://boringssl-review.googlesource.com/c/boringssl/+/16946 - for f in files['ssl_headers'] + - files['ssl_internal_headers'] + - files['crypto_headers'] + - files['crypto_internal_headers'] + - files['fips_fragments']), - 'boringssl': - True, - 'defaults': - 'boringssl', + "headers": sorted( + map_dir(f) + # We want to include files['fips_fragments'], but not build them as objects. + # See https://boringssl-review.googlesource.com/c/boringssl/+/16946 + for f in files["ssl_headers"] + + files["ssl_internal_headers"] + + files["crypto_headers"] + + files["crypto_internal_headers"] + + files["fips_fragments"] + ), + "boringssl": True, + "defaults": "boringssl", }, { - 'name': 'boringssl_test_util', - 'build': 'private', - 'language': 'c++', - 'secure': False, - 'boringssl': True, - 'defaults': 'boringssl', - 'src': [map_dir(f) for f in sorted(files['test_support'])], + "name": "boringssl_test_util", + "build": "private", + "language": "c++", + "secure": False, + "boringssl": True, + "defaults": "boringssl", + "src": [map_dir(f) for f in sorted(files["test_support"])], + }, + ], + "targets": [ + { + "name": "boringssl_%s" % test, + "build": "test", + "run": False, + "secure": False, + "language": "c++", + "src": sorted(map_dir(f) for f in files[test]), + "vs_proj_dir": "test/boringssl", + "boringssl": True, + "defaults": "boringssl", + "deps": [ + "boringssl_test_util", + "boringssl", + ], + } + for test in test_binaries + ], + "tests": [ + { + "name": "boringssl_%s" % test, + "args": [], + "exclude_configs": ["asan", "ubsan"], + "ci_platforms": ["linux", "mac", "posix", "windows"], + "platforms": ["linux", "mac", "posix", "windows"], + "flaky": False, + "gtest": True, + "language": "c++", + "boringssl": True, + "defaults": "boringssl", + "cpu_cost": 1.0, } + for test in test_binaries ], - 'targets': [{ - 'name': 'boringssl_%s' % test, - 'build': 'test', - 'run': False, - 'secure': False, - 'language': 'c++', - 'src': sorted(map_dir(f) for f in files[test]), - 'vs_proj_dir': 'test/boringssl', - 'boringssl': True, - 'defaults': 'boringssl', - 'deps': [ - 'boringssl_test_util', - 'boringssl', - ] - } for test in test_binaries], - 'tests': [{ - 'name': 'boringssl_%s' % test, - 'args': [], - 'exclude_configs': ['asan', 'ubsan'], - 'ci_platforms': ['linux', 'mac', 'posix', 'windows'], - 'platforms': ['linux', 'mac', 'posix', 'windows'], - 'flaky': False, - 'gtest': True, - 'language': 'c++', - 'boringssl': True, - 'defaults': 'boringssl', - 'cpu_cost': 1.0 - } for test in test_binaries] } diff --git a/src/c-ares/gen_build_yaml.py b/src/c-ares/gen_build_yaml.py index 607704ca278fe..00a2e319c179e 100755 --- a/src/c-ares/gen_build_yaml.py +++ b/src/c-ares/gen_build_yaml.py @@ -19,7 +19,7 @@ import sys import yaml -os.chdir(os.path.dirname(sys.argv[0]) + '/../..') +os.chdir(os.path.dirname(sys.argv[0]) + "/../..") out = {} @@ -30,127 +30,124 @@ def gen_ares_build(x): subprocess.call("third_party/cares/cares/configure", shell=True) def config_platform(x): - if 'darwin' in sys.platform: - return 'src/cares/cares/config_darwin/ares_config.h' - if 'freebsd' in sys.platform: - return 'src/cares/cares/config_freebsd/ares_config.h' - if 'linux' in sys.platform: - return 'src/cares/cares/config_linux/ares_config.h' - if 'openbsd' in sys.platform: - return 'src/cares/cares/config_openbsd/ares_config.h' - if not os.path.isfile('third_party/cares/cares/ares_config.h'): + if "darwin" in sys.platform: + return "src/cares/cares/config_darwin/ares_config.h" + if "freebsd" in sys.platform: + return "src/cares/cares/config_freebsd/ares_config.h" + if "linux" in sys.platform: + return "src/cares/cares/config_linux/ares_config.h" + if "openbsd" in sys.platform: + return "src/cares/cares/config_openbsd/ares_config.h" + if not os.path.isfile("third_party/cares/cares/ares_config.h"): gen_ares_build(x) - return 'third_party/cares/cares/ares_config.h' + return "third_party/cares/cares/ares_config.h" def ares_build(x): - if os.path.isfile('src/cares/cares/ares_build.h'): - return 'src/cares/cares/ares_build.h' - if not os.path.isfile('third_party/cares/cares/include/ares_build.h'): + if os.path.isfile("src/cares/cares/ares_build.h"): + return "src/cares/cares/ares_build.h" + if not os.path.isfile("third_party/cares/cares/include/ares_build.h"): gen_ares_build(x) - return 'third_party/cares/cares/include/ares_build.h' + return "third_party/cares/cares/include/ares_build.h" - out['libs'] = [{ - 'name': - 'ares', - 'defaults': - 'ares', - 'build': - 'private', - 'language': - 'c', - 'secure': - False, - 'src': [ - "third_party/cares/cares/src/lib/ares_init.c", - "third_party/cares/cares/src/lib/ares_expand_string.c", - "third_party/cares/cares/src/lib/ares_strcasecmp.c", - "third_party/cares/cares/src/lib/ares_destroy.c", - "third_party/cares/cares/src/lib/ares_free_string.c", - "third_party/cares/cares/src/lib/ares__timeval.c", - "third_party/cares/cares/src/lib/ares_library_init.c", - "third_party/cares/cares/src/lib/ares_getsock.c", - "third_party/cares/cares/src/lib/ares_process.c", - "third_party/cares/cares/src/lib/ares_create_query.c", - "third_party/cares/cares/src/lib/ares_fds.c", - "third_party/cares/cares/src/lib/ares_gethostbyname.c", - "third_party/cares/cares/src/lib/ares_mkquery.c", - "third_party/cares/cares/src/lib/ares_freeaddrinfo.c", - "third_party/cares/cares/src/lib/ares_strdup.c", - "third_party/cares/cares/src/lib/ares_timeout.c", - "third_party/cares/cares/src/lib/ares_getnameinfo.c", - "third_party/cares/cares/src/lib/ares_parse_soa_reply.c", - "third_party/cares/cares/src/lib/ares_parse_naptr_reply.c", - "third_party/cares/cares/src/lib/ares_parse_a_reply.c", - "third_party/cares/cares/src/lib/ares_send.c", - "third_party/cares/cares/src/lib/ares_nowarn.c", - "third_party/cares/cares/src/lib/ares__sortaddrinfo.c", - "third_party/cares/cares/src/lib/ares_android.c", - "third_party/cares/cares/src/lib/ares_strerror.c", - "third_party/cares/cares/src/lib/ares_parse_caa_reply.c", - "third_party/cares/cares/src/lib/ares__close_sockets.c", - "third_party/cares/cares/src/lib/ares_llist.c", - "third_party/cares/cares/src/lib/ares_parse_aaaa_reply.c", - "third_party/cares/cares/src/lib/ares_getaddrinfo.c", - "third_party/cares/cares/src/lib/ares_parse_ns_reply.c", - "third_party/cares/cares/src/lib/windows_port.c", - "third_party/cares/cares/src/lib/bitncmp.c", - "third_party/cares/cares/src/lib/ares_strsplit.c", - "third_party/cares/cares/src/lib/ares_data.c", - "third_party/cares/cares/src/lib/ares_free_hostent.c", - "third_party/cares/cares/src/lib/ares_platform.c", - "third_party/cares/cares/src/lib/ares_parse_txt_reply.c", - "third_party/cares/cares/src/lib/ares__parse_into_addrinfo.c", - "third_party/cares/cares/src/lib/ares_gethostbyaddr.c", - "third_party/cares/cares/src/lib/ares_parse_srv_reply.c", - "third_party/cares/cares/src/lib/ares_version.c", - "third_party/cares/cares/src/lib/ares_getenv.c", - "third_party/cares/cares/src/lib/ares_search.c", - "third_party/cares/cares/src/lib/ares_parse_mx_reply.c", - "third_party/cares/cares/src/lib/ares__get_hostent.c", - "third_party/cares/cares/src/lib/ares__readaddrinfo.c", - "third_party/cares/cares/src/lib/ares_parse_ptr_reply.c", - "third_party/cares/cares/src/lib/ares__read_line.c", - "third_party/cares/cares/src/lib/ares_query.c", - "third_party/cares/cares/src/lib/ares_options.c", - "third_party/cares/cares/src/lib/inet_net_pton.c", - "third_party/cares/cares/src/lib/ares_expand_name.c", - "third_party/cares/cares/src/lib/inet_ntop.c", - "third_party/cares/cares/src/lib/ares_cancel.c", - "third_party/cares/cares/src/lib/ares_writev.c", - ], - 'headers': [ - "third_party/cares/ares_build.h", - "third_party/cares/cares/include/ares_version.h", - "third_party/cares/cares/include/ares.h", - "third_party/cares/cares/include/ares_rules.h", - "third_party/cares/cares/include/ares_dns.h", - "third_party/cares/cares/src/lib/ares_data.h", - "third_party/cares/cares/src/lib/ares_strsplit.h", - "third_party/cares/cares/src/lib/bitncmp.h", - "third_party/cares/cares/src/lib/ares_iphlpapi.h", - "third_party/cares/cares/src/lib/ares_inet_net_pton.h", - "third_party/cares/cares/src/lib/ares_getenv.h", - "third_party/cares/cares/src/lib/ares_platform.h", - "third_party/cares/cares/src/lib/ares_writev.h", - "third_party/cares/cares/src/lib/ares_private.h", - "third_party/cares/cares/src/lib/ares_setup.h", - "third_party/cares/cares/src/lib/config-win32.h", - "third_party/cares/cares/src/lib/ares_strcasecmp.h", - "third_party/cares/cares/src/lib/setup_once.h", - "third_party/cares/cares/src/lib/ares_ipv6.h", - "third_party/cares/cares/src/lib/ares_library_init.h", - "third_party/cares/cares/src/lib/ares_nameser.h", - "third_party/cares/cares/src/lib/ares_strdup.h", - "third_party/cares/cares/src/lib/config-dos.h", - "third_party/cares/cares/src/lib/ares_llist.h", - "third_party/cares/cares/src/lib/ares_nowarn.h", - "third_party/cares/cares/src/lib/ares_android.h", - "third_party/cares/config_darwin/ares_config.h", - "third_party/cares/config_freebsd/ares_config.h", - "third_party/cares/config_linux/ares_config.h", - "third_party/cares/config_openbsd/ares_config.h" - ], - }] + out["libs"] = [ + { + "name": "ares", + "defaults": "ares", + "build": "private", + "language": "c", + "secure": False, + "src": [ + "third_party/cares/cares/src/lib/ares_init.c", + "third_party/cares/cares/src/lib/ares_expand_string.c", + "third_party/cares/cares/src/lib/ares_strcasecmp.c", + "third_party/cares/cares/src/lib/ares_destroy.c", + "third_party/cares/cares/src/lib/ares_free_string.c", + "third_party/cares/cares/src/lib/ares__timeval.c", + "third_party/cares/cares/src/lib/ares_library_init.c", + "third_party/cares/cares/src/lib/ares_getsock.c", + "third_party/cares/cares/src/lib/ares_process.c", + "third_party/cares/cares/src/lib/ares_create_query.c", + "third_party/cares/cares/src/lib/ares_fds.c", + "third_party/cares/cares/src/lib/ares_gethostbyname.c", + "third_party/cares/cares/src/lib/ares_mkquery.c", + "third_party/cares/cares/src/lib/ares_freeaddrinfo.c", + "third_party/cares/cares/src/lib/ares_strdup.c", + "third_party/cares/cares/src/lib/ares_timeout.c", + "third_party/cares/cares/src/lib/ares_getnameinfo.c", + "third_party/cares/cares/src/lib/ares_parse_soa_reply.c", + "third_party/cares/cares/src/lib/ares_parse_naptr_reply.c", + "third_party/cares/cares/src/lib/ares_parse_a_reply.c", + "third_party/cares/cares/src/lib/ares_send.c", + "third_party/cares/cares/src/lib/ares_nowarn.c", + "third_party/cares/cares/src/lib/ares__sortaddrinfo.c", + "third_party/cares/cares/src/lib/ares_android.c", + "third_party/cares/cares/src/lib/ares_strerror.c", + "third_party/cares/cares/src/lib/ares_parse_caa_reply.c", + "third_party/cares/cares/src/lib/ares__close_sockets.c", + "third_party/cares/cares/src/lib/ares_llist.c", + "third_party/cares/cares/src/lib/ares_parse_aaaa_reply.c", + "third_party/cares/cares/src/lib/ares_getaddrinfo.c", + "third_party/cares/cares/src/lib/ares_parse_ns_reply.c", + "third_party/cares/cares/src/lib/windows_port.c", + "third_party/cares/cares/src/lib/bitncmp.c", + "third_party/cares/cares/src/lib/ares_strsplit.c", + "third_party/cares/cares/src/lib/ares_data.c", + "third_party/cares/cares/src/lib/ares_free_hostent.c", + "third_party/cares/cares/src/lib/ares_platform.c", + "third_party/cares/cares/src/lib/ares_parse_txt_reply.c", + "third_party/cares/cares/src/lib/ares__parse_into_addrinfo.c", + "third_party/cares/cares/src/lib/ares_gethostbyaddr.c", + "third_party/cares/cares/src/lib/ares_parse_srv_reply.c", + "third_party/cares/cares/src/lib/ares_version.c", + "third_party/cares/cares/src/lib/ares_getenv.c", + "third_party/cares/cares/src/lib/ares_search.c", + "third_party/cares/cares/src/lib/ares_parse_mx_reply.c", + "third_party/cares/cares/src/lib/ares__get_hostent.c", + "third_party/cares/cares/src/lib/ares__readaddrinfo.c", + "third_party/cares/cares/src/lib/ares_parse_ptr_reply.c", + "third_party/cares/cares/src/lib/ares__read_line.c", + "third_party/cares/cares/src/lib/ares_query.c", + "third_party/cares/cares/src/lib/ares_options.c", + "third_party/cares/cares/src/lib/inet_net_pton.c", + "third_party/cares/cares/src/lib/ares_expand_name.c", + "third_party/cares/cares/src/lib/inet_ntop.c", + "third_party/cares/cares/src/lib/ares_cancel.c", + "third_party/cares/cares/src/lib/ares_writev.c", + ], + "headers": [ + "third_party/cares/ares_build.h", + "third_party/cares/cares/include/ares_version.h", + "third_party/cares/cares/include/ares.h", + "third_party/cares/cares/include/ares_rules.h", + "third_party/cares/cares/include/ares_dns.h", + "third_party/cares/cares/src/lib/ares_data.h", + "third_party/cares/cares/src/lib/ares_strsplit.h", + "third_party/cares/cares/src/lib/bitncmp.h", + "third_party/cares/cares/src/lib/ares_iphlpapi.h", + "third_party/cares/cares/src/lib/ares_inet_net_pton.h", + "third_party/cares/cares/src/lib/ares_getenv.h", + "third_party/cares/cares/src/lib/ares_platform.h", + "third_party/cares/cares/src/lib/ares_writev.h", + "third_party/cares/cares/src/lib/ares_private.h", + "third_party/cares/cares/src/lib/ares_setup.h", + "third_party/cares/cares/src/lib/config-win32.h", + "third_party/cares/cares/src/lib/ares_strcasecmp.h", + "third_party/cares/cares/src/lib/setup_once.h", + "third_party/cares/cares/src/lib/ares_ipv6.h", + "third_party/cares/cares/src/lib/ares_library_init.h", + "third_party/cares/cares/src/lib/ares_nameser.h", + "third_party/cares/cares/src/lib/ares_strdup.h", + "third_party/cares/cares/src/lib/config-dos.h", + "third_party/cares/cares/src/lib/ares_llist.h", + "third_party/cares/cares/src/lib/ares_nowarn.h", + "third_party/cares/cares/src/lib/ares_android.h", + "third_party/cares/config_darwin/ares_config.h", + "third_party/cares/config_freebsd/ares_config.h", + "third_party/cares/config_linux/ares_config.h", + "third_party/cares/config_openbsd/ares_config.h", + ], + } + ] except: pass diff --git a/src/csharp/Grpc.Tools.Tests/scripts/fakeprotoc.py b/src/csharp/Grpc.Tools.Tests/scripts/fakeprotoc.py index 29fb9f0f3c23a..01f4dbfb76e83 100755 --- a/src/csharp/Grpc.Tools.Tests/scripts/fakeprotoc.py +++ b/src/csharp/Grpc.Tools.Tests/scripts/fakeprotoc.py @@ -117,7 +117,11 @@ def _parse_protoc_arguments(protoc_args, projectdir): # msbuild integration uses, but it's not the only way. (name, value) = arg.split("=", 1) - if name == "--dependency_out" or name == "--grpc_out" or name == "--csharp_out": + if ( + name == "--dependency_out" + or name == "--grpc_out" + or name == "--csharp_out" + ): # For args that contain a path, make the path absolute and normalize it # to make it easier to assert equality in tests. value = _normalized_absolute_path(value) @@ -152,7 +156,8 @@ def _normalized_relative_to_projectdir(file, projectdir): """Convert a file path to one relative to the project directory.""" try: return _normalize_slashes( - os.path.relpath(os.path.abspath(file), projectdir)) + os.path.relpath(os.path.abspath(file), projectdir) + ) except ValueError: # On Windows if the paths are on different drives then we get this error # Just return the absolute path @@ -170,7 +175,7 @@ def _normalize_slashes(path): def _write_or_update_results_json(log_dir, protofile, protoc_arg_dict): - """ Write or update the results JSON file """ + """Write or update the results JSON file""" # Read existing json. # Since protoc may be called more than once each build/test if there is @@ -182,9 +187,9 @@ def _write_or_update_results_json(log_dir, protofile, protoc_arg_dict): results_json = json.load(forig) else: results_json = {} - results_json['Files'] = {} + results_json["Files"] = {} - results_json['Files'][protofile] = protoc_arg_dict + results_json["Files"][protofile] = protoc_arg_dict results_json["Metadata"] = {"timestamp": str(datetime.datetime.now())} with open(fname, "w") as fout: @@ -227,8 +232,9 @@ def _is_grpc_out_file(csfile): return csfile.endswith("Grpc.cs") -def _generate_cs_files(protofile, cs_files_to_generate, grpc_out_dir, - csharp_out_dir, projectdir): +def _generate_cs_files( + protofile, cs_files_to_generate, grpc_out_dir, csharp_out_dir, projectdir +): """Create expected cs files.""" _write_debug("\ngenerate_cs_files") @@ -262,8 +268,13 @@ def _generate_cs_files(protofile, cs_files_to_generate, grpc_out_dir, print("// Generated by fake protoc: %s" % timestamp, file=fout) -def _create_dependency_file(protofile, cs_files_to_generate, dependencyfile, - grpc_out_dir, csharp_out_dir): +def _create_dependency_file( + protofile, + cs_files_to_generate, + dependencyfile, + grpc_out_dir, + csharp_out_dir, +): """Create the expected dependency file.""" _write_debug("\ncreate_dependency_file") @@ -312,21 +323,21 @@ def _get_argument_last_occurrence_or_none(protoc_arg_dict, name): def main(): # Check environment variables for the additional arguments used in the tests. - projectdir = _getenv('FAKEPROTOC_PROJECTDIR') + projectdir = _getenv("FAKEPROTOC_PROJECTDIR") if not projectdir: print("FAKEPROTOC_PROJECTDIR not set") sys.exit(1) projectdir = os.path.abspath(projectdir) # Output directory for generated files and output file - protoc_outdir = _getenv('FAKEPROTOC_OUTDIR') + protoc_outdir = _getenv("FAKEPROTOC_OUTDIR") if not protoc_outdir: print("FAKEPROTOC_OUTDIR not set") sys.exit(1) protoc_outdir = os.path.abspath(protoc_outdir) # Get list of expected generated files from env variable - generate_expected = _getenv('FAKEPROTOC_GENERATE_EXPECTED') + generate_expected = _getenv("FAKEPROTOC_GENERATE_EXPECTED") if not generate_expected: print("FAKEPROTOC_GENERATE_EXPECTED not set") sys.exit(1) @@ -338,9 +349,13 @@ def main(): _open_debug_log("%s/fakeprotoc_log.txt" % log_dir) _write_debug( - ("##### fakeprotoc called at %s\n" + "FAKEPROTOC_PROJECTDIR = %s\n" + - "FAKEPROTOC_GENERATE_EXPECTED = %s\n") % - (datetime.datetime.now(), projectdir, generate_expected)) + ( + "##### fakeprotoc called at %s\n" + + "FAKEPROTOC_PROJECTDIR = %s\n" + + "FAKEPROTOC_GENERATE_EXPECTED = %s\n" + ) + % (datetime.datetime.now(), projectdir, generate_expected) + ) proto_to_generated = _parse_generate_expected(generate_expected) protoc_args = _read_protoc_arguments() @@ -349,43 +364,52 @@ def main(): # If argument was passed multiple times, take the last occurrence of it. # TODO(jtattermusch): handle multiple occurrences of the same argument dependencyfile = _get_argument_last_occurrence_or_none( - protoc_arg_dict, '--dependency_out') - grpcout = _get_argument_last_occurrence_or_none(protoc_arg_dict, - '--grpc_out') - csharpout = _get_argument_last_occurrence_or_none(protoc_arg_dict, - '--csharp_out') + protoc_arg_dict, "--dependency_out" + ) + grpcout = _get_argument_last_occurrence_or_none( + protoc_arg_dict, "--grpc_out" + ) + csharpout = _get_argument_last_occurrence_or_none( + protoc_arg_dict, "--csharp_out" + ) # --grpc_out might not be set in which case use --csharp_out if grpcout is None: grpcout = csharpout - if len(protoc_arg_dict.get('protofile')) != 1: + if len(protoc_arg_dict.get("protofile")) != 1: # regular protoc can process multiple .proto files passed at once, but we know # the Grpc.Tools msbuild integration only ever passes one .proto file per invocation. print( - "Expecting to get exactly one .proto file argument per fakeprotoc invocation." + "Expecting to get exactly one .proto file argument per fakeprotoc" + " invocation." ) sys.exit(1) - protofile = protoc_arg_dict.get('protofile')[0] + protofile = protoc_arg_dict.get("protofile")[0] cs_files_to_generate = _get_cs_files_to_generate( - protofile=protofile, proto_to_generated=proto_to_generated) - - _create_dependency_file(protofile=protofile, - cs_files_to_generate=cs_files_to_generate, - dependencyfile=dependencyfile, - grpc_out_dir=grpcout, - csharp_out_dir=csharpout) - - _generate_cs_files(protofile=protofile, - cs_files_to_generate=cs_files_to_generate, - grpc_out_dir=grpcout, - csharp_out_dir=csharpout, - projectdir=projectdir) - - _write_or_update_results_json(log_dir=log_dir, - protofile=protofile, - protoc_arg_dict=protoc_arg_dict) + protofile=protofile, proto_to_generated=proto_to_generated + ) + + _create_dependency_file( + protofile=protofile, + cs_files_to_generate=cs_files_to_generate, + dependencyfile=dependencyfile, + grpc_out_dir=grpcout, + csharp_out_dir=csharpout, + ) + + _generate_cs_files( + protofile=protofile, + cs_files_to_generate=cs_files_to_generate, + grpc_out_dir=grpcout, + csharp_out_dir=csharpout, + projectdir=projectdir, + ) + + _write_or_update_results_json( + log_dir=log_dir, protofile=protofile, protoc_arg_dict=protoc_arg_dict + ) _close_debug_log() diff --git a/src/objective-c/change-comments.py b/src/objective-c/change-comments.py index 82645a419f5f5..a2fa264fa6b0c 100755 --- a/src/objective-c/change-comments.py +++ b/src/objective-c/change-comments.py @@ -22,9 +22,11 @@ sys.exit() for file_name in sys.argv[1:]: - - print("Modifying format of {file} comments in place...".format( - file=file_name,)) + print( + "Modifying format of {file} comments in place...".format( + file=file_name, + ) + ) # Input @@ -54,7 +56,7 @@ def flush_output(): # Pattern matching - comment_regex = r'^(\s*)//\s(.*)$' + comment_regex = r"^(\s*)//\s(.*)$" def is_comment(line): return re.search(comment_regex, line) @@ -84,8 +86,11 @@ def format_as_block(comment_block): if len(comment_block) == 1: return [indent + "/** " + content(comment_block[0]) + " */\n"] - block = ["/**"] + [" * " + content(line) for line in comment_block - ] + [" */"] + block = ( + ["/**"] + + [" * " + content(line) for line in comment_block] + + [" */"] + ) return [indent + line.rstrip() + "\n" for line in block] # Main algorithm @@ -97,7 +102,7 @@ def format_as_block(comment_block): comment_block = [] # Get all lines in the same comment block. We could restrict the indentation # to be the same as the first line of the block, but it's probably ok. - while (next_line(is_comment)): + while next_line(is_comment): comment_block.append(read_line()) for line in format_as_block(comment_block): diff --git a/src/php/bin/xds_manager.py b/src/php/bin/xds_manager.py index 8ba6733bba333..85057c4bf0a06 100755 --- a/src/php/bin/xds_manager.py +++ b/src/php/bin/xds_manager.py @@ -25,55 +25,69 @@ # processes and reports back to the main PHP interop client each # of the child RPCs' status code. -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--tmp_file1', nargs='?', default='') - parser.add_argument('--tmp_file2', nargs='?', default='') - parser.add_argument('--bootstrap_path', nargs='?', default='') + parser.add_argument("--tmp_file1", nargs="?", default="") + parser.add_argument("--tmp_file2", nargs="?", default="") + parser.add_argument("--bootstrap_path", nargs="?", default="") args = parser.parse_args() - server_address = '' + server_address = "" rpcs_started = [] open_processes = {} client_env = dict(os.environ) - client_env['GRPC_XDS_BOOTSTRAP'] = args.bootstrap_path + client_env["GRPC_XDS_BOOTSTRAP"] = args.bootstrap_path while True: # tmp_file1 contains a list of RPCs (and their spec) the parent process # wants executed - f1 = open(args.tmp_file1, 'r+') + f1 = open(args.tmp_file1, "r+") fcntl.flock(f1, fcntl.LOCK_EX) while True: key = f1.readline() if not key: break key = key.strip() - if key.startswith('server_address'): + if key.startswith("server_address"): if not server_address: server_address = key[15:] elif not key in rpcs_started: # format here needs to be in sync with # src/php/tests/interop/xds_client.php - items = key.split('|') + items = key.split("|") num = items[0] metadata = items[2] timeout_sec = items[3] - if items[1] == 'UnaryCall': - p = subprocess.Popen([ - 'php', '-d', 'extension=grpc.so', '-d', - 'extension=pthreads.so', - 'src/php/tests/interop/xds_unary_call.php', - '--server=' + server_address, '--num=' + str(num), - '--metadata=' + metadata, '--timeout_sec=' + timeout_sec - ], - env=client_env) - elif items[1] == 'EmptyCall': - p = subprocess.Popen([ - 'php', '-d', 'extension=grpc.so', '-d', - 'extension=pthreads.so', - 'src/php/tests/interop/xds_empty_call.php', - '--server=' + server_address, '--num=' + str(num), - '--metadata=' + metadata, '--timeout_sec=' + timeout_sec - ], - env=client_env) + if items[1] == "UnaryCall": + p = subprocess.Popen( + [ + "php", + "-d", + "extension=grpc.so", + "-d", + "extension=pthreads.so", + "src/php/tests/interop/xds_unary_call.php", + "--server=" + server_address, + "--num=" + str(num), + "--metadata=" + metadata, + "--timeout_sec=" + timeout_sec, + ], + env=client_env, + ) + elif items[1] == "EmptyCall": + p = subprocess.Popen( + [ + "php", + "-d", + "extension=grpc.so", + "-d", + "extension=pthreads.so", + "src/php/tests/interop/xds_empty_call.php", + "--server=" + server_address, + "--num=" + str(num), + "--metadata=" + metadata, + "--timeout_sec=" + timeout_sec, + ], + env=client_env, + ) else: continue rpcs_started.append(key) @@ -82,7 +96,7 @@ fcntl.flock(f1, fcntl.LOCK_UN) f1.close() # tmp_file2 contains the RPC result of each key received from tmp_file1 - f2 = open(args.tmp_file2, 'a') + f2 = open(args.tmp_file2, "a") fcntl.flock(f2, fcntl.LOCK_EX) keys_to_delete = [] for key, process in open_processes.items(): @@ -90,7 +104,7 @@ if result is not None: # format here needs to be in sync with # src/php/tests/interop/xds_client.php - f2.write(key + ',' + str(process.returncode) + "\n") + f2.write(key + "," + str(process.returncode) + "\n") keys_to_delete.append(key) for key in keys_to_delete: del open_processes[key] diff --git a/src/proto/gen_build_yaml.py b/src/proto/gen_build_yaml.py index dea444ff40e22..a3d7de11e6e34 100755 --- a/src/proto/gen_build_yaml.py +++ b/src/proto/gen_build_yaml.py @@ -32,52 +32,54 @@ def update_deps(key, proto_filename, deps, deps_external, is_trans, visited): imp_proto = imp.group(1) # This indicates an external dependency, which we should handle # differently and not traverse recursively - if imp_proto.startswith('google/'): + if imp_proto.startswith("google/"): if key not in deps_external: deps_external[key] = [] deps_external[key].append(imp_proto[:-6]) continue # In case that the path is changed by copybara, # revert the change to avoid file error. - if imp_proto.startswith('third_party/grpc'): + if imp_proto.startswith("third_party/grpc"): imp_proto = imp_proto[17:] if key not in deps: deps[key] = [] deps[key].append(imp_proto[:-6]) if is_trans: - update_deps(key, imp_proto, deps, deps_external, is_trans, - visited) + update_deps( + key, imp_proto, deps, deps_external, is_trans, visited + ) def main(): proto_dir = os.path.abspath(os.path.dirname(sys.argv[0])) - os.chdir(os.path.join(proto_dir, '../..')) + os.chdir(os.path.join(proto_dir, "../..")) deps = {} deps_trans = {} deps_external = {} deps_external_trans = {} - for root, dirs, files in os.walk('src/proto'): + for root, dirs, files in os.walk("src/proto"): for f in files: - if f[-6:] != '.proto': + if f[-6:] != ".proto": continue look_at = os.path.join(root, f) deps_for = look_at[:-6] # First level deps update_deps(deps_for, look_at, deps, deps_external, False, []) # Transitive deps - update_deps(deps_for, look_at, deps_trans, deps_external_trans, - True, []) + update_deps( + deps_for, look_at, deps_trans, deps_external_trans, True, [] + ) json = { - 'proto_deps': deps, - 'proto_transitive_deps': deps_trans, - 'proto_external_deps': deps_external, - 'proto_transitive_external_deps': deps_external_trans + "proto_deps": deps, + "proto_transitive_deps": deps_trans, + "proto_external_deps": deps_external, + "proto_transitive_external_deps": deps_external_trans, } print(yaml.dump(json)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/python/grpcio/_parallel_compile_patch.py b/src/python/grpcio/_parallel_compile_patch.py index e4d50c38311d6..4adc3630b5a45 100644 --- a/src/python/grpcio/_parallel_compile_patch.py +++ b/src/python/grpcio/_parallel_compile_patch.py @@ -22,28 +22,33 @@ try: BUILD_EXT_COMPILER_JOBS = int( - os.environ['GRPC_PYTHON_BUILD_EXT_COMPILER_JOBS']) + os.environ["GRPC_PYTHON_BUILD_EXT_COMPILER_JOBS"] + ) except KeyError: import multiprocessing + BUILD_EXT_COMPILER_JOBS = multiprocessing.cpu_count() except ValueError: BUILD_EXT_COMPILER_JOBS = 1 # monkey-patch for parallel compilation -def _parallel_compile(self, - sources, - output_dir=None, - macros=None, - include_dirs=None, - debug=0, - extra_preargs=None, - extra_postargs=None, - depends=None): +def _parallel_compile( + self, + sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None, +): # setup the same way as distutils.ccompiler.CCompiler # https://github.com/python/cpython/blob/31368a4f0e531c19affe2a1becd25fc316bc7501/Lib/distutils/ccompiler.py#L564 macros, objects, extra_postargs, pp_opts, build = self._setup_compile( - str(output_dir), macros, include_dirs, sources, depends, extra_postargs) + str(output_dir), macros, include_dirs, sources, depends, extra_postargs + ) cc_args = self._get_cc_args(pp_opts, debug, extra_preargs) def _compile_single_file(obj): @@ -55,8 +60,10 @@ def _compile_single_file(obj): # run compilation of individual files in parallel import multiprocessing.pool + multiprocessing.pool.ThreadPool(BUILD_EXT_COMPILER_JOBS).map( - _compile_single_file, objects) + _compile_single_file, objects + ) return objects diff --git a/src/python/grpcio/_spawn_patch.py b/src/python/grpcio/_spawn_patch.py index 421c7dfee83b1..e5e1f1ef85c7b 100644 --- a/src/python/grpcio/_spawn_patch.py +++ b/src/python/grpcio/_spawn_patch.py @@ -32,23 +32,24 @@ def _commandfile_spawn(self, command): command_length = sum([len(arg) for arg in command]) - if os.name == 'nt' and command_length > MAX_COMMAND_LENGTH: + if os.name == "nt" and command_length > MAX_COMMAND_LENGTH: # Even if this command doesn't support the @command_file, it will # fail as is so we try blindly - print('Command line length exceeded, using command file') - print(' '.join(command)) + print("Command line length exceeded, using command file") + print(" ".join(command)) temporary_directory = tempfile.mkdtemp() command_filename = os.path.abspath( - os.path.join(temporary_directory, 'command')) - with open(command_filename, 'w') as command_file: + os.path.join(temporary_directory, "command") + ) + with open(command_filename, "w") as command_file: escaped_args = [ - '"' + arg.replace('\\', '\\\\') + '"' for arg in command[1:] + '"' + arg.replace("\\", "\\\\") + '"' for arg in command[1:] ] # add each arg on a separate line to avoid hitting the # "line in command file contains 131071 or more characters" error # (can happen for extra long link commands) - command_file.write(' \n'.join(escaped_args)) - modified_command = command[:1] + ['@{}'.format(command_filename)] + command_file.write(" \n".join(escaped_args)) + modified_command = command[:1] + ["@{}".format(command_filename)] try: _classic_spawn(self, modified_command) finally: @@ -59,5 +60,5 @@ def _commandfile_spawn(self, command): def monkeypatch_spawn(): """Monkeypatching is dumb, but it's either that or we become maintainers of - something much, much bigger.""" + something much, much bigger.""" ccompiler.CCompiler.spawn = _commandfile_spawn diff --git a/src/python/grpcio/commands.py b/src/python/grpcio/commands.py index 6d8228ffa052a..3739097a745ec 100644 --- a/src/python/grpcio/commands.py +++ b/src/python/grpcio/commands.py @@ -31,10 +31,10 @@ import support PYTHON_STEM = os.path.dirname(os.path.abspath(__file__)) -GRPC_STEM = os.path.abspath(PYTHON_STEM + '../../../../') -PROTO_STEM = os.path.join(GRPC_STEM, 'src', 'proto') -PROTO_GEN_STEM = os.path.join(GRPC_STEM, 'src', 'python', 'gens') -CYTHON_STEM = os.path.join(PYTHON_STEM, 'grpc', '_cython') +GRPC_STEM = os.path.abspath(PYTHON_STEM + "../../../../") +PROTO_STEM = os.path.join(GRPC_STEM, "src", "proto") +PROTO_GEN_STEM = os.path.join(GRPC_STEM, "src", "python", "gens") +CYTHON_STEM = os.path.join(PYTHON_STEM, "grpc", "_cython") class CommandError(Exception): @@ -46,9 +46,9 @@ class CommandError(Exception): def _get_grpc_custom_bdist(decorated_basename, target_bdist_basename): """Returns a string path to a bdist file for Linux to install. - If we can retrieve a pre-compiled bdist from online, uses it. Else, emits a - warning and builds from source. - """ + If we can retrieve a pre-compiled bdist from online, uses it. Else, emits a + warning and builds from source. + """ # TODO(atash): somehow the name that's returned from `wheel` is different # between different versions of 'wheel' (but from a compatibility standpoint, # the names are compatible); we should have some way of determining name @@ -58,28 +58,35 @@ def _get_grpc_custom_bdist(decorated_basename, target_bdist_basename): # Break import style to ensure that setup.py has had a chance to install the # relevant package. from urllib import request + decorated_path = decorated_basename + GRPC_CUSTOM_BDIST_EXT try: - url = BINARIES_REPOSITORY + '/{target}'.format(target=decorated_path) + url = BINARIES_REPOSITORY + "/{target}".format(target=decorated_path) bdist_data = request.urlopen(url).read() except IOError as error: - raise CommandError('{}\n\nCould not find the bdist {}: {}'.format( - traceback.format_exc(), decorated_path, error.message)) + raise CommandError( + "{}\n\nCould not find the bdist {}: {}".format( + traceback.format_exc(), decorated_path, error.message + ) + ) # Our chosen local bdist path. bdist_path = target_bdist_basename + GRPC_CUSTOM_BDIST_EXT try: - with open(bdist_path, 'w') as bdist_file: + with open(bdist_path, "w") as bdist_file: bdist_file.write(bdist_data) except IOError as error: - raise CommandError('{}\n\nCould not write grpcio bdist: {}'.format( - traceback.format_exc(), error.message)) + raise CommandError( + "{}\n\nCould not write grpcio bdist: {}".format( + traceback.format_exc(), error.message + ) + ) return bdist_path class SphinxDocumentation(setuptools.Command): """Command to generate documentation via sphinx.""" - description = 'generate sphinx documentation' + description = "generate sphinx documentation" user_options = [] def initialize_options(self): @@ -92,19 +99,22 @@ def run(self): # We import here to ensure that setup.py has had a chance to install the # relevant package eggs first. import sphinx.cmd.build - source_dir = os.path.join(GRPC_STEM, 'doc', 'python', 'sphinx') - target_dir = os.path.join(GRPC_STEM, 'doc', 'build') + + source_dir = os.path.join(GRPC_STEM, "doc", "python", "sphinx") + target_dir = os.path.join(GRPC_STEM, "doc", "build") exit_code = sphinx.cmd.build.build_main( - ['-b', 'html', '-W', '--keep-going', source_dir, target_dir]) + ["-b", "html", "-W", "--keep-going", source_dir, target_dir] + ) if exit_code != 0: raise CommandError( - "Documentation generation has warnings or errors") + "Documentation generation has warnings or errors" + ) class BuildProjectMetadata(setuptools.Command): """Command to generate project metadata in a module.""" - description = 'build grpcio project metadata files' + description = "build grpcio project metadata files" user_options = [] def initialize_options(self): @@ -114,95 +124,112 @@ def finalize_options(self): pass def run(self): - with open(os.path.join(PYTHON_STEM, 'grpc/_grpcio_metadata.py'), - 'w') as module_file: - module_file.write('__version__ = """{}"""'.format( - self.distribution.get_version())) + with open( + os.path.join(PYTHON_STEM, "grpc/_grpcio_metadata.py"), "w" + ) as module_file: + module_file.write( + '__version__ = """{}"""'.format(self.distribution.get_version()) + ) class BuildPy(build_py.build_py): """Custom project build command.""" def run(self): - self.run_command('build_project_metadata') + self.run_command("build_project_metadata") build_py.build_py.run(self) def _poison_extensions(extensions, message): """Includes a file that will always fail to compile in all extensions.""" - poison_filename = os.path.join(PYTHON_STEM, 'poison.c') - with open(poison_filename, 'w') as poison: - poison.write('#error {}'.format(message)) + poison_filename = os.path.join(PYTHON_STEM, "poison.c") + with open(poison_filename, "w") as poison: + poison.write("#error {}".format(message)) for extension in extensions: extension.sources = [poison_filename] def check_and_update_cythonization(extensions): """Replace .pyx files with their generated counterparts and return whether or - not cythonization still needs to occur.""" + not cythonization still needs to occur.""" for extension in extensions: generated_pyx_sources = [] other_sources = [] for source in extension.sources: base, file_ext = os.path.splitext(source) - if file_ext == '.pyx': - generated_pyx_source = next((base + gen_ext for gen_ext in ( - '.c', - '.cpp', - ) if os.path.isfile(base + gen_ext)), None) + if file_ext == ".pyx": + generated_pyx_source = next( + ( + base + gen_ext + for gen_ext in ( + ".c", + ".cpp", + ) + if os.path.isfile(base + gen_ext) + ), + None, + ) if generated_pyx_source: generated_pyx_sources.append(generated_pyx_source) else: - sys.stderr.write('Cython-generated files are missing...\n') + sys.stderr.write("Cython-generated files are missing...\n") return False else: other_sources.append(source) extension.sources = generated_pyx_sources + other_sources - sys.stderr.write('Found cython-generated files...\n') + sys.stderr.write("Found cython-generated files...\n") return True def try_cythonize(extensions, linetracing=False, mandatory=True): """Attempt to cythonize the extensions. - Args: - extensions: A list of `distutils.extension.Extension`. - linetracing: A bool indicating whether or not to enable linetracing. - mandatory: Whether or not having Cython-generated files is mandatory. If it - is, extensions will be poisoned when they can't be fully generated. - """ + Args: + extensions: A list of `distutils.extension.Extension`. + linetracing: A bool indicating whether or not to enable linetracing. + mandatory: Whether or not having Cython-generated files is mandatory. If it + is, extensions will be poisoned when they can't be fully generated. + """ try: # Break import style to ensure we have access to Cython post-setup_requires import Cython.Build except ImportError: if mandatory: sys.stderr.write( - "This package needs to generate C files with Cython but it cannot. " - "Poisoning extension sources to disallow extension commands...") + "This package needs to generate C files with Cython but it" + " cannot. Poisoning extension sources to disallow extension" + " commands..." + ) _poison_extensions( extensions, - "Extensions have been poisoned due to missing Cython-generated code." + ( + "Extensions have been poisoned due to missing" + " Cython-generated code." + ), ) return extensions cython_compiler_directives = {} if linetracing: - additional_define_macros = [('CYTHON_TRACE_NOGIL', '1')] - cython_compiler_directives['linetrace'] = True + additional_define_macros = [("CYTHON_TRACE_NOGIL", "1")] + cython_compiler_directives["linetrace"] = True return Cython.Build.cythonize( extensions, include_path=[ - include_dir for extension in extensions + include_dir + for extension in extensions for include_dir in extension.include_dirs - ] + [CYTHON_STEM], - compiler_directives=cython_compiler_directives) + ] + + [CYTHON_STEM], + compiler_directives=cython_compiler_directives, + ) class BuildExt(build_ext.build_ext): """Custom build_ext command to enable compiler-specific flags.""" C_OPTIONS = { - 'unix': ('-pthread',), - 'msvc': (), + "unix": ("-pthread",), + "msvc": (), } LINK_OPTIONS = {} @@ -214,30 +241,32 @@ def get_ext_filename(self, ext_name): # so that the resulting file name matches the target architecture and we end up with a well-formed # wheel. filename = build_ext.build_ext.get_ext_filename(self, ext_name) - orig_ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') - new_ext_suffix = os.getenv('GRPC_PYTHON_OVERRIDE_EXT_SUFFIX') + orig_ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") + new_ext_suffix = os.getenv("GRPC_PYTHON_OVERRIDE_EXT_SUFFIX") if new_ext_suffix and filename.endswith(orig_ext_suffix): - filename = filename[:-len(orig_ext_suffix)] + new_ext_suffix + filename = filename[: -len(orig_ext_suffix)] + new_ext_suffix return filename def build_extensions(self): - def compiler_ok_with_extra_std(): """Test if default compiler is okay with specifying c++ version when invoked in C mode. GCC is okay with this, while clang is not. """ try: # TODO(lidiz) Remove the generated a.out for success tests. - cc = os.environ.get('CC', 'cc') - cc_test = subprocess.Popen([cc, '-x', 'c', '-std=c++14', '-'], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - _, cc_err = cc_test.communicate(input=b'int main(){return 0;}') - return not 'invalid argument' in str(cc_err) + cc = os.environ.get("CC", "cc") + cc_test = subprocess.Popen( + [cc, "-x", "c", "-std=c++14", "-"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _, cc_err = cc_test.communicate(input=b"int main(){return 0;}") + return not "invalid argument" in str(cc_err) except: - sys.stderr.write('Non-fatal exception:' + - traceback.format_exc() + '\n') + sys.stderr.write( + "Non-fatal exception:" + traceback.format_exc() + "\n" + ) return False # This special conditioning is here due to difference of compiler @@ -253,16 +282,17 @@ def compiler_ok_with_extra_std(): old_compile = self.compiler._compile def new_compile(obj, src, ext, cc_args, extra_postargs, pp_opts): - if src.endswith('.c'): + if src.endswith(".c"): extra_postargs = [ - arg for arg in extra_postargs if not '-std=c++' in arg + arg for arg in extra_postargs if not "-std=c++" in arg ] - elif src.endswith('.cc') or src.endswith('.cpp'): + elif src.endswith(".cc") or src.endswith(".cpp"): extra_postargs = [ - arg for arg in extra_postargs if not '-std=gnu99' in arg + arg for arg in extra_postargs if not "-std=gnu99" in arg ] - return old_compile(obj, src, ext, cc_args, extra_postargs, - pp_opts) + return old_compile( + obj, src, ext, cc_args, extra_postargs, pp_opts + ) self.compiler._compile = new_compile @@ -270,11 +300,13 @@ def new_compile(obj, src, ext, cc_args, extra_postargs, pp_opts): if compiler in BuildExt.C_OPTIONS: for extension in self.extensions: extension.extra_compile_args += list( - BuildExt.C_OPTIONS[compiler]) + BuildExt.C_OPTIONS[compiler] + ) if compiler in BuildExt.LINK_OPTIONS: for extension in self.extensions: extension.extra_link_args += list( - BuildExt.LINK_OPTIONS[compiler]) + BuildExt.LINK_OPTIONS[compiler] + ) if not check_and_update_cythonization(self.extensions): self.extensions = try_cythonize(self.extensions) try: @@ -283,16 +315,17 @@ def new_compile(obj, src, ext, cc_args, extra_postargs, pp_opts): formatted_exception = traceback.format_exc() support.diagnose_build_ext_error(self, error, formatted_exception) raise CommandError( - "Failed `build_ext` step:\n{}".format(formatted_exception)) + "Failed `build_ext` step:\n{}".format(formatted_exception) + ) class Gather(setuptools.Command): """Command to gather project dependencies.""" - description = 'gather dependencies for grpcio' + description = "gather dependencies for grpcio" user_options = [ - ('test', 't', 'flag indicating to gather test dependencies'), - ('install', 'i', 'flag indicating to gather install dependencies') + ("test", "t", "flag indicating to gather test dependencies"), + ("install", "i", "flag indicating to gather install dependencies"), ] def initialize_options(self): @@ -306,7 +339,8 @@ def finalize_options(self): def run(self): if self.install and self.distribution.install_requires: self.distribution.fetch_build_eggs( - self.distribution.install_requires) + self.distribution.install_requires + ) if self.test and self.distribution.tests_require: self.distribution.fetch_build_eggs(self.distribution.tests_require) @@ -314,20 +348,21 @@ def run(self): class Clean(setuptools.Command): """Command to clean build artifacts.""" - description = 'Clean build artifacts.' + description = "Clean build artifacts." user_options = [ - ('all', 'a', 'a phony flag to allow our script to continue'), + ("all", "a", "a phony flag to allow our script to continue"), ] _FILE_PATTERNS = ( - 'python_build', - 'src/python/grpcio/__pycache__/', - 'src/python/grpcio/grpc/_cython/cygrpc.cpp', - 'src/python/grpcio/grpc/_cython/*.so', - 'src/python/grpcio/grpcio.egg-info/', + "python_build", + "src/python/grpcio/__pycache__/", + "src/python/grpcio/grpc/_cython/cygrpc.cpp", + "src/python/grpcio/grpc/_cython/*.so", + "src/python/grpcio/grpcio.egg-info/", ) _CURRENT_DIRECTORY = os.path.normpath( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../..")) + os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../..") + ) def initialize_options(self): self.all = False @@ -338,12 +373,14 @@ def finalize_options(self): def run(self): for path_spec in self._FILE_PATTERNS: this_glob = os.path.normpath( - os.path.join(Clean._CURRENT_DIRECTORY, path_spec)) + os.path.join(Clean._CURRENT_DIRECTORY, path_spec) + ) abs_paths = glob.glob(this_glob) for path in abs_paths: if not str(path).startswith(Clean._CURRENT_DIRECTORY): raise ValueError( - "Cowardly refusing to delete {}.".format(path)) + "Cowardly refusing to delete {}.".format(path) + ) print("Removing {}".format(os.path.relpath(path))) if os.path.isfile(path): os.remove(str(path)) diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index ce7446dc90574..afbf861619ba6 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -219,12 +219,15 @@ class ChannelConnectivity(enum.Enum): to recover. SHUTDOWN: The channel has seen a failure from which it cannot recover. """ - IDLE = (_cygrpc.ConnectivityState.idle, 'idle') - CONNECTING = (_cygrpc.ConnectivityState.connecting, 'connecting') - READY = (_cygrpc.ConnectivityState.ready, 'ready') - TRANSIENT_FAILURE = (_cygrpc.ConnectivityState.transient_failure, - 'transient failure') - SHUTDOWN = (_cygrpc.ConnectivityState.shutdown, 'shutdown') + + IDLE = (_cygrpc.ConnectivityState.idle, "idle") + CONNECTING = (_cygrpc.ConnectivityState.connecting, "connecting") + READY = (_cygrpc.ConnectivityState.ready, "ready") + TRANSIENT_FAILURE = ( + _cygrpc.ConnectivityState.transient_failure, + "transient failure", + ) + SHUTDOWN = (_cygrpc.ConnectivityState.shutdown, "shutdown") @enum.unique @@ -256,27 +259,36 @@ class StatusCode(enum.Enum): UNAVAILABLE: The service is currently unavailable. DATA_LOSS: Unrecoverable data loss or corruption. """ - OK = (_cygrpc.StatusCode.ok, 'ok') - CANCELLED = (_cygrpc.StatusCode.cancelled, 'cancelled') - UNKNOWN = (_cygrpc.StatusCode.unknown, 'unknown') - INVALID_ARGUMENT = (_cygrpc.StatusCode.invalid_argument, 'invalid argument') - DEADLINE_EXCEEDED = (_cygrpc.StatusCode.deadline_exceeded, - 'deadline exceeded') - NOT_FOUND = (_cygrpc.StatusCode.not_found, 'not found') - ALREADY_EXISTS = (_cygrpc.StatusCode.already_exists, 'already exists') - PERMISSION_DENIED = (_cygrpc.StatusCode.permission_denied, - 'permission denied') - RESOURCE_EXHAUSTED = (_cygrpc.StatusCode.resource_exhausted, - 'resource exhausted') - FAILED_PRECONDITION = (_cygrpc.StatusCode.failed_precondition, - 'failed precondition') - ABORTED = (_cygrpc.StatusCode.aborted, 'aborted') - OUT_OF_RANGE = (_cygrpc.StatusCode.out_of_range, 'out of range') - UNIMPLEMENTED = (_cygrpc.StatusCode.unimplemented, 'unimplemented') - INTERNAL = (_cygrpc.StatusCode.internal, 'internal') - UNAVAILABLE = (_cygrpc.StatusCode.unavailable, 'unavailable') - DATA_LOSS = (_cygrpc.StatusCode.data_loss, 'data loss') - UNAUTHENTICATED = (_cygrpc.StatusCode.unauthenticated, 'unauthenticated') + + OK = (_cygrpc.StatusCode.ok, "ok") + CANCELLED = (_cygrpc.StatusCode.cancelled, "cancelled") + UNKNOWN = (_cygrpc.StatusCode.unknown, "unknown") + INVALID_ARGUMENT = (_cygrpc.StatusCode.invalid_argument, "invalid argument") + DEADLINE_EXCEEDED = ( + _cygrpc.StatusCode.deadline_exceeded, + "deadline exceeded", + ) + NOT_FOUND = (_cygrpc.StatusCode.not_found, "not found") + ALREADY_EXISTS = (_cygrpc.StatusCode.already_exists, "already exists") + PERMISSION_DENIED = ( + _cygrpc.StatusCode.permission_denied, + "permission denied", + ) + RESOURCE_EXHAUSTED = ( + _cygrpc.StatusCode.resource_exhausted, + "resource exhausted", + ) + FAILED_PRECONDITION = ( + _cygrpc.StatusCode.failed_precondition, + "failed precondition", + ) + ABORTED = (_cygrpc.StatusCode.aborted, "aborted") + OUT_OF_RANGE = (_cygrpc.StatusCode.out_of_range, "out of range") + UNIMPLEMENTED = (_cygrpc.StatusCode.unimplemented, "unimplemented") + INTERNAL = (_cygrpc.StatusCode.internal, "internal") + UNAVAILABLE = (_cygrpc.StatusCode.unavailable, "unavailable") + DATA_LOSS = (_cygrpc.StatusCode.data_loss, "data loss") + UNAUTHENTICATED = (_cygrpc.StatusCode.unauthenticated, "unauthenticated") ############################# gRPC Status ################################ @@ -459,8 +471,9 @@ class UnaryStreamClientInterceptor(abc.ABC): """Affords intercepting unary-stream invocations.""" @abc.abstractmethod - def intercept_unary_stream(self, continuation, client_call_details, - request): + def intercept_unary_stream( + self, continuation, client_call_details, request + ): """Intercepts a unary-stream invocation. Args: @@ -493,8 +506,9 @@ class StreamUnaryClientInterceptor(abc.ABC): """Affords intercepting stream-unary invocations.""" @abc.abstractmethod - def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): """Intercepts a stream-unary invocation asynchronously. Args: @@ -527,8 +541,9 @@ class StreamStreamClientInterceptor(abc.ABC): """Affords intercepting stream-stream invocations.""" @abc.abstractmethod - def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): """Intercepts a stream-stream invocation. Args: @@ -662,13 +677,15 @@ class UnaryUnaryMultiCallable(abc.ABC): """Affords invoking a unary-unary RPC from client-side.""" @abc.abstractmethod - def __call__(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def __call__( + self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): """Synchronously invokes the underlying RPC. Args: @@ -694,13 +711,15 @@ def __call__(self, raise NotImplementedError() @abc.abstractmethod - def with_call(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def with_call( + self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): """Synchronously invokes the underlying RPC. Args: @@ -726,13 +745,15 @@ def with_call(self, raise NotImplementedError() @abc.abstractmethod - def future(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def future( + self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): """Asynchronously invokes the underlying RPC. Args: @@ -761,13 +782,15 @@ class UnaryStreamMultiCallable(abc.ABC): """Affords invoking a unary-stream RPC from client-side.""" @abc.abstractmethod - def __call__(self, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def __call__( + self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): """Invokes the underlying RPC. Args: @@ -795,13 +818,15 @@ class StreamUnaryMultiCallable(abc.ABC): """Affords invoking a stream-unary RPC from client-side.""" @abc.abstractmethod - def __call__(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def __call__( + self, + request_iterator, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): """Synchronously invokes the underlying RPC. Args: @@ -828,13 +853,15 @@ def __call__(self, raise NotImplementedError() @abc.abstractmethod - def with_call(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def with_call( + self, + request_iterator, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): """Synchronously invokes the underlying RPC on the client. Args: @@ -861,13 +888,15 @@ def with_call(self, raise NotImplementedError() @abc.abstractmethod - def future(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def future( + self, + request_iterator, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): """Asynchronously invokes the underlying RPC on the client. Args: @@ -896,13 +925,15 @@ class StreamStreamMultiCallable(abc.ABC): """Affords invoking a stream-stream RPC on client-side.""" @abc.abstractmethod - def __call__(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None): + def __call__( + self, + request_iterator, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None, + ): """Invokes the underlying RPC on the client. Args: @@ -968,10 +999,9 @@ def unsubscribe(self, callback): raise NotImplementedError() @abc.abstractmethod - def unary_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_unary( + self, method, request_serializer=None, response_deserializer=None + ): """Creates a UnaryUnaryMultiCallable for a unary-unary method. Args: @@ -988,10 +1018,9 @@ def unary_unary(self, raise NotImplementedError() @abc.abstractmethod - def unary_stream(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_stream( + self, method, request_serializer=None, response_deserializer=None + ): """Creates a UnaryStreamMultiCallable for a unary-stream method. Args: @@ -1008,10 +1037,9 @@ def unary_stream(self, raise NotImplementedError() @abc.abstractmethod - def stream_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def stream_unary( + self, method, request_serializer=None, response_deserializer=None + ): """Creates a StreamUnaryMultiCallable for a stream-unary method. Args: @@ -1028,10 +1056,9 @@ def stream_unary(self, raise NotImplementedError() @abc.abstractmethod - def stream_stream(self, - method, - request_serializer=None, - response_deserializer=None): + def stream_stream( + self, method, request_serializer=None, response_deserializer=None + ): """Creates a StreamStreamMultiCallable for a stream-stream method. Args: @@ -1493,9 +1520,9 @@ def wait_for_termination(self, timeout=None): ################################# Functions ################################ -def unary_unary_rpc_method_handler(behavior, - request_deserializer=None, - response_serializer=None): +def unary_unary_rpc_method_handler( + behavior, request_deserializer=None, response_serializer=None +): """Creates an RpcMethodHandler for a unary-unary RPC method. Args: @@ -1508,14 +1535,22 @@ def unary_unary_rpc_method_handler(behavior, An RpcMethodHandler object that is typically used by grpc.Server. """ from grpc import _utilities # pylint: disable=cyclic-import - return _utilities.RpcMethodHandler(False, False, request_deserializer, - response_serializer, behavior, None, - None, None) - -def unary_stream_rpc_method_handler(behavior, - request_deserializer=None, - response_serializer=None): + return _utilities.RpcMethodHandler( + False, + False, + request_deserializer, + response_serializer, + behavior, + None, + None, + None, + ) + + +def unary_stream_rpc_method_handler( + behavior, request_deserializer=None, response_serializer=None +): """Creates an RpcMethodHandler for a unary-stream RPC method. Args: @@ -1528,14 +1563,22 @@ def unary_stream_rpc_method_handler(behavior, An RpcMethodHandler object that is typically used by grpc.Server. """ from grpc import _utilities # pylint: disable=cyclic-import - return _utilities.RpcMethodHandler(False, True, request_deserializer, - response_serializer, None, behavior, - None, None) - -def stream_unary_rpc_method_handler(behavior, - request_deserializer=None, - response_serializer=None): + return _utilities.RpcMethodHandler( + False, + True, + request_deserializer, + response_serializer, + None, + behavior, + None, + None, + ) + + +def stream_unary_rpc_method_handler( + behavior, request_deserializer=None, response_serializer=None +): """Creates an RpcMethodHandler for a stream-unary RPC method. Args: @@ -1548,14 +1591,22 @@ def stream_unary_rpc_method_handler(behavior, An RpcMethodHandler object that is typically used by grpc.Server. """ from grpc import _utilities # pylint: disable=cyclic-import - return _utilities.RpcMethodHandler(True, False, request_deserializer, - response_serializer, None, None, - behavior, None) - -def stream_stream_rpc_method_handler(behavior, - request_deserializer=None, - response_serializer=None): + return _utilities.RpcMethodHandler( + True, + False, + request_deserializer, + response_serializer, + None, + None, + behavior, + None, + ) + + +def stream_stream_rpc_method_handler( + behavior, request_deserializer=None, response_serializer=None +): """Creates an RpcMethodHandler for a stream-stream RPC method. Args: @@ -1568,9 +1619,17 @@ def stream_stream_rpc_method_handler(behavior, An RpcMethodHandler object that is typically used by grpc.Server. """ from grpc import _utilities # pylint: disable=cyclic-import - return _utilities.RpcMethodHandler(True, True, request_deserializer, - response_serializer, None, None, None, - behavior) + + return _utilities.RpcMethodHandler( + True, + True, + request_deserializer, + response_serializer, + None, + None, + None, + behavior, + ) def method_handlers_generic_handler(service, method_handlers): @@ -1587,12 +1646,13 @@ def method_handlers_generic_handler(service, method_handlers): with add_generic_rpc_handlers() before starting the server. """ from grpc import _utilities # pylint: disable=cyclic-import + return _utilities.DictionaryGenericHandler(service, method_handlers) -def ssl_channel_credentials(root_certificates=None, - private_key=None, - certificate_chain=None): +def ssl_channel_credentials( + root_certificates=None, private_key=None, certificate_chain=None +): """Creates a ChannelCredentials for use with an SSL-enabled Channel. Args: @@ -1608,8 +1668,10 @@ def ssl_channel_credentials(root_certificates=None, A ChannelCredentials for use with an SSL-enabled Channel. """ return ChannelCredentials( - _cygrpc.SSLChannelCredentials(root_certificates, private_key, - certificate_chain)) + _cygrpc.SSLChannelCredentials( + root_certificates, private_key, certificate_chain + ) + ) def xds_channel_credentials(fallback_credentials=None): @@ -1621,10 +1683,14 @@ def xds_channel_credentials(fallback_credentials=None): establish a secure connection via xDS. If no fallback_credentials argument is supplied, a default SSLChannelCredentials is used. """ - fallback_credentials = ssl_channel_credentials( - ) if fallback_credentials is None else fallback_credentials + fallback_credentials = ( + ssl_channel_credentials() + if fallback_credentials is None + else fallback_credentials + ) return ChannelCredentials( - _cygrpc.XDSChannelCredentials(fallback_credentials._credentials)) + _cygrpc.XDSChannelCredentials(fallback_credentials._credentials) + ) def metadata_call_credentials(metadata_plugin, name=None): @@ -1638,8 +1704,10 @@ def metadata_call_credentials(metadata_plugin, name=None): A CallCredentials. """ from grpc import _plugin_wrapping # pylint: disable=cyclic-import + return _plugin_wrapping.metadata_plugin_call_credentials( - metadata_plugin, name) + metadata_plugin, name + ) def access_token_call_credentials(access_token): @@ -1655,8 +1723,10 @@ def access_token_call_credentials(access_token): """ from grpc import _auth # pylint: disable=cyclic-import from grpc import _plugin_wrapping # pylint: disable=cyclic-import + return _plugin_wrapping.metadata_plugin_call_credentials( - _auth.AccessTokenAuthMetadataPlugin(access_token), None) + _auth.AccessTokenAuthMetadataPlugin(access_token), None + ) def composite_call_credentials(*call_credentials): @@ -1670,8 +1740,12 @@ def composite_call_credentials(*call_credentials): """ return CallCredentials( _cygrpc.CompositeCallCredentials( - tuple(single_call_credentials._credentials - for single_call_credentials in call_credentials))) + tuple( + single_call_credentials._credentials + for single_call_credentials in call_credentials + ) + ) + ) def composite_channel_credentials(channel_credentials, *call_credentials): @@ -1687,14 +1761,20 @@ def composite_channel_credentials(channel_credentials, *call_credentials): """ return ChannelCredentials( _cygrpc.CompositeChannelCredentials( - tuple(single_call_credentials._credentials - for single_call_credentials in call_credentials), - channel_credentials._credentials)) + tuple( + single_call_credentials._credentials + for single_call_credentials in call_credentials + ), + channel_credentials._credentials, + ) + ) -def ssl_server_credentials(private_key_certificate_chain_pairs, - root_certificates=None, - require_client_auth=False): +def ssl_server_credentials( + private_key_certificate_chain_pairs, + root_certificates=None, + require_client_auth=False, +): """Creates a ServerCredentials for use with an SSL-enabled Server. Args: @@ -1713,17 +1793,24 @@ def ssl_server_credentials(private_key_certificate_chain_pairs, """ if not private_key_certificate_chain_pairs: raise ValueError( - 'At least one private key-certificate chain pair is required!') + "At least one private key-certificate chain pair is required!" + ) elif require_client_auth and root_certificates is None: raise ValueError( - 'Illegal to require client auth without providing root certificates!' + "Illegal to require client auth without providing root" + " certificates!" ) else: return ServerCredentials( - _cygrpc.server_credentials_ssl(root_certificates, [ - _cygrpc.SslPemKeyCertPair(key, pem) - for key, pem in private_key_certificate_chain_pairs - ], require_client_auth)) + _cygrpc.server_credentials_ssl( + root_certificates, + [ + _cygrpc.SslPemKeyCertPair(key, pem) + for key, pem in private_key_certificate_chain_pairs + ], + require_client_auth, + ) + ) def xds_server_credentials(fallback_credentials): @@ -1735,7 +1822,8 @@ def xds_server_credentials(fallback_credentials): establish a secure connection via xDS. No default value is provided. """ return ServerCredentials( - _cygrpc.xds_server_credentials(fallback_credentials._credentials)) + _cygrpc.xds_server_credentials(fallback_credentials._credentials) + ) def insecure_server_credentials(): @@ -1749,8 +1837,9 @@ def insecure_server_credentials(): return ServerCredentials(_cygrpc.insecure_server_credentials()) -def ssl_server_certificate_configuration(private_key_certificate_chain_pairs, - root_certificates=None): +def ssl_server_certificate_configuration( + private_key_certificate_chain_pairs, root_certificates=None +): """Creates a ServerCertificateConfiguration for use with a Server. Args: @@ -1766,18 +1855,25 @@ def ssl_server_certificate_configuration(private_key_certificate_chain_pairs, """ if private_key_certificate_chain_pairs: return ServerCertificateConfiguration( - _cygrpc.server_certificate_config_ssl(root_certificates, [ - _cygrpc.SslPemKeyCertPair(key, pem) - for key, pem in private_key_certificate_chain_pairs - ])) + _cygrpc.server_certificate_config_ssl( + root_certificates, + [ + _cygrpc.SslPemKeyCertPair(key, pem) + for key, pem in private_key_certificate_chain_pairs + ], + ) + ) else: raise ValueError( - 'At least one private key-certificate chain pair is required!') + "At least one private key-certificate chain pair is required!" + ) -def dynamic_ssl_server_credentials(initial_certificate_configuration, - certificate_configuration_fetcher, - require_client_authentication=False): +def dynamic_ssl_server_credentials( + initial_certificate_configuration, + certificate_configuration_fetcher, + require_client_authentication=False, +): """Creates a ServerCredentials for use with an SSL-enabled Server. Args: @@ -1801,7 +1897,10 @@ def dynamic_ssl_server_credentials(initial_certificate_configuration, return ServerCredentials( _cygrpc.server_credentials_ssl_dynamic_cert_config( initial_certificate_configuration, - certificate_configuration_fetcher, require_client_authentication)) + certificate_configuration_fetcher, + require_client_authentication, + ) + ) @enum.unique @@ -1812,6 +1911,7 @@ class LocalConnectionType(enum.Enum): UDS: Unix domain socket connections LOCAL_TCP: Local TCP connections. """ + UDS = _cygrpc.LocalConnectionType.uds LOCAL_TCP = _cygrpc.LocalConnectionType.local_tcp @@ -1843,7 +1943,8 @@ def local_channel_credentials(local_connect_type=LocalConnectionType.LOCAL_TCP): A ChannelCredentials for use with a local Channel """ return ChannelCredentials( - _cygrpc.channel_credentials_local(local_connect_type.value)) + _cygrpc.channel_credentials_local(local_connect_type.value) + ) def local_server_credentials(local_connect_type=LocalConnectionType.LOCAL_TCP): @@ -1873,7 +1974,8 @@ def local_server_credentials(local_connect_type=LocalConnectionType.LOCAL_TCP): A ServerCredentials for use with a local Server """ return ServerCredentials( - _cygrpc.server_credentials_local(local_connect_type.value)) + _cygrpc.server_credentials_local(local_connect_type.value) + ) def alts_channel_credentials(service_accounts=None): @@ -1894,7 +1996,8 @@ def alts_channel_credentials(service_accounts=None): A ChannelCredentials for use with an ALTS-enabled Channel """ return ChannelCredentials( - _cygrpc.channel_credentials_alts(service_accounts or [])) + _cygrpc.channel_credentials_alts(service_accounts or []) + ) def alts_server_credentials(): @@ -1925,7 +2028,9 @@ def compute_engine_channel_credentials(call_credentials): """ return ChannelCredentials( _cygrpc.channel_credentials_compute_engine( - call_credentials._credentials)) + call_credentials._credentials + ) + ) def channel_ready_future(channel): @@ -1942,6 +2047,7 @@ def channel_ready_future(channel): ChannelConnectivity.READY. """ from grpc import _utilities # pylint: disable=cyclic-import + return _utilities.channel_ready_future(channel) @@ -1961,8 +2067,10 @@ def insecure_channel(target, options=None, compression=None): A Channel. """ from grpc import _channel # pylint: disable=cyclic-import - return _channel.Channel(target, () if options is None else options, None, - compression) + + return _channel.Channel( + target, () if options is None else options, None, compression + ) def secure_channel(target, credentials, options=None, compression=None): @@ -1983,12 +2091,18 @@ def secure_channel(target, credentials, options=None, compression=None): """ from grpc import _channel # pylint: disable=cyclic-import from grpc.experimental import _insecure_channel_credentials + if credentials._credentials is _insecure_channel_credentials: raise ValueError( - "secure_channel cannot be called with insecure credentials." + - " Call insecure_channel instead.") - return _channel.Channel(target, () if options is None else options, - credentials._credentials, compression) + "secure_channel cannot be called with insecure credentials." + + " Call insecure_channel instead." + ) + return _channel.Channel( + target, + () if options is None else options, + credentials._credentials, + compression, + ) def intercept_channel(channel, *interceptors): @@ -2014,16 +2128,19 @@ def intercept_channel(channel, *interceptors): StreamStreamClientInterceptor. """ from grpc import _interceptor # pylint: disable=cyclic-import + return _interceptor.intercept_channel(channel, *interceptors) -def server(thread_pool, - handlers=None, - interceptors=None, - options=None, - maximum_concurrent_rpcs=None, - compression=None, - xds=False): +def server( + thread_pool, + handlers=None, + interceptors=None, + options=None, + maximum_concurrent_rpcs=None, + compression=None, + xds=False, +): """Creates a Server with which RPCs can be serviced. Args: @@ -2051,16 +2168,22 @@ def server(thread_pool, A Server object. """ from grpc import _server # pylint: disable=cyclic-import - return _server.create_server(thread_pool, - () if handlers is None else handlers, - () if interceptors is None else interceptors, - () if options is None else options, - maximum_concurrent_rpcs, compression, xds) + + return _server.create_server( + thread_pool, + () if handlers is None else handlers, + () if interceptors is None else interceptors, + () if options is None else options, + maximum_concurrent_rpcs, + compression, + xds, + ) @contextlib.contextmanager def _create_servicer_context(rpc_event, state, request_deserializer): from grpc import _server # pylint: disable=cyclic-import + context = _server._Context(rpc_event, state, request_deserializer) yield context context._finalize_state() # pylint: disable=protected-access @@ -2070,11 +2193,12 @@ def _create_servicer_context(rpc_event, state, request_deserializer): class Compression(enum.IntEnum): """Indicates the compression method to be used for an RPC. - Attributes: - NoCompression: Do not use compression algorithm. - Deflate: Use "Deflate" compression algorithm. - Gzip: Use "Gzip" compression algorithm. + Attributes: + NoCompression: Do not use compression algorithm. + Deflate: Use "Deflate" compression algorithm. + Gzip: Use "Gzip" compression algorithm. """ + NoCompression = _compression.NoCompression Deflate = _compression.Deflate Gzip = _compression.Gzip @@ -2083,70 +2207,70 @@ class Compression(enum.IntEnum): ################################### __all__ ################################# __all__ = ( - 'FutureTimeoutError', - 'FutureCancelledError', - 'Future', - 'ChannelConnectivity', - 'StatusCode', - 'Status', - 'RpcError', - 'RpcContext', - 'Call', - 'ChannelCredentials', - 'CallCredentials', - 'AuthMetadataContext', - 'AuthMetadataPluginCallback', - 'AuthMetadataPlugin', - 'Compression', - 'ClientCallDetails', - 'ServerCertificateConfiguration', - 'ServerCredentials', - 'LocalConnectionType', - 'UnaryUnaryMultiCallable', - 'UnaryStreamMultiCallable', - 'StreamUnaryMultiCallable', - 'StreamStreamMultiCallable', - 'UnaryUnaryClientInterceptor', - 'UnaryStreamClientInterceptor', - 'StreamUnaryClientInterceptor', - 'StreamStreamClientInterceptor', - 'Channel', - 'ServicerContext', - 'RpcMethodHandler', - 'HandlerCallDetails', - 'GenericRpcHandler', - 'ServiceRpcHandler', - 'Server', - 'ServerInterceptor', - 'unary_unary_rpc_method_handler', - 'unary_stream_rpc_method_handler', - 'stream_unary_rpc_method_handler', - 'stream_stream_rpc_method_handler', - 'method_handlers_generic_handler', - 'ssl_channel_credentials', - 'metadata_call_credentials', - 'access_token_call_credentials', - 'composite_call_credentials', - 'composite_channel_credentials', - 'compute_engine_channel_credentials', - 'local_channel_credentials', - 'local_server_credentials', - 'alts_channel_credentials', - 'alts_server_credentials', - 'ssl_server_credentials', - 'ssl_server_certificate_configuration', - 'dynamic_ssl_server_credentials', - 'channel_ready_future', - 'insecure_channel', - 'secure_channel', - 'intercept_channel', - 'server', - 'protos', - 'services', - 'protos_and_services', - 'xds_channel_credentials', - 'xds_server_credentials', - 'insecure_server_credentials', + "FutureTimeoutError", + "FutureCancelledError", + "Future", + "ChannelConnectivity", + "StatusCode", + "Status", + "RpcError", + "RpcContext", + "Call", + "ChannelCredentials", + "CallCredentials", + "AuthMetadataContext", + "AuthMetadataPluginCallback", + "AuthMetadataPlugin", + "Compression", + "ClientCallDetails", + "ServerCertificateConfiguration", + "ServerCredentials", + "LocalConnectionType", + "UnaryUnaryMultiCallable", + "UnaryStreamMultiCallable", + "StreamUnaryMultiCallable", + "StreamStreamMultiCallable", + "UnaryUnaryClientInterceptor", + "UnaryStreamClientInterceptor", + "StreamUnaryClientInterceptor", + "StreamStreamClientInterceptor", + "Channel", + "ServicerContext", + "RpcMethodHandler", + "HandlerCallDetails", + "GenericRpcHandler", + "ServiceRpcHandler", + "Server", + "ServerInterceptor", + "unary_unary_rpc_method_handler", + "unary_stream_rpc_method_handler", + "stream_unary_rpc_method_handler", + "stream_stream_rpc_method_handler", + "method_handlers_generic_handler", + "ssl_channel_credentials", + "metadata_call_credentials", + "access_token_call_credentials", + "composite_call_credentials", + "composite_channel_credentials", + "compute_engine_channel_credentials", + "local_channel_credentials", + "local_server_credentials", + "alts_channel_credentials", + "alts_server_credentials", + "ssl_server_credentials", + "ssl_server_certificate_configuration", + "dynamic_ssl_server_credentials", + "channel_ready_future", + "insecure_channel", + "secure_channel", + "intercept_channel", + "server", + "protos", + "services", + "protos_and_services", + "xds_channel_credentials", + "xds_server_credentials", + "insecure_server_credentials", ) ############################### Extension Shims ################################ @@ -2154,21 +2278,25 @@ class Compression(enum.IntEnum): # Here to maintain backwards compatibility; avoid using these in new code! try: import grpc_tools - sys.modules.update({'grpc.tools': grpc_tools}) + + sys.modules.update({"grpc.tools": grpc_tools}) except ImportError: pass try: import grpc_health - sys.modules.update({'grpc.health': grpc_health}) + + sys.modules.update({"grpc.health": grpc_health}) except ImportError: pass try: import grpc_reflection - sys.modules.update({'grpc.reflection': grpc_reflection}) + + sys.modules.update({"grpc.reflection": grpc_reflection}) except ImportError: pass # Prevents import order issue in the case of renamed path. if sys.version_info >= (3, 6) and __name__ == "grpc": from grpc import aio # pylint: disable=ungrouped-imports - sys.modules.update({'grpc.aio': aio}) + + sys.modules.update({"grpc.aio": aio}) diff --git a/src/python/grpcio/grpc/_auth.py b/src/python/grpcio/grpc/_auth.py index 2095957072f35..9cef38b69105e 100644 --- a/src/python/grpcio/grpc/_auth.py +++ b/src/python/grpcio/grpc/_auth.py @@ -19,14 +19,18 @@ import grpc -def _sign_request(callback: grpc.AuthMetadataPluginCallback, - token: Optional[str], error: Optional[Exception]): - metadata = (('authorization', 'Bearer {}'.format(token)),) +def _sign_request( + callback: grpc.AuthMetadataPluginCallback, + token: Optional[str], + error: Optional[Exception], +): + metadata = (("authorization", "Bearer {}".format(token)),) callback(metadata, error) class GoogleCallCredentials(grpc.AuthMetadataPlugin): """Metadata wrapper for GoogleCredentials from the oauth2client library.""" + _is_jwt: bool _credentials: Any @@ -35,19 +39,23 @@ def __init__(self, credentials: Any): self._credentials = credentials # Hack to determine if these are JWT creds and we need to pass # additional_claims when getting a token - self._is_jwt = 'additional_claims' in inspect.getfullargspec( - credentials.get_access_token).args + self._is_jwt = ( + "additional_claims" + in inspect.getfullargspec(credentials.get_access_token).args + ) - def __call__(self, context: grpc.AuthMetadataContext, - callback: grpc.AuthMetadataPluginCallback): + def __call__( + self, + context: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ): try: if self._is_jwt: access_token = self._credentials.get_access_token( additional_claims={ - 'aud': - context. - service_url # pytype: disable=attribute-error - }).access_token + "aud": context.service_url # pytype: disable=attribute-error + } + ).access_token else: access_token = self._credentials.get_access_token().access_token except Exception as exception: # pylint: disable=broad-except @@ -58,11 +66,15 @@ def __call__(self, context: grpc.AuthMetadataContext, class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): """Metadata wrapper for raw access token credentials.""" + _access_token: str def __init__(self, access_token: str): self._access_token = access_token - def __call__(self, context: grpc.AuthMetadataContext, - callback: grpc.AuthMetadataPluginCallback): + def __call__( + self, + context: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ): _sign_request(callback, self._access_token, None) diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 074a3f0e785b5..d8b42255acce3 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -22,8 +22,17 @@ import threading import time import types -from typing import (Any, Callable, Iterator, List, Optional, Sequence, Set, - Tuple, Union) +from typing import ( + Any, + Callable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import grpc # pytype: disable=pyi-error from grpc import _common # pytype: disable=pyi-error @@ -43,14 +52,15 @@ _LOGGER = logging.getLogger(__name__) -_USER_AGENT = 'grpc-python/{}'.format(_grpcio_metadata.__version__) +_USER_AGENT = "grpc-python/{}".format(_grpcio_metadata.__version__) _EMPTY_FLAGS = 0 # NOTE(rbellevi): No guarantees are given about the maintenance of this # environment variable. -_DEFAULT_SINGLE_THREADED_UNARY_STREAM = os.getenv( - "GRPC_SINGLE_THREADED_UNARY_STREAM") is not None +_DEFAULT_SINGLE_THREADED_UNARY_STREAM = ( + os.getenv("GRPC_SINGLE_THREADED_UNARY_STREAM") is not None +) _UNARY_UNARY_INITIAL_DUE = ( cygrpc.OperationType.send_initial_metadata, @@ -80,28 +90,32 @@ ) _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = ( - 'Exception calling channel subscription callback!') + "Exception calling channel subscription callback!" +) -_OK_RENDEZVOUS_REPR_FORMAT = ('<{} of RPC that terminated with:\n' - '\tstatus = {}\n' - '\tdetails = "{}"\n' - '>') +_OK_RENDEZVOUS_REPR_FORMAT = ( + '<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>' +) -_NON_OK_RENDEZVOUS_REPR_FORMAT = ('<{} of RPC that terminated with:\n' - '\tstatus = {}\n' - '\tdetails = "{}"\n' - '\tdebug_error_string = "{}"\n' - '>') +_NON_OK_RENDEZVOUS_REPR_FORMAT = ( + "<{} of RPC that terminated with:\n" + "\tstatus = {}\n" + '\tdetails = "{}"\n' + '\tdebug_error_string = "{}"\n' + ">" +) def _deadline(timeout: Optional[float]) -> Optional[float]: return None if timeout is None else time.time() + timeout -def _unknown_code_details(unknown_cygrpc_code: Optional[grpc.StatusCode], - details: Optional[str]) -> str: +def _unknown_code_details( + unknown_cygrpc_code: Optional[grpc.StatusCode], details: Optional[str] +) -> str: return 'Server sent unknown code {} and details "{}"'.format( - unknown_cygrpc_code, details) + unknown_cygrpc_code, details + ) class _RPCState(object): @@ -120,10 +134,14 @@ class _RPCState(object): rpc_end_time: Optional[datetime] method: Optional[str] - def __init__(self, due: Sequence[cygrpc.OperationType], - initial_metadata: Optional[MetadataType], - trailing_metadata: Optional[MetadataType], - code: Optional[grpc.StatusCode], details: Optional[str]): + def __init__( + self, + due: Sequence[cygrpc.OperationType], + initial_metadata: Optional[MetadataType], + trailing_metadata: Optional[MetadataType], + code: Optional[grpc.StatusCode], + details: Optional[str], + ): # `condition` guards all members of _RPCState. `notify_all` is called on # `condition` when the state of the RPC has changed. self.condition = threading.Condition() @@ -169,8 +187,9 @@ def _abort(state: _RPCState, code: grpc.StatusCode, details: str) -> None: def _handle_event( - event: cygrpc.BaseEvent, state: _RPCState, - response_deserializer: Optional[DeserializingFunction] + event: cygrpc.BaseEvent, + state: _RPCState, + response_deserializer: Optional[DeserializingFunction], ) -> List[NullaryCallbackType]: callbacks = [] for batch_operation in event.batch_operations: @@ -181,10 +200,11 @@ def _handle_event( elif operation_type == cygrpc.OperationType.receive_message: serialized_response = batch_operation.message() if serialized_response is not None: - response = _common.deserialize(serialized_response, - response_deserializer) + response = _common.deserialize( + serialized_response, response_deserializer + ) if response is None: - details = 'Exception deserializing response!' + details = "Exception deserializing response!" _abort(state, grpc.StatusCode.INTERNAL, details) else: state.response = response @@ -192,11 +212,13 @@ def _handle_event( state.trailing_metadata = batch_operation.trailing_metadata() if state.code is None: code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE.get( - batch_operation.code()) + batch_operation.code() + ) if code is None: state.code = grpc.StatusCode.UNKNOWN state.details = _unknown_code_details( - code, batch_operation.details()) + code, batch_operation.details() + ) else: state.code = code state.details = batch_operation.details() @@ -209,9 +231,8 @@ def _handle_event( def _event_handler( - state: _RPCState, - response_deserializer: Optional[DeserializingFunction]) -> UserTag: - + state: _RPCState, response_deserializer: Optional[DeserializingFunction] +) -> UserTag: def handle_event(event): with state.condition: callbacks = _handle_event(event, state, response_deserializer) @@ -223,20 +244,23 @@ def handle_event(event): except Exception as e: # pylint: disable=broad-except # NOTE(rbellevi): We suppress but log errors here so as not to # kill the channel spin thread. - logging.error('Exception in callback %s: %s', - repr(callback.func), repr(e)) + logging.error( + "Exception in callback %s: %s", repr(callback.func), repr(e) + ) return done and state.fork_epoch >= cygrpc.get_fork_epoch() return handle_event # TODO(xuanwn): Create a base class for IntegratedCall and SegregatedCall. -#pylint: disable=too-many-statements -def _consume_request_iterator(request_iterator: Iterator, state: _RPCState, - call: Union[cygrpc.IntegratedCall, - cygrpc.SegregatedCall], - request_serializer: SerializingFunction, - event_handler: Optional[UserTag]) -> None: +# pylint: disable=too-many-statements +def _consume_request_iterator( + request_iterator: Iterator, + state: _RPCState, + call: Union[cygrpc.IntegratedCall, cygrpc.SegregatedCall], + request_serializer: SerializingFunction, + event_handler: Optional[UserTag], +) -> None: """Consume a request supplied by the user.""" def consume_request_iterator(): # pylint: disable=too-many-branches @@ -254,10 +278,11 @@ def consume_request_iterator(): # pylint: disable=too-many-branches cygrpc.return_from_user_request_generator() return_from_user_request_generator_invoked = True code = grpc.StatusCode.UNKNOWN - details = 'Exception iterating requests!' + details = "Exception iterating requests!" _LOGGER.exception(details) - call.cancel(_common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], - details) + call.cancel( + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details + ) _abort(state, code, details) return finally: @@ -268,31 +293,39 @@ def consume_request_iterator(): # pylint: disable=too-many-branches if state.code is None and not state.cancelled: if serialized_request is None: code = grpc.StatusCode.INTERNAL - details = 'Exception serializing request!' + details = "Exception serializing request!" call.cancel( _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], - details) + details, + ) _abort(state, code, details) return else: state.due.add(cygrpc.OperationType.send_message) - operations = (cygrpc.SendMessageOperation( - serialized_request, _EMPTY_FLAGS),) + operations = ( + cygrpc.SendMessageOperation( + serialized_request, _EMPTY_FLAGS + ), + ) operating = call.operate(operations, event_handler) if not operating: state.due.remove(cygrpc.OperationType.send_message) return def _done(): - return (state.code is not None or - cygrpc.OperationType.send_message - not in state.due) - - _common.wait(state.condition.wait, - _done, - spin_cb=functools.partial( - cygrpc.block_if_fork_in_progress, - state)) + return ( + state.code is not None + or cygrpc.OperationType.send_message + not in state.due + ) + + _common.wait( + state.condition.wait, + _done, + spin_cb=functools.partial( + cygrpc.block_if_fork_in_progress, state + ), + ) if state.code is not None: return else: @@ -301,14 +334,17 @@ def _done(): if state.code is None: state.due.add(cygrpc.OperationType.send_close_from_client) operations = ( - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),) + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + ) operating = call.operate(operations, event_handler) if not operating: state.due.remove( - cygrpc.OperationType.send_close_from_client) + cygrpc.OperationType.send_close_from_client + ) consumption_thread = cygrpc.ForkManagedThread( - target=consume_request_iterator) + target=consume_request_iterator + ) consumption_thread.setDaemon(True) consumption_thread.start() @@ -317,14 +353,18 @@ def _rpc_state_string(class_name: str, rpc_state: _RPCState) -> str: """Calculates error string for RPC.""" with rpc_state.condition: if rpc_state.code is None: - return '<{} object>'.format(class_name) + return "<{} object>".format(class_name) elif rpc_state.code is grpc.StatusCode.OK: - return _OK_RENDEZVOUS_REPR_FORMAT.format(class_name, rpc_state.code, - rpc_state.details) + return _OK_RENDEZVOUS_REPR_FORMAT.format( + class_name, rpc_state.code, rpc_state.details + ) else: return _NON_OK_RENDEZVOUS_REPR_FORMAT.format( - class_name, rpc_state.code, rpc_state.details, - rpc_state.debug_error_string) + class_name, + rpc_state.code, + rpc_state.details, + rpc_state.debug_error_string, + ) class _InactiveRpcError(grpc.RpcError, grpc.Call, grpc.Future): @@ -336,13 +376,18 @@ class _InactiveRpcError(grpc.RpcError, grpc.Call, grpc.Future): Attributes: _state: An instance of _RPCState. """ + _state: _RPCState def __init__(self, state: _RPCState): with state.condition: - self._state = _RPCState((), copy.deepcopy(state.initial_metadata), - copy.deepcopy(state.trailing_metadata), - state.code, copy.deepcopy(state.details)) + self._state = _RPCState( + (), + copy.deepcopy(state.initial_metadata), + copy.deepcopy(state.trailing_metadata), + state.code, + copy.deepcopy(state.details), + ) self._state.response = copy.copy(state.response) self._state.debug_error_string = copy.copy(state.debug_error_string) @@ -386,17 +431,20 @@ def done(self) -> bool: """See grpc.Future.done.""" return True - def result(self, timeout: Optional[float] = None) -> Any: # pylint: disable=unused-argument + def result( + self, timeout: Optional[float] = None + ) -> Any: # pylint: disable=unused-argument """See grpc.Future.result.""" raise self - def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: # pylint: disable=unused-argument + def exception( + self, timeout: Optional[float] = None # pylint: disable=unused-argument + ) -> Optional[Exception]: """See grpc.Future.exception.""" return self def traceback( - self, - timeout: Optional[float] = None # pylint: disable=unused-argument + self, timeout: Optional[float] = None # pylint: disable=unused-argument ) -> Optional[types.TracebackType]: """See grpc.Future.traceback.""" try: @@ -407,7 +455,8 @@ def traceback( def add_done_callback( self, fn: Callable[[grpc.Future], None], - timeout: Optional[float] = None) -> None: # pylint: disable=unused-argument + timeout: Optional[float] = None, # pylint: disable=unused-argument + ) -> None: """See grpc.Future.add_done_callback.""" fn(self) @@ -425,15 +474,19 @@ class _Rendezvous(grpc.RpcError, grpc.RpcContext): _deadline: A float representing the deadline of the RPC in seconds. Or possibly None, to represent an RPC with no deadline at all. """ + _state: _RPCState _call: Union[cygrpc.SegregatedCall, cygrpc.IntegratedCall] _response_deserializer: Optional[DeserializingFunction] _deadline: Optional[float] - def __init__(self, state: _RPCState, call: Union[cygrpc.SegregatedCall, - cygrpc.IntegratedCall], - response_deserializer: Optional[DeserializingFunction], - deadline: Optional[float]): + def __init__( + self, + state: _RPCState, + call: Union[cygrpc.SegregatedCall, cygrpc.IntegratedCall], + response_deserializer: Optional[DeserializingFunction], + deadline: Optional[float], + ): super(_Rendezvous, self).__init__() self._state = state self._call = call @@ -458,9 +511,10 @@ def cancel(self) -> bool: with self._state.condition: if self._state.code is None: code = grpc.StatusCode.CANCELLED - details = 'Locally cancelled by application!' + details = "Locally cancelled by application!" self._call.cancel( - _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details) + _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[code], details + ) self._state.cancelled = True _abort(self._state, code, details) self._state.condition.notify_all() @@ -505,15 +559,18 @@ def __del__(self) -> None: with self._state.condition: if self._state.code is None: self._state.code = grpc.StatusCode.CANCELLED - self._state.details = 'Cancelled upon garbage collection!' + self._state.details = "Cancelled upon garbage collection!" self._state.cancelled = True self._call.cancel( _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE[self._state.code], - self._state.details) + self._state.details, + ) self._state.condition.notify_all() -class _SingleThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: disable=too-many-ancestors +class _SingleThreadedRendezvous( + _Rendezvous, grpc.Call, grpc.Future +): # pylint: disable=too-many-ancestors """An RPC iterator operating entirely on a single thread. The __next__ method of _SingleThreadedRendezvous does not depend on the @@ -526,6 +583,7 @@ class cannot completely fulfill the grpc.Future interface. The result, This means that these methods are safe to call from add_done_callback handlers. """ + _state: _RPCState def _is_complete(self) -> bool: @@ -556,7 +614,8 @@ def result(self, timeout: Optional[float] = None) -> Any: with self._state.condition: if not self._is_complete(): raise grpc.experimental.UsageError( - "_SingleThreadedRendezvous only supports result() when the RPC is complete." + "_SingleThreadedRendezvous only supports result() when the" + " RPC is complete." ) if self._state.code is grpc.StatusCode.OK: return self._state.response @@ -578,7 +637,8 @@ def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: with self._state.condition: if not self._is_complete(): raise grpc.experimental.UsageError( - "_SingleThreadedRendezvous only supports exception() when the RPC is complete." + "_SingleThreadedRendezvous only supports exception() when" + " the RPC is complete." ) if self._state.code is grpc.StatusCode.OK: return None @@ -588,8 +648,8 @@ def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: return self def traceback( - self, - timeout: Optional[float] = None) -> Optional[types.TracebackType]: + self, timeout: Optional[float] = None + ) -> Optional[types.TracebackType]: """Access the traceback of the exception raised by the computation. This method will never block. Instead, it will raise an exception @@ -602,7 +662,8 @@ def traceback( with self._state.condition: if not self._is_complete(): raise grpc.experimental.UsageError( - "_SingleThreadedRendezvous only supports traceback() when the RPC is complete." + "_SingleThreadedRendezvous only supports traceback() when" + " the RPC is complete." ) if self._state.code is grpc.StatusCode.OK: return None @@ -636,7 +697,8 @@ def trailing_metadata(self) -> Optional[MetadataType]: with self._state.condition: if self._state.trailing_metadata is None: raise grpc.experimental.UsageError( - "Cannot get trailing metadata until RPC is completed.") + "Cannot get trailing metadata until RPC is completed." + ) return self._state.trailing_metadata def code(self) -> Optional[grpc.StatusCode]: @@ -644,7 +706,8 @@ def code(self) -> Optional[grpc.StatusCode]: with self._state.condition: if self._state.code is None: raise grpc.experimental.UsageError( - "Cannot get code until RPC is completed.") + "Cannot get code until RPC is completed." + ) return self._state.code def details(self) -> Optional[str]: @@ -652,14 +715,16 @@ def details(self) -> Optional[str]: with self._state.condition: if self._state.details is None: raise grpc.experimental.UsageError( - "Cannot get details until RPC is completed.") + "Cannot get details until RPC is completed." + ) return _common.decode(self._state.details) def _consume_next_event(self) -> Optional[cygrpc.BaseEvent]: event = self._call.next_event() with self._state.condition: - callbacks = _handle_event(event, self._state, - self._response_deserializer) + callbacks = _handle_event( + event, self._state, self._response_deserializer + ) for callback in callbacks: # NOTE(gnossen): We intentionally allow exceptions to bubble up # to the user when running on a single thread. @@ -674,7 +739,9 @@ def _next_response(self) -> Any: response = self._state.response self._state.response = None return response - elif cygrpc.OperationType.receive_message not in self._state.due: + elif ( + cygrpc.OperationType.receive_message not in self._state.due + ): if self._state.code is grpc.StatusCode.OK: raise StopIteration() elif self._state.code is not None: @@ -696,7 +763,8 @@ def _next(self) -> Any: # no data race on `due`. self._state.due.add(cygrpc.OperationType.receive_message) operating = self._call.operate( - (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), None) + (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), None + ) if not operating: self._state.due.remove(cygrpc.OperationType.receive_message) elif self._state.code is grpc.StatusCode.OK: @@ -709,11 +777,14 @@ def debug_error_string(self) -> Optional[str]: with self._state.condition: if self._state.debug_error_string is None: raise grpc.experimental.UsageError( - "Cannot get debug error string until RPC is completed.") + "Cannot get debug error string until RPC is completed." + ) return _common.decode(self._state.debug_error_string) -class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: disable=too-many-ancestors +class _MultiThreadedRendezvous( + _Rendezvous, grpc.Call, grpc.Future +): # pylint: disable=too-many-ancestors """An RPC iterator that depends on a channel spin thread. This iterator relies upon a per-channel thread running in the background, @@ -723,6 +794,7 @@ class _MultiThreadedRendezvous(_Rendezvous, grpc.Call, grpc.Future): # pylint: This extra thread allows _MultiThreadedRendezvous to fulfill the grpc.Future interface and to mediate a bidirection streaming RPC. """ + _state: _RPCState def initial_metadata(self) -> Optional[MetadataType]: @@ -795,9 +867,9 @@ def result(self, timeout: Optional[float] = None) -> Any: See grpc.Future.result for the full API contract. """ with self._state.condition: - timed_out = _common.wait(self._state.condition.wait, - self._is_complete, - timeout=timeout) + timed_out = _common.wait( + self._state.condition.wait, self._is_complete, timeout=timeout + ) if timed_out: raise grpc.FutureTimeoutError() else: @@ -814,9 +886,9 @@ def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: See grpc.Future.exception for the full API contract. """ with self._state.condition: - timed_out = _common.wait(self._state.condition.wait, - self._is_complete, - timeout=timeout) + timed_out = _common.wait( + self._state.condition.wait, self._is_complete, timeout=timeout + ) if timed_out: raise grpc.FutureTimeoutError() else: @@ -828,16 +900,16 @@ def exception(self, timeout: Optional[float] = None) -> Optional[Exception]: return self def traceback( - self, - timeout: Optional[float] = None) -> Optional[types.TracebackType]: + self, timeout: Optional[float] = None + ) -> Optional[types.TracebackType]: """Access the traceback of the exception raised by the computation. See grpc.future.traceback for the full API contract. """ with self._state.condition: - timed_out = _common.wait(self._state.condition.wait, - self._is_complete, - timeout=timeout) + timed_out = _common.wait( + self._state.condition.wait, self._is_complete, timeout=timeout + ) if timed_out: raise grpc.FutureTimeoutError() else: @@ -862,12 +934,14 @@ def add_done_callback(self, fn: Callable[[grpc.Future], None]) -> None: def _next(self) -> Any: with self._state.condition: if self._state.code is None: - event_handler = _event_handler(self._state, - self._response_deserializer) + event_handler = _event_handler( + self._state, self._response_deserializer + ) self._state.due.add(cygrpc.OperationType.receive_message) operating = self._call.operate( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), - event_handler) + event_handler, + ) if not operating: self._state.due.remove(cygrpc.OperationType.receive_message) elif self._state.code is grpc.StatusCode.OK: @@ -876,10 +950,10 @@ def _next(self) -> Any: raise self def _response_ready(): - return (self._state.response is not None or - (cygrpc.OperationType.receive_message - not in self._state.due and - self._state.code is not None)) + return self._state.response is not None or ( + cygrpc.OperationType.receive_message not in self._state.due + and self._state.code is not None + ) _common.wait(self._state.condition.wait, _response_ready) if self._state.response is not None: @@ -894,14 +968,20 @@ def _response_ready(): def _start_unary_request( - request: Any, timeout: Optional[float], - request_serializer: SerializingFunction + request: Any, + timeout: Optional[float], + request_serializer: SerializingFunction, ) -> Tuple[Optional[float], Optional[bytes], Optional[grpc.RpcError]]: deadline = _deadline(timeout) serialized_request = _common.serialize(request, request_serializer) if serialized_request is None: - state = _RPCState((), (), (), grpc.StatusCode.INTERNAL, - 'Exception serializing request!') + state = _RPCState( + (), + (), + (), + grpc.StatusCode.INTERNAL, + "Exception serializing request!", + ) error = _InactiveRpcError(state) return deadline, None, error else: @@ -909,8 +989,10 @@ def _start_unary_request( def _end_unary_response_blocking( - state: _RPCState, call: cygrpc.SegregatedCall, with_call: bool, - deadline: Optional[float] + state: _RPCState, + call: cygrpc.SegregatedCall, + with_call: bool, + deadline: Optional[float], ) -> Union[ResponseType, Tuple[ResponseType, grpc.Call]]: if state.code is grpc.StatusCode.OK: if with_call: @@ -923,12 +1005,13 @@ def _end_unary_response_blocking( def _stream_unary_invocation_operations( - metadata: Optional[MetadataType], - initial_metadata_flags: int) -> Sequence[Sequence[cygrpc.Operation]]: + metadata: Optional[MetadataType], initial_metadata_flags: int +) -> Sequence[Sequence[cygrpc.Operation]]: return ( ( - cygrpc.SendInitialMetadataOperation(metadata, - initial_metadata_flags), + cygrpc.SendInitialMetadataOperation( + metadata, initial_metadata_flags + ), cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), ), @@ -939,11 +1022,15 @@ def _stream_unary_invocation_operations( def _stream_unary_invocation_operations_and_tags( metadata: Optional[MetadataType], initial_metadata_flags: int ) -> Sequence[Tuple[Sequence[cygrpc.Operation], Optional[UserTag]]]: - return tuple(( - operations, - None, - ) for operations in _stream_unary_invocation_operations( - metadata, initial_metadata_flags)) + return tuple( + ( + operations, + None, + ) + for operations in _stream_unary_invocation_operations( + metadata, initial_metadata_flags + ) + ) def _determine_deadline(user_deadline: Optional[float]) -> Optional[float]: @@ -967,10 +1054,14 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel: cygrpc.Channel, - managed_call: IntegratedCallFactory, method: bytes, - request_serializer: Optional[SerializingFunction], - response_deserializer: Optional[DeserializingFunction]): + def __init__( + self, + channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, + method: bytes, + request_serializer: Optional[SerializingFunction], + response_deserializer: Optional[DeserializingFunction], + ): self._channel = channel self._managed_call = managed_call self._method = method @@ -979,24 +1070,35 @@ def __init__(self, channel: cygrpc.Channel, self._context = cygrpc.build_census_context() def _prepare( - self, request: Any, timeout: Optional[float], - metadata: Optional[MetadataType], wait_for_ready: Optional[bool], - compression: Optional[grpc.Compression] - ) -> Tuple[Optional[_RPCState], Optional[Sequence[cygrpc.Operation]], - Optional[float], Optional[grpc.RpcError]]: + self, + request: Any, + timeout: Optional[float], + metadata: Optional[MetadataType], + wait_for_ready: Optional[bool], + compression: Optional[grpc.Compression], + ) -> Tuple[ + Optional[_RPCState], + Optional[Sequence[cygrpc.Operation]], + Optional[float], + Optional[grpc.RpcError], + ]: deadline, serialized_request, rendezvous = _start_unary_request( - request, timeout, self._request_serializer) + request, timeout, self._request_serializer + ) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( - wait_for_ready) + wait_for_ready + ) augmented_metadata = _compression.augment_metadata( - metadata, compression) + metadata, compression + ) if serialized_request is None: return None, None, None, rendezvous else: state = _RPCState(_UNARY_UNARY_INITIAL_DUE, None, None, None, None) operations = ( - cygrpc.SendInitialMetadataOperation(augmented_metadata, - initial_metadata_flags), + cygrpc.SendInitialMetadataOperation( + augmented_metadata, initial_metadata_flags + ), cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), @@ -1012,35 +1114,50 @@ def _blocking( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> Tuple[_RPCState, cygrpc.SegregatedCall]: state, operations, deadline, rendezvous = self._prepare( - request, timeout, metadata, wait_for_ready, compression) + request, timeout, metadata, wait_for_ready, compression + ) if state is None: raise rendezvous # pylint: disable-msg=raising-bad-type else: call = self._channel.segregated_call( cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, - self._method, None, _determine_deadline(deadline), metadata, - None if credentials is None else credentials._credentials, (( - operations, - None, - ),), self._context) + self._method, + None, + _determine_deadline(deadline), + metadata, + None if credentials is None else credentials._credentials, + ( + ( + operations, + None, + ), + ), + self._context, + ) state.rpc_start_time = datetime.utcnow() state.method = _common.decode(self._method) event = call.next_event() _handle_event(event, state, self._response_deserializer) return state, call - def __call__(self, - request: Any, - timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, - credentials: Optional[grpc.CallCredentials] = None, - wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None) -> Any: - state, call, = self._blocking(request, timeout, metadata, credentials, - wait_for_ready, compression) + def __call__( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None, + ) -> Any: + ( + state, + call, + ) = self._blocking( + request, timeout, metadata, credentials, wait_for_ready, compression + ) return _end_unary_response_blocking(state, call, False, None) def with_call( @@ -1050,10 +1167,14 @@ def with_call( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> Tuple[Any, grpc.Call]: - state, call, = self._blocking(request, timeout, metadata, credentials, - wait_for_ready, compression) + ( + state, + call, + ) = self._blocking( + request, timeout, metadata, credentials, wait_for_ready, compression + ) return _end_unary_response_blocking(state, call, True, None) def future( @@ -1063,24 +1184,31 @@ def future( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _MultiThreadedRendezvous: state, operations, deadline, rendezvous = self._prepare( - request, timeout, metadata, wait_for_ready, compression) + request, timeout, metadata, wait_for_ready, compression + ) if state is None: raise rendezvous # pylint: disable-msg=raising-bad-type else: event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, - self._method, None, deadline, metadata, + self._method, + None, + deadline, + metadata, None if credentials is None else credentials._credentials, - (operations,), event_handler, self._context) + (operations,), + event_handler, + self._context, + ) state.rpc_start_time = datetime.utcnow() state.method = _common.decode(self._method) - return _MultiThreadedRendezvous(state, call, - self._response_deserializer, - deadline) + return _MultiThreadedRendezvous( + state, call, self._response_deserializer, deadline + ) class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): @@ -1091,9 +1219,13 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel: cygrpc.Channel, method: bytes, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction): + def __init__( + self, + channel: cygrpc.Channel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ): self._channel = channel self._method = method self._request_serializer = request_serializer @@ -1107,39 +1239,59 @@ def __call__( # pylint: disable=too-many-locals metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _SingleThreadedRendezvous: deadline = _deadline(timeout) - serialized_request = _common.serialize(request, - self._request_serializer) + serialized_request = _common.serialize( + request, self._request_serializer + ) if serialized_request is None: - state = _RPCState((), (), (), grpc.StatusCode.INTERNAL, - 'Exception serializing request!') + state = _RPCState( + (), + (), + (), + grpc.StatusCode.INTERNAL, + "Exception serializing request!", + ) raise _InactiveRpcError(state) state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) - call_credentials = None if credentials is None else credentials._credentials + call_credentials = ( + None if credentials is None else credentials._credentials + ) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( - wait_for_ready) + wait_for_ready + ) augmented_metadata = _compression.augment_metadata( - metadata, compression) + metadata, compression + ) operations = ( - (cygrpc.SendInitialMetadataOperation(augmented_metadata, - initial_metadata_flags), - cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS)), + ( + cygrpc.SendInitialMetadataOperation( + augmented_metadata, initial_metadata_flags + ), + cygrpc.SendMessageOperation(serialized_request, _EMPTY_FLAGS), + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + ), (cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),), (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), ) operations_and_tags = tuple((ops, None) for ops in operations) call = self._channel.segregated_call( - cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, - None, _determine_deadline(deadline), metadata, call_credentials, - operations_and_tags, self._context) + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, + None, + _determine_deadline(deadline), + metadata, + call_credentials, + operations_and_tags, + self._context, + ) state.rpc_start_time = datetime.utcnow() state.method = _common.decode(self._method) - return _SingleThreadedRendezvous(state, call, - self._response_deserializer, deadline) + return _SingleThreadedRendezvous( + state, call, self._response_deserializer, deadline + ) class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): @@ -1151,10 +1303,14 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel: cygrpc.Channel, - managed_call: IntegratedCallFactory, method: bytes, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction): + def __init__( + self, + channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ): self._channel = channel self._managed_call = managed_call self._method = method @@ -1169,24 +1325,29 @@ def __call__( # pylint: disable=too-many-locals metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[ - grpc.Compression] = None) -> _MultiThreadedRendezvous: + compression: Optional[grpc.Compression] = None, + ) -> _MultiThreadedRendezvous: deadline, serialized_request, rendezvous = _start_unary_request( - request, timeout, self._request_serializer) + request, timeout, self._request_serializer + ) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( - wait_for_ready) + wait_for_ready + ) if serialized_request is None: raise rendezvous # pylint: disable-msg=raising-bad-type else: augmented_metadata = _compression.augment_metadata( - metadata, compression) + metadata, compression + ) state = _RPCState(_UNARY_STREAM_INITIAL_DUE, None, None, None, None) operations = ( ( - cygrpc.SendInitialMetadataOperation(augmented_metadata, - initial_metadata_flags), - cygrpc.SendMessageOperation(serialized_request, - _EMPTY_FLAGS), + cygrpc.SendInitialMetadataOperation( + augmented_metadata, initial_metadata_flags + ), + cygrpc.SendMessageOperation( + serialized_request, _EMPTY_FLAGS + ), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), ), @@ -1194,15 +1355,20 @@ def __call__( # pylint: disable=too-many-locals ) call = self._managed_call( cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, - self._method, None, _determine_deadline(deadline), metadata, + self._method, + None, + _determine_deadline(deadline), + metadata, None if credentials is None else credentials._credentials, - operations, _event_handler(state, self._response_deserializer), - self._context) + operations, + _event_handler(state, self._response_deserializer), + self._context, + ) state.rpc_start_time = datetime.utcnow() state.method = _common.decode(self._method) - return _MultiThreadedRendezvous(state, call, - self._response_deserializer, - deadline) + return _MultiThreadedRendezvous( + state, call, self._response_deserializer, deadline + ) class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): @@ -1214,10 +1380,14 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): _context: Any # pylint: disable=too-many-arguments - def __init__(self, channel: cygrpc.Channel, - managed_call: IntegratedCallFactory, method: bytes, - request_serializer: Optional[SerializingFunction], - response_deserializer: Optional[DeserializingFunction]): + def __init__( + self, + channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, + method: bytes, + request_serializer: Optional[SerializingFunction], + response_deserializer: Optional[DeserializingFunction], + ): self._channel = channel self._managed_call = managed_call self._method = method @@ -1226,27 +1396,39 @@ def __init__(self, channel: cygrpc.Channel, self._context = cygrpc.build_census_context() def _blocking( - self, request_iterator: Iterator, timeout: Optional[float], + self, + request_iterator: Iterator, + timeout: Optional[float], metadata: Optional[MetadataType], credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], compression: Optional[grpc.Compression] + wait_for_ready: Optional[bool], + compression: Optional[grpc.Compression], ) -> Tuple[_RPCState, cygrpc.SegregatedCall]: deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( - wait_for_ready) + wait_for_ready + ) augmented_metadata = _compression.augment_metadata( - metadata, compression) + metadata, compression + ) call = self._channel.segregated_call( - cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, - None, _determine_deadline(deadline), augmented_metadata, + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, + None, + _determine_deadline(deadline), + augmented_metadata, None if credentials is None else credentials._credentials, _stream_unary_invocation_operations_and_tags( - augmented_metadata, initial_metadata_flags), self._context) + augmented_metadata, initial_metadata_flags + ), + self._context, + ) state.rpc_start_time = datetime.utcnow() state.method = _common.decode(self._method) - _consume_request_iterator(request_iterator, state, call, - self._request_serializer, None) + _consume_request_iterator( + request_iterator, state, call, self._request_serializer, None + ) while True: event = call.next_event() with state.condition: @@ -1256,15 +1438,26 @@ def _blocking( break return state, call - def __call__(self, - request_iterator: Iterator, - timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, - credentials: Optional[grpc.CallCredentials] = None, - wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None) -> Any: - state, call, = self._blocking(request_iterator, timeout, metadata, - credentials, wait_for_ready, compression) + def __call__( + self, + request_iterator: Iterator, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None, + ) -> Any: + ( + state, + call, + ) = self._blocking( + request_iterator, + timeout, + metadata, + credentials, + wait_for_ready, + compression, + ) return _end_unary_response_blocking(state, call, False, None) def with_call( @@ -1274,10 +1467,19 @@ def with_call( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> Tuple[Any, grpc.Call]: - state, call, = self._blocking(request_iterator, timeout, metadata, - credentials, wait_for_ready, compression) + ( + state, + call, + ) = self._blocking( + request_iterator, + timeout, + metadata, + credentials, + wait_for_ready, + compression, + ) return _end_unary_response_blocking(state, call, True, None) def future( @@ -1287,28 +1489,42 @@ def future( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _MultiThreadedRendezvous: deadline = _deadline(timeout) state = _RPCState(_STREAM_UNARY_INITIAL_DUE, None, None, None, None) event_handler = _event_handler(state, self._response_deserializer) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( - wait_for_ready) + wait_for_ready + ) augmented_metadata = _compression.augment_metadata( - metadata, compression) + metadata, compression + ) call = self._managed_call( - cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, - None, deadline, augmented_metadata, + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, + None, + deadline, + augmented_metadata, None if credentials is None else credentials._credentials, - _stream_unary_invocation_operations(metadata, - initial_metadata_flags), - event_handler, self._context) + _stream_unary_invocation_operations( + metadata, initial_metadata_flags + ), + event_handler, + self._context, + ) state.rpc_start_time = datetime.utcnow() state.method = _common.decode(self._method) - _consume_request_iterator(request_iterator, state, call, - self._request_serializer, event_handler) - return _MultiThreadedRendezvous(state, call, - self._response_deserializer, deadline) + _consume_request_iterator( + request_iterator, + state, + call, + self._request_serializer, + event_handler, + ) + return _MultiThreadedRendezvous( + state, call, self._response_deserializer, deadline + ) class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): @@ -1320,12 +1536,14 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): _context: Any # pylint: disable=too-many-arguments - def __init__(self, - channel: cygrpc.Channel, - managed_call: IntegratedCallFactory, - method: bytes, - request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None): + def __init__( + self, + channel: cygrpc.Channel, + managed_call: IntegratedCallFactory, + method: bytes, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ): self._channel = channel self._managed_call = managed_call self._method = method @@ -1340,34 +1558,49 @@ def __call__( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _MultiThreadedRendezvous: deadline = _deadline(timeout) state = _RPCState(_STREAM_STREAM_INITIAL_DUE, None, None, None, None) initial_metadata_flags = _InitialMetadataFlags().with_wait_for_ready( - wait_for_ready) + wait_for_ready + ) augmented_metadata = _compression.augment_metadata( - metadata, compression) + metadata, compression + ) operations = ( ( - cygrpc.SendInitialMetadataOperation(augmented_metadata, - initial_metadata_flags), + cygrpc.SendInitialMetadataOperation( + augmented_metadata, initial_metadata_flags + ), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), ), (cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),), ) event_handler = _event_handler(state, self._response_deserializer) call = self._managed_call( - cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, self._method, - None, _determine_deadline(deadline), augmented_metadata, + cygrpc.PropagationConstants.GRPC_PROPAGATE_DEFAULTS, + self._method, + None, + _determine_deadline(deadline), + augmented_metadata, None if credentials is None else credentials._credentials, - operations, event_handler, self._context) + operations, + event_handler, + self._context, + ) state.rpc_start_time = datetime.utcnow() state.method = _common.decode(self._method) - _consume_request_iterator(request_iterator, state, call, - self._request_serializer, event_handler) - return _MultiThreadedRendezvous(state, call, - self._response_deserializer, deadline) + _consume_request_iterator( + request_iterator, + state, + call, + self._request_serializer, + event_handler, + ) + return _MultiThreadedRendezvous( + state, call, self._response_deserializer, deadline + ) class _InitialMetadataFlags(int): @@ -1380,11 +1613,16 @@ def __new__(cls, value: int = _EMPTY_FLAGS): def with_wait_for_ready(self, wait_for_ready: Optional[bool]) -> int: if wait_for_ready is not None: if wait_for_ready: - return self.__class__(self | cygrpc.InitialMetadataFlags.wait_for_ready | \ - cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) + return self.__class__( + self + | cygrpc.InitialMetadataFlags.wait_for_ready + | cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set + ) elif not wait_for_ready: - return self.__class__(self & ~cygrpc.InitialMetadataFlags.wait_for_ready | \ - cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set) + return self.__class__( + self & ~cygrpc.InitialMetadataFlags.wait_for_ready + | cygrpc.InitialMetadataFlags.wait_for_ready_explicitly_set + ) return self @@ -1404,14 +1642,14 @@ def reset_postfork_child(self) -> None: def __del__(self): try: - self.channel.close(cygrpc.StatusCode.cancelled, - 'Channel deallocated!') + self.channel.close( + cygrpc.StatusCode.cancelled, "Channel deallocated!" + ) except (TypeError, AttributeError): pass def _run_channel_spin_thread(state: _ChannelCallState) -> None: - def channel_spin(): while True: cygrpc.block_if_fork_in_progress(state) @@ -1431,13 +1669,18 @@ def channel_spin(): def _channel_managed_call_management(state: _ChannelCallState): - # pylint: disable=too-many-arguments - def create(flags: int, method: bytes, host: Optional[str], - deadline: Optional[float], metadata: Optional[MetadataType], - credentials: Optional[cygrpc.CallCredentials], - operations: Sequence[Sequence[cygrpc.Operation]], - event_handler: UserTag, context) -> cygrpc.IntegratedCall: + def create( + flags: int, + method: bytes, + host: Optional[str], + deadline: Optional[float], + metadata: Optional[MetadataType], + credentials: Optional[cygrpc.CallCredentials], + operations: Sequence[Sequence[cygrpc.Operation]], + event_handler: UserTag, + context, + ) -> cygrpc.IntegratedCall: """Creates a cygrpc.IntegratedCall. Args: @@ -1456,14 +1699,24 @@ def create(flags: int, method: bytes, host: Optional[str], Returns: A cygrpc.IntegratedCall with which to conduct an RPC. """ - operations_and_tags = tuple(( - operation, - event_handler, - ) for operation in operations) + operations_and_tags = tuple( + ( + operation, + event_handler, + ) + for operation in operations + ) with state.lock: - call = state.channel.integrated_call(flags, method, host, deadline, - metadata, credentials, - operations_and_tags, context) + call = state.channel.integrated_call( + flags, + method, + host, + deadline, + metadata, + credentials, + operations_and_tags, + context, + ) if state.managed_calls == 0: state.managed_calls = 1 _run_channel_spin_thread(state) @@ -1481,8 +1734,14 @@ class _ChannelConnectivityState(object): connectivity: grpc.ChannelConnectivity try_to_connect: bool # TODO(xuanwn): Refactor this: https://github.com/grpc/grpc/issues/31704 - callbacks_and_connectivities: List[Sequence[Union[Callable[ - [grpc.ChannelConnectivity], None], Optional[grpc.ChannelConnectivity]]]] + callbacks_and_connectivities: List[ + Sequence[ + Union[ + Callable[[grpc.ChannelConnectivity], None], + Optional[grpc.ChannelConnectivity], + ] + ] + ] delivering: bool def __init__(self, channel: grpc.Channel): @@ -1503,11 +1762,14 @@ def reset_postfork_child(self) -> None: def _deliveries( - state: _ChannelConnectivityState + state: _ChannelConnectivityState, ) -> List[Callable[[grpc.ChannelConnectivity], None]]: callbacks_needing_update = [] for callback_and_connectivity in state.callbacks_and_connectivities: - callback, callback_connectivity, = callback_and_connectivity + ( + callback, + callback_connectivity, + ) = callback_and_connectivity if callback_connectivity is not state.connectivity: callbacks_needing_update.append(callback) callback_and_connectivity[1] = state.connectivity @@ -1517,7 +1779,7 @@ def _deliveries( def _deliver( state: _ChannelConnectivityState, initial_connectivity: grpc.ChannelConnectivity, - initial_callbacks: Sequence[Callable[[grpc.ChannelConnectivity], None]] + initial_callbacks: Sequence[Callable[[grpc.ChannelConnectivity], None]], ) -> None: connectivity = initial_connectivity callbacks = initial_callbacks @@ -1528,7 +1790,8 @@ def _deliver( callback(connectivity) except Exception: # pylint: disable=broad-except _LOGGER.exception( - _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE) + _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE + ) with state.lock: callbacks = _deliveries(state) if callbacks: @@ -1539,42 +1802,53 @@ def _deliver( def _spawn_delivery( - state: _ChannelConnectivityState, - callbacks: Sequence[Callable[[grpc.ChannelConnectivity], - None]]) -> None: - delivering_thread = cygrpc.ForkManagedThread(target=_deliver, - args=( - state, - state.connectivity, - callbacks, - )) + state: _ChannelConnectivityState, + callbacks: Sequence[Callable[[grpc.ChannelConnectivity], None]], +) -> None: + delivering_thread = cygrpc.ForkManagedThread( + target=_deliver, + args=( + state, + state.connectivity, + callbacks, + ), + ) delivering_thread.setDaemon(True) delivering_thread.start() state.delivering = True # NOTE(https://github.com/grpc/grpc/issues/3064): We'd rather not poll. -def _poll_connectivity(state: _ChannelConnectivityState, channel: grpc.Channel, - initial_try_to_connect: bool) -> None: +def _poll_connectivity( + state: _ChannelConnectivityState, + channel: grpc.Channel, + initial_try_to_connect: bool, +) -> None: try_to_connect = initial_try_to_connect connectivity = channel.check_connectivity_state(try_to_connect) with state.lock: state.connectivity = ( - _common. - CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[connectivity]) + _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[ + connectivity + ] + ) callbacks = tuple( - callback for callback, unused_but_known_to_be_none_connectivity in - state.callbacks_and_connectivities) + callback for callback, _ in state.callbacks_and_connectivities + ) for callback_and_connectivity in state.callbacks_and_connectivities: callback_and_connectivity[1] = state.connectivity if callbacks: _spawn_delivery(state, callbacks) while True: - event = channel.watch_connectivity_state(connectivity, - time.time() + 0.2) + event = channel.watch_connectivity_state( + connectivity, time.time() + 0.2 + ) cygrpc.block_if_fork_in_progress(state) with state.lock: - if not state.callbacks_and_connectivities and not state.try_to_connect: + if ( + not state.callbacks_and_connectivities + and not state.try_to_connect + ): state.polling = False state.connectivity = None break @@ -1585,21 +1859,26 @@ def _poll_connectivity(state: _ChannelConnectivityState, channel: grpc.Channel, with state.lock: state.connectivity = ( _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[ - connectivity]) + connectivity + ] + ) if not state.delivering: callbacks = _deliveries(state) if callbacks: _spawn_delivery(state, callbacks) -def _subscribe(state: _ChannelConnectivityState, - callback: Callable[[grpc.ChannelConnectivity], - None], try_to_connect: bool) -> None: +def _subscribe( + state: _ChannelConnectivityState, + callback: Callable[[grpc.ChannelConnectivity], None], + try_to_connect: bool, +) -> None: with state.lock: if not state.callbacks_and_connectivities and not state.polling: polling_thread = cygrpc.ForkManagedThread( target=_poll_connectivity, - args=(state, state.channel, bool(try_to_connect))) + args=(state, state.channel, bool(try_to_connect)), + ) polling_thread.setDaemon(True) polling_thread.start() state.polling = True @@ -1608,41 +1887,54 @@ def _subscribe(state: _ChannelConnectivityState, _spawn_delivery(state, (callback,)) state.try_to_connect |= bool(try_to_connect) state.callbacks_and_connectivities.append( - [callback, state.connectivity]) + [callback, state.connectivity] + ) else: state.try_to_connect |= bool(try_to_connect) state.callbacks_and_connectivities.append([callback, None]) -def _unsubscribe(state: _ChannelConnectivityState, - callback: Callable[[grpc.ChannelConnectivity], None]) -> None: +def _unsubscribe( + state: _ChannelConnectivityState, + callback: Callable[[grpc.ChannelConnectivity], None], +) -> None: with state.lock: for index, (subscribed_callback, unused_connectivity) in enumerate( - state.callbacks_and_connectivities): + state.callbacks_and_connectivities + ): if callback == subscribed_callback: state.callbacks_and_connectivities.pop(index) break def _augment_options( - base_options: Sequence[ChannelArgumentType], - compression: Optional[grpc.Compression] + base_options: Sequence[ChannelArgumentType], + compression: Optional[grpc.Compression], ) -> Sequence[ChannelArgumentType]: compression_option = _compression.create_channel_option(compression) - return tuple(base_options) + compression_option + (( - cygrpc.ChannelArgKey.primary_user_agent_string, - _USER_AGENT, - ),) + return ( + tuple(base_options) + + compression_option + + ( + ( + cygrpc.ChannelArgKey.primary_user_agent_string, + _USER_AGENT, + ), + ) + ) def _separate_channel_options( - options: Sequence[ChannelArgumentType] + options: Sequence[ChannelArgumentType], ) -> Tuple[Sequence[ChannelArgumentType], Sequence[ChannelArgumentType]]: """Separates core channel options from Python channel options.""" core_options = [] python_options = [] for pair in options: - if pair[0] == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream: + if ( + pair[0] + == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream + ): python_options.append(pair) else: core_options.append(pair) @@ -1651,14 +1943,19 @@ def _separate_channel_options( class Channel(grpc.Channel): """A cygrpc.Channel-backed implementation of grpc.Channel.""" + _single_threaded_unary_stream: bool _channel: cygrpc.Channel _call_state: _ChannelCallState _connectivity_state: _ChannelConnectivityState - def __init__(self, target: str, options: Sequence[ChannelArgumentType], - credentials: Optional[grpc.ChannelCredentials], - compression: Optional[grpc.Compression]): + def __init__( + self, + target: str, + options: Sequence[ChannelArgumentType], + credentials: Optional[grpc.ChannelCredentials], + compression: Optional[grpc.Compression], + ): """Constructor. Args: @@ -1669,11 +1966,15 @@ def __init__(self, target: str, options: Sequence[ChannelArgumentType], used over the lifetime of the channel. """ python_options, core_options = _separate_channel_options(options) - self._single_threaded_unary_stream = _DEFAULT_SINGLE_THREADED_UNARY_STREAM + self._single_threaded_unary_stream = ( + _DEFAULT_SINGLE_THREADED_UNARY_STREAM + ) self._process_python_options(python_options) self._channel = cygrpc.Channel( - _common.encode(target), _augment_options(core_options, compression), - credentials) + _common.encode(target), + _augment_options(core_options, compression), + credentials, + ) self._call_state = _ChannelCallState(self._channel) self._connectivity_state = _ChannelConnectivityState(self._channel) cygrpc.fork_register_channel(self) @@ -1681,36 +1982,47 @@ def __init__(self, target: str, options: Sequence[ChannelArgumentType], cygrpc.gevent_increment_channel_count() def _process_python_options( - self, python_options: Sequence[ChannelArgumentType]) -> None: + self, python_options: Sequence[ChannelArgumentType] + ) -> None: """Sets channel attributes according to python-only channel options.""" for pair in python_options: - if pair[0] == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream: + if ( + pair[0] + == grpc.experimental.ChannelOptions.SingleThreadedUnaryStream + ): self._single_threaded_unary_stream = True - def subscribe(self, - callback: Callable[[grpc.ChannelConnectivity], None], - try_to_connect: Optional[bool] = None) -> None: + def subscribe( + self, + callback: Callable[[grpc.ChannelConnectivity], None], + try_to_connect: Optional[bool] = None, + ) -> None: _subscribe(self._connectivity_state, callback, try_to_connect) def unsubscribe( - self, callback: Callable[[grpc.ChannelConnectivity], None]) -> None: + self, callback: Callable[[grpc.ChannelConnectivity], None] + ) -> None: _unsubscribe(self._connectivity_state, callback) def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> grpc.UnaryUnaryMultiCallable: return _UnaryUnaryMultiCallable( - self._channel, _channel_managed_call_management(self._call_state), - _common.encode(method), request_serializer, response_deserializer) + self._channel, + _channel_managed_call_management(self._call_state), + _common.encode(method), + request_serializer, + response_deserializer, + ) def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> grpc.UnaryStreamMultiCallable: # NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC # on a single Python thread results in an appreciable speed-up. However, @@ -1718,34 +2030,47 @@ def unary_stream( # remains the default. if self._single_threaded_unary_stream: return _SingleThreadedUnaryStreamMultiCallable( - self._channel, _common.encode(method), request_serializer, - response_deserializer) + self._channel, + _common.encode(method), + request_serializer, + response_deserializer, + ) else: return _UnaryStreamMultiCallable( self._channel, _channel_managed_call_management(self._call_state), - _common.encode(method), request_serializer, - response_deserializer) + _common.encode(method), + request_serializer, + response_deserializer, + ) def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> grpc.StreamUnaryMultiCallable: return _StreamUnaryMultiCallable( - self._channel, _channel_managed_call_management(self._call_state), - _common.encode(method), request_serializer, response_deserializer) + self._channel, + _channel_managed_call_management(self._call_state), + _common.encode(method), + request_serializer, + response_deserializer, + ) def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> grpc.StreamStreamMultiCallable: return _StreamStreamMultiCallable( - self._channel, _channel_managed_call_management(self._call_state), - _common.encode(method), request_serializer, response_deserializer) + self._channel, + _channel_managed_call_management(self._call_state), + _common.encode(method), + request_serializer, + response_deserializer, + ) def _unsubscribe_all(self) -> None: state = self._connectivity_state @@ -1755,15 +2080,16 @@ def _unsubscribe_all(self) -> None: def _close(self) -> None: self._unsubscribe_all() - self._channel.close(cygrpc.StatusCode.cancelled, 'Channel closed!') + self._channel.close(cygrpc.StatusCode.cancelled, "Channel closed!") cygrpc.fork_unregister_channel(self) if cygrpc.g_gevent_activated: cygrpc.gevent_decrement_channel_count() def _close_on_fork(self) -> None: self._unsubscribe_all() - self._channel.close_on_fork(cygrpc.StatusCode.cancelled, - 'Channel closed due to fork') + self._channel.close_on_fork( + cygrpc.StatusCode.cancelled, "Channel closed due to fork" + ) def __enter__(self): return self diff --git a/src/python/grpcio/grpc/_common.py b/src/python/grpcio/grpc/_common.py index 3b8fd0ff97d89..475f0510cf84e 100644 --- a/src/python/grpcio/grpc/_common.py +++ b/src/python/grpcio/grpc/_common.py @@ -25,16 +25,11 @@ _LOGGER = logging.getLogger(__name__) CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = { - cygrpc.ConnectivityState.idle: - grpc.ChannelConnectivity.IDLE, - cygrpc.ConnectivityState.connecting: - grpc.ChannelConnectivity.CONNECTING, - cygrpc.ConnectivityState.ready: - grpc.ChannelConnectivity.READY, - cygrpc.ConnectivityState.transient_failure: - grpc.ChannelConnectivity.TRANSIENT_FAILURE, - cygrpc.ConnectivityState.shutdown: - grpc.ChannelConnectivity.SHUTDOWN, + cygrpc.ConnectivityState.idle: grpc.ChannelConnectivity.IDLE, + cygrpc.ConnectivityState.connecting: grpc.ChannelConnectivity.CONNECTING, + cygrpc.ConnectivityState.ready: grpc.ChannelConnectivity.READY, + cygrpc.ConnectivityState.transient_failure: grpc.ChannelConnectivity.TRANSIENT_FAILURE, + cygrpc.ConnectivityState.shutdown: grpc.ChannelConnectivity.SHUTDOWN, } CYGRPC_STATUS_CODE_TO_STATUS_CODE = { @@ -63,26 +58,30 @@ MAXIMUM_WAIT_TIMEOUT = 0.1 -_ERROR_MESSAGE_PORT_BINDING_FAILED = 'Failed to bind to address %s; set ' \ - 'GRPC_VERBOSITY=debug environment variable to see detailed error message.' +_ERROR_MESSAGE_PORT_BINDING_FAILED = ( + "Failed to bind to address %s; set " + "GRPC_VERBOSITY=debug environment variable to see detailed error message." +) def encode(s: AnyStr) -> bytes: if isinstance(s, bytes): return s else: - return s.encode('utf8') + return s.encode("utf8") def decode(b: AnyStr) -> str: if isinstance(b, bytes): - return b.decode('utf-8', 'replace') + return b.decode("utf-8", "replace") return b -def _transform(message: Any, transformer: Union[SerializingFunction, - DeserializingFunction, None], - exception_message: str) -> Any: +def _transform( + message: Any, + transformer: Union[SerializingFunction, DeserializingFunction, None], + exception_message: str, +) -> Any: if transformer is None: return message else: @@ -94,30 +93,37 @@ def _transform(message: Any, transformer: Union[SerializingFunction, def serialize(message: Any, serializer: Optional[SerializingFunction]) -> bytes: - return _transform(message, serializer, 'Exception serializing message!') + return _transform(message, serializer, "Exception serializing message!") -def deserialize(serialized_message: bytes, - deserializer: Optional[DeserializingFunction]) -> Any: - return _transform(serialized_message, deserializer, - 'Exception deserializing message!') +def deserialize( + serialized_message: bytes, deserializer: Optional[DeserializingFunction] +) -> Any: + return _transform( + serialized_message, deserializer, "Exception deserializing message!" + ) def fully_qualified_method(group: str, method: str) -> str: - return '/{}/{}'.format(group, method) + return "/{}/{}".format(group, method) -def _wait_once(wait_fn: Callable[..., bool], timeout: float, - spin_cb: Optional[Callable[[], None]]): +def _wait_once( + wait_fn: Callable[..., bool], + timeout: float, + spin_cb: Optional[Callable[[], None]], +): wait_fn(timeout=timeout) if spin_cb is not None: spin_cb() -def wait(wait_fn: Callable[..., bool], - wait_complete_fn: Callable[[], bool], - timeout: Optional[float] = None, - spin_cb: Optional[Callable[[], None]] = None) -> bool: +def wait( + wait_fn: Callable[..., bool], + wait_complete_fn: Callable[[], bool], + timeout: Optional[float] = None, + spin_cb: Optional[Callable[[], None]] = None, +) -> bool: """Blocks waiting for an event without blocking the thread indefinitely. See https://github.com/grpc/grpc/issues/19464 for full context. CPython's diff --git a/src/python/grpcio/grpc/_compression.py b/src/python/grpcio/grpc/_compression.py index 5eb6f2ac6d892..07fa6f8434f86 100644 --- a/src/python/grpcio/grpc/_compression.py +++ b/src/python/grpcio/grpc/_compression.py @@ -25,34 +25,42 @@ Gzip = cygrpc.CompressionAlgorithm.gzip _METADATA_STRING_MAPPING = { - NoCompression: 'identity', - Deflate: 'deflate', - Gzip: 'gzip', + NoCompression: "identity", + Deflate: "deflate", + Gzip: "gzip", } def _compression_algorithm_to_metadata_value( - compression: grpc.Compression) -> str: + compression: grpc.Compression, +) -> str: return _METADATA_STRING_MAPPING[compression] def compression_algorithm_to_metadata(compression: grpc.Compression): - return (cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, - _compression_algorithm_to_metadata_value(compression)) + return ( + cygrpc.GRPC_COMPRESSION_REQUEST_ALGORITHM_MD_KEY, + _compression_algorithm_to_metadata_value(compression), + ) def create_channel_option(compression: Optional[grpc.Compression]): - return ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, - int(compression)),) if compression else () + return ( + ((cygrpc.GRPC_COMPRESSION_CHANNEL_DEFAULT_ALGORITHM, int(compression)),) + if compression + else () + ) -def augment_metadata(metadata: Optional[MetadataType], - compression: Optional[grpc.Compression]): +def augment_metadata( + metadata: Optional[MetadataType], compression: Optional[grpc.Compression] +): if not metadata and not compression: return None base_metadata = tuple(metadata) if metadata else () compression_metadata = ( - compression_algorithm_to_metadata(compression),) if compression else () + (compression_algorithm_to_metadata(compression),) if compression else () + ) return base_metadata + compression_metadata diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index 865ff17d35b88..36bce4e3ba5f6 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -37,8 +37,8 @@ def _continuation(self, thunk: Callable, index: int) -> Callable: return lambda context: self._intercept_at(thunk, index, context) def _intercept_at( - self, thunk: Callable, index: int, - context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler: + self, thunk: Callable, index: int, context: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: if index < len(self.interceptors): interceptor = self.interceptors[index] thunk = self._continuation(thunk, index + 1) @@ -46,30 +46,41 @@ def _intercept_at( else: return thunk(context) - def execute(self, thunk: Callable, - context: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler: + def execute( + self, thunk: Callable, context: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: return self._intercept_at(thunk, 0, context) def service_pipeline( - interceptors: Optional[Sequence[grpc.ServerInterceptor]] + interceptors: Optional[Sequence[grpc.ServerInterceptor]], ) -> Optional[_ServicePipeline]: return _ServicePipeline(interceptors) if interceptors else None class _ClientCallDetails( - collections.namedtuple('_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials', - 'wait_for_ready', 'compression')), - grpc.ClientCallDetails): + collections.namedtuple( + "_ClientCallDetails", + ( + "method", + "timeout", + "metadata", + "credentials", + "wait_for_ready", + "compression", + ), + ), + grpc.ClientCallDetails, +): pass def _unwrap_client_call_details( call_details: grpc.ClientCallDetails, - default_details: grpc.ClientCallDetails -) -> Tuple[str, float, MetadataType, grpc.CallCredentials, bool, - grpc.Compression]: + default_details: grpc.ClientCallDetails, +) -> Tuple[ + str, float, MetadataType, grpc.CallCredentials, bool, grpc.Compression +]: try: method = call_details.method # pytype: disable=attribute-error except AttributeError: @@ -86,24 +97,38 @@ def _unwrap_client_call_details( metadata = default_details.metadata # pytype: disable=attribute-error try: - credentials = call_details.credentials # pytype: disable=attribute-error + credentials = ( + call_details.credentials + ) # pytype: disable=attribute-error except AttributeError: - credentials = default_details.credentials # pytype: disable=attribute-error + credentials = ( + default_details.credentials + ) # pytype: disable=attribute-error try: - wait_for_ready = call_details.wait_for_ready # pytype: disable=attribute-error + wait_for_ready = ( + call_details.wait_for_ready + ) # pytype: disable=attribute-error except AttributeError: - wait_for_ready = default_details.wait_for_ready # pytype: disable=attribute-error + wait_for_ready = ( + default_details.wait_for_ready + ) # pytype: disable=attribute-error try: - compression = call_details.compression # pytype: disable=attribute-error + compression = ( + call_details.compression + ) # pytype: disable=attribute-error except AttributeError: - compression = default_details.compression # pytype: disable=attribute-error + compression = ( + default_details.compression + ) # pytype: disable=attribute-error return method, timeout, metadata, credentials, wait_for_ready, compression -class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too-many-ancestors +class _FailureOutcome( + grpc.RpcError, grpc.Future, grpc.Call +): # pylint: disable=too-many-ancestors _exception: Exception _traceback: types.TracebackType @@ -122,7 +147,7 @@ def code(self) -> Optional[grpc.StatusCode]: return grpc.StatusCode.INTERNAL def details(self) -> Optional[str]: - return 'Exception raised while intercepting the RPC' + return "Exception raised while intercepting the RPC" def cancel(self) -> bool: return False @@ -146,13 +171,12 @@ def result(self, ignored_timeout: Optional[float] = None): raise self._exception def exception( - self, - ignored_timeout: Optional[float] = None) -> Optional[Exception]: + self, ignored_timeout: Optional[float] = None + ) -> Optional[Exception]: return self._exception def traceback( - self, - ignored_timeout: Optional[float] = None + self, ignored_timeout: Optional[float] = None ) -> Optional[types.TracebackType]: return self._traceback @@ -231,25 +255,33 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): _method: str _interceptor: grpc.UnaryUnaryClientInterceptor - def __init__(self, thunk: Callable, method: str, - interceptor: grpc.UnaryUnaryClientInterceptor): + def __init__( + self, + thunk: Callable, + method: str, + interceptor: grpc.UnaryUnaryClientInterceptor, + ): self._thunk = thunk self._method = method self._interceptor = interceptor - def __call__(self, - request: Any, - timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, - credentials: Optional[grpc.CallCredentials] = None, - wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None) -> Any: - response, ignored_call = self._with_call(request, - timeout=timeout, - metadata=metadata, - credentials=credentials, - wait_for_ready=wait_for_ready, - compression=compression) + def __call__( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None, + ) -> Any: + response, ignored_call = self._with_call( + request, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) return response def _with_call( @@ -259,17 +291,26 @@ def _with_call( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> Tuple[Any, grpc.Call]: - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials, - wait_for_ready, compression) + client_call_details = _ClientCallDetails( + self._method, + timeout, + metadata, + credentials, + wait_for_ready, + compression, + ) def continuation(new_details, request): - (new_method, new_timeout, new_metadata, new_credentials, - new_wait_for_ready, - new_compression) = (_unwrap_client_call_details( - new_details, client_call_details)) + ( + new_method, + new_timeout, + new_metadata, + new_credentials, + new_wait_for_ready, + new_compression, + ) = _unwrap_client_call_details(new_details, client_call_details) try: response, call = self._thunk(new_method).with_call( request, @@ -277,16 +318,17 @@ def continuation(new_details, request): metadata=new_metadata, credentials=new_credentials, wait_for_ready=new_wait_for_ready, - compression=new_compression) + compression=new_compression, + ) return _UnaryOutcome(response, call) except grpc.RpcError as rpc_error: return rpc_error except Exception as exception: # pylint:disable=broad-except return _FailureOutcome(exception, sys.exc_info()[2]) - call = self._interceptor.intercept_unary_unary(continuation, - client_call_details, - request) + call = self._interceptor.intercept_unary_unary( + continuation, client_call_details, request + ) return call.result(), call def with_call( @@ -296,42 +338,57 @@ def with_call( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> Tuple[Any, grpc.Call]: - return self._with_call(request, - timeout=timeout, - metadata=metadata, - credentials=credentials, - wait_for_ready=wait_for_ready, - compression=compression) - - def future(self, - request: Any, - timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, - credentials: Optional[grpc.CallCredentials] = None, - wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None) -> Any: - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials, - wait_for_ready, compression) + return self._with_call( + request, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + def future( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None, + ) -> Any: + client_call_details = _ClientCallDetails( + self._method, + timeout, + metadata, + credentials, + wait_for_ready, + compression, + ) def continuation(new_details, request): - (new_method, new_timeout, new_metadata, new_credentials, - new_wait_for_ready, - new_compression) = (_unwrap_client_call_details( - new_details, client_call_details)) + ( + new_method, + new_timeout, + new_metadata, + new_credentials, + new_wait_for_ready, + new_compression, + ) = _unwrap_client_call_details(new_details, client_call_details) return self._thunk(new_method).future( request, timeout=new_timeout, metadata=new_metadata, credentials=new_credentials, wait_for_ready=new_wait_for_ready, - compression=new_compression) + compression=new_compression, + ) try: return self._interceptor.intercept_unary_unary( - continuation, client_call_details, request) + continuation, client_call_details, request + ) except Exception as exception: # pylint:disable=broad-except return _FailureOutcome(exception, sys.exc_info()[2]) @@ -341,38 +398,56 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): _method: str _interceptor: grpc.UnaryStreamClientInterceptor - def __init__(self, thunk: Callable, method: str, - interceptor: grpc.UnaryStreamClientInterceptor): + def __init__( + self, + thunk: Callable, + method: str, + interceptor: grpc.UnaryStreamClientInterceptor, + ): self._thunk = thunk self._method = method self._interceptor = interceptor - def __call__(self, - request: Any, - timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, - credentials: Optional[grpc.CallCredentials] = None, - wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None): - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials, - wait_for_ready, compression) + def __call__( + self, + request: Any, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None, + ): + client_call_details = _ClientCallDetails( + self._method, + timeout, + metadata, + credentials, + wait_for_ready, + compression, + ) def continuation(new_details, request): - (new_method, new_timeout, new_metadata, new_credentials, - new_wait_for_ready, - new_compression) = (_unwrap_client_call_details( - new_details, client_call_details)) - return self._thunk(new_method)(request, - timeout=new_timeout, - metadata=new_metadata, - credentials=new_credentials, - wait_for_ready=new_wait_for_ready, - compression=new_compression) + ( + new_method, + new_timeout, + new_metadata, + new_credentials, + new_wait_for_ready, + new_compression, + ) = _unwrap_client_call_details(new_details, client_call_details) + return self._thunk(new_method)( + request, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials, + wait_for_ready=new_wait_for_ready, + compression=new_compression, + ) try: return self._interceptor.intercept_unary_stream( - continuation, client_call_details, request) + continuation, client_call_details, request + ) except Exception as exception: # pylint:disable=broad-except return _FailureOutcome(exception, sys.exc_info()[2]) @@ -382,25 +457,33 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): _method: str _interceptor: grpc.StreamUnaryClientInterceptor - def __init__(self, thunk: Callable, method: str, - interceptor: grpc.StreamUnaryClientInterceptor): + def __init__( + self, + thunk: Callable, + method: str, + interceptor: grpc.StreamUnaryClientInterceptor, + ): self._thunk = thunk self._method = method self._interceptor = interceptor - def __call__(self, - request_iterator: RequestIterableType, - timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, - credentials: Optional[grpc.CallCredentials] = None, - wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None) -> Any: - response, ignored_call = self._with_call(request_iterator, - timeout=timeout, - metadata=metadata, - credentials=credentials, - wait_for_ready=wait_for_ready, - compression=compression) + def __call__( + self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None, + ) -> Any: + response, ignored_call = self._with_call( + request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) return response def _with_call( @@ -410,17 +493,26 @@ def _with_call( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> Tuple[Any, grpc.Call]: - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials, - wait_for_ready, compression) + client_call_details = _ClientCallDetails( + self._method, + timeout, + metadata, + credentials, + wait_for_ready, + compression, + ) def continuation(new_details, request_iterator): - (new_method, new_timeout, new_metadata, new_credentials, - new_wait_for_ready, - new_compression) = (_unwrap_client_call_details( - new_details, client_call_details)) + ( + new_method, + new_timeout, + new_metadata, + new_credentials, + new_wait_for_ready, + new_compression, + ) = _unwrap_client_call_details(new_details, client_call_details) try: response, call = self._thunk(new_method).with_call( request_iterator, @@ -428,16 +520,17 @@ def continuation(new_details, request_iterator): metadata=new_metadata, credentials=new_credentials, wait_for_ready=new_wait_for_ready, - compression=new_compression) + compression=new_compression, + ) return _UnaryOutcome(response, call) except grpc.RpcError as rpc_error: return rpc_error except Exception as exception: # pylint:disable=broad-except return _FailureOutcome(exception, sys.exc_info()[2]) - call = self._interceptor.intercept_stream_unary(continuation, - client_call_details, - request_iterator) + call = self._interceptor.intercept_stream_unary( + continuation, client_call_details, request_iterator + ) return call.result(), call def with_call( @@ -447,42 +540,57 @@ def with_call( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> Tuple[Any, grpc.Call]: - return self._with_call(request_iterator, - timeout=timeout, - metadata=metadata, - credentials=credentials, - wait_for_ready=wait_for_ready, - compression=compression) - - def future(self, - request_iterator: RequestIterableType, - timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, - credentials: Optional[grpc.CallCredentials] = None, - wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None) -> Any: - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials, - wait_for_ready, compression) + return self._with_call( + request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + def future( + self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None, + ) -> Any: + client_call_details = _ClientCallDetails( + self._method, + timeout, + metadata, + credentials, + wait_for_ready, + compression, + ) def continuation(new_details, request_iterator): - (new_method, new_timeout, new_metadata, new_credentials, - new_wait_for_ready, - new_compression) = (_unwrap_client_call_details( - new_details, client_call_details)) + ( + new_method, + new_timeout, + new_metadata, + new_credentials, + new_wait_for_ready, + new_compression, + ) = _unwrap_client_call_details(new_details, client_call_details) return self._thunk(new_method).future( request_iterator, timeout=new_timeout, metadata=new_metadata, credentials=new_credentials, wait_for_ready=new_wait_for_ready, - compression=new_compression) + compression=new_compression, + ) try: return self._interceptor.intercept_stream_unary( - continuation, client_call_details, request_iterator) + continuation, client_call_details, request_iterator + ) except Exception as exception: # pylint:disable=broad-except return _FailureOutcome(exception, sys.exc_info()[2]) @@ -492,60 +600,85 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): _method: str _interceptor: grpc.StreamStreamClientInterceptor - def __init__(self, thunk: Callable, method: str, - interceptor: grpc.StreamStreamClientInterceptor): + def __init__( + self, + thunk: Callable, + method: str, + interceptor: grpc.StreamStreamClientInterceptor, + ): self._thunk = thunk self._method = method self._interceptor = interceptor - def __call__(self, - request_iterator: RequestIterableType, - timeout: Optional[float] = None, - metadata: Optional[MetadataType] = None, - credentials: Optional[grpc.CallCredentials] = None, - wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None): - client_call_details = _ClientCallDetails(self._method, timeout, - metadata, credentials, - wait_for_ready, compression) + def __call__( + self, + request_iterator: RequestIterableType, + timeout: Optional[float] = None, + metadata: Optional[MetadataType] = None, + credentials: Optional[grpc.CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[grpc.Compression] = None, + ): + client_call_details = _ClientCallDetails( + self._method, + timeout, + metadata, + credentials, + wait_for_ready, + compression, + ) def continuation(new_details, request_iterator): - (new_method, new_timeout, new_metadata, new_credentials, - new_wait_for_ready, - new_compression) = (_unwrap_client_call_details( - new_details, client_call_details)) - return self._thunk(new_method)(request_iterator, - timeout=new_timeout, - metadata=new_metadata, - credentials=new_credentials, - wait_for_ready=new_wait_for_ready, - compression=new_compression) + ( + new_method, + new_timeout, + new_metadata, + new_credentials, + new_wait_for_ready, + new_compression, + ) = _unwrap_client_call_details(new_details, client_call_details) + return self._thunk(new_method)( + request_iterator, + timeout=new_timeout, + metadata=new_metadata, + credentials=new_credentials, + wait_for_ready=new_wait_for_ready, + compression=new_compression, + ) try: return self._interceptor.intercept_stream_stream( - continuation, client_call_details, request_iterator) + continuation, client_call_details, request_iterator + ) except Exception as exception: # pylint:disable=broad-except return _FailureOutcome(exception, sys.exc_info()[2]) class _Channel(grpc.Channel): _channel: grpc.Channel - _interceptor: Union[grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor] - - def __init__(self, channel: grpc.Channel, - interceptor: Union[grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor]): + _interceptor: Union[ + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + ] + + def __init__( + self, + channel: grpc.Channel, + interceptor: Union[ + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + ], + ): self._channel = channel self._interceptor = interceptor - def subscribe(self, - callback: Callable, - try_to_connect: Optional[bool] = False): + def subscribe( + self, callback: Callable, try_to_connect: Optional[bool] = False + ): self._channel.subscribe(callback, try_to_connect=try_to_connect) def unsubscribe(self, callback: Callable): @@ -555,10 +688,11 @@ def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> grpc.UnaryUnaryMultiCallable: - thunk = lambda m: self._channel.unary_unary(m, request_serializer, - response_deserializer) + thunk = lambda m: self._channel.unary_unary( + m, request_serializer, response_deserializer + ) if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) else: @@ -568,10 +702,11 @@ def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> grpc.UnaryStreamMultiCallable: - thunk = lambda m: self._channel.unary_stream(m, request_serializer, - response_deserializer) + thunk = lambda m: self._channel.unary_stream( + m, request_serializer, response_deserializer + ) if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): return _UnaryStreamMultiCallable(thunk, method, self._interceptor) else: @@ -581,10 +716,11 @@ def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> grpc.StreamUnaryMultiCallable: - thunk = lambda m: self._channel.stream_unary(m, request_serializer, - response_deserializer) + thunk = lambda m: self._channel.stream_unary( + m, request_serializer, response_deserializer + ) if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): return _StreamUnaryMultiCallable(thunk, method, self._interceptor) else: @@ -594,10 +730,11 @@ def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> grpc.StreamStreamMultiCallable: - thunk = lambda m: self._channel.stream_stream(m, request_serializer, - response_deserializer) + thunk = lambda m: self._channel.stream_stream( + m, request_serializer, response_deserializer + ) if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): return _StreamStreamMultiCallable(thunk, method, self._interceptor) else: @@ -619,20 +756,30 @@ def close(self): def intercept_channel( channel: grpc.Channel, - *interceptors: Optional[Sequence[Union[grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor]]] + *interceptors: Optional[ + Sequence[ + Union[ + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + ] + ] + ], ) -> grpc.Channel: for interceptor in reversed(list(interceptors)): - if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \ - not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \ - not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \ - not isinstance(interceptor, grpc.StreamStreamClientInterceptor): - raise TypeError('interceptor must be ' - 'grpc.UnaryUnaryClientInterceptor or ' - 'grpc.UnaryStreamClientInterceptor or ' - 'grpc.StreamUnaryClientInterceptor or ' - 'grpc.StreamStreamClientInterceptor or ') + if ( + not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) + and not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) + and not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) + and not isinstance(interceptor, grpc.StreamStreamClientInterceptor) + ): + raise TypeError( + "interceptor must be " + "grpc.UnaryUnaryClientInterceptor or " + "grpc.UnaryStreamClientInterceptor or " + "grpc.StreamUnaryClientInterceptor or " + "grpc.StreamStreamClientInterceptor or " + ) channel = _Channel(channel, interceptor) return channel diff --git a/src/python/grpcio/grpc/_observability.py b/src/python/grpcio/grpc/_observability.py index 8e36baf75090f..781bfbb3bc896 100644 --- a/src/python/grpcio/grpc/_observability.py +++ b/src/python/grpcio/grpc/_observability.py @@ -25,16 +25,17 @@ _LOGGER = logging.getLogger(__name__) _channel = Any # _channel.py imports this module. -ClientCallTracerCapsule = TypeVar('ClientCallTracerCapsule') -ServerCallTracerFactoryCapsule = TypeVar('ServerCallTracerFactoryCapsule') +ClientCallTracerCapsule = TypeVar("ClientCallTracerCapsule") +ServerCallTracerFactoryCapsule = TypeVar("ServerCallTracerFactoryCapsule") _plugin_lock: threading.RLock = threading.RLock() -_OBSERVABILITY_PLUGIN: Optional[ObservabilityPlugin] = None # pylint: disable=used-before-assignment +_OBSERVABILITY_PLUGIN: Optional["ObservabilityPlugin"] = None -class ObservabilityPlugin(Generic[ClientCallTracerCapsule, - ServerCallTracerFactoryCapsule], - metaclass=abc.ABCMeta): +class ObservabilityPlugin( + Generic[ClientCallTracerCapsule, ServerCallTracerFactoryCapsule], + metaclass=abc.ABCMeta, +): """Abstract base class for observability plugin. *This is a semi-private class that was intended for the exclusive use of @@ -51,12 +52,14 @@ class ObservabilityPlugin(Generic[ClientCallTracerCapsule, _stats_enabled: A bool indicates whether tracing is enabled. _tracing_enabled: A bool indicates whether stats(metrics) is enabled. """ + _tracing_enabled: bool = False _stats_enabled: bool = False @abc.abstractmethod def create_client_call_tracer( - self, method_name: bytes) -> ClientCallTracerCapsule: + self, method_name: bytes + ) -> ClientCallTracerCapsule: """Creates a ClientCallTracerCapsule. After register the plugin, if tracing or stats is enabled, this method @@ -76,7 +79,8 @@ def create_client_call_tracer( @abc.abstractmethod def delete_client_call_tracer( - self, client_call_tracer: ClientCallTracerCapsule) -> None: + self, client_call_tracer: ClientCallTracerCapsule + ) -> None: """Deletes the ClientCallTracer stored in ClientCallTracerCapsule. After register the plugin, if tracing or stats is enabled, this method @@ -91,8 +95,9 @@ def delete_client_call_tracer( raise NotImplementedError() @abc.abstractmethod - def save_trace_context(self, trace_id: str, span_id: str, - is_sampled: bool) -> None: + def save_trace_context( + self, trace_id: str, span_id: str, is_sampled: bool + ) -> None: """Saves the trace_id and span_id related to the current span. After register the plugin, if tracing is enabled, this method will be @@ -112,7 +117,8 @@ def save_trace_context(self, trace_id: str, span_id: str, @abc.abstractmethod def create_server_call_tracer_factory( - self) -> ServerCallTracerFactoryCapsule: + self, + ) -> ServerCallTracerFactoryCapsule: """Creates a ServerCallTracerFactoryCapsule. After register the plugin, if tracing or stats is enabled, this method @@ -129,8 +135,9 @@ def create_server_call_tracer_factory( raise NotImplementedError() @abc.abstractmethod - def record_rpc_latency(self, method: str, rpc_latency: float, - status_code: Any) -> None: + def record_rpc_latency( + self, method: str, rpc_latency: float, status_code: Any + ) -> None: """Record the latency of the RPC. After register the plugin, if stats is enabled, this method will be diff --git a/src/python/grpcio/grpc/_plugin_wrapping.py b/src/python/grpcio/grpc/_plugin_wrapping.py index 942264cdaea23..79900ee1dae35 100644 --- a/src/python/grpcio/grpc/_plugin_wrapping.py +++ b/src/python/grpcio/grpc/_plugin_wrapping.py @@ -26,15 +26,19 @@ class _AuthMetadataContext( - collections.namedtuple('AuthMetadataContext', ( - 'service_url', - 'method_name', - )), grpc.AuthMetadataContext): + collections.namedtuple( + "AuthMetadataContext", + ( + "service_url", + "method_name", + ), + ), + grpc.AuthMetadataContext, +): pass class _CallbackState(object): - def __init__(self): self.lock = threading.Lock() self.called = False @@ -49,24 +53,29 @@ def __init__(self, state: _CallbackState, callback: Callable): self._state = state self._callback = callback - def __call__(self, metadata: MetadataType, - error: Optional[Type[BaseException]]): + def __call__( + self, metadata: MetadataType, error: Optional[Type[BaseException]] + ): with self._state.lock: if self._state.exception is None: if self._state.called: raise RuntimeError( - 'AuthMetadataPluginCallback invoked more than once!') + "AuthMetadataPluginCallback invoked more than once!" + ) else: self._state.called = True else: raise RuntimeError( 'AuthMetadataPluginCallback raised exception "{}"!'.format( - self._state.exception)) + self._state.exception + ) + ) if error is None: self._callback(metadata, cygrpc.StatusCode.ok, None) else: - self._callback(None, cygrpc.StatusCode.internal, - _common.encode(str(error))) + self._callback( + None, cygrpc.StatusCode.internal, _common.encode(str(error)) + ) class _Plugin(object): @@ -88,27 +97,31 @@ def __init__(self, metadata_plugin: grpc.AuthMetadataPlugin): pass def __call__(self, service_url: str, method_name: str, callback: Callable): - context = _AuthMetadataContext(_common.decode(service_url), - _common.decode(method_name)) + context = _AuthMetadataContext( + _common.decode(service_url), _common.decode(method_name) + ) callback_state = _CallbackState() try: self._metadata_plugin( - context, _AuthMetadataPluginCallback(callback_state, callback)) + context, _AuthMetadataPluginCallback(callback_state, callback) + ) except Exception as exception: # pylint: disable=broad-except _LOGGER.exception( 'AuthMetadataPluginCallback "%s" raised exception!', - self._metadata_plugin) + self._metadata_plugin, + ) with callback_state.lock: callback_state.exception = exception if callback_state.called: return - callback(None, cygrpc.StatusCode.internal, - _common.encode(str(exception))) + callback( + None, cygrpc.StatusCode.internal, _common.encode(str(exception)) + ) def metadata_plugin_call_credentials( - metadata_plugin: grpc.AuthMetadataPlugin, - name: Optional[str]) -> grpc.CallCredentials: + metadata_plugin: grpc.AuthMetadataPlugin, name: Optional[str] +) -> grpc.CallCredentials: if name is None: try: effective_name = metadata_plugin.__name__ @@ -117,5 +130,7 @@ def metadata_plugin_call_credentials( else: effective_name = name return grpc.CallCredentials( - cygrpc.MetadataPluginCallCredentials(_Plugin(metadata_plugin), - _common.encode(effective_name))) + cygrpc.MetadataPluginCallCredentials( + _Plugin(metadata_plugin), _common.encode(effective_name) + ) + ) diff --git a/src/python/grpcio/grpc/_runtime_protos.py b/src/python/grpcio/grpc/_runtime_protos.py index fcc37038dacbd..7ff887e685400 100644 --- a/src/python/grpcio/grpc/_runtime_protos.py +++ b/src/python/grpcio/grpc/_runtime_protos.py @@ -19,8 +19,12 @@ _REQUIRED_SYMBOLS = ("_protos", "_services", "_protos_and_services") _MINIMUM_VERSION = (3, 5, 0) -_UNINSTALLED_TEMPLATE = "Install the grpcio-tools package (1.32.0+) to use the {} function." -_VERSION_ERROR_TEMPLATE = "The {} function is only on available on Python 3.X interpreters." +_UNINSTALLED_TEMPLATE = ( + "Install the grpcio-tools package (1.32.0+) to use the {} function." +) +_VERSION_ERROR_TEMPLATE = ( + "The {} function is only on available on Python 3.X interpreters." +) def _has_runtime_proto_symbols(mod: types.ModuleType) -> bool: @@ -30,6 +34,7 @@ def _has_runtime_proto_symbols(mod: types.ModuleType) -> bool: def _is_grpc_tools_importable() -> bool: try: import grpc_tools # pylint: disable=unused-import # pytype: disable=import-error + return True except ImportError as e: # NOTE: It's possible that we're encountering a transitive ImportError, so @@ -57,8 +62,9 @@ def _call_with_lazy_import( if not _is_grpc_tools_importable(): raise NotImplementedError(_UNINSTALLED_TEMPLATE.format(fn_name)) import grpc_tools.protoc # pytype: disable=import-error + if _has_runtime_proto_symbols(grpc_tools.protoc): - fn = getattr(grpc_tools.protoc, '_' + fn_name) + fn = getattr(grpc_tools.protoc, "_" + fn_name) return fn(protobuf_path) else: raise NotImplementedError(_UNINSTALLED_TEMPLATE.format(fn_name)) diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index d6802bfeadc1f..5d05eead76c36 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -22,8 +22,19 @@ import threading import time import traceback -from typing import (Any, Callable, Iterable, Iterator, List, Mapping, Optional, - Sequence, Set, Tuple, Union) +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) import grpc # pytype: disable=pyi-error from grpc import _common # pytype: disable=pyi-error @@ -42,22 +53,24 @@ _LOGGER = logging.getLogger(__name__) -_SHUTDOWN_TAG = 'shutdown' -_REQUEST_CALL_TAG = 'request_call' +_SHUTDOWN_TAG = "shutdown" +_REQUEST_CALL_TAG = "request_call" -_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server' -_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata' -_RECEIVE_MESSAGE_TOKEN = 'receive_message' -_SEND_MESSAGE_TOKEN = 'send_message' +_RECEIVE_CLOSE_ON_SERVER_TOKEN = "receive_close_on_server" +_SEND_INITIAL_METADATA_TOKEN = "send_initial_metadata" +_RECEIVE_MESSAGE_TOKEN = "receive_message" +_SEND_MESSAGE_TOKEN = "send_message" _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = ( - 'send_initial_metadata * send_message') -_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server' + "send_initial_metadata * send_message" +) +_SEND_STATUS_FROM_SERVER_TOKEN = "send_status_from_server" _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = ( - 'send_initial_metadata * send_status_from_server') + "send_initial_metadata * send_status_from_server" +) -_OPEN = 'open' -_CLOSED = 'closed' -_CANCELLED = 'cancelled' +_OPEN = "open" +_CLOSED = "closed" +_CANCELLED = "cancelled" _EMPTY_FLAGS = 0 @@ -81,8 +94,9 @@ def _completion_code(state: _RPCState) -> cygrpc.StatusCode: return _application_code(state.code) -def _abortion_code(state: _RPCState, - code: cygrpc.StatusCode) -> cygrpc.StatusCode: +def _abortion_code( + state: _RPCState, code: cygrpc.StatusCode +) -> cygrpc.StatusCode: if state.code is None: return code else: @@ -90,14 +104,19 @@ def _abortion_code(state: _RPCState, def _details(state: _RPCState) -> bytes: - return b'' if state.details is None else state.details + return b"" if state.details is None else state.details class _HandlerCallDetails( - collections.namedtuple('_HandlerCallDetails', ( - 'method', - 'invocation_metadata', - )), grpc.HandlerCallDetails): + collections.namedtuple( + "_HandlerCallDetails", + ( + "method", + "invocation_metadata", + ), + ), + grpc.HandlerCallDetails, +): pass @@ -140,8 +159,9 @@ def _raise_rpc_error(state: _RPCState) -> None: raise rpc_error -def _possibly_finish_call(state: _RPCState, - token: str) -> ServerTagCallbackType: +def _possibly_finish_call( + state: _RPCState, token: str +) -> ServerTagCallbackType: state.due.remove(token) if not _is_rpc_state_active(state) and not state.due: callbacks = state.callbacks @@ -152,7 +172,6 @@ def _possibly_finish_call(state: _RPCState, def _send_status_from_server(state: _RPCState, token: str) -> ServerCallbackTag: - def send_status_from_server(unused_send_status_from_server_event): with state.condition: return _possibly_finish_call(state, token) @@ -161,13 +180,15 @@ def send_status_from_server(unused_send_status_from_server_event): def _get_initial_metadata( - state: _RPCState, - metadata: Optional[MetadataType]) -> Optional[MetadataType]: + state: _RPCState, metadata: Optional[MetadataType] +) -> Optional[MetadataType]: with state.condition: if state.compression_algorithm: compression_metadata = ( _compression.compression_algorithm_to_metadata( - state.compression_algorithm),) + state.compression_algorithm + ), + ) if metadata is None: return compression_metadata else: @@ -177,39 +198,49 @@ def _get_initial_metadata( def _get_initial_metadata_operation( - state: _RPCState, metadata: Optional[MetadataType]) -> cygrpc.Operation: + state: _RPCState, metadata: Optional[MetadataType] +) -> cygrpc.Operation: operation = cygrpc.SendInitialMetadataOperation( - _get_initial_metadata(state, metadata), _EMPTY_FLAGS) + _get_initial_metadata(state, metadata), _EMPTY_FLAGS + ) return operation -def _abort(state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, - details: bytes) -> None: +def _abort( + state: _RPCState, call: cygrpc.Call, code: cygrpc.StatusCode, details: bytes +) -> None: if state.client is not _CANCELLED: effective_code = _abortion_code(state, code) effective_details = details if state.details is None else state.details if state.initial_metadata_allowed: operations = ( _get_initial_metadata_operation(state, None), - cygrpc.SendStatusFromServerOperation(state.trailing_metadata, - effective_code, - effective_details, - _EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + state.trailing_metadata, + effective_code, + effective_details, + _EMPTY_FLAGS, + ), ) token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN else: - operations = (cygrpc.SendStatusFromServerOperation( - state.trailing_metadata, effective_code, effective_details, - _EMPTY_FLAGS),) + operations = ( + cygrpc.SendStatusFromServerOperation( + state.trailing_metadata, + effective_code, + effective_details, + _EMPTY_FLAGS, + ), + ) token = _SEND_STATUS_FROM_SERVER_TOKEN - call.start_server_batch(operations, - _send_status_from_server(state, token)) + call.start_server_batch( + operations, _send_status_from_server(state, token) + ) state.statused = True state.due.add(token) def _receive_close_on_server(state: _RPCState) -> ServerCallbackTag: - def receive_close_on_server(receive_close_on_server_event): with state.condition: if receive_close_on_server_event.batch_operations[0].cancelled(): @@ -223,10 +254,10 @@ def receive_close_on_server(receive_close_on_server_event): def _receive_message( - state: _RPCState, call: cygrpc.Call, - request_deserializer: Optional[DeserializingFunction] + state: _RPCState, + call: cygrpc.Call, + request_deserializer: Optional[DeserializingFunction], ) -> ServerCallbackTag: - def receive_message(receive_message_event): serialized_request = _serialized_request(receive_message_event) if serialized_request is None: @@ -236,12 +267,17 @@ def receive_message(receive_message_event): state.condition.notify_all() return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN) else: - request = _common.deserialize(serialized_request, - request_deserializer) + request = _common.deserialize( + serialized_request, request_deserializer + ) with state.condition: if request is None: - _abort(state, call, cygrpc.StatusCode.internal, - b'Exception deserializing request!') + _abort( + state, + call, + cygrpc.StatusCode.internal, + b"Exception deserializing request!", + ) else: state.request = request state.condition.notify_all() @@ -251,7 +287,6 @@ def receive_message(receive_message_event): def _send_initial_metadata(state: _RPCState) -> ServerCallbackTag: - def send_initial_metadata(unused_send_initial_metadata_event): with state.condition: return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN) @@ -260,7 +295,6 @@ def send_initial_metadata(unused_send_initial_metadata_event): def _send_message(state: _RPCState, token: str) -> ServerCallbackTag: - def send_message(unused_send_message_event): with state.condition: state.condition.notify_all() @@ -274,8 +308,12 @@ class _Context(grpc.ServicerContext): _state: _RPCState request_deserializer: Optional[DeserializingFunction] - def __init__(self, rpc_event: cygrpc.BaseEvent, state: _RPCState, - request_deserializer: Optional[DeserializingFunction]): + def __init__( + self, + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + request_deserializer: Optional[DeserializingFunction], + ): self._rpc_event = rpc_event self._state = state self._request_deserializer = request_deserializer @@ -334,13 +372,15 @@ def send_initial_metadata(self, initial_metadata: MetadataType) -> None: else: if self._state.initial_metadata_allowed: operation = _get_initial_metadata_operation( - self._state, initial_metadata) + self._state, initial_metadata + ) self._rpc_event.call.start_server_batch( - (operation,), _send_initial_metadata(self._state)) + (operation,), _send_initial_metadata(self._state) + ) self._state.initial_metadata_allowed = False self._state.due.add(_SEND_INITIAL_METADATA_TOKEN) else: - raise ValueError('Initial metadata no longer allowed!') + raise ValueError("Initial metadata no longer allowed!") def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: with self._state.condition: @@ -353,9 +393,10 @@ def abort(self, code: grpc.StatusCode, details: str) -> None: # treat OK like other invalid arguments: fail the RPC if code == grpc.StatusCode.OK: _LOGGER.error( - 'abort() called with StatusCode.OK; returning UNKNOWN') + "abort() called with StatusCode.OK; returning UNKNOWN" + ) code = grpc.StatusCode.UNKNOWN - details = '' + details = "" with self._state.condition: self._state.code = code self._state.details = _common.encode(details) @@ -389,8 +430,12 @@ class _RequestIterator(object): _call: cygrpc.Call _request_deserializer: Optional[DeserializingFunction] - def __init__(self, state: _RPCState, call: cygrpc.Call, - request_deserializer: Optional[DeserializingFunction]): + def __init__( + self, + state: _RPCState, + call: cygrpc.Call, + request_deserializer: Optional[DeserializingFunction], + ): self._state = state self._call = call self._request_deserializer = request_deserializer @@ -403,15 +448,19 @@ def _raise_or_start_receive_message(self) -> None: else: self._call.start_server_batch( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), - _receive_message(self._state, self._call, - self._request_deserializer)) + _receive_message( + self._state, self._call, self._request_deserializer + ), + ) self._state.due.add(_RECEIVE_MESSAGE_TOKEN) def _look_for_request(self) -> Any: if self._state.client is _CANCELLED: _raise_rpc_error(self._state) - elif (self._state.request is None and - _RECEIVE_MESSAGE_TOKEN not in self._state.due): + elif ( + self._state.request is None + and _RECEIVE_MESSAGE_TOKEN not in self._state.due + ): raise StopIteration() else: request = self._state.request @@ -440,10 +489,10 @@ def next(self) -> Any: def _unary_request( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - request_deserializer: Optional[DeserializingFunction] + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + request_deserializer: Optional[DeserializingFunction], ) -> Callable[[], Any]: - def unary_request(): with state.condition: if not _is_rpc_state_active(state): @@ -451,18 +500,24 @@ def unary_request(): else: rpc_event.call.start_server_batch( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), - _receive_message(state, rpc_event.call, - request_deserializer)) + _receive_message( + state, rpc_event.call, request_deserializer + ), + ) state.due.add(_RECEIVE_MESSAGE_TOKEN) while True: state.condition.wait() if state.request is None: if state.client is _CLOSED: details = '"{}" requires exactly one request message.'.format( - rpc_event.call_details.method) - _abort(state, rpc_event.call, - cygrpc.StatusCode.unimplemented, - _common.encode(details)) + rpc_event.call_details.method + ) + _abort( + state, + rpc_event.call, + cygrpc.StatusCode.unimplemented, + _common.encode(details), + ) return None elif state.client is _CANCELLED: return None @@ -480,40 +535,56 @@ def _call_behavior( behavior: ArityAgnosticMethodHandler, argument: Any, request_deserializer: Optional[DeserializingFunction], - send_response_callback: Optional[Callable[[ResponseType], None]] = None + send_response_callback: Optional[Callable[[ResponseType], None]] = None, ) -> Tuple[Union[ResponseType, Iterator[ResponseType]], bool]: from grpc import _create_servicer_context # pytype: disable=pyi-error - with _create_servicer_context(rpc_event, state, - request_deserializer) as context: + + with _create_servicer_context( + rpc_event, state, request_deserializer + ) as context: try: response_or_iterator = None if send_response_callback is not None: - response_or_iterator = behavior(argument, context, - send_response_callback) + response_or_iterator = behavior( + argument, context, send_response_callback + ) else: response_or_iterator = behavior(argument, context) return response_or_iterator, True except Exception as exception: # pylint: disable=broad-except with state.condition: if state.aborted: - _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, - b'RPC Aborted') + _abort( + state, + rpc_event.call, + cygrpc.StatusCode.unknown, + b"RPC Aborted", + ) elif exception not in state.rpc_errors: try: - details = 'Exception calling application: {}'.format( - exception) + details = "Exception calling application: {}".format( + exception + ) except Exception: # pylint: disable=broad-except - details = 'Calling application raised unprintable Exception!' + details = ( + "Calling application raised unprintable Exception!" + ) traceback.print_exc() _LOGGER.exception(details) - _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, - _common.encode(details)) + _abort( + state, + rpc_event.call, + cygrpc.StatusCode.unknown, + _common.encode(details), + ) return None, False def _take_response_from_response_iterator( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - response_iterator: Iterator[ResponseType]) -> Tuple[ResponseType, bool]: + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + response_iterator: Iterator[ResponseType], +) -> Tuple[ResponseType, bool]: try: return next(response_iterator), True except StopIteration: @@ -521,31 +592,47 @@ def _take_response_from_response_iterator( except Exception as exception: # pylint: disable=broad-except with state.condition: if state.aborted: - _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, - b'RPC Aborted') + _abort( + state, + rpc_event.call, + cygrpc.StatusCode.unknown, + b"RPC Aborted", + ) elif exception not in state.rpc_errors: - details = 'Exception iterating responses: {}'.format(exception) + details = "Exception iterating responses: {}".format(exception) _LOGGER.exception(details) - _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, - _common.encode(details)) + _abort( + state, + rpc_event.call, + cygrpc.StatusCode.unknown, + _common.encode(details), + ) return None, False def _serialize_response( - rpc_event: cygrpc.BaseEvent, state: _RPCState, response: Any, - response_serializer: Optional[SerializingFunction]) -> Optional[bytes]: + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + response: Any, + response_serializer: Optional[SerializingFunction], +) -> Optional[bytes]: serialized_response = _common.serialize(response, response_serializer) if serialized_response is None: with state.condition: - _abort(state, rpc_event.call, cygrpc.StatusCode.internal, - b'Failed to serialize response!') + _abort( + state, + rpc_event.call, + cygrpc.StatusCode.internal, + b"Failed to serialize response!", + ) return None else: return serialized_response def _get_send_message_op_flags_from_state( - state: _RPCState) -> Union[int, cygrpc.WriteFlag]: + state: _RPCState, +) -> Union[int, cygrpc.WriteFlag]: if state.disable_next_compression: return cygrpc.WriteFlag.no_compress else: @@ -557,8 +644,9 @@ def _reset_per_message_state(state: _RPCState) -> None: state.disable_next_compression = False -def _send_response(rpc_event: cygrpc.BaseEvent, state: _RPCState, - serialized_response: bytes) -> bool: +def _send_response( + rpc_event: cygrpc.BaseEvent, state: _RPCState, serialized_response: bytes +) -> bool: with state.condition: if not _is_rpc_state_active(state): return False @@ -568,17 +656,22 @@ def _send_response(rpc_event: cygrpc.BaseEvent, state: _RPCState, _get_initial_metadata_operation(state, None), cygrpc.SendMessageOperation( serialized_response, - _get_send_message_op_flags_from_state(state)), + _get_send_message_op_flags_from_state(state), + ), ) state.initial_metadata_allowed = False token = _SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN else: - operations = (cygrpc.SendMessageOperation( - serialized_response, - _get_send_message_op_flags_from_state(state)),) + operations = ( + cygrpc.SendMessageOperation( + serialized_response, + _get_send_message_op_flags_from_state(state), + ), + ) token = _SEND_MESSAGE_TOKEN - rpc_event.call.start_server_batch(operations, - _send_message(state, token)) + rpc_event.call.start_server_batch( + operations, _send_message(state, token) + ) state.due.add(token) _reset_per_message_state(state) while True: @@ -587,16 +680,19 @@ def _send_response(rpc_event: cygrpc.BaseEvent, state: _RPCState, return _is_rpc_state_active(state) -def _status(rpc_event: cygrpc.BaseEvent, state: _RPCState, - serialized_response: Optional[bytes]) -> None: +def _status( + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + serialized_response: Optional[bytes], +) -> None: with state.condition: if state.client is not _CANCELLED: code = _completion_code(state) details = _details(state) operations = [ - cygrpc.SendStatusFromServerOperation(state.trailing_metadata, - code, details, - _EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + state.trailing_metadata, code, details, _EMPTY_FLAGS + ), ] if state.initial_metadata_allowed: operations.append(_get_initial_metadata_operation(state, None)) @@ -604,30 +700,38 @@ def _status(rpc_event: cygrpc.BaseEvent, state: _RPCState, operations.append( cygrpc.SendMessageOperation( serialized_response, - _get_send_message_op_flags_from_state(state))) + _get_send_message_op_flags_from_state(state), + ) + ) rpc_event.call.start_server_batch( operations, - _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN)) + _send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN), + ) state.statused = True _reset_per_message_state(state) state.due.add(_SEND_STATUS_FROM_SERVER_TOKEN) def _unary_response_in_pool( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - behavior: ArityAgnosticMethodHandler, argument_thunk: Callable[[], Any], - request_deserializer: Optional[SerializingFunction], - response_serializer: Optional[SerializingFunction]) -> None: + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + behavior: ArityAgnosticMethodHandler, + argument_thunk: Callable[[], Any], + request_deserializer: Optional[SerializingFunction], + response_serializer: Optional[SerializingFunction], +) -> None: cygrpc.install_context_from_request_call_event(rpc_event) try: argument = argument_thunk() if argument is not None: - response, proceed = _call_behavior(rpc_event, state, behavior, - argument, request_deserializer) + response, proceed = _call_behavior( + rpc_event, state, behavior, argument, request_deserializer + ) if proceed: serialized_response = _serialize_response( - rpc_event, state, response, response_serializer) + rpc_event, state, response, response_serializer + ) if serialized_response is not None: _status(rpc_event, state, serialized_response) except Exception: # pylint: disable=broad-except @@ -637,39 +741,48 @@ def _unary_response_in_pool( def _stream_response_in_pool( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - behavior: ArityAgnosticMethodHandler, argument_thunk: Callable[[], Any], - request_deserializer: Optional[DeserializingFunction], - response_serializer: Optional[SerializingFunction]) -> None: + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + behavior: ArityAgnosticMethodHandler, + argument_thunk: Callable[[], Any], + request_deserializer: Optional[DeserializingFunction], + response_serializer: Optional[SerializingFunction], +) -> None: cygrpc.install_context_from_request_call_event(rpc_event) def send_response(response: Any) -> None: if response is None: _status(rpc_event, state, None) else: - serialized_response = _serialize_response(rpc_event, state, - response, - response_serializer) + serialized_response = _serialize_response( + rpc_event, state, response, response_serializer + ) if serialized_response is not None: _send_response(rpc_event, state, serialized_response) try: argument = argument_thunk() if argument is not None: - if hasattr(behavior, 'experimental_non_blocking' - ) and behavior.experimental_non_blocking: - _call_behavior(rpc_event, - state, - behavior, - argument, - request_deserializer, - send_response_callback=send_response) + if ( + hasattr(behavior, "experimental_non_blocking") + and behavior.experimental_non_blocking + ): + _call_behavior( + rpc_event, + state, + behavior, + argument, + request_deserializer, + send_response_callback=send_response, + ) else: response_iterator, proceed = _call_behavior( - rpc_event, state, behavior, argument, request_deserializer) + rpc_event, state, behavior, argument, request_deserializer + ) if proceed: _send_message_callback_to_blocking_iterator_adapter( - rpc_event, state, send_response, response_iterator) + rpc_event, state, send_response, response_iterator + ) except Exception: # pylint: disable=broad-except traceback.print_exc() finally: @@ -681,12 +794,15 @@ def _is_rpc_state_active(state: _RPCState) -> bool: def _send_message_callback_to_blocking_iterator_adapter( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - send_response_callback: Callable[[ResponseType], None], - response_iterator: Iterator[ResponseType]) -> None: + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + send_response_callback: Callable[[ResponseType], None], + response_iterator: Iterator[ResponseType], +) -> None: while True: response, proceed = _take_response_from_response_iterator( - rpc_event, state, response_iterator) + rpc_event, state, response_iterator + ) if proceed: send_response_callback(response) if not _is_rpc_state_active(state): @@ -697,80 +813,115 @@ def _send_message_callback_to_blocking_iterator_adapter( def _select_thread_pool_for_behavior( behavior: ArityAgnosticMethodHandler, - default_thread_pool: futures.ThreadPoolExecutor + default_thread_pool: futures.ThreadPoolExecutor, ) -> futures.ThreadPoolExecutor: - if hasattr(behavior, 'experimental_thread_pool') and isinstance( - behavior.experimental_thread_pool, futures.ThreadPoolExecutor): + if hasattr(behavior, "experimental_thread_pool") and isinstance( + behavior.experimental_thread_pool, futures.ThreadPoolExecutor + ): return behavior.experimental_thread_pool else: return default_thread_pool def _handle_unary_unary( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - method_handler: grpc.RpcMethodHandler, - default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: - unary_request = _unary_request(rpc_event, state, - method_handler.request_deserializer) - thread_pool = _select_thread_pool_for_behavior(method_handler.unary_unary, - default_thread_pool) - return thread_pool.submit(_unary_response_in_pool, rpc_event, state, - method_handler.unary_unary, unary_request, - method_handler.request_deserializer, - method_handler.response_serializer) + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + method_handler: grpc.RpcMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor, +) -> futures.Future: + unary_request = _unary_request( + rpc_event, state, method_handler.request_deserializer + ) + thread_pool = _select_thread_pool_for_behavior( + method_handler.unary_unary, default_thread_pool + ) + return thread_pool.submit( + _unary_response_in_pool, + rpc_event, + state, + method_handler.unary_unary, + unary_request, + method_handler.request_deserializer, + method_handler.response_serializer, + ) def _handle_unary_stream( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - method_handler: grpc.RpcMethodHandler, - default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: - unary_request = _unary_request(rpc_event, state, - method_handler.request_deserializer) - thread_pool = _select_thread_pool_for_behavior(method_handler.unary_stream, - default_thread_pool) - return thread_pool.submit(_stream_response_in_pool, rpc_event, state, - method_handler.unary_stream, unary_request, - method_handler.request_deserializer, - method_handler.response_serializer) + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + method_handler: grpc.RpcMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor, +) -> futures.Future: + unary_request = _unary_request( + rpc_event, state, method_handler.request_deserializer + ) + thread_pool = _select_thread_pool_for_behavior( + method_handler.unary_stream, default_thread_pool + ) + return thread_pool.submit( + _stream_response_in_pool, + rpc_event, + state, + method_handler.unary_stream, + unary_request, + method_handler.request_deserializer, + method_handler.response_serializer, + ) def _handle_stream_unary( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - method_handler: grpc.RpcMethodHandler, - default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: - request_iterator = _RequestIterator(state, rpc_event.call, - method_handler.request_deserializer) - thread_pool = _select_thread_pool_for_behavior(method_handler.stream_unary, - default_thread_pool) - return thread_pool.submit(_unary_response_in_pool, rpc_event, state, - method_handler.stream_unary, - lambda: request_iterator, - method_handler.request_deserializer, - method_handler.response_serializer) + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + method_handler: grpc.RpcMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor, +) -> futures.Future: + request_iterator = _RequestIterator( + state, rpc_event.call, method_handler.request_deserializer + ) + thread_pool = _select_thread_pool_for_behavior( + method_handler.stream_unary, default_thread_pool + ) + return thread_pool.submit( + _unary_response_in_pool, + rpc_event, + state, + method_handler.stream_unary, + lambda: request_iterator, + method_handler.request_deserializer, + method_handler.response_serializer, + ) def _handle_stream_stream( - rpc_event: cygrpc.BaseEvent, state: _RPCState, - method_handler: grpc.RpcMethodHandler, - default_thread_pool: futures.ThreadPoolExecutor) -> futures.Future: - request_iterator = _RequestIterator(state, rpc_event.call, - method_handler.request_deserializer) - thread_pool = _select_thread_pool_for_behavior(method_handler.stream_stream, - default_thread_pool) - return thread_pool.submit(_stream_response_in_pool, rpc_event, state, - method_handler.stream_stream, - lambda: request_iterator, - method_handler.request_deserializer, - method_handler.response_serializer) + rpc_event: cygrpc.BaseEvent, + state: _RPCState, + method_handler: grpc.RpcMethodHandler, + default_thread_pool: futures.ThreadPoolExecutor, +) -> futures.Future: + request_iterator = _RequestIterator( + state, rpc_event.call, method_handler.request_deserializer + ) + thread_pool = _select_thread_pool_for_behavior( + method_handler.stream_stream, default_thread_pool + ) + return thread_pool.submit( + _stream_response_in_pool, + rpc_event, + state, + method_handler.stream_stream, + lambda: request_iterator, + method_handler.request_deserializer, + method_handler.response_serializer, + ) def _find_method_handler( - rpc_event: cygrpc.BaseEvent, generic_handlers: List[grpc.GenericRpcHandler], - interceptor_pipeline: Optional[_interceptor._ServicePipeline] + rpc_event: cygrpc.BaseEvent, + generic_handlers: List[grpc.GenericRpcHandler], + interceptor_pipeline: Optional[_interceptor._ServicePipeline], ) -> Optional[grpc.RpcMethodHandler]: - def query_handlers( - handler_call_details: _HandlerCallDetails + handler_call_details: _HandlerCallDetails, ) -> Optional[grpc.RpcMethodHandler]: for generic_handler in generic_handlers: method_handler = generic_handler.service(handler_call_details) @@ -780,91 +931,126 @@ def query_handlers( handler_call_details = _HandlerCallDetails( _common.decode(rpc_event.call_details.method), - rpc_event.invocation_metadata) + rpc_event.invocation_metadata, + ) if interceptor_pipeline is not None: - return interceptor_pipeline.execute(query_handlers, - handler_call_details) + return interceptor_pipeline.execute( + query_handlers, handler_call_details + ) else: return query_handlers(handler_call_details) -def _reject_rpc(rpc_event: cygrpc.BaseEvent, status: cygrpc.StatusCode, - details: bytes) -> _RPCState: +def _reject_rpc( + rpc_event: cygrpc.BaseEvent, status: cygrpc.StatusCode, details: bytes +) -> _RPCState: rpc_state = _RPCState() operations = ( _get_initial_metadata_operation(rpc_state, None), cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), - cygrpc.SendStatusFromServerOperation(None, status, details, - _EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + None, status, details, _EMPTY_FLAGS + ), + ) + rpc_event.call.start_server_batch( + operations, + lambda ignored_event: ( + rpc_state, + (), + ), ) - rpc_event.call.start_server_batch(operations, lambda ignored_event: ( - rpc_state, - (), - )) return rpc_state def _handle_with_method_handler( - rpc_event: cygrpc.BaseEvent, method_handler: grpc.RpcMethodHandler, - thread_pool: futures.ThreadPoolExecutor + rpc_event: cygrpc.BaseEvent, + method_handler: grpc.RpcMethodHandler, + thread_pool: futures.ThreadPoolExecutor, ) -> Tuple[_RPCState, futures.Future]: state = _RPCState() with state.condition: rpc_event.call.start_server_batch( (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), - _receive_close_on_server(state)) + _receive_close_on_server(state), + ) state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN) if method_handler.request_streaming: if method_handler.response_streaming: - return state, _handle_stream_stream(rpc_event, state, - method_handler, thread_pool) + return state, _handle_stream_stream( + rpc_event, state, method_handler, thread_pool + ) else: - return state, _handle_stream_unary(rpc_event, state, - method_handler, thread_pool) + return state, _handle_stream_unary( + rpc_event, state, method_handler, thread_pool + ) else: if method_handler.response_streaming: - return state, _handle_unary_stream(rpc_event, state, - method_handler, thread_pool) + return state, _handle_unary_stream( + rpc_event, state, method_handler, thread_pool + ) else: - return state, _handle_unary_unary(rpc_event, state, - method_handler, thread_pool) + return state, _handle_unary_unary( + rpc_event, state, method_handler, thread_pool + ) def _handle_call( - rpc_event: cygrpc.BaseEvent, generic_handlers: List[grpc.GenericRpcHandler], + rpc_event: cygrpc.BaseEvent, + generic_handlers: List[grpc.GenericRpcHandler], interceptor_pipeline: Optional[_interceptor._ServicePipeline], - thread_pool: futures.ThreadPoolExecutor, concurrency_exceeded: bool + thread_pool: futures.ThreadPoolExecutor, + concurrency_exceeded: bool, ) -> Tuple[Optional[_RPCState], Optional[futures.Future]]: if not rpc_event.success: return None, None if rpc_event.call_details.method is not None: try: - method_handler = _find_method_handler(rpc_event, generic_handlers, - interceptor_pipeline) + method_handler = _find_method_handler( + rpc_event, generic_handlers, interceptor_pipeline + ) except Exception as exception: # pylint: disable=broad-except - details = 'Exception servicing handler: {}'.format(exception) + details = "Exception servicing handler: {}".format(exception) _LOGGER.exception(details) - return _reject_rpc(rpc_event, cygrpc.StatusCode.unknown, - b'Error in service handler!'), None + return ( + _reject_rpc( + rpc_event, + cygrpc.StatusCode.unknown, + b"Error in service handler!", + ), + None, + ) if method_handler is None: - return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented, - b'Method not found!'), None + return ( + _reject_rpc( + rpc_event, + cygrpc.StatusCode.unimplemented, + b"Method not found!", + ), + None, + ) elif concurrency_exceeded: - return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted, - b'Concurrent RPC limit exceeded!'), None + return ( + _reject_rpc( + rpc_event, + cygrpc.StatusCode.resource_exhausted, + b"Concurrent RPC limit exceeded!", + ), + None, + ) else: - return _handle_with_method_handler(rpc_event, method_handler, - thread_pool) + return _handle_with_method_handler( + rpc_event, method_handler, thread_pool + ) else: return None, None @enum.unique class _ServerStage(enum.Enum): - STOPPED = 'stopped' - STARTED = 'started' - GRACE = 'grace' + STOPPED = "stopped" + STARTED = "started" + GRACE = "grace" class _ServerState(object): @@ -884,12 +1070,15 @@ class _ServerState(object): server_deallocated: bool # pylint: disable=too-many-arguments - def __init__(self, completion_queue: cygrpc.CompletionQueue, - server: cygrpc.Server, - generic_handlers: Sequence[grpc.GenericRpcHandler], - interceptor_pipeline: Optional[_interceptor._ServicePipeline], - thread_pool: futures.ThreadPoolExecutor, - maximum_concurrent_rpcs: Optional[int]): + def __init__( + self, + completion_queue: cygrpc.CompletionQueue, + server: cygrpc.Server, + generic_handlers: Sequence[grpc.GenericRpcHandler], + interceptor_pipeline: Optional[_interceptor._ServicePipeline], + thread_pool: futures.ThreadPoolExecutor, + maximum_concurrent_rpcs: Optional[int], + ): self.lock = threading.RLock() self.completion_queue = completion_queue self.server = server @@ -911,8 +1100,8 @@ def __init__(self, completion_queue: cygrpc.CompletionQueue, def _add_generic_handlers( - state: _ServerState, - generic_handlers: Iterable[grpc.GenericRpcHandler]) -> None: + state: _ServerState, generic_handlers: Iterable[grpc.GenericRpcHandler] +) -> None: with state.lock: state.generic_handlers.extend(generic_handlers) @@ -922,16 +1111,21 @@ def _add_insecure_port(state: _ServerState, address: bytes) -> int: return state.server.add_http2_port(address) -def _add_secure_port(state: _ServerState, address: bytes, - server_credentials: grpc.ServerCredentials) -> int: +def _add_secure_port( + state: _ServerState, + address: bytes, + server_credentials: grpc.ServerCredentials, +) -> int: with state.lock: - return state.server.add_http2_port(address, - server_credentials._credentials) + return state.server.add_http2_port( + address, server_credentials._credentials + ) def _request_call(state: _ServerState) -> None: - state.server.request_call(state.completion_queue, state.completion_queue, - _REQUEST_CALL_TAG) + state.server.request_call( + state.completion_queue, state.completion_queue, _REQUEST_CALL_TAG + ) state.due.add(_REQUEST_CALL_TAG) @@ -952,8 +1146,9 @@ def _on_call_completed(state: _ServerState) -> None: state.active_rpc_count -= 1 -def _process_event_and_continue(state: _ServerState, - event: cygrpc.BaseEvent) -> bool: +def _process_event_and_continue( + state: _ServerState, event: cygrpc.BaseEvent +) -> bool: should_continue = True if event.tag is _SHUTDOWN_TAG: with state.lock: @@ -964,18 +1159,23 @@ def _process_event_and_continue(state: _ServerState, with state.lock: state.due.remove(_REQUEST_CALL_TAG) concurrency_exceeded = ( - state.maximum_concurrent_rpcs is not None and - state.active_rpc_count >= state.maximum_concurrent_rpcs) - rpc_state, rpc_future = _handle_call(event, state.generic_handlers, - state.interceptor_pipeline, - state.thread_pool, - concurrency_exceeded) + state.maximum_concurrent_rpcs is not None + and state.active_rpc_count >= state.maximum_concurrent_rpcs + ) + rpc_state, rpc_future = _handle_call( + event, + state.generic_handlers, + state.interceptor_pipeline, + state.thread_pool, + concurrency_exceeded, + ) if rpc_state is not None: state.rpc_states.add(rpc_state) if rpc_future is not None: state.active_rpc_count += 1 rpc_future.add_done_callback( - lambda unused_future: _on_call_completed(state)) + lambda unused_future: _on_call_completed(state) + ) if state.stage is _ServerStage.STARTED: _request_call(state) elif _stop_serving(state): @@ -986,7 +1186,7 @@ def _process_event_and_continue(state: _ServerState, try: callback() except Exception: # pylint: disable=broad-except - _LOGGER.exception('Exception calling callback!') + _LOGGER.exception("Exception calling callback!") if rpc_state is not None: with state.lock: state.rpc_states.remove(rpc_state) @@ -1047,7 +1247,7 @@ def cancel_all_calls_after_grace(): def _start(state: _ServerState) -> None: with state.lock: if state.stage is not _ServerStage.STOPPED: - raise ValueError('Cannot start already-started server!') + raise ValueError("Cannot start already-started server!") state.server.start() state.stage = _ServerStage.STARTED _request_call(state) @@ -1058,18 +1258,20 @@ def _start(state: _ServerState) -> None: def _validate_generic_rpc_handlers( - generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]) -> None: + generic_rpc_handlers: Iterable[grpc.GenericRpcHandler], +) -> None: for generic_rpc_handler in generic_rpc_handlers: - service_attribute = getattr(generic_rpc_handler, 'service', None) + service_attribute = getattr(generic_rpc_handler, "service", None) if service_attribute is None: raise AttributeError( '"{}" must conform to grpc.GenericRpcHandler type but does ' - 'not have "service" method!'.format(generic_rpc_handler)) + 'not have "service" method!'.format(generic_rpc_handler) + ) def _augment_options( - base_options: Sequence[ChannelArgumentType], - compression: Optional[grpc.Compression] + base_options: Sequence[ChannelArgumentType], + compression: Optional[grpc.Compression], ) -> Sequence[ChannelArgumentType]: compression_option = _compression.create_channel_option(compression) return tuple(base_options) + compression_option @@ -1079,35 +1281,48 @@ class _Server(grpc.Server): _state: _ServerState # pylint: disable=too-many-arguments - def __init__(self, thread_pool: futures.ThreadPoolExecutor, - generic_handlers: Sequence[grpc.GenericRpcHandler], - interceptors: Sequence[grpc.ServerInterceptor], - options: Sequence[ChannelArgumentType], - maximum_concurrent_rpcs: Optional[int], - compression: Optional[grpc.Compression], xds: bool): + def __init__( + self, + thread_pool: futures.ThreadPoolExecutor, + generic_handlers: Sequence[grpc.GenericRpcHandler], + interceptors: Sequence[grpc.ServerInterceptor], + options: Sequence[ChannelArgumentType], + maximum_concurrent_rpcs: Optional[int], + compression: Optional[grpc.Compression], + xds: bool, + ): completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server(_augment_options(options, compression), xds) server.register_completion_queue(completion_queue) - self._state = _ServerState(completion_queue, server, generic_handlers, - _interceptor.service_pipeline(interceptors), - thread_pool, maximum_concurrent_rpcs) + self._state = _ServerState( + completion_queue, + server, + generic_handlers, + _interceptor.service_pipeline(interceptors), + thread_pool, + maximum_concurrent_rpcs, + ) def add_generic_rpc_handlers( - self, - generic_rpc_handlers: Iterable[grpc.GenericRpcHandler]) -> None: + self, generic_rpc_handlers: Iterable[grpc.GenericRpcHandler] + ) -> None: _validate_generic_rpc_handlers(generic_rpc_handlers) _add_generic_handlers(self._state, generic_rpc_handlers) def add_insecure_port(self, address: str) -> int: return _common.validate_port_binding_result( - address, _add_insecure_port(self._state, _common.encode(address))) + address, _add_insecure_port(self._state, _common.encode(address)) + ) - def add_secure_port(self, address: str, - server_credentials: grpc.ServerCredentials) -> int: + def add_secure_port( + self, address: str, server_credentials: grpc.ServerCredentials + ) -> int: return _common.validate_port_binding_result( address, - _add_secure_port(self._state, _common.encode(address), - server_credentials)) + _add_secure_port( + self._state, _common.encode(address), server_credentials + ), + ) def start(self) -> None: _start(self._state) @@ -1116,27 +1331,38 @@ def wait_for_termination(self, timeout: Optional[float] = None) -> bool: # NOTE(https://bugs.python.org/issue35935) # Remove this workaround once threading.Event.wait() is working with # CTRL+C across platforms. - return _common.wait(self._state.termination_event.wait, - self._state.termination_event.is_set, - timeout=timeout) + return _common.wait( + self._state.termination_event.wait, + self._state.termination_event.is_set, + timeout=timeout, + ) def stop(self, grace: Optional[float]) -> threading.Event: return _stop(self._state, grace) def __del__(self): - if hasattr(self, '_state'): + if hasattr(self, "_state"): # We can not grab a lock in __del__(), so set a flag to signal the # serving daemon thread (if it exists) to initiate shutdown. self._state.server_deallocated = True -def create_server(thread_pool: futures.ThreadPoolExecutor, - generic_rpc_handlers: Sequence[grpc.GenericRpcHandler], - interceptors: Sequence[grpc.ServerInterceptor], - options: Sequence[ChannelArgumentType], - maximum_concurrent_rpcs: Optional[int], - compression: Optional[grpc.Compression], - xds: bool) -> _Server: +def create_server( + thread_pool: futures.ThreadPoolExecutor, + generic_rpc_handlers: Sequence[grpc.GenericRpcHandler], + interceptors: Sequence[grpc.ServerInterceptor], + options: Sequence[ChannelArgumentType], + maximum_concurrent_rpcs: Optional[int], + compression: Optional[grpc.Compression], + xds: bool, +) -> _Server: _validate_generic_rpc_handlers(generic_rpc_handlers) - return _Server(thread_pool, generic_rpc_handlers, interceptors, options, - maximum_concurrent_rpcs, compression, xds) + return _Server( + thread_pool, + generic_rpc_handlers, + interceptors, + options, + maximum_concurrent_rpcs, + compression, + xds, + ) diff --git a/src/python/grpcio/grpc/_simple_stubs.py b/src/python/grpcio/grpc/_simple_stubs.py index 54c2a2d5dbd7f..7772860957ba4 100644 --- a/src/python/grpcio/grpc/_simple_stubs.py +++ b/src/python/grpcio/grpc/_simple_stubs.py @@ -18,27 +18,43 @@ import logging import os import threading -from typing import (Any, AnyStr, Callable, Dict, Iterator, Optional, Sequence, - Tuple, TypeVar, Union) +from typing import ( + Any, + AnyStr, + Callable, + Dict, + Iterator, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import grpc from grpc.experimental import experimental_api -RequestType = TypeVar('RequestType') -ResponseType = TypeVar('ResponseType') +RequestType = TypeVar("RequestType") +ResponseType = TypeVar("ResponseType") OptionsType = Sequence[Tuple[str, str]] -CacheKey = Tuple[str, OptionsType, Optional[grpc.ChannelCredentials], - Optional[grpc.Compression]] +CacheKey = Tuple[ + str, + OptionsType, + Optional[grpc.ChannelCredentials], + Optional[grpc.Compression], +] _LOGGER = logging.getLogger(__name__) _EVICTION_PERIOD_KEY = "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS" if _EVICTION_PERIOD_KEY in os.environ: _EVICTION_PERIOD = datetime.timedelta( - seconds=float(os.environ[_EVICTION_PERIOD_KEY])) - _LOGGER.debug("Setting managed channel eviction period to %s", - _EVICTION_PERIOD) + seconds=float(os.environ[_EVICTION_PERIOD_KEY]) + ) + _LOGGER.debug( + "Setting managed channel eviction period to %s", _EVICTION_PERIOD + ) else: _EVICTION_PERIOD = datetime.timedelta(minutes=10) @@ -57,16 +73,22 @@ _DEFAULT_TIMEOUT = 60.0 -def _create_channel(target: str, options: Sequence[Tuple[str, str]], - channel_credentials: Optional[grpc.ChannelCredentials], - compression: Optional[grpc.Compression]) -> grpc.Channel: +def _create_channel( + target: str, + options: Sequence[Tuple[str, str]], + channel_credentials: Optional[grpc.ChannelCredentials], + compression: Optional[grpc.Compression], +) -> grpc.Channel: _LOGGER.debug( - f"Creating secure channel with credentials '{channel_credentials}', " + - f"options '{options}' and compression '{compression}'") - return grpc.secure_channel(target, - credentials=channel_credentials, - options=options, - compression=compression) + f"Creating secure channel with credentials '{channel_credentials}', " + + f"options '{options}' and compression '{compression}'" + ) + return grpc.secure_channel( + target, + credentials=channel_credentials, + options=options, + compression=compression, + ) class ChannelCache: @@ -82,7 +104,8 @@ class ChannelCache: def __init__(self): self._mapping = collections.OrderedDict() self._eviction_thread = threading.Thread( - target=ChannelCache._perform_evictions, daemon=True) + target=ChannelCache._perform_evictions, daemon=True + ) self._eviction_thread.start() @staticmethod @@ -95,8 +118,9 @@ def get(): def _evict_locked(self, key: CacheKey): channel, _ = self._mapping.pop(key) - _LOGGER.debug("Evicting channel %s with configuration %s.", channel, - key) + _LOGGER.debug( + "Evicting channel %s with configuration %s.", channel, key + ) channel.close() del channel @@ -113,7 +137,8 @@ def _perform_evictions(): # And immediately reevaluate. else: key, (_, eviction_time) = next( - iter(ChannelCache._singleton._mapping.items())) + iter(ChannelCache._singleton._mapping.items()) + ) now = datetime.datetime.now() if eviction_time <= now: ChannelCache._singleton._evict_locked(key) @@ -127,16 +152,23 @@ def _perform_evictions(): # criteria are not met. ChannelCache._condition.wait(timeout=time_to_eviction) - def get_channel(self, target: str, options: Sequence[Tuple[str, str]], - channel_credentials: Optional[grpc.ChannelCredentials], - insecure: bool, - compression: Optional[grpc.Compression]) -> grpc.Channel: + def get_channel( + self, + target: str, + options: Sequence[Tuple[str, str]], + channel_credentials: Optional[grpc.ChannelCredentials], + insecure: bool, + compression: Optional[grpc.Compression], + ) -> grpc.Channel: if insecure and channel_credentials: - raise ValueError("The insecure option is mutually exclusive with " + - "the channel_credentials option. Please use one " + - "or the other.") + raise ValueError( + "The insecure option is mutually exclusive with " + + "the channel_credentials option. Please use one " + + "or the other." + ) if insecure: - channel_credentials = grpc.experimental.insecure_channel_credentials( + channel_credentials = ( + grpc.experimental.insecure_channel_credentials() ) elif channel_credentials is None: _LOGGER.debug("Defaulting to SSL channel credentials.") @@ -147,16 +179,23 @@ def get_channel(self, target: str, options: Sequence[Tuple[str, str]], if channel_data is not None: channel = channel_data[0] self._mapping.pop(key) - self._mapping[key] = (channel, datetime.datetime.now() + - _EVICTION_PERIOD) + self._mapping[key] = ( + channel, + datetime.datetime.now() + _EVICTION_PERIOD, + ) return channel else: - channel = _create_channel(target, options, channel_credentials, - compression) - self._mapping[key] = (channel, datetime.datetime.now() + - _EVICTION_PERIOD) - if len(self._mapping) == 1 or len( - self._mapping) >= _MAXIMUM_CHANNELS: + channel = _create_channel( + target, options, channel_credentials, compression + ) + self._mapping[key] = ( + channel, + datetime.datetime.now() + _EVICTION_PERIOD, + ) + if ( + len(self._mapping) == 1 + or len(self._mapping) >= _MAXIMUM_CHANNELS + ): self._condition.notify() return channel @@ -179,7 +218,7 @@ def unary_unary( compression: Optional[grpc.Compression] = None, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, - metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None + metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, ) -> ResponseType: """Invokes a unary-unary RPC without an explicitly specified channel. @@ -233,17 +272,20 @@ def unary_unary( Returns: The response to the RPC. """ - channel = ChannelCache.get().get_channel(target, options, - channel_credentials, insecure, - compression) - multicallable = channel.unary_unary(method, request_serializer, - response_deserializer) + channel = ChannelCache.get().get_channel( + target, options, channel_credentials, insecure, compression + ) + multicallable = channel.unary_unary( + method, request_serializer, response_deserializer + ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True - return multicallable(request, - metadata=metadata, - wait_for_ready=wait_for_ready, - credentials=call_credentials, - timeout=timeout) + return multicallable( + request, + metadata=metadata, + wait_for_ready=wait_for_ready, + credentials=call_credentials, + timeout=timeout, + ) @experimental_api @@ -260,7 +302,7 @@ def unary_stream( compression: Optional[grpc.Compression] = None, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, - metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None + metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, ) -> Iterator[ResponseType]: """Invokes a unary-stream RPC without an explicitly specified channel. @@ -313,17 +355,20 @@ def unary_stream( Returns: An iterator of responses. """ - channel = ChannelCache.get().get_channel(target, options, - channel_credentials, insecure, - compression) - multicallable = channel.unary_stream(method, request_serializer, - response_deserializer) + channel = ChannelCache.get().get_channel( + target, options, channel_credentials, insecure, compression + ) + multicallable = channel.unary_stream( + method, request_serializer, response_deserializer + ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True - return multicallable(request, - metadata=metadata, - wait_for_ready=wait_for_ready, - credentials=call_credentials, - timeout=timeout) + return multicallable( + request, + metadata=metadata, + wait_for_ready=wait_for_ready, + credentials=call_credentials, + timeout=timeout, + ) @experimental_api @@ -340,7 +385,7 @@ def stream_unary( compression: Optional[grpc.Compression] = None, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, - metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None + metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, ) -> ResponseType: """Invokes a stream-unary RPC without an explicitly specified channel. @@ -393,17 +438,20 @@ def stream_unary( Returns: The response to the RPC. """ - channel = ChannelCache.get().get_channel(target, options, - channel_credentials, insecure, - compression) - multicallable = channel.stream_unary(method, request_serializer, - response_deserializer) + channel = ChannelCache.get().get_channel( + target, options, channel_credentials, insecure, compression + ) + multicallable = channel.stream_unary( + method, request_serializer, response_deserializer + ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True - return multicallable(request_iterator, - metadata=metadata, - wait_for_ready=wait_for_ready, - credentials=call_credentials, - timeout=timeout) + return multicallable( + request_iterator, + metadata=metadata, + wait_for_ready=wait_for_ready, + credentials=call_credentials, + timeout=timeout, + ) @experimental_api @@ -420,7 +468,7 @@ def stream_stream( compression: Optional[grpc.Compression] = None, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, - metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None + metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, ) -> Iterator[ResponseType]: """Invokes a stream-stream RPC without an explicitly specified channel. @@ -473,14 +521,17 @@ def stream_stream( Returns: An iterator of responses. """ - channel = ChannelCache.get().get_channel(target, options, - channel_credentials, insecure, - compression) - multicallable = channel.stream_stream(method, request_serializer, - response_deserializer) + channel = ChannelCache.get().get_channel( + target, options, channel_credentials, insecure, compression + ) + multicallable = channel.stream_stream( + method, request_serializer, response_deserializer + ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True - return multicallable(request_iterator, - metadata=metadata, - wait_for_ready=wait_for_ready, - credentials=call_credentials, - timeout=timeout) + return multicallable( + request_iterator, + metadata=metadata, + wait_for_ready=wait_for_ready, + credentials=call_credentials, + timeout=timeout, + ) diff --git a/src/python/grpcio/grpc/_typing.py b/src/python/grpcio/grpc/_typing.py index d2a0b47215384..93ecb5c3cf239 100644 --- a/src/python/grpcio/grpc/_typing.py +++ b/src/python/grpcio/grpc/_typing.py @@ -13,8 +13,18 @@ # limitations under the License. """Common types for gRPC Sync API""" -from typing import (TYPE_CHECKING, Any, Callable, Iterable, Iterator, Optional, - Sequence, Tuple, TypeVar, Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) from grpc._cython import cygrpc @@ -22,8 +32,8 @@ from grpc import ServicerContext from grpc._server import _RPCState -RequestType = TypeVar('RequestType') -ResponseType = TypeVar('ResponseType') +RequestType = TypeVar("RequestType") +ResponseType = TypeVar("ResponseType") SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] MetadataType = Sequence[Tuple[str, Union[str, bytes]]] @@ -33,26 +43,53 @@ RequestIterableType = Iterable[Any] ResponseIterableType = Iterable[Any] UserTag = Callable[[cygrpc.BaseEvent], bool] -IntegratedCallFactory = Callable[[ - int, bytes, None, Optional[float], Optional[MetadataType], Optional[ - cygrpc.CallCredentials], Sequence[Sequence[cygrpc. - Operation]], UserTag, Any -], cygrpc.IntegratedCall] -ServerTagCallbackType = Tuple[Optional['_RPCState'], - Sequence[NullaryCallbackType]] +IntegratedCallFactory = Callable[ + [ + int, + bytes, + None, + Optional[float], + Optional[MetadataType], + Optional[cygrpc.CallCredentials], + Sequence[Sequence[cygrpc.Operation]], + UserTag, + Any, + ], + cygrpc.IntegratedCall, +] +ServerTagCallbackType = Tuple[ + Optional["_RPCState"], Sequence[NullaryCallbackType] +] ServerCallbackTag = Callable[[cygrpc.BaseEvent], ServerTagCallbackType] ArityAgnosticMethodHandler = Union[ - Callable[[RequestType, 'ServicerContext', Callable[[ResponseType], None]], - ResponseType], - Callable[[RequestType, 'ServicerContext', Callable[[ResponseType], None]], - Iterator[ResponseType]], - Callable[[ - Iterator[RequestType], 'ServicerContext', Callable[[ResponseType], None] - ], ResponseType], Callable[[ - Iterator[RequestType], 'ServicerContext', Callable[[ResponseType], None] - ], Iterator[ResponseType]], Callable[[RequestType, 'ServicerContext'], - ResponseType], - Callable[[RequestType, 'ServicerContext'], Iterator[ResponseType]], - Callable[[Iterator[RequestType], 'ServicerContext'], - ResponseType], Callable[[Iterator[RequestType], 'ServicerContext'], - Iterator[ResponseType]]] + Callable[ + [RequestType, "ServicerContext", Callable[[ResponseType], None]], + ResponseType, + ], + Callable[ + [RequestType, "ServicerContext", Callable[[ResponseType], None]], + Iterator[ResponseType], + ], + Callable[ + [ + Iterator[RequestType], + "ServicerContext", + Callable[[ResponseType], None], + ], + ResponseType, + ], + Callable[ + [ + Iterator[RequestType], + "ServicerContext", + Callable[[ResponseType], None], + ], + Iterator[ResponseType], + ], + Callable[[RequestType, "ServicerContext"], ResponseType], + Callable[[RequestType, "ServicerContext"], Iterator[ResponseType]], + Callable[[Iterator[RequestType], "ServicerContext"], ResponseType], + Callable[ + [Iterator[RequestType], "ServicerContext"], Iterator[ResponseType] + ], +] diff --git a/src/python/grpcio/grpc/_utilities.py b/src/python/grpcio/grpc/_utilities.py index 3dafa7a03d3d9..1ab8c4fa8f133 100644 --- a/src/python/grpcio/grpc/_utilities.py +++ b/src/python/grpcio/grpc/_utilities.py @@ -26,20 +26,26 @@ _LOGGER = logging.getLogger(__name__) _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = ( - 'Exception calling connectivity future "done" callback!') + 'Exception calling connectivity future "done" callback!' +) class RpcMethodHandler( - collections.namedtuple('_RpcMethodHandler', ( - 'request_streaming', - 'response_streaming', - 'request_deserializer', - 'response_serializer', - 'unary_unary', - 'unary_stream', - 'stream_unary', - 'stream_stream', - )), grpc.RpcMethodHandler): + collections.namedtuple( + "_RpcMethodHandler", + ( + "request_streaming", + "response_streaming", + "request_deserializer", + "response_serializer", + "unary_unary", + "unary_stream", + "stream_unary", + "stream_stream", + ), + ), + grpc.RpcMethodHandler, +): pass @@ -47,8 +53,9 @@ class DictionaryGenericHandler(grpc.ServiceRpcHandler): _name: str _method_handlers: Dict[str, grpc.RpcMethodHandler] - def __init__(self, service: str, - method_handlers: Dict[str, grpc.RpcMethodHandler]): + def __init__( + self, service: str, method_handlers: Dict[str, grpc.RpcMethodHandler] + ): self._name = service self._method_handlers = { _common.fully_qualified_method(service, method): method_handler @@ -62,7 +69,9 @@ def service( self, handler_call_details: grpc.HandlerCallDetails ) -> Optional[grpc.RpcMethodHandler]: details_method = handler_call_details.method - return self._method_handlers.get(details_method) # pytype: disable=attribute-error + return self._method_handlers.get( + details_method + ) # pytype: disable=attribute-error class _ChannelReadyFuture(grpc.Future): @@ -100,8 +109,10 @@ def _block(self, timeout: Optional[float]) -> None: def _update(self, connectivity: Optional[grpc.ChannelConnectivity]) -> None: with self._condition: - if (not self._cancelled and - connectivity is grpc.ChannelConnectivity.READY): + if ( + not self._cancelled + and connectivity is grpc.ChannelConnectivity.READY + ): self._matured = True self._channel.unsubscribe(self._update) self._condition.notify_all() diff --git a/src/python/grpcio/grpc/aio/__init__.py b/src/python/grpcio/grpc/aio/__init__.py index 3436d2ef98c88..a4e104ad51b5d 100644 --- a/src/python/grpcio/grpc/aio/__init__.py +++ b/src/python/grpcio/grpc/aio/__init__.py @@ -59,37 +59,37 @@ ################################### __all__ ################################# __all__ = ( - 'init_grpc_aio', - 'shutdown_grpc_aio', - 'AioRpcError', - 'RpcContext', - 'Call', - 'UnaryUnaryCall', - 'UnaryStreamCall', - 'StreamUnaryCall', - 'StreamStreamCall', - 'Channel', - 'UnaryUnaryMultiCallable', - 'UnaryStreamMultiCallable', - 'StreamUnaryMultiCallable', - 'StreamStreamMultiCallable', - 'ClientCallDetails', - 'ClientInterceptor', - 'UnaryStreamClientInterceptor', - 'UnaryUnaryClientInterceptor', - 'StreamUnaryClientInterceptor', - 'StreamStreamClientInterceptor', - 'InterceptedUnaryUnaryCall', - 'ServerInterceptor', - 'insecure_channel', - 'server', - 'Server', - 'ServicerContext', - 'EOF', - 'secure_channel', - 'AbortError', - 'BaseError', - 'UsageError', - 'InternalError', - 'Metadata', + "init_grpc_aio", + "shutdown_grpc_aio", + "AioRpcError", + "RpcContext", + "Call", + "UnaryUnaryCall", + "UnaryStreamCall", + "StreamUnaryCall", + "StreamStreamCall", + "Channel", + "UnaryUnaryMultiCallable", + "UnaryStreamMultiCallable", + "StreamUnaryMultiCallable", + "StreamStreamMultiCallable", + "ClientCallDetails", + "ClientInterceptor", + "UnaryStreamClientInterceptor", + "UnaryUnaryClientInterceptor", + "StreamUnaryClientInterceptor", + "StreamStreamClientInterceptor", + "InterceptedUnaryUnaryCall", + "ServerInterceptor", + "insecure_channel", + "server", + "Server", + "ServicerContext", + "EOF", + "secure_channel", + "AbortError", + "BaseError", + "UsageError", + "InternalError", + "Metadata", ) diff --git a/src/python/grpcio/grpc/aio/_base_call.py b/src/python/grpcio/grpc/aio/_base_call.py index a1226158bac54..cc648f3e1d939 100644 --- a/src/python/grpcio/grpc/aio/_base_call.py +++ b/src/python/grpcio/grpc/aio/_base_call.py @@ -30,7 +30,7 @@ from ._typing import RequestType from ._typing import ResponseType -__all__ = 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' +__all__ = "RpcContext", "Call", "UnaryUnaryCall", "UnaryStreamCall" class RpcContext(metaclass=ABCMeta): @@ -135,9 +135,9 @@ async def wait_for_connection(self) -> None: """ -class UnaryUnaryCall(Generic[RequestType, ResponseType], - Call, - metaclass=ABCMeta): +class UnaryUnaryCall( + Generic[RequestType, ResponseType], Call, metaclass=ABCMeta +): """The abstract base class of an unary-unary RPC on the client-side.""" @abstractmethod @@ -149,10 +149,9 @@ def __await__(self) -> Generator[Any, None, ResponseType]: """ -class UnaryStreamCall(Generic[RequestType, ResponseType], - Call, - metaclass=ABCMeta): - +class UnaryStreamCall( + Generic[RequestType, ResponseType], Call, metaclass=ABCMeta +): @abstractmethod def __aiter__(self) -> AsyncIterator[ResponseType]: """Returns the async iterator representation that yields messages. @@ -176,10 +175,9 @@ async def read(self) -> Union[EOFType, ResponseType]: """ -class StreamUnaryCall(Generic[RequestType, ResponseType], - Call, - metaclass=ABCMeta): - +class StreamUnaryCall( + Generic[RequestType, ResponseType], Call, metaclass=ABCMeta +): @abstractmethod async def write(self, request: RequestType) -> None: """Writes one message to the stream. @@ -205,10 +203,9 @@ def __await__(self) -> Generator[Any, None, ResponseType]: """ -class StreamStreamCall(Generic[RequestType, ResponseType], - Call, - metaclass=ABCMeta): - +class StreamStreamCall( + Generic[RequestType, ResponseType], Call, metaclass=ABCMeta +): @abstractmethod def __aiter__(self) -> AsyncIterator[ResponseType]: """Returns the async iterator representation that yields messages. diff --git a/src/python/grpcio/grpc/aio/_base_channel.py b/src/python/grpcio/grpc/aio/_base_channel.py index 04b92a424033b..2fb8e75d3a9c9 100644 --- a/src/python/grpcio/grpc/aio/_base_channel.py +++ b/src/python/grpcio/grpc/aio/_base_channel.py @@ -39,7 +39,7 @@ def __call__( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: """Asynchronously invokes the underlying RPC. @@ -77,7 +77,7 @@ def __call__( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: """Asynchronously invokes the underlying RPC. @@ -114,7 +114,7 @@ def __call__( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _base_call.StreamUnaryCall: """Asynchronously invokes the underlying RPC. @@ -152,7 +152,7 @@ def __call__( metadata: Optional[MetadataType] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _base_call.StreamStreamCall: """Asynchronously invokes the underlying RPC. @@ -218,8 +218,9 @@ async def close(self, grace: Optional[float] = None): """ @abc.abstractmethod - def get_state(self, - try_to_connect: bool = False) -> grpc.ChannelConnectivity: + def get_state( + self, try_to_connect: bool = False + ) -> grpc.ChannelConnectivity: """Checks the connectivity state of a channel. This is an EXPERIMENTAL API. @@ -270,7 +271,7 @@ def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> UnaryUnaryMultiCallable: """Creates a UnaryUnaryMultiCallable for a unary-unary method. @@ -291,7 +292,7 @@ def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> UnaryStreamMultiCallable: """Creates a UnaryStreamMultiCallable for a unary-stream method. @@ -312,7 +313,7 @@ def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> StreamUnaryMultiCallable: """Creates a StreamUnaryMultiCallable for a stream-unary method. @@ -333,7 +334,7 @@ def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> StreamStreamMultiCallable: """Creates a StreamStreamMultiCallable for a stream-stream method. diff --git a/src/python/grpcio/grpc/aio/_base_server.py b/src/python/grpcio/grpc/aio/_base_server.py index a86bbbad09f6b..904e8f35e4a92 100644 --- a/src/python/grpcio/grpc/aio/_base_server.py +++ b/src/python/grpcio/grpc/aio/_base_server.py @@ -30,8 +30,8 @@ class Server(abc.ABC): @abc.abstractmethod def add_generic_rpc_handlers( - self, - generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None: + self, generic_rpc_handlers: Sequence[grpc.GenericRpcHandler] + ) -> None: """Registers GenericRpcHandlers with this Server. This method is only safe to call before the server is started. @@ -59,8 +59,9 @@ def add_insecure_port(self, address: str) -> int: """ @abc.abstractmethod - def add_secure_port(self, address: str, - server_credentials: grpc.ServerCredentials) -> int: + def add_secure_port( + self, address: str, server_credentials: grpc.ServerCredentials + ) -> int: """Opens a secure port for accepting RPCs. A port is a communication endpoint that used by networking protocols, @@ -110,8 +111,9 @@ async def stop(self, grace: Optional[float]) -> None: """ @abc.abstractmethod - async def wait_for_termination(self, - timeout: Optional[float] = None) -> bool: + async def wait_for_termination( + self, timeout: Optional[float] = None + ) -> bool: """Continues current coroutine once the server stops. This is an EXPERIMENTAL API. @@ -162,8 +164,9 @@ async def write(self, message: ResponseType) -> None: """ @abc.abstractmethod - async def send_initial_metadata(self, - initial_metadata: MetadataType) -> None: + async def send_initial_metadata( + self, initial_metadata: MetadataType + ) -> None: """Sends the initial metadata value to the client. This method need not be called by implementations if they have no @@ -177,8 +180,9 @@ async def send_initial_metadata(self, async def abort( self, code: grpc.StatusCode, - details: str = '', - trailing_metadata: MetadataType = tuple()) -> NoReturn: + details: str = "", + trailing_metadata: MetadataType = tuple(), + ) -> NoReturn: """Raises an exception to terminate the RPC with a non-OK status. The code and details passed as arguments will supercede any existing diff --git a/src/python/grpcio/grpc/aio/_call.py b/src/python/grpcio/grpc/aio/_call.py index fcc90066c00ea..cb32f235fe5a4 100644 --- a/src/python/grpcio/grpc/aio/_call.py +++ b/src/python/grpcio/grpc/aio/_call.py @@ -35,24 +35,27 @@ from ._typing import ResponseType from ._typing import SerializingFunction -__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' +__all__ = "AioRpcError", "Call", "UnaryUnaryCall", "UnaryStreamCall" -_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' -_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!' -_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.' +_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!" +_GC_CANCELLATION_DETAILS = "Cancelled upon garbage collection!" +_RPC_ALREADY_FINISHED_DETAILS = "RPC already finished." _RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".' -_API_STYLE_ERROR = 'The iterator and read/write APIs may not be mixed on a single RPC.' - -_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' - '\tstatus = {}\n' - '\tdetails = "{}"\n' - '>') - -_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' - '\tstatus = {}\n' - '\tdetails = "{}"\n' - '\tdebug_error_string = "{}"\n' - '>') +_API_STYLE_ERROR = ( + "The iterator and read/write APIs may not be mixed on a single RPC." +) + +_OK_CALL_REPRESENTATION = ( + '<{} of RPC that terminated with:\n\tstatus = {}\n\tdetails = "{}"\n>' +) + +_NON_OK_CALL_REPRESENTATION = ( + "<{} of RPC that terminated with:\n" + "\tstatus = {}\n" + '\tdetails = "{}"\n' + '\tdebug_error_string = "{}"\n' + ">" +) _LOGGER = logging.getLogger(__name__) @@ -70,12 +73,14 @@ class AioRpcError(grpc.RpcError): _trailing_metadata: Optional[Metadata] _debug_error_string: Optional[str] - def __init__(self, - code: grpc.StatusCode, - initial_metadata: Metadata, - trailing_metadata: Metadata, - details: Optional[str] = None, - debug_error_string: Optional[str] = None) -> None: + def __init__( + self, + code: grpc.StatusCode, + initial_metadata: Metadata, + trailing_metadata: Metadata, + details: Optional[str] = None, + debug_error_string: Optional[str] = None, + ) -> None: """Constructor. Args: @@ -135,9 +140,12 @@ def debug_error_string(self) -> str: def _repr(self) -> str: """Assembles the error string for the RPC error.""" - return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__, - self._code, self._details, - self._debug_error_string) + return _NON_OK_CALL_REPRESENTATION.format( + self.__class__.__name__, + self._code, + self._details, + self._debug_error_string, + ) def __repr__(self) -> str: return self._repr() @@ -146,8 +154,9 @@ def __str__(self) -> str: return self._repr() -def _create_rpc_error(initial_metadata: Metadata, - status: cygrpc.AioRpcStatus) -> AioRpcError: +def _create_rpc_error( + initial_metadata: Metadata, status: cygrpc.AioRpcStatus +) -> AioRpcError: return AioRpcError( _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()], Metadata.from_tuple(initial_metadata), @@ -162,6 +171,7 @@ class Call: Implements logic around final status, metadata and cancellation. """ + _loop: asyncio.AbstractEventLoop _code: grpc.StatusCode _cython_call: cygrpc._AioCall @@ -169,10 +179,14 @@ class Call: _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction - def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + cython_call: cygrpc._AioCall, + metadata: Metadata, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: self._loop = loop self._cython_call = cython_call self._metadata = tuple(metadata) @@ -181,7 +195,7 @@ def __init__(self, cython_call: cygrpc._AioCall, metadata: Metadata, def __del__(self) -> None: # The '_cython_call' object might be destructed before Call object - if hasattr(self, '_cython_call'): + if hasattr(self, "_cython_call"): if not self._cython_call.done(): self._cancel(_GC_CANCELLATION_DETAILS) @@ -214,8 +228,9 @@ async def initial_metadata(self) -> Metadata: return Metadata.from_tuple(raw_metadata_tuple) async def trailing_metadata(self) -> Metadata: - raw_metadata_tuple = (await - self._cython_call.status()).trailing_metadata() + raw_metadata_tuple = ( + await self._cython_call.status() + ).trailing_metadata() return Metadata.from_tuple(raw_metadata_tuple) async def code(self) -> grpc.StatusCode: @@ -233,8 +248,9 @@ async def _raise_for_status(self) -> None: raise asyncio.CancelledError() code = await self.code() if code != grpc.StatusCode.OK: - raise _create_rpc_error(await self.initial_metadata(), await - self._cython_call.status()) + raise _create_rpc_error( + await self.initial_metadata(), await self._cython_call.status() + ) def _repr(self) -> str: return repr(self._cython_call) @@ -287,8 +303,10 @@ def __await__(self) -> Generator[Any, None, ResponseType]: if self._cython_call.is_locally_cancelled(): raise asyncio.CancelledError() else: - raise _create_rpc_error(self._cython_call._initial_metadata, - self._cython_call._status) + raise _create_rpc_error( + self._cython_call._initial_metadata, + self._cython_call._status, + ) else: return response @@ -346,8 +364,9 @@ async def _read(self) -> ResponseType: if raw_response is cygrpc.EOF: return cygrpc.EOF else: - return _common.deserialize(raw_response, - self._response_deserializer) + return _common.deserialize( + raw_response, self._response_deserializer + ) async def read(self) -> ResponseType: if self.done(): @@ -370,14 +389,16 @@ class _StreamRequestMixin(Call): _request_style: _APIStyle def _init_stream_request_mixin( - self, request_iterator: Optional[RequestIterableType]): + self, request_iterator: Optional[RequestIterableType] + ): self._metadata_sent = asyncio.Event() self._done_writing_flag = False # If user passes in an async iterator, create a consumer Task. if request_iterator is not None: self._async_request_poller = self._loop.create_task( - self._consume_request_iterator(request_iterator)) + self._consume_request_iterator(request_iterator) + ) self._request_style = _APIStyle.ASYNC_GENERATOR else: self._async_request_poller = None @@ -399,17 +420,23 @@ def _metadata_sent_observer(self): self._metadata_sent.set() async def _consume_request_iterator( - self, request_iterator: RequestIterableType) -> None: + self, request_iterator: RequestIterableType + ) -> None: try: if inspect.isasyncgen(request_iterator) or hasattr( - request_iterator, '__aiter__'): + request_iterator, "__aiter__" + ): async for request in request_iterator: try: await self._write(request) except AioRpcError as rpc_error: _LOGGER.debug( - 'Exception while consuming the request_iterator: %s', - rpc_error) + ( + "Exception while consuming the" + " request_iterator: %s" + ), + rpc_error, + ) return else: for request in request_iterator: @@ -417,8 +444,12 @@ async def _consume_request_iterator( await self._write(request) except AioRpcError as rpc_error: _LOGGER.debug( - 'Exception while consuming the request_iterator: %s', - rpc_error) + ( + "Exception while consuming the" + " request_iterator: %s" + ), + rpc_error, + ) return await self._done_writing() @@ -426,8 +457,10 @@ async def _consume_request_iterator( # Client iterators can raise exceptions, which we should handle by # cancelling the RPC and logging the client's error. No exceptions # should escape this function. - _LOGGER.debug('Client request_iterator raised exception:\n%s', - traceback.format_exc()) + _LOGGER.debug( + "Client request_iterator raised exception:\n%s", + traceback.format_exc(), + ) self.cancel() async def _write(self, request: RequestType) -> None: @@ -440,8 +473,9 @@ async def _write(self, request: RequestType) -> None: if self.done(): await self._raise_for_status() - serialized_request = _common.serialize(request, - self._request_serializer) + serialized_request = _common.serialize( + request, self._request_serializer + ) try: await self._cython_call.send_serialized_message(serialized_request) except cygrpc.InternalError: @@ -488,41 +522,55 @@ class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): Returned when an instance of `UnaryUnaryMultiCallable` object is called. """ + _request: RequestType _invocation_task: asyncio.Task # pylint: disable=too-many-arguments - def __init__(self, request: RequestType, deadline: Optional[float], - metadata: Metadata, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + request: RequestType, + deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: super().__init__( channel.call(method, deadline, credentials, wait_for_ready), - metadata, request_serializer, response_deserializer, loop) + metadata, + request_serializer, + response_deserializer, + loop, + ) self._request = request self._invocation_task = loop.create_task(self._invoke()) self._init_unary_response_mixin(self._invocation_task) async def _invoke(self) -> ResponseType: - serialized_request = _common.serialize(self._request, - self._request_serializer) + serialized_request = _common.serialize( + self._request, self._request_serializer + ) # NOTE(lidiz) asyncio.CancelledError is not a good transport for status, # because the asyncio.Task class do not cache the exception object. # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 try: serialized_response = await self._cython_call.unary_unary( - serialized_request, self._metadata) + serialized_request, self._metadata + ) except asyncio.CancelledError: if not self.cancelled(): self.cancel() if self._cython_call.is_ok(): - return _common.deserialize(serialized_response, - self._response_deserializer) + return _common.deserialize( + serialized_response, self._response_deserializer + ) else: return cygrpc.EOF @@ -537,31 +585,45 @@ class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): Returned when an instance of `UnaryStreamMultiCallable` object is called. """ + _request: RequestType _send_unary_request_task: asyncio.Task # pylint: disable=too-many-arguments - def __init__(self, request: RequestType, deadline: Optional[float], - metadata: Metadata, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + request: RequestType, + deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: super().__init__( channel.call(method, deadline, credentials, wait_for_ready), - metadata, request_serializer, response_deserializer, loop) + metadata, + request_serializer, + response_deserializer, + loop, + ) self._request = request self._send_unary_request_task = loop.create_task( - self._send_unary_request()) + self._send_unary_request() + ) self._init_stream_response_mixin(self._send_unary_request_task) async def _send_unary_request(self) -> ResponseType: - serialized_request = _common.serialize(self._request, - self._request_serializer) + serialized_request = _common.serialize( + self._request, self._request_serializer + ) try: await self._cython_call.initiate_unary_stream( - serialized_request, self._metadata) + serialized_request, self._metadata + ) except asyncio.CancelledError: if not self.cancelled(): self.cancel() @@ -574,24 +636,35 @@ async def wait_for_connection(self) -> None: # pylint: disable=too-many-ancestors -class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, - _base_call.StreamUnaryCall): +class StreamUnaryCall( + _StreamRequestMixin, _UnaryResponseMixin, Call, _base_call.StreamUnaryCall +): """Object for managing stream-unary RPC calls. Returned when an instance of `StreamUnaryMultiCallable` object is called. """ # pylint: disable=too-many-arguments - def __init__(self, request_iterator: Optional[RequestIterableType], - deadline: Optional[float], metadata: Metadata, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + request_iterator: Optional[RequestIterableType], + deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: super().__init__( channel.call(method, deadline, credentials, wait_for_ready), - metadata, request_serializer, response_deserializer, loop) + metadata, + request_serializer, + response_deserializer, + loop, + ) self._init_stream_request_mixin(request_iterator) self._init_unary_response_mixin(loop.create_task(self._conduct_rpc())) @@ -599,38 +672,52 @@ def __init__(self, request_iterator: Optional[RequestIterableType], async def _conduct_rpc(self) -> ResponseType: try: serialized_response = await self._cython_call.stream_unary( - self._metadata, self._metadata_sent_observer) + self._metadata, self._metadata_sent_observer + ) except asyncio.CancelledError: if not self.cancelled(): self.cancel() raise if self._cython_call.is_ok(): - return _common.deserialize(serialized_response, - self._response_deserializer) + return _common.deserialize( + serialized_response, self._response_deserializer + ) else: return cygrpc.EOF -class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, - _base_call.StreamStreamCall): +class StreamStreamCall( + _StreamRequestMixin, _StreamResponseMixin, Call, _base_call.StreamStreamCall +): """Object for managing stream-stream RPC calls. Returned when an instance of `StreamStreamMultiCallable` object is called. """ + _initializer: asyncio.Task # pylint: disable=too-many-arguments - def __init__(self, request_iterator: Optional[RequestIterableType], - deadline: Optional[float], metadata: Metadata, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + request_iterator: Optional[RequestIterableType], + deadline: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: super().__init__( channel.call(method, deadline, credentials, wait_for_ready), - metadata, request_serializer, response_deserializer, loop) + metadata, + request_serializer, + response_deserializer, + loop, + ) self._initializer = self._loop.create_task(self._prepare_rpc()) self._init_stream_request_mixin(request_iterator) self._init_stream_response_mixin(self._initializer) @@ -643,7 +730,8 @@ async def _prepare_rpc(self): """ try: await self._cython_call.initiate_stream_stream( - self._metadata, self._metadata_sent_observer) + self._metadata, self._metadata_sent_observer + ) except asyncio.CancelledError: if not self.cancelled(): self.cancel() diff --git a/src/python/grpcio/grpc/aio/_channel.py b/src/python/grpcio/grpc/aio/_channel.py index f40e413a48727..19c263c41f4d8 100644 --- a/src/python/grpcio/grpc/aio/_channel.py +++ b/src/python/grpcio/grpc/aio/_channel.py @@ -47,28 +47,36 @@ from ._typing import SerializingFunction from ._utils import _timeout_to_deadline -_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) +_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__) if sys.version_info[1] < 7: def _all_tasks() -> Iterable[asyncio.Task]: - return asyncio.Task.all_tasks() + return asyncio.Task.all_tasks() # pylint: disable=no-member + else: def _all_tasks() -> Iterable[asyncio.Task]: return asyncio.all_tasks() -def _augment_channel_arguments(base_options: ChannelArgumentType, - compression: Optional[grpc.Compression]): +def _augment_channel_arguments( + base_options: ChannelArgumentType, compression: Optional[grpc.Compression] +): compression_channel_argument = _compression.create_channel_option( - compression) - user_agent_channel_argument = (( - cygrpc.ChannelArgKey.primary_user_agent_string, - _USER_AGENT, - ),) - return tuple(base_options - ) + compression_channel_argument + user_agent_channel_argument + compression + ) + user_agent_channel_argument = ( + ( + cygrpc.ChannelArgKey.primary_user_agent_string, + _USER_AGENT, + ), + ) + return ( + tuple(base_options) + + compression_channel_argument + + user_agent_channel_argument + ) class _BaseMultiCallable: @@ -76,6 +84,7 @@ class _BaseMultiCallable: Handles the initialization logic and stores common attributes. """ + _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel _method: bytes @@ -106,21 +115,23 @@ def __init__( @staticmethod def _init_metadata( - metadata: Optional[Metadata] = None, - compression: Optional[grpc.Compression] = None) -> Metadata: + metadata: Optional[Metadata] = None, + compression: Optional[grpc.Compression] = None, + ) -> Metadata: """Based on the provided values for or initialise the final metadata, as it should be used for the current call. """ metadata = metadata or Metadata() if compression: metadata = Metadata( - *_compression.augment_metadata(metadata, compression)) + *_compression.augment_metadata(metadata, compression) + ) return metadata -class UnaryUnaryMultiCallable(_BaseMultiCallable, - _base_channel.UnaryUnaryMultiCallable): - +class UnaryUnaryMultiCallable( + _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable +): def __call__( self, request: RequestType, @@ -129,29 +140,43 @@ def __call__( metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]: - metadata = self._init_metadata(metadata, compression) if not self._interceptors: - call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), - metadata, credentials, wait_for_ready, - self._channel, self._method, - self._request_serializer, - self._response_deserializer, self._loop) + call = UnaryUnaryCall( + request, + _timeout_to_deadline(timeout), + metadata, + credentials, + wait_for_ready, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + self._loop, + ) else: call = InterceptedUnaryUnaryCall( - self._interceptors, request, timeout, metadata, credentials, - wait_for_ready, self._channel, self._method, - self._request_serializer, self._response_deserializer, - self._loop) + self._interceptors, + request, + timeout, + metadata, + credentials, + wait_for_ready, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + self._loop, + ) return call -class UnaryStreamMultiCallable(_BaseMultiCallable, - _base_channel.UnaryStreamMultiCallable): - +class UnaryStreamMultiCallable( + _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable +): def __call__( self, request: RequestType, @@ -160,30 +185,45 @@ def __call__( metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]: - metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: - call = UnaryStreamCall(request, deadline, metadata, credentials, - wait_for_ready, self._channel, self._method, - self._request_serializer, - self._response_deserializer, self._loop) + call = UnaryStreamCall( + request, + deadline, + metadata, + credentials, + wait_for_ready, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + self._loop, + ) else: call = InterceptedUnaryStreamCall( - self._interceptors, request, deadline, metadata, credentials, - wait_for_ready, self._channel, self._method, - self._request_serializer, self._response_deserializer, - self._loop) + self._interceptors, + request, + deadline, + metadata, + credentials, + wait_for_ready, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + self._loop, + ) return call -class StreamUnaryMultiCallable(_BaseMultiCallable, - _base_channel.StreamUnaryMultiCallable): - +class StreamUnaryMultiCallable( + _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable +): def __call__( self, request_iterator: Optional[RequestIterableType] = None, @@ -191,30 +231,45 @@ def __call__( metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _base_call.StreamUnaryCall: - metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: - call = StreamUnaryCall(request_iterator, deadline, metadata, - credentials, wait_for_ready, self._channel, - self._method, self._request_serializer, - self._response_deserializer, self._loop) + call = StreamUnaryCall( + request_iterator, + deadline, + metadata, + credentials, + wait_for_ready, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + self._loop, + ) else: call = InterceptedStreamUnaryCall( - self._interceptors, request_iterator, deadline, metadata, - credentials, wait_for_ready, self._channel, self._method, - self._request_serializer, self._response_deserializer, - self._loop) + self._interceptors, + request_iterator, + deadline, + metadata, + credentials, + wait_for_ready, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + self._loop, + ) return call -class StreamStreamMultiCallable(_BaseMultiCallable, - _base_channel.StreamStreamMultiCallable): - +class StreamStreamMultiCallable( + _BaseMultiCallable, _base_channel.StreamStreamMultiCallable +): def __call__( self, request_iterator: Optional[RequestIterableType] = None, @@ -222,23 +277,38 @@ def __call__( metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, - compression: Optional[grpc.Compression] = None + compression: Optional[grpc.Compression] = None, ) -> _base_call.StreamStreamCall: - metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: - call = StreamStreamCall(request_iterator, deadline, metadata, - credentials, wait_for_ready, self._channel, - self._method, self._request_serializer, - self._response_deserializer, self._loop) + call = StreamStreamCall( + request_iterator, + deadline, + metadata, + credentials, + wait_for_ready, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + self._loop, + ) else: call = InterceptedStreamStreamCall( - self._interceptors, request_iterator, deadline, metadata, - credentials, wait_for_ready, self._channel, self._method, - self._request_serializer, self._response_deserializer, - self._loop) + self._interceptors, + request_iterator, + deadline, + metadata, + credentials, + wait_for_ready, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + self._loop, + ) return call @@ -251,10 +321,14 @@ class Channel(_base_channel.Channel): _stream_unary_interceptors: List[StreamUnaryClientInterceptor] _stream_stream_interceptors: List[StreamStreamClientInterceptor] - def __init__(self, target: str, options: ChannelArgumentType, - credentials: Optional[grpc.ChannelCredentials], - compression: Optional[grpc.Compression], - interceptors: Optional[Sequence[ClientInterceptor]]): + def __init__( + self, + target: str, + options: ChannelArgumentType, + credentials: Optional[grpc.ChannelCredentials], + compression: Optional[grpc.Compression], + interceptors: Optional[Sequence[ClientInterceptor]], + ): """Constructor. Args: @@ -283,17 +357,20 @@ def __init__(self, target: str, options: ChannelArgumentType, self._stream_stream_interceptors.append(interceptor) else: raise ValueError( - "Interceptor {} must be ".format(interceptor) + - "{} or ".format(UnaryUnaryClientInterceptor.__name__) + - "{} or ".format(UnaryStreamClientInterceptor.__name__) + - "{} or ".format(StreamUnaryClientInterceptor.__name__) + - "{}. ".format(StreamStreamClientInterceptor.__name__)) + "Interceptor {} must be ".format(interceptor) + + "{} or ".format(UnaryUnaryClientInterceptor.__name__) + + "{} or ".format(UnaryStreamClientInterceptor.__name__) + + "{} or ".format(StreamUnaryClientInterceptor.__name__) + + "{}. ".format(StreamStreamClientInterceptor.__name__) + ) self._loop = cygrpc.get_working_loop() self._channel = cygrpc.AioChannel( _common.encode(target), - _augment_channel_arguments(options, compression), credentials, - self._loop) + _augment_channel_arguments(options, compression), + credentials, + self._loop, + ) async def __aenter__(self): return self @@ -330,7 +407,7 @@ async def _close(self, grace): # pylint: disable=too-many-branches # but not available until 3.9 or 3.8.3. So, we have to keep it # for a while. # TODO(lidiz) drop this hack after 3.8 deprecation - if 'frame' in str(attribute_error): + if "frame" in str(attribute_error): continue else: raise @@ -341,21 +418,22 @@ async def _close(self, grace): # pylint: disable=too-many-branches # Locate ones created by `aio.Call`. frame = stack[0] - candidate = frame.f_locals.get('self') + candidate = frame.f_locals.get("self") if candidate: if isinstance(candidate, _base_call.Call): - if hasattr(candidate, '_channel'): + if hasattr(candidate, "_channel"): # For intercepted Call object if candidate._channel is not self._channel: continue - elif hasattr(candidate, '_cython_call'): + elif hasattr(candidate, "_cython_call"): # For normal Call object if candidate._cython_call._channel is not self._channel: continue else: # Unidentified Call object raise cygrpc.InternalError( - f'Unrecognized call object: {candidate}') + f"Unrecognized call object: {candidate}" + ) calls.append(candidate) call_tasks.append(task) @@ -376,12 +454,13 @@ async def close(self, grace: Optional[float] = None): await self._close(grace) def __del__(self): - if hasattr(self, '_channel'): + if hasattr(self, "_channel"): if not self._channel.closed(): self._channel.close() - def get_state(self, - try_to_connect: bool = False) -> grpc.ChannelConnectivity: + def get_state( + self, try_to_connect: bool = False + ) -> grpc.ChannelConnectivity: result = self._channel.check_connectivity_state(try_to_connect) return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] @@ -390,7 +469,8 @@ async def wait_for_state_change( last_observed_state: grpc.ChannelConnectivity, ) -> None: assert await self._channel.watch_connectivity_state( - last_observed_state.value[0], None) + last_observed_state.value[0], None + ) async def channel_ready(self) -> None: state = self.get_state(try_to_connect=True) @@ -402,56 +482,73 @@ def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> UnaryUnaryMultiCallable: - return UnaryUnaryMultiCallable(self._channel, _common.encode(method), - request_serializer, - response_deserializer, - self._unary_unary_interceptors, [self], - self._loop) + return UnaryUnaryMultiCallable( + self._channel, + _common.encode(method), + request_serializer, + response_deserializer, + self._unary_unary_interceptors, + [self], + self._loop, + ) def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> UnaryStreamMultiCallable: - return UnaryStreamMultiCallable(self._channel, _common.encode(method), - request_serializer, - response_deserializer, - self._unary_stream_interceptors, [self], - self._loop) + return UnaryStreamMultiCallable( + self._channel, + _common.encode(method), + request_serializer, + response_deserializer, + self._unary_stream_interceptors, + [self], + self._loop, + ) def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> StreamUnaryMultiCallable: - return StreamUnaryMultiCallable(self._channel, _common.encode(method), - request_serializer, - response_deserializer, - self._stream_unary_interceptors, [self], - self._loop) + return StreamUnaryMultiCallable( + self._channel, + _common.encode(method), + request_serializer, + response_deserializer, + self._stream_unary_interceptors, + [self], + self._loop, + ) def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None + response_deserializer: Optional[DeserializingFunction] = None, ) -> StreamStreamMultiCallable: - return StreamStreamMultiCallable(self._channel, _common.encode(method), - request_serializer, - response_deserializer, - self._stream_stream_interceptors, - [self], self._loop) + return StreamStreamMultiCallable( + self._channel, + _common.encode(method), + request_serializer, + response_deserializer, + self._stream_stream_interceptors, + [self], + self._loop, + ) def insecure_channel( - target: str, - options: Optional[ChannelArgumentType] = None, - compression: Optional[grpc.Compression] = None, - interceptors: Optional[Sequence[ClientInterceptor]] = None): + target: str, + options: Optional[ChannelArgumentType] = None, + compression: Optional[grpc.Compression] = None, + interceptors: Optional[Sequence[ClientInterceptor]] = None, +): """Creates an insecure asynchronous Channel to a server. Args: @@ -466,15 +563,22 @@ def insecure_channel( Returns: A Channel. """ - return Channel(target, () if options is None else options, None, - compression, interceptors) - - -def secure_channel(target: str, - credentials: grpc.ChannelCredentials, - options: Optional[ChannelArgumentType] = None, - compression: Optional[grpc.Compression] = None, - interceptors: Optional[Sequence[ClientInterceptor]] = None): + return Channel( + target, + () if options is None else options, + None, + compression, + interceptors, + ) + + +def secure_channel( + target: str, + credentials: grpc.ChannelCredentials, + options: Optional[ChannelArgumentType] = None, + compression: Optional[grpc.Compression] = None, + interceptors: Optional[Sequence[ClientInterceptor]] = None, +): """Creates a secure asynchronous Channel to a server. Args: @@ -490,5 +594,10 @@ def secure_channel(target: str, Returns: An aio.Channel. """ - return Channel(target, () if options is None else options, - credentials._credentials, compression, interceptors) + return Channel( + target, + () if options is None else options, + credentials._credentials, + compression, + interceptors, + ) diff --git a/src/python/grpcio/grpc/aio/_interceptor.py b/src/python/grpcio/grpc/aio/_interceptor.py index 05f166e3b0b75..953ed2d18b7be 100644 --- a/src/python/grpcio/grpc/aio/_interceptor.py +++ b/src/python/grpcio/grpc/aio/_interceptor.py @@ -17,8 +17,16 @@ import asyncio import collections import functools -from typing import (AsyncIterable, Awaitable, Callable, Iterator, List, - Optional, Sequence, Union) +from typing import ( + AsyncIterable, + Awaitable, + Callable, + Iterator, + List, + Optional, + Sequence, + Union, +) import grpc from grpc._cython import cygrpc @@ -42,7 +50,7 @@ from ._typing import SerializingFunction from ._utils import _timeout_to_deadline -_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' +_LOCAL_CANCELLATION_DETAILS = "Locally cancelled by application!" class ServerInterceptor(metaclass=ABCMeta): @@ -53,9 +61,11 @@ class ServerInterceptor(metaclass=ABCMeta): @abstractmethod async def intercept_service( - self, continuation: Callable[[grpc.HandlerCallDetails], - Awaitable[grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: """Intercepts incoming RPCs before handing them over to a handler. @@ -74,10 +84,12 @@ async def intercept_service( class ClientCallDetails( - collections.namedtuple( - 'ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials', 'wait_for_ready')), - grpc.ClientCallDetails): + collections.namedtuple( + "ClientCallDetails", + ("method", "timeout", "metadata", "credentials", "wait_for_ready"), + ), + grpc.ClientCallDetails, +): """Describes an RPC to be invoked. This is an EXPERIMENTAL API. @@ -107,10 +119,13 @@ class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): @abstractmethod async def intercept_unary_unary( - self, continuation: Callable[[ClientCallDetails, RequestType], - UnaryUnaryCall], - client_call_details: ClientCallDetails, - request: RequestType) -> Union[UnaryUnaryCall, ResponseType]: + self, + continuation: Callable[ + [ClientCallDetails, RequestType], UnaryUnaryCall + ], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> Union[UnaryUnaryCall, ResponseType]: """Intercepts a unary-unary invocation asynchronously. Args: @@ -140,9 +155,12 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): @abstractmethod async def intercept_unary_stream( - self, continuation: Callable[[ClientCallDetails, RequestType], - UnaryStreamCall], - client_call_details: ClientCallDetails, request: RequestType + self, + continuation: Callable[ + [ClientCallDetails, RequestType], UnaryStreamCall + ], + client_call_details: ClientCallDetails, + request: RequestType, ) -> Union[ResponseIterableType, UnaryStreamCall]: """Intercepts a unary-stream invocation asynchronously. @@ -178,8 +196,9 @@ class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): @abstractmethod async def intercept_stream_unary( self, - continuation: Callable[[ClientCallDetails, RequestType], - StreamUnaryCall], + continuation: Callable[ + [ClientCallDetails, RequestType], StreamUnaryCall + ], client_call_details: ClientCallDetails, request_iterator: RequestIterableType, ) -> StreamUnaryCall: @@ -219,8 +238,9 @@ class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): @abstractmethod async def intercept_stream_stream( self, - continuation: Callable[[ClientCallDetails, RequestType], - StreamStreamCall], + continuation: Callable[ + [ClientCallDetails, RequestType], StreamStreamCall + ], client_call_details: ClientCallDetails, request_iterator: RequestIterableType, ) -> Union[ResponseIterableType, StreamStreamCall]: @@ -285,14 +305,15 @@ def __init__(self, interceptors_task: asyncio.Task) -> None: self._interceptors_task = interceptors_task self._pending_add_done_callbacks = [] self._interceptors_task.add_done_callback( - self._fire_or_add_pending_done_callbacks) + self._fire_or_add_pending_done_callbacks + ) def __del__(self): self.cancel() def _fire_or_add_pending_done_callbacks( - self, interceptors_task: asyncio.Task) -> None: - + self, interceptors_task: asyncio.Task + ) -> None: if not self._pending_add_done_callbacks: return @@ -310,14 +331,16 @@ def _fire_or_add_pending_done_callbacks( callback(self) else: for callback in self._pending_add_done_callbacks: - callback = functools.partial(self._wrap_add_done_callback, - callback) + callback = functools.partial( + self._wrap_add_done_callback, callback + ) call.add_done_callback(callback) self._pending_add_done_callbacks = [] - def _wrap_add_done_callback(self, callback: DoneCallbackType, - unused_call: _base_call.Call) -> None: + def _wrap_add_done_callback( + self, callback: DoneCallbackType, unused_call: _base_call.Call + ) -> None: callback(self) def cancel(self) -> bool: @@ -426,7 +449,7 @@ async def debug_error_string(self) -> Optional[str]: except AioRpcError as err: return err.debug_error_string() except asyncio.CancelledError: - return '' + return "" return await call.debug_error_string() @@ -436,7 +459,6 @@ async def wait_for_connection(self) -> None: class _InterceptedUnaryResponseMixin: - def __await__(self): call = yield from self._interceptors_task.__await__() response = yield from call.__await__() @@ -452,26 +474,28 @@ def _init_stream_response_mixin(self) -> None: self._response_aiter = None async def _wait_for_interceptor_task_response_iterator( - self) -> ResponseType: + self, + ) -> ResponseType: call = await self._interceptors_task async for response in call: yield response def __aiter__(self) -> AsyncIterable[ResponseType]: if self._response_aiter is None: - self._response_aiter = self._wait_for_interceptor_task_response_iterator( + self._response_aiter = ( + self._wait_for_interceptor_task_response_iterator() ) return self._response_aiter async def read(self) -> ResponseType: if self._response_aiter is None: - self._response_aiter = self._wait_for_interceptor_task_response_iterator( + self._response_aiter = ( + self._wait_for_interceptor_task_response_iterator() ) return await self._response_aiter.asend(None) class _InterceptedStreamRequestMixin: - _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]] _write_to_iterator_queue: Optional[asyncio.Queue] _status_code_task: Optional[asyncio.Task] @@ -479,14 +503,14 @@ class _InterceptedStreamRequestMixin: _FINISH_ITERATOR_SENTINEL = object() def _init_stream_request_mixin( - self, request_iterator: Optional[RequestIterableType] + self, request_iterator: Optional[RequestIterableType] ) -> RequestIterableType: - if request_iterator is None: # We provide our own request iterator which is a proxy - # of the futures writes that will be done by the caller. + # of the futures writes that will be done by the caller. self._write_to_iterator_queue = asyncio.Queue(maxsize=1) - self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator( + self._write_to_iterator_async_gen = ( + self._proxy_writes_as_request_iterator() ) self._status_code_task = None request_iterator = self._write_to_iterator_async_gen @@ -500,12 +524,16 @@ async def _proxy_writes_as_request_iterator(self): while True: value = await self._write_to_iterator_queue.get() - if value is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL: + if ( + value + is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL + ): break yield value - async def _write_to_iterator_queue_interruptible(self, request: RequestType, - call: InterceptedCall): + async def _write_to_iterator_queue_interruptible( + self, request: RequestType, call: InterceptedCall + ): # Write the specified 'request' to the request iterator queue using the # specified 'call' to allow for interruption of the write in the case # of abrupt termination of the call. @@ -513,9 +541,14 @@ async def _write_to_iterator_queue_interruptible(self, request: RequestType, self._status_code_task = self._loop.create_task(call.code()) await asyncio.wait( - (self._loop.create_task(self._write_to_iterator_queue.put(request)), - self._status_code_task), - return_when=asyncio.FIRST_COMPLETED) + ( + self._loop.create_task( + self._write_to_iterator_queue.put(request) + ), + self._status_code_task, + ), + return_when=asyncio.FIRST_COMPLETED, + ) async def write(self, request: RequestType) -> None: # If no queue was created it means that requests @@ -556,11 +589,13 @@ async def done_writing(self) -> None: raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) await self._write_to_iterator_queue_interruptible( - _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call) + _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL, call + ) -class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, - _base_call.UnaryUnaryCall): +class InterceptedUnaryUnaryCall( + _InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall +): """Used for running a `UnaryUnaryCall` wrapped by interceptors. For the `__await__` method is it is proxied to the intercepted call only when @@ -571,43 +606,64 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, _channel: cygrpc.AioChannel # pylint: disable=too-many-arguments - def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor], - request: RequestType, timeout: Optional[float], - metadata: Metadata, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + interceptors: Sequence[UnaryUnaryClientInterceptor], + request: RequestType, + timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: self._loop = loop self._channel = channel interceptors_task = loop.create_task( - self._invoke(interceptors, method, timeout, metadata, credentials, - wait_for_ready, request, request_serializer, - response_deserializer)) + self._invoke( + interceptors, + method, + timeout, + metadata, + credentials, + wait_for_ready, + request, + request_serializer, + response_deserializer, + ) + ) super().__init__(interceptors_task) # pylint: disable=too-many-arguments async def _invoke( - self, interceptors: Sequence[UnaryUnaryClientInterceptor], - method: bytes, timeout: Optional[float], - metadata: Optional[Metadata], - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], request: RequestType, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> UnaryUnaryCall: + self, + interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ) -> UnaryUnaryCall: """Run the RPC call wrapped in interceptors""" async def _run_interceptor( - interceptors: List[UnaryUnaryClientInterceptor], - client_call_details: ClientCallDetails, - request: RequestType) -> _base_call.UnaryUnaryCall: - + interceptors: List[UnaryUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> _base_call.UnaryUnaryCall: if interceptors: - continuation = functools.partial(_run_interceptor, - interceptors[1:]) + continuation = functools.partial( + _run_interceptor, interceptors[1:] + ) call_or_response = await interceptors[0].intercept_unary_unary( - continuation, client_call_details, request) + continuation, client_call_details, request + ) if isinstance(call_or_response, _base_call.UnaryUnaryCall): return call_or_response @@ -616,24 +672,32 @@ async def _run_interceptor( else: return UnaryUnaryCall( - request, _timeout_to_deadline(client_call_details.timeout), + request, + _timeout_to_deadline(client_call_details.timeout), client_call_details.metadata, client_call_details.credentials, - client_call_details.wait_for_ready, self._channel, - client_call_details.method, request_serializer, - response_deserializer, self._loop) - - client_call_details = ClientCallDetails(method, timeout, metadata, - credentials, wait_for_ready) - return await _run_interceptor(list(interceptors), client_call_details, - request) + client_call_details.wait_for_ready, + self._channel, + client_call_details.method, + request_serializer, + response_deserializer, + self._loop, + ) + + client_call_details = ClientCallDetails( + method, timeout, metadata, credentials, wait_for_ready + ) + return await _run_interceptor( + list(interceptors), client_call_details, request + ) def time_remaining(self) -> Optional[float]: raise NotImplementedError() -class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, - InterceptedCall, _base_call.UnaryStreamCall): +class InterceptedUnaryStreamCall( + _InterceptedStreamResponseMixin, InterceptedCall, _base_call.UnaryStreamCall +): """Used for running a `UnaryStreamCall` wrapped by interceptors.""" _loop: asyncio.AbstractEventLoop @@ -641,33 +705,52 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] # pylint: disable=too-many-arguments - def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor], - request: RequestType, timeout: Optional[float], - metadata: Metadata, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + interceptors: Sequence[UnaryStreamClientInterceptor], + request: RequestType, + timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: self._loop = loop self._channel = channel self._init_stream_response_mixin() self._last_returned_call_from_interceptors = None interceptors_task = loop.create_task( - self._invoke(interceptors, method, timeout, metadata, credentials, - wait_for_ready, request, request_serializer, - response_deserializer)) + self._invoke( + interceptors, + method, + timeout, + metadata, + credentials, + wait_for_ready, + request, + request_serializer, + response_deserializer, + ) + ) super().__init__(interceptors_task) # pylint: disable=too-many-arguments async def _invoke( - self, interceptors: Sequence[UnaryUnaryClientInterceptor], - method: bytes, timeout: Optional[float], - metadata: Optional[Metadata], - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], request: RequestType, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> UnaryStreamCall: + self, + interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ) -> UnaryStreamCall: """Run the RPC call wrapped in interceptors""" async def _run_interceptor( @@ -675,46 +758,64 @@ async def _run_interceptor( client_call_details: ClientCallDetails, request: RequestType, ) -> _base_call.UnaryUnaryCall: - if interceptors: - continuation = functools.partial(_run_interceptor, - interceptors[1:]) + continuation = functools.partial( + _run_interceptor, interceptors[1:] + ) call_or_response_iterator = await interceptors[ - 0].intercept_unary_stream(continuation, client_call_details, - request) - - if isinstance(call_or_response_iterator, - _base_call.UnaryStreamCall): - self._last_returned_call_from_interceptors = call_or_response_iterator + 0 + ].intercept_unary_stream( + continuation, client_call_details, request + ) + + if isinstance( + call_or_response_iterator, _base_call.UnaryStreamCall + ): + self._last_returned_call_from_interceptors = ( + call_or_response_iterator + ) else: - self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator( - self._last_returned_call_from_interceptors, - call_or_response_iterator) + self._last_returned_call_from_interceptors = ( + UnaryStreamCallResponseIterator( + self._last_returned_call_from_interceptors, + call_or_response_iterator, + ) + ) return self._last_returned_call_from_interceptors else: self._last_returned_call_from_interceptors = UnaryStreamCall( - request, _timeout_to_deadline(client_call_details.timeout), + request, + _timeout_to_deadline(client_call_details.timeout), client_call_details.metadata, client_call_details.credentials, - client_call_details.wait_for_ready, self._channel, - client_call_details.method, request_serializer, - response_deserializer, self._loop) + client_call_details.wait_for_ready, + self._channel, + client_call_details.method, + request_serializer, + response_deserializer, + self._loop, + ) return self._last_returned_call_from_interceptors - client_call_details = ClientCallDetails(method, timeout, metadata, - credentials, wait_for_ready) - return await _run_interceptor(list(interceptors), client_call_details, - request) + client_call_details = ClientCallDetails( + method, timeout, metadata, credentials, wait_for_ready + ) + return await _run_interceptor( + list(interceptors), client_call_details, request + ) def time_remaining(self) -> Optional[float]: raise NotImplementedError() -class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, - _InterceptedStreamRequestMixin, - InterceptedCall, _base_call.StreamUnaryCall): +class InterceptedStreamUnaryCall( + _InterceptedUnaryResponseMixin, + _InterceptedStreamRequestMixin, + InterceptedCall, + _base_call.StreamUnaryCall, +): """Used for running a `StreamUnaryCall` wrapped by interceptors. For the `__await__` method is it is proxied to the intercepted call only when @@ -725,69 +826,97 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, _channel: cygrpc.AioChannel # pylint: disable=too-many-arguments - def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor], - request_iterator: Optional[RequestIterableType], - timeout: Optional[float], metadata: Metadata, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + interceptors: Sequence[StreamUnaryClientInterceptor], + request_iterator: Optional[RequestIterableType], + timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: self._loop = loop self._channel = channel request_iterator = self._init_stream_request_mixin(request_iterator) interceptors_task = loop.create_task( - self._invoke(interceptors, method, timeout, metadata, credentials, - wait_for_ready, request_iterator, request_serializer, - response_deserializer)) + self._invoke( + interceptors, + method, + timeout, + metadata, + credentials, + wait_for_ready, + request_iterator, + request_serializer, + response_deserializer, + ) + ) super().__init__(interceptors_task) # pylint: disable=too-many-arguments async def _invoke( - self, interceptors: Sequence[StreamUnaryClientInterceptor], - method: bytes, timeout: Optional[float], - metadata: Optional[Metadata], - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], - request_iterator: RequestIterableType, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> StreamUnaryCall: + self, + interceptors: Sequence[StreamUnaryClientInterceptor], + method: bytes, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request_iterator: RequestIterableType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ) -> StreamUnaryCall: """Run the RPC call wrapped in interceptors""" async def _run_interceptor( interceptors: Iterator[UnaryUnaryClientInterceptor], client_call_details: ClientCallDetails, - request_iterator: RequestIterableType + request_iterator: RequestIterableType, ) -> _base_call.StreamUnaryCall: - if interceptors: - continuation = functools.partial(_run_interceptor, - interceptors[1:]) + continuation = functools.partial( + _run_interceptor, interceptors[1:] + ) return await interceptors[0].intercept_stream_unary( - continuation, client_call_details, request_iterator) + continuation, client_call_details, request_iterator + ) else: return StreamUnaryCall( request_iterator, _timeout_to_deadline(client_call_details.timeout), client_call_details.metadata, client_call_details.credentials, - client_call_details.wait_for_ready, self._channel, - client_call_details.method, request_serializer, - response_deserializer, self._loop) - - client_call_details = ClientCallDetails(method, timeout, metadata, - credentials, wait_for_ready) - return await _run_interceptor(list(interceptors), client_call_details, - request_iterator) + client_call_details.wait_for_ready, + self._channel, + client_call_details.method, + request_serializer, + response_deserializer, + self._loop, + ) + + client_call_details = ClientCallDetails( + method, timeout, metadata, credentials, wait_for_ready + ) + return await _run_interceptor( + list(interceptors), client_call_details, request_iterator + ) def time_remaining(self) -> Optional[float]: raise NotImplementedError() -class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, - _InterceptedStreamRequestMixin, - InterceptedCall, _base_call.StreamStreamCall): +class InterceptedStreamStreamCall( + _InterceptedStreamResponseMixin, + _InterceptedStreamRequestMixin, + InterceptedCall, + _base_call.StreamStreamCall, +): """Used for running a `StreamStreamCall` wrapped by interceptors.""" _loop: asyncio.AbstractEventLoop @@ -795,59 +924,84 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] # pylint: disable=too-many-arguments - def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor], - request_iterator: Optional[RequestIterableType], - timeout: Optional[float], metadata: Metadata, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + interceptors: Sequence[StreamStreamClientInterceptor], + request_iterator: Optional[RequestIterableType], + timeout: Optional[float], + metadata: Metadata, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + channel: cygrpc.AioChannel, + method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop, + ) -> None: self._loop = loop self._channel = channel self._init_stream_response_mixin() request_iterator = self._init_stream_request_mixin(request_iterator) self._last_returned_call_from_interceptors = None interceptors_task = loop.create_task( - self._invoke(interceptors, method, timeout, metadata, credentials, - wait_for_ready, request_iterator, request_serializer, - response_deserializer)) + self._invoke( + interceptors, + method, + timeout, + metadata, + credentials, + wait_for_ready, + request_iterator, + request_serializer, + response_deserializer, + ) + ) super().__init__(interceptors_task) # pylint: disable=too-many-arguments async def _invoke( - self, interceptors: Sequence[StreamStreamClientInterceptor], - method: bytes, timeout: Optional[float], - metadata: Optional[Metadata], - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], - request_iterator: RequestIterableType, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> StreamStreamCall: + self, + interceptors: Sequence[StreamStreamClientInterceptor], + method: bytes, + timeout: Optional[float], + metadata: Optional[Metadata], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request_iterator: RequestIterableType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + ) -> StreamStreamCall: """Run the RPC call wrapped in interceptors""" async def _run_interceptor( interceptors: List[StreamStreamClientInterceptor], client_call_details: ClientCallDetails, - request_iterator: RequestIterableType + request_iterator: RequestIterableType, ) -> _base_call.StreamStreamCall: - if interceptors: - continuation = functools.partial(_run_interceptor, - interceptors[1:]) + continuation = functools.partial( + _run_interceptor, interceptors[1:] + ) call_or_response_iterator = await interceptors[ - 0].intercept_stream_stream(continuation, - client_call_details, - request_iterator) - - if isinstance(call_or_response_iterator, - _base_call.StreamStreamCall): - self._last_returned_call_from_interceptors = call_or_response_iterator + 0 + ].intercept_stream_stream( + continuation, client_call_details, request_iterator + ) + + if isinstance( + call_or_response_iterator, _base_call.StreamStreamCall + ): + self._last_returned_call_from_interceptors = ( + call_or_response_iterator + ) else: - self._last_returned_call_from_interceptors = StreamStreamCallResponseIterator( - self._last_returned_call_from_interceptors, - call_or_response_iterator) + self._last_returned_call_from_interceptors = ( + StreamStreamCallResponseIterator( + self._last_returned_call_from_interceptors, + call_or_response_iterator, + ) + ) return self._last_returned_call_from_interceptors else: self._last_returned_call_from_interceptors = StreamStreamCall( @@ -855,15 +1009,21 @@ async def _run_interceptor( _timeout_to_deadline(client_call_details.timeout), client_call_details.metadata, client_call_details.credentials, - client_call_details.wait_for_ready, self._channel, - client_call_details.method, request_serializer, - response_deserializer, self._loop) + client_call_details.wait_for_ready, + self._channel, + client_call_details.method, + request_serializer, + response_deserializer, + self._loop, + ) return self._last_returned_call_from_interceptors - client_call_details = ClientCallDetails(method, timeout, metadata, - credentials, wait_for_ready) - return await _run_interceptor(list(interceptors), client_call_details, - request_iterator) + client_call_details = ClientCallDetails( + method, timeout, metadata, credentials, wait_for_ready + ) + return await _run_interceptor( + list(interceptors), client_call_details, request_iterator + ) def time_remaining(self) -> Optional[float]: raise NotImplementedError() @@ -871,6 +1031,7 @@ def time_remaining(self) -> Optional[float]: class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): """Final UnaryUnaryCall class finished with a response.""" + _response: ResponseType def __init__(self, response: ResponseType) -> None: @@ -901,7 +1062,7 @@ async def code(self) -> grpc.StatusCode: return grpc.StatusCode.OK async def details(self) -> str: - return '' + return "" async def debug_error_string(self) -> Optional[str]: return None @@ -918,13 +1079,14 @@ async def wait_for_connection(self) -> None: class _StreamCallResponseIterator: - _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall] _response_iterator: AsyncIterable[ResponseType] - def __init__(self, call: Union[_base_call.UnaryStreamCall, - _base_call.StreamStreamCall], - response_iterator: AsyncIterable[ResponseType]) -> None: + def __init__( + self, + call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall], + response_iterator: AsyncIterable[ResponseType], + ) -> None: self._response_iterator = response_iterator self._call = call @@ -965,8 +1127,9 @@ async def wait_for_connection(self) -> None: return await self._call.wait_for_connection() -class UnaryStreamCallResponseIterator(_StreamCallResponseIterator, - _base_call.UnaryStreamCall): +class UnaryStreamCallResponseIterator( + _StreamCallResponseIterator, _base_call.UnaryStreamCall +): """UnaryStreamCall class wich uses an alternative response iterator.""" async def read(self) -> ResponseType: @@ -975,8 +1138,9 @@ async def read(self) -> ResponseType: raise NotImplementedError() -class StreamStreamCallResponseIterator(_StreamCallResponseIterator, - _base_call.StreamStreamCall): +class StreamStreamCallResponseIterator( + _StreamCallResponseIterator, _base_call.StreamStreamCall +): """StreamStreamCall class wich uses an alternative response iterator.""" async def read(self) -> ResponseType: diff --git a/src/python/grpcio/grpc/aio/_metadata.py b/src/python/grpcio/grpc/aio/_metadata.py index 970f62c0590db..230cb0df7864c 100644 --- a/src/python/grpcio/grpc/aio/_metadata.py +++ b/src/python/grpcio/grpc/aio/_metadata.py @@ -108,7 +108,7 @@ def __eq__(self, other: Any) -> bool: return tuple(self) == other return NotImplemented # pytype: disable=bad-return-type - def __add__(self, other: Any) -> 'Metadata': + def __add__(self, other: Any) -> "Metadata": if isinstance(other, self.__class__): return Metadata(*(tuple(self) + tuple(other))) if isinstance(other, tuple): diff --git a/src/python/grpcio/grpc/aio/_server.py b/src/python/grpcio/grpc/aio/_server.py index 1465ab6bbb0ee..8daa18ddd5155 100644 --- a/src/python/grpcio/grpc/aio/_server.py +++ b/src/python/grpcio/grpc/aio/_server.py @@ -26,8 +26,9 @@ from ._typing import ChannelArgumentType -def _augment_channel_arguments(base_options: ChannelArgumentType, - compression: Optional[grpc.Compression]): +def _augment_channel_arguments( + base_options: ChannelArgumentType, compression: Optional[grpc.Compression] +): compression_option = _compression.create_channel_option(compression) return tuple(base_options) + compression_option @@ -35,30 +36,39 @@ def _augment_channel_arguments(base_options: ChannelArgumentType, class Server(_base_server.Server): """Serves RPCs.""" - def __init__(self, thread_pool: Optional[Executor], - generic_handlers: Optional[Sequence[grpc.GenericRpcHandler]], - interceptors: Optional[Sequence[Any]], - options: ChannelArgumentType, - maximum_concurrent_rpcs: Optional[int], - compression: Optional[grpc.Compression]): + def __init__( + self, + thread_pool: Optional[Executor], + generic_handlers: Optional[Sequence[grpc.GenericRpcHandler]], + interceptors: Optional[Sequence[Any]], + options: ChannelArgumentType, + maximum_concurrent_rpcs: Optional[int], + compression: Optional[grpc.Compression], + ): self._loop = cygrpc.get_working_loop() if interceptors: invalid_interceptors = [ - interceptor for interceptor in interceptors + interceptor + for interceptor in interceptors if not isinstance(interceptor, ServerInterceptor) ] if invalid_interceptors: raise ValueError( - 'Interceptor must be ServerInterceptor, the ' - f'following are invalid: {invalid_interceptors}') + "Interceptor must be ServerInterceptor, the " + f"following are invalid: {invalid_interceptors}" + ) self._server = cygrpc.AioServer( - self._loop, thread_pool, generic_handlers, interceptors, + self._loop, + thread_pool, + generic_handlers, + interceptors, _augment_channel_arguments(options, compression), - maximum_concurrent_rpcs) + maximum_concurrent_rpcs, + ) def add_generic_rpc_handlers( - self, - generic_rpc_handlers: Sequence[grpc.GenericRpcHandler]) -> None: + self, generic_rpc_handlers: Sequence[grpc.GenericRpcHandler] + ) -> None: """Registers GenericRpcHandlers with this Server. This method is only safe to call before the server is started. @@ -82,10 +92,12 @@ def add_insecure_port(self, address: str) -> int: An integer port on which the server will accept RPC requests. """ return _common.validate_port_binding_result( - address, self._server.add_insecure_port(_common.encode(address))) + address, self._server.add_insecure_port(_common.encode(address)) + ) - def add_secure_port(self, address: str, - server_credentials: grpc.ServerCredentials) -> int: + def add_secure_port( + self, address: str, server_credentials: grpc.ServerCredentials + ) -> int: """Opens a secure port for accepting RPCs. This method may only be called before starting the server. @@ -101,8 +113,10 @@ def add_secure_port(self, address: str, """ return _common.validate_port_binding_result( address, - self._server.add_secure_port(_common.encode(address), - server_credentials)) + self._server.add_secure_port( + _common.encode(address), server_credentials + ), + ) async def start(self) -> None: """Starts this Server. @@ -135,8 +149,9 @@ async def stop(self, grace: Optional[float]) -> None: """ await self._server.shutdown(grace) - async def wait_for_termination(self, - timeout: Optional[float] = None) -> bool: + async def wait_for_termination( + self, timeout: Optional[float] = None + ) -> bool: """Block current coroutine until the server stops. This is an EXPERIMENTAL API. @@ -165,7 +180,7 @@ def __del__(self): The Cython AioServer doesn't hold a ref-count to this class. It should be safe to slightly extend the underlying Cython object's life span. """ - if hasattr(self, '_server'): + if hasattr(self, "_server"): if self._server.is_running(): cygrpc.schedule_coro_threadsafe( self._server.shutdown(None), @@ -173,12 +188,14 @@ def __del__(self): ) -def server(migration_thread_pool: Optional[Executor] = None, - handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None, - interceptors: Optional[Sequence[Any]] = None, - options: Optional[ChannelArgumentType] = None, - maximum_concurrent_rpcs: Optional[int] = None, - compression: Optional[grpc.Compression] = None): +def server( + migration_thread_pool: Optional[Executor] = None, + handlers: Optional[Sequence[grpc.GenericRpcHandler]] = None, + interceptors: Optional[Sequence[Any]] = None, + options: Optional[ChannelArgumentType] = None, + maximum_concurrent_rpcs: Optional[int] = None, + compression: Optional[grpc.Compression] = None, +): """Creates a Server with which RPCs can be serviced. Args: @@ -203,7 +220,11 @@ def server(migration_thread_pool: Optional[Executor] = None, Returns: A Server object. """ - return Server(migration_thread_pool, () if handlers is None else handlers, - () if interceptors is None else interceptors, - () if options is None else options, maximum_concurrent_rpcs, - compression) + return Server( + migration_thread_pool, + () if handlers is None else handlers, + () if interceptors is None else interceptors, + () if options is None else options, + maximum_concurrent_rpcs, + compression, + ) diff --git a/src/python/grpcio/grpc/aio/_typing.py b/src/python/grpcio/grpc/aio/_typing.py index f9c0eb10fc773..0bc32b22e6fe8 100644 --- a/src/python/grpcio/grpc/aio/_typing.py +++ b/src/python/grpcio/grpc/aio/_typing.py @@ -13,8 +13,16 @@ # limitations under the License. """Common types for gRPC Async API""" -from typing import (Any, AsyncIterable, Callable, Iterable, Sequence, Tuple, - TypeVar, Union) +from typing import ( + Any, + AsyncIterable, + Callable, + Iterable, + Sequence, + Tuple, + TypeVar, + Union, +) from grpc._cython.cygrpc import EOF @@ -22,8 +30,8 @@ from ._metadata import MetadataKey from ._metadata import MetadataValue -RequestType = TypeVar('RequestType') -ResponseType = TypeVar('ResponseType') +RequestType = TypeVar("RequestType") +ResponseType = TypeVar("ResponseType") SerializingFunction = Callable[[Any], bytes] DeserializingFunction = Callable[[bytes], Any] MetadatumType = Tuple[MetadataKey, MetadataValue] diff --git a/src/python/grpcio/grpc/beta/_client_adaptations.py b/src/python/grpcio/grpc/beta/_client_adaptations.py index 652ae0ea171f9..012149212a238 100644 --- a/src/python/grpcio/grpc/beta/_client_adaptations.py +++ b/src/python/grpcio/grpc/beta/_client_adaptations.py @@ -24,14 +24,22 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = { - grpc.StatusCode.CANCELLED: - (face.Abortion.Kind.CANCELLED, face.CancellationError), - grpc.StatusCode.UNKNOWN: - (face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError), - grpc.StatusCode.DEADLINE_EXCEEDED: - (face.Abortion.Kind.EXPIRED, face.ExpirationError), - grpc.StatusCode.UNIMPLEMENTED: - (face.Abortion.Kind.LOCAL_FAILURE, face.LocalError), + grpc.StatusCode.CANCELLED: ( + face.Abortion.Kind.CANCELLED, + face.CancellationError, + ), + grpc.StatusCode.UNKNOWN: ( + face.Abortion.Kind.REMOTE_FAILURE, + face.RemoteError, + ), + grpc.StatusCode.DEADLINE_EXCEEDED: ( + face.Abortion.Kind.EXPIRED, + face.ExpirationError, + ), + grpc.StatusCode.UNIMPLEMENTED: ( + face.Abortion.Kind.LOCAL_FAILURE, + face.LocalError, + ), } @@ -51,28 +59,33 @@ def _abortion(rpc_error_call): code = rpc_error_call.code() pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0] - return face.Abortion(error_kind, rpc_error_call.initial_metadata(), - rpc_error_call.trailing_metadata(), code, - rpc_error_call.details()) + return face.Abortion( + error_kind, + rpc_error_call.initial_metadata(), + rpc_error_call.trailing_metadata(), + code, + rpc_error_call.details(), + ) def _abortion_error(rpc_error_call): code = rpc_error_call.code() pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code) exception_class = face.AbortionError if pair is None else pair[1] - return exception_class(rpc_error_call.initial_metadata(), - rpc_error_call.trailing_metadata(), code, - rpc_error_call.details()) + return exception_class( + rpc_error_call.initial_metadata(), + rpc_error_call.trailing_metadata(), + code, + rpc_error_call.details(), + ) class _InvocationProtocolContext(interfaces.GRPCInvocationContext): - def disable_next_request_compression(self): pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement. class _Rendezvous(future.Future, face.Call): - def __init__(self, response_future, response_iterator, call): self._future = response_future self._iterator = response_iterator @@ -145,7 +158,6 @@ def time_remaining(self): return self._call.time_remaining() def add_abortion_callback(self, abortion_callback): - def done_callback(): if self.code() is not grpc.StatusCode.OK: abortion_callback(_abortion(self._call)) @@ -169,125 +181,202 @@ def details(self): return self._call.details() -def _blocking_unary_unary(channel, group, method, timeout, with_call, - protocol_options, metadata, metadata_transformer, - request, request_serializer, response_deserializer): +def _blocking_unary_unary( + channel, + group, + method, + timeout, + with_call, + protocol_options, + metadata, + metadata_transformer, + request, + request_serializer, + response_deserializer, +): try: multi_callable = channel.unary_unary( _common.fully_qualified_method(group, method), request_serializer=request_serializer, - response_deserializer=response_deserializer) + response_deserializer=response_deserializer, + ) effective_metadata = _effective_metadata(metadata, metadata_transformer) if with_call: response, call = multi_callable.with_call( request, timeout=timeout, metadata=_metadata.unbeta(effective_metadata), - credentials=_credentials(protocol_options)) + credentials=_credentials(protocol_options), + ) return response, _Rendezvous(None, None, call) else: - return multi_callable(request, - timeout=timeout, - metadata=_metadata.unbeta(effective_metadata), - credentials=_credentials(protocol_options)) + return multi_callable( + request, + timeout=timeout, + metadata=_metadata.unbeta(effective_metadata), + credentials=_credentials(protocol_options), + ) except grpc.RpcError as rpc_error_call: raise _abortion_error(rpc_error_call) -def _future_unary_unary(channel, group, method, timeout, protocol_options, - metadata, metadata_transformer, request, - request_serializer, response_deserializer): +def _future_unary_unary( + channel, + group, + method, + timeout, + protocol_options, + metadata, + metadata_transformer, + request, + request_serializer, + response_deserializer, +): multi_callable = channel.unary_unary( _common.fully_qualified_method(group, method), request_serializer=request_serializer, - response_deserializer=response_deserializer) + response_deserializer=response_deserializer, + ) effective_metadata = _effective_metadata(metadata, metadata_transformer) response_future = multi_callable.future( request, timeout=timeout, metadata=_metadata.unbeta(effective_metadata), - credentials=_credentials(protocol_options)) + credentials=_credentials(protocol_options), + ) return _Rendezvous(response_future, None, response_future) -def _unary_stream(channel, group, method, timeout, protocol_options, metadata, - metadata_transformer, request, request_serializer, - response_deserializer): +def _unary_stream( + channel, + group, + method, + timeout, + protocol_options, + metadata, + metadata_transformer, + request, + request_serializer, + response_deserializer, +): multi_callable = channel.unary_stream( _common.fully_qualified_method(group, method), request_serializer=request_serializer, - response_deserializer=response_deserializer) + response_deserializer=response_deserializer, + ) effective_metadata = _effective_metadata(metadata, metadata_transformer) response_iterator = multi_callable( request, timeout=timeout, metadata=_metadata.unbeta(effective_metadata), - credentials=_credentials(protocol_options)) + credentials=_credentials(protocol_options), + ) return _Rendezvous(None, response_iterator, response_iterator) -def _blocking_stream_unary(channel, group, method, timeout, with_call, - protocol_options, metadata, metadata_transformer, - request_iterator, request_serializer, - response_deserializer): +def _blocking_stream_unary( + channel, + group, + method, + timeout, + with_call, + protocol_options, + metadata, + metadata_transformer, + request_iterator, + request_serializer, + response_deserializer, +): try: multi_callable = channel.stream_unary( _common.fully_qualified_method(group, method), request_serializer=request_serializer, - response_deserializer=response_deserializer) + response_deserializer=response_deserializer, + ) effective_metadata = _effective_metadata(metadata, metadata_transformer) if with_call: response, call = multi_callable.with_call( request_iterator, timeout=timeout, metadata=_metadata.unbeta(effective_metadata), - credentials=_credentials(protocol_options)) + credentials=_credentials(protocol_options), + ) return response, _Rendezvous(None, None, call) else: - return multi_callable(request_iterator, - timeout=timeout, - metadata=_metadata.unbeta(effective_metadata), - credentials=_credentials(protocol_options)) + return multi_callable( + request_iterator, + timeout=timeout, + metadata=_metadata.unbeta(effective_metadata), + credentials=_credentials(protocol_options), + ) except grpc.RpcError as rpc_error_call: raise _abortion_error(rpc_error_call) -def _future_stream_unary(channel, group, method, timeout, protocol_options, - metadata, metadata_transformer, request_iterator, - request_serializer, response_deserializer): +def _future_stream_unary( + channel, + group, + method, + timeout, + protocol_options, + metadata, + metadata_transformer, + request_iterator, + request_serializer, + response_deserializer, +): multi_callable = channel.stream_unary( _common.fully_qualified_method(group, method), request_serializer=request_serializer, - response_deserializer=response_deserializer) + response_deserializer=response_deserializer, + ) effective_metadata = _effective_metadata(metadata, metadata_transformer) response_future = multi_callable.future( request_iterator, timeout=timeout, metadata=_metadata.unbeta(effective_metadata), - credentials=_credentials(protocol_options)) + credentials=_credentials(protocol_options), + ) return _Rendezvous(response_future, None, response_future) -def _stream_stream(channel, group, method, timeout, protocol_options, metadata, - metadata_transformer, request_iterator, request_serializer, - response_deserializer): +def _stream_stream( + channel, + group, + method, + timeout, + protocol_options, + metadata, + metadata_transformer, + request_iterator, + request_serializer, + response_deserializer, +): multi_callable = channel.stream_stream( _common.fully_qualified_method(group, method), request_serializer=request_serializer, - response_deserializer=response_deserializer) + response_deserializer=response_deserializer, + ) effective_metadata = _effective_metadata(metadata, metadata_transformer) response_iterator = multi_callable( request_iterator, timeout=timeout, metadata=_metadata.unbeta(effective_metadata), - credentials=_credentials(protocol_options)) + credentials=_credentials(protocol_options), + ) return _Rendezvous(None, response_iterator, response_iterator) class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable): - - def __init__(self, channel, group, method, metadata_transformer, - request_serializer, response_deserializer): + def __init__( + self, + channel, + group, + method, + metadata_transformer, + request_serializer, + response_deserializer, + ): self._channel = channel self._group = group self._method = method @@ -295,39 +384,64 @@ def __init__(self, channel, group, method, metadata_transformer, self._request_serializer = request_serializer self._response_deserializer = response_deserializer - def __call__(self, - request, - timeout, - metadata=None, - with_call=False, - protocol_options=None): - return _blocking_unary_unary(self._channel, self._group, self._method, - timeout, with_call, protocol_options, - metadata, self._metadata_transformer, - request, self._request_serializer, - self._response_deserializer) + def __call__( + self, + request, + timeout, + metadata=None, + with_call=False, + protocol_options=None, + ): + return _blocking_unary_unary( + self._channel, + self._group, + self._method, + timeout, + with_call, + protocol_options, + metadata, + self._metadata_transformer, + request, + self._request_serializer, + self._response_deserializer, + ) def future(self, request, timeout, metadata=None, protocol_options=None): - return _future_unary_unary(self._channel, self._group, self._method, - timeout, protocol_options, metadata, - self._metadata_transformer, request, - self._request_serializer, - self._response_deserializer) - - def event(self, - request, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + return _future_unary_unary( + self._channel, + self._group, + self._method, + timeout, + protocol_options, + metadata, + self._metadata_transformer, + request, + self._request_serializer, + self._response_deserializer, + ) + + def event( + self, + request, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): raise NotImplementedError() class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable): - - def __init__(self, channel, group, method, metadata_transformer, - request_serializer, response_deserializer): + def __init__( + self, + channel, + group, + method, + metadata_transformer, + request_serializer, + response_deserializer, + ): self._channel = channel self._group = group self._method = method @@ -336,26 +450,41 @@ def __init__(self, channel, group, method, metadata_transformer, self._response_deserializer = response_deserializer def __call__(self, request, timeout, metadata=None, protocol_options=None): - return _unary_stream(self._channel, self._group, self._method, timeout, - protocol_options, metadata, - self._metadata_transformer, request, - self._request_serializer, - self._response_deserializer) - - def event(self, - request, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + return _unary_stream( + self._channel, + self._group, + self._method, + timeout, + protocol_options, + metadata, + self._metadata_transformer, + request, + self._request_serializer, + self._response_deserializer, + ) + + def event( + self, + request, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): raise NotImplementedError() class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable): - - def __init__(self, channel, group, method, metadata_transformer, - request_serializer, response_deserializer): + def __init__( + self, + channel, + group, + method, + metadata_transformer, + request_serializer, + response_deserializer, + ): self._channel = channel self._group = group self._method = method @@ -363,43 +492,65 @@ def __init__(self, channel, group, method, metadata_transformer, self._request_serializer = request_serializer self._response_deserializer = response_deserializer - def __call__(self, - request_iterator, - timeout, - metadata=None, - with_call=False, - protocol_options=None): - return _blocking_stream_unary(self._channel, self._group, self._method, - timeout, with_call, protocol_options, - metadata, self._metadata_transformer, - request_iterator, - self._request_serializer, - self._response_deserializer) - - def future(self, - request_iterator, - timeout, - metadata=None, - protocol_options=None): - return _future_stream_unary(self._channel, self._group, self._method, - timeout, protocol_options, metadata, - self._metadata_transformer, - request_iterator, self._request_serializer, - self._response_deserializer) - - def event(self, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def __call__( + self, + request_iterator, + timeout, + metadata=None, + with_call=False, + protocol_options=None, + ): + return _blocking_stream_unary( + self._channel, + self._group, + self._method, + timeout, + with_call, + protocol_options, + metadata, + self._metadata_transformer, + request_iterator, + self._request_serializer, + self._response_deserializer, + ) + + def future( + self, request_iterator, timeout, metadata=None, protocol_options=None + ): + return _future_stream_unary( + self._channel, + self._group, + self._method, + timeout, + protocol_options, + metadata, + self._metadata_transformer, + request_iterator, + self._request_serializer, + self._response_deserializer, + ) + + def event( + self, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): raise NotImplementedError() class _StreamStreamMultiCallable(face.StreamStreamMultiCallable): - - def __init__(self, channel, group, method, metadata_transformer, - request_serializer, response_deserializer): + def __init__( + self, + channel, + group, + method, + metadata_transformer, + request_serializer, + response_deserializer, + ): self._channel = channel self._group = group self._method = method @@ -407,256 +558,391 @@ def __init__(self, channel, group, method, metadata_transformer, self._request_serializer = request_serializer self._response_deserializer = response_deserializer - def __call__(self, - request_iterator, - timeout, - metadata=None, - protocol_options=None): - return _stream_stream(self._channel, self._group, self._method, timeout, - protocol_options, metadata, - self._metadata_transformer, request_iterator, - self._request_serializer, - self._response_deserializer) - - def event(self, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def __call__( + self, request_iterator, timeout, metadata=None, protocol_options=None + ): + return _stream_stream( + self._channel, + self._group, + self._method, + timeout, + protocol_options, + metadata, + self._metadata_transformer, + request_iterator, + self._request_serializer, + self._response_deserializer, + ) + + def event( + self, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): raise NotImplementedError() class _GenericStub(face.GenericStub): - - def __init__(self, channel, metadata_transformer, request_serializers, - response_deserializers): + def __init__( + self, + channel, + metadata_transformer, + request_serializers, + response_deserializers, + ): self._channel = channel self._metadata_transformer = metadata_transformer self._request_serializers = request_serializers or {} self._response_deserializers = response_deserializers or {} - def blocking_unary_unary(self, - group, - method, - request, - timeout, - metadata=None, - with_call=None, - protocol_options=None): - request_serializer = self._request_serializers.get(( - group, - method, - )) - response_deserializer = self._response_deserializers.get(( - group, - method, - )) - return _blocking_unary_unary(self._channel, group, method, timeout, - with_call, protocol_options, metadata, - self._metadata_transformer, request, - request_serializer, response_deserializer) - - def future_unary_unary(self, - group, - method, - request, - timeout, - metadata=None, - protocol_options=None): - request_serializer = self._request_serializers.get(( - group, - method, - )) - response_deserializer = self._response_deserializers.get(( - group, - method, - )) - return _future_unary_unary(self._channel, group, method, timeout, - protocol_options, metadata, - self._metadata_transformer, request, - request_serializer, response_deserializer) - - def inline_unary_stream(self, - group, - method, - request, - timeout, - metadata=None, - protocol_options=None): - request_serializer = self._request_serializers.get(( - group, - method, - )) - response_deserializer = self._response_deserializers.get(( - group, - method, - )) - return _unary_stream(self._channel, group, method, timeout, - protocol_options, metadata, - self._metadata_transformer, request, - request_serializer, response_deserializer) - - def blocking_stream_unary(self, - group, - method, - request_iterator, - timeout, - metadata=None, - with_call=None, - protocol_options=None): - request_serializer = self._request_serializers.get(( + def blocking_unary_unary( + self, + group, + method, + request, + timeout, + metadata=None, + with_call=None, + protocol_options=None, + ): + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _blocking_unary_unary( + self._channel, group, method, - )) - response_deserializer = self._response_deserializers.get(( + timeout, + with_call, + protocol_options, + metadata, + self._metadata_transformer, + request, + request_serializer, + response_deserializer, + ) + + def future_unary_unary( + self, + group, + method, + request, + timeout, + metadata=None, + protocol_options=None, + ): + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _future_unary_unary( + self._channel, group, method, - )) - return _blocking_stream_unary(self._channel, group, method, timeout, - with_call, protocol_options, metadata, - self._metadata_transformer, - request_iterator, request_serializer, - response_deserializer) - - def future_stream_unary(self, - group, - method, - request_iterator, - timeout, - metadata=None, - protocol_options=None): - request_serializer = self._request_serializers.get(( + timeout, + protocol_options, + metadata, + self._metadata_transformer, + request, + request_serializer, + response_deserializer, + ) + + def inline_unary_stream( + self, + group, + method, + request, + timeout, + metadata=None, + protocol_options=None, + ): + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _unary_stream( + self._channel, group, method, - )) - response_deserializer = self._response_deserializers.get(( + timeout, + protocol_options, + metadata, + self._metadata_transformer, + request, + request_serializer, + response_deserializer, + ) + + def blocking_stream_unary( + self, + group, + method, + request_iterator, + timeout, + metadata=None, + with_call=None, + protocol_options=None, + ): + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _blocking_stream_unary( + self._channel, group, method, - )) - return _future_stream_unary(self._channel, group, method, timeout, - protocol_options, metadata, - self._metadata_transformer, - request_iterator, request_serializer, - response_deserializer) - - def inline_stream_stream(self, - group, - method, - request_iterator, - timeout, - metadata=None, - protocol_options=None): - request_serializer = self._request_serializers.get(( + timeout, + with_call, + protocol_options, + metadata, + self._metadata_transformer, + request_iterator, + request_serializer, + response_deserializer, + ) + + def future_stream_unary( + self, + group, + method, + request_iterator, + timeout, + metadata=None, + protocol_options=None, + ): + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _future_stream_unary( + self._channel, group, method, - )) - response_deserializer = self._response_deserializers.get(( + timeout, + protocol_options, + metadata, + self._metadata_transformer, + request_iterator, + request_serializer, + response_deserializer, + ) + + def inline_stream_stream( + self, + group, + method, + request_iterator, + timeout, + metadata=None, + protocol_options=None, + ): + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _stream_stream( + self._channel, group, method, - )) - return _stream_stream(self._channel, group, method, timeout, - protocol_options, metadata, - self._metadata_transformer, request_iterator, - request_serializer, response_deserializer) - - def event_unary_unary(self, - group, - method, - request, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + timeout, + protocol_options, + metadata, + self._metadata_transformer, + request_iterator, + request_serializer, + response_deserializer, + ) + + def event_unary_unary( + self, + group, + method, + request, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): raise NotImplementedError() - def event_unary_stream(self, - group, - method, - request, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event_unary_stream( + self, + group, + method, + request, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): raise NotImplementedError() - def event_stream_unary(self, - group, - method, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event_stream_unary( + self, + group, + method, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): raise NotImplementedError() - def event_stream_stream(self, - group, - method, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event_stream_stream( + self, + group, + method, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): raise NotImplementedError() def unary_unary(self, group, method): - request_serializer = self._request_serializers.get(( - group, - method, - )) - response_deserializer = self._response_deserializers.get(( + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _UnaryUnaryMultiCallable( + self._channel, group, method, - )) - return _UnaryUnaryMultiCallable(self._channel, group, method, - self._metadata_transformer, - request_serializer, - response_deserializer) + self._metadata_transformer, + request_serializer, + response_deserializer, + ) def unary_stream(self, group, method): - request_serializer = self._request_serializers.get(( + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _UnaryStreamMultiCallable( + self._channel, group, method, - )) - response_deserializer = self._response_deserializers.get(( - group, - method, - )) - return _UnaryStreamMultiCallable(self._channel, group, method, - self._metadata_transformer, - request_serializer, - response_deserializer) + self._metadata_transformer, + request_serializer, + response_deserializer, + ) def stream_unary(self, group, method): - request_serializer = self._request_serializers.get(( - group, - method, - )) - response_deserializer = self._response_deserializers.get(( + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _StreamUnaryMultiCallable( + self._channel, group, method, - )) - return _StreamUnaryMultiCallable(self._channel, group, method, - self._metadata_transformer, - request_serializer, - response_deserializer) + self._metadata_transformer, + request_serializer, + response_deserializer, + ) def stream_stream(self, group, method): - request_serializer = self._request_serializers.get(( - group, - method, - )) - response_deserializer = self._response_deserializers.get(( + request_serializer = self._request_serializers.get( + ( + group, + method, + ) + ) + response_deserializer = self._response_deserializers.get( + ( + group, + method, + ) + ) + return _StreamStreamMultiCallable( + self._channel, group, method, - )) - return _StreamStreamMultiCallable(self._channel, group, method, - self._metadata_transformer, - request_serializer, - response_deserializer) + self._metadata_transformer, + request_serializer, + response_deserializer, + ) def __enter__(self): return self @@ -666,7 +952,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): class _DynamicStub(face.DynamicStub): - def __init__(self, backing_generic_stub, group, cardinalities): self._generic_stub = backing_generic_stub self._group = group @@ -683,8 +968,9 @@ def __getattr__(self, attr): elif method_cardinality is cardinality.Cardinality.STREAM_STREAM: return self._generic_stub.stream_stream(self._group, attr) else: - raise AttributeError('_DynamicStub object has no attribute "%s"!' % - attr) + raise AttributeError( + '_DynamicStub object has no attribute "%s"!' % attr + ) def __enter__(self): return self @@ -693,14 +979,37 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False -def generic_stub(channel, host, metadata_transformer, request_serializers, - response_deserializers): - return _GenericStub(channel, metadata_transformer, request_serializers, - response_deserializers) - - -def dynamic_stub(channel, service, cardinalities, host, metadata_transformer, - request_serializers, response_deserializers): +def generic_stub( + channel, + host, + metadata_transformer, + request_serializers, + response_deserializers, +): + return _GenericStub( + channel, + metadata_transformer, + request_serializers, + response_deserializers, + ) + + +def dynamic_stub( + channel, + service, + cardinalities, + host, + metadata_transformer, + request_serializers, + response_deserializers, +): return _DynamicStub( - _GenericStub(channel, metadata_transformer, request_serializers, - response_deserializers), service, cardinalities) + _GenericStub( + channel, + metadata_transformer, + request_serializers, + response_deserializers, + ), + service, + cardinalities, + ) diff --git a/src/python/grpcio/grpc/beta/_metadata.py b/src/python/grpcio/grpc/beta/_metadata.py index b7c85352853d0..301010878d67f 100644 --- a/src/python/grpcio/grpc/beta/_metadata.py +++ b/src/python/grpcio/grpc/beta/_metadata.py @@ -15,24 +15,27 @@ import collections -_Metadatum = collections.namedtuple('_Metadatum', ( - 'key', - 'value', -)) +_Metadatum = collections.namedtuple( + "_Metadatum", + ( + "key", + "value", + ), +) def _beta_metadatum(key, value): - beta_key = key if isinstance(key, (bytes,)) else key.encode('ascii') - beta_value = value if isinstance(value, (bytes,)) else value.encode('ascii') + beta_key = key if isinstance(key, (bytes,)) else key.encode("ascii") + beta_value = value if isinstance(value, (bytes,)) else value.encode("ascii") return _Metadatum(beta_key, beta_value) def _metadatum(beta_key, beta_value): - key = beta_key if isinstance(beta_key, (str,)) else beta_key.decode('utf8') - if isinstance(beta_value, (str,)) or key[-4:] == '-bin': + key = beta_key if isinstance(beta_key, (str,)) else beta_key.decode("utf8") + if isinstance(beta_value, (str,)) or key[-4:] == "-bin": value = beta_value else: - value = beta_value.decode('utf8') + value = beta_value.decode("utf8") return _Metadatum(key, value) @@ -49,4 +52,5 @@ def unbeta(beta_metadata): else: return tuple( _metadatum(beta_key, beta_value) - for beta_key, beta_value in beta_metadata) + for beta_key, beta_value in beta_metadata + ) diff --git a/src/python/grpcio/grpc/beta/_server_adaptations.py b/src/python/grpcio/grpc/beta/_server_adaptations.py index 8843a3c5502e1..a6f730bb29bf6 100644 --- a/src/python/grpcio/grpc/beta/_server_adaptations.py +++ b/src/python/grpcio/grpc/beta/_server_adaptations.py @@ -33,7 +33,6 @@ class _ServerProtocolContext(interfaces.GRPCServicerContext): - def __init__(self, servicer_context): self._servicer_context = servicer_context @@ -45,7 +44,6 @@ def disable_next_response_compression(self): class _FaceServicerContext(face.ServicerContext): - def __init__(self, servicer_context): self._servicer_context = servicer_context @@ -57,7 +55,8 @@ def time_remaining(self): def add_abortion_callback(self, abortion_callback): raise NotImplementedError( - 'add_abortion_callback no longer supported server-side!') + "add_abortion_callback no longer supported server-side!" + ) def cancel(self): self._servicer_context.cancel() @@ -70,11 +69,13 @@ def invocation_metadata(self): def initial_metadata(self, initial_metadata): self._servicer_context.send_initial_metadata( - _metadata.unbeta(initial_metadata)) + _metadata.unbeta(initial_metadata) + ) def terminal_metadata(self, terminal_metadata): self._servicer_context.set_terminal_metadata( - _metadata.unbeta(terminal_metadata)) + _metadata.unbeta(terminal_metadata) + ) def code(self, code): self._servicer_context.set_code(code) @@ -84,25 +85,24 @@ def details(self, details): def _adapt_unary_request_inline(unary_request_inline): - def adaptation(request, servicer_context): - return unary_request_inline(request, - _FaceServicerContext(servicer_context)) + return unary_request_inline( + request, _FaceServicerContext(servicer_context) + ) return adaptation def _adapt_stream_request_inline(stream_request_inline): - def adaptation(request_iterator, servicer_context): - return stream_request_inline(request_iterator, - _FaceServicerContext(servicer_context)) + return stream_request_inline( + request_iterator, _FaceServicerContext(servicer_context) + ) return adaptation class _Callback(stream.Consumer): - def __init__(self): self._condition = threading.Condition() self._values = [] @@ -155,8 +155,9 @@ def draw_all_values(self): self._condition.wait() -def _run_request_pipe_thread(request_iterator, request_consumer, - servicer_context): +def _run_request_pipe_thread( + request_iterator, request_consumer, servicer_context +): thread_joined = threading.Event() def pipe_requests(): @@ -174,26 +175,28 @@ def pipe_requests(): def _adapt_unary_unary_event(unary_unary_event): - def adaptation(request, servicer_context): callback = _Callback() if not servicer_context.add_callback(callback.cancel): raise abandonment.Abandoned() - unary_unary_event(request, callback.consume_and_terminate, - _FaceServicerContext(servicer_context)) + unary_unary_event( + request, + callback.consume_and_terminate, + _FaceServicerContext(servicer_context), + ) return callback.draw_all_values()[0] return adaptation def _adapt_unary_stream_event(unary_stream_event): - def adaptation(request, servicer_context): callback = _Callback() if not servicer_context.add_callback(callback.cancel): raise abandonment.Abandoned() - unary_stream_event(request, callback, - _FaceServicerContext(servicer_context)) + unary_stream_event( + request, callback, _FaceServicerContext(servicer_context) + ) while True: response = callback.draw_one_value() if response is None: @@ -205,31 +208,33 @@ def adaptation(request, servicer_context): def _adapt_stream_unary_event(stream_unary_event): - def adaptation(request_iterator, servicer_context): callback = _Callback() if not servicer_context.add_callback(callback.cancel): raise abandonment.Abandoned() request_consumer = stream_unary_event( callback.consume_and_terminate, - _FaceServicerContext(servicer_context)) - _run_request_pipe_thread(request_iterator, request_consumer, - servicer_context) + _FaceServicerContext(servicer_context), + ) + _run_request_pipe_thread( + request_iterator, request_consumer, servicer_context + ) return callback.draw_all_values()[0] return adaptation def _adapt_stream_stream_event(stream_stream_event): - def adaptation(request_iterator, servicer_context): callback = _Callback() if not servicer_context.add_callback(callback.cancel): raise abandonment.Abandoned() request_consumer = stream_stream_event( - callback, _FaceServicerContext(servicer_context)) - _run_request_pipe_thread(request_iterator, request_consumer, - servicer_context) + callback, _FaceServicerContext(servicer_context) + ) + _run_request_pipe_thread( + request_iterator, request_consumer, servicer_context + ) while True: response = callback.draw_one_value() if response is None: @@ -241,66 +246,125 @@ def adaptation(request_iterator, servicer_context): class _SimpleMethodHandler( - collections.namedtuple('_MethodHandler', ( - 'request_streaming', - 'response_streaming', - 'request_deserializer', - 'response_serializer', - 'unary_unary', - 'unary_stream', - 'stream_unary', - 'stream_stream', - )), grpc.RpcMethodHandler): + collections.namedtuple( + "_MethodHandler", + ( + "request_streaming", + "response_streaming", + "request_deserializer", + "response_serializer", + "unary_unary", + "unary_stream", + "stream_unary", + "stream_stream", + ), + ), + grpc.RpcMethodHandler, +): pass -def _simple_method_handler(implementation, request_deserializer, - response_serializer): +def _simple_method_handler( + implementation, request_deserializer, response_serializer +): if implementation.style is style.Service.INLINE: if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY: return _SimpleMethodHandler( - False, False, request_deserializer, response_serializer, + False, + False, + request_deserializer, + response_serializer, _adapt_unary_request_inline(implementation.unary_unary_inline), - None, None, None) + None, + None, + None, + ) elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM: return _SimpleMethodHandler( - False, True, request_deserializer, response_serializer, None, + False, + True, + request_deserializer, + response_serializer, + None, _adapt_unary_request_inline(implementation.unary_stream_inline), - None, None) + None, + None, + ) elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY: return _SimpleMethodHandler( - True, False, request_deserializer, response_serializer, None, + True, + False, + request_deserializer, + response_serializer, + None, None, _adapt_stream_request_inline( - implementation.stream_unary_inline), None) - elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM: + implementation.stream_unary_inline + ), + None, + ) + elif ( + implementation.cardinality is cardinality.Cardinality.STREAM_STREAM + ): return _SimpleMethodHandler( - True, True, request_deserializer, response_serializer, None, - None, None, + True, + True, + request_deserializer, + response_serializer, + None, + None, + None, _adapt_stream_request_inline( - implementation.stream_stream_inline)) + implementation.stream_stream_inline + ), + ) elif implementation.style is style.Service.EVENT: if implementation.cardinality is cardinality.Cardinality.UNARY_UNARY: return _SimpleMethodHandler( - False, False, request_deserializer, response_serializer, + False, + False, + request_deserializer, + response_serializer, _adapt_unary_unary_event(implementation.unary_unary_event), - None, None, None) + None, + None, + None, + ) elif implementation.cardinality is cardinality.Cardinality.UNARY_STREAM: return _SimpleMethodHandler( - False, True, request_deserializer, response_serializer, None, + False, + True, + request_deserializer, + response_serializer, + None, _adapt_unary_stream_event(implementation.unary_stream_event), - None, None) + None, + None, + ) elif implementation.cardinality is cardinality.Cardinality.STREAM_UNARY: return _SimpleMethodHandler( - True, False, request_deserializer, response_serializer, None, + True, + False, + request_deserializer, + response_serializer, + None, None, _adapt_stream_unary_event(implementation.stream_unary_event), - None) - elif implementation.cardinality is cardinality.Cardinality.STREAM_STREAM: + None, + ) + elif ( + implementation.cardinality is cardinality.Cardinality.STREAM_STREAM + ): return _SimpleMethodHandler( - True, True, request_deserializer, response_serializer, None, - None, None, - _adapt_stream_stream_event(implementation.stream_stream_event)) + True, + True, + request_deserializer, + response_serializer, + None, + None, + None, + _adapt_stream_stream_event(implementation.stream_stream_event), + ) raise ValueError() @@ -314,36 +378,44 @@ def _flatten_method_pair_map(method_pair_map): class _GenericRpcHandler(grpc.GenericRpcHandler): - - def __init__(self, method_implementations, multi_method_implementation, - request_deserializers, response_serializers): + def __init__( + self, + method_implementations, + multi_method_implementation, + request_deserializers, + response_serializers, + ): self._method_implementations = _flatten_method_pair_map( - method_implementations) + method_implementations + ) self._request_deserializers = _flatten_method_pair_map( - request_deserializers) + request_deserializers + ) self._response_serializers = _flatten_method_pair_map( - response_serializers) + response_serializers + ) self._multi_method_implementation = multi_method_implementation def service(self, handler_call_details): method_implementation = self._method_implementations.get( - handler_call_details.method) + handler_call_details.method + ) if method_implementation is not None: return _simple_method_handler( method_implementation, self._request_deserializers.get(handler_call_details.method), - self._response_serializers.get(handler_call_details.method)) + self._response_serializers.get(handler_call_details.method), + ) elif self._multi_method_implementation is None: return None else: try: - return None #TODO(nathaniel): call the multimethod. + return None # TODO(nathaniel): call the multimethod. except face.NoSuchMethodError: return None class _Server(interfaces.Server): - def __init__(self, grpc_server): self._grpc_server = grpc_server @@ -368,13 +440,20 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False -def server(service_implementations, multi_method_implementation, - request_deserializers, response_serializers, thread_pool, - thread_pool_size): - generic_rpc_handler = _GenericRpcHandler(service_implementations, - multi_method_implementation, - request_deserializers, - response_serializers) +def server( + service_implementations, + multi_method_implementation, + request_deserializers, + response_serializers, + thread_pool, + thread_pool_size, +): + generic_rpc_handler = _GenericRpcHandler( + service_implementations, + multi_method_implementation, + request_deserializers, + response_serializers, + ) if thread_pool is None: effective_thread_pool = logging_pool.pool( _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size @@ -382,4 +461,5 @@ def server(service_implementations, multi_method_implementation, else: effective_thread_pool = thread_pool return _Server( - grpc.server(effective_thread_pool, handlers=(generic_rpc_handler,))) + grpc.server(effective_thread_pool, handlers=(generic_rpc_handler,)) + ) diff --git a/src/python/grpcio/grpc/beta/implementations.py b/src/python/grpcio/grpc/beta/implementations.py index 43312aac7c825..ffa4f0d4bfe00 100644 --- a/src/python/grpcio/grpc/beta/implementations.py +++ b/src/python/grpcio/grpc/beta/implementations.py @@ -25,8 +25,7 @@ from grpc.beta import _server_adaptations from grpc.beta import interfaces # pylint: disable=unused-import from grpc.framework.common import cardinality # pylint: disable=unused-import -from grpc.framework.interfaces.face import \ - face # pylint: disable=unused-import +from grpc.framework.interfaces.face import face # pylint: disable=unused-import # pylint: disable=too-many-arguments @@ -36,9 +35,7 @@ def metadata_call_credentials(metadata_plugin, name=None): - def plugin(context, callback): - def wrapped_callback(beta_metadata, error): callback(_metadata.unbeta(beta_metadata), error) @@ -50,12 +47,12 @@ def wrapped_callback(beta_metadata, error): def google_call_credentials(credentials): """Construct CallCredentials from GoogleCredentials. - Args: - credentials: A GoogleCredentials object from the oauth2client library. + Args: + credentials: A GoogleCredentials object from the oauth2client library. - Returns: - A CallCredentials object for use in a GRPCCallOptions object. - """ + Returns: + A CallCredentials object for use in a GRPCCallOptions object. + """ return metadata_call_credentials(_auth.GoogleCallCredentials(credentials)) @@ -67,10 +64,10 @@ def google_call_credentials(credentials): class Channel(object): """A channel to a remote host through which RPCs may be conducted. - Only the "subscribe" and "unsubscribe" methods are supported for application - use. This class' instance constructor and all other attributes are - unsupported. - """ + Only the "subscribe" and "unsubscribe" methods are supported for application + use. This class' instance constructor and all other attributes are + unsupported. + """ def __init__(self, channel): self._channel = channel @@ -78,71 +75,80 @@ def __init__(self, channel): def subscribe(self, callback, try_to_connect=None): """Subscribes to this Channel's connectivity. - Args: - callback: A callable to be invoked and passed an - interfaces.ChannelConnectivity identifying this Channel's connectivity. - The callable will be invoked immediately upon subscription and again for - every change to this Channel's connectivity thereafter until it is - unsubscribed. - try_to_connect: A boolean indicating whether or not this Channel should - attempt to connect if it is not already connected and ready to conduct - RPCs. - """ + Args: + callback: A callable to be invoked and passed an + interfaces.ChannelConnectivity identifying this Channel's connectivity. + The callable will be invoked immediately upon subscription and again for + every change to this Channel's connectivity thereafter until it is + unsubscribed. + try_to_connect: A boolean indicating whether or not this Channel should + attempt to connect if it is not already connected and ready to conduct + RPCs. + """ self._channel.subscribe(callback, try_to_connect=try_to_connect) def unsubscribe(self, callback): """Unsubscribes a callback from this Channel's connectivity. - Args: - callback: A callable previously registered with this Channel from having - been passed to its "subscribe" method. - """ + Args: + callback: A callable previously registered with this Channel from having + been passed to its "subscribe" method. + """ self._channel.unsubscribe(callback) def insecure_channel(host, port): """Creates an insecure Channel to a remote host. - Args: - host: The name of the remote host to which to connect. - port: The port of the remote host to which to connect. - If None only the 'host' part will be used. + Args: + host: The name of the remote host to which to connect. + port: The port of the remote host to which to connect. + If None only the 'host' part will be used. - Returns: - A Channel to the remote host through which RPCs may be conducted. - """ - channel = grpc.insecure_channel(host if port is None else '%s:%d' % - (host, port)) + Returns: + A Channel to the remote host through which RPCs may be conducted. + """ + channel = grpc.insecure_channel( + host if port is None else "%s:%d" % (host, port) + ) return Channel(channel) def secure_channel(host, port, channel_credentials): """Creates a secure Channel to a remote host. - Args: - host: The name of the remote host to which to connect. - port: The port of the remote host to which to connect. - If None only the 'host' part will be used. - channel_credentials: A ChannelCredentials. + Args: + host: The name of the remote host to which to connect. + port: The port of the remote host to which to connect. + If None only the 'host' part will be used. + channel_credentials: A ChannelCredentials. - Returns: - A secure Channel to the remote host through which RPCs may be conducted. - """ + Returns: + A secure Channel to the remote host through which RPCs may be conducted. + """ channel = grpc.secure_channel( - host if port is None else '%s:%d' % (host, port), channel_credentials) + host if port is None else "%s:%d" % (host, port), channel_credentials + ) return Channel(channel) class StubOptions(object): """A value encapsulating the various options for creation of a Stub. - This class and its instances have no supported interface - it exists to define - the type of its instances and its instances exist to be passed to other - functions. - """ + This class and its instances have no supported interface - it exists to define + the type of its instances and its instances exist to be passed to other + functions. + """ - def __init__(self, host, request_serializers, response_deserializers, - metadata_transformer, thread_pool, thread_pool_size): + def __init__( + self, + host, + request_serializers, + response_deserializers, + metadata_transformer, + thread_pool, + thread_pool_size, + ): self.host = host self.request_serializers = request_serializers self.response_deserializers = response_deserializers @@ -154,69 +160,78 @@ def __init__(self, host, request_serializers, response_deserializers, _EMPTY_STUB_OPTIONS = StubOptions(None, None, None, None, None, None) -def stub_options(host=None, - request_serializers=None, - response_deserializers=None, - metadata_transformer=None, - thread_pool=None, - thread_pool_size=None): +def stub_options( + host=None, + request_serializers=None, + response_deserializers=None, + metadata_transformer=None, + thread_pool=None, + thread_pool_size=None, +): """Creates a StubOptions value to be passed at stub creation. - All parameters are optional and should always be passed by keyword. + All parameters are optional and should always be passed by keyword. - Args: - host: A host string to set on RPC calls. - request_serializers: A dictionary from service name-method name pair to - request serialization behavior. - response_deserializers: A dictionary from service name-method name pair to - response deserialization behavior. - metadata_transformer: A callable that given a metadata object produces - another metadata object to be used in the underlying communication on the - wire. - thread_pool: A thread pool to use in stubs. - thread_pool_size: The size of thread pool to create for use in stubs; - ignored if thread_pool has been passed. - - Returns: - A StubOptions value created from the passed parameters. - """ - return StubOptions(host, request_serializers, response_deserializers, - metadata_transformer, thread_pool, thread_pool_size) + Args: + host: A host string to set on RPC calls. + request_serializers: A dictionary from service name-method name pair to + request serialization behavior. + response_deserializers: A dictionary from service name-method name pair to + response deserialization behavior. + metadata_transformer: A callable that given a metadata object produces + another metadata object to be used in the underlying communication on the + wire. + thread_pool: A thread pool to use in stubs. + thread_pool_size: The size of thread pool to create for use in stubs; + ignored if thread_pool has been passed. + + Returns: + A StubOptions value created from the passed parameters. + """ + return StubOptions( + host, + request_serializers, + response_deserializers, + metadata_transformer, + thread_pool, + thread_pool_size, + ) def generic_stub(channel, options=None): """Creates a face.GenericStub on which RPCs can be made. - Args: - channel: A Channel for use by the created stub. - options: A StubOptions customizing the created stub. + Args: + channel: A Channel for use by the created stub. + options: A StubOptions customizing the created stub. - Returns: - A face.GenericStub on which RPCs can be made. - """ + Returns: + A face.GenericStub on which RPCs can be made. + """ effective_options = _EMPTY_STUB_OPTIONS if options is None else options return _client_adaptations.generic_stub( channel._channel, # pylint: disable=protected-access effective_options.host, effective_options.metadata_transformer, effective_options.request_serializers, - effective_options.response_deserializers) + effective_options.response_deserializers, + ) def dynamic_stub(channel, service, cardinalities, options=None): """Creates a face.DynamicStub with which RPCs can be invoked. - Args: - channel: A Channel for the returned face.DynamicStub to use. - service: The package-qualified full name of the service. - cardinalities: A dictionary from RPC method name to cardinality.Cardinality - value identifying the cardinality of the RPC method. - options: An optional StubOptions value further customizing the functionality - of the returned face.DynamicStub. - - Returns: - A face.DynamicStub with which RPCs can be invoked. - """ + Args: + channel: A Channel for the returned face.DynamicStub to use. + service: The package-qualified full name of the service. + cardinalities: A dictionary from RPC method name to cardinality.Cardinality + value identifying the cardinality of the RPC method. + options: An optional StubOptions value further customizing the functionality + of the returned face.DynamicStub. + + Returns: + A face.DynamicStub with which RPCs can be invoked. + """ effective_options = _EMPTY_STUB_OPTIONS if options is None else options return _client_adaptations.dynamic_stub( channel._channel, # pylint: disable=protected-access @@ -225,7 +240,8 @@ def dynamic_stub(channel, service, cardinalities, options=None): effective_options.host, effective_options.metadata_transformer, effective_options.request_serializers, - effective_options.response_deserializers) + effective_options.response_deserializers, + ) ServerCredentials = grpc.ServerCredentials @@ -235,14 +251,21 @@ def dynamic_stub(channel, service, cardinalities, options=None): class ServerOptions(object): """A value encapsulating the various options for creation of a Server. - This class and its instances have no supported interface - it exists to define - the type of its instances and its instances exist to be passed to other - functions. - """ + This class and its instances have no supported interface - it exists to define + the type of its instances and its instances exist to be passed to other + functions. + """ - def __init__(self, multi_method_implementation, request_deserializers, - response_serializers, thread_pool, thread_pool_size, - default_timeout, maximum_timeout): + def __init__( + self, + multi_method_implementation, + request_deserializers, + response_serializers, + thread_pool, + thread_pool_size, + default_timeout, + maximum_timeout, + ): self.multi_method_implementation = multi_method_implementation self.request_deserializers = request_deserializers self.response_serializers = response_serializers @@ -255,57 +278,68 @@ def __init__(self, multi_method_implementation, request_deserializers, _EMPTY_SERVER_OPTIONS = ServerOptions(None, None, None, None, None, None, None) -def server_options(multi_method_implementation=None, - request_deserializers=None, - response_serializers=None, - thread_pool=None, - thread_pool_size=None, - default_timeout=None, - maximum_timeout=None): +def server_options( + multi_method_implementation=None, + request_deserializers=None, + response_serializers=None, + thread_pool=None, + thread_pool_size=None, + default_timeout=None, + maximum_timeout=None, +): """Creates a ServerOptions value to be passed at server creation. - All parameters are optional and should always be passed by keyword. - - Args: - multi_method_implementation: A face.MultiMethodImplementation to be called - to service an RPC if the server has no specific method implementation for - the name of the RPC for which service was requested. - request_deserializers: A dictionary from service name-method name pair to - request deserialization behavior. - response_serializers: A dictionary from service name-method name pair to - response serialization behavior. - thread_pool: A thread pool to use in stubs. - thread_pool_size: The size of thread pool to create for use in stubs; - ignored if thread_pool has been passed. - default_timeout: A duration in seconds to allow for RPC service when - servicing RPCs that did not include a timeout value when invoked. - maximum_timeout: A duration in seconds to allow for RPC service when - servicing RPCs no matter what timeout value was passed when the RPC was - invoked. - - Returns: - A StubOptions value created from the passed parameters. - """ - return ServerOptions(multi_method_implementation, request_deserializers, - response_serializers, thread_pool, thread_pool_size, - default_timeout, maximum_timeout) + All parameters are optional and should always be passed by keyword. + + Args: + multi_method_implementation: A face.MultiMethodImplementation to be called + to service an RPC if the server has no specific method implementation for + the name of the RPC for which service was requested. + request_deserializers: A dictionary from service name-method name pair to + request deserialization behavior. + response_serializers: A dictionary from service name-method name pair to + response serialization behavior. + thread_pool: A thread pool to use in stubs. + thread_pool_size: The size of thread pool to create for use in stubs; + ignored if thread_pool has been passed. + default_timeout: A duration in seconds to allow for RPC service when + servicing RPCs that did not include a timeout value when invoked. + maximum_timeout: A duration in seconds to allow for RPC service when + servicing RPCs no matter what timeout value was passed when the RPC was + invoked. + + Returns: + A StubOptions value created from the passed parameters. + """ + return ServerOptions( + multi_method_implementation, + request_deserializers, + response_serializers, + thread_pool, + thread_pool_size, + default_timeout, + maximum_timeout, + ) def server(service_implementations, options=None): """Creates an interfaces.Server with which RPCs can be serviced. - Args: - service_implementations: A dictionary from service name-method name pair to - face.MethodImplementation. - options: An optional ServerOptions value further customizing the - functionality of the returned Server. + Args: + service_implementations: A dictionary from service name-method name pair to + face.MethodImplementation. + options: An optional ServerOptions value further customizing the + functionality of the returned Server. - Returns: - An interfaces.Server with which RPCs can be serviced. - """ + Returns: + An interfaces.Server with which RPCs can be serviced. + """ effective_options = _EMPTY_SERVER_OPTIONS if options is None else options return _server_adaptations.server( - service_implementations, effective_options.multi_method_implementation, + service_implementations, + effective_options.multi_method_implementation, effective_options.request_deserializers, - effective_options.response_serializers, effective_options.thread_pool, - effective_options.thread_pool_size) + effective_options.response_serializers, + effective_options.thread_pool, + effective_options.thread_pool_size, + ) diff --git a/src/python/grpcio/grpc/beta/interfaces.py b/src/python/grpcio/grpc/beta/interfaces.py index e29b173a4347f..c29b291585491 100644 --- a/src/python/grpcio/grpc/beta/interfaces.py +++ b/src/python/grpcio/grpc/beta/interfaces.py @@ -27,10 +27,10 @@ class GRPCCallOptions(object): """A value encapsulating gRPC-specific options passed on RPC invocation. - This class and its instances have no supported interface - it exists to - define the type of its instances and its instances exist to be passed to - other functions. - """ + This class and its instances have no supported interface - it exists to + define the type of its instances and its instances exist to be passed to + other functions. + """ def __init__(self, disable_compression, subcall_of, credentials): self.disable_compression = disable_compression @@ -41,14 +41,14 @@ def __init__(self, disable_compression, subcall_of, credentials): def grpc_call_options(disable_compression=False, credentials=None): """Creates a GRPCCallOptions value to be passed at RPC invocation. - All parameters are optional and should always be passed by keyword. + All parameters are optional and should always be passed by keyword. - Args: - disable_compression: A boolean indicating whether or not compression should - be disabled for the request object of the RPC. Only valid for - request-unary RPCs. - credentials: A CallCredentials object to use for the invoked RPC. - """ + Args: + disable_compression: A boolean indicating whether or not compression should + be disabled for the request object of the RPC. Only valid for + request-unary RPCs. + credentials: A CallCredentials object to use for the invoked RPC. + """ return GRPCCallOptions(disable_compression, None, credentials) @@ -64,9 +64,9 @@ class GRPCServicerContext(abc.ABC): def peer(self): """Identifies the peer that invoked the RPC being serviced. - Returns: - A string identifying the peer that invoked the RPC being serviced. - """ + Returns: + A string identifying the peer that invoked the RPC being serviced. + """ raise NotImplementedError() @abc.abstractmethod @@ -91,73 +91,73 @@ class Server(abc.ABC): def add_insecure_port(self, address): """Reserves a port for insecure RPC service once this Server becomes active. - This method may only be called before calling this Server's start method is - called. + This method may only be called before calling this Server's start method is + called. - Args: - address: The address for which to open a port. + Args: + address: The address for which to open a port. - Returns: - An integer port on which RPCs will be serviced after this link has been - started. This is typically the same number as the port number contained - in the passed address, but will likely be different if the port number - contained in the passed address was zero. - """ + Returns: + An integer port on which RPCs will be serviced after this link has been + started. This is typically the same number as the port number contained + in the passed address, but will likely be different if the port number + contained in the passed address was zero. + """ raise NotImplementedError() @abc.abstractmethod def add_secure_port(self, address, server_credentials): """Reserves a port for secure RPC service after this Server becomes active. - This method may only be called before calling this Server's start method is - called. + This method may only be called before calling this Server's start method is + called. - Args: - address: The address for which to open a port. - server_credentials: A ServerCredentials. - - Returns: - An integer port on which RPCs will be serviced after this link has been - started. This is typically the same number as the port number contained - in the passed address, but will likely be different if the port number - contained in the passed address was zero. - """ + Args: + address: The address for which to open a port. + server_credentials: A ServerCredentials. + + Returns: + An integer port on which RPCs will be serviced after this link has been + started. This is typically the same number as the port number contained + in the passed address, but will likely be different if the port number + contained in the passed address was zero. + """ raise NotImplementedError() @abc.abstractmethod def start(self): """Starts this Server's service of RPCs. - This method may only be called while the server is not serving RPCs (i.e. it - is not idempotent). - """ + This method may only be called while the server is not serving RPCs (i.e. it + is not idempotent). + """ raise NotImplementedError() @abc.abstractmethod def stop(self, grace): """Stops this Server's service of RPCs. - All calls to this method immediately stop service of new RPCs. When existing - RPCs are aborted is controlled by the grace period parameter passed to this - method. - - This method may be called at any time and is idempotent. Passing a smaller - grace value than has been passed in a previous call will have the effect of - stopping the Server sooner. Passing a larger grace value than has been - passed in a previous call will not have the effect of stopping the server - later. - - Args: - grace: A duration of time in seconds to allow existing RPCs to complete - before being aborted by this Server's stopping. May be zero for - immediate abortion of all in-progress RPCs. - - Returns: - A threading.Event that will be set when this Server has completely - stopped. The returned event may not be set until after the full grace - period (if some ongoing RPC continues for the full length of the period) - of it may be set much sooner (such as if this Server had no RPCs underway - at the time it was stopped or if all RPCs that it had underway completed - very early in the grace period). - """ + All calls to this method immediately stop service of new RPCs. When existing + RPCs are aborted is controlled by the grace period parameter passed to this + method. + + This method may be called at any time and is idempotent. Passing a smaller + grace value than has been passed in a previous call will have the effect of + stopping the Server sooner. Passing a larger grace value than has been + passed in a previous call will not have the effect of stopping the server + later. + + Args: + grace: A duration of time in seconds to allow existing RPCs to complete + before being aborted by this Server's stopping. May be zero for + immediate abortion of all in-progress RPCs. + + Returns: + A threading.Event that will be set when this Server has completely + stopped. The returned event may not be set until after the full grace + period (if some ongoing RPC continues for the full length of the period) + of it may be set much sooner (such as if this Server had no RPCs underway + at the time it was stopped or if all RPCs that it had underway completed + very early in the grace period). + """ raise NotImplementedError() diff --git a/src/python/grpcio/grpc/beta/utilities.py b/src/python/grpcio/grpc/beta/utilities.py index fe3ce606c9491..90e54715cff72 100644 --- a/src/python/grpcio/grpc/beta/utilities.py +++ b/src/python/grpcio/grpc/beta/utilities.py @@ -23,11 +23,11 @@ from grpc.framework.foundation import future _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE = ( - 'Exception calling connectivity future "done" callback!') + 'Exception calling connectivity future "done" callback!' +) class _ChannelReadyFuture(future.Future): - def __init__(self, channel): self._condition = threading.Condition() self._channel = channel @@ -56,8 +56,10 @@ def _block(self, timeout): def _update(self, connectivity): with self._condition: - if (not self._cancelled and - connectivity is interfaces.ChannelConnectivity.READY): + if ( + not self._cancelled + and connectivity is interfaces.ChannelConnectivity.READY + ): self._matured = True self._channel.unsubscribe(self._update) self._condition.notify_all() @@ -68,7 +70,8 @@ def _update(self, connectivity): for done_callback in done_callbacks: callable_util.call_logging_exceptions( - done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self) + done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self + ) def cancel(self): with self._condition: @@ -83,7 +86,8 @@ def cancel(self): for done_callback in done_callbacks: callable_util.call_logging_exceptions( - done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self) + done_callback, _DONE_CALLBACK_EXCEPTION_LOG_MESSAGE, self + ) return True @@ -132,18 +136,18 @@ def __del__(self): def channel_ready_future(channel): """Creates a future.Future tracking when an implementations.Channel is ready. - Cancelling the returned future.Future does not tell the given - implementations.Channel to abandon attempts it may have been making to - connect; cancelling merely deactivates the return future.Future's - subscription to the given implementations.Channel's connectivity. + Cancelling the returned future.Future does not tell the given + implementations.Channel to abandon attempts it may have been making to + connect; cancelling merely deactivates the return future.Future's + subscription to the given implementations.Channel's connectivity. - Args: - channel: An implementations.Channel. + Args: + channel: An implementations.Channel. - Returns: - A future.Future that matures when the given Channel has connectivity - interfaces.ChannelConnectivity.READY. - """ + Returns: + A future.Future that matures when the given Channel has connectivity + interfaces.ChannelConnectivity.READY. + """ ready_future = _ChannelReadyFuture(channel) ready_future.start() return ready_future diff --git a/src/python/grpcio/grpc/experimental/__init__.py b/src/python/grpcio/grpc/experimental/__init__.py index f0d142c981e43..32a53bf7f3434 100644 --- a/src/python/grpcio/grpc/experimental/__init__.py +++ b/src/python/grpcio/grpc/experimental/__init__.py @@ -30,11 +30,12 @@ class ChannelOptions(object): """Indicates a channel option unique to gRPC Python. - This enumeration is part of an EXPERIMENTAL API. + This enumeration is part of an EXPERIMENTAL API. - Attributes: - SingleThreadedUnaryStream: Perform unary-stream RPCs on a single thread. + Attributes: + SingleThreadedUnaryStream: Perform unary-stream RPCs on a single thread. """ + SingleThreadedUnaryStream = "SingleThreadedUnaryStream" @@ -45,7 +46,8 @@ class UsageError(Exception): # It's important that there be a single insecure credentials object so that its # hash is deterministic and can be used for indexing in the simple stubs cache. _insecure_channel_credentials = grpc.ChannelCredentials( - _cygrpc.channel_credentials_insecure()) + _cygrpc.channel_credentials_insecure() +) def insecure_channel_credentials(): @@ -63,14 +65,16 @@ class ExperimentalApiWarning(Warning): def _warn_experimental(api_name, stack_offset): if api_name not in _EXPERIMENTAL_APIS_USED: _EXPERIMENTAL_APIS_USED.add(api_name) - msg = ("'{}' is an experimental API. It is subject to change or ". - format(api_name) + - "removal between minor releases. Proceed with caution.") + msg = ( + "'{}' is an experimental API. It is subject to change or ".format( + api_name + ) + + "removal between minor releases. Proceed with caution." + ) warnings.warn(msg, ExperimentalApiWarning, stacklevel=2 + stack_offset) def experimental_api(f): - @functools.wraps(f) def _wrapper(*args, **kwargs): _warn_experimental(f.__name__, 1) @@ -109,15 +113,16 @@ def wrap_server_method_handler(wrapper, handler): return handler._replace(stream_unary=wrapper(handler.stream_unary)) else: return handler._replace( - stream_stream=wrapper(handler.stream_stream)) + stream_stream=wrapper(handler.stream_stream) + ) __all__ = ( - 'ChannelOptions', - 'ExperimentalApiWarning', - 'UsageError', - 'insecure_channel_credentials', - 'wrap_server_method_handler', + "ChannelOptions", + "ExperimentalApiWarning", + "UsageError", + "insecure_channel_credentials", + "wrap_server_method_handler", ) if sys.version_info > (3, 6): @@ -125,4 +130,5 @@ def wrap_server_method_handler(wrapper, handler): from grpc._simple_stubs import stream_unary from grpc._simple_stubs import unary_stream from grpc._simple_stubs import unary_unary + __all__ = __all__ + (unary_unary, unary_stream, stream_unary, stream_stream) diff --git a/src/python/grpcio/grpc/framework/common/cardinality.py b/src/python/grpcio/grpc/framework/common/cardinality.py index c98735622d789..3d3d4d3427cb8 100644 --- a/src/python/grpcio/grpc/framework/common/cardinality.py +++ b/src/python/grpcio/grpc/framework/common/cardinality.py @@ -20,7 +20,7 @@ class Cardinality(enum.Enum): """Describes the streaming semantics of an RPC method.""" - UNARY_UNARY = 'request-unary/response-unary' - UNARY_STREAM = 'request-unary/response-streaming' - STREAM_UNARY = 'request-streaming/response-unary' - STREAM_STREAM = 'request-streaming/response-streaming' + UNARY_UNARY = "request-unary/response-unary" + UNARY_STREAM = "request-unary/response-streaming" + STREAM_UNARY = "request-streaming/response-unary" + STREAM_STREAM = "request-streaming/response-streaming" diff --git a/src/python/grpcio/grpc/framework/common/style.py b/src/python/grpcio/grpc/framework/common/style.py index f6138d417ff3f..10bf5f17697af 100644 --- a/src/python/grpcio/grpc/framework/common/style.py +++ b/src/python/grpcio/grpc/framework/common/style.py @@ -20,5 +20,5 @@ class Service(enum.Enum): """Describes the control flow style of RPC method implementation.""" - INLINE = 'inline' - EVENT = 'event' + INLINE = "inline" + EVENT = "event" diff --git a/src/python/grpcio/grpc/framework/foundation/abandonment.py b/src/python/grpcio/grpc/framework/foundation/abandonment.py index 660ce991c418d..c4cb7d5c07254 100644 --- a/src/python/grpcio/grpc/framework/foundation/abandonment.py +++ b/src/python/grpcio/grpc/framework/foundation/abandonment.py @@ -17,6 +17,6 @@ class Abandoned(Exception): """Indicates that some computation is being abandoned. - Abandoning a computation is different than returning a value or raising - an exception indicating some operational or programming defect. - """ + Abandoning a computation is different than returning a value or raising + an exception indicating some operational or programming defect. + """ diff --git a/src/python/grpcio/grpc/framework/foundation/callable_util.py b/src/python/grpcio/grpc/framework/foundation/callable_util.py index 0a638eb62e8ed..b64131b40294f 100644 --- a/src/python/grpcio/grpc/framework/foundation/callable_util.py +++ b/src/python/grpcio/grpc/framework/foundation/callable_util.py @@ -25,14 +25,14 @@ class Outcome(ABC): """A sum type describing the outcome of some call. - Attributes: - kind: One of Kind.RETURNED or Kind.RAISED respectively indicating that the - call returned a value or raised an exception. - return_value: The value returned by the call. Must be present if kind is - Kind.RETURNED. - exception: The exception raised by the call. Must be present if kind is - Kind.RAISED. - """ + Attributes: + kind: One of Kind.RETURNED or Kind.RAISED respectively indicating that the + call returned a value or raised an exception. + return_value: The value returned by the call. Must be present if kind is + Kind.RETURNED. + exception: The exception raised by the call. Must be present if kind is + Kind.RAISED. + """ @enum.unique class Kind(enum.Enum): @@ -43,15 +43,19 @@ class Kind(enum.Enum): class _EasyOutcome( - collections.namedtuple('_EasyOutcome', - ['kind', 'return_value', 'exception']), Outcome): + collections.namedtuple( + "_EasyOutcome", ["kind", "return_value", "exception"] + ), + Outcome, +): """A trivial implementation of Outcome.""" def _call_logging_exceptions(behavior, message, *args, **kwargs): try: - return _EasyOutcome(Outcome.Kind.RETURNED, behavior(*args, **kwargs), - None) + return _EasyOutcome( + Outcome.Kind.RETURNED, behavior(*args, **kwargs), None + ) except Exception as e: # pylint: disable=broad-except _LOGGER.exception(message) return _EasyOutcome(Outcome.Kind.RAISED, None, e) @@ -60,16 +64,16 @@ def _call_logging_exceptions(behavior, message, *args, **kwargs): def with_exceptions_logged(behavior, message): """Wraps a callable in a try-except that logs any exceptions it raises. - Args: - behavior: Any callable. - message: A string to log if the behavior raises an exception. + Args: + behavior: Any callable. + message: A string to log if the behavior raises an exception. - Returns: - A callable that when executed invokes the given behavior. The returned - callable takes the same arguments as the given behavior but returns a - future.Outcome describing whether the given behavior returned a value or - raised an exception. - """ + Returns: + A callable that when executed invokes the given behavior. The returned + callable takes the same arguments as the given behavior but returns a + future.Outcome describing whether the given behavior returned a value or + raised an exception. + """ @functools.wraps(behavior) def wrapped_behavior(*args, **kwargs): @@ -81,14 +85,14 @@ def wrapped_behavior(*args, **kwargs): def call_logging_exceptions(behavior, message, *args, **kwargs): """Calls a behavior in a try-except that logs any exceptions it raises. - Args: - behavior: Any callable. - message: A string to log if the behavior raises an exception. - *args: Positional arguments to pass to the given behavior. - **kwargs: Keyword arguments to pass to the given behavior. + Args: + behavior: Any callable. + message: A string to log if the behavior raises an exception. + *args: Positional arguments to pass to the given behavior. + **kwargs: Keyword arguments to pass to the given behavior. - Returns: - An Outcome describing whether the given behavior returned a value or raised - an exception. - """ + Returns: + An Outcome describing whether the given behavior returned a value or raised + an exception. + """ return _call_logging_exceptions(behavior, message, *args, **kwargs) diff --git a/src/python/grpcio/grpc/framework/foundation/future.py b/src/python/grpcio/grpc/framework/foundation/future.py index c7996aa8a5655..73b0d0bdbe118 100644 --- a/src/python/grpcio/grpc/framework/foundation/future.py +++ b/src/python/grpcio/grpc/framework/foundation/future.py @@ -45,9 +45,9 @@ class CancelledError(Exception): class Future(abc.ABC): """A representation of a computation in another control flow. - Computations represented by a Future may be yet to be begun, may be ongoing, - or may have already completed. - """ + Computations represented by a Future may be yet to be begun, may be ongoing, + or may have already completed. + """ # NOTE(nathaniel): This isn't the return type that I would want to have if it # were up to me. Were this interface being written from scratch, the return @@ -63,17 +63,17 @@ class Future(abc.ABC): def cancel(self): """Attempts to cancel the computation. - This method does not block. - - Returns: - True if the computation has not yet begun, will not be allowed to take - place, and determination of both was possible without blocking. False - under all other circumstances including but not limited to the - computation's already having begun, the computation's already having - finished, and the computation's having been scheduled for execution on a - remote system for which a determination of whether or not it commenced - before being cancelled cannot be made without blocking. - """ + This method does not block. + + Returns: + True if the computation has not yet begun, will not be allowed to take + place, and determination of both was possible without blocking. False + under all other circumstances including but not limited to the + computation's already having begun, the computation's already having + finished, and the computation's having been scheduled for execution on a + remote system for which a determination of whether or not it commenced + before being cancelled cannot be made without blocking. + """ raise NotImplementedError() # NOTE(nathaniel): Here too this isn't the return type that I'd want this @@ -94,27 +94,27 @@ def cancel(self): def cancelled(self): """Describes whether the computation was cancelled. - This method does not block. + This method does not block. - Returns: - True if the computation was cancelled any time before its result became - immediately available. False under all other circumstances including but - not limited to this object's cancel method not having been called and - the computation's result having become immediately available. - """ + Returns: + True if the computation was cancelled any time before its result became + immediately available. False under all other circumstances including but + not limited to this object's cancel method not having been called and + the computation's result having become immediately available. + """ raise NotImplementedError() @abc.abstractmethod def running(self): """Describes whether the computation is taking place. - This method does not block. + This method does not block. - Returns: - True if the computation is scheduled to take place in the future or is - taking place now, or False if the computation took place in the past or - was cancelled. - """ + Returns: + True if the computation is scheduled to take place in the future or is + taking place now, or False if the computation took place in the past or + was cancelled. + """ raise NotImplementedError() # NOTE(nathaniel): These aren't quite the semantics I'd like here either. I @@ -125,95 +125,95 @@ def running(self): def done(self): """Describes whether the computation has taken place. - This method does not block. + This method does not block. - Returns: - True if the computation is known to have either completed or have been - unscheduled or interrupted. False if the computation may possibly be - executing or scheduled to execute later. - """ + Returns: + True if the computation is known to have either completed or have been + unscheduled or interrupted. False if the computation may possibly be + executing or scheduled to execute later. + """ raise NotImplementedError() @abc.abstractmethod def result(self, timeout=None): """Accesses the outcome of the computation or raises its exception. - This method may return immediately or may block. + This method may return immediately or may block. - Args: - timeout: The length of time in seconds to wait for the computation to - finish or be cancelled, or None if this method should block until the - computation has finished or is cancelled no matter how long that takes. + Args: + timeout: The length of time in seconds to wait for the computation to + finish or be cancelled, or None if this method should block until the + computation has finished or is cancelled no matter how long that takes. - Returns: - The return value of the computation. + Returns: + The return value of the computation. - Raises: - TimeoutError: If a timeout value is passed and the computation does not - terminate within the allotted time. - CancelledError: If the computation was cancelled. - Exception: If the computation raised an exception, this call will raise - the same exception. - """ + Raises: + TimeoutError: If a timeout value is passed and the computation does not + terminate within the allotted time. + CancelledError: If the computation was cancelled. + Exception: If the computation raised an exception, this call will raise + the same exception. + """ raise NotImplementedError() @abc.abstractmethod def exception(self, timeout=None): """Return the exception raised by the computation. - This method may return immediately or may block. + This method may return immediately or may block. - Args: - timeout: The length of time in seconds to wait for the computation to - terminate or be cancelled, or None if this method should block until - the computation is terminated or is cancelled no matter how long that - takes. + Args: + timeout: The length of time in seconds to wait for the computation to + terminate or be cancelled, or None if this method should block until + the computation is terminated or is cancelled no matter how long that + takes. - Returns: - The exception raised by the computation, or None if the computation did - not raise an exception. + Returns: + The exception raised by the computation, or None if the computation did + not raise an exception. - Raises: - TimeoutError: If a timeout value is passed and the computation does not - terminate within the allotted time. - CancelledError: If the computation was cancelled. - """ + Raises: + TimeoutError: If a timeout value is passed and the computation does not + terminate within the allotted time. + CancelledError: If the computation was cancelled. + """ raise NotImplementedError() @abc.abstractmethod def traceback(self, timeout=None): """Access the traceback of the exception raised by the computation. - This method may return immediately or may block. + This method may return immediately or may block. - Args: - timeout: The length of time in seconds to wait for the computation to - terminate or be cancelled, or None if this method should block until - the computation is terminated or is cancelled no matter how long that - takes. + Args: + timeout: The length of time in seconds to wait for the computation to + terminate or be cancelled, or None if this method should block until + the computation is terminated or is cancelled no matter how long that + takes. - Returns: - The traceback of the exception raised by the computation, or None if the - computation did not raise an exception. + Returns: + The traceback of the exception raised by the computation, or None if the + computation did not raise an exception. - Raises: - TimeoutError: If a timeout value is passed and the computation does not - terminate within the allotted time. - CancelledError: If the computation was cancelled. - """ + Raises: + TimeoutError: If a timeout value is passed and the computation does not + terminate within the allotted time. + CancelledError: If the computation was cancelled. + """ raise NotImplementedError() @abc.abstractmethod def add_done_callback(self, fn): """Adds a function to be called at completion of the computation. - The callback will be passed this Future object describing the outcome of - the computation. + The callback will be passed this Future object describing the outcome of + the computation. - If the computation has already completed, the callback will be called - immediately. + If the computation has already completed, the callback will be called + immediately. - Args: - fn: A callable taking this Future object as its single parameter. - """ + Args: + fn: A callable taking this Future object as its single parameter. + """ raise NotImplementedError() diff --git a/src/python/grpcio/grpc/framework/foundation/logging_pool.py b/src/python/grpcio/grpc/framework/foundation/logging_pool.py index 53d2cd008254c..a4e140f174de5 100644 --- a/src/python/grpcio/grpc/framework/foundation/logging_pool.py +++ b/src/python/grpcio/grpc/framework/foundation/logging_pool.py @@ -27,8 +27,9 @@ def _wrapping(*args, **kwargs): return behavior(*args, **kwargs) except Exception: _LOGGER.exception( - 'Unexpected exception from %s executed in logging pool!', - behavior) + "Unexpected exception from %s executed in logging pool!", + behavior, + ) raise return _wrapping @@ -50,9 +51,9 @@ def submit(self, fn, *args, **kwargs): return self._backing_pool.submit(_wrap(fn), *args, **kwargs) def map(self, func, *iterables, **kwargs): - return self._backing_pool.map(_wrap(func), - *iterables, - timeout=kwargs.get('timeout', None)) + return self._backing_pool.map( + _wrap(func), *iterables, timeout=kwargs.get("timeout", None) + ) def shutdown(self, wait=True): self._backing_pool.shutdown(wait=wait) @@ -61,11 +62,11 @@ def shutdown(self, wait=True): def pool(max_workers): """Creates a thread pool that logs exceptions raised by the tasks within it. - Args: - max_workers: The maximum number of worker threads to allow the pool. + Args: + max_workers: The maximum number of worker threads to allow the pool. - Returns: - A futures.ThreadPoolExecutor-compatible thread pool that logs exceptions - raised by the tasks executed within it. - """ + Returns: + A futures.ThreadPoolExecutor-compatible thread pool that logs exceptions + raised by the tasks executed within it. + """ return _LoggingPool(futures.ThreadPoolExecutor(max_workers)) diff --git a/src/python/grpcio/grpc/framework/foundation/stream.py b/src/python/grpcio/grpc/framework/foundation/stream.py index 150a22435eedb..70ca1d915756a 100644 --- a/src/python/grpcio/grpc/framework/foundation/stream.py +++ b/src/python/grpcio/grpc/framework/foundation/stream.py @@ -23,9 +23,9 @@ class Consumer(abc.ABC): def consume(self, value): """Accepts a value. - Args: - value: Any value accepted by this Consumer. - """ + Args: + value: Any value accepted by this Consumer. + """ raise NotImplementedError() @abc.abstractmethod @@ -37,7 +37,7 @@ def terminate(self): def consume_and_terminate(self, value): """Supplies a value and signals that no more values will be supplied. - Args: - value: Any value accepted by this Consumer. - """ + Args: + value: Any value accepted by this Consumer. + """ raise NotImplementedError() diff --git a/src/python/grpcio/grpc/framework/interfaces/base/base.py b/src/python/grpcio/grpc/framework/interfaces/base/base.py index 8caee325c2c50..d1c0b07911693 100644 --- a/src/python/grpcio/grpc/framework/interfaces/base/base.py +++ b/src/python/grpcio/grpc/framework/interfaces/base/base.py @@ -56,37 +56,37 @@ def __init__(self, code, details): class Outcome(object): """The outcome of an operation. - Attributes: - kind: A Kind value coarsely identifying how the operation terminated. - code: An application-specific code value or None if no such value was - provided. - details: An application-specific details value or None if no such value was - provided. - """ + Attributes: + kind: A Kind value coarsely identifying how the operation terminated. + code: An application-specific code value or None if no such value was + provided. + details: An application-specific details value or None if no such value was + provided. + """ @enum.unique class Kind(enum.Enum): """Ways in which an operation can terminate.""" - COMPLETED = 'completed' - CANCELLED = 'cancelled' - EXPIRED = 'expired' - LOCAL_SHUTDOWN = 'local shutdown' - REMOTE_SHUTDOWN = 'remote shutdown' - RECEPTION_FAILURE = 'reception failure' - TRANSMISSION_FAILURE = 'transmission failure' - LOCAL_FAILURE = 'local failure' - REMOTE_FAILURE = 'remote failure' + COMPLETED = "completed" + CANCELLED = "cancelled" + EXPIRED = "expired" + LOCAL_SHUTDOWN = "local shutdown" + REMOTE_SHUTDOWN = "remote shutdown" + RECEPTION_FAILURE = "reception failure" + TRANSMISSION_FAILURE = "transmission failure" + LOCAL_FAILURE = "local failure" + REMOTE_FAILURE = "remote failure" class Completion(abc.ABC): """An aggregate of the values exchanged upon operation completion. - Attributes: - terminal_metadata: A terminal metadata value for the operaton. - code: A code value for the operation. - message: A message value for the operation. - """ + Attributes: + terminal_metadata: A terminal metadata value for the operaton. + code: A code value for the operation. + message: A message value for the operation. + """ class OperationContext(abc.ABC): @@ -96,37 +96,37 @@ class OperationContext(abc.ABC): def outcome(self): """Indicates the operation's outcome (or that the operation is ongoing). - Returns: - None if the operation is still active or the Outcome value for the - operation if it has terminated. - """ + Returns: + None if the operation is still active or the Outcome value for the + operation if it has terminated. + """ raise NotImplementedError() @abc.abstractmethod def add_termination_callback(self, callback): """Adds a function to be called upon operation termination. - Args: - callback: A callable to be passed an Outcome value on operation - termination. - - Returns: - None if the operation has not yet terminated and the passed callback will - later be called when it does terminate, or if the operation has already - terminated an Outcome value describing the operation termination and the - passed callback will not be called as a result of this method call. - """ + Args: + callback: A callable to be passed an Outcome value on operation + termination. + + Returns: + None if the operation has not yet terminated and the passed callback will + later be called when it does terminate, or if the operation has already + terminated an Outcome value describing the operation termination and the + passed callback will not be called as a result of this method call. + """ raise NotImplementedError() @abc.abstractmethod def time_remaining(self): """Describes the length of allowed time remaining for the operation. - Returns: - A nonnegative float indicating the length of allowed time in seconds - remaining for the operation to complete before it is considered to have - timed out. Zero is returned if the operation has terminated. - """ + Returns: + A nonnegative float indicating the length of allowed time in seconds + remaining for the operation to complete before it is considered to have + timed out. Zero is returned if the operation has terminated. + """ raise NotImplementedError() @abc.abstractmethod @@ -138,9 +138,9 @@ def cancel(self): def fail(self, exception): """Indicates that the operation has failed. - Args: - exception: An exception germane to the operation failure. May be None. - """ + Args: + exception: An exception germane to the operation failure. May be None. + """ raise NotImplementedError() @@ -148,23 +148,25 @@ class Operator(abc.ABC): """An interface through which to participate in an operation.""" @abc.abstractmethod - def advance(self, - initial_metadata=None, - payload=None, - completion=None, - allowance=None): + def advance( + self, + initial_metadata=None, + payload=None, + completion=None, + allowance=None, + ): """Progresses the operation. - Args: - initial_metadata: An initial metadata value. Only one may ever be - communicated in each direction for an operation, and they must be - communicated no later than either the first payload or the completion. - payload: A payload value. - completion: A Completion value. May only ever be non-None once in either - direction, and no payloads may be passed after it has been communicated. - allowance: A positive integer communicating the number of additional - payloads allowed to be passed by the remote side of the operation. - """ + Args: + initial_metadata: An initial metadata value. Only one may ever be + communicated in each direction for an operation, and they must be + communicated no later than either the first payload or the completion. + payload: A payload value. + completion: A Completion value. May only ever be non-None once in either + direction, and no payloads may be passed after it has been communicated. + allowance: A positive integer communicating the number of additional + payloads allowed to be passed by the remote side of the operation. + """ raise NotImplementedError() @@ -175,37 +177,36 @@ class ProtocolReceiver(abc.ABC): def context(self, protocol_context): """Accepts the protocol context object for the operation. - Args: - protocol_context: The protocol context object for the operation. - """ + Args: + protocol_context: The protocol context object for the operation. + """ raise NotImplementedError() class Subscription(abc.ABC): """Describes customer code's interest in values from the other side. - Attributes: - kind: A Kind value describing the overall kind of this value. - termination_callback: A callable to be passed the Outcome associated with - the operation after it has terminated. Must be non-None if kind is - Kind.TERMINATION_ONLY. Must be None otherwise. - allowance: A callable behavior that accepts positive integers representing - the number of additional payloads allowed to be passed to the other side - of the operation. Must be None if kind is Kind.FULL. Must not be None - otherwise. - operator: An Operator to be passed values from the other side of the - operation. Must be non-None if kind is Kind.FULL. Must be None otherwise. - protocol_receiver: A ProtocolReceiver to be passed protocol objects as they - become available during the operation. Must be non-None if kind is - Kind.FULL. - """ + Attributes: + kind: A Kind value describing the overall kind of this value. + termination_callback: A callable to be passed the Outcome associated with + the operation after it has terminated. Must be non-None if kind is + Kind.TERMINATION_ONLY. Must be None otherwise. + allowance: A callable behavior that accepts positive integers representing + the number of additional payloads allowed to be passed to the other side + of the operation. Must be None if kind is Kind.FULL. Must not be None + otherwise. + operator: An Operator to be passed values from the other side of the + operation. Must be non-None if kind is Kind.FULL. Must be None otherwise. + protocol_receiver: A ProtocolReceiver to be passed protocol objects as they + become available during the operation. Must be non-None if kind is + Kind.FULL. + """ @enum.unique class Kind(enum.Enum): - - NONE = 'none' - TERMINATION_ONLY = 'termination only' - FULL = 'full' + NONE = "none" + TERMINATION_ONLY = "termination only" + FULL = "full" class Servicer(abc.ABC): @@ -215,24 +216,24 @@ class Servicer(abc.ABC): def service(self, group, method, context, output_operator): """Services an operation. - Args: - group: The group identifier of the operation to be serviced. - method: The method identifier of the operation to be serviced. - context: An OperationContext object affording contextual information and - actions. - output_operator: An Operator that will accept output values of the - operation. - - Returns: - A Subscription via which this object may or may not accept more values of - the operation. - - Raises: - NoSuchMethodError: If this Servicer does not handle operations with the - given group and method. - abandonment.Abandoned: If the operation has been aborted and there no - longer is any reason to service the operation. - """ + Args: + group: The group identifier of the operation to be serviced. + method: The method identifier of the operation to be serviced. + context: An OperationContext object affording contextual information and + actions. + output_operator: An Operator that will accept output values of the + operation. + + Returns: + A Subscription via which this object may or may not accept more values of + the operation. + + Raises: + NoSuchMethodError: If this Servicer does not handle operations with the + given group and method. + abandonment.Abandoned: If the operation has been aborted and there no + longer is any reason to service the operation. + """ raise NotImplementedError() @@ -248,78 +249,80 @@ def start(self): def stop(self, grace): """Stops this object's service of operations. - This object will refuse service of new operations as soon as this method is - called but operations under way at the time of the call may be given a - grace period during which they are allowed to finish. - - Args: - grace: A duration of time in seconds to allow ongoing operations to - terminate before being forcefully terminated by the stopping of this - End. May be zero to terminate all ongoing operations and immediately - stop. - - Returns: - A threading.Event that will be set to indicate all operations having - terminated and this End having completely stopped. The returned event - may not be set until after the full grace period (if some ongoing - operation continues for the full length of the period) or it may be set - much sooner (if for example this End had no operations in progress at - the time its stop method was called). - """ + This object will refuse service of new operations as soon as this method is + called but operations under way at the time of the call may be given a + grace period during which they are allowed to finish. + + Args: + grace: A duration of time in seconds to allow ongoing operations to + terminate before being forcefully terminated by the stopping of this + End. May be zero to terminate all ongoing operations and immediately + stop. + + Returns: + A threading.Event that will be set to indicate all operations having + terminated and this End having completely stopped. The returned event + may not be set until after the full grace period (if some ongoing + operation continues for the full length of the period) or it may be set + much sooner (if for example this End had no operations in progress at + the time its stop method was called). + """ raise NotImplementedError() @abc.abstractmethod - def operate(self, - group, - method, - subscription, - timeout, - initial_metadata=None, - payload=None, - completion=None, - protocol_options=None): + def operate( + self, + group, + method, + subscription, + timeout, + initial_metadata=None, + payload=None, + completion=None, + protocol_options=None, + ): """Commences an operation. - Args: - group: The group identifier of the invoked operation. - method: The method identifier of the invoked operation. - subscription: A Subscription to which the results of the operation will be - passed. - timeout: A length of time in seconds to allow for the operation. - initial_metadata: An initial metadata value to be sent to the other side - of the operation. May be None if the initial metadata will be later - passed via the returned operator or if there will be no initial metadata - passed at all. - payload: An initial payload for the operation. - completion: A Completion value indicating the end of transmission to the - other side of the operation. - protocol_options: A value specified by the provider of a Base interface - implementation affording custom state and behavior. - - Returns: - A pair of objects affording information about the operation and action - continuing the operation. The first element of the returned pair is an - OperationContext for the operation and the second element of the - returned pair is an Operator to which operation values not passed in - this call should later be passed. - """ + Args: + group: The group identifier of the invoked operation. + method: The method identifier of the invoked operation. + subscription: A Subscription to which the results of the operation will be + passed. + timeout: A length of time in seconds to allow for the operation. + initial_metadata: An initial metadata value to be sent to the other side + of the operation. May be None if the initial metadata will be later + passed via the returned operator or if there will be no initial metadata + passed at all. + payload: An initial payload for the operation. + completion: A Completion value indicating the end of transmission to the + other side of the operation. + protocol_options: A value specified by the provider of a Base interface + implementation affording custom state and behavior. + + Returns: + A pair of objects affording information about the operation and action + continuing the operation. The first element of the returned pair is an + OperationContext for the operation and the second element of the + returned pair is an Operator to which operation values not passed in + this call should later be passed. + """ raise NotImplementedError() @abc.abstractmethod def operation_stats(self): """Reports the number of terminated operations broken down by outcome. - Returns: - A dictionary from Outcome.Kind value to an integer identifying the number - of operations that terminated with that outcome kind. - """ + Returns: + A dictionary from Outcome.Kind value to an integer identifying the number + of operations that terminated with that outcome kind. + """ raise NotImplementedError() @abc.abstractmethod def add_idle_action(self, action): """Adds an action to be called when this End has no ongoing operations. - Args: - action: A callable that accepts no arguments. - """ + Args: + action: A callable that accepts no arguments. + """ raise NotImplementedError() diff --git a/src/python/grpcio/grpc/framework/interfaces/base/utilities.py b/src/python/grpcio/grpc/framework/interfaces/base/utilities.py index 281db62b5d4dd..d188339b1eb01 100644 --- a/src/python/grpcio/grpc/framework/interfaces/base/utilities.py +++ b/src/python/grpcio/grpc/framework/interfaces/base/utilities.py @@ -18,54 +18,66 @@ from grpc.framework.interfaces.base import base -class _Completion(base.Completion, - collections.namedtuple('_Completion', ( - 'terminal_metadata', - 'code', - 'message', - ))): +class _Completion( + base.Completion, + collections.namedtuple( + "_Completion", + ( + "terminal_metadata", + "code", + "message", + ), + ), +): """A trivial implementation of base.Completion.""" -class _Subscription(base.Subscription, - collections.namedtuple('_Subscription', ( - 'kind', - 'termination_callback', - 'allowance', - 'operator', - 'protocol_receiver', - ))): +class _Subscription( + base.Subscription, + collections.namedtuple( + "_Subscription", + ( + "kind", + "termination_callback", + "allowance", + "operator", + "protocol_receiver", + ), + ), +): """A trivial implementation of base.Subscription.""" -_NONE_SUBSCRIPTION = _Subscription(base.Subscription.Kind.NONE, None, None, - None, None) +_NONE_SUBSCRIPTION = _Subscription( + base.Subscription.Kind.NONE, None, None, None, None +) def completion(terminal_metadata, code, message): """Creates a base.Completion aggregating the given operation values. - Args: - terminal_metadata: A terminal metadata value for an operaton. - code: A code value for an operation. - message: A message value for an operation. + Args: + terminal_metadata: A terminal metadata value for an operaton. + code: A code value for an operation. + message: A message value for an operation. - Returns: - A base.Completion aggregating the given operation values. - """ + Returns: + A base.Completion aggregating the given operation values. + """ return _Completion(terminal_metadata, code, message) def full_subscription(operator, protocol_receiver): """Creates a "full" base.Subscription for the given base.Operator. - Args: - operator: A base.Operator to be used in an operation. - protocol_receiver: A base.ProtocolReceiver to be used in an operation. - - Returns: - A base.Subscription of kind base.Subscription.Kind.FULL wrapping the given - base.Operator and base.ProtocolReceiver. - """ - return _Subscription(base.Subscription.Kind.FULL, None, None, operator, - protocol_receiver) + Args: + operator: A base.Operator to be used in an operation. + protocol_receiver: A base.ProtocolReceiver to be used in an operation. + + Returns: + A base.Subscription of kind base.Subscription.Kind.FULL wrapping the given + base.Operator and base.ProtocolReceiver. + """ + return _Subscription( + base.Subscription.Kind.FULL, None, None, operator, protocol_receiver + ) diff --git a/src/python/grpcio/grpc/framework/interfaces/face/face.py b/src/python/grpcio/grpc/framework/interfaces/face/face.py index ed0de6a7de262..9239fcc9eb996 100644 --- a/src/python/grpcio/grpc/framework/interfaces/face/face.py +++ b/src/python/grpcio/grpc/framework/interfaces/face/face.py @@ -30,62 +30,66 @@ class NoSuchMethodError(Exception): """Raised by customer code to indicate an unrecognized method. - Attributes: - group: The group of the unrecognized method. - name: The name of the unrecognized method. - """ + Attributes: + group: The group of the unrecognized method. + name: The name of the unrecognized method. + """ def __init__(self, group, method): """Constructor. - Args: - group: The group identifier of the unrecognized RPC name. - method: The method identifier of the unrecognized RPC name. - """ + Args: + group: The group identifier of the unrecognized RPC name. + method: The method identifier of the unrecognized RPC name. + """ super(NoSuchMethodError, self).__init__() self.group = group self.method = method def __repr__(self): - return 'face.NoSuchMethodError(%s, %s)' % ( + return "face.NoSuchMethodError(%s, %s)" % ( self.group, self.method, ) class Abortion( - collections.namedtuple('Abortion', ( - 'kind', - 'initial_metadata', - 'terminal_metadata', - 'code', - 'details', - ))): + collections.namedtuple( + "Abortion", + ( + "kind", + "initial_metadata", + "terminal_metadata", + "code", + "details", + ), + ) +): """A value describing RPC abortion. - Attributes: - kind: A Kind value identifying how the RPC failed. - initial_metadata: The initial metadata from the other side of the RPC or - None if no initial metadata value was received. - terminal_metadata: The terminal metadata from the other side of the RPC or - None if no terminal metadata value was received. - code: The code value from the other side of the RPC or None if no code value - was received. - details: The details value from the other side of the RPC or None if no - details value was received. - """ + Attributes: + kind: A Kind value identifying how the RPC failed. + initial_metadata: The initial metadata from the other side of the RPC or + None if no initial metadata value was received. + terminal_metadata: The terminal metadata from the other side of the RPC or + None if no terminal metadata value was received. + code: The code value from the other side of the RPC or None if no code value + was received. + details: The details value from the other side of the RPC or None if no + details value was received. + """ @enum.unique class Kind(enum.Enum): """Types of RPC abortion.""" - CANCELLED = 'cancelled' - EXPIRED = 'expired' - LOCAL_SHUTDOWN = 'local shutdown' - REMOTE_SHUTDOWN = 'remote shutdown' - NETWORK_FAILURE = 'network failure' - LOCAL_FAILURE = 'local failure' - REMOTE_FAILURE = 'remote failure' + CANCELLED = "cancelled" + EXPIRED = "expired" + LOCAL_SHUTDOWN = "local shutdown" + REMOTE_SHUTDOWN = "remote shutdown" + NETWORK_FAILURE = "network failure" + LOCAL_FAILURE = "local failure" + REMOTE_FAILURE = "remote failure" class AbortionError(Exception, metaclass=abc.ABCMeta): @@ -99,7 +103,7 @@ class AbortionError(Exception, metaclass=abc.ABCMeta): was received. details: The details value from the other side of the RPC or None if no details value was received. - """ + """ def __init__(self, initial_metadata, terminal_metadata, code, details): super(AbortionError, self).__init__() @@ -109,8 +113,11 @@ def __init__(self, initial_metadata, terminal_metadata, code, details): self.details = details def __str__(self): - return '%s(code=%s, details="%s")' % (self.__class__.__name__, - self.code, self.details) + return '%s(code=%s, details="%s")' % ( + self.__class__.__name__, + self.code, + self.details, + ) class CancellationError(AbortionError): @@ -153,39 +160,39 @@ def is_active(self): def time_remaining(self): """Describes the length of allowed time remaining for the RPC. - Returns: - A nonnegative float indicating the length of allowed time in seconds - remaining for the RPC to complete before it is considered to have timed - out. - """ + Returns: + A nonnegative float indicating the length of allowed time in seconds + remaining for the RPC to complete before it is considered to have timed + out. + """ raise NotImplementedError() @abc.abstractmethod def add_abortion_callback(self, abortion_callback): """Registers a callback to be called if the RPC is aborted. - Args: - abortion_callback: A callable to be called and passed an Abortion value - in the event of RPC abortion. - """ + Args: + abortion_callback: A callable to be called and passed an Abortion value + in the event of RPC abortion. + """ raise NotImplementedError() @abc.abstractmethod def cancel(self): """Cancels the RPC. - Idempotent and has no effect if the RPC has already terminated. - """ + Idempotent and has no effect if the RPC has already terminated. + """ raise NotImplementedError() @abc.abstractmethod def protocol_context(self): """Accesses a custom object specified by an implementation provider. - Returns: - A value specified by the provider of a Face interface implementation - affording custom state and behavior. - """ + Returns: + A value specified by the provider of a Face interface implementation + affording custom state and behavior. + """ raise NotImplementedError() @@ -196,52 +203,52 @@ class Call(RpcContext, metaclass=abc.ABCMeta): def initial_metadata(self): """Accesses the initial metadata from the service-side of the RPC. - This method blocks until the value is available or is known not to have been - emitted from the service-side of the RPC. + This method blocks until the value is available or is known not to have been + emitted from the service-side of the RPC. - Returns: - The initial metadata object emitted by the service-side of the RPC, or - None if there was no such value. - """ + Returns: + The initial metadata object emitted by the service-side of the RPC, or + None if there was no such value. + """ raise NotImplementedError() @abc.abstractmethod def terminal_metadata(self): """Accesses the terminal metadata from the service-side of the RPC. - This method blocks until the value is available or is known not to have been - emitted from the service-side of the RPC. + This method blocks until the value is available or is known not to have been + emitted from the service-side of the RPC. - Returns: - The terminal metadata object emitted by the service-side of the RPC, or - None if there was no such value. - """ + Returns: + The terminal metadata object emitted by the service-side of the RPC, or + None if there was no such value. + """ raise NotImplementedError() @abc.abstractmethod def code(self): """Accesses the code emitted by the service-side of the RPC. - This method blocks until the value is available or is known not to have been - emitted from the service-side of the RPC. + This method blocks until the value is available or is known not to have been + emitted from the service-side of the RPC. - Returns: - The code object emitted by the service-side of the RPC, or None if there - was no such value. - """ + Returns: + The code object emitted by the service-side of the RPC, or None if there + was no such value. + """ raise NotImplementedError() @abc.abstractmethod def details(self): """Accesses the details value emitted by the service-side of the RPC. - This method blocks until the value is available or is known not to have been - emitted from the service-side of the RPC. + This method blocks until the value is available or is known not to have been + emitted from the service-side of the RPC. - Returns: - The details value emitted by the service-side of the RPC, or None if there - was no such value. - """ + Returns: + The details value emitted by the service-side of the RPC, or None if there + was no such value. + """ raise NotImplementedError() @@ -252,65 +259,65 @@ class ServicerContext(RpcContext, metaclass=abc.ABCMeta): def invocation_metadata(self): """Accesses the metadata from the invocation-side of the RPC. - This method blocks until the value is available or is known not to have been - emitted from the invocation-side of the RPC. + This method blocks until the value is available or is known not to have been + emitted from the invocation-side of the RPC. - Returns: - The metadata object emitted by the invocation-side of the RPC, or None if - there was no such value. - """ + Returns: + The metadata object emitted by the invocation-side of the RPC, or None if + there was no such value. + """ raise NotImplementedError() @abc.abstractmethod def initial_metadata(self, initial_metadata): """Accepts the service-side initial metadata value of the RPC. - This method need not be called by method implementations if they have no - service-side initial metadata to transmit. + This method need not be called by method implementations if they have no + service-side initial metadata to transmit. - Args: - initial_metadata: The service-side initial metadata value of the RPC to - be transmitted to the invocation side of the RPC. - """ + Args: + initial_metadata: The service-side initial metadata value of the RPC to + be transmitted to the invocation side of the RPC. + """ raise NotImplementedError() @abc.abstractmethod def terminal_metadata(self, terminal_metadata): """Accepts the service-side terminal metadata value of the RPC. - This method need not be called by method implementations if they have no - service-side terminal metadata to transmit. + This method need not be called by method implementations if they have no + service-side terminal metadata to transmit. - Args: - terminal_metadata: The service-side terminal metadata value of the RPC to - be transmitted to the invocation side of the RPC. - """ + Args: + terminal_metadata: The service-side terminal metadata value of the RPC to + be transmitted to the invocation side of the RPC. + """ raise NotImplementedError() @abc.abstractmethod def code(self, code): """Accepts the service-side code of the RPC. - This method need not be called by method implementations if they have no - code to transmit. + This method need not be called by method implementations if they have no + code to transmit. - Args: - code: The code of the RPC to be transmitted to the invocation side of the - RPC. - """ + Args: + code: The code of the RPC to be transmitted to the invocation side of the + RPC. + """ raise NotImplementedError() @abc.abstractmethod def details(self, details): """Accepts the service-side details of the RPC. - This method need not be called by method implementations if they have no - service-side details to transmit. + This method need not be called by method implementations if they have no + service-side details to transmit. - Args: - details: The service-side details value of the RPC to be transmitted to - the invocation side of the RPC. - """ + Args: + details: The service-side details value of the RPC to be transmitted to + the invocation side of the RPC. + """ raise NotImplementedError() @@ -321,31 +328,31 @@ class ResponseReceiver(abc.ABC): def initial_metadata(self, initial_metadata): """Receives the initial metadata from the service-side of the RPC. - Args: - initial_metadata: The initial metadata object emitted from the - service-side of the RPC. - """ + Args: + initial_metadata: The initial metadata object emitted from the + service-side of the RPC. + """ raise NotImplementedError() @abc.abstractmethod def response(self, response): """Receives a response from the service-side of the RPC. - Args: - response: A response object emitted from the service-side of the RPC. - """ + Args: + response: A response object emitted from the service-side of the RPC. + """ raise NotImplementedError() @abc.abstractmethod def complete(self, terminal_metadata, code, details): """Receives the completion values emitted from the service-side of the RPC. - Args: - terminal_metadata: The terminal metadata object emitted from the - service-side of the RPC. - code: The code object emitted from the service-side of the RPC. - details: The details object emitted from the service-side of the RPC. - """ + Args: + terminal_metadata: The terminal metadata object emitted from the + service-side of the RPC. + code: The code object emitted from the service-side of the RPC. + details: The details object emitted from the service-side of the RPC. + """ raise NotImplementedError() @@ -353,77 +360,81 @@ class UnaryUnaryMultiCallable(abc.ABC): """Affords invoking a unary-unary RPC in any call style.""" @abc.abstractmethod - def __call__(self, - request, - timeout, - metadata=None, - with_call=False, - protocol_options=None): + def __call__( + self, + request, + timeout, + metadata=None, + with_call=False, + protocol_options=None, + ): """Synchronously invokes the underlying RPC. - Args: - request: The request value for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - with_call: Whether or not to include return a Call for the RPC in addition - to the response. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - The response value for the RPC, and a Call for the RPC if with_call was - set to True at invocation. - - Raises: - AbortionError: Indicating that the RPC was aborted. - """ + Args: + request: The request value for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + with_call: Whether or not to include return a Call for the RPC in addition + to the response. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + The response value for the RPC, and a Call for the RPC if with_call was + set to True at invocation. + + Raises: + AbortionError: Indicating that the RPC was aborted. + """ raise NotImplementedError() @abc.abstractmethod def future(self, request, timeout, metadata=None, protocol_options=None): """Asynchronously invokes the underlying RPC. - Args: - request: The request value for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - An object that is both a Call for the RPC and a future.Future. In the - event of RPC completion, the return Future's result value will be the - response value of the RPC. In the event of RPC abortion, the returned - Future's exception value will be an AbortionError. - """ + Args: + request: The request value for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + An object that is both a Call for the RPC and a future.Future. In the + event of RPC completion, the return Future's result value will be the + response value of the RPC. In the event of RPC abortion, the returned + Future's exception value will be an AbortionError. + """ raise NotImplementedError() @abc.abstractmethod - def event(self, - request, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event( + self, + request, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): """Asynchronously invokes the underlying RPC. - Args: - request: The request value for the RPC. - receiver: A ResponseReceiver to be passed the response data of the RPC. - abortion_callback: A callback to be called and passed an Abortion value - in the event of RPC abortion. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - A Call for the RPC. - """ + Args: + request: The request value for the RPC. + receiver: A ResponseReceiver to be passed the response data of the RPC. + abortion_callback: A callback to be called and passed an Abortion value + in the event of RPC abortion. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + A Call for the RPC. + """ raise NotImplementedError() @@ -434,45 +445,47 @@ class UnaryStreamMultiCallable(abc.ABC): def __call__(self, request, timeout, metadata=None, protocol_options=None): """Invokes the underlying RPC. - Args: - request: The request value for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - An object that is both a Call for the RPC and an iterator of response - values. Drawing response values from the returned iterator may raise - AbortionError indicating abortion of the RPC. - """ + Args: + request: The request value for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + An object that is both a Call for the RPC and an iterator of response + values. Drawing response values from the returned iterator may raise + AbortionError indicating abortion of the RPC. + """ raise NotImplementedError() @abc.abstractmethod - def event(self, - request, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event( + self, + request, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): """Asynchronously invokes the underlying RPC. - Args: - request: The request value for the RPC. - receiver: A ResponseReceiver to be passed the response data of the RPC. - abortion_callback: A callback to be called and passed an Abortion value - in the event of RPC abortion. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - A Call object for the RPC. - """ + Args: + request: The request value for the RPC. + receiver: A ResponseReceiver to be passed the response data of the RPC. + abortion_callback: A callback to be called and passed an Abortion value + in the event of RPC abortion. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + A Call object for the RPC. + """ raise NotImplementedError() @@ -480,80 +493,82 @@ class StreamUnaryMultiCallable(abc.ABC): """Affords invoking a stream-unary RPC in any call style.""" @abc.abstractmethod - def __call__(self, - request_iterator, - timeout, - metadata=None, - with_call=False, - protocol_options=None): + def __call__( + self, + request_iterator, + timeout, + metadata=None, + with_call=False, + protocol_options=None, + ): """Synchronously invokes the underlying RPC. - Args: - request_iterator: An iterator that yields request values for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - with_call: Whether or not to include return a Call for the RPC in addition - to the response. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - The response value for the RPC, and a Call for the RPC if with_call was - set to True at invocation. - - Raises: - AbortionError: Indicating that the RPC was aborted. - """ + Args: + request_iterator: An iterator that yields request values for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + with_call: Whether or not to include return a Call for the RPC in addition + to the response. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + The response value for the RPC, and a Call for the RPC if with_call was + set to True at invocation. + + Raises: + AbortionError: Indicating that the RPC was aborted. + """ raise NotImplementedError() @abc.abstractmethod - def future(self, - request_iterator, - timeout, - metadata=None, - protocol_options=None): + def future( + self, request_iterator, timeout, metadata=None, protocol_options=None + ): """Asynchronously invokes the underlying RPC. - Args: - request_iterator: An iterator that yields request values for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - An object that is both a Call for the RPC and a future.Future. In the - event of RPC completion, the return Future's result value will be the - response value of the RPC. In the event of RPC abortion, the returned - Future's exception value will be an AbortionError. - """ + Args: + request_iterator: An iterator that yields request values for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + An object that is both a Call for the RPC and a future.Future. In the + event of RPC completion, the return Future's result value will be the + response value of the RPC. In the event of RPC abortion, the returned + Future's exception value will be an AbortionError. + """ raise NotImplementedError() @abc.abstractmethod - def event(self, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event( + self, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): """Asynchronously invokes the underlying RPC. - Args: - receiver: A ResponseReceiver to be passed the response data of the RPC. - abortion_callback: A callback to be called and passed an Abortion value - in the event of RPC abortion. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - A single object that is both a Call object for the RPC and a - stream.Consumer to which the request values of the RPC should be passed. - """ + Args: + receiver: A ResponseReceiver to be passed the response data of the RPC. + abortion_callback: A callback to be called and passed an Abortion value + in the event of RPC abortion. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + A single object that is both a Call object for the RPC and a + stream.Consumer to which the request values of the RPC should be passed. + """ raise NotImplementedError() @@ -561,97 +576,97 @@ class StreamStreamMultiCallable(abc.ABC): """Affords invoking a stream-stream RPC in any call style.""" @abc.abstractmethod - def __call__(self, - request_iterator, - timeout, - metadata=None, - protocol_options=None): + def __call__( + self, request_iterator, timeout, metadata=None, protocol_options=None + ): """Invokes the underlying RPC. - Args: - request_iterator: An iterator that yields request values for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - An object that is both a Call for the RPC and an iterator of response - values. Drawing response values from the returned iterator may raise - AbortionError indicating abortion of the RPC. - """ + Args: + request_iterator: An iterator that yields request values for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + An object that is both a Call for the RPC and an iterator of response + values. Drawing response values from the returned iterator may raise + AbortionError indicating abortion of the RPC. + """ raise NotImplementedError() @abc.abstractmethod - def event(self, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event( + self, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): """Asynchronously invokes the underlying RPC. - Args: - receiver: A ResponseReceiver to be passed the response data of the RPC. - abortion_callback: A callback to be called and passed an Abortion value - in the event of RPC abortion. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of - the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - A single object that is both a Call object for the RPC and a - stream.Consumer to which the request values of the RPC should be passed. - """ + Args: + receiver: A ResponseReceiver to be passed the response data of the RPC. + abortion_callback: A callback to be called and passed an Abortion value + in the event of RPC abortion. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of + the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + A single object that is both a Call object for the RPC and a + stream.Consumer to which the request values of the RPC should be passed. + """ raise NotImplementedError() class MethodImplementation(abc.ABC): """A sum type that describes a method implementation. - Attributes: - cardinality: A cardinality.Cardinality value. - style: A style.Service value. - unary_unary_inline: The implementation of the method as a callable value - that takes a request value and a ServicerContext object and returns a - response value. Only non-None if cardinality is - cardinality.Cardinality.UNARY_UNARY and style is style.Service.INLINE. - unary_stream_inline: The implementation of the method as a callable value - that takes a request value and a ServicerContext object and returns an - iterator of response values. Only non-None if cardinality is - cardinality.Cardinality.UNARY_STREAM and style is style.Service.INLINE. - stream_unary_inline: The implementation of the method as a callable value - that takes an iterator of request values and a ServicerContext object and - returns a response value. Only non-None if cardinality is - cardinality.Cardinality.STREAM_UNARY and style is style.Service.INLINE. - stream_stream_inline: The implementation of the method as a callable value - that takes an iterator of request values and a ServicerContext object and - returns an iterator of response values. Only non-None if cardinality is - cardinality.Cardinality.STREAM_STREAM and style is style.Service.INLINE. - unary_unary_event: The implementation of the method as a callable value that - takes a request value, a response callback to which to pass the response - value of the RPC, and a ServicerContext. Only non-None if cardinality is - cardinality.Cardinality.UNARY_UNARY and style is style.Service.EVENT. - unary_stream_event: The implementation of the method as a callable value - that takes a request value, a stream.Consumer to which to pass the - response values of the RPC, and a ServicerContext. Only non-None if - cardinality is cardinality.Cardinality.UNARY_STREAM and style is - style.Service.EVENT. - stream_unary_event: The implementation of the method as a callable value - that takes a response callback to which to pass the response value of the - RPC and a ServicerContext and returns a stream.Consumer to which the - request values of the RPC should be passed. Only non-None if cardinality - is cardinality.Cardinality.STREAM_UNARY and style is style.Service.EVENT. - stream_stream_event: The implementation of the method as a callable value - that takes a stream.Consumer to which to pass the response values of the - RPC and a ServicerContext and returns a stream.Consumer to which the - request values of the RPC should be passed. Only non-None if cardinality - is cardinality.Cardinality.STREAM_STREAM and style is - style.Service.EVENT. - """ + Attributes: + cardinality: A cardinality.Cardinality value. + style: A style.Service value. + unary_unary_inline: The implementation of the method as a callable value + that takes a request value and a ServicerContext object and returns a + response value. Only non-None if cardinality is + cardinality.Cardinality.UNARY_UNARY and style is style.Service.INLINE. + unary_stream_inline: The implementation of the method as a callable value + that takes a request value and a ServicerContext object and returns an + iterator of response values. Only non-None if cardinality is + cardinality.Cardinality.UNARY_STREAM and style is style.Service.INLINE. + stream_unary_inline: The implementation of the method as a callable value + that takes an iterator of request values and a ServicerContext object and + returns a response value. Only non-None if cardinality is + cardinality.Cardinality.STREAM_UNARY and style is style.Service.INLINE. + stream_stream_inline: The implementation of the method as a callable value + that takes an iterator of request values and a ServicerContext object and + returns an iterator of response values. Only non-None if cardinality is + cardinality.Cardinality.STREAM_STREAM and style is style.Service.INLINE. + unary_unary_event: The implementation of the method as a callable value that + takes a request value, a response callback to which to pass the response + value of the RPC, and a ServicerContext. Only non-None if cardinality is + cardinality.Cardinality.UNARY_UNARY and style is style.Service.EVENT. + unary_stream_event: The implementation of the method as a callable value + that takes a request value, a stream.Consumer to which to pass the + response values of the RPC, and a ServicerContext. Only non-None if + cardinality is cardinality.Cardinality.UNARY_STREAM and style is + style.Service.EVENT. + stream_unary_event: The implementation of the method as a callable value + that takes a response callback to which to pass the response value of the + RPC and a ServicerContext and returns a stream.Consumer to which the + request values of the RPC should be passed. Only non-None if cardinality + is cardinality.Cardinality.STREAM_UNARY and style is style.Service.EVENT. + stream_stream_event: The implementation of the method as a callable value + that takes a stream.Consumer to which to pass the response values of the + RPC and a ServicerContext and returns a stream.Consumer to which the + request values of the RPC should be passed. Only non-None if cardinality + is cardinality.Cardinality.STREAM_STREAM and style is + style.Service.EVENT. + """ class MultiMethodImplementation(abc.ABC): @@ -661,27 +676,27 @@ class MultiMethodImplementation(abc.ABC): def service(self, group, method, response_consumer, context): """Services an RPC. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - response_consumer: A stream.Consumer to be called to accept the response - values of the RPC. - context: a ServicerContext object. - - Returns: - A stream.Consumer with which to accept the request values of the RPC. The - consumer returned from this method may or may not be invoked to - completion: in the case of RPC abortion, RPC Framework will simply stop - passing values to this object. Implementations must not assume that this - object will be called to completion of the request stream or even called - at all. - - Raises: - abandonment.Abandoned: May or may not be raised when the RPC has been - aborted. - NoSuchMethodError: If this MultiMethod does not recognize the given group - and name for the RPC and is not able to service the RPC. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + response_consumer: A stream.Consumer to be called to accept the response + values of the RPC. + context: a ServicerContext object. + + Returns: + A stream.Consumer with which to accept the request values of the RPC. The + consumer returned from this method may or may not be invoked to + completion: in the case of RPC abortion, RPC Framework will simply stop + passing values to this object. Implementations must not assume that this + object will be called to completion of the request stream or even called + at all. + + Raises: + abandonment.Abandoned: May or may not be raised when the RPC has been + aborted. + NoSuchMethodError: If this MultiMethod does not recognize the given group + and name for the RPC and is not able to service the RPC. + """ raise NotImplementedError() @@ -689,361 +704,381 @@ class GenericStub(abc.ABC): """Affords RPC invocation via generic methods.""" @abc.abstractmethod - def blocking_unary_unary(self, - group, - method, - request, - timeout, - metadata=None, - with_call=False, - protocol_options=None): + def blocking_unary_unary( + self, + group, + method, + request, + timeout, + metadata=None, + with_call=False, + protocol_options=None, + ): """Invokes a unary-request-unary-response method. - This method blocks until either returning the response value of the RPC - (in the event of RPC completion) or raising an exception (in the event of - RPC abortion). - - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - request: The request value for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - with_call: Whether or not to include return a Call for the RPC in addition - to the response. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - The response value for the RPC, and a Call for the RPC if with_call was - set to True at invocation. - - Raises: - AbortionError: Indicating that the RPC was aborted. - """ + This method blocks until either returning the response value of the RPC + (in the event of RPC completion) or raising an exception (in the event of + RPC abortion). + + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + request: The request value for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + with_call: Whether or not to include return a Call for the RPC in addition + to the response. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + The response value for the RPC, and a Call for the RPC if with_call was + set to True at invocation. + + Raises: + AbortionError: Indicating that the RPC was aborted. + """ raise NotImplementedError() @abc.abstractmethod - def future_unary_unary(self, - group, - method, - request, - timeout, - metadata=None, - protocol_options=None): + def future_unary_unary( + self, + group, + method, + request, + timeout, + metadata=None, + protocol_options=None, + ): """Invokes a unary-request-unary-response method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - request: The request value for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - An object that is both a Call for the RPC and a future.Future. In the - event of RPC completion, the return Future's result value will be the - response value of the RPC. In the event of RPC abortion, the returned - Future's exception value will be an AbortionError. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + request: The request value for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + An object that is both a Call for the RPC and a future.Future. In the + event of RPC completion, the return Future's result value will be the + response value of the RPC. In the event of RPC abortion, the returned + Future's exception value will be an AbortionError. + """ raise NotImplementedError() @abc.abstractmethod - def inline_unary_stream(self, - group, - method, - request, - timeout, - metadata=None, - protocol_options=None): + def inline_unary_stream( + self, + group, + method, + request, + timeout, + metadata=None, + protocol_options=None, + ): """Invokes a unary-request-stream-response method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - request: The request value for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - An object that is both a Call for the RPC and an iterator of response - values. Drawing response values from the returned iterator may raise - AbortionError indicating abortion of the RPC. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + request: The request value for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + An object that is both a Call for the RPC and an iterator of response + values. Drawing response values from the returned iterator may raise + AbortionError indicating abortion of the RPC. + """ raise NotImplementedError() @abc.abstractmethod - def blocking_stream_unary(self, - group, - method, - request_iterator, - timeout, - metadata=None, - with_call=False, - protocol_options=None): + def blocking_stream_unary( + self, + group, + method, + request_iterator, + timeout, + metadata=None, + with_call=False, + protocol_options=None, + ): """Invokes a stream-request-unary-response method. - This method blocks until either returning the response value of the RPC - (in the event of RPC completion) or raising an exception (in the event of - RPC abortion). - - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - request_iterator: An iterator that yields request values for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - with_call: Whether or not to include return a Call for the RPC in addition - to the response. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - The response value for the RPC, and a Call for the RPC if with_call was - set to True at invocation. - - Raises: - AbortionError: Indicating that the RPC was aborted. - """ + This method blocks until either returning the response value of the RPC + (in the event of RPC completion) or raising an exception (in the event of + RPC abortion). + + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + request_iterator: An iterator that yields request values for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + with_call: Whether or not to include return a Call for the RPC in addition + to the response. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + The response value for the RPC, and a Call for the RPC if with_call was + set to True at invocation. + + Raises: + AbortionError: Indicating that the RPC was aborted. + """ raise NotImplementedError() @abc.abstractmethod - def future_stream_unary(self, - group, - method, - request_iterator, - timeout, - metadata=None, - protocol_options=None): + def future_stream_unary( + self, + group, + method, + request_iterator, + timeout, + metadata=None, + protocol_options=None, + ): """Invokes a stream-request-unary-response method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - request_iterator: An iterator that yields request values for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - An object that is both a Call for the RPC and a future.Future. In the - event of RPC completion, the return Future's result value will be the - response value of the RPC. In the event of RPC abortion, the returned - Future's exception value will be an AbortionError. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + request_iterator: An iterator that yields request values for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + An object that is both a Call for the RPC and a future.Future. In the + event of RPC completion, the return Future's result value will be the + response value of the RPC. In the event of RPC abortion, the returned + Future's exception value will be an AbortionError. + """ raise NotImplementedError() @abc.abstractmethod - def inline_stream_stream(self, - group, - method, - request_iterator, - timeout, - metadata=None, - protocol_options=None): + def inline_stream_stream( + self, + group, + method, + request_iterator, + timeout, + metadata=None, + protocol_options=None, + ): """Invokes a stream-request-stream-response method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - request_iterator: An iterator that yields request values for the RPC. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - An object that is both a Call for the RPC and an iterator of response - values. Drawing response values from the returned iterator may raise - AbortionError indicating abortion of the RPC. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + request_iterator: An iterator that yields request values for the RPC. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + An object that is both a Call for the RPC and an iterator of response + values. Drawing response values from the returned iterator may raise + AbortionError indicating abortion of the RPC. + """ raise NotImplementedError() @abc.abstractmethod - def event_unary_unary(self, - group, - method, - request, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event_unary_unary( + self, + group, + method, + request, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): """Event-driven invocation of a unary-request-unary-response method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - request: The request value for the RPC. - receiver: A ResponseReceiver to be passed the response data of the RPC. - abortion_callback: A callback to be called and passed an Abortion value - in the event of RPC abortion. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - A Call for the RPC. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + request: The request value for the RPC. + receiver: A ResponseReceiver to be passed the response data of the RPC. + abortion_callback: A callback to be called and passed an Abortion value + in the event of RPC abortion. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + A Call for the RPC. + """ raise NotImplementedError() @abc.abstractmethod - def event_unary_stream(self, - group, - method, - request, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event_unary_stream( + self, + group, + method, + request, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): """Event-driven invocation of a unary-request-stream-response method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - request: The request value for the RPC. - receiver: A ResponseReceiver to be passed the response data of the RPC. - abortion_callback: A callback to be called and passed an Abortion value - in the event of RPC abortion. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - A Call for the RPC. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + request: The request value for the RPC. + receiver: A ResponseReceiver to be passed the response data of the RPC. + abortion_callback: A callback to be called and passed an Abortion value + in the event of RPC abortion. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + A Call for the RPC. + """ raise NotImplementedError() @abc.abstractmethod - def event_stream_unary(self, - group, - method, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event_stream_unary( + self, + group, + method, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): """Event-driven invocation of a unary-request-unary-response method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - receiver: A ResponseReceiver to be passed the response data of the RPC. - abortion_callback: A callback to be called and passed an Abortion value - in the event of RPC abortion. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - A pair of a Call object for the RPC and a stream.Consumer to which the - request values of the RPC should be passed. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + receiver: A ResponseReceiver to be passed the response data of the RPC. + abortion_callback: A callback to be called and passed an Abortion value + in the event of RPC abortion. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + A pair of a Call object for the RPC and a stream.Consumer to which the + request values of the RPC should be passed. + """ raise NotImplementedError() @abc.abstractmethod - def event_stream_stream(self, - group, - method, - receiver, - abortion_callback, - timeout, - metadata=None, - protocol_options=None): + def event_stream_stream( + self, + group, + method, + receiver, + abortion_callback, + timeout, + metadata=None, + protocol_options=None, + ): """Event-driven invocation of a unary-request-stream-response method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. - receiver: A ResponseReceiver to be passed the response data of the RPC. - abortion_callback: A callback to be called and passed an Abortion value - in the event of RPC abortion. - timeout: A duration of time in seconds to allow for the RPC. - metadata: A metadata value to be passed to the service-side of the RPC. - protocol_options: A value specified by the provider of a Face interface - implementation affording custom state and behavior. - - Returns: - A pair of a Call object for the RPC and a stream.Consumer to which the - request values of the RPC should be passed. - """ + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. + receiver: A ResponseReceiver to be passed the response data of the RPC. + abortion_callback: A callback to be called and passed an Abortion value + in the event of RPC abortion. + timeout: A duration of time in seconds to allow for the RPC. + metadata: A metadata value to be passed to the service-side of the RPC. + protocol_options: A value specified by the provider of a Face interface + implementation affording custom state and behavior. + + Returns: + A pair of a Call object for the RPC and a stream.Consumer to which the + request values of the RPC should be passed. + """ raise NotImplementedError() @abc.abstractmethod def unary_unary(self, group, method): """Creates a UnaryUnaryMultiCallable for a unary-unary method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. - Returns: - A UnaryUnaryMultiCallable value for the named unary-unary method. - """ + Returns: + A UnaryUnaryMultiCallable value for the named unary-unary method. + """ raise NotImplementedError() @abc.abstractmethod def unary_stream(self, group, method): """Creates a UnaryStreamMultiCallable for a unary-stream method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. - Returns: - A UnaryStreamMultiCallable value for the name unary-stream method. - """ + Returns: + A UnaryStreamMultiCallable value for the name unary-stream method. + """ raise NotImplementedError() @abc.abstractmethod def stream_unary(self, group, method): """Creates a StreamUnaryMultiCallable for a stream-unary method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. - Returns: - A StreamUnaryMultiCallable value for the named stream-unary method. - """ + Returns: + A StreamUnaryMultiCallable value for the named stream-unary method. + """ raise NotImplementedError() @abc.abstractmethod def stream_stream(self, group, method): """Creates a StreamStreamMultiCallable for a stream-stream method. - Args: - group: The group identifier of the RPC. - method: The method identifier of the RPC. + Args: + group: The group identifier of the RPC. + method: The method identifier of the RPC. - Returns: - A StreamStreamMultiCallable value for the named stream-stream method. - """ + Returns: + A StreamStreamMultiCallable value for the named stream-stream method. + """ raise NotImplementedError() class DynamicStub(abc.ABC): """Affords RPC invocation via attributes corresponding to afforded methods. - Instances of this type may be scoped to a single group so that attribute - access is unambiguous. - - Instances of this type respond to attribute access as follows: if the - requested attribute is the name of a unary-unary method, the value of the - attribute will be a UnaryUnaryMultiCallable with which to invoke an RPC; if - the requested attribute is the name of a unary-stream method, the value of the - attribute will be a UnaryStreamMultiCallable with which to invoke an RPC; if - the requested attribute is the name of a stream-unary method, the value of the - attribute will be a StreamUnaryMultiCallable with which to invoke an RPC; and - if the requested attribute is the name of a stream-stream method, the value of - the attribute will be a StreamStreamMultiCallable with which to invoke an RPC. - """ + Instances of this type may be scoped to a single group so that attribute + access is unambiguous. + + Instances of this type respond to attribute access as follows: if the + requested attribute is the name of a unary-unary method, the value of the + attribute will be a UnaryUnaryMultiCallable with which to invoke an RPC; if + the requested attribute is the name of a unary-stream method, the value of the + attribute will be a UnaryStreamMultiCallable with which to invoke an RPC; if + the requested attribute is the name of a stream-unary method, the value of the + attribute will be a StreamUnaryMultiCallable with which to invoke an RPC; and + if the requested attribute is the name of a stream-stream method, the value of + the attribute will be a StreamStreamMultiCallable with which to invoke an RPC. + """ diff --git a/src/python/grpcio/grpc/framework/interfaces/face/utilities.py b/src/python/grpcio/grpc/framework/interfaces/face/utilities.py index f27bd676155f6..01807a1602683 100644 --- a/src/python/grpcio/grpc/framework/interfaces/face/utilities.py +++ b/src/python/grpcio/grpc/framework/interfaces/face/utilities.py @@ -22,147 +22,224 @@ from grpc.framework.interfaces.face import face -class _MethodImplementation(face.MethodImplementation, - collections.namedtuple('_MethodImplementation', [ - 'cardinality', - 'style', - 'unary_unary_inline', - 'unary_stream_inline', - 'stream_unary_inline', - 'stream_stream_inline', - 'unary_unary_event', - 'unary_stream_event', - 'stream_unary_event', - 'stream_stream_event', - ])): +class _MethodImplementation( + face.MethodImplementation, + collections.namedtuple( + "_MethodImplementation", + [ + "cardinality", + "style", + "unary_unary_inline", + "unary_stream_inline", + "stream_unary_inline", + "stream_stream_inline", + "unary_unary_event", + "unary_stream_event", + "stream_unary_event", + "stream_stream_event", + ], + ), +): pass def unary_unary_inline(behavior): """Creates an face.MethodImplementation for the given behavior. - Args: - behavior: The implementation of a unary-unary RPC method as a callable value - that takes a request value and an face.ServicerContext object and - returns a response value. - - Returns: - An face.MethodImplementation derived from the given behavior. - """ - return _MethodImplementation(cardinality.Cardinality.UNARY_UNARY, - style.Service.INLINE, behavior, None, None, - None, None, None, None, None) + Args: + behavior: The implementation of a unary-unary RPC method as a callable value + that takes a request value and an face.ServicerContext object and + returns a response value. + + Returns: + An face.MethodImplementation derived from the given behavior. + """ + return _MethodImplementation( + cardinality.Cardinality.UNARY_UNARY, + style.Service.INLINE, + behavior, + None, + None, + None, + None, + None, + None, + None, + ) def unary_stream_inline(behavior): """Creates an face.MethodImplementation for the given behavior. - Args: - behavior: The implementation of a unary-stream RPC method as a callable - value that takes a request value and an face.ServicerContext object and - returns an iterator of response values. - - Returns: - An face.MethodImplementation derived from the given behavior. - """ - return _MethodImplementation(cardinality.Cardinality.UNARY_STREAM, - style.Service.INLINE, None, behavior, None, - None, None, None, None, None) + Args: + behavior: The implementation of a unary-stream RPC method as a callable + value that takes a request value and an face.ServicerContext object and + returns an iterator of response values. + + Returns: + An face.MethodImplementation derived from the given behavior. + """ + return _MethodImplementation( + cardinality.Cardinality.UNARY_STREAM, + style.Service.INLINE, + None, + behavior, + None, + None, + None, + None, + None, + None, + ) def stream_unary_inline(behavior): """Creates an face.MethodImplementation for the given behavior. - Args: - behavior: The implementation of a stream-unary RPC method as a callable - value that takes an iterator of request values and an - face.ServicerContext object and returns a response value. - - Returns: - An face.MethodImplementation derived from the given behavior. - """ - return _MethodImplementation(cardinality.Cardinality.STREAM_UNARY, - style.Service.INLINE, None, None, behavior, - None, None, None, None, None) + Args: + behavior: The implementation of a stream-unary RPC method as a callable + value that takes an iterator of request values and an + face.ServicerContext object and returns a response value. + + Returns: + An face.MethodImplementation derived from the given behavior. + """ + return _MethodImplementation( + cardinality.Cardinality.STREAM_UNARY, + style.Service.INLINE, + None, + None, + behavior, + None, + None, + None, + None, + None, + ) def stream_stream_inline(behavior): """Creates an face.MethodImplementation for the given behavior. - Args: - behavior: The implementation of a stream-stream RPC method as a callable - value that takes an iterator of request values and an - face.ServicerContext object and returns an iterator of response values. - - Returns: - An face.MethodImplementation derived from the given behavior. - """ - return _MethodImplementation(cardinality.Cardinality.STREAM_STREAM, - style.Service.INLINE, None, None, None, - behavior, None, None, None, None) + Args: + behavior: The implementation of a stream-stream RPC method as a callable + value that takes an iterator of request values and an + face.ServicerContext object and returns an iterator of response values. + + Returns: + An face.MethodImplementation derived from the given behavior. + """ + return _MethodImplementation( + cardinality.Cardinality.STREAM_STREAM, + style.Service.INLINE, + None, + None, + None, + behavior, + None, + None, + None, + None, + ) def unary_unary_event(behavior): """Creates an face.MethodImplementation for the given behavior. - Args: - behavior: The implementation of a unary-unary RPC method as a callable - value that takes a request value, a response callback to which to pass - the response value of the RPC, and an face.ServicerContext. - - Returns: - An face.MethodImplementation derived from the given behavior. - """ - return _MethodImplementation(cardinality.Cardinality.UNARY_UNARY, - style.Service.EVENT, None, None, None, None, - behavior, None, None, None) + Args: + behavior: The implementation of a unary-unary RPC method as a callable + value that takes a request value, a response callback to which to pass + the response value of the RPC, and an face.ServicerContext. + + Returns: + An face.MethodImplementation derived from the given behavior. + """ + return _MethodImplementation( + cardinality.Cardinality.UNARY_UNARY, + style.Service.EVENT, + None, + None, + None, + None, + behavior, + None, + None, + None, + ) def unary_stream_event(behavior): """Creates an face.MethodImplementation for the given behavior. - Args: - behavior: The implementation of a unary-stream RPC method as a callable - value that takes a request value, a stream.Consumer to which to pass the - the response values of the RPC, and an face.ServicerContext. - - Returns: - An face.MethodImplementation derived from the given behavior. - """ - return _MethodImplementation(cardinality.Cardinality.UNARY_STREAM, - style.Service.EVENT, None, None, None, None, - None, behavior, None, None) + Args: + behavior: The implementation of a unary-stream RPC method as a callable + value that takes a request value, a stream.Consumer to which to pass the + the response values of the RPC, and an face.ServicerContext. + + Returns: + An face.MethodImplementation derived from the given behavior. + """ + return _MethodImplementation( + cardinality.Cardinality.UNARY_STREAM, + style.Service.EVENT, + None, + None, + None, + None, + None, + behavior, + None, + None, + ) def stream_unary_event(behavior): """Creates an face.MethodImplementation for the given behavior. - Args: - behavior: The implementation of a stream-unary RPC method as a callable - value that takes a response callback to which to pass the response value - of the RPC and an face.ServicerContext and returns a stream.Consumer to - which the request values of the RPC should be passed. - - Returns: - An face.MethodImplementation derived from the given behavior. - """ - return _MethodImplementation(cardinality.Cardinality.STREAM_UNARY, - style.Service.EVENT, None, None, None, None, - None, None, behavior, None) + Args: + behavior: The implementation of a stream-unary RPC method as a callable + value that takes a response callback to which to pass the response value + of the RPC and an face.ServicerContext and returns a stream.Consumer to + which the request values of the RPC should be passed. + + Returns: + An face.MethodImplementation derived from the given behavior. + """ + return _MethodImplementation( + cardinality.Cardinality.STREAM_UNARY, + style.Service.EVENT, + None, + None, + None, + None, + None, + None, + behavior, + None, + ) def stream_stream_event(behavior): """Creates an face.MethodImplementation for the given behavior. - Args: - behavior: The implementation of a stream-stream RPC method as a callable - value that takes a stream.Consumer to which to pass the response values - of the RPC and an face.ServicerContext and returns a stream.Consumer to - which the request values of the RPC should be passed. - - Returns: - An face.MethodImplementation derived from the given behavior. - """ - return _MethodImplementation(cardinality.Cardinality.STREAM_STREAM, - style.Service.EVENT, None, None, None, None, - None, None, None, behavior) + Args: + behavior: The implementation of a stream-stream RPC method as a callable + value that takes a stream.Consumer to which to pass the response values + of the RPC and an face.ServicerContext and returns a stream.Consumer to + which the request values of the RPC should be passed. + + Returns: + An face.MethodImplementation derived from the given behavior. + """ + return _MethodImplementation( + cardinality.Cardinality.STREAM_STREAM, + style.Service.EVENT, + None, + None, + None, + None, + None, + None, + None, + behavior, + ) diff --git a/src/python/grpcio/support.py b/src/python/grpcio/support.py index 3d64b3170ca48..6f0a59df7c19c 100644 --- a/src/python/grpcio/support.py +++ b/src/python/grpcio/support.py @@ -37,22 +37,23 @@ (check your environment variables or try re-installing?) """ if sys.version_info[0] == 2: - PYTHON_REPRESENTATION = 'python' + PYTHON_REPRESENTATION = "python" elif sys.version_info[0] == 3: - PYTHON_REPRESENTATION = 'python3' + PYTHON_REPRESENTATION = "python3" else: - raise NotImplementedError('Unsupported Python version: %s' % sys.version) + raise NotImplementedError("Unsupported Python version: %s" % sys.version) C_CHECKS = { - C_PYTHON_DEV: - C_PYTHON_DEV_ERROR_MESSAGE.replace('', PYTHON_REPRESENTATION), + C_PYTHON_DEV: C_PYTHON_DEV_ERROR_MESSAGE.replace( + "", PYTHON_REPRESENTATION + ), } def _compile(compiler, source_string): tempdir = tempfile.mkdtemp() - cpath = os.path.join(tempdir, 'a.c') - with open(cpath, 'w') as cfile: + cpath = os.path.join(tempdir, "a.c") + with open(cpath, "w") as cfile: cfile.write(source_string) try: compiler.compile([cpath]) @@ -67,7 +68,9 @@ def _expect_compile(compiler, source_string, error_message): sys.stderr.write(error_message) raise commands.CommandError( "Diagnostics found a compilation environment issue:\n{}".format( - error_message)) + error_message + ) + ) def diagnose_compile_error(build_ext, error): @@ -75,27 +78,32 @@ def diagnose_compile_error(build_ext, error): for c_check, message in C_CHECKS.items(): _expect_compile(build_ext.compiler, c_check, message) python_sources = [ - source for source in build_ext.get_source_files() - if source.startswith('./src/python') and source.endswith('c') + source + for source in build_ext.get_source_files() + if source.startswith("./src/python") and source.endswith("c") ] for source in python_sources: if not os.path.isfile(source): - raise commands.CommandError(( - "Diagnostics found a missing Python extension source file:\n{}\n\n" - "This is usually because the Cython sources haven't been transpiled " - "into C yet and you're building from source.\n" - "Try setting the environment variable " - "`GRPC_PYTHON_BUILD_WITH_CYTHON=1` when invoking `setup.py` or " - "when using `pip`, e.g.:\n\n" - "pip install -rrequirements.txt\n" - "GRPC_PYTHON_BUILD_WITH_CYTHON=1 pip install .").format(source)) + raise commands.CommandError( + ( + "Diagnostics found a missing Python extension source" + " file:\n{}\n\nThis is usually because the Cython sources" + " haven't been transpiled into C yet and you're building" + " from source.\nTry setting the environment variable" + " `GRPC_PYTHON_BUILD_WITH_CYTHON=1` when invoking" + " `setup.py` or when using `pip`, e.g.:\n\npip install" + " -rrequirements.txt\nGRPC_PYTHON_BUILD_WITH_CYTHON=1 pip" + " install ." + ).format(source) + ) def diagnose_attribute_error(build_ext, error): - if any('_needs_stub' in arg for arg in error.args): + if any("_needs_stub" in arg for arg in error.args): raise commands.CommandError( - "We expect a missing `_needs_stub` attribute from older versions of " - "setuptools. Consider upgrading setuptools.") + "We expect a missing `_needs_stub` attribute from older versions of" + " setuptools. Consider upgrading setuptools." + ) _ERROR_DIAGNOSES = { @@ -108,10 +116,11 @@ def diagnose_build_ext_error(build_ext, error, formatted): diagnostic = _ERROR_DIAGNOSES.get(type(error)) if diagnostic is None: raise commands.CommandError( - "\n\nWe could not diagnose your build failure. If you are unable to " - "proceed, please file an issue at http://www.github.com/grpc/grpc " - "with `[Python install]` in the title; please attach the whole log " - "(including everything that may have appeared above the Python " - "backtrace).\n\n{}".format(formatted)) + "\n\nWe could not diagnose your build failure. If you are unable to" + " proceed, please file an issue at http://www.github.com/grpc/grpc" + " with `[Python install]` in the title; please attach the whole log" + " (including everything that may have appeared above the Python" + " backtrace).\n\n{}".format(formatted) + ) else: diagnostic(build_ext, error) diff --git a/src/python/grpcio_admin/grpc_admin/__init__.py b/src/python/grpcio_admin/grpc_admin/__init__.py index 95e70858b3b52..567c8fa359c53 100644 --- a/src/python/grpcio_admin/grpc_admin/__init__.py +++ b/src/python/grpcio_admin/grpc_admin/__init__.py @@ -25,7 +25,7 @@ def add_admin_servicers(server): a separate library, and the documentation of the predefined admin services is usually scattered. It can be time consuming to get the dependency management, module initialization, and library import right for each one of - them. + them. This API provides a convenient way to create a gRPC server to expose admin services. With this, any new admin services that you may add in the future @@ -39,4 +39,4 @@ def add_admin_servicers(server): grpc_csds.add_csds_servicer(server) -__all__ = ['add_admin_servicers'] +__all__ = ["add_admin_servicers"] diff --git a/src/python/grpcio_admin/setup.py b/src/python/grpcio_admin/setup.py index 2d966cdc0aa0b..b144b34dc86b6 100644 --- a/src/python/grpcio_admin/setup.py +++ b/src/python/grpcio_admin/setup.py @@ -19,7 +19,7 @@ import setuptools _PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__)) -_README_PATH = os.path.join(_PACKAGE_PATH, 'README.rst') +_README_PATH = os.path.join(_PACKAGE_PATH, "README.rst") # Ensure we're in the proper directory whether or not we're being used by pip. os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -28,33 +28,35 @@ import grpc_version CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", ] PACKAGE_DIRECTORIES = { - '': '.', + "": ".", } INSTALL_REQUIRES = ( - 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION), - 'grpcio-csds>={version}'.format(version=grpc_version.VERSION), + "grpcio-channelz>={version}".format(version=grpc_version.VERSION), + "grpcio-csds>={version}".format(version=grpc_version.VERSION), ) SETUP_REQUIRES = INSTALL_REQUIRES -setuptools.setup(name='grpcio-admin', - version=grpc_version.VERSION, - license='Apache License 2.0', - description='a collection of admin services', - long_description=open(_README_PATH, 'r').read(), - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - classifiers=CLASSIFIERS, - url='https://grpc.io', - package_dir=PACKAGE_DIRECTORIES, - packages=setuptools.find_packages('.'), - python_requires='>=3.6', - install_requires=INSTALL_REQUIRES, - setup_requires=SETUP_REQUIRES) +setuptools.setup( + name="grpcio-admin", + version=grpc_version.VERSION, + license="Apache License 2.0", + description="a collection of admin services", + long_description=open(_README_PATH, "r").read(), + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + classifiers=CLASSIFIERS, + url="https://grpc.io", + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages("."), + python_requires=">=3.6", + install_requires=INSTALL_REQUIRES, + setup_requires=SETUP_REQUIRES, +) diff --git a/src/python/grpcio_channelz/channelz_commands.py b/src/python/grpcio_channelz/channelz_commands.py index dbbce2fda5aad..c42522a6ed09d 100644 --- a/src/python/grpcio_channelz/channelz_commands.py +++ b/src/python/grpcio_channelz/channelz_commands.py @@ -19,16 +19,17 @@ import setuptools ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) -CHANNELZ_PROTO = os.path.join(ROOT_DIR, - '../../proto/grpc/channelz/channelz.proto') -LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') +CHANNELZ_PROTO = os.path.join( + ROOT_DIR, "../../proto/grpc/channelz/channelz.proto" +) +LICENSE = os.path.join(ROOT_DIR, "../../../LICENSE") class Preprocess(setuptools.Command): """Command to copy proto modules from grpc/src/proto and LICENSE from the root directory""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -41,15 +42,16 @@ def run(self): if os.path.isfile(CHANNELZ_PROTO): shutil.copyfile( CHANNELZ_PROTO, - os.path.join(ROOT_DIR, 'grpc_channelz/v1/channelz.proto')) + os.path.join(ROOT_DIR, "grpc_channelz/v1/channelz.proto"), + ) if os.path.isfile(LICENSE): - shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, "LICENSE")) class BuildPackageProtos(setuptools.Command): """Command to generate project *_pb2.py modules from proto files.""" - description = 'build grpc protobuf modules' + description = "build grpc protobuf modules" user_options = [] def initialize_options(self): @@ -64,4 +66,5 @@ def run(self): # to `self.distribution.package_dir` (and get a key error if it's not # there). from grpc_tools import command - command.build_package_protos(self.distribution.package_dir['']) + + command.build_package_protos(self.distribution.package_dir[""]) diff --git a/src/python/grpcio_channelz/grpc_channelz/v1/_async.py b/src/python/grpcio_channelz/grpc_channelz/v1/_async.py index 463f5e14dcf14..47f3d6963f4a3 100644 --- a/src/python/grpcio_channelz/grpc_channelz/v1/_async.py +++ b/src/python/grpcio_channelz/grpc_channelz/v1/_async.py @@ -14,8 +14,7 @@ """AsyncIO version of Channelz servicer.""" from grpc.experimental import aio -from grpc_channelz.v1._servicer import \ - ChannelzServicer as _SyncChannelzServicer +from grpc_channelz.v1._servicer import ChannelzServicer as _SyncChannelzServicer import grpc_channelz.v1.channelz_pb2 as _channelz_pb2 import grpc_channelz.v1.channelz_pb2_grpc as _channelz_pb2_grpc @@ -25,45 +24,45 @@ class ChannelzServicer(_channelz_pb2_grpc.ChannelzServicer): @staticmethod async def GetTopChannels( - request: _channelz_pb2.GetTopChannelsRequest, - context: aio.ServicerContext + request: _channelz_pb2.GetTopChannelsRequest, + context: aio.ServicerContext, ) -> _channelz_pb2.GetTopChannelsResponse: return _SyncChannelzServicer.GetTopChannels(request, context) @staticmethod async def GetServers( - request: _channelz_pb2.GetServersRequest, - context: aio.ServicerContext) -> _channelz_pb2.GetServersResponse: + request: _channelz_pb2.GetServersRequest, context: aio.ServicerContext + ) -> _channelz_pb2.GetServersResponse: return _SyncChannelzServicer.GetServers(request, context) @staticmethod async def GetServer( - request: _channelz_pb2.GetServerRequest, - context: aio.ServicerContext) -> _channelz_pb2.GetServerResponse: + request: _channelz_pb2.GetServerRequest, context: aio.ServicerContext + ) -> _channelz_pb2.GetServerResponse: return _SyncChannelzServicer.GetServer(request, context) @staticmethod async def GetServerSockets( - request: _channelz_pb2.GetServerSocketsRequest, - context: aio.ServicerContext + request: _channelz_pb2.GetServerSocketsRequest, + context: aio.ServicerContext, ) -> _channelz_pb2.GetServerSocketsResponse: return _SyncChannelzServicer.GetServerSockets(request, context) @staticmethod async def GetChannel( - request: _channelz_pb2.GetChannelRequest, - context: aio.ServicerContext) -> _channelz_pb2.GetChannelResponse: + request: _channelz_pb2.GetChannelRequest, context: aio.ServicerContext + ) -> _channelz_pb2.GetChannelResponse: return _SyncChannelzServicer.GetChannel(request, context) @staticmethod async def GetSubchannel( - request: _channelz_pb2.GetSubchannelRequest, - context: aio.ServicerContext + request: _channelz_pb2.GetSubchannelRequest, + context: aio.ServicerContext, ) -> _channelz_pb2.GetSubchannelResponse: return _SyncChannelzServicer.GetSubchannel(request, context) @staticmethod async def GetSocket( - request: _channelz_pb2.GetSocketRequest, - context: aio.ServicerContext) -> _channelz_pb2.GetSocketResponse: + request: _channelz_pb2.GetSocketRequest, context: aio.ServicerContext + ) -> _channelz_pb2.GetSocketResponse: return _SyncChannelzServicer.GetSocket(request, context) diff --git a/src/python/grpcio_channelz/grpc_channelz/v1/_servicer.py b/src/python/grpcio_channelz/grpc_channelz/v1/_servicer.py index 2d44976ec1d2c..b167bef044745 100644 --- a/src/python/grpcio_channelz/grpc_channelz/v1/_servicer.py +++ b/src/python/grpcio_channelz/grpc_channelz/v1/_servicer.py @@ -63,9 +63,11 @@ def GetServer(request, context): def GetServerSockets(request, context): try: return json_format.Parse( - cygrpc.channelz_get_server_sockets(request.server_id, - request.start_socket_id, - request.max_results), + cygrpc.channelz_get_server_sockets( + request.server_id, + request.start_socket_id, + request.max_results, + ), _channelz_pb2.GetServerSocketsResponse(), ) except ValueError as e: diff --git a/src/python/grpcio_channelz/grpc_channelz/v1/channelz.py b/src/python/grpcio_channelz/grpc_channelz/v1/channelz.py index 605150b79a7a8..bfd5851096973 100644 --- a/src/python/grpcio_channelz/grpc_channelz/v1/channelz.py +++ b/src/python/grpcio_channelz/grpc_channelz/v1/channelz.py @@ -43,13 +43,14 @@ from grpc_channelz.v1 import _async as aio def add_channelz_servicer(server): - if isinstance(server, grpc.experimental.aio.Server): _channelz_pb2_grpc.add_ChannelzServicer_to_server( - aio.ChannelzServicer(), server) + aio.ChannelzServicer(), server + ) else: _channelz_pb2_grpc.add_ChannelzServicer_to_server( - ChannelzServicer(), server) + ChannelzServicer(), server + ) add_channelz_servicer.__doc__ = _add_channelz_servicer_doc @@ -63,7 +64,8 @@ def add_channelz_servicer(server): def add_channelz_servicer(server): _channelz_pb2_grpc.add_ChannelzServicer_to_server( - ChannelzServicer(), server) + ChannelzServicer(), server + ) add_channelz_servicer.__doc__ = _add_channelz_servicer_doc diff --git a/src/python/grpcio_channelz/setup.py b/src/python/grpcio_channelz/setup.py index f0a9c6168ee05..7a7ce0dfba2b0 100644 --- a/src/python/grpcio_channelz/setup.py +++ b/src/python/grpcio_channelz/setup.py @@ -19,7 +19,7 @@ import setuptools _PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__)) -_README_PATH = os.path.join(_PACKAGE_PATH, 'README.rst') +_README_PATH = os.path.join(_PACKAGE_PATH, "README.rst") # Ensure we're in the proper directory whether or not we're being used by pip. os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -31,7 +31,7 @@ class _NoOpCommand(setuptools.Command): """No-op command.""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -45,61 +45,63 @@ def run(self): CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: Apache Software License", ] PACKAGE_DIRECTORIES = { - '': '.', + "": ".", } INSTALL_REQUIRES = ( - 'protobuf>=4.21.6', - 'grpcio>={version}'.format(version=grpc_version.VERSION), + "protobuf>=4.21.6", + "grpcio>={version}".format(version=grpc_version.VERSION), ) try: import channelz_commands as _channelz_commands # we are in the build environment, otherwise the above import fails - SETUP_REQUIRES = ('grpcio-tools=={version}'.format( - version=grpc_version.VERSION),) + SETUP_REQUIRES = ( + "grpcio-tools=={version}".format(version=grpc_version.VERSION), + ) COMMAND_CLASS = { # Run preprocess from the repository *before* doing any packaging! - 'preprocess': _channelz_commands.Preprocess, - 'build_package_protos': _channelz_commands.BuildPackageProtos, + "preprocess": _channelz_commands.Preprocess, + "build_package_protos": _channelz_commands.BuildPackageProtos, } except ImportError: SETUP_REQUIRES = () COMMAND_CLASS = { # wire up commands to no-op not to break the external dependencies - 'preprocess': _NoOpCommand, - 'build_package_protos': _NoOpCommand, + "preprocess": _NoOpCommand, + "build_package_protos": _NoOpCommand, } setuptools.setup( - name='grpcio-channelz', + name="grpcio-channelz", version=grpc_version.VERSION, - license='Apache License 2.0', - description='Channel Level Live Debug Information Service for gRPC', - long_description=open(_README_PATH, 'r').read(), - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', + license="Apache License 2.0", + description="Channel Level Live Debug Information Service for gRPC", + long_description=open(_README_PATH, "r").read(), + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", classifiers=CLASSIFIERS, - url='https://grpc.io', + url="https://grpc.io", package_dir=PACKAGE_DIRECTORIES, - packages=setuptools.find_packages('.'), - python_requires='>=3.6', + packages=setuptools.find_packages("."), + python_requires=">=3.6", install_requires=INSTALL_REQUIRES, setup_requires=SETUP_REQUIRES, - cmdclass=COMMAND_CLASS) + cmdclass=COMMAND_CLASS, +) diff --git a/src/python/grpcio_csds/grpc_csds/__init__.py b/src/python/grpcio_csds/grpc_csds/__init__.py index 118ba3e319c81..a2c0c14480058 100644 --- a/src/python/grpcio_csds/grpc_csds/__init__.py +++ b/src/python/grpcio_csds/grpc_csds/__init__.py @@ -20,13 +20,15 @@ class ClientStatusDiscoveryServiceServicer( - csds_pb2_grpc.ClientStatusDiscoveryServiceServicer): + csds_pb2_grpc.ClientStatusDiscoveryServiceServicer +): """CSDS Servicer works for both the sync API and asyncio API.""" @staticmethod def FetchClientStatus(request, unused_context): client_config = csds_pb2.ClientConfig.FromString( - cygrpc.dump_xds_configs()) + cygrpc.dump_xds_configs() + ) response = csds_pb2.ClientStatusResponse() response.config.append(client_config) return response @@ -35,7 +37,8 @@ def FetchClientStatus(request, unused_context): def StreamClientStatus(request_iterator, context): for request in request_iterator: yield ClientStatusDiscoveryServiceServicer.FetchClientStatus( - request, context) + request, context + ) def add_csds_servicer(server): @@ -50,7 +53,8 @@ def add_csds_servicer(server): server: A gRPC server to which the CSDS service will be added. """ csds_pb2_grpc.add_ClientStatusDiscoveryServiceServicer_to_server( - ClientStatusDiscoveryServiceServicer(), server) + ClientStatusDiscoveryServiceServicer(), server + ) -__all__ = ['ClientStatusDiscoveryServiceServicer', 'add_csds_servicer'] +__all__ = ["ClientStatusDiscoveryServiceServicer", "add_csds_servicer"] diff --git a/src/python/grpcio_csds/setup.py b/src/python/grpcio_csds/setup.py index 6523648516b6e..120f127e7ae3d 100644 --- a/src/python/grpcio_csds/setup.py +++ b/src/python/grpcio_csds/setup.py @@ -19,7 +19,7 @@ import setuptools _PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__)) -_README_PATH = os.path.join(_PACKAGE_PATH, 'README.rst') +_README_PATH = os.path.join(_PACKAGE_PATH, "README.rst") # Ensure we're in the proper directory whether or not we're being used by pip. os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -28,34 +28,36 @@ import grpc_version CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", ] PACKAGE_DIRECTORIES = { - '': '.', + "": ".", } INSTALL_REQUIRES = ( - 'protobuf>=4.21.6', - 'xds-protos>=0.0.7', - 'grpcio>={version}'.format(version=grpc_version.VERSION), + "protobuf>=4.21.6", + "xds-protos>=0.0.7", + "grpcio>={version}".format(version=grpc_version.VERSION), ) SETUP_REQUIRES = INSTALL_REQUIRES -setuptools.setup(name='grpcio-csds', - version=grpc_version.VERSION, - license='Apache License 2.0', - description='xDS configuration dump library', - long_description=open(_README_PATH, 'r').read(), - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - classifiers=CLASSIFIERS, - url='https://grpc.io', - package_dir=PACKAGE_DIRECTORIES, - packages=setuptools.find_packages('.'), - python_requires='>=3.6', - install_requires=INSTALL_REQUIRES, - setup_requires=SETUP_REQUIRES) +setuptools.setup( + name="grpcio-csds", + version=grpc_version.VERSION, + license="Apache License 2.0", + description="xDS configuration dump library", + long_description=open(_README_PATH, "r").read(), + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + classifiers=CLASSIFIERS, + url="https://grpc.io", + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages("."), + python_requires=">=3.6", + install_requires=INSTALL_REQUIRES, + setup_requires=SETUP_REQUIRES, +) diff --git a/src/python/grpcio_health_checking/grpc_health/v1/_async.py b/src/python/grpcio_health_checking/grpc_health/v1/_async.py index b56a945c61424..3788050f21b32 100644 --- a/src/python/grpcio_health_checking/grpc_health/v1/_async.py +++ b/src/python/grpcio_health_checking/grpc_health/v1/_async.py @@ -24,8 +24,10 @@ class HealthServicer(_health_pb2_grpc.HealthServicer): """An AsyncIO implementation of health checking servicer.""" + _server_status: MutableMapping[ - str, '_health_pb2.HealthCheckResponse.ServingStatus'] + str, "_health_pb2.HealthCheckResponse.ServingStatus" + ] _server_watchers: MutableMapping[str, asyncio.Condition] _gracefully_shutting_down: bool @@ -34,8 +36,9 @@ def __init__(self) -> None: self._server_watchers = collections.defaultdict(asyncio.Condition) self._gracefully_shutting_down = False - async def Check(self, request: _health_pb2.HealthCheckRequest, - context) -> None: + async def Check( + self, request: _health_pb2.HealthCheckRequest, context + ) -> None: status = self._server_status.get(request.service) if status is None: @@ -43,8 +46,9 @@ async def Check(self, request: _health_pb2.HealthCheckRequest, else: return _health_pb2.HealthCheckResponse(status=status) - async def Watch(self, request: _health_pb2.HealthCheckRequest, - context) -> None: + async def Watch( + self, request: _health_pb2.HealthCheckRequest, context + ) -> None: condition = self._server_watchers[request.service] last_status = None try: @@ -52,7 +56,8 @@ async def Watch(self, request: _health_pb2.HealthCheckRequest, while True: status = self._server_status.get( request.service, - _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN) + _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + ) # NOTE(lidiz) If the observed status is the same, it means # there are missing intermediate statuses. It's considered @@ -60,7 +65,8 @@ async def Watch(self, request: _health_pb2.HealthCheckRequest, if status != last_status: # Responds with current health state await context.write( - _health_pb2.HealthCheckResponse(status=status)) + _health_pb2.HealthCheckResponse(status=status) + ) # Records the last sent status last_status = status @@ -72,8 +78,10 @@ async def Watch(self, request: _health_pb2.HealthCheckRequest, del self._server_watchers[request.service] async def _set( - self, service: str, - status: _health_pb2.HealthCheckResponse.ServingStatus) -> None: + self, + service: str, + status: _health_pb2.HealthCheckResponse.ServingStatus, + ) -> None: if service in self._server_watchers: condition = self._server_watchers.get(service) async with condition: @@ -83,8 +91,10 @@ async def _set( self._server_status[service] = status async def set( - self, service: str, - status: _health_pb2.HealthCheckResponse.ServingStatus) -> None: + self, + service: str, + status: _health_pb2.HealthCheckResponse.ServingStatus, + ) -> None: """Sets the status of a service. Args: @@ -109,5 +119,6 @@ async def enter_graceful_shutdown(self) -> None: else: self._gracefully_shutting_down = True for service in self._server_status: - await self._set(service, - _health_pb2.HealthCheckResponse.NOT_SERVING) + await self._set( + service, _health_pb2.HealthCheckResponse.NOT_SERVING + ) diff --git a/src/python/grpcio_health_checking/grpc_health/v1/health.py b/src/python/grpcio_health_checking/grpc_health/v1/health.py index e5f0e032134b5..db8756fa78a67 100644 --- a/src/python/grpcio_health_checking/grpc_health/v1/health.py +++ b/src/python/grpcio_health_checking/grpc_health/v1/health.py @@ -26,13 +26,12 @@ from . import _async as aio # pylint: disable=unused-import # The service name of the health checking servicer. -SERVICE_NAME = _health_pb2.DESCRIPTOR.services_by_name['Health'].full_name +SERVICE_NAME = _health_pb2.DESCRIPTOR.services_by_name["Health"].full_name # The entry of overall health for the entire server. -OVERALL_HEALTH = '' +OVERALL_HEALTH = "" -class _Watcher(): - +class _Watcher: def __init__(self): self._condition = threading.Condition() self._responses = collections.deque() @@ -68,7 +67,6 @@ def close(self): def _watcher_to_send_response_callback_adapter(watcher): - def send_response_callback(response): if response is None: watcher.close() @@ -81,22 +79,24 @@ def send_response_callback(response): class HealthServicer(_health_pb2_grpc.HealthServicer): """Servicer handling RPCs for service statuses.""" - def __init__(self, - experimental_non_blocking=True, - experimental_thread_pool=None): + def __init__( + self, experimental_non_blocking=True, experimental_thread_pool=None + ): self._lock = threading.RLock() self._server_status = {"": _health_pb2.HealthCheckResponse.SERVING} self._send_response_callbacks = {} - self.Watch.__func__.experimental_non_blocking = experimental_non_blocking + self.Watch.__func__.experimental_non_blocking = ( + experimental_non_blocking + ) self.Watch.__func__.experimental_thread_pool = experimental_thread_pool self._gracefully_shutting_down = False def _on_close_callback(self, send_response_callback, service): - def callback(): with self._lock: self._send_response_callbacks[service].remove( - send_response_callback) + send_response_callback + ) send_response_callback(None) return callback @@ -119,19 +119,24 @@ def Watch(self, request, context, send_response_callback=None): # generator. blocking_watcher = _Watcher() send_response_callback = _watcher_to_send_response_callback_adapter( - blocking_watcher) + blocking_watcher + ) service = request.service with self._lock: status = self._server_status.get(service) if status is None: - status = _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN # pylint: disable=no-member + status = ( + _health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + ) # pylint: disable=no-member send_response_callback( - _health_pb2.HealthCheckResponse(status=status)) + _health_pb2.HealthCheckResponse(status=status) + ) if service not in self._send_response_callbacks: self._send_response_callbacks[service] = set() self._send_response_callbacks[service].add(send_response_callback) context.add_callback( - self._on_close_callback(send_response_callback, service)) + self._on_close_callback(send_response_callback, service) + ) return blocking_watcher def set(self, service, status): @@ -149,9 +154,11 @@ def set(self, service, status): self._server_status[service] = status if service in self._send_response_callbacks: for send_response_callback in self._send_response_callbacks[ - service]: + service + ]: send_response_callback( - _health_pb2.HealthCheckResponse(status=status)) + _health_pb2.HealthCheckResponse(status=status) + ) def enter_graceful_shutdown(self): """Permanently sets the status of all services to NOT_SERVING. @@ -167,6 +174,7 @@ def enter_graceful_shutdown(self): return else: for service in self._server_status: - self.set(service, - _health_pb2.HealthCheckResponse.NOT_SERVING) # pylint: disable=no-member + self.set( + service, _health_pb2.HealthCheckResponse.NOT_SERVING + ) # pylint: disable=no-member self._gracefully_shutting_down = True diff --git a/src/python/grpcio_health_checking/health_commands.py b/src/python/grpcio_health_checking/health_commands.py index 874dec7343a98..74df84ad7bf68 100644 --- a/src/python/grpcio_health_checking/health_commands.py +++ b/src/python/grpcio_health_checking/health_commands.py @@ -19,15 +19,15 @@ import setuptools ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) -HEALTH_PROTO = os.path.join(ROOT_DIR, '../../proto/grpc/health/v1/health.proto') -LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') +HEALTH_PROTO = os.path.join(ROOT_DIR, "../../proto/grpc/health/v1/health.proto") +LICENSE = os.path.join(ROOT_DIR, "../../../LICENSE") class Preprocess(setuptools.Command): """Command to copy proto modules from grpc/src/proto and LICENSE from the root directory""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -40,15 +40,16 @@ def run(self): if os.path.isfile(HEALTH_PROTO): shutil.copyfile( HEALTH_PROTO, - os.path.join(ROOT_DIR, 'grpc_health/v1/health.proto')) + os.path.join(ROOT_DIR, "grpc_health/v1/health.proto"), + ) if os.path.isfile(LICENSE): - shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, "LICENSE")) class BuildPackageProtos(setuptools.Command): """Command to generate project *_pb2.py modules from proto files.""" - description = 'build grpc protobuf modules' + description = "build grpc protobuf modules" user_options = [] def initialize_options(self): @@ -63,4 +64,5 @@ def run(self): # to `self.distribution.package_dir` (and get a key error if it's not # there). from grpc_tools import command - command.build_package_protos(self.distribution.package_dir['']) + + command.build_package_protos(self.distribution.package_dir[""]) diff --git a/src/python/grpcio_health_checking/setup.py b/src/python/grpcio_health_checking/setup.py index 6c4725bf0fc2e..446b2a10e1533 100644 --- a/src/python/grpcio_health_checking/setup.py +++ b/src/python/grpcio_health_checking/setup.py @@ -18,7 +18,7 @@ import setuptools _PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__)) -_README_PATH = os.path.join(_PACKAGE_PATH, 'README.rst') +_README_PATH = os.path.join(_PACKAGE_PATH, "README.rst") # Ensure we're in the proper directory whether or not we're being used by pip. os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -30,7 +30,7 @@ class _NoOpCommand(setuptools.Command): """No-op command.""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -44,60 +44,63 @@ def run(self): CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: Apache Software License", ] PACKAGE_DIRECTORIES = { - '': '.', + "": ".", } INSTALL_REQUIRES = ( - 'protobuf>=4.21.6', - 'grpcio>={version}'.format(version=grpc_version.VERSION), + "protobuf>=4.21.6", + "grpcio>={version}".format(version=grpc_version.VERSION), ) try: import health_commands as _health_commands # we are in the build environment, otherwise the above import fails - SETUP_REQUIRES = ('grpcio-tools=={version}'.format( - version=grpc_version.VERSION),) + SETUP_REQUIRES = ( + "grpcio-tools=={version}".format(version=grpc_version.VERSION), + ) COMMAND_CLASS = { # Run preprocess from the repository *before* doing any packaging! - 'preprocess': _health_commands.Preprocess, - 'build_package_protos': _health_commands.BuildPackageProtos, + "preprocess": _health_commands.Preprocess, + "build_package_protos": _health_commands.BuildPackageProtos, } except ImportError: SETUP_REQUIRES = () COMMAND_CLASS = { # wire up commands to no-op not to break the external dependencies - 'preprocess': _NoOpCommand, - 'build_package_protos': _NoOpCommand, + "preprocess": _NoOpCommand, + "build_package_protos": _NoOpCommand, } -setuptools.setup(name='grpcio-health-checking', - version=grpc_version.VERSION, - description='Standard Health Checking Service for gRPC', - long_description=open(_README_PATH, 'r').read(), - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - url='https://grpc.io', - license='Apache License 2.0', - classifiers=CLASSIFIERS, - package_dir=PACKAGE_DIRECTORIES, - packages=setuptools.find_packages('.'), - python_requires='>=3.6', - install_requires=INSTALL_REQUIRES, - setup_requires=SETUP_REQUIRES, - cmdclass=COMMAND_CLASS) +setuptools.setup( + name="grpcio-health-checking", + version=grpc_version.VERSION, + description="Standard Health Checking Service for gRPC", + long_description=open(_README_PATH, "r").read(), + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + url="https://grpc.io", + license="Apache License 2.0", + classifiers=CLASSIFIERS, + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages("."), + python_requires=">=3.6", + install_requires=INSTALL_REQUIRES, + setup_requires=SETUP_REQUIRES, + cmdclass=COMMAND_CLASS, +) diff --git a/src/python/grpcio_observability/grpc_observability/__init__.py b/src/python/grpcio_observability/grpc_observability/__init__.py index 909317be364ee..6a10a17fb0b26 100644 --- a/src/python/grpcio_observability/grpc_observability/__init__.py +++ b/src/python/grpcio_observability/grpc_observability/__init__.py @@ -14,4 +14,4 @@ from grpc_observability._gcp_observability import GCPOpenCensusObservability -__all__ = ('GCPOpenCensusObservability',) +__all__ = ("GCPOpenCensusObservability",) diff --git a/src/python/grpcio_observability/grpc_observability/_gcp_observability.py b/src/python/grpcio_observability/grpc_observability/_gcp_observability.py index 674676bf3b210..083060c95ab8a 100644 --- a/src/python/grpcio_observability/grpc_observability/_gcp_observability.py +++ b/src/python/grpcio_observability/grpc_observability/_gcp_observability.py @@ -30,7 +30,9 @@ _LOGGER = logging.getLogger(__name__) ClientCallTracerCapsule = Any # it appears only once in the function signature -ServerCallTracerFactoryCapsule = Any # it appears only once in the function signature +ServerCallTracerFactoryCapsule = ( + Any # it appears only once in the function signature +) grpc_observability = Any # grpc_observability.py imports this module. GRPC_STATUS_CODE_TO_STRING = { @@ -68,16 +70,19 @@ class GcpObservabilityPythonConfig: def get(): with GcpObservabilityPythonConfig._lock: if GcpObservabilityPythonConfig._singleton is None: - GcpObservabilityPythonConfig._singleton = GcpObservabilityPythonConfig( + GcpObservabilityPythonConfig._singleton = ( + GcpObservabilityPythonConfig() ) return GcpObservabilityPythonConfig._singleton - def set_configuration(self, - project_id: str, - sampling_rate: Optional[float] = 0.0, - labels: Optional[Mapping[str, str]] = None, - tracing_enabled: bool = False, - stats_enabled: bool = False) -> None: + def set_configuration( + self, + project_id: str, + sampling_rate: Optional[float] = 0.0, + labels: Optional[Mapping[str, str]] = None, + tracing_enabled: bool = False, + stats_enabled: bool = False, + ) -> None: self.project_id = project_id self.stats_enabled = stats_enabled self.tracing_enabled = tracing_enabled @@ -99,6 +104,7 @@ class GCPOpenCensusObservability(grpc._observability.ObservabilityPlugin): config: Configuration for GCP OpenCensus Observability. exporter: Exporter used to export data. """ + config: GcpObservabilityPythonConfig exporter: "grpc_observability.Exporter" @@ -110,7 +116,8 @@ def __init__(self, exporter: "grpc_observability.Exporter" = None): else: self.exporter = OpenCensusExporter(self.config.get().labels) config_valid = _cyobservability.set_gcp_observability_config( - self.config) + self.config + ) if not config_valid: raise ValueError("Invalid configuration") @@ -122,7 +129,7 @@ def __init__(self, exporter: "grpc_observability.Exporter" = None): def __enter__(self): try: _cyobservability.cyobservability_init(self.exporter) - #TODO(xuanwn): Use specific exceptons + # TODO(xuanwn): Use specific exceptons except Exception as e: # pylint: disable=broad-except _LOGGER.exception("GCPOpenCensusObservability failed with: %s", e) @@ -147,40 +154,49 @@ def exit(self) -> None: grpc._observability.observability_deinit() def create_client_call_tracer( - self, method_name: bytes) -> ClientCallTracerCapsule: + self, method_name: bytes + ) -> ClientCallTracerCapsule: current_span = execution_context.get_current_span() if current_span: # Propagate existing OC context - trace_id = current_span.context_tracer.trace_id.encode('utf8') - parent_span_id = current_span.span_id.encode('utf8') + trace_id = current_span.context_tracer.trace_id.encode("utf8") + parent_span_id = current_span.span_id.encode("utf8") capsule = _cyobservability.create_client_call_tracer( - method_name, trace_id, parent_span_id) + method_name, trace_id, parent_span_id + ) else: - trace_id = span_context_module.generate_trace_id().encode('utf8') + trace_id = span_context_module.generate_trace_id().encode("utf8") capsule = _cyobservability.create_client_call_tracer( - method_name, trace_id) + method_name, trace_id + ) return capsule def create_server_call_tracer_factory( - self) -> ServerCallTracerFactoryCapsule: + self, + ) -> ServerCallTracerFactoryCapsule: capsule = _cyobservability.create_server_call_tracer_factory_capsule() return capsule def delete_client_call_tracer( - self, client_call_tracer: ClientCallTracerCapsule) -> None: + self, client_call_tracer: ClientCallTracerCapsule + ) -> None: _cyobservability.delete_client_call_tracer(client_call_tracer) - def save_trace_context(self, trace_id: str, span_id: str, - is_sampled: bool) -> None: + def save_trace_context( + self, trace_id: str, span_id: str, is_sampled: bool + ) -> None: trace_options = trace_options_module.TraceOptions(0) trace_options.set_enabled(is_sampled) span_context = span_context_module.SpanContext( - trace_id=trace_id, span_id=span_id, trace_options=trace_options) + trace_id=trace_id, span_id=span_id, trace_options=trace_options + ) current_tracer = execution_context.get_opencensus_tracer() current_tracer.span_context = span_context - def record_rpc_latency(self, method: str, rpc_latency: float, - status_code: grpc.StatusCode) -> None: + def record_rpc_latency( + self, method: str, rpc_latency: float, status_code: grpc.StatusCode + ) -> None: status_code = GRPC_STATUS_CODE_TO_STRING.get(status_code, "UNKNOWN") - _cyobservability._record_rpc_latency(self.exporter, method, rpc_latency, - status_code) + _cyobservability._record_rpc_latency( + self.exporter, method, rpc_latency, status_code + ) diff --git a/src/python/grpcio_observability/grpc_observability/_observability.py b/src/python/grpcio_observability/grpc_observability/_observability.py index 178d71f2d72be..1f0ea328f3f88 100644 --- a/src/python/grpcio_observability/grpc_observability/_observability.py +++ b/src/python/grpcio_observability/grpc_observability/_observability.py @@ -55,6 +55,7 @@ class StatsData: labels: A dictionary that maps label tags associated with this metric to corresponding label value. """ + name: "grpc_observability._cyobservability.MetricsName" measure_double: bool value_int: int = 0 @@ -88,6 +89,7 @@ class TracingData: description. The timeStamp have a format which can be converted to Python datetime.datetime, e.g. 2023-05-29 17:07:09.895 """ + name: str start_time: str end_time: str diff --git a/src/python/grpcio_observability/grpc_observability/_open_census_exporter.py b/src/python/grpcio_observability/grpc_observability/_open_census_exporter.py index a78d70770578b..bbd92fc8e3bf9 100644 --- a/src/python/grpcio_observability/grpc_observability/_open_census_exporter.py +++ b/src/python/grpcio_observability/grpc_observability/_open_census_exporter.py @@ -21,13 +21,14 @@ class OpenCensusExporter(_observability.Exporter): - - def export_stats_data(self, - stats_data: List[_observability.StatsData]) -> None: + def export_stats_data( + self, stats_data: List[_observability.StatsData] + ) -> None: # TODO(xuanwn): Add implementation raise NotImplementedError() def export_tracing_data( - self, tracing_data: List[_observability.TracingData]) -> None: + self, tracing_data: List[_observability.TracingData] + ) -> None: # TODO(xuanwn): Add implementation raise NotImplementedError() diff --git a/src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py b/src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py index 1806fe35bccfe..ef3a99abb3aef 100644 --- a/src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py +++ b/src/python/grpcio_reflection/grpc_reflection/v1alpha/_async.py @@ -24,31 +24,39 @@ class ReflectionServicer(BaseReflectionServicer): """Servicer handling RPCs for service statuses.""" async def ServerReflectionInfo( - self, request_iterator: AsyncIterable[ - _reflection_pb2.ServerReflectionRequest], unused_context + self, + request_iterator: AsyncIterable[ + _reflection_pb2.ServerReflectionRequest + ], + unused_context, ) -> AsyncIterable[_reflection_pb2.ServerReflectionResponse]: async for request in request_iterator: - if request.HasField('file_by_filename'): + if request.HasField("file_by_filename"): yield self._file_by_filename(request.file_by_filename) - elif request.HasField('file_containing_symbol'): + elif request.HasField("file_containing_symbol"): yield self._file_containing_symbol( - request.file_containing_symbol) - elif request.HasField('file_containing_extension'): + request.file_containing_symbol + ) + elif request.HasField("file_containing_extension"): yield self._file_containing_extension( request.file_containing_extension.containing_type, - request.file_containing_extension.extension_number) - elif request.HasField('all_extension_numbers_of_type'): + request.file_containing_extension.extension_number, + ) + elif request.HasField("all_extension_numbers_of_type"): yield self._all_extension_numbers_of_type( - request.all_extension_numbers_of_type) - elif request.HasField('list_services'): + request.all_extension_numbers_of_type + ) + elif request.HasField("list_services"): yield self._list_services() else: yield _reflection_pb2.ServerReflectionResponse( error_response=_reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.INVALID_ARGUMENT.value[0], - error_message=grpc.StatusCode.INVALID_ARGUMENT.value[1]. - encode(), - )) + error_message=grpc.StatusCode.INVALID_ARGUMENT.value[ + 1 + ].encode(), + ) + ) __all__ = [ diff --git a/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py b/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py index ff8ed7c501d0f..e73abc9346ee2 100644 --- a/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py +++ b/src/python/grpcio_reflection/grpc_reflection/v1alpha/_base.py @@ -27,7 +27,8 @@ def _not_found_error(): error_response=_reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )) + ) + ) def _collect_transitive_dependencies(descriptor, seen_files): @@ -52,7 +53,9 @@ def _file_descriptor_response(descriptor): return _reflection_pb2.ServerReflectionResponse( file_descriptor_response=_reflection_pb2.FileDescriptorResponse( - file_descriptor_proto=(serialized_proto_list)),) + file_descriptor_proto=(serialized_proto_list) + ), + ) class BaseReflectionServicer(_reflection_pb2_grpc.ServerReflectionServicer): @@ -79,7 +82,8 @@ def _file_by_filename(self, filename): def _file_containing_symbol(self, fully_qualified_name): try: descriptor = self._pool.FindFileContainingSymbol( - fully_qualified_name) + fully_qualified_name + ) except KeyError: return _not_found_error() else: @@ -88,11 +92,14 @@ def _file_containing_symbol(self, fully_qualified_name): def _file_containing_extension(self, containing_type, extension_number): try: message_descriptor = self._pool.FindMessageTypeByName( - containing_type) + containing_type + ) extension_descriptor = self._pool.FindExtensionByNumber( - message_descriptor, extension_number) + message_descriptor, extension_number + ) descriptor = self._pool.FindFileContainingSymbol( - extension_descriptor.full_name) + extension_descriptor.full_name + ) except KeyError: return _not_found_error() else: @@ -101,25 +108,35 @@ def _file_containing_extension(self, containing_type, extension_number): def _all_extension_numbers_of_type(self, containing_type): try: message_descriptor = self._pool.FindMessageTypeByName( - containing_type) + containing_type + ) extension_numbers = tuple( - sorted(extension.number for extension in - self._pool.FindAllExtensions(message_descriptor))) + sorted( + extension.number + for extension in self._pool.FindAllExtensions( + message_descriptor + ) + ) + ) except KeyError: return _not_found_error() else: return _reflection_pb2.ServerReflectionResponse( - all_extension_numbers_response=_reflection_pb2. - ExtensionNumberResponse( + all_extension_numbers_response=_reflection_pb2.ExtensionNumberResponse( base_type_name=message_descriptor.full_name, - extension_number=extension_numbers)) + extension_number=extension_numbers, + ) + ) def _list_services(self): return _reflection_pb2.ServerReflectionResponse( - list_services_response=_reflection_pb2.ListServiceResponse(service=[ - _reflection_pb2.ServiceResponse(name=service_name) - for service_name in self._service_names - ])) + list_services_response=_reflection_pb2.ListServiceResponse( + service=[ + _reflection_pb2.ServiceResponse(name=service_name) + for service_name in self._service_names + ] + ) + ) -__all__ = ['BaseReflectionServicer'] +__all__ = ["BaseReflectionServicer"] diff --git a/src/python/grpcio_reflection/grpc_reflection/v1alpha/proto_reflection_descriptor_database.py b/src/python/grpcio_reflection/grpc_reflection/v1alpha/proto_reflection_descriptor_database.py index 685ad95a82a3f..876461b2a1748 100644 --- a/src/python/grpcio_reflection/grpc_reflection/v1alpha/proto_reflection_descriptor_database.py +++ b/src/python/grpcio_reflection/grpc_reflection/v1alpha/proto_reflection_descriptor_database.py @@ -152,17 +152,19 @@ def FindAllExtensionNumbers(self, extendee_name: str) -> Iterable[int]: if extendee_name in self._cached_extension_numbers: return self._cached_extension_numbers[extendee_name] request = ServerReflectionRequest( - all_extension_numbers_of_type=extendee_name) + all_extension_numbers_of_type=extendee_name + ) response = self._do_one_request(request, key=extendee_name) all_extension_numbers: ExtensionNumberResponse = ( - response.all_extension_numbers_response) + response.all_extension_numbers_response + ) numbers = list(all_extension_numbers.extension_number) self._cached_extension_numbers[extendee_name] = numbers return numbers def FindFileContainingExtension( - self, extendee_name: str, - extension_number: int) -> FileDescriptorProto: + self, extendee_name: str, extension_number: int + ) -> FileDescriptorProto: """ Find the file which defines an extension for the given message type and field number. @@ -182,41 +184,49 @@ def FindFileContainingExtension( """ try: - return super().FindFileContainingExtension(extendee_name, - extension_number) + return super().FindFileContainingExtension( + extendee_name, extension_number + ) except KeyError: pass request = ServerReflectionRequest( file_containing_extension=ExtensionRequest( - containing_type=extendee_name, - extension_number=extension_number)) - response = self._do_one_request(request, - key=(extendee_name, extension_number)) + containing_type=extendee_name, extension_number=extension_number + ) + ) + response = self._do_one_request( + request, key=(extendee_name, extension_number) + ) file_desc = response.file_descriptor_response self._add_file_from_response(file_desc) - return super().FindFileContainingExtension(extendee_name, - extension_number) + return super().FindFileContainingExtension( + extendee_name, extension_number + ) - def _do_one_request(self, request: ServerReflectionRequest, - key: Any) -> ServerReflectionResponse: + def _do_one_request( + self, request: ServerReflectionRequest, key: Any + ) -> ServerReflectionResponse: response = self._stub.ServerReflectionInfo(iter([request])) res = next(response) if res.WhichOneof("message_response") == "error_response": # Only NOT_FOUND errors are expected at this layer error_code = res.error_response.error_code - assert (error_code == grpc.StatusCode.NOT_FOUND.value[0] - ), "unexpected error response: " + repr(res.error_response) + assert ( + error_code == grpc.StatusCode.NOT_FOUND.value[0] + ), "unexpected error response: " + repr(res.error_response) raise KeyError(key) return res def _add_file_from_response( - self, file_descriptor: FileDescriptorResponse) -> None: + self, file_descriptor: FileDescriptorResponse + ) -> None: protos: List[bytes] = file_descriptor.file_descriptor_proto for proto in protos: desc = FileDescriptorProto() desc.ParseFromString(proto) if desc.name not in self._known_files: - self._logger.info("Loading descriptors from file: %s", - desc.name) + self._logger.info( + "Loading descriptors from file: %s", desc.name + ) self._known_files.add(desc.name) self.Add(desc) diff --git a/src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py b/src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py index 3f0eb982b9a25..1c1807f49aa2c 100644 --- a/src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py +++ b/src/python/grpcio_reflection/grpc_reflection/v1alpha/reflection.py @@ -21,7 +21,8 @@ from grpc_reflection.v1alpha._base import BaseReflectionServicer SERVICE_NAME = _reflection_pb2.DESCRIPTOR.services_by_name[ - 'ServerReflection'].full_name + "ServerReflection" +].full_name class ReflectionServicer(BaseReflectionServicer): @@ -30,27 +31,32 @@ class ReflectionServicer(BaseReflectionServicer): def ServerReflectionInfo(self, request_iterator, context): # pylint: disable=unused-argument for request in request_iterator: - if request.HasField('file_by_filename'): + if request.HasField("file_by_filename"): yield self._file_by_filename(request.file_by_filename) - elif request.HasField('file_containing_symbol'): + elif request.HasField("file_containing_symbol"): yield self._file_containing_symbol( - request.file_containing_symbol) - elif request.HasField('file_containing_extension'): + request.file_containing_symbol + ) + elif request.HasField("file_containing_extension"): yield self._file_containing_extension( request.file_containing_extension.containing_type, - request.file_containing_extension.extension_number) - elif request.HasField('all_extension_numbers_of_type'): + request.file_containing_extension.extension_number, + ) + elif request.HasField("all_extension_numbers_of_type"): yield self._all_extension_numbers_of_type( - request.all_extension_numbers_of_type) - elif request.HasField('list_services'): + request.all_extension_numbers_of_type + ) + elif request.HasField("list_services"): yield self._list_services() else: yield _reflection_pb2.ServerReflectionResponse( error_response=_reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.INVALID_ARGUMENT.value[0], - error_message=grpc.StatusCode.INVALID_ARGUMENT.value[1]. - encode(), - )) + error_message=grpc.StatusCode.INVALID_ARGUMENT.value[ + 1 + ].encode(), + ) + ) _enable_server_reflection_doc = """Enables server reflection on a server. @@ -63,17 +69,21 @@ def ServerReflectionInfo(self, request_iterator, context): if sys.version_info[0] >= 3 and sys.version_info[1] >= 6: # Exposes AsyncReflectionServicer as public API. - from grpc.experimental import aio as grpc_aio # pylint: disable=ungrouped-imports + # pylint: disable=ungrouped-imports + from grpc.experimental import aio as grpc_aio + # pylint: enable=ungrouped-imports from . import _async as aio def enable_server_reflection(service_names, server, pool=None): if isinstance(server, grpc_aio.Server): _reflection_pb2_grpc.add_ServerReflectionServicer_to_server( - aio.ReflectionServicer(service_names, pool=pool), server) + aio.ReflectionServicer(service_names, pool=pool), server + ) else: _reflection_pb2_grpc.add_ServerReflectionServicer_to_server( - ReflectionServicer(service_names, pool=pool), server) + ReflectionServicer(service_names, pool=pool), server + ) enable_server_reflection.__doc__ = _enable_server_reflection_doc @@ -87,7 +97,8 @@ def enable_server_reflection(service_names, server, pool=None): def enable_server_reflection(service_names, server, pool=None): _reflection_pb2_grpc.add_ServerReflectionServicer_to_server( - ReflectionServicer(service_names, pool=pool), server) + ReflectionServicer(service_names, pool=pool), server + ) enable_server_reflection.__doc__ = _enable_server_reflection_doc diff --git a/src/python/grpcio_reflection/reflection_commands.py b/src/python/grpcio_reflection/reflection_commands.py index 311ca4c4dbae2..07ce59bfca81e 100644 --- a/src/python/grpcio_reflection/reflection_commands.py +++ b/src/python/grpcio_reflection/reflection_commands.py @@ -20,15 +20,16 @@ ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) REFLECTION_PROTO = os.path.join( - ROOT_DIR, '../../proto/grpc/reflection/v1alpha/reflection.proto') -LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') + ROOT_DIR, "../../proto/grpc/reflection/v1alpha/reflection.proto" +) +LICENSE = os.path.join(ROOT_DIR, "../../../LICENSE") class Preprocess(setuptools.Command): """Command to copy proto modules from grpc/src/proto and LICENSE from the root directory""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -41,16 +42,18 @@ def run(self): if os.path.isfile(REFLECTION_PROTO): shutil.copyfile( REFLECTION_PROTO, - os.path.join(ROOT_DIR, - 'grpc_reflection/v1alpha/reflection.proto')) + os.path.join( + ROOT_DIR, "grpc_reflection/v1alpha/reflection.proto" + ), + ) if os.path.isfile(LICENSE): - shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, "LICENSE")) class BuildPackageProtos(setuptools.Command): """Command to generate project *_pb2.py modules from proto files.""" - description = 'build grpc protobuf modules' + description = "build grpc protobuf modules" user_options = [] def initialize_options(self): @@ -65,4 +68,5 @@ def run(self): # to `self.distribution.package_dir` (and get a key error if it's not # there). from grpc_tools import command - command.build_package_protos(self.distribution.package_dir['']) + + command.build_package_protos(self.distribution.package_dir[""]) diff --git a/src/python/grpcio_reflection/setup.py b/src/python/grpcio_reflection/setup.py index bd3f53821f944..a92e11f807d28 100644 --- a/src/python/grpcio_reflection/setup.py +++ b/src/python/grpcio_reflection/setup.py @@ -19,7 +19,7 @@ import setuptools _PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__)) -_README_PATH = os.path.join(_PACKAGE_PATH, 'README.rst') +_README_PATH = os.path.join(_PACKAGE_PATH, "README.rst") # Ensure we're in the proper directory whether or not we're being used by pip. os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -31,7 +31,7 @@ class _NoOpCommand(setuptools.Command): """No-op command.""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -45,60 +45,63 @@ def run(self): CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: Apache Software License", ] PACKAGE_DIRECTORIES = { - '': '.', + "": ".", } INSTALL_REQUIRES = ( - 'protobuf>=4.21.6', - 'grpcio>={version}'.format(version=grpc_version.VERSION), + "protobuf>=4.21.6", + "grpcio>={version}".format(version=grpc_version.VERSION), ) try: import reflection_commands as _reflection_commands # we are in the build environment, otherwise the above import fails - SETUP_REQUIRES = ('grpcio-tools=={version}'.format( - version=grpc_version.VERSION),) + SETUP_REQUIRES = ( + "grpcio-tools=={version}".format(version=grpc_version.VERSION), + ) COMMAND_CLASS = { # Run preprocess from the repository *before* doing any packaging! - 'preprocess': _reflection_commands.Preprocess, - 'build_package_protos': _reflection_commands.BuildPackageProtos, + "preprocess": _reflection_commands.Preprocess, + "build_package_protos": _reflection_commands.BuildPackageProtos, } except ImportError: SETUP_REQUIRES = () COMMAND_CLASS = { # wire up commands to no-op not to break the external dependencies - 'preprocess': _NoOpCommand, - 'build_package_protos': _NoOpCommand, + "preprocess": _NoOpCommand, + "build_package_protos": _NoOpCommand, } -setuptools.setup(name='grpcio-reflection', - version=grpc_version.VERSION, - license='Apache License 2.0', - description='Standard Protobuf Reflection Service for gRPC', - long_description=open(_README_PATH, 'r').read(), - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - classifiers=CLASSIFIERS, - url='https://grpc.io', - package_dir=PACKAGE_DIRECTORIES, - packages=setuptools.find_packages('.'), - python_requires='>=3.6', - install_requires=INSTALL_REQUIRES, - setup_requires=SETUP_REQUIRES, - cmdclass=COMMAND_CLASS) +setuptools.setup( + name="grpcio-reflection", + version=grpc_version.VERSION, + license="Apache License 2.0", + description="Standard Protobuf Reflection Service for gRPC", + long_description=open(_README_PATH, "r").read(), + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + classifiers=CLASSIFIERS, + url="https://grpc.io", + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages("."), + python_requires=">=3.6", + install_requires=INSTALL_REQUIRES, + setup_requires=SETUP_REQUIRES, + cmdclass=COMMAND_CLASS, +) diff --git a/src/python/grpcio_status/grpc_status/_async.py b/src/python/grpcio_status/grpc_status/_async.py index bbd3be8971a92..9f58c8a81f330 100644 --- a/src/python/grpcio_status/grpc_status/_async.py +++ b/src/python/grpcio_status/grpc_status/_async.py @@ -41,16 +41,18 @@ async def from_call(call: aio.Call): rich_status = status_pb2.Status.FromString(value) if code.value[0] != rich_status.code: raise ValueError( - 'Code in Status proto (%s) doesn\'t match status code (%s)' - % (code_to_grpc_status_code(rich_status.code), code)) + "Code in Status proto (%s) doesn't match status code (%s)" + % (code_to_grpc_status_code(rich_status.code), code) + ) if details != rich_status.message: raise ValueError( - 'Message in Status proto (%s) doesn\'t match status details (%s)' - % (rich_status.message, details)) + "Message in Status proto (%s) doesn't match status details" + " (%s)" % (rich_status.message, details) + ) return rich_status return None __all__ = [ - 'from_call', + "from_call", ] diff --git a/src/python/grpcio_status/grpc_status/_common.py b/src/python/grpcio_status/grpc_status/_common.py index 4bec0ba1372ac..66677d849d647 100644 --- a/src/python/grpcio_status/grpc_status/_common.py +++ b/src/python/grpcio_status/grpc_status/_common.py @@ -17,11 +17,11 @@ _CODE_TO_GRPC_CODE_MAPPING = {x.value[0]: x for x in grpc.StatusCode} -GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' +GRPC_DETAILS_METADATA_KEY = "grpc-status-details-bin" def code_to_grpc_status_code(code): try: return _CODE_TO_GRPC_CODE_MAPPING[code] except KeyError: - raise ValueError('Invalid status code %s' % code) + raise ValueError("Invalid status code %s" % code) diff --git a/src/python/grpcio_status/grpc_status/rpc_status.py b/src/python/grpcio_status/grpc_status/rpc_status.py index 432e414f708b4..a3f10a8b18f41 100644 --- a/src/python/grpcio_status/grpc_status/rpc_status.py +++ b/src/python/grpcio_status/grpc_status/rpc_status.py @@ -24,9 +24,9 @@ class _Status( - collections.namedtuple('_Status', - ('code', 'details', 'trailing_metadata')), - grpc.Status): + collections.namedtuple("_Status", ("code", "details", "trailing_metadata")), + grpc.Status, +): pass @@ -52,12 +52,14 @@ def from_call(call): rich_status = status_pb2.Status.FromString(value) if call.code().value[0] != rich_status.code: raise ValueError( - 'Code in Status proto (%s) doesn\'t match status code (%s)' - % (code_to_grpc_status_code(rich_status.code), call.code())) + "Code in Status proto (%s) doesn't match status code (%s)" + % (code_to_grpc_status_code(rich_status.code), call.code()) + ) if call.details() != rich_status.message: raise ValueError( - 'Message in Status proto (%s) doesn\'t match status details (%s)' - % (rich_status.message, call.details())) + "Message in Status proto (%s) doesn't match status details" + " (%s)" % (rich_status.message, call.details()) + ) return rich_status return None @@ -74,17 +76,21 @@ def to_status(status): Returns: A grpc.Status instance representing the input google.rpc.status.Status message. """ - return _Status(code=code_to_grpc_status_code(status.code), - details=status.message, - trailing_metadata=((GRPC_DETAILS_METADATA_KEY, - status.SerializeToString()),)) + return _Status( + code=code_to_grpc_status_code(status.code), + details=status.message, + trailing_metadata=( + (GRPC_DETAILS_METADATA_KEY, status.SerializeToString()), + ), + ) __all__ = [ - 'from_call', - 'to_status', + "from_call", + "to_status", ] if sys.version_info[0] >= 3 and sys.version_info[1] >= 6: from . import _async as aio # pylint: disable=unused-import - __all__.append('aio') + + __all__.append("aio") diff --git a/src/python/grpcio_status/setup.py b/src/python/grpcio_status/setup.py index 593891c64e5fb..1fce94d2773d0 100644 --- a/src/python/grpcio_status/setup.py +++ b/src/python/grpcio_status/setup.py @@ -18,7 +18,7 @@ import setuptools _PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__)) -_README_PATH = os.path.join(_PACKAGE_PATH, 'README.rst') +_README_PATH = os.path.join(_PACKAGE_PATH, "README.rst") # Ensure we're in the proper directory whether or not we're being used by pip. os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -30,7 +30,7 @@ class _NoOpCommand(setuptools.Command): """No-op command.""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -44,28 +44,28 @@ def run(self): CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: Apache Software License", ] PACKAGE_DIRECTORIES = { - '': '.', + "": ".", } INSTALL_REQUIRES = ( - 'protobuf>=4.21.6', - 'grpcio>={version}'.format(version=grpc_version.VERSION), - 'googleapis-common-protos>=1.5.5', + "protobuf>=4.21.6", + "grpcio>={version}".format(version=grpc_version.VERSION), + "googleapis-common-protos>=1.5.5", ) try: @@ -74,27 +74,29 @@ def run(self): # we are in the build environment, otherwise the above import fails COMMAND_CLASS = { # Run preprocess from the repository *before* doing any packaging! - 'preprocess': _status_commands.Preprocess, - 'build_package_protos': _NoOpCommand, + "preprocess": _status_commands.Preprocess, + "build_package_protos": _NoOpCommand, } except ImportError: COMMAND_CLASS = { # wire up commands to no-op not to break the external dependencies - 'preprocess': _NoOpCommand, - 'build_package_protos': _NoOpCommand, + "preprocess": _NoOpCommand, + "build_package_protos": _NoOpCommand, } -setuptools.setup(name='grpcio-status', - version=grpc_version.VERSION, - description='Status proto mapping for gRPC', - long_description=open(_README_PATH, 'r').read(), - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - url='https://grpc.io', - license='Apache License 2.0', - classifiers=CLASSIFIERS, - package_dir=PACKAGE_DIRECTORIES, - packages=setuptools.find_packages('.'), - python_requires='>=3.6', - install_requires=INSTALL_REQUIRES, - cmdclass=COMMAND_CLASS) +setuptools.setup( + name="grpcio-status", + version=grpc_version.VERSION, + description="Status proto mapping for gRPC", + long_description=open(_README_PATH, "r").read(), + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + url="https://grpc.io", + license="Apache License 2.0", + classifiers=CLASSIFIERS, + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages("."), + python_requires=">=3.6", + install_requires=INSTALL_REQUIRES, + cmdclass=COMMAND_CLASS, +) diff --git a/src/python/grpcio_status/status_commands.py b/src/python/grpcio_status/status_commands.py index 8306f3c0278b8..25bc4694d9f00 100644 --- a/src/python/grpcio_status/status_commands.py +++ b/src/python/grpcio_status/status_commands.py @@ -20,15 +20,16 @@ ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) STATUS_PROTO = os.path.join( - ROOT_DIR, '../../../third_party/googleapis/google/rpc/status.proto') -PACKAGE_STATUS_PROTO_PATH = 'grpc_status/google/rpc' -LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') + ROOT_DIR, "../../../third_party/googleapis/google/rpc/status.proto" +) +PACKAGE_STATUS_PROTO_PATH = "grpc_status/google/rpc" +LICENSE = os.path.join(ROOT_DIR, "../../../LICENSE") class Preprocess(setuptools.Command): """Command to copy LICENSE from root directory.""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -43,7 +44,9 @@ def run(self): os.makedirs(PACKAGE_STATUS_PROTO_PATH) shutil.copyfile( STATUS_PROTO, - os.path.join(ROOT_DIR, PACKAGE_STATUS_PROTO_PATH, - 'status.proto')) + os.path.join( + ROOT_DIR, PACKAGE_STATUS_PROTO_PATH, "status.proto" + ), + ) if os.path.isfile(LICENSE): - shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, "LICENSE")) diff --git a/src/python/grpcio_testing/grpc_testing/__init__.py b/src/python/grpcio_testing/grpc_testing/__init__.py index b67b5deb6f509..4d84007af827f 100644 --- a/src/python/grpcio_testing/grpc_testing/__init__.py +++ b/src/python/grpcio_testing/grpc_testing/__init__.py @@ -493,8 +493,9 @@ class Server(abc.ABC): """A server with which to test a system that services RPCs.""" @abc.abstractmethod - def invoke_unary_unary(self, method_descriptor, invocation_metadata, - request, timeout): + def invoke_unary_unary( + self, method_descriptor, invocation_metadata, request, timeout + ): """Invokes an RPC to be serviced by the system under test. Args: @@ -511,8 +512,9 @@ def invoke_unary_unary(self, method_descriptor, invocation_metadata, raise NotImplementedError() @abc.abstractmethod - def invoke_unary_stream(self, method_descriptor, invocation_metadata, - request, timeout): + def invoke_unary_stream( + self, method_descriptor, invocation_metadata, request, timeout + ): """Invokes an RPC to be serviced by the system under test. Args: @@ -529,8 +531,9 @@ def invoke_unary_stream(self, method_descriptor, invocation_metadata, raise NotImplementedError() @abc.abstractmethod - def invoke_stream_unary(self, method_descriptor, invocation_metadata, - timeout): + def invoke_stream_unary( + self, method_descriptor, invocation_metadata, timeout + ): """Invokes an RPC to be serviced by the system under test. Args: @@ -546,8 +549,9 @@ def invoke_stream_unary(self, method_descriptor, invocation_metadata, raise NotImplementedError() @abc.abstractmethod - def invoke_stream_stream(self, method_descriptor, invocation_metadata, - timeout): + def invoke_stream_stream( + self, method_descriptor, invocation_metadata, timeout + ): """Invokes an RPC to be serviced by the system under test. Args: @@ -640,6 +644,7 @@ def strict_real_time(): A Time backed by the "system" (Python interpreter's) time. """ from grpc_testing import _time + return _time.StrictRealTime() @@ -659,6 +664,7 @@ def strict_fake_time(now): A Time that simulates the passage of time. """ from grpc_testing import _time + return _time.StrictFakeTime(now) @@ -675,6 +681,7 @@ def channel(service_descriptors, time): A Channel for use in tests. """ from grpc_testing import _channel + return _channel.testing_channel(service_descriptors, time) @@ -692,4 +699,5 @@ def server_from_dictionary(descriptors_to_servicers, time): A Server for use in tests. """ from grpc_testing import _server + return _server.server_from_dictionary(descriptors_to_servicers, time) diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_channel.py b/src/python/grpcio_testing/grpc_testing/_channel/_channel.py index 0c1941e6bea46..170533f63ea33 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_channel.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_channel.py @@ -21,7 +21,6 @@ # test infrastructure. # pylint: disable=unused-argument class TestingChannel(grpc_testing.Channel): - def __init__(self, time, state): self._time = time self._state = state @@ -32,28 +31,24 @@ def subscribe(self, callback, try_to_connect=False): def unsubscribe(self, callback): raise NotImplementedError() - def unary_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_unary( + self, method, request_serializer=None, response_deserializer=None + ): return _multi_callable.UnaryUnary(method, self._state) - def unary_stream(self, - method, - request_serializer=None, - response_deserializer=None): + def unary_stream( + self, method, request_serializer=None, response_deserializer=None + ): return _multi_callable.UnaryStream(method, self._state) - def stream_unary(self, - method, - request_serializer=None, - response_deserializer=None): + def stream_unary( + self, method, request_serializer=None, response_deserializer=None + ): return _multi_callable.StreamUnary(method, self._state) - def stream_stream(self, - method, - request_serializer=None, - response_deserializer=None): + def stream_stream( + self, method, request_serializer=None, response_deserializer=None + ): return _multi_callable.StreamStream(method, self._state) def _close(self): diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_channel_rpc.py b/src/python/grpcio_testing/grpc_testing/_channel/_channel_rpc.py index 54499b3b55fea..6c4531bdeff96 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_channel_rpc.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_channel_rpc.py @@ -16,7 +16,6 @@ class _UnaryUnary(grpc_testing.UnaryUnaryChannelRpc): - def __init__(self, rpc_state): self._rpc_state = rpc_state @@ -27,12 +26,12 @@ def cancelled(self): self._rpc_state.cancelled() def terminate(self, response, trailing_metadata, code, details): - self._rpc_state.terminate_with_response(response, trailing_metadata, - code, details) + self._rpc_state.terminate_with_response( + response, trailing_metadata, code, details + ) class _UnaryStream(grpc_testing.UnaryStreamChannelRpc): - def __init__(self, rpc_state): self._rpc_state = rpc_state @@ -50,7 +49,6 @@ def terminate(self, trailing_metadata, code, details): class _StreamUnary(grpc_testing.StreamUnaryChannelRpc): - def __init__(self, rpc_state): self._rpc_state = rpc_state @@ -67,12 +65,12 @@ def cancelled(self): self._rpc_state.cancelled() def terminate(self, response, trailing_metadata, code, details): - self._rpc_state.terminate_with_response(response, trailing_metadata, - code, details) + self._rpc_state.terminate_with_response( + response, trailing_metadata, code, details + ) class _StreamStream(grpc_testing.StreamStreamChannelRpc): - def __init__(self, rpc_state): self._rpc_state = rpc_state @@ -97,15 +95,19 @@ def terminate(self, trailing_metadata, code, details): def unary_unary(channel_state, method_descriptor): rpc_state = channel_state.take_rpc_state(method_descriptor) - invocation_metadata, request = ( - rpc_state.take_invocation_metadata_and_request()) + ( + invocation_metadata, + request, + ) = rpc_state.take_invocation_metadata_and_request() return invocation_metadata, request, _UnaryUnary(rpc_state) def unary_stream(channel_state, method_descriptor): rpc_state = channel_state.take_rpc_state(method_descriptor) - invocation_metadata, request = ( - rpc_state.take_invocation_metadata_and_request()) + ( + invocation_metadata, + request, + ) = rpc_state.take_invocation_metadata_and_request() return invocation_metadata, request, _UnaryStream(rpc_state) diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_channel_state.py b/src/python/grpcio_testing/grpc_testing/_channel/_channel_state.py index 779d59e59ad3a..91a095dd79170 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_channel_state.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_channel_state.py @@ -20,24 +20,31 @@ class State(_common.ChannelHandler): - def __init__(self): self._condition = threading.Condition() self._rpc_states = collections.defaultdict(list) - def invoke_rpc(self, method_full_rpc_name, invocation_metadata, requests, - requests_closed, timeout): - rpc_state = _rpc_state.State(invocation_metadata, requests, - requests_closed) + def invoke_rpc( + self, + method_full_rpc_name, + invocation_metadata, + requests, + requests_closed, + timeout, + ): + rpc_state = _rpc_state.State( + invocation_metadata, requests, requests_closed + ) with self._condition: self._rpc_states[method_full_rpc_name].append(rpc_state) self._condition.notify_all() return rpc_state def take_rpc_state(self, method_descriptor): - method_full_rpc_name = '/{}/{}'.format( + method_full_rpc_name = "/{}/{}".format( method_descriptor.containing_service.full_name, - method_descriptor.name) + method_descriptor.name, + ) with self._condition: while True: method_rpc_states = self._rpc_states[method_full_rpc_name] diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_invocation.py b/src/python/grpcio_testing/grpc_testing/_channel/_invocation.py index d7205ca579396..b1b7a6645915a 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_invocation.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_invocation.py @@ -23,7 +23,7 @@ def _cancel(handler): - return handler.cancel(grpc.StatusCode.CANCELLED, 'Locally cancelled!') + return handler.cancel(grpc.StatusCode.CANCELLED, "Locally cancelled!") def _is_active(handler): @@ -58,7 +58,6 @@ def _details(handler): class _Call(grpc.Call): - def __init__(self, handler): self._handler = handler @@ -88,7 +87,6 @@ def details(self): class _RpcErrorCall(grpc.RpcError, grpc.Call): - def __init__(self, handler): self._handler = handler @@ -128,7 +126,6 @@ def _next(handler): class _HandlerExtras(object): - def __init__(self): self.condition = threading.Condition() self.unary_response = _NOT_YET_OBSERVED @@ -137,7 +134,7 @@ def __init__(self): def _with_extras_cancel(handler, extras): with extras.condition: - if handler.cancel(grpc.StatusCode.CANCELLED, 'Locally cancelled!'): + if handler.cancel(grpc.StatusCode.CANCELLED, "Locally cancelled!"): extras.cancelled = True return True else: @@ -171,11 +168,11 @@ def _with_extras_unary_response(handler, extras): def _exception(unused_handler): - raise NotImplementedError('TODO!') + raise NotImplementedError("TODO!") def _traceback(unused_handler): - raise NotImplementedError('TODO!') + raise NotImplementedError("TODO!") def _add_done_callback(handler, callback, future): @@ -185,7 +182,6 @@ def _add_done_callback(handler, callback, future): class _FutureCall(grpc.Future, grpc.Call): - def __init__(self, handler, extras): self._handler = handler self._extras = extras @@ -237,7 +233,6 @@ def details(self): def consume_requests(request_iterator, handler): - def _consume(): while True: try: @@ -249,7 +244,7 @@ def _consume(): handler.close_requests() break except Exception: # pylint: disable=broad-except - details = 'Exception iterating requests!' + details = "Exception iterating requests!" _LOGGER.exception(details) handler.cancel(grpc.StatusCode.UNKNOWN, details) @@ -286,7 +281,6 @@ def future_call(handler): class ResponseIteratorCall(grpc.Call): - def __init__(self, handler): self._handler = handler diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_multi_callable.py b/src/python/grpcio_testing/grpc_testing/_channel/_multi_callable.py index 2b2f5761f5689..ee8f8db9a830c 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_multi_callable.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_multi_callable.py @@ -20,97 +20,117 @@ # All per-call credentials parameters are unused by this test infrastructure. # pylint: disable=unused-argument class UnaryUnary(grpc.UnaryUnaryMultiCallable): - def __init__(self, method_full_rpc_name, channel_handler): self._method_full_rpc_name = method_full_rpc_name self._channel_handler = channel_handler def __call__(self, request, timeout=None, metadata=None, credentials=None): rpc_handler = self._channel_handler.invoke_rpc( - self._method_full_rpc_name, _common.fuss_with_metadata(metadata), - [request], True, timeout) + self._method_full_rpc_name, + _common.fuss_with_metadata(metadata), + [request], + True, + timeout, + ) return _invocation.blocking_unary_response(rpc_handler) def with_call(self, request, timeout=None, metadata=None, credentials=None): rpc_handler = self._channel_handler.invoke_rpc( - self._method_full_rpc_name, _common.fuss_with_metadata(metadata), - [request], True, timeout) + self._method_full_rpc_name, + _common.fuss_with_metadata(metadata), + [request], + True, + timeout, + ) return _invocation.blocking_unary_response_with_call(rpc_handler) def future(self, request, timeout=None, metadata=None, credentials=None): rpc_handler = self._channel_handler.invoke_rpc( - self._method_full_rpc_name, _common.fuss_with_metadata(metadata), - [request], True, timeout) + self._method_full_rpc_name, + _common.fuss_with_metadata(metadata), + [request], + True, + timeout, + ) return _invocation.future_call(rpc_handler) class UnaryStream(grpc.StreamStreamMultiCallable): - def __init__(self, method_full_rpc_name, channel_handler): self._method_full_rpc_name = method_full_rpc_name self._channel_handler = channel_handler def __call__(self, request, timeout=None, metadata=None, credentials=None): rpc_handler = self._channel_handler.invoke_rpc( - self._method_full_rpc_name, _common.fuss_with_metadata(metadata), - [request], True, timeout) + self._method_full_rpc_name, + _common.fuss_with_metadata(metadata), + [request], + True, + timeout, + ) return _invocation.ResponseIteratorCall(rpc_handler) class StreamUnary(grpc.StreamUnaryMultiCallable): - def __init__(self, method_full_rpc_name, channel_handler): self._method_full_rpc_name = method_full_rpc_name self._channel_handler = channel_handler - def __call__(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None): + def __call__( + self, request_iterator, timeout=None, metadata=None, credentials=None + ): rpc_handler = self._channel_handler.invoke_rpc( - self._method_full_rpc_name, _common.fuss_with_metadata(metadata), - [], False, timeout) + self._method_full_rpc_name, + _common.fuss_with_metadata(metadata), + [], + False, + timeout, + ) _invocation.consume_requests(request_iterator, rpc_handler) return _invocation.blocking_unary_response(rpc_handler) - def with_call(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None): + def with_call( + self, request_iterator, timeout=None, metadata=None, credentials=None + ): rpc_handler = self._channel_handler.invoke_rpc( - self._method_full_rpc_name, _common.fuss_with_metadata(metadata), - [], False, timeout) + self._method_full_rpc_name, + _common.fuss_with_metadata(metadata), + [], + False, + timeout, + ) _invocation.consume_requests(request_iterator, rpc_handler) return _invocation.blocking_unary_response_with_call(rpc_handler) - def future(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None): + def future( + self, request_iterator, timeout=None, metadata=None, credentials=None + ): rpc_handler = self._channel_handler.invoke_rpc( - self._method_full_rpc_name, _common.fuss_with_metadata(metadata), - [], False, timeout) + self._method_full_rpc_name, + _common.fuss_with_metadata(metadata), + [], + False, + timeout, + ) _invocation.consume_requests(request_iterator, rpc_handler) return _invocation.future_call(rpc_handler) class StreamStream(grpc.StreamStreamMultiCallable): - def __init__(self, method_full_rpc_name, channel_handler): self._method_full_rpc_name = method_full_rpc_name self._channel_handler = channel_handler - def __call__(self, - request_iterator, - timeout=None, - metadata=None, - credentials=None): + def __call__( + self, request_iterator, timeout=None, metadata=None, credentials=None + ): rpc_handler = self._channel_handler.invoke_rpc( - self._method_full_rpc_name, _common.fuss_with_metadata(metadata), - [], False, timeout) + self._method_full_rpc_name, + _common.fuss_with_metadata(metadata), + [], + False, + timeout, + ) _invocation.consume_requests(request_iterator, rpc_handler) return _invocation.ResponseIteratorCall(rpc_handler) diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_rpc_state.py b/src/python/grpcio_testing/grpc_testing/_channel/_rpc_state.py index a548ef0f12e6c..6bbe40bcfc920 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_rpc_state.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_rpc_state.py @@ -19,7 +19,6 @@ class State(_common.ChannelRpcHandler): - def __init__(self, invocation_metadata, requests, requests_closed): self._condition = threading.Condition() self._invocation_metadata = invocation_metadata @@ -63,23 +62,28 @@ def take_response(self): if self._code is grpc.StatusCode.OK: if self._responses: response = self._responses.pop(0) - return _common.ChannelRpcRead(response, None, None, - None) + return _common.ChannelRpcRead( + response, None, None, None + ) else: - return _common.ChannelRpcRead(None, - self._trailing_metadata, - grpc.StatusCode.OK, - self._details) + return _common.ChannelRpcRead( + None, + self._trailing_metadata, + grpc.StatusCode.OK, + self._details, + ) elif self._code is None: if self._responses: response = self._responses.pop(0) - return _common.ChannelRpcRead(response, None, None, - None) + return _common.ChannelRpcRead( + response, None, None, None + ) else: self._condition.wait() else: - return _common.ChannelRpcRead(None, self._trailing_metadata, - self._code, self._details) + return _common.ChannelRpcRead( + None, self._trailing_metadata, self._code, self._details + ) def termination(self): with self._condition: @@ -105,7 +109,7 @@ def cancel(self, code, details): def take_invocation_metadata(self): with self._condition: if self._invocation_metadata is None: - raise ValueError('Expected invocation metadata!') + raise ValueError("Expected invocation metadata!") else: invocation_metadata = self._invocation_metadata self._invocation_metadata = None @@ -114,9 +118,9 @@ def take_invocation_metadata(self): def take_invocation_metadata_and_request(self): with self._condition: if self._invocation_metadata is None: - raise ValueError('Expected invocation metadata!') + raise ValueError("Expected invocation metadata!") elif not self._requests: - raise ValueError('Expected at least one request!') + raise ValueError("Expected at least one request!") else: invocation_metadata = self._invocation_metadata self._invocation_metadata = None @@ -125,7 +129,8 @@ def take_invocation_metadata_and_request(self): def send_initial_metadata(self, initial_metadata): with self._condition: self._initial_metadata = _common.fuss_with_metadata( - initial_metadata) + initial_metadata + ) self._condition.notify_all() def take_request(self): @@ -150,14 +155,16 @@ def send_response(self, response): self._responses.append(response) self._condition.notify_all() - def terminate_with_response(self, response, trailing_metadata, code, - details): + def terminate_with_response( + self, response, trailing_metadata, code, details + ): with self._condition: if self._initial_metadata is None: self._initial_metadata = _common.FUSSED_EMPTY_METADATA self._responses.append(response) self._trailing_metadata = _common.fuss_with_metadata( - trailing_metadata) + trailing_metadata + ) self._code = code self._details = details self._condition.notify_all() @@ -167,7 +174,8 @@ def terminate(self, trailing_metadata, code, details): if self._initial_metadata is None: self._initial_metadata = _common.FUSSED_EMPTY_METADATA self._trailing_metadata = _common.fuss_with_metadata( - trailing_metadata) + trailing_metadata + ) self._code = code self._details = details self._condition.notify_all() @@ -180,8 +188,9 @@ def cancelled(self): elif self._code is None: self._condition.wait() else: - raise ValueError('Status code unexpectedly {}!'.format( - self._code)) + raise ValueError( + "Status code unexpectedly {}!".format(self._code) + ) def is_active(self): raise NotImplementedError() diff --git a/src/python/grpcio_testing/grpc_testing/_common.py b/src/python/grpcio_testing/grpc_testing/_common.py index 01bf20780d712..6e1e16c93de46 100644 --- a/src/python/grpcio_testing/grpc_testing/_common.py +++ b/src/python/grpcio_testing/grpc_testing/_common.py @@ -18,10 +18,12 @@ def _fuss(tuplified_metadata): - return tuplified_metadata + (( - 'grpc.metadata_added_by_runtime', - 'gRPC is allowed to add metadata in transmission and does so.', - ),) + return tuplified_metadata + ( + ( + "grpc.metadata_added_by_runtime", + "gRPC is allowed to add metadata in transmission and does so.", + ), + ) FUSSED_EMPTY_METADATA = _fuss(()) @@ -38,24 +40,28 @@ def rpc_names(service_descriptors): rpc_names_to_descriptors = {} for service_descriptor in service_descriptors: for method_descriptor in service_descriptor.methods_by_name.values(): - rpc_name = '/{}/{}'.format(service_descriptor.full_name, - method_descriptor.name) + rpc_name = "/{}/{}".format( + service_descriptor.full_name, method_descriptor.name + ) rpc_names_to_descriptors[rpc_name] = method_descriptor return rpc_names_to_descriptors class ChannelRpcRead( - collections.namedtuple('ChannelRpcRead', ( - 'response', - 'trailing_metadata', - 'code', - 'details', - ))): + collections.namedtuple( + "ChannelRpcRead", + ( + "response", + "trailing_metadata", + "code", + "details", + ), + ) +): pass class ChannelRpcHandler(abc.ABC): - @abc.abstractmethod def initial_metadata(self): raise NotImplementedError() @@ -94,19 +100,28 @@ def add_callback(self, callback): class ChannelHandler(abc.ABC): - @abc.abstractmethod - def invoke_rpc(self, method_full_rpc_name, invocation_metadata, requests, - requests_closed, timeout): + def invoke_rpc( + self, + method_full_rpc_name, + invocation_metadata, + requests, + requests_closed, + timeout, + ): raise NotImplementedError() class ServerRpcRead( - collections.namedtuple('ServerRpcRead', ( - 'request', - 'requests_closed', - 'terminated', - ))): + collections.namedtuple( + "ServerRpcRead", + ( + "request", + "requests_closed", + "terminated", + ), + ) +): pass @@ -115,7 +130,6 @@ class ServerRpcRead( class ServerRpcHandler(abc.ABC): - @abc.abstractmethod def send_initial_metadata(self, initial_metadata): raise NotImplementedError() @@ -138,23 +152,26 @@ def add_termination_callback(self, callback): class Serverish(abc.ABC): - @abc.abstractmethod - def invoke_unary_unary(self, method_descriptor, handler, - invocation_metadata, request, deadline): + def invoke_unary_unary( + self, method_descriptor, handler, invocation_metadata, request, deadline + ): raise NotImplementedError() @abc.abstractmethod - def invoke_unary_stream(self, method_descriptor, handler, - invocation_metadata, request, deadline): + def invoke_unary_stream( + self, method_descriptor, handler, invocation_metadata, request, deadline + ): raise NotImplementedError() @abc.abstractmethod - def invoke_stream_unary(self, method_descriptor, handler, - invocation_metadata, deadline): + def invoke_stream_unary( + self, method_descriptor, handler, invocation_metadata, deadline + ): raise NotImplementedError() @abc.abstractmethod - def invoke_stream_stream(self, method_descriptor, handler, - invocation_metadata, deadline): + def invoke_stream_stream( + self, method_descriptor, handler, invocation_metadata, deadline + ): raise NotImplementedError() diff --git a/src/python/grpcio_testing/grpc_testing/_server/__init__.py b/src/python/grpcio_testing/grpc_testing/_server/__init__.py index 5f035a91cab4e..73c4eaa5f56b1 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/__init__.py +++ b/src/python/grpcio_testing/grpc_testing/_server/__init__.py @@ -16,5 +16,6 @@ def server_from_dictionary(descriptors_to_servicers, time): - return _server.server_from_descriptor_to_servicers(descriptors_to_servicers, - time) + return _server.server_from_descriptor_to_servicers( + descriptors_to_servicers, time + ) diff --git a/src/python/grpcio_testing/grpc_testing/_server/_handler.py b/src/python/grpcio_testing/grpc_testing/_server/_handler.py index 100d8195f624a..32a785d99f2e4 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_handler.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_handler.py @@ -22,7 +22,6 @@ class Handler(_common.ServerRpcHandler): - @abc.abstractmethod def initial_metadata(self): raise NotImplementedError() @@ -53,7 +52,6 @@ def stream_response_termination(self): class _Handler(Handler): - def __init__(self, requests_closed): self._condition = threading.Condition() self._requests = [] @@ -121,7 +119,8 @@ def initial_metadata(self): self._condition.wait() else: raise ValueError( - 'No initial metadata despite status code!') + "No initial metadata despite status code!" + ) else: return self._initial_metadata @@ -140,7 +139,7 @@ def take_response(self): elif self._code is None: self._condition.wait() else: - raise ValueError('No more responses!') + raise ValueError("No more responses!") def requests_closed(self): with self._condition: @@ -163,7 +162,7 @@ def unary_response_termination(self): with self._condition: while True: if self._code is _CLIENT_INACTIVE: - raise ValueError('Huh? Cancelled but wanting status?') + raise ValueError("Huh? Cancelled but wanting status?") elif self._code is None: self._condition.wait() else: @@ -181,7 +180,7 @@ def stream_response_termination(self): with self._condition: while True: if self._code is _CLIENT_INACTIVE: - raise ValueError('Huh? Cancelled but wanting status?') + raise ValueError("Huh? Cancelled but wanting status?") elif self._code is None: self._condition.wait() else: @@ -194,7 +193,7 @@ def expire(self): self._initial_metadata = _common.FUSSED_EMPTY_METADATA self._trailing_metadata = _common.FUSSED_EMPTY_METADATA self._code = grpc.StatusCode.DEADLINE_EXCEEDED - self._details = 'Took too much time!' + self._details = "Took too much time!" termination_callbacks = self._termination_callbacks self._termination_callbacks = None self._condition.notify_all() diff --git a/src/python/grpcio_testing/grpc_testing/_server/_rpc.py b/src/python/grpcio_testing/grpc_testing/_server/_rpc.py index 736b714dc6d0e..cadf091de479a 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_rpc.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_rpc.py @@ -23,7 +23,6 @@ class Rpc(object): - def __init__(self, handler, invocation_metadata): self._condition = threading.Condition() self._handler = handler @@ -50,7 +49,7 @@ def call_back(): try: callback() except Exception: # pylint: disable=broad-except - _LOGGER.exception('Exception calling server-side callback!') + _LOGGER.exception("Exception calling server-side callback!") callback_calling_thread = threading.Thread(target=call_back) callback_calling_thread.start() @@ -71,7 +70,7 @@ def _complete(self): code = grpc.StatusCode.OK else: code = self._pending_code - details = '' if self._pending_details is None else self._pending_details + details = "" if self._pending_details is None else self._pending_details self._terminate(trailing_metadata, code, details) def _abort(self, code, details): @@ -83,16 +82,19 @@ def add_rpc_error(self, rpc_error): def application_cancel(self): with self._condition: - self._abort(grpc.StatusCode.CANCELLED, - 'Cancelled by server-side application!') + self._abort( + grpc.StatusCode.CANCELLED, + "Cancelled by server-side application!", + ) def application_exception_abort(self, exception): with self._condition: if exception not in self._rpc_errors: - _LOGGER.exception('Exception calling application!') + _LOGGER.exception("Exception calling application!") self._abort( grpc.StatusCode.UNKNOWN, - 'Exception calling application: {}'.format(exception)) + "Exception calling application: {}".format(exception), + ) def extrinsic_abort(self): with self._condition: diff --git a/src/python/grpcio_testing/grpc_testing/_server/_server.py b/src/python/grpcio_testing/grpc_testing/_server/_server.py index 6d256d848f7d1..170a4b3c0d2f8 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_server.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_server.py @@ -29,7 +29,6 @@ def _implementation(descriptors_to_servicers, method_descriptor): def _unary_unary_service(request): - def service(implementation, rpc, servicer_context): _service.unary_unary(implementation, rpc, request, servicer_context) @@ -37,7 +36,6 @@ def service(implementation, rpc, servicer_context): def _unary_stream_service(request): - def service(implementation, rpc, servicer_context): _service.unary_stream(implementation, rpc, request, servicer_context) @@ -45,7 +43,6 @@ def service(implementation, rpc, servicer_context): def _stream_unary_service(handler): - def service(implementation, rpc, servicer_context): _service.stream_unary(implementation, rpc, handler, servicer_context) @@ -53,7 +50,6 @@ def service(implementation, rpc, servicer_context): def _stream_stream_service(handler): - def service(implementation, rpc, servicer_context): _service.stream_stream(implementation, rpc, handler, servicer_context) @@ -61,46 +57,79 @@ def service(implementation, rpc, servicer_context): class _Serverish(_common.Serverish): - def __init__(self, descriptors_to_servicers, time): self._descriptors_to_servicers = descriptors_to_servicers self._time = time - def _invoke(self, service_behavior, method_descriptor, handler, - invocation_metadata, deadline): - implementation = _implementation(self._descriptors_to_servicers, - method_descriptor) + def _invoke( + self, + service_behavior, + method_descriptor, + handler, + invocation_metadata, + deadline, + ): + implementation = _implementation( + self._descriptors_to_servicers, method_descriptor + ) rpc = _rpc.Rpc(handler, invocation_metadata) if handler.add_termination_callback(rpc.extrinsic_abort): servicer_context = _servicer_context.ServicerContext( - rpc, self._time, deadline) - service_thread = threading.Thread(target=service_behavior, - args=( - implementation, - rpc, - servicer_context, - )) + rpc, self._time, deadline + ) + service_thread = threading.Thread( + target=service_behavior, + args=( + implementation, + rpc, + servicer_context, + ), + ) service_thread.start() - def invoke_unary_unary(self, method_descriptor, handler, - invocation_metadata, request, deadline): - self._invoke(_unary_unary_service(request), method_descriptor, handler, - invocation_metadata, deadline) - - def invoke_unary_stream(self, method_descriptor, handler, - invocation_metadata, request, deadline): - self._invoke(_unary_stream_service(request), method_descriptor, handler, - invocation_metadata, deadline) - - def invoke_stream_unary(self, method_descriptor, handler, - invocation_metadata, deadline): - self._invoke(_stream_unary_service(handler), method_descriptor, handler, - invocation_metadata, deadline) - - def invoke_stream_stream(self, method_descriptor, handler, - invocation_metadata, deadline): - self._invoke(_stream_stream_service(handler), method_descriptor, - handler, invocation_metadata, deadline) + def invoke_unary_unary( + self, method_descriptor, handler, invocation_metadata, request, deadline + ): + self._invoke( + _unary_unary_service(request), + method_descriptor, + handler, + invocation_metadata, + deadline, + ) + + def invoke_unary_stream( + self, method_descriptor, handler, invocation_metadata, request, deadline + ): + self._invoke( + _unary_stream_service(request), + method_descriptor, + handler, + invocation_metadata, + deadline, + ) + + def invoke_stream_unary( + self, method_descriptor, handler, invocation_metadata, deadline + ): + self._invoke( + _stream_unary_service(handler), + method_descriptor, + handler, + invocation_metadata, + deadline, + ) + + def invoke_stream_stream( + self, method_descriptor, handler, invocation_metadata, deadline + ): + self._invoke( + _stream_stream_service(handler), + method_descriptor, + handler, + invocation_metadata, + deadline, + ) def _deadline_and_handler(requests_closed, time, timeout): @@ -108,45 +137,51 @@ def _deadline_and_handler(requests_closed, time, timeout): return None, _handler.handler_without_deadline(requests_closed) else: deadline = time.time() + timeout - handler = _handler.handler_with_deadline(requests_closed, time, - deadline) + handler = _handler.handler_with_deadline( + requests_closed, time, deadline + ) return deadline, handler class _Server(grpc_testing.Server): - def __init__(self, serverish, time): self._serverish = serverish self._time = time - def invoke_unary_unary(self, method_descriptor, invocation_metadata, - request, timeout): + def invoke_unary_unary( + self, method_descriptor, invocation_metadata, request, timeout + ): deadline, handler = _deadline_and_handler(True, self._time, timeout) - self._serverish.invoke_unary_unary(method_descriptor, handler, - invocation_metadata, request, - deadline) + self._serverish.invoke_unary_unary( + method_descriptor, handler, invocation_metadata, request, deadline + ) return _server_rpc.UnaryUnaryServerRpc(handler) - def invoke_unary_stream(self, method_descriptor, invocation_metadata, - request, timeout): + def invoke_unary_stream( + self, method_descriptor, invocation_metadata, request, timeout + ): deadline, handler = _deadline_and_handler(True, self._time, timeout) - self._serverish.invoke_unary_stream(method_descriptor, handler, - invocation_metadata, request, - deadline) + self._serverish.invoke_unary_stream( + method_descriptor, handler, invocation_metadata, request, deadline + ) return _server_rpc.UnaryStreamServerRpc(handler) - def invoke_stream_unary(self, method_descriptor, invocation_metadata, - timeout): + def invoke_stream_unary( + self, method_descriptor, invocation_metadata, timeout + ): deadline, handler = _deadline_and_handler(False, self._time, timeout) - self._serverish.invoke_stream_unary(method_descriptor, handler, - invocation_metadata, deadline) + self._serverish.invoke_stream_unary( + method_descriptor, handler, invocation_metadata, deadline + ) return _server_rpc.StreamUnaryServerRpc(handler) - def invoke_stream_stream(self, method_descriptor, invocation_metadata, - timeout): + def invoke_stream_stream( + self, method_descriptor, invocation_metadata, timeout + ): deadline, handler = _deadline_and_handler(False, self._time, timeout) - self._serverish.invoke_stream_stream(method_descriptor, handler, - invocation_metadata, deadline) + self._serverish.invoke_stream_stream( + method_descriptor, handler, invocation_metadata, deadline + ) return _server_rpc.StreamStreamServerRpc(handler) diff --git a/src/python/grpcio_testing/grpc_testing/_server/_server_rpc.py b/src/python/grpcio_testing/grpc_testing/_server/_server_rpc.py index 30de8ff0e2b91..d068587c0d94f 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_server_rpc.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_server_rpc.py @@ -16,7 +16,6 @@ class UnaryUnaryServerRpc(grpc_testing.UnaryUnaryServerRpc): - def __init__(self, handler): self._handler = handler @@ -31,7 +30,6 @@ def termination(self): class UnaryStreamServerRpc(grpc_testing.UnaryStreamServerRpc): - def __init__(self, handler): self._handler = handler @@ -49,7 +47,6 @@ def termination(self): class StreamUnaryServerRpc(grpc_testing.StreamUnaryServerRpc): - def __init__(self, handler): self._handler = handler @@ -70,7 +67,6 @@ def termination(self): class StreamStreamServerRpc(grpc_testing.StreamStreamServerRpc): - def __init__(self, handler): self._handler = handler diff --git a/src/python/grpcio_testing/grpc_testing/_server/_service.py b/src/python/grpcio_testing/grpc_testing/_server/_service.py index 661257e275e79..fe936f9cde0dd 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_service.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_service.py @@ -18,7 +18,6 @@ class _RequestIterator(object): - def __init__(self, rpc, handler): self._rpc = rpc self._handler = handler @@ -81,10 +80,12 @@ def unary_stream(implementation, rpc, request, servicer_context): def stream_unary(implementation, rpc, handler, servicer_context): - _unary_response(_RequestIterator(rpc, handler), implementation, rpc, - servicer_context) + _unary_response( + _RequestIterator(rpc, handler), implementation, rpc, servicer_context + ) def stream_stream(implementation, rpc, handler, servicer_context): - _stream_response(_RequestIterator(rpc, handler), implementation, rpc, - servicer_context) + _stream_response( + _RequestIterator(rpc, handler), implementation, rpc, servicer_context + ) diff --git a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py index c63750f978597..01949c9009a93 100644 --- a/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py +++ b/src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py @@ -17,7 +17,6 @@ class ServicerContext(grpc.ServicerContext): - def __init__(self, rpc, time, deadline): self._rpc = rpc self._time = time @@ -61,17 +60,20 @@ def set_compression(self): def send_initial_metadata(self, initial_metadata): initial_metadata_sent = self._rpc.send_initial_metadata( - _common.fuss_with_metadata(initial_metadata)) + _common.fuss_with_metadata(initial_metadata) + ) if not initial_metadata_sent: raise ValueError( - 'ServicerContext.send_initial_metadata called too late!') + "ServicerContext.send_initial_metadata called too late!" + ) def disable_next_message_compression(self): raise NotImplementedError() def set_trailing_metadata(self, trailing_metadata): self._rpc.set_trailing_metadata( - _common.fuss_with_metadata(trailing_metadata)) + _common.fuss_with_metadata(trailing_metadata) + ) def abort(self, code, details): with self._rpc._condition: diff --git a/src/python/grpcio_testing/grpc_testing/_time.py b/src/python/grpcio_testing/grpc_testing/_time.py index 9692c34e6f38b..71eb54086869e 100644 --- a/src/python/grpcio_testing/grpc_testing/_time.py +++ b/src/python/grpcio_testing/grpc_testing/_time.py @@ -42,18 +42,21 @@ def _call_in_thread(behaviors): class _State(object): - def __init__(self): self.condition = threading.Condition() self.times_to_behaviors = collections.defaultdict(list) class _Delta( - collections.namedtuple('_Delta', ( - 'mature_behaviors', - 'earliest_mature_time', - 'earliest_immature_time', - ))): + collections.namedtuple( + "_Delta", + ( + "mature_behaviors", + "earliest_mature_time", + "earliest_immature_time", + ), + ) +): pass @@ -66,19 +69,20 @@ def _process(state, now): if earliest_mature_time is None: earliest_mature_time = earliest_time earliest_mature_behaviors = state.times_to_behaviors.pop( - earliest_time) + earliest_time + ) mature_behaviors.extend(earliest_mature_behaviors) else: earliest_immature_time = earliest_time break else: earliest_immature_time = None - return _Delta(mature_behaviors, earliest_mature_time, - earliest_immature_time) + return _Delta( + mature_behaviors, earliest_mature_time, earliest_immature_time + ) class _Future(grpc.Future): - def __init__(self, state, behavior, time): self._state = state self._behavior = behavior @@ -91,7 +95,8 @@ def cancel(self): return True else: behaviors_at_time = self._state.times_to_behaviors.get( - self._time) + self._time + ) if behaviors_at_time is None: return False else: @@ -126,7 +131,6 @@ def add_done_callback(self, fn): class StrictRealTime(grpc_testing.Time): - def __init__(self): self._state = _State() self._active = False @@ -153,9 +157,10 @@ def _activity(self): def _ensure_called_through(self, time): with self._state.condition: - while ((self._state.times_to_behaviors and - min(self._state.times_to_behaviors) < time) or - (self._calling is not None and self._calling < time)): + while ( + self._state.times_to_behaviors + and min(self._state.times_to_behaviors) < time + ) or (self._calling is not None and self._calling < time): self._state.condition.wait() def _call_at(self, behavior, time): @@ -189,7 +194,6 @@ def sleep_until(self, time): class StrictFakeTime(grpc_testing.Time): - def __init__(self, time): self._state = _State() self._time = time diff --git a/src/python/grpcio_testing/setup.py b/src/python/grpcio_testing/setup.py index a983d8d43dd82..fad6d799bde47 100644 --- a/src/python/grpcio_testing/setup.py +++ b/src/python/grpcio_testing/setup.py @@ -19,7 +19,7 @@ import setuptools _PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__)) -_README_PATH = os.path.join(_PACKAGE_PATH, 'README.rst') +_README_PATH = os.path.join(_PACKAGE_PATH, "README.rst") # Ensure we're in the proper directory whether or not we're being used by pip. os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -31,7 +31,7 @@ class _NoOpCommand(setuptools.Command): """No-op command.""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -45,12 +45,12 @@ def run(self): PACKAGE_DIRECTORIES = { - '': '.', + "": ".", } INSTALL_REQUIRES = ( - 'protobuf>=4.21.6', - 'grpcio>={version}'.format(version=grpc_version.VERSION), + "protobuf>=4.21.6", + "grpcio>={version}".format(version=grpc_version.VERSION), ) try: @@ -59,23 +59,25 @@ def run(self): # we are in the build environment, otherwise the above import fails COMMAND_CLASS = { # Run preprocess from the repository *before* doing any packaging! - 'preprocess': _testing_commands.Preprocess, + "preprocess": _testing_commands.Preprocess, } except ImportError: COMMAND_CLASS = { # wire up commands to no-op not to break the external dependencies - 'preprocess': _NoOpCommand, + "preprocess": _NoOpCommand, } -setuptools.setup(name='grpcio-testing', - version=grpc_version.VERSION, - license='Apache License 2.0', - description='Testing utilities for gRPC Python', - long_description=open(_README_PATH, 'r').read(), - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - url='https://grpc.io', - package_dir=PACKAGE_DIRECTORIES, - packages=setuptools.find_packages('.'), - install_requires=INSTALL_REQUIRES, - cmdclass=COMMAND_CLASS) +setuptools.setup( + name="grpcio-testing", + version=grpc_version.VERSION, + license="Apache License 2.0", + description="Testing utilities for gRPC Python", + long_description=open(_README_PATH, "r").read(), + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + url="https://grpc.io", + package_dir=PACKAGE_DIRECTORIES, + packages=setuptools.find_packages("."), + install_requires=INSTALL_REQUIRES, + cmdclass=COMMAND_CLASS, +) diff --git a/src/python/grpcio_testing/testing_commands.py b/src/python/grpcio_testing/testing_commands.py index fb40d37efb639..b7374814ac6d2 100644 --- a/src/python/grpcio_testing/testing_commands.py +++ b/src/python/grpcio_testing/testing_commands.py @@ -19,13 +19,13 @@ import setuptools ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) -LICENSE = os.path.join(ROOT_DIR, '../../../LICENSE') +LICENSE = os.path.join(ROOT_DIR, "../../../LICENSE") class Preprocess(setuptools.Command): """Command to copy LICENSE from root directory.""" - description = '' + description = "" user_options = [] def initialize_options(self): @@ -36,4 +36,4 @@ def finalize_options(self): def run(self): if os.path.isfile(LICENSE): - shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, 'LICENSE')) + shutil.copyfile(LICENSE, os.path.join(ROOT_DIR, "LICENSE")) diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py index b9336d899db7b..22f458b8c9b62 100644 --- a/src/python/grpcio_tests/commands.py +++ b/src/python/grpcio_tests/commands.py @@ -30,10 +30,10 @@ from setuptools.command import test PYTHON_STEM = os.path.dirname(os.path.abspath(__file__)) -GRPC_STEM = os.path.abspath(PYTHON_STEM + '../../../../') -GRPC_PROTO_STEM = os.path.join(GRPC_STEM, 'src', 'proto') -PROTO_STEM = os.path.join(PYTHON_STEM, 'src', 'proto') -PYTHON_PROTO_TOP_LEVEL = os.path.join(PYTHON_STEM, 'src') +GRPC_STEM = os.path.abspath(PYTHON_STEM + "../../../../") +GRPC_PROTO_STEM = os.path.join(GRPC_STEM, "src", "proto") +PROTO_STEM = os.path.join(PYTHON_STEM, "src", "proto") +PYTHON_PROTO_TOP_LEVEL = os.path.join(PYTHON_STEM, "src") class CommandError(object): @@ -41,8 +41,7 @@ class CommandError(object): class GatherProto(setuptools.Command): - - description = 'gather proto dependencies' + description = "gather proto dependencies" user_options = [] def initialize_options(self): @@ -61,8 +60,8 @@ def run(self): pass shutil.copytree(GRPC_PROTO_STEM, PROTO_STEM) for root, _, _ in os.walk(PYTHON_PROTO_TOP_LEVEL): - path = os.path.join(root, '__init__.py') - open(path, 'a').close() + path = os.path.join(root, "__init__.py") + open(path, "a").close() class BuildPy(build_py.build_py): @@ -70,16 +69,16 @@ class BuildPy(build_py.build_py): def run(self): try: - self.run_command('build_package_protos') + self.run_command("build_package_protos") except CommandError as error: - sys.stderr.write('warning: %s\n' % error.message) + sys.stderr.write("warning: %s\n" % error.message) build_py.build_py.run(self) class TestLite(setuptools.Command): """Command to run tests without fetching or building anything.""" - description = 'run tests without fetching or building anything.' + description = "run tests without fetching or building anything." user_options = [] def initialize_options(self): @@ -93,12 +92,13 @@ def run(self): self._add_eggs_to_path() import tests + loader = tests.Loader() - loader.loadTestsFromNames(['tests']) + loader.loadTestsFromNames(["tests"]) runner = tests.Runner(dedicated_threads=True) result = runner.run(loader.suite) if not result.wasSuccessful(): - sys.exit('Test failure') + sys.exit("Test failure") def _add_eggs_to_path(self): """Fetch install and test requirements""" @@ -113,7 +113,7 @@ class TestPy3Only(setuptools.Command): directory. """ - description = 'run tests for py3+ features' + description = "run tests for py3+ features" user_options = [] def initialize_options(self): @@ -125,12 +125,13 @@ def finalize_options(self): def run(self): self._add_eggs_to_path() import tests + loader = tests.Loader() - loader.loadTestsFromNames(['tests_py3_only']) + loader.loadTestsFromNames(["tests_py3_only"]) runner = tests.Runner() result = runner.run(loader.suite) if not result.wasSuccessful(): - sys.exit('Test failure') + sys.exit("Test failure") def _add_eggs_to_path(self): self.distribution.fetch_build_eggs(self.distribution.install_requires) @@ -140,7 +141,7 @@ def _add_eggs_to_path(self): class TestAio(setuptools.Command): """Command to run aio tests without fetching or building anything.""" - description = 'run aio tests without fetching or building anything.' + description = "run aio tests without fetching or building anything." user_options = [] def initialize_options(self): @@ -153,15 +154,16 @@ def run(self): self._add_eggs_to_path() import tests + loader = tests.Loader() - loader.loadTestsFromNames(['tests_aio']) + loader.loadTestsFromNames(["tests_aio"]) # Even without dedicated threads, the framework will somehow spawn a # new thread for tests to run upon. New thread doesn't have event loop # attached by default, so initialization is needed. runner = tests.Runner(dedicated_threads=False) result = runner.run(loader.suite) if not result.wasSuccessful(): - sys.exit('Test failure') + sys.exit("Test failure") def _add_eggs_to_path(self): """Fetch install and test requirements""" @@ -174,67 +176,68 @@ class TestGevent(setuptools.Command): BANNED_TESTS = ( # Fork support is not compatible with gevent - 'fork._fork_interop_test.ForkInteropTest', + "fork._fork_interop_test.ForkInteropTest", # These tests send a lot of RPCs and are really slow on gevent. They will # eventually succeed, but need to dig into performance issues. - 'unit._cython._no_messages_server_completion_queue_per_call_test.Test.test_rpcs', - 'unit._cython._no_messages_single_server_completion_queue_test.Test.test_rpcs', - 'unit._compression_test', + "unit._cython._no_messages_server_completion_queue_per_call_test.Test.test_rpcs", + "unit._cython._no_messages_single_server_completion_queue_test.Test.test_rpcs", + "unit._compression_test", # TODO(https://github.com/grpc/grpc/issues/16890) enable this test - 'unit._cython._channel_test.ChannelTest.test_multiple_channels_lonely_connectivity', + "unit._cython._channel_test.ChannelTest.test_multiple_channels_lonely_connectivity", # I have no idea why this doesn't work in gevent, but it shouldn't even be # using the c-core - 'testing._client_test.ClientTest.test_infinite_request_stream_real_time', + "testing._client_test.ClientTest.test_infinite_request_stream_real_time", # TODO(https://github.com/grpc/grpc/issues/15743) enable this test - 'unit._session_cache_test.SSLSessionCacheTest.testSSLSessionCacheLRU', + "unit._session_cache_test.SSLSessionCacheTest.testSSLSessionCacheLRU", # TODO(https://github.com/grpc/grpc/issues/14789) enable this test - 'unit._server_ssl_cert_config_test', + "unit._server_ssl_cert_config_test", # TODO(https://github.com/grpc/grpc/issues/14901) enable this test - 'protoc_plugin._python_plugin_test.PythonPluginTest', - 'protoc_plugin._python_plugin_test.SimpleStubsPluginTest', + "protoc_plugin._python_plugin_test.PythonPluginTest", + "protoc_plugin._python_plugin_test.SimpleStubsPluginTest", # Beta API is unsupported for gevent - 'protoc_plugin.beta_python_plugin_test', - 'unit.beta._beta_features_test', + "protoc_plugin.beta_python_plugin_test", + "unit.beta._beta_features_test", # TODO(https://github.com/grpc/grpc/issues/15411) unpin gevent version # This test will stuck while running higher version of gevent - 'unit._auth_context_test.AuthContextTest.testSessionResumption', + "unit._auth_context_test.AuthContextTest.testSessionResumption", # TODO(https://github.com/grpc/grpc/issues/15411) enable these tests - 'unit._channel_ready_future_test.ChannelReadyFutureTest.test_immediately_connectable_channel_connectivity', + "unit._channel_ready_future_test.ChannelReadyFutureTest.test_immediately_connectable_channel_connectivity", "unit._cython._channel_test.ChannelTest.test_single_channel_lonely_connectivity", - 'unit._exit_test.ExitTest.test_in_flight_unary_unary_call', - 'unit._exit_test.ExitTest.test_in_flight_unary_stream_call', - 'unit._exit_test.ExitTest.test_in_flight_stream_unary_call', - 'unit._exit_test.ExitTest.test_in_flight_stream_stream_call', - 'unit._exit_test.ExitTest.test_in_flight_partial_unary_stream_call', - 'unit._exit_test.ExitTest.test_in_flight_partial_stream_unary_call', - 'unit._exit_test.ExitTest.test_in_flight_partial_stream_stream_call', + "unit._exit_test.ExitTest.test_in_flight_unary_unary_call", + "unit._exit_test.ExitTest.test_in_flight_unary_stream_call", + "unit._exit_test.ExitTest.test_in_flight_stream_unary_call", + "unit._exit_test.ExitTest.test_in_flight_stream_stream_call", + "unit._exit_test.ExitTest.test_in_flight_partial_unary_stream_call", + "unit._exit_test.ExitTest.test_in_flight_partial_stream_unary_call", + "unit._exit_test.ExitTest.test_in_flight_partial_stream_stream_call", # TODO(https://github.com/grpc/grpc/issues/18980): Reenable. - 'unit._signal_handling_test.SignalHandlingTest', - 'unit._metadata_flags_test', - 'health_check._health_servicer_test.HealthServicerTest.test_cancelled_watch_removed_from_watch_list', + "unit._signal_handling_test.SignalHandlingTest", + "unit._metadata_flags_test", + "health_check._health_servicer_test.HealthServicerTest.test_cancelled_watch_removed_from_watch_list", # TODO(https://github.com/grpc/grpc/issues/17330) enable these three tests - 'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels', - 'channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels_and_sockets', - 'channelz._channelz_servicer_test.ChannelzServicerTest.test_streaming_rpc', + "channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels", + "channelz._channelz_servicer_test.ChannelzServicerTest.test_many_subchannels_and_sockets", + "channelz._channelz_servicer_test.ChannelzServicerTest.test_streaming_rpc", # TODO(https://github.com/grpc/grpc/issues/15411) enable this test - 'unit._cython._channel_test.ChannelTest.test_negative_deadline_connectivity', + "unit._cython._channel_test.ChannelTest.test_negative_deadline_connectivity", # TODO(https://github.com/grpc/grpc/issues/15411) enable this test - 'unit._local_credentials_test.LocalCredentialsTest', + "unit._local_credentials_test.LocalCredentialsTest", # TODO(https://github.com/grpc/grpc/issues/22020) LocalCredentials # aren't supported with custom io managers. - 'unit._contextvars_propagation_test', - 'testing._time_test.StrictRealTimeTest', + "unit._contextvars_propagation_test", + "testing._time_test.StrictRealTimeTest", ) BANNED_WINDOWS_TESTS = ( # TODO(https://github.com/grpc/grpc/pull/15411) enable this test - 'unit._dns_resolver_test.DNSResolverTest.test_connect_loopback', + "unit._dns_resolver_test.DNSResolverTest.test_connect_loopback", # TODO(https://github.com/grpc/grpc/pull/15411) enable this test - 'unit._server_test.ServerTest.test_failed_port_binding_exception', + "unit._server_test.ServerTest.test_failed_port_binding_exception", ) BANNED_MACOS_TESTS = ( # TODO(https://github.com/grpc/grpc/issues/15411) enable this test - 'unit._dynamic_stubs_test.DynamicStubTest',) - description = 'run tests with gevent. Assumes grpc/gevent are installed' + "unit._dynamic_stubs_test.DynamicStubTest", + ) + description = "run tests with gevent. Assumes grpc/gevent are installed" user_options = [] def initialize_options(self): @@ -247,6 +250,7 @@ def finalize_options(self): def run(self): import gevent from gevent import monkey + monkey.patch_all() threadpool = gevent.hub.get_hub().threadpool @@ -262,38 +266,39 @@ def run(self): import grpc.experimental.gevent import tests + grpc.experimental.gevent.init_gevent() import gevent import tests + loader = tests.Loader() - loader.loadTestsFromNames(['tests', 'tests_gevent']) + loader.loadTestsFromNames(["tests", "tests_gevent"]) runner = tests.Runner() - if sys.platform == 'win32': + if sys.platform == "win32": runner.skip_tests(self.BANNED_TESTS + self.BANNED_WINDOWS_TESTS) - elif sys.platform == 'darwin': + elif sys.platform == "darwin": runner.skip_tests(self.BANNED_TESTS + self.BANNED_MACOS_TESTS) else: runner.skip_tests(self.BANNED_TESTS) result = gevent.spawn(runner.run, loader.suite) result.join() if not result.value.wasSuccessful(): - sys.exit('Test failure') + sys.exit("Test failure") class RunInterop(test.test): - - description = 'run interop test client/server' + description = "run interop test client/server" user_options = [ - ('args=', None, 'pass-thru arguments for the client/server'), - ('client', None, 'flag indicating to run the client'), - ('server', None, 'flag indicating to run the server'), - ('use-asyncio', None, 'flag indicating to run the asyncio stack') + ("args=", None, "pass-thru arguments for the client/server"), + ("client", None, "flag indicating to run the client"), + ("server", None, "flag indicating to run the server"), + ("use-asyncio", None, "flag indicating to run the asyncio stack"), ] def initialize_options(self): - self.args = '' + self.args = "" self.client = False self.server = False self.use_asyncio = False @@ -301,12 +306,14 @@ def initialize_options(self): def finalize_options(self): if self.client and self.server: raise _errors.DistutilsOptionError( - 'you may only specify one of client or server') + "you may only specify one of client or server" + ) def run(self): if self.distribution.install_requires: self.distribution.fetch_build_eggs( - self.distribution.install_requires) + self.distribution.install_requires + ) if self.distribution.tests_require: self.distribution.fetch_build_eggs(self.distribution.tests_require) if self.client: @@ -321,10 +328,12 @@ def run_server(self): import asyncio from tests_aio.interop import server + sys.argv[1:] = self.args.split() asyncio.get_event_loop().run_until_complete(server.serve()) else: from tests.interop import server + sys.argv[1:] = self.args.split() server.serve() @@ -332,17 +341,17 @@ def run_client(self): # We import here to ensure that our setuptools parent has had a chance to # edit the Python system path. from tests.interop import client + sys.argv[1:] = self.args.split() client.test_interoperability() class RunFork(test.test): - - description = 'run fork test client' - user_options = [('args=', 'a', 'pass-thru arguments for the client')] + description = "run fork test client" + user_options = [("args=", "a", "pass-thru arguments for the client")] def initialize_options(self): - self.args = '' + self.args = "" def finalize_options(self): # distutils requires this override. @@ -351,11 +360,13 @@ def finalize_options(self): def run(self): if self.distribution.install_requires: self.distribution.fetch_build_eggs( - self.distribution.install_requires) + self.distribution.install_requires + ) if self.distribution.tests_require: self.distribution.fetch_build_eggs(self.distribution.tests_require) # We import here to ensure that our setuptools parent has had a chance to # edit the Python system path. from tests.fork import client + sys.argv[1:] = self.args.split() client.test_fork() diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py index fbf2e3835fd72..8409dcc95197b 100644 --- a/src/python/grpcio_tests/setup.py +++ b/src/python/grpcio_tests/setup.py @@ -30,72 +30,77 @@ import commands import grpc_version -LICENSE = 'Apache License 2.0' +LICENSE = "Apache License 2.0" PACKAGE_DIRECTORIES = { - '': '.', + "": ".", } INSTALL_REQUIRES = ( - 'coverage>=4.0', 'grpcio>={version}'.format(version=grpc_version.VERSION), - 'grpcio-channelz>={version}'.format(version=grpc_version.VERSION), - 'grpcio-status>={version}'.format(version=grpc_version.VERSION), - 'grpcio-tools>={version}'.format(version=grpc_version.VERSION), - 'grpcio-health-checking>={version}'.format(version=grpc_version.VERSION), - 'oauth2client>=1.4.7', 'protobuf>=4.21.6rc1,!=4.22.0.*', - 'google-auth>=1.17.2', 'requests>=2.14.2') + "coverage>=4.0", + "grpcio>={version}".format(version=grpc_version.VERSION), + "grpcio-channelz>={version}".format(version=grpc_version.VERSION), + "grpcio-status>={version}".format(version=grpc_version.VERSION), + "grpcio-tools>={version}".format(version=grpc_version.VERSION), + "grpcio-health-checking>={version}".format(version=grpc_version.VERSION), + "oauth2client>=1.4.7", + "protobuf>=4.21.6rc1,!=4.22.0.*", + "google-auth>=1.17.2", + "requests>=2.14.2", +) COMMAND_CLASS = { # Run `preprocess` *before* doing any packaging! - 'preprocess': commands.GatherProto, - 'build_package_protos': grpc_tools.command.BuildPackageProtos, - 'build_py': commands.BuildPy, - 'run_fork': commands.RunFork, - 'run_interop': commands.RunInterop, - 'test_lite': commands.TestLite, - 'test_gevent': commands.TestGevent, - 'test_aio': commands.TestAio, - 'test_py3_only': commands.TestPy3Only, + "preprocess": commands.GatherProto, + "build_package_protos": grpc_tools.command.BuildPackageProtos, + "build_py": commands.BuildPy, + "run_fork": commands.RunFork, + "run_interop": commands.RunInterop, + "test_lite": commands.TestLite, + "test_gevent": commands.TestGevent, + "test_aio": commands.TestAio, + "test_py3_only": commands.TestPy3Only, } PACKAGE_DATA = { - 'tests.interop': [ - 'credentials/ca.pem', - 'credentials/server1.key', - 'credentials/server1.pem', + "tests.interop": [ + "credentials/ca.pem", + "credentials/server1.key", + "credentials/server1.pem", ], - 'tests.protoc_plugin.protos.invocation_testing': [ - 'same.proto', 'compiler.proto' + "tests.protoc_plugin.protos.invocation_testing": [ + "same.proto", + "compiler.proto", ], - 'tests.protoc_plugin.protos.invocation_testing.split_messages': [ - 'messages.proto', + "tests.protoc_plugin.protos.invocation_testing.split_messages": [ + "messages.proto", ], - 'tests.protoc_plugin.protos.invocation_testing.split_services': [ - 'services.proto', + "tests.protoc_plugin.protos.invocation_testing.split_services": [ + "services.proto", ], - 'tests.testing.proto': [ - 'requests.proto', - 'services.proto', + "tests.testing.proto": [ + "requests.proto", + "services.proto", ], - 'tests.unit': [ - 'credentials/ca.pem', - 'credentials/server1.key', - 'credentials/server1.pem', + "tests.unit": [ + "credentials/ca.pem", + "credentials/server1.key", + "credentials/server1.pem", ], - 'tests': ['tests.json'], + "tests": ["tests.json"], } -TEST_SUITE = 'tests' -TEST_LOADER = 'tests:Loader' -TEST_RUNNER = 'tests:Runner' +TEST_SUITE = "tests" +TEST_LOADER = "tests:Loader" +TEST_RUNNER = "tests:Runner" TESTS_REQUIRE = INSTALL_REQUIRES -PACKAGES = setuptools.find_packages('.') +PACKAGES = setuptools.find_packages(".") if __name__ == "__main__": multiprocessing.freeze_support() setuptools.setup( - name='grpcio-tests', + name="grpcio-tests", version=grpc_version.VERSION, license=LICENSE, packages=list(PACKAGES), diff --git a/src/python/grpcio_tests/tests/_loader.py b/src/python/grpcio_tests/tests/_loader.py index ef5cfaee36ddb..0d82ce401163f 100644 --- a/src/python/grpcio_tests/tests/_loader.py +++ b/src/python/grpcio_tests/tests/_loader.py @@ -23,7 +23,7 @@ import coverage -TEST_MODULE_REGEX = r'^.*_test$' +TEST_MODULE_REGEX = r"^.*_test$" # Determines the path og a given path relative to the first matching @@ -32,11 +32,11 @@ def _relativize_to_sys_path(path): for sys_path in sys.path: if path.startswith(sys_path): - relative = path[len(sys_path):] + relative = path[len(sys_path) :] if not relative: return "" if relative.startswith(os.path.sep): - relative = relative[len(os.path.sep):] + relative = relative[len(os.path.sep) :] if not relative.endswith(os.path.sep): relative += os.path.sep return relative @@ -50,14 +50,14 @@ def _relative_path_to_module_prefix(path): class Loader(object): """Test loader for setuptools test suite support. - Attributes: - suite (unittest.TestSuite): All tests collected by the loader. - loader (unittest.TestLoader): Standard Python unittest loader to be ran per - module discovered. - module_matcher (re.RegexObject): A regular expression object to match - against module names and determine whether or not the discovered module - contributes to the test suite. - """ + Attributes: + suite (unittest.TestSuite): All tests collected by the loader. + loader (unittest.TestLoader): Standard Python unittest loader to be ran per + module discovered. + module_matcher (re.RegexObject): A regular expression object to match + against module names and determine whether or not the discovered module + contributes to the test suite. + """ def __init__(self): self.suite = unittest.TestSuite() @@ -66,13 +66,14 @@ def __init__(self): def loadTestsFromNames(self, names, module=None): """Function mirroring TestLoader::loadTestsFromNames, as expected by - setuptools.setup argument `test_loader`.""" + setuptools.setup argument `test_loader`.""" # ensure that we capture decorators and definitions (else our coverage # measure unnecessarily suffers) coverage_context = coverage.Coverage(data_suffix=True) coverage_context.start() imported_modules = tuple( - importlib.import_module(name) for name in names) + importlib.import_module(name) for name in names + ) for imported_module in imported_modules: self.visit_module(imported_module) for imported_module in imported_modules: @@ -88,18 +89,20 @@ def loadTestsFromNames(self, names, module=None): def walk_packages(self, package_paths): """Walks over the packages, dispatching `visit_module` calls. - Args: - package_paths (list): A list of paths over which to walk through modules - along. - """ + Args: + package_paths (list): A list of paths over which to walk through modules + along. + """ for path in package_paths: self._walk_package(path) def _walk_package(self, package_path): prefix = _relative_path_to_module_prefix( - _relativize_to_sys_path(package_path)) - for importer, module_name, is_package in (pkgutil.walk_packages( - [package_path], prefix)): + _relativize_to_sys_path(package_path) + ) + for importer, module_name, is_package in pkgutil.walk_packages( + [package_path], prefix + ): found_module = importer.find_module(module_name) module = None if module_name in sys.modules: @@ -111,10 +114,10 @@ def _walk_package(self, package_path): def visit_module(self, module): """Visits the module, adding discovered tests to the test suite. - Args: - module (module): Module to match against self.module_matcher; if matched - it has its tests loaded via self.loader into self.suite. - """ + Args: + module (module): Module to match against self.module_matcher; if matched + it has its tests loaded via self.loader into self.suite. + """ if self.module_matcher.match(module.__name__): module_suite = self.loader.loadTestsFromModule(module) self.suite.addTest(module_suite) @@ -123,12 +126,12 @@ def visit_module(self, module): def iterate_suite_cases(suite): """Generator over all unittest.TestCases in a unittest.TestSuite. - Args: - suite (unittest.TestSuite): Suite to iterate over in the generator. + Args: + suite (unittest.TestSuite): Suite to iterate over in the generator. - Returns: - generator: A generator over all unittest.TestCases in `suite`. - """ + Returns: + generator: A generator over all unittest.TestCases in `suite`. + """ for item in suite: if isinstance(item, unittest.TestSuite): for child_item in iterate_suite_cases(item): @@ -136,5 +139,6 @@ def iterate_suite_cases(suite): elif isinstance(item, unittest.TestCase): yield item else: - raise ValueError('unexpected suite item of type {}'.format( - type(item))) + raise ValueError( + "unexpected suite item of type {}".format(type(item)) + ) diff --git a/src/python/grpcio_tests/tests/_result.py b/src/python/grpcio_tests/tests/_result.py index 2283004146a4c..6f621d7e75e25 100644 --- a/src/python/grpcio_tests/tests/_result.py +++ b/src/python/grpcio_tests/tests/_result.py @@ -27,46 +27,50 @@ class CaseResult( - collections.namedtuple('CaseResult', [ - 'id', 'name', 'kind', 'stdout', 'stderr', 'skip_reason', 'traceback' - ])): + collections.namedtuple( + "CaseResult", + ["id", "name", "kind", "stdout", "stderr", "skip_reason", "traceback"], + ) +): """A serializable result of a single test case. - Attributes: - id (object): Any serializable object used to denote the identity of this - test case. - name (str or None): A human-readable name of the test case. - kind (CaseResult.Kind): The kind of test result. - stdout (object or None): Output on stdout, or None if nothing was captured. - stderr (object or None): Output on stderr, or None if nothing was captured. - skip_reason (object or None): The reason the test was skipped. Must be - something if self.kind is CaseResult.Kind.SKIP, else None. - traceback (object or None): The traceback of the test. Must be something if - self.kind is CaseResult.Kind.{ERROR, FAILURE, EXPECTED_FAILURE}, else - None. - """ + Attributes: + id (object): Any serializable object used to denote the identity of this + test case. + name (str or None): A human-readable name of the test case. + kind (CaseResult.Kind): The kind of test result. + stdout (object or None): Output on stdout, or None if nothing was captured. + stderr (object or None): Output on stderr, or None if nothing was captured. + skip_reason (object or None): The reason the test was skipped. Must be + something if self.kind is CaseResult.Kind.SKIP, else None. + traceback (object or None): The traceback of the test. Must be something if + self.kind is CaseResult.Kind.{ERROR, FAILURE, EXPECTED_FAILURE}, else + None. + """ class Kind(object): - UNTESTED = 'untested' - RUNNING = 'running' - ERROR = 'error' - FAILURE = 'failure' - SUCCESS = 'success' - SKIP = 'skip' - EXPECTED_FAILURE = 'expected failure' - UNEXPECTED_SUCCESS = 'unexpected success' - - def __new__(cls, - id=None, - name=None, - kind=None, - stdout=None, - stderr=None, - skip_reason=None, - traceback=None): + UNTESTED = "untested" + RUNNING = "running" + ERROR = "error" + FAILURE = "failure" + SUCCESS = "success" + SKIP = "skip" + EXPECTED_FAILURE = "expected failure" + UNEXPECTED_SUCCESS = "unexpected success" + + def __new__( + cls, + id=None, + name=None, + kind=None, + stdout=None, + stderr=None, + skip_reason=None, + traceback=None, + ): """Helper keyword constructor for the namedtuple. - See this class' attributes for information on the arguments.""" + See this class' attributes for information on the arguments.""" assert id is not None assert name is None or isinstance(name, str) if kind is CaseResult.Kind.UNTESTED: @@ -87,53 +91,58 @@ def __new__(cls, pass else: assert False - return super(cls, CaseResult).__new__(cls, id, name, kind, stdout, - stderr, skip_reason, traceback) - - def updated(self, - name=None, - kind=None, - stdout=None, - stderr=None, - skip_reason=None, - traceback=None): + return super(cls, CaseResult).__new__( + cls, id, name, kind, stdout, stderr, skip_reason, traceback + ) + + def updated( + self, + name=None, + kind=None, + stdout=None, + stderr=None, + skip_reason=None, + traceback=None, + ): """Get a new validated CaseResult with the fields updated. - See this class' attributes for information on the arguments.""" + See this class' attributes for information on the arguments.""" name = self.name if name is None else name kind = self.kind if kind is None else kind stdout = self.stdout if stdout is None else stdout stderr = self.stderr if stderr is None else stderr skip_reason = self.skip_reason if skip_reason is None else skip_reason traceback = self.traceback if traceback is None else traceback - return CaseResult(id=self.id, - name=name, - kind=kind, - stdout=stdout, - stderr=stderr, - skip_reason=skip_reason, - traceback=traceback) + return CaseResult( + id=self.id, + name=name, + kind=kind, + stdout=stdout, + stderr=stderr, + skip_reason=skip_reason, + traceback=traceback, + ) class AugmentedResult(unittest.TestResult): """unittest.Result that keeps track of additional information. - Uses CaseResult objects to store test-case results, providing additional - information beyond that of the standard Python unittest library, such as - standard output. + Uses CaseResult objects to store test-case results, providing additional + information beyond that of the standard Python unittest library, such as + standard output. - Attributes: - id_map (callable): A unary callable mapping unittest.TestCase objects to - unique identifiers. - cases (dict): A dictionary mapping from the identifiers returned by id_map - to CaseResult objects corresponding to those IDs. - """ + Attributes: + id_map (callable): A unary callable mapping unittest.TestCase objects to + unique identifiers. + cases (dict): A dictionary mapping from the identifiers returned by id_map + to CaseResult objects corresponding to those IDs. + """ def __init__(self, id_map): """Initialize the object with an identifier mapping. - Arguments: - id_map (callable): Corresponds to the attribute `id_map`.""" + Arguments: + id_map (callable): Corresponds to the attribute `id_map`.""" super(AugmentedResult, self).__init__() self.id_map = id_map self.cases = None @@ -147,73 +156,82 @@ def startTest(self, test): """See unittest.TestResult.startTest.""" super(AugmentedResult, self).startTest(test) case_id = self.id_map(test) - self.cases[case_id] = CaseResult(id=case_id, - name=test.id(), - kind=CaseResult.Kind.RUNNING) + self.cases[case_id] = CaseResult( + id=case_id, name=test.id(), kind=CaseResult.Kind.RUNNING + ) def addError(self, test, err): """See unittest.TestResult.addError.""" super(AugmentedResult, self).addError(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.ERROR, traceback=err) + kind=CaseResult.Kind.ERROR, traceback=err + ) def addFailure(self, test, err): """See unittest.TestResult.addFailure.""" super(AugmentedResult, self).addFailure(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.FAILURE, traceback=err) + kind=CaseResult.Kind.FAILURE, traceback=err + ) def addSuccess(self, test): """See unittest.TestResult.addSuccess.""" super(AugmentedResult, self).addSuccess(test) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.SUCCESS) + kind=CaseResult.Kind.SUCCESS + ) def addSkip(self, test, reason): """See unittest.TestResult.addSkip.""" super(AugmentedResult, self).addSkip(test, reason) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.SKIP, skip_reason=reason) + kind=CaseResult.Kind.SKIP, skip_reason=reason + ) def addExpectedFailure(self, test, err): """See unittest.TestResult.addExpectedFailure.""" super(AugmentedResult, self).addExpectedFailure(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=err) + kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=err + ) def addUnexpectedSuccess(self, test): """See unittest.TestResult.addUnexpectedSuccess.""" super(AugmentedResult, self).addUnexpectedSuccess(test) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.UNEXPECTED_SUCCESS) + kind=CaseResult.Kind.UNEXPECTED_SUCCESS + ) def set_output(self, test, stdout, stderr): """Set the output attributes for the CaseResult corresponding to a test. - Args: - test (unittest.TestCase): The TestCase to set the outputs of. - stdout (str): Output from stdout to assign to self.id_map(test). - stderr (str): Output from stderr to assign to self.id_map(test). - """ + Args: + test (unittest.TestCase): The TestCase to set the outputs of. + stdout (str): Output from stdout to assign to self.id_map(test). + stderr (str): Output from stderr to assign to self.id_map(test). + """ case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - stdout=stdout.decode(), stderr=stderr.decode()) + stdout=stdout.decode(), stderr=stderr.decode() + ) def augmented_results(self, filter): """Convenience method to retrieve filtered case results. - Args: - filter (callable): A unary predicate to filter over CaseResult objects. - """ - return (self.cases[case_id] - for case_id in self.cases - if filter(self.cases[case_id])) + Args: + filter (callable): A unary predicate to filter over CaseResult objects. + """ + return ( + self.cases[case_id] + for case_id in self.cases + if filter(self.cases[case_id]) + ) class CoverageResult(AugmentedResult): @@ -231,7 +249,7 @@ def __init__(self, id_map): def startTest(self, test): """See unittest.TestResult.startTest. - Additionally initializes and begins code coverage tracking.""" + Additionally initializes and begins code coverage tracking.""" super(CoverageResult, self).startTest(test) self.coverage_context = coverage.Coverage(data_suffix=True) self.coverage_context.start() @@ -239,7 +257,7 @@ def startTest(self, test): def stopTest(self, test): """See unittest.TestResult.stopTest. - Additionally stops and deinitializes code coverage tracking.""" + Additionally stops and deinitializes code coverage tracking.""" super(CoverageResult, self).stopTest(test) self.coverage_context.stop() self.coverage_context.save() @@ -248,14 +266,15 @@ def stopTest(self, test): class _Colors(object): """Namespaced constants for terminal color magic numbers.""" - HEADER = '\033[95m' - INFO = '\033[94m' - OK = '\033[92m' - WARN = '\033[93m' - FAIL = '\033[91m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - END = '\033[0m' + + HEADER = "\033[95m" + INFO = "\033[94m" + OK = "\033[92m" + WARN = "\033[93m" + FAIL = "\033[91m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" class TerminalResult(CoverageResult): @@ -264,19 +283,20 @@ class TerminalResult(CoverageResult): def __init__(self, out, id_map): """Initialize the result object. - Args: - out (file-like): Output file to which terminal-colored live results will - be written. - id_map (callable): See AugmentedResult.__init__. - """ + Args: + out (file-like): Output file to which terminal-colored live results will + be written. + id_map (callable): See AugmentedResult.__init__. + """ super(TerminalResult, self).__init__(id_map=id_map) self.out = out def startTestRun(self): """See unittest.TestResult.startTestRun.""" super(TerminalResult, self).startTestRun() - self.out.write(_Colors.HEADER + 'Testing gRPC Python...\n' + - _Colors.END) + self.out.write( + _Colors.HEADER + "Testing gRPC Python...\n" + _Colors.END + ) def stopTestRun(self): """See unittest.TestResult.stopTestRun.""" @@ -287,57 +307,63 @@ def stopTestRun(self): def addError(self, test, err): """See unittest.TestResult.addError.""" super(TerminalResult, self).addError(test, err) - self.out.write(_Colors.FAIL + 'ERROR {}\n'.format(test.id()) + - _Colors.END) + self.out.write( + _Colors.FAIL + "ERROR {}\n".format(test.id()) + _Colors.END + ) self.out.flush() def addFailure(self, test, err): """See unittest.TestResult.addFailure.""" super(TerminalResult, self).addFailure(test, err) - self.out.write(_Colors.FAIL + 'FAILURE {}\n'.format(test.id()) + - _Colors.END) + self.out.write( + _Colors.FAIL + "FAILURE {}\n".format(test.id()) + _Colors.END + ) self.out.flush() def addSuccess(self, test): """See unittest.TestResult.addSuccess.""" super(TerminalResult, self).addSuccess(test) - self.out.write(_Colors.OK + 'SUCCESS {}\n'.format(test.id()) + - _Colors.END) + self.out.write( + _Colors.OK + "SUCCESS {}\n".format(test.id()) + _Colors.END + ) self.out.flush() def addSkip(self, test, reason): """See unittest.TestResult.addSkip.""" super(TerminalResult, self).addSkip(test, reason) - self.out.write(_Colors.INFO + 'SKIP {}\n'.format(test.id()) + - _Colors.END) + self.out.write( + _Colors.INFO + "SKIP {}\n".format(test.id()) + _Colors.END + ) self.out.flush() def addExpectedFailure(self, test, err): """See unittest.TestResult.addExpectedFailure.""" super(TerminalResult, self).addExpectedFailure(test, err) - self.out.write(_Colors.INFO + 'FAILURE_OK {}\n'.format(test.id()) + - _Colors.END) + self.out.write( + _Colors.INFO + "FAILURE_OK {}\n".format(test.id()) + _Colors.END + ) self.out.flush() def addUnexpectedSuccess(self, test): """See unittest.TestResult.addUnexpectedSuccess.""" super(TerminalResult, self).addUnexpectedSuccess(test) - self.out.write(_Colors.INFO + 'UNEXPECTED_OK {}\n'.format(test.id()) + - _Colors.END) + self.out.write( + _Colors.INFO + "UNEXPECTED_OK {}\n".format(test.id()) + _Colors.END + ) self.out.flush() def _traceback_string(type, value, trace): """Generate a descriptive string of a Python exception traceback. - Args: - type (class): The type of the exception. - value (Exception): The value of the exception. - trace (traceback): Traceback of the exception. + Args: + type (class): The type of the exception. + value (Exception): The value of the exception. + trace (traceback): Traceback of the exception. - Returns: - str: Formatted exception descriptive string. - """ + Returns: + str: Formatted exception descriptive string. + """ buffer = io.StringIO() traceback.print_exception(type, value, trace, file=buffer) return buffer.getvalue() @@ -346,94 +372,153 @@ def _traceback_string(type, value, trace): def summary(result): """A summary string of a result object. - Args: - result (AugmentedResult): The result object to get the summary of. + Args: + result (AugmentedResult): The result object to get the summary of. - Returns: - str: The summary string. - """ + Returns: + str: The summary string. + """ assert isinstance(result, AugmentedResult) untested = list( result.augmented_results( - lambda case_result: case_result.kind is CaseResult.Kind.UNTESTED)) + lambda case_result: case_result.kind is CaseResult.Kind.UNTESTED + ) + ) running = list( result.augmented_results( - lambda case_result: case_result.kind is CaseResult.Kind.RUNNING)) + lambda case_result: case_result.kind is CaseResult.Kind.RUNNING + ) + ) failures = list( result.augmented_results( - lambda case_result: case_result.kind is CaseResult.Kind.FAILURE)) + lambda case_result: case_result.kind is CaseResult.Kind.FAILURE + ) + ) errors = list( result.augmented_results( - lambda case_result: case_result.kind is CaseResult.Kind.ERROR)) + lambda case_result: case_result.kind is CaseResult.Kind.ERROR + ) + ) successes = list( result.augmented_results( - lambda case_result: case_result.kind is CaseResult.Kind.SUCCESS)) + lambda case_result: case_result.kind is CaseResult.Kind.SUCCESS + ) + ) skips = list( result.augmented_results( - lambda case_result: case_result.kind is CaseResult.Kind.SKIP)) + lambda case_result: case_result.kind is CaseResult.Kind.SKIP + ) + ) expected_failures = list( - result.augmented_results(lambda case_result: case_result.kind is - CaseResult.Kind.EXPECTED_FAILURE)) + result.augmented_results( + lambda case_result: case_result.kind + is CaseResult.Kind.EXPECTED_FAILURE + ) + ) unexpected_successes = list( - result.augmented_results(lambda case_result: case_result.kind is - CaseResult.Kind.UNEXPECTED_SUCCESS)) + result.augmented_results( + lambda case_result: case_result.kind + is CaseResult.Kind.UNEXPECTED_SUCCESS + ) + ) running_names = [case.name for case in running] - finished_count = (len(failures) + len(errors) + len(successes) + - len(expected_failures) + len(unexpected_successes)) - statistics = ('{finished} tests finished:\n' - '\t{successful} successful\n' - '\t{unsuccessful} unsuccessful\n' - '\t{skipped} skipped\n' - '\t{expected_fail} expected failures\n' - '\t{unexpected_successful} unexpected successes\n' - 'Interrupted Tests:\n' - '\t{interrupted}\n'.format( - finished=finished_count, - successful=len(successes), - unsuccessful=(len(failures) + len(errors)), - skipped=len(skips), - expected_fail=len(expected_failures), - unexpected_successful=len(unexpected_successes), - interrupted=str(running_names))) - tracebacks = '\n\n'.join([ - (_Colors.FAIL + '{test_name}' + _Colors.END + '\n' + _Colors.BOLD + - 'traceback:' + _Colors.END + '\n' + '{traceback}\n' + _Colors.BOLD + - 'stdout:' + _Colors.END + '\n' + '{stdout}\n' + _Colors.BOLD + - 'stderr:' + _Colors.END + '\n' + '{stderr}\n').format( - test_name=result.name, - traceback=_traceback_string(*result.traceback), - stdout=result.stdout, - stderr=result.stderr) - for result in itertools.chain(failures, errors) - ]) - notes = 'Unexpected successes: {}\n'.format( - [result.name for result in unexpected_successes]) - return statistics + '\nErrors/Failures: \n' + tracebacks + '\n' + notes + finished_count = ( + len(failures) + + len(errors) + + len(successes) + + len(expected_failures) + + len(unexpected_successes) + ) + statistics = ( + "{finished} tests finished:\n" + "\t{successful} successful\n" + "\t{unsuccessful} unsuccessful\n" + "\t{skipped} skipped\n" + "\t{expected_fail} expected failures\n" + "\t{unexpected_successful} unexpected successes\n" + "Interrupted Tests:\n" + "\t{interrupted}\n".format( + finished=finished_count, + successful=len(successes), + unsuccessful=(len(failures) + len(errors)), + skipped=len(skips), + expected_fail=len(expected_failures), + unexpected_successful=len(unexpected_successes), + interrupted=str(running_names), + ) + ) + tracebacks = "\n\n".join( + [ + ( + _Colors.FAIL + + "{test_name}" + + _Colors.END + + "\n" + + _Colors.BOLD + + "traceback:" + + _Colors.END + + "\n" + + "{traceback}\n" + + _Colors.BOLD + + "stdout:" + + _Colors.END + + "\n" + + "{stdout}\n" + + _Colors.BOLD + + "stderr:" + + _Colors.END + + "\n" + + "{stderr}\n" + ).format( + test_name=result.name, + traceback=_traceback_string(*result.traceback), + stdout=result.stdout, + stderr=result.stderr, + ) + for result in itertools.chain(failures, errors) + ] + ) + notes = "Unexpected successes: {}\n".format( + [result.name for result in unexpected_successes] + ) + return statistics + "\nErrors/Failures: \n" + tracebacks + "\n" + notes def jenkins_junit_xml(result): """An XML tree object that when written is recognizable by Jenkins. - Args: - result (AugmentedResult): The result object to get the junit xml output of. + Args: + result (AugmentedResult): The result object to get the junit xml output of. - Returns: - ElementTree.ElementTree: The XML tree. - """ + Returns: + ElementTree.ElementTree: The XML tree. + """ assert isinstance(result, AugmentedResult) - root = ElementTree.Element('testsuites') - suite = ElementTree.SubElement(root, 'testsuite', { - 'name': 'Python gRPC tests', - }) + root = ElementTree.Element("testsuites") + suite = ElementTree.SubElement( + root, + "testsuite", + { + "name": "Python gRPC tests", + }, + ) for case in result.cases.values(): if case.kind is CaseResult.Kind.SUCCESS: - ElementTree.SubElement(suite, 'testcase', { - 'name': case.name, - }) + ElementTree.SubElement( + suite, + "testcase", + { + "name": case.name, + }, + ) elif case.kind in (CaseResult.Kind.ERROR, CaseResult.Kind.FAILURE): - case_xml = ElementTree.SubElement(suite, 'testcase', { - 'name': case.name, - }) - error_xml = ElementTree.SubElement(case_xml, 'error', {}) - error_xml.text = ''.format(case.stderr, case.traceback) + case_xml = ElementTree.SubElement( + suite, + "testcase", + { + "name": case.name, + }, + ) + error_xml = ElementTree.SubElement(case_xml, "error", {}) + error_xml.text = "".format(case.stderr, case.traceback) return ElementTree.ElementTree(element=root) diff --git a/src/python/grpcio_tests/tests/_runner.py b/src/python/grpcio_tests/tests/_runner.py index 8254c0a8d60a4..ff6a7644f4563 100644 --- a/src/python/grpcio_tests/tests/_runner.py +++ b/src/python/grpcio_tests/tests/_runner.py @@ -33,20 +33,20 @@ class CaptureFile(object): """A context-managed file to redirect output to a byte array. - Use by invoking `start` (`__enter__`) and at some point invoking `stop` - (`__exit__`). At any point after the initial call to `start` call `output` to - get the current redirected output. Note that we don't currently use file - locking, so calling `output` between calls to `start` and `stop` may muddle - the result (you should only be doing this during a Python-handled interrupt as - a last ditch effort to provide output to the user). - - Attributes: - _redirected_fd (int): File descriptor of file to redirect writes from. - _saved_fd (int): A copy of the original value of the redirected file - descriptor. - _into_file (TemporaryFile or None): File to which writes are redirected. - Only non-None when self is started. - """ + Use by invoking `start` (`__enter__`) and at some point invoking `stop` + (`__exit__`). At any point after the initial call to `start` call `output` to + get the current redirected output. Note that we don't currently use file + locking, so calling `output` between calls to `start` and `stop` may muddle + the result (you should only be doing this during a Python-handled interrupt as + a last ditch effort to provide output to the user). + + Attributes: + _redirected_fd (int): File descriptor of file to redirect writes from. + _saved_fd (int): A copy of the original value of the redirected file + descriptor. + _into_file (TemporaryFile or None): File to which writes are redirected. + Only non-None when self is started. + """ def __init__(self, fd): self._redirected_fd = fd @@ -74,11 +74,11 @@ def stop(self): def write_bypass(self, value): """Bypass the redirection and write directly to the original file. - Arguments: - value (str): What to write to the original file. - """ + Arguments: + value (str): What to write to the original file. + """ if not isinstance(value, bytes): - value = value.encode('ascii') + value = value.encode("ascii") if self._saved_fd is None: os.write(self._redirect_fd, value) else: @@ -96,15 +96,15 @@ def close(self): os.close(self._saved_fd) -class AugmentedCase(collections.namedtuple('AugmentedCase', ['case', 'id'])): +class AugmentedCase(collections.namedtuple("AugmentedCase", ["case", "id"])): """A test case with a guaranteed unique externally specified identifier. - Attributes: - case (unittest.TestCase): TestCase we're decorating with an additional - identifier. - id (object): Any identifier that may be considered 'unique' for testing - purposes. - """ + Attributes: + case (unittest.TestCase): TestCase we're decorating with an additional + identifier. + id (object): Any identifier that may be considered 'unique' for testing + purposes. + """ def __new__(cls, case, id=None): if id is None: @@ -115,7 +115,6 @@ def __new__(cls, case, id=None): # NOTE(lidiz) This complex wrapper is not triggering setUpClass nor # tearDownClass. Do not use those methods, or fix this wrapper! class Runner(object): - def __init__(self, dedicated_threads=False): """Constructs the Runner object. @@ -132,7 +131,7 @@ def skip_tests(self, tests): def run(self, suite): """See setuptools' test_runner setup argument for information.""" # only run test cases with id starting with given prefix - testcase_filter = os.getenv('GRPC_PYTHON_TESTRUNNER_FILTER') + testcase_filter = os.getenv("GRPC_PYTHON_TESTRUNNER_FILTER") filtered_cases = [] for case in _loader.iterate_suite_cases(suite): if not testcase_filter or case.id().startswith(testcase_filter): @@ -143,11 +142,14 @@ def run(self, suite): augmented_cases = [ AugmentedCase(case, uuid.uuid4()) for case in filtered_cases ] - case_id_by_case = dict((augmented_case.case, augmented_case.id) - for augmented_case in augmented_cases) + case_id_by_case = dict( + (augmented_case.case, augmented_case.id) + for augmented_case in augmented_cases + ) result_out = io.StringIO() result = _result.TerminalResult( - result_out, id_map=lambda case: case_id_by_case[case]) + result_out, id_map=lambda case: case_id_by_case[case] + ) stdout_pipe = CaptureFile(sys.stdout.fileno()) stderr_pipe = CaptureFile(sys.stderr.fileno()) kill_flag = [False] @@ -159,19 +161,27 @@ def sigint_handler(signal_number, frame): def fault_handler(signal_number, frame): stdout_pipe.write_bypass( - 'Received fault signal {}\nstdout:\n{}\n\nstderr:{}\n'.format( - signal_number, stdout_pipe.output(), stderr_pipe.output())) + "Received fault signal {}\nstdout:\n{}\n\nstderr:{}\n".format( + signal_number, stdout_pipe.output(), stderr_pipe.output() + ) + ) os._exit(1) def check_kill_self(): if kill_flag[0]: - stdout_pipe.write_bypass('Stopping tests short...') + stdout_pipe.write_bypass("Stopping tests short...") result.stopTestRun() stdout_pipe.write_bypass(result_out.getvalue()) - stdout_pipe.write_bypass('\ninterrupted stdout:\n{}\n'.format( - stdout_pipe.output().decode())) - stderr_pipe.write_bypass('\ninterrupted stderr:\n{}\n'.format( - stderr_pipe.output().decode())) + stdout_pipe.write_bypass( + "\ninterrupted stdout:\n{}\n".format( + stdout_pipe.output().decode() + ) + ) + stderr_pipe.write_bypass( + "\ninterrupted stderr:\n{}\n".format( + stderr_pipe.output().decode() + ) + ) os._exit(1) def try_set_handler(name, handler): @@ -180,14 +190,14 @@ def try_set_handler(name, handler): except AttributeError: pass - try_set_handler('SIGINT', sigint_handler) - try_set_handler('SIGBUS', fault_handler) - try_set_handler('SIGABRT', fault_handler) - try_set_handler('SIGFPE', fault_handler) - try_set_handler('SIGILL', fault_handler) + try_set_handler("SIGINT", sigint_handler) + try_set_handler("SIGBUS", fault_handler) + try_set_handler("SIGABRT", fault_handler) + try_set_handler("SIGFPE", fault_handler) + try_set_handler("SIGILL", fault_handler) # Sometimes output will lag after a test has successfully finished; we # ignore such writes to our pipes. - try_set_handler('SIGPIPE', signal.SIG_IGN) + try_set_handler("SIGPIPE", signal.SIG_IGN) # Run the tests result.startTestRun() @@ -196,13 +206,15 @@ def try_set_handler(name, handler): if skipped_test in augmented_case.case.id(): break else: - sys.stdout.write('Running {}\n'.format( - augmented_case.case.id())) + sys.stdout.write( + "Running {}\n".format(augmented_case.case.id()) + ) sys.stdout.flush() if self._dedicated_threads: # (Deprecated) Spawns dedicated thread for each test case. case_thread = threading.Thread( - target=augmented_case.case.run, args=(result,)) + target=augmented_case.case.run, args=(result,) + ) try: with stdout_pipe, stderr_pipe: case_thread.start() @@ -215,8 +227,11 @@ def try_set_handler(name, handler): # re-raise the exception after forcing the with-block to end raise # Records the result of the test case run. - result.set_output(augmented_case.case, stdout_pipe.output(), - stderr_pipe.output()) + result.set_output( + augmented_case.case, + stdout_pipe.output(), + stderr_pipe.output(), + ) sys.stdout.write(result_out.getvalue()) sys.stdout.flush() result_out.truncate(0) @@ -232,6 +247,6 @@ def try_set_handler(name, handler): sys.stdout.write(result_out.getvalue()) sys.stdout.flush() signal.signal(signal.SIGINT, signal.SIG_DFL) - with open('report.xml', 'wb') as report_xml_file: + with open("report.xml", "wb") as report_xml_file: _result.jenkins_junit_xml(result).write(report_xml_file) return result diff --git a/src/python/grpcio_tests/tests/_sanity/_sanity_test.py b/src/python/grpcio_tests/tests/_sanity/_sanity_test.py index 1c7c9ec05d9f9..af0b2ee57c04b 100644 --- a/src/python/grpcio_tests/tests/_sanity/_sanity_test.py +++ b/src/python/grpcio_tests/tests/_sanity/_sanity_test.py @@ -20,26 +20,29 @@ class SanityTest(unittest.TestCase): - maxDiff = 32768 - TEST_PKG_MODULE_NAME = 'tests' - TEST_PKG_PATH = 'tests' + TEST_PKG_MODULE_NAME = "tests" + TEST_PKG_PATH = "tests" def testTestsJsonUpToDate(self): """Autodiscovers all test suites and checks that tests.json is up to date""" loader = tests.Loader() loader.loadTestsFromNames([self.TEST_PKG_MODULE_NAME]) - test_suite_names = sorted({ - test_case_class.id().rsplit('.', 1)[0] for test_case_class in - tests._loader.iterate_suite_cases(loader.suite) - }) - - tests_json_string = pkgutil.get_data(self.TEST_PKG_PATH, 'tests.json') + test_suite_names = sorted( + { + test_case_class.id().rsplit(".", 1)[0] + for test_case_class in tests._loader.iterate_suite_cases( + loader.suite + ) + } + ) + + tests_json_string = pkgutil.get_data(self.TEST_PKG_PATH, "tests.json") tests_json = tests_json_string.decode() self.assertSequenceEqual(tests_json, test_suite_names) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/admin/test_admin.py b/src/python/grpcio_tests/tests/admin/test_admin.py index adc4878282805..b7501cc222071 100644 --- a/src/python/grpcio_tests/tests/admin/test_admin.py +++ b/src/python/grpcio_tests/tests/admin/test_admin.py @@ -26,17 +26,17 @@ from grpc_csds import csds_pb2_grpc -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class TestAdmin(unittest.TestCase): - def setUp(self): self._server = grpc.server(ThreadPoolExecutor()) - port = self._server.add_insecure_port('localhost:0') + port = self._server.add_insecure_port("localhost:0") grpc_admin.add_admin_servicers(self._server) self._server.start() - self._channel = grpc.insecure_channel('localhost:%s' % port) + self._channel = grpc.insecure_channel("localhost:%s" % port) def tearDown(self): self._channel.close() diff --git a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py index 34c6d15eb5ef3..78333fc62c7df 100644 --- a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py +++ b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py @@ -25,16 +25,16 @@ from tests.unit import test_common from tests.unit.framework.common import test_constants -_SUCCESSFUL_UNARY_UNARY = '/test/SuccessfulUnaryUnary' -_FAILED_UNARY_UNARY = '/test/FailedUnaryUnary' -_SUCCESSFUL_STREAM_STREAM = '/test/SuccessfulStreamStream' +_SUCCESSFUL_UNARY_UNARY = "/test/SuccessfulUnaryUnary" +_FAILED_UNARY_UNARY = "/test/FailedUnaryUnary" +_SUCCESSFUL_STREAM_STREAM = "/test/SuccessfulStreamStream" -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x01\x01\x01' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x01\x01\x01" -_DISABLE_REUSE_PORT = (('grpc.so_reuseport', 0),) -_ENABLE_CHANNELZ = (('grpc.enable_channelz', 1),) -_DISABLE_CHANNELZ = (('grpc.enable_channelz', 0),) +_DISABLE_REUSE_PORT = (("grpc.so_reuseport", 0),) +_ENABLE_CHANNELZ = (("grpc.enable_channelz", 1),) +_DISABLE_CHANNELZ = (("grpc.enable_channelz", 0),) def _successful_unary_unary(request, servicer_context): @@ -52,7 +52,6 @@ def _successful_stream_stream(request_iterator, servicer_context): class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _SUCCESSFUL_UNARY_UNARY: return grpc.unary_unary_rpc_method_handler(_successful_unary_unary) @@ -60,25 +59,27 @@ def service(self, handler_call_details): return grpc.unary_unary_rpc_method_handler(_failed_unary_unary) elif handler_call_details.method == _SUCCESSFUL_STREAM_STREAM: return grpc.stream_stream_rpc_method_handler( - _successful_stream_stream) + _successful_stream_stream + ) else: return None class _ChannelServerPair(object): - def __init__(self): # Server will enable channelz service - self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=3), - options=_DISABLE_REUSE_PORT + - _ENABLE_CHANNELZ) - port = self.server.add_insecure_port('[::]:0') + self.server = grpc.server( + futures.ThreadPoolExecutor(max_workers=3), + options=_DISABLE_REUSE_PORT + _ENABLE_CHANNELZ, + ) + port = self.server.add_insecure_port("[::]:0") self.server.add_generic_rpc_handlers((_GenericHandler(),)) self.server.start() # Channel will enable channelz service... - self.channel = grpc.insecure_channel('localhost:%d' % port, - _ENABLE_CHANNELZ) + self.channel = grpc.insecure_channel( + "localhost:%d" % port, _ENABLE_CHANNELZ + ) def _generate_channel_server_pairs(n): @@ -91,28 +92,34 @@ def _close_channel_server_pairs(pairs): pair.channel.close() -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class ChannelzServicerTest(unittest.TestCase): - def _send_successful_unary_unary(self, idx): - _, r = self._pairs[idx].channel.unary_unary( - _SUCCESSFUL_UNARY_UNARY).with_call(_REQUEST) + _, r = ( + self._pairs[idx] + .channel.unary_unary(_SUCCESSFUL_UNARY_UNARY) + .with_call(_REQUEST) + ) self.assertEqual(r.code(), grpc.StatusCode.OK) def _send_failed_unary_unary(self, idx): try: self._pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY).with_call( - _REQUEST) + _REQUEST + ) except grpc.RpcError: return else: self.fail("This call supposed to fail") def _send_successful_stream_stream(self, idx): - response_iterator = self._pairs[idx].channel.stream_stream( - _SUCCESSFUL_STREAM_STREAM).__call__( - iter([_REQUEST] * test_constants.STREAM_LENGTH)) + response_iterator = ( + self._pairs[idx] + .channel.stream_stream(_SUCCESSFUL_STREAM_STREAM) + .__call__(iter([_REQUEST] * test_constants.STREAM_LENGTH)) + ) cnt = 0 for _ in response_iterator: cnt += 1 @@ -121,7 +128,8 @@ def _send_successful_stream_stream(self, idx): def _get_channel_id(self, idx): """Channel id may not be consecutive""" resp = self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + channelz_pb2.GetTopChannelsRequest(start_channel_id=0) + ) self.assertGreater(len(resp.channel), idx) return resp.channel[idx].ref.channel_id @@ -129,17 +137,19 @@ def setUp(self): self._pairs = [] # This server is for Channelz info fetching only # It self should not enable Channelz - self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=3), - options=_DISABLE_REUSE_PORT + - _DISABLE_CHANNELZ) - port = self._server.add_insecure_port('[::]:0') + self._server = grpc.server( + futures.ThreadPoolExecutor(max_workers=3), + options=_DISABLE_REUSE_PORT + _DISABLE_CHANNELZ, + ) + port = self._server.add_insecure_port("[::]:0") channelz.add_channelz_servicer(self._server) self._server.start() # This channel is used to fetch Channelz info only # Channelz should not be enabled - self._channel = grpc.insecure_channel('localhost:%d' % port, - _DISABLE_CHANNELZ) + self._channel = grpc.insecure_channel( + "localhost:%d" % port, _DISABLE_CHANNELZ + ) self._channelz_stub = channelz_pb2_grpc.ChannelzStub(self._channel) def tearDown(self): @@ -150,14 +160,16 @@ def tearDown(self): def test_get_top_channels_basic(self): self._pairs = _generate_channel_server_pairs(1) resp = self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + channelz_pb2.GetTopChannelsRequest(start_channel_id=0) + ) self.assertEqual(len(resp.channel), 1) self.assertEqual(resp.end, True) def test_get_top_channels_high_start_id(self): self._pairs = _generate_channel_server_pairs(1) resp = self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=10000)) + channelz_pb2.GetTopChannelsRequest(start_channel_id=10000) + ) self.assertEqual(len(resp.channel), 0) self.assertEqual(resp.end, True) @@ -165,7 +177,8 @@ def test_successful_request(self): self._pairs = _generate_channel_server_pairs(1) self._send_successful_unary_unary(0) resp = self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0)) + ) self.assertEqual(resp.channel.data.calls_started, 1) self.assertEqual(resp.channel.data.calls_succeeded, 1) self.assertEqual(resp.channel.data.calls_failed, 0) @@ -174,7 +187,8 @@ def test_failed_request(self): self._pairs = _generate_channel_server_pairs(1) self._send_failed_unary_unary(0) resp = self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0)) + ) self.assertEqual(resp.channel.data.calls_started, 1) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, 1) @@ -188,7 +202,8 @@ def test_many_requests(self): for i in range(k_failed): self._send_failed_unary_unary(0) resp = self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0)) + ) self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, k_failed) @@ -197,7 +212,8 @@ def test_many_channel(self): k_channels = 4 self._pairs = _generate_channel_server_pairs(k_channels) resp = self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + channelz_pb2.GetTopChannelsRequest(start_channel_id=0) + ) self.assertEqual(len(resp.channel), k_channels) def test_many_requests_many_channel(self): @@ -214,28 +230,32 @@ def test_many_requests_many_channel(self): # The first channel saw only successes resp = self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0)) + ) self.assertEqual(resp.channel.data.calls_started, k_success) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, 0) # The second channel saw only failures resp = self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(1))) + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(1)) + ) self.assertEqual(resp.channel.data.calls_started, k_failed) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, k_failed) # The third channel saw both successes and failures resp = self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(2))) + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(2)) + ) self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, k_failed) # The fourth channel saw nothing resp = self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(3))) + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(3)) + ) self.assertEqual(resp.channel.data.calls_started, 0) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, 0) @@ -253,7 +273,8 @@ def test_many_subchannels(self): self._send_failed_unary_unary(2) gtc_resp = self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + channelz_pb2.GetTopChannelsRequest(start_channel_id=0) + ) self.assertEqual(len(gtc_resp.channel), k_channels) for i in range(k_channels): # If no call performed in the channel, there shouldn't be any subchannel @@ -265,31 +286,45 @@ def test_many_subchannels(self): self.assertGreater(len(gtc_resp.channel[i].subchannel_ref), 0) gsc_resp = self._channelz_stub.GetSubchannel( channelz_pb2.GetSubchannelRequest( - subchannel_id=gtc_resp.channel[i].subchannel_ref[0]. - subchannel_id)) - self.assertEqual(gtc_resp.channel[i].data.calls_started, - gsc_resp.subchannel.data.calls_started) - self.assertEqual(gtc_resp.channel[i].data.calls_succeeded, - gsc_resp.subchannel.data.calls_succeeded) - self.assertEqual(gtc_resp.channel[i].data.calls_failed, - gsc_resp.subchannel.data.calls_failed) + subchannel_id=gtc_resp.channel[i] + .subchannel_ref[0] + .subchannel_id + ) + ) + self.assertEqual( + gtc_resp.channel[i].data.calls_started, + gsc_resp.subchannel.data.calls_started, + ) + self.assertEqual( + gtc_resp.channel[i].data.calls_succeeded, + gsc_resp.subchannel.data.calls_succeeded, + ) + self.assertEqual( + gtc_resp.channel[i].data.calls_failed, + gsc_resp.subchannel.data.calls_failed, + ) def test_server_basic(self): self._pairs = _generate_channel_server_pairs(1) resp = self._channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=0)) + channelz_pb2.GetServersRequest(start_server_id=0) + ) self.assertEqual(len(resp.server), 1) def test_get_one_server(self): self._pairs = _generate_channel_server_pairs(1) gss_resp = self._channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=0)) + channelz_pb2.GetServersRequest(start_server_id=0) + ) self.assertEqual(len(gss_resp.server), 1) gs_resp = self._channelz_stub.GetServer( channelz_pb2.GetServerRequest( - server_id=gss_resp.server[0].ref.server_id)) - self.assertEqual(gss_resp.server[0].ref.server_id, - gs_resp.server.ref.server_id) + server_id=gss_resp.server[0].ref.server_id + ) + ) + self.assertEqual( + gss_resp.server[0].ref.server_id, gs_resp.server.ref.server_id + ) def test_server_call(self): self._pairs = _generate_channel_server_pairs(1) @@ -301,10 +336,12 @@ def test_server_call(self): self._send_failed_unary_unary(0) resp = self._channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=0)) + channelz_pb2.GetServersRequest(start_server_id=0) + ) self.assertEqual(len(resp.server), 1) - self.assertEqual(resp.server[0].data.calls_started, - k_success + k_failed) + self.assertEqual( + resp.server[0].data.calls_started, k_success + k_failed + ) self.assertEqual(resp.server[0].data.calls_succeeded, k_success) self.assertEqual(resp.server[0].data.calls_failed, k_failed) @@ -321,7 +358,8 @@ def test_many_subchannels_and_sockets(self): self._send_failed_unary_unary(2) gtc_resp = self._channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + channelz_pb2.GetTopChannelsRequest(start_channel_id=0) + ) self.assertEqual(len(gtc_resp.channel), k_channels) for i in range(k_channels): # If no call performed in the channel, there shouldn't be any subchannel @@ -333,32 +371,47 @@ def test_many_subchannels_and_sockets(self): self.assertGreater(len(gtc_resp.channel[i].subchannel_ref), 0) gsc_resp = self._channelz_stub.GetSubchannel( channelz_pb2.GetSubchannelRequest( - subchannel_id=gtc_resp.channel[i].subchannel_ref[0]. - subchannel_id)) + subchannel_id=gtc_resp.channel[i] + .subchannel_ref[0] + .subchannel_id + ) + ) self.assertEqual(len(gsc_resp.subchannel.socket_ref), 1) gs_resp = self._channelz_stub.GetSocket( channelz_pb2.GetSocketRequest( - socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) - self.assertEqual(gsc_resp.subchannel.data.calls_started, - gs_resp.socket.data.streams_started) - self.assertEqual(gsc_resp.subchannel.data.calls_started, - gs_resp.socket.data.streams_succeeded) + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id + ) + ) + self.assertEqual( + gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.streams_started, + ) + self.assertEqual( + gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.streams_succeeded, + ) # Calls started == messages sent, only valid for unary calls - self.assertEqual(gsc_resp.subchannel.data.calls_started, - gs_resp.socket.data.messages_sent) + self.assertEqual( + gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.messages_sent, + ) # Only receive responses when the RPC was successful - self.assertEqual(gsc_resp.subchannel.data.calls_succeeded, - gs_resp.socket.data.messages_received) + self.assertEqual( + gsc_resp.subchannel.data.calls_succeeded, + gs_resp.socket.data.messages_received, + ) if gs_resp.socket.remote.HasField("tcpip_address"): address = gs_resp.socket.remote.tcpip_address.ip_address self.assertTrue( - len(address) == 4 or len(address) == 16, address) + len(address) == 4 or len(address) == 16, address + ) if gs_resp.socket.local.HasField("tcpip_address"): address = gs_resp.socket.local.tcpip_address.ip_address self.assertTrue( - len(address) == 4 or len(address) == 16, address) + len(address) == 4 or len(address) == 16, address + ) def test_streaming_rpc(self): self._pairs = _generate_channel_server_pairs(1) @@ -367,7 +420,8 @@ def test_streaming_rpc(self): self._send_successful_stream_stream(0) gc_resp = self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0))) + channelz_pb2.GetChannelRequest(channel_id=self._get_channel_id(0)) + ) self.assertEqual(gc_resp.channel.data.calls_started, 1) self.assertEqual(gc_resp.channel.data.calls_succeeded, 1) self.assertEqual(gc_resp.channel.data.calls_failed, 0) @@ -377,9 +431,16 @@ def test_streaming_rpc(self): while True: gsc_resp = self._channelz_stub.GetSubchannel( channelz_pb2.GetSubchannelRequest( - subchannel_id=gc_resp.channel.subchannel_ref[0]. - subchannel_id)) - if gsc_resp.subchannel.data.calls_started == gsc_resp.subchannel.data.calls_succeeded + gsc_resp.subchannel.data.calls_failed: + subchannel_id=gc_resp.channel.subchannel_ref[ + 0 + ].subchannel_id + ) + ) + if ( + gsc_resp.subchannel.data.calls_started + == gsc_resp.subchannel.data.calls_succeeded + + gsc_resp.subchannel.data.calls_failed + ): break self.assertEqual(gsc_resp.subchannel.data.calls_started, 1) self.assertEqual(gsc_resp.subchannel.data.calls_failed, 0) @@ -390,16 +451,24 @@ def test_streaming_rpc(self): while True: gs_resp = self._channelz_stub.GetSocket( channelz_pb2.GetSocketRequest( - socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) - if gs_resp.socket.data.streams_started == gs_resp.socket.data.streams_succeeded + gs_resp.socket.data.streams_failed: + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id + ) + ) + if ( + gs_resp.socket.data.streams_started + == gs_resp.socket.data.streams_succeeded + + gs_resp.socket.data.streams_failed + ): break self.assertEqual(gs_resp.socket.data.streams_started, 1) self.assertEqual(gs_resp.socket.data.streams_succeeded, 1) self.assertEqual(gs_resp.socket.data.streams_failed, 0) - self.assertEqual(gs_resp.socket.data.messages_sent, - test_constants.STREAM_LENGTH) - self.assertEqual(gs_resp.socket.data.messages_received, - test_constants.STREAM_LENGTH) + self.assertEqual( + gs_resp.socket.data.messages_sent, test_constants.STREAM_LENGTH + ) + self.assertEqual( + gs_resp.socket.data.messages_received, test_constants.STREAM_LENGTH + ) def test_server_sockets(self): self._pairs = _generate_channel_server_pairs(1) @@ -407,7 +476,8 @@ def test_server_sockets(self): self._send_failed_unary_unary(0) gs_resp = self._channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=0)) + channelz_pb2.GetServersRequest(start_server_id=0) + ) self.assertEqual(len(gs_resp.server), 1) self.assertEqual(gs_resp.server[0].data.calls_started, 2) self.assertEqual(gs_resp.server[0].data.calls_succeeded, 1) @@ -415,7 +485,9 @@ def test_server_sockets(self): gss_resp = self._channelz_stub.GetServerSockets( channelz_pb2.GetServerSocketsRequest( - server_id=gs_resp.server[0].ref.server_id, start_socket_id=0)) + server_id=gs_resp.server[0].ref.server_id, start_socket_id=0 + ) + ) # If the RPC call failed, it will raise a grpc.RpcError # So, if there is no exception raised, considered pass @@ -423,13 +495,16 @@ def test_server_listen_sockets(self): self._pairs = _generate_channel_server_pairs(1) gss_resp = self._channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=0)) + channelz_pb2.GetServersRequest(start_server_id=0) + ) self.assertEqual(len(gss_resp.server), 1) self.assertEqual(len(gss_resp.server[0].listen_socket), 1) gs_resp = self._channelz_stub.GetSocket( channelz_pb2.GetSocketRequest( - socket_id=gss_resp.server[0].listen_socket[0].socket_id)) + socket_id=gss_resp.server[0].listen_socket[0].socket_id + ) + ) # If the RPC call failed, it will raise a grpc.RpcError # So, if there is no exception raised, considered pass @@ -437,38 +512,42 @@ def test_server_listen_sockets(self): def test_invalid_query_get_server(self): try: self._channelz_stub.GetServer( - channelz_pb2.GetServerRequest(server_id=10000)) + channelz_pb2.GetServerRequest(server_id=10000) + ) except BaseException as e: - self.assertIn('StatusCode.NOT_FOUND', str(e)) + self.assertIn("StatusCode.NOT_FOUND", str(e)) else: - self.fail('Invalid query not detected') + self.fail("Invalid query not detected") def test_invalid_query_get_channel(self): try: self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=10000)) + channelz_pb2.GetChannelRequest(channel_id=10000) + ) except BaseException as e: - self.assertIn('StatusCode.NOT_FOUND', str(e)) + self.assertIn("StatusCode.NOT_FOUND", str(e)) else: - self.fail('Invalid query not detected') + self.fail("Invalid query not detected") def test_invalid_query_get_subchannel(self): try: self._channelz_stub.GetSubchannel( - channelz_pb2.GetSubchannelRequest(subchannel_id=10000)) + channelz_pb2.GetSubchannelRequest(subchannel_id=10000) + ) except BaseException as e: - self.assertIn('StatusCode.NOT_FOUND', str(e)) + self.assertIn("StatusCode.NOT_FOUND", str(e)) else: - self.fail('Invalid query not detected') + self.fail("Invalid query not detected") def test_invalid_query_get_socket(self): try: self._channelz_stub.GetSocket( - channelz_pb2.GetSocketRequest(socket_id=10000)) + channelz_pb2.GetSocketRequest(socket_id=10000) + ) except BaseException as e: - self.assertIn('StatusCode.NOT_FOUND', str(e)) + self.assertIn("StatusCode.NOT_FOUND", str(e)) else: - self.fail('Invalid query not detected') + self.fail("Invalid query not detected") def test_invalid_query_get_server_sockets(self): try: @@ -476,12 +555,13 @@ def test_invalid_query_get_server_sockets(self): channelz_pb2.GetServerSocketsRequest( server_id=10000, start_socket_id=0, - )) + ) + ) except BaseException as e: - self.assertIn('StatusCode.NOT_FOUND', str(e)) + self.assertIn("StatusCode.NOT_FOUND", str(e)) else: - self.fail('Invalid query not detected') + self.fail("Invalid query not detected") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/csds/test_csds.py b/src/python/grpcio_tests/tests/csds/test_csds.py index 0a5fee8f31736..8823444d59be7 100644 --- a/src/python/grpcio_tests/tests/csds/test_csds.py +++ b/src/python/grpcio_tests/tests/csds/test_csds.py @@ -27,7 +27,7 @@ import grpc import grpc_csds -_DUMMY_XDS_ADDRESS = 'xds:///foo.bar' +_DUMMY_XDS_ADDRESS = "xds:///foo.bar" _DUMMY_BOOTSTRAP_FILE = """ { \"xds_servers\": [ @@ -57,25 +57,26 @@ """ -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class TestCsds(unittest.TestCase): - def setUp(self): - os.environ['GRPC_XDS_BOOTSTRAP_CONFIG'] = _DUMMY_BOOTSTRAP_FILE + os.environ["GRPC_XDS_BOOTSTRAP_CONFIG"] = _DUMMY_BOOTSTRAP_FILE self._server = grpc.server(ThreadPoolExecutor()) - port = self._server.add_insecure_port('localhost:0') + port = self._server.add_insecure_port("localhost:0") grpc_csds.add_csds_servicer(self._server) self._server.start() - self._channel = grpc.insecure_channel('localhost:%s' % port) + self._channel = grpc.insecure_channel("localhost:%s" % port) self._stub = csds_pb2_grpc.ClientStatusDiscoveryServiceStub( - self._channel) + self._channel + ) def tearDown(self): self._channel.close() self._server.stop(0) - os.environ.pop('GRPC_XDS_BOOTSTRAP_CONFIG', None) + os.environ.pop("GRPC_XDS_BOOTSTRAP_CONFIG", None) def get_xds_config_dump(self): return self._stub.FetchClientStatus(csds_pb2.ClientStatusRequest()) @@ -83,17 +84,18 @@ def get_xds_config_dump(self): def test_has_node(self): resp = self.get_xds_config_dump() self.assertEqual(1, len(resp.config)) - self.assertEqual('python_test_csds', resp.config[0].node.id) - self.assertEqual('test', resp.config[0].node.cluster) + self.assertEqual("python_test_csds", resp.config[0].node.id) + self.assertEqual("test", resp.config[0].node.cluster) def test_no_lds_found(self): dummy_channel = grpc.insecure_channel(_DUMMY_XDS_ADDRESS) # Force the XdsClient to initialize and request a resource with self.assertRaises(grpc.RpcError) as rpc_error: - dummy_channel.unary_unary('')(b'', wait_for_ready=False, timeout=1) - self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, - rpc_error.exception.code()) + dummy_channel.unary_unary("")(b"", wait_for_ready=False, timeout=1) + self.assertEqual( + grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.exception.code() + ) # The resource request will fail with DOES_NOT_EXIST (after 15s) while True: @@ -104,14 +106,16 @@ def test_no_lds_found(self): for xds_config in config["config"][0].get("xdsConfig", []): if "listenerConfig" in xds_config: listener = xds_config["listenerConfig"][ - "dynamicListeners"][0] - if listener['clientStatus'] == 'REQUESTED': + "dynamicListeners" + ][0] + if listener["clientStatus"] == "REQUESTED": ok = True break for generic_xds_config in config["config"][0].get( - "genericXdsConfigs", []): + "genericXdsConfigs", [] + ): if "Listener" in generic_xds_config["typeUrl"]: - if generic_xds_config['clientStatus'] == 'REQUESTED': + if generic_xds_config["clientStatus"] == "REQUESTED": ok = True break except KeyError as e: @@ -123,15 +127,16 @@ def test_no_lds_found(self): dummy_channel.close() -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class TestCsdsStream(TestCsds): - def get_xds_config_dump(self): - if not hasattr(self, 'request_queue'): + if not hasattr(self, "request_queue"): request_queue = queue.Queue() response_iterator = self._stub.StreamClientStatus( - iter(request_queue.get, None)) + iter(request_queue.get, None) + ) request_queue.put(csds_pb2.ClientStatusRequest()) return next(response_iterator) diff --git a/src/python/grpcio_tests/tests/fork/_fork_interop_test.py b/src/python/grpcio_tests/tests/fork/_fork_interop_test.py index 7512e026780c1..90c61cb7ff714 100644 --- a/src/python/grpcio_tests/tests/fork/_fork_interop_test.py +++ b/src/python/grpcio_tests/tests/fork/_fork_interop_test.py @@ -30,8 +30,11 @@ def _dump_streams(name, streams): assert len(streams) == 2 for stream_name, stream in zip(("STDOUT", "STDERR"), streams): stream.seek(0) - sys.stderr.write("{} {}:\n{}\n".format(name, stream_name, - stream.read().decode("ascii"))) + sys.stderr.write( + "{} {}:\n{}\n".format( + name, stream_name, stream.read().decode("ascii") + ) + ) stream.close() sys.stderr.flush() @@ -65,12 +68,13 @@ def _dump_streams(name, streams): @unittest.skipUnless( sys.platform.startswith("linux"), - "not supported on windows, and fork+exec networking blocked on mac") + "not supported on windows, and fork+exec networking blocked on mac", +) @unittest.skipUnless( os.getenv("GRPC_ENABLE_FORK_SUPPORT") is not None, - "Core must be built with fork support to run this test.") + "Core must be built with fork support to run this test.", +) class ForkInteropTest(unittest.TestCase): - def setUp(self): self._port = None start_server_script = """if True: @@ -94,11 +98,13 @@ def setUp(self): """ self._streams = tuple(tempfile.TemporaryFile() for _ in range(2)) self._server_process = subprocess.Popen( - [sys.executable, '-c', start_server_script], + [sys.executable, "-c", start_server_script], stdout=self._streams[0], - stderr=self._streams[1]) - timer = threading.Timer(_SUBPROCESS_TIMEOUT_S, - self._server_process.kill) + stderr=self._streams[1], + ) + timer = threading.Timer( + _SUBPROCESS_TIMEOUT_S, self._server_process.kill + ) interval_secs = 2.0 cumulative_secs = 0.0 try: @@ -115,15 +121,17 @@ def setUp(self): if self._port is None: # Timeout self._streams[0].seek(0) - sys.stderr.write("Server STDOUT:\n{}\n".format( - self._streams[0].read())) + sys.stderr.write( + "Server STDOUT:\n{}\n".format(self._streams[0].read()) + ) self._streams[1].seek(0) - sys.stderr.write("Server STDERR:\n{}\n".format( - self._streams[1].read())) + sys.stderr.write( + "Server STDERR:\n{}\n".format(self._streams[1].read()) + ) sys.stderr.flush() raise Exception("Failed to get port from server.") except ValueError: - raise Exception('Failed to get port from server') + raise Exception("Failed to get port from server") finally: timer.cancel() @@ -150,19 +158,23 @@ def testInProgressBidiContinueCall(self): def testInProgressBidiSameChannelAsyncCall(self): self._verifyTestCase( - methods.TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL) + methods.TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL + ) def testInProgressBidiSameChannelBlockingCall(self): self._verifyTestCase( - methods.TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL) + methods.TestCase.IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL + ) def testInProgressBidiNewChannelAsyncCall(self): self._verifyTestCase( - methods.TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL) + methods.TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL + ) def testInProgressBidiNewChannelBlockingCall(self): self._verifyTestCase( - methods.TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL) + methods.TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL + ) def tearDown(self): self._server_process.kill() @@ -201,9 +213,9 @@ def _print_backtraces(self, pid): def _verifyTestCase(self, test_case): script = _CLIENT_FORK_SCRIPT_TEMPLATE % (test_case.name, self._port) streams = tuple(tempfile.TemporaryFile() for _ in range(2)) - process = subprocess.Popen([sys.executable, '-c', script], - stdout=streams[0], - stderr=streams[1]) + process = subprocess.Popen( + [sys.executable, "-c", script], stdout=streams[0], stderr=streams[1] + ) try: process.wait(timeout=_SUBPROCESS_TIMEOUT_S) self.assertEqual(0, process.returncode) @@ -216,5 +228,5 @@ def _verifyTestCase(self, test_case): _dump_streams("Server", self._streams) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/fork/client.py b/src/python/grpcio_tests/tests/fork/client.py index 852e6da4d698f..6bd7ef5dd931e 100644 --- a/src/python/grpcio_tests/tests/fork/client.py +++ b/src/python/grpcio_tests/tests/fork/client.py @@ -21,31 +21,38 @@ def _args(): - def parse_bool(value): - if value == 'true': + if value == "true": return True - if value == 'false': + if value == "false": return False - raise argparse.ArgumentTypeError('Only true/false allowed') + raise argparse.ArgumentTypeError("Only true/false allowed") parser = argparse.ArgumentParser() - parser.add_argument('--server_host', - default="localhost", - type=str, - help='the host to which to connect') - parser.add_argument('--server_port', - type=int, - required=True, - help='the port to which to connect') - parser.add_argument('--test_case', - default='large_unary', - type=str, - help='the test case to execute') - parser.add_argument('--use_tls', - default=False, - type=parse_bool, - help='require a secure connection') + parser.add_argument( + "--server_host", + default="localhost", + type=str, + help="the host to which to connect", + ) + parser.add_argument( + "--server_port", + type=int, + required=True, + help="the port to which to connect", + ) + parser.add_argument( + "--test_case", + default="large_unary", + type=str, + help="the test case to execute", + ) + parser.add_argument( + "--use_tls", + default=False, + type=parse_bool, + help="require a secure connection", + ) return parser.parse_args() @@ -60,13 +67,13 @@ def _test_case_from_arg(test_case_arg): def test_fork(): logging.basicConfig(level=logging.INFO) args = vars(_args()) - if args['test_case'] == "all": + if args["test_case"] == "all": for test_case in methods.TestCase: test_case.run_test(args) else: - test_case = _test_case_from_arg(args['test_case']) + test_case = _test_case_from_arg(args["test_case"]) test_case.run_test(args) -if __name__ == '__main__': +if __name__ == "__main__": test_fork() diff --git a/src/python/grpcio_tests/tests/fork/methods.py b/src/python/grpcio_tests/tests/fork/methods.py index 2106e3b8e08f3..d69815183489c 100644 --- a/src/python/grpcio_tests/tests/fork/methods.py +++ b/src/python/grpcio_tests/tests/fork/methods.py @@ -39,8 +39,8 @@ def _channel(args): - target = '{}:{}'.format(args['server_host'], args['server_port']) - if args['use_tls']: + target = "{}:{}".format(args["server_host"], args["server_port"]) + if args["use_tls"]: channel_credentials = grpc.ssl_channel_credentials() channel = grpc.secure_channel(target, channel_credentials) else: @@ -50,11 +50,15 @@ def _channel(args): def _validate_payload_type_and_length(response, expected_type, expected_length): if response.payload.type is not expected_type: - raise ValueError('expected payload type %s, got %s' % - (expected_type, type(response.payload.type))) + raise ValueError( + "expected payload type %s, got %s" + % (expected_type, type(response.payload.type)) + ) elif len(response.payload.body) != expected_length: - raise ValueError('expected payload body size %d, got %d' % - (expected_length, len(response.payload.body))) + raise ValueError( + "expected payload body size %d, got %d" + % (expected_length, len(response.payload.body)) + ) def _async_unary(stub): @@ -62,7 +66,8 @@ def _async_unary(stub): request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=size, - payload=messages_pb2.Payload(body=b'\x00' * 271828)) + payload=messages_pb2.Payload(body=b"\x00" * 271828), + ) response_future = stub.UnaryCall.future(request, timeout=_RPC_TIMEOUT_S) response = response_future.result() @@ -74,13 +79,13 @@ def _blocking_unary(stub): request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=size, - payload=messages_pb2.Payload(body=b'\x00' * 271828)) + payload=messages_pb2.Payload(body=b"\x00" * 271828), + ) response = stub.UnaryCall(request, timeout=_RPC_TIMEOUT_S) _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) class _Pipe(object): - def __init__(self): self._condition = threading.Condition() self._values = [] @@ -119,7 +124,6 @@ def __exit__(self, type, value, traceback): class _ChildProcess(object): - def __init__(self, task, args=None): if args is None: args = () @@ -134,13 +138,14 @@ def __init__(self, task, args=None): def _child_main(self): import faulthandler + faulthandler.enable(all_threads=True) try: self._task(*self._args) except grpc.RpcError as rpc_error: traceback.print_exc() - self._exceptions.put('RpcError: %s' % rpc_error) + self._exceptions.put("RpcError: %s" % rpc_error) except Exception as e: # pylint: disable=broad-except traceback.print_exc() self._exceptions.put(e) @@ -217,9 +222,11 @@ def _print_backtraces(self): finally: for stream_name, stream in zip(("STDOUT", "STDERR"), streams): stream.seek(0) - sys.stderr.write("gdb {}:\n{}\n".format( - stream_name, - stream.read().decode("ascii"))) + sys.stderr.write( + "gdb {}:\n{}\n".format( + stream_name, stream.read().decode("ascii") + ) + ) stream.close() sys.stderr.flush() @@ -228,24 +235,26 @@ def finish(self): sys.stderr.write("Exit code: {}\n".format(self._rc)) if not terminated: self._print_backtraces() - raise RuntimeError('Child process did not terminate') + raise RuntimeError("Child process did not terminate") if self._rc != 0: - raise ValueError('Child process failed with exitcode %d' % self._rc) + raise ValueError("Child process failed with exitcode %d" % self._rc) try: exception = self._exceptions.get(block=False) - raise ValueError('Child process failed: "%s": "%s"' % - (repr(exception), exception)) + raise ValueError( + 'Child process failed: "%s": "%s"' + % (repr(exception), exception) + ) except queue.Empty: pass def _async_unary_same_channel(channel): - def child_target(): try: _async_unary(stub) raise Exception( - 'Child should not be able to re-use channel after fork') + "Child should not be able to re-use channel after fork" + ) except ValueError as expected_value_error: pass @@ -258,7 +267,6 @@ def child_target(): def _async_unary_new_channel(channel, args): - def child_target(): with _channel(args) as child_channel: child_stub = test_pb2_grpc.TestServiceStub(child_channel) @@ -274,12 +282,12 @@ def child_target(): def _blocking_unary_same_channel(channel): - def child_target(): try: _blocking_unary(stub) raise Exception( - 'Child should not be able to re-use channel after fork') + "Child should not be able to re-use channel after fork" + ) except ValueError as expected_value_error: pass @@ -291,7 +299,6 @@ def child_target(): def _blocking_unary_new_channel(channel, args): - def child_target(): with _channel(args) as child_channel: child_stub = test_pb2_grpc.TestServiceStub(child_channel) @@ -307,7 +314,6 @@ def child_target(): # Verify that the fork channel registry can handle already closed channels def _close_channel_before_fork(channel, args): - def child_target(): new_channel.close() with _channel(args) as child_channel: @@ -327,12 +333,10 @@ def child_target(): def _connectivity_watch(channel, args): - parent_states = [] parent_channel_ready_event = threading.Event() def child_target(): - child_channel_ready_event = threading.Event() def child_connectivity_callback(state): @@ -344,11 +348,12 @@ def child_connectivity_callback(state): child_channel.subscribe(child_connectivity_callback) _async_unary(child_stub) if not child_channel_ready_event.wait(timeout=_RPC_TIMEOUT_S): - raise ValueError('Channel did not move to READY') + raise ValueError("Channel did not move to READY") if len(parent_states) > 1: raise ValueError( - 'Received connectivity updates on parent callback', - parent_states) + "Received connectivity updates on parent callback", + parent_states, + ) child_channel.unsubscribe(child_connectivity_callback) def parent_connectivity_callback(state): @@ -362,13 +367,14 @@ def parent_connectivity_callback(state): child_process.start() _async_unary(stub) if not parent_channel_ready_event.wait(timeout=_RPC_TIMEOUT_S): - raise ValueError('Channel did not move to READY') + raise ValueError("Channel did not move to READY") channel.unsubscribe(parent_connectivity_callback) child_process.finish() def _ping_pong_with_child_processes_after_first_response( - channel, args, child_target, run_after_close=True): + channel, args, child_target, run_after_close=True +): request_response_sizes = ( 31415, 9, @@ -386,31 +392,38 @@ def _ping_pong_with_child_processes_after_first_response( parent_bidi_call = stub.FullDuplexCall(pipe) child_processes = [] first_message_received = False - for response_size, payload_size in zip(request_response_sizes, - request_payload_sizes): + for response_size, payload_size in zip( + request_response_sizes, request_payload_sizes + ): request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - response_parameters=(messages_pb2.ResponseParameters( - size=response_size),), - payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + response_parameters=( + messages_pb2.ResponseParameters(size=response_size), + ), + payload=messages_pb2.Payload(body=b"\x00" * payload_size), + ) pipe.add(request) if first_message_received: - child_process = _ChildProcess(child_target, - (parent_bidi_call, channel, args)) + child_process = _ChildProcess( + child_target, (parent_bidi_call, channel, args) + ) child_process.start() child_processes.append(child_process) response = next(parent_bidi_call) first_message_received = True - child_process = _ChildProcess(child_target, - (parent_bidi_call, channel, args)) + child_process = _ChildProcess( + child_target, (parent_bidi_call, channel, args) + ) child_process.start() child_processes.append(child_process) - _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, - response_size) + _validate_payload_type_and_length( + response, messages_pb2.COMPRESSABLE, response_size + ) pipe.close() if run_after_close: - child_process = _ChildProcess(child_target, - (parent_bidi_call, channel, args)) + child_process = _ChildProcess( + child_target, (parent_bidi_call, channel, args) + ) child_process.start() child_processes.append(child_process) for child_process in child_processes: @@ -418,99 +431,109 @@ def _ping_pong_with_child_processes_after_first_response( def _in_progress_bidi_continue_call(channel): - def child_target(parent_bidi_call, parent_channel, args): stub = test_pb2_grpc.TestServiceStub(parent_channel) try: _async_unary(stub) raise Exception( - 'Child should not be able to re-use channel after fork') + "Child should not be able to re-use channel after fork" + ) except ValueError as expected_value_error: pass inherited_code = parent_bidi_call.code() inherited_details = parent_bidi_call.details() if inherited_code != grpc.StatusCode.CANCELLED: - raise ValueError('Expected inherited code CANCELLED, got %s' % - inherited_code) - if inherited_details != 'Channel closed due to fork': raise ValueError( - 'Expected inherited details Channel closed due to fork, got %s' - % inherited_details) + "Expected inherited code CANCELLED, got %s" % inherited_code + ) + if inherited_details != "Channel closed due to fork": + raise ValueError( + "Expected inherited details Channel closed due to fork, got %s" + % inherited_details + ) # Don't run child_target after closing the parent call, as the call may have # received a status from the server before fork occurs. - _ping_pong_with_child_processes_after_first_response(channel, - None, - child_target, - run_after_close=False) + _ping_pong_with_child_processes_after_first_response( + channel, None, child_target, run_after_close=False + ) def _in_progress_bidi_same_channel_async_call(channel): - def child_target(parent_bidi_call, parent_channel, args): stub = test_pb2_grpc.TestServiceStub(parent_channel) try: _async_unary(stub) raise Exception( - 'Child should not be able to re-use channel after fork') + "Child should not be able to re-use channel after fork" + ) except ValueError as expected_value_error: pass _ping_pong_with_child_processes_after_first_response( - channel, None, child_target) + channel, None, child_target + ) def _in_progress_bidi_same_channel_blocking_call(channel): - def child_target(parent_bidi_call, parent_channel, args): stub = test_pb2_grpc.TestServiceStub(parent_channel) try: _blocking_unary(stub) raise Exception( - 'Child should not be able to re-use channel after fork') + "Child should not be able to re-use channel after fork" + ) except ValueError as expected_value_error: pass _ping_pong_with_child_processes_after_first_response( - channel, None, child_target) + channel, None, child_target + ) def _in_progress_bidi_new_channel_async_call(channel, args): - def child_target(parent_bidi_call, parent_channel, args): with _channel(args) as channel: stub = test_pb2_grpc.TestServiceStub(channel) _async_unary(stub) _ping_pong_with_child_processes_after_first_response( - channel, args, child_target) + channel, args, child_target + ) def _in_progress_bidi_new_channel_blocking_call(channel, args): - def child_target(parent_bidi_call, parent_channel, args): with _channel(args) as channel: stub = test_pb2_grpc.TestServiceStub(channel) _blocking_unary(stub) _ping_pong_with_child_processes_after_first_response( - channel, args, child_target) + channel, args, child_target + ) @enum.unique class TestCase(enum.Enum): - - CONNECTIVITY_WATCH = 'connectivity_watch' - CLOSE_CHANNEL_BEFORE_FORK = 'close_channel_before_fork' - ASYNC_UNARY_SAME_CHANNEL = 'async_unary_same_channel' - ASYNC_UNARY_NEW_CHANNEL = 'async_unary_new_channel' - BLOCKING_UNARY_SAME_CHANNEL = 'blocking_unary_same_channel' - BLOCKING_UNARY_NEW_CHANNEL = 'blocking_unary_new_channel' - IN_PROGRESS_BIDI_CONTINUE_CALL = 'in_progress_bidi_continue_call' - IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = 'in_progress_bidi_same_channel_async_call' - IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_same_channel_blocking_call' - IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = 'in_progress_bidi_new_channel_async_call' - IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = 'in_progress_bidi_new_channel_blocking_call' + CONNECTIVITY_WATCH = "connectivity_watch" + CLOSE_CHANNEL_BEFORE_FORK = "close_channel_before_fork" + ASYNC_UNARY_SAME_CHANNEL = "async_unary_same_channel" + ASYNC_UNARY_NEW_CHANNEL = "async_unary_new_channel" + BLOCKING_UNARY_SAME_CHANNEL = "blocking_unary_same_channel" + BLOCKING_UNARY_NEW_CHANNEL = "blocking_unary_new_channel" + IN_PROGRESS_BIDI_CONTINUE_CALL = "in_progress_bidi_continue_call" + IN_PROGRESS_BIDI_SAME_CHANNEL_ASYNC_CALL = ( + "in_progress_bidi_same_channel_async_call" + ) + IN_PROGRESS_BIDI_SAME_CHANNEL_BLOCKING_CALL = ( + "in_progress_bidi_same_channel_blocking_call" + ) + IN_PROGRESS_BIDI_NEW_CHANNEL_ASYNC_CALL = ( + "in_progress_bidi_new_channel_async_call" + ) + IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL = ( + "in_progress_bidi_new_channel_blocking_call" + ) def run_test(self, args): _LOGGER.info("Running %s", self) @@ -538,8 +561,9 @@ def run_test(self, args): elif self is TestCase.IN_PROGRESS_BIDI_NEW_CHANNEL_BLOCKING_CALL: _in_progress_bidi_new_channel_blocking_call(channel, args) else: - raise NotImplementedError('Test case "%s" not implemented!' % - self.name) + raise NotImplementedError( + 'Test case "%s" not implemented!' % self.name + ) channel.close() diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py index 980e0b75a0372..31580892e8163 100644 --- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py +++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py @@ -29,10 +29,10 @@ from tests.unit import thread_pool from tests.unit.framework.common import test_constants -_SERVING_SERVICE = 'grpc.test.TestServiceServing' -_UNKNOWN_SERVICE = 'grpc.test.TestServiceUnknown' -_NOT_SERVING_SERVICE = 'grpc.test.TestServiceNotServing' -_WATCH_SERVICE = 'grpc.test.WatchService' +_SERVING_SERVICE = "grpc.test.TestServiceServing" +_UNKNOWN_SERVICE = "grpc.test.TestServiceUnknown" +_NOT_SERVING_SERVICE = "grpc.test.TestServiceNotServing" +_WATCH_SERVICE = "grpc.test.WatchService" def _consume_responses(response_iterator, response_queue): @@ -41,29 +41,33 @@ def _consume_responses(response_iterator, response_queue): class BaseWatchTests(object): - - @unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') + @unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" + ) class WatchTests(unittest.TestCase): - def start_server(self, non_blocking=False, thread_pool=None): self._thread_pool = thread_pool self._servicer = health.HealthServicer( experimental_non_blocking=non_blocking, - experimental_thread_pool=thread_pool) - self._servicer.set(_SERVING_SERVICE, - health_pb2.HealthCheckResponse.SERVING) - self._servicer.set(_UNKNOWN_SERVICE, - health_pb2.HealthCheckResponse.UNKNOWN) - self._servicer.set(_NOT_SERVING_SERVICE, - health_pb2.HealthCheckResponse.NOT_SERVING) + experimental_thread_pool=thread_pool, + ) + self._servicer.set( + _SERVING_SERVICE, health_pb2.HealthCheckResponse.SERVING + ) + self._servicer.set( + _UNKNOWN_SERVICE, health_pb2.HealthCheckResponse.UNKNOWN + ) + self._servicer.set( + _NOT_SERVING_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING + ) self._server = test_common.test_server() - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") health_pb2_grpc.add_HealthServicer_to_server( - self._servicer, self._server) + self._servicer, self._server + ) self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) self._stub = health_pb2_grpc.HealthStub(self._channel) def tearDown(self): @@ -71,16 +75,18 @@ def tearDown(self): self._channel.close() def test_watch_empty_service(self): - request = health_pb2.HealthCheckRequest(service='') + request = health_pb2.HealthCheckRequest(service="") response_queue = queue.Queue() rendezvous = self._stub.Watch(request) - thread = threading.Thread(target=_consume_responses, - args=(rendezvous, response_queue)) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue) + ) thread.start() response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response.status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, response.status + ) rendezvous.cancel() thread.join() @@ -93,25 +99,31 @@ def test_watch_new_service(self): request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) response_queue = queue.Queue() rendezvous = self._stub.Watch(request) - thread = threading.Thread(target=_consume_responses, - args=(rendezvous, response_queue)) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue) + ) thread.start() response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response.status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, response.status + ) - self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) + self._servicer.set( + _WATCH_SERVICE, health_pb2.HealthCheckResponse.SERVING + ) response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response.status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, response.status + ) - self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.NOT_SERVING) + self._servicer.set( + _WATCH_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING + ) response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, - response.status) + self.assertEqual( + health_pb2.HealthCheckResponse.NOT_SERVING, response.status + ) rendezvous.cancel() thread.join() @@ -121,16 +133,19 @@ def test_watch_service_isolation(self): request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) response_queue = queue.Queue() rendezvous = self._stub.Watch(request) - thread = threading.Thread(target=_consume_responses, - args=(rendezvous, response_queue)) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue) + ) thread.start() response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response.status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, response.status + ) - self._servicer.set('some-other-service', - health_pb2.HealthCheckResponse.SERVING) + self._servicer.set( + "some-other-service", health_pb2.HealthCheckResponse.SERVING + ) with self.assertRaises(queue.Empty): response_queue.get(timeout=test_constants.SHORT_TIMEOUT) @@ -144,32 +159,43 @@ def test_two_watchers(self): response_queue2 = queue.Queue() rendezvous1 = self._stub.Watch(request) rendezvous2 = self._stub.Watch(request) - thread1 = threading.Thread(target=_consume_responses, - args=(rendezvous1, response_queue1)) - thread2 = threading.Thread(target=_consume_responses, - args=(rendezvous2, response_queue2)) + thread1 = threading.Thread( + target=_consume_responses, args=(rendezvous1, response_queue1) + ) + thread2 = threading.Thread( + target=_consume_responses, args=(rendezvous2, response_queue2) + ) thread1.start() thread2.start() response1 = response_queue1.get( - timeout=test_constants.SHORT_TIMEOUT) + timeout=test_constants.SHORT_TIMEOUT + ) response2 = response_queue2.get( - timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response1.status) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response2.status) - - self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) + timeout=test_constants.SHORT_TIMEOUT + ) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, response1.status + ) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, response2.status + ) + + self._servicer.set( + _WATCH_SERVICE, health_pb2.HealthCheckResponse.SERVING + ) response1 = response_queue1.get( - timeout=test_constants.SHORT_TIMEOUT) + timeout=test_constants.SHORT_TIMEOUT + ) response2 = response_queue2.get( - timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response1.status) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response2.status) + timeout=test_constants.SHORT_TIMEOUT + ) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, response1.status + ) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, response2.status + ) rendezvous1.cancel() rendezvous2.cancel() @@ -183,63 +209,72 @@ def test_cancelled_watch_removed_from_watch_list(self): request = health_pb2.HealthCheckRequest(service=_WATCH_SERVICE) response_queue = queue.Queue() rendezvous = self._stub.Watch(request) - thread = threading.Thread(target=_consume_responses, - args=(rendezvous, response_queue)) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue) + ) thread.start() response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - response.status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, response.status + ) rendezvous.cancel() - self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) + self._servicer.set( + _WATCH_SERVICE, health_pb2.HealthCheckResponse.SERVING + ) thread.join() # Wait, if necessary, for serving thread to process client cancellation timeout = time.time() + test_constants.TIME_ALLOWANCE - while (time.time() < timeout and - self._servicer._send_response_callbacks[_WATCH_SERVICE]): + while ( + time.time() < timeout + and self._servicer._send_response_callbacks[_WATCH_SERVICE] + ): time.sleep(1) self.assertFalse( self._servicer._send_response_callbacks[_WATCH_SERVICE], - 'watch set should be empty') + "watch set should be empty", + ) self.assertTrue(response_queue.empty()) def test_graceful_shutdown(self): - request = health_pb2.HealthCheckRequest(service='') + request = health_pb2.HealthCheckRequest(service="") response_queue = queue.Queue() rendezvous = self._stub.Watch(request) - thread = threading.Thread(target=_consume_responses, - args=(rendezvous, response_queue)) + thread = threading.Thread( + target=_consume_responses, args=(rendezvous, response_queue) + ) thread.start() response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - response.status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, response.status + ) self._servicer.enter_graceful_shutdown() response = response_queue.get(timeout=test_constants.SHORT_TIMEOUT) - self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, - response.status) + self.assertEqual( + health_pb2.HealthCheckResponse.NOT_SERVING, response.status + ) # This should be a no-op. - self._servicer.set('', health_pb2.HealthCheckResponse.SERVING) + self._servicer.set("", health_pb2.HealthCheckResponse.SERVING) rendezvous.cancel() thread.join() self.assertTrue(response_queue.empty()) -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class HealthServicerTest(BaseWatchTests.WatchTests): - def setUp(self): self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None) - super(HealthServicerTest, - self).start_server(non_blocking=True, - thread_pool=self._thread_pool) + super(HealthServicerTest, self).start_server( + non_blocking=True, thread_pool=self._thread_pool + ) def test_check_empty_service(self): request = health_pb2.HealthCheckRequest() @@ -259,29 +294,31 @@ def test_check_unknown_service(self): def test_check_not_serving_service(self): request = health_pb2.HealthCheckRequest(service=_NOT_SERVING_SERVICE) resp = self._stub.Check(request) - self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, - resp.status) + self.assertEqual( + health_pb2.HealthCheckResponse.NOT_SERVING, resp.status + ) def test_check_not_found_service(self): - request = health_pb2.HealthCheckRequest(service='not-found') + request = health_pb2.HealthCheckRequest(service="not-found") with self.assertRaises(grpc.RpcError) as context: resp = self._stub.Check(request) self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code()) def test_health_service_name(self): - self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health') + self.assertEqual(health.SERVICE_NAME, "grpc.health.v1.Health") -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class HealthServicerBackwardsCompatibleWatchTest(BaseWatchTests.WatchTests): - def setUp(self): - super(HealthServicerBackwardsCompatibleWatchTest, - self).start_server(non_blocking=False, thread_pool=None) + super(HealthServicerBackwardsCompatibleWatchTest, self).start_server( + non_blocking=False, thread_pool=None + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/http2/negative_http2_client.py b/src/python/grpcio_tests/tests/http2/negative_http2_client.py index 138f61995c92b..74ac1de9bea0b 100644 --- a/src/python/grpcio_tests/tests/http2/negative_http2_client.py +++ b/src/python/grpcio_tests/tests/http2/negative_http2_client.py @@ -24,23 +24,29 @@ def _validate_payload_type_and_length(response, expected_type, expected_length): if response.payload.type is not expected_type: - raise ValueError('expected payload type %s, got %s' % - (expected_type, type(response.payload.type))) + raise ValueError( + "expected payload type %s, got %s" + % (expected_type, type(response.payload.type)) + ) elif len(response.payload.body) != expected_length: - raise ValueError('expected payload body size %d, got %d' % - (expected_length, len(response.payload.body))) + raise ValueError( + "expected payload body size %d, got %d" + % (expected_length, len(response.payload.body)) + ) def _expect_status_code(call, expected_code): if call.code() != expected_code: - raise ValueError('expected code %s, got %s' % - (expected_code, call.code())) + raise ValueError( + "expected code %s, got %s" % (expected_code, call.code()) + ) def _expect_status_details(call, expected_details): if call.details() != expected_details: - raise ValueError('expected message %s, got %s' % - (expected_details, call.details())) + raise ValueError( + "expected message %s, got %s" % (expected_details, call.details()) + ) def _validate_status_code_and_details(call, expected_code, expected_details): @@ -55,71 +61,85 @@ def _validate_status_code_and_details(call, expected_code, expected_details): _SIMPLE_REQUEST = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=_RESPONSE_SIZE, - payload=messages_pb2.Payload(body=b'\x00' * _REQUEST_SIZE)) + payload=messages_pb2.Payload(body=b"\x00" * _REQUEST_SIZE), +) def _goaway(stub): first_response = stub.UnaryCall(_SIMPLE_REQUEST) - _validate_payload_type_and_length(first_response, messages_pb2.COMPRESSABLE, - _RESPONSE_SIZE) + _validate_payload_type_and_length( + first_response, messages_pb2.COMPRESSABLE, _RESPONSE_SIZE + ) time.sleep(1) second_response = stub.UnaryCall(_SIMPLE_REQUEST) - _validate_payload_type_and_length(second_response, - messages_pb2.COMPRESSABLE, _RESPONSE_SIZE) + _validate_payload_type_and_length( + second_response, messages_pb2.COMPRESSABLE, _RESPONSE_SIZE + ) def _rst_after_header(stub): resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST) - _validate_status_code_and_details(resp_future, grpc.StatusCode.INTERNAL, - "Received RST_STREAM with error code 0") + _validate_status_code_and_details( + resp_future, + grpc.StatusCode.INTERNAL, + "Received RST_STREAM with error code 0", + ) def _rst_during_data(stub): resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST) - _validate_status_code_and_details(resp_future, grpc.StatusCode.INTERNAL, - "Received RST_STREAM with error code 0") + _validate_status_code_and_details( + resp_future, + grpc.StatusCode.INTERNAL, + "Received RST_STREAM with error code 0", + ) def _rst_after_data(stub): resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST) - _validate_status_code_and_details(resp_future, grpc.StatusCode.INTERNAL, - "Received RST_STREAM with error code 0") + _validate_status_code_and_details( + resp_future, + grpc.StatusCode.INTERNAL, + "Received RST_STREAM with error code 0", + ) def _ping(stub): response = stub.UnaryCall(_SIMPLE_REQUEST) - _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, - _RESPONSE_SIZE) + _validate_payload_type_and_length( + response, messages_pb2.COMPRESSABLE, _RESPONSE_SIZE + ) def _max_streams(stub): # send one req to ensure server sets MAX_STREAMS response = stub.UnaryCall(_SIMPLE_REQUEST) - _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, - _RESPONSE_SIZE) + _validate_payload_type_and_length( + response, messages_pb2.COMPRESSABLE, _RESPONSE_SIZE + ) # give the streams a workout futures = [] for _ in range(15): futures.append(stub.UnaryCall.future(_SIMPLE_REQUEST)) for future in futures: - _validate_payload_type_and_length(future.result(), - messages_pb2.COMPRESSABLE, - _RESPONSE_SIZE) + _validate_payload_type_and_length( + future.result(), messages_pb2.COMPRESSABLE, _RESPONSE_SIZE + ) def _run_test_case(test_case, stub): - if test_case == 'goaway': + if test_case == "goaway": _goaway(stub) - elif test_case == 'rst_after_header': + elif test_case == "rst_after_header": _rst_after_header(stub) - elif test_case == 'rst_during_data': + elif test_case == "rst_during_data": _rst_during_data(stub) - elif test_case == 'rst_after_data': + elif test_case == "rst_after_data": _rst_after_data(stub) - elif test_case == 'ping': + elif test_case == "ping": _ping(stub) - elif test_case == 'max_streams': + elif test_case == "max_streams": _max_streams(stub) else: raise ValueError("Invalid test case: %s" % test_case) @@ -127,23 +147,29 @@ def _run_test_case(test_case, stub): def _args(): parser = argparse.ArgumentParser() - parser.add_argument('--server_host', - help='the host to which to connect', - type=str, - default="127.0.0.1") - parser.add_argument('--server_port', - help='the port to which to connect', - type=int, - default="8080") - parser.add_argument('--test_case', - help='the test case to execute', - type=str, - default="goaway") + parser.add_argument( + "--server_host", + help="the host to which to connect", + type=str, + default="127.0.0.1", + ) + parser.add_argument( + "--server_port", + help="the port to which to connect", + type=int, + default="8080", + ) + parser.add_argument( + "--test_case", + help="the test case to execute", + type=str, + default="goaway", + ) return parser.parse_args() def _stub(server_host, server_port): - target = '{}:{}'.format(server_host, server_port) + target = "{}:{}".format(server_host, server_port) channel = grpc.insecure_channel(target) grpc.channel_ready_future(channel).result() return test_pb2_grpc.TestServiceStub(channel) @@ -155,5 +181,5 @@ def main(): _run_test_case(args.test_case, stub) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py b/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py index 27e5dcdd759d2..bb8ada4c58fee 100644 --- a/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py +++ b/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py @@ -25,23 +25,26 @@ from tests.unit import test_common -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') -class InsecureIntraopTest(_intraop_test_case.IntraopTestCase, - unittest.TestCase): - +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) +class InsecureIntraopTest( + _intraop_test_case.IntraopTestCase, unittest.TestCase +): def setUp(self): self.server = test_common.test_server() - test_pb2_grpc.add_TestServiceServicer_to_server(service.TestService(), - self.server) - port = self.server.add_insecure_port('[::]:0') + test_pb2_grpc.add_TestServiceServicer_to_server( + service.TestService(), self.server + ) + port = self.server.add_insecure_port("[::]:0") self.server.start() self.stub = test_pb2_grpc.TestServiceStub( - grpc.insecure_channel('localhost:{}'.format(port))) + grpc.insecure_channel("localhost:{}".format(port)) + ) def tearDown(self): self.server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/interop/_intraop_test_case.py b/src/python/grpcio_tests/tests/interop/_intraop_test_case.py index 007db7ab41b31..c35ee2a7e93b0 100644 --- a/src/python/grpcio_tests/tests/interop/_intraop_test_case.py +++ b/src/python/grpcio_tests/tests/interop/_intraop_test_case.py @@ -19,9 +19,9 @@ class IntraopTestCase(object): """Unit test methods. - This class must be mixed in with unittest.TestCase and a class that defines - setUp and tearDown methods that manage a stub attribute. - """ + This class must be mixed in with unittest.TestCase and a class that defines + setUp and tearDown methods that manage a stub attribute. + """ def testEmptyUnary(self): methods.TestCase.EMPTY_UNARY.test_interoperability(self.stub, None) @@ -40,12 +40,15 @@ def testPingPong(self): def testCancelAfterBegin(self): methods.TestCase.CANCEL_AFTER_BEGIN.test_interoperability( - self.stub, None) + self.stub, None + ) def testCancelAfterFirstResponse(self): methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE.test_interoperability( - self.stub, None) + self.stub, None + ) def testTimeoutOnSleepingServer(self): methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER.test_interoperability( - self.stub, None) + self.stub, None + ) diff --git a/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py b/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py index 0ec88a2cd997d..3572dabbc121d 100644 --- a/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py +++ b/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py @@ -24,34 +24,43 @@ from tests.interop import service from tests.unit import test_common -_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_SERVER_HOST_OVERRIDE = "foo.test.google.fr" -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class SecureIntraopTest(_intraop_test_case.IntraopTestCase, unittest.TestCase): - def setUp(self): self.server = test_common.test_server() - test_pb2_grpc.add_TestServiceServicer_to_server(service.TestService(), - self.server) + test_pb2_grpc.add_TestServiceServicer_to_server( + service.TestService(), self.server + ) port = self.server.add_secure_port( - '[::]:0', - grpc.ssl_server_credentials([(resources.private_key(), - resources.certificate_chain())])) + "[::]:0", + grpc.ssl_server_credentials( + [(resources.private_key(), resources.certificate_chain())] + ), + ) self.server.start() self.stub = test_pb2_grpc.TestServiceStub( grpc.secure_channel( - 'localhost:{}'.format(port), + "localhost:{}".format(port), grpc.ssl_channel_credentials( - resources.test_root_certificates()), (( - 'grpc.ssl_target_name_override', + resources.test_root_certificates() + ), + ( + ( + "grpc.ssl_target_name_override", _SERVER_HOST_OVERRIDE, - ),))) + ), + ), + ) + ) def tearDown(self): self.server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/interop/client.py b/src/python/grpcio_tests/tests/interop/client.py index c95e81353d9ca..14095344b171d 100644 --- a/src/python/grpcio_tests/tests/interop/client.py +++ b/src/python/grpcio_tests/tests/interop/client.py @@ -27,73 +27,102 @@ def parse_interop_client_args(): parser = argparse.ArgumentParser() - parser.add_argument('--server_host', - default="localhost", - type=str, - help='the host to which to connect') - parser.add_argument('--server_port', - type=int, - required=True, - help='the port to which to connect') - parser.add_argument('--test_case', - default='large_unary', - type=str, - help='the test case to execute') - parser.add_argument('--use_tls', - default=False, - type=resources.parse_bool, - help='require a secure connection') - parser.add_argument('--use_alts', - default=False, - type=resources.parse_bool, - help='require an ALTS secure connection') - parser.add_argument('--use_test_ca', - default=False, - type=resources.parse_bool, - help='replace platform root CAs with ca.pem') - parser.add_argument('--custom_credentials_type', - choices=["compute_engine_channel_creds"], - default=None, - help='use google default credentials') - parser.add_argument('--server_host_override', - type=str, - help='the server host to which to claim to connect') - parser.add_argument('--oauth_scope', - type=str, - help='scope for OAuth tokens') - parser.add_argument('--default_service_account', - type=str, - help='email address of the default service account') + parser.add_argument( + "--server_host", + default="localhost", + type=str, + help="the host to which to connect", + ) + parser.add_argument( + "--server_port", + type=int, + required=True, + help="the port to which to connect", + ) + parser.add_argument( + "--test_case", + default="large_unary", + type=str, + help="the test case to execute", + ) + parser.add_argument( + "--use_tls", + default=False, + type=resources.parse_bool, + help="require a secure connection", + ) + parser.add_argument( + "--use_alts", + default=False, + type=resources.parse_bool, + help="require an ALTS secure connection", + ) + parser.add_argument( + "--use_test_ca", + default=False, + type=resources.parse_bool, + help="replace platform root CAs with ca.pem", + ) + parser.add_argument( + "--custom_credentials_type", + choices=["compute_engine_channel_creds"], + default=None, + help="use google default credentials", + ) + parser.add_argument( + "--server_host_override", + type=str, + help="the server host to which to claim to connect", + ) + parser.add_argument( + "--oauth_scope", type=str, help="scope for OAuth tokens" + ) + parser.add_argument( + "--default_service_account", + type=str, + help="email address of the default service account", + ) parser.add_argument( "--grpc_test_use_grpclb_with_child_policy", type=str, help=( "If non-empty, set a static service config on channels created by " - + "grpc::CreateTestChannel, that configures the grpclb LB policy " + - "with a child policy being the value of this flag (e.g. round_robin " - + "or pick_first).")) + + "grpc::CreateTestChannel, that configures the grpclb LB policy " + + "with a child policy being the value of this flag (e.g." + " round_robin " + "or pick_first)." + ), + ) return parser.parse_args() def _create_call_credentials(args): - if args.test_case == 'oauth2_auth_token': + if args.test_case == "oauth2_auth_token": google_credentials, unused_project_id = google_auth.default( - scopes=[args.oauth_scope]) + scopes=[args.oauth_scope] + ) google_credentials.refresh(google_auth.transport.requests.Request()) return grpc.access_token_call_credentials(google_credentials.token) - elif args.test_case == 'compute_engine_creds': + elif args.test_case == "compute_engine_creds": google_credentials, unused_project_id = google_auth.default( - scopes=[args.oauth_scope]) + scopes=[args.oauth_scope] + ) return grpc.metadata_call_credentials( google_auth.transport.grpc.AuthMetadataPlugin( credentials=google_credentials, - request=google_auth.transport.requests.Request())) - elif args.test_case == 'jwt_token_creds': - google_credentials = google_auth_jwt.OnDemandCredentials.from_service_account_file( - os.environ[google_auth.environment_vars.CREDENTIALS]) + request=google_auth.transport.requests.Request(), + ) + ) + elif args.test_case == "jwt_token_creds": + google_credentials = ( + google_auth_jwt.OnDemandCredentials.from_service_account_file( + os.environ[google_auth.environment_vars.CREDENTIALS] + ) + ) return grpc.metadata_call_credentials( google_auth.transport.grpc.AuthMetadataPlugin( - credentials=google_credentials, request=None)) + credentials=google_credentials, request=None + ) + ) else: return None @@ -103,24 +132,34 @@ def get_secure_channel_parameters(args): channel_opts = () if args.grpc_test_use_grpclb_with_child_policy: - channel_opts += (( - "grpc.service_config", - '{"loadBalancingConfig": [{"grpclb": {"childPolicy": [{"%s": {}}]}}]}' - % args.grpc_test_use_grpclb_with_child_policy),) + channel_opts += ( + ( + "grpc.service_config", + '{"loadBalancingConfig": [{"grpclb": {"childPolicy": [{"%s":' + " {}}]}}]}" % args.grpc_test_use_grpclb_with_child_policy, + ), + ) if args.custom_credentials_type is not None: if args.custom_credentials_type == "compute_engine_channel_creds": assert call_credentials is None google_credentials, unused_project_id = google_auth.default( - scopes=[args.oauth_scope]) + scopes=[args.oauth_scope] + ) call_creds = grpc.metadata_call_credentials( google_auth.transport.grpc.AuthMetadataPlugin( credentials=google_credentials, - request=google_auth.transport.requests.Request())) + request=google_auth.transport.requests.Request(), + ) + ) channel_credentials = grpc.compute_engine_channel_credentials( - call_creds) + call_creds + ) else: - raise ValueError("Unknown credentials type '{}'".format( - args.custom_credentials_type)) + raise ValueError( + "Unknown credentials type '{}'".format( + args.custom_credentials_type + ) + ) elif args.use_tls: if args.use_test_ca: root_certificates = resources.test_root_certificates() @@ -130,13 +169,16 @@ def get_secure_channel_parameters(args): channel_credentials = grpc.ssl_channel_credentials(root_certificates) if call_credentials is not None: channel_credentials = grpc.composite_channel_credentials( - channel_credentials, call_credentials) + channel_credentials, call_credentials + ) if args.server_host_override: - channel_opts += (( - 'grpc.ssl_target_name_override', - args.server_host_override, - ),) + channel_opts += ( + ( + "grpc.ssl_target_name_override", + args.server_host_override, + ), + ) elif args.use_alts: channel_credentials = grpc.alts_channel_credentials() @@ -144,9 +186,13 @@ def get_secure_channel_parameters(args): def _create_channel(args): - target = '{}:{}'.format(args.server_host, args.server_port) + target = "{}:{}".format(args.server_host, args.server_port) - if args.use_tls or args.use_alts or args.custom_credentials_type is not None: + if ( + args.use_tls + or args.use_alts + or args.custom_credentials_type is not None + ): channel_credentials, options = get_secure_channel_parameters(args) return grpc.secure_channel(target, channel_credentials, options) else: @@ -176,5 +222,5 @@ def test_interoperability(): test_case.test_interoperability(stub, args) -if __name__ == '__main__': +if __name__ == "__main__": test_interoperability() diff --git a/src/python/grpcio_tests/tests/interop/methods.py b/src/python/grpcio_tests/tests/interop/methods.py index 44a1c38bb93d4..0e98abc9f4e0f 100644 --- a/src/python/grpcio_tests/tests/interop/methods.py +++ b/src/python/grpcio_tests/tests/interop/methods.py @@ -17,6 +17,7 @@ # please refer to comments in the "bazel_namespace_package_hack" module. try: from tests import bazel_namespace_package_hack + bazel_namespace_package_hack.sys_path_to_site_dir_hack() except ImportError: pass @@ -42,14 +43,16 @@ def _expect_status_code(call, expected_code): if call.code() != expected_code: - raise ValueError('expected code %s, got %s' % - (expected_code, call.code())) + raise ValueError( + "expected code %s, got %s" % (expected_code, call.code()) + ) def _expect_status_details(call, expected_details): if call.details() != expected_details: - raise ValueError('expected message %s, got %s' % - (expected_details, call.details())) + raise ValueError( + "expected message %s, got %s" % (expected_details, call.details()) + ) def _validate_status_code_and_details(call, expected_code, expected_details): @@ -59,24 +62,31 @@ def _validate_status_code_and_details(call, expected_code, expected_details): def _validate_payload_type_and_length(response, expected_type, expected_length): if response.payload.type is not expected_type: - raise ValueError('expected payload type %s, got %s' % - (expected_type, type(response.payload.type))) + raise ValueError( + "expected payload type %s, got %s" + % (expected_type, type(response.payload.type)) + ) elif len(response.payload.body) != expected_length: - raise ValueError('expected payload body size %d, got %d' % - (expected_length, len(response.payload.body))) + raise ValueError( + "expected payload body size %d, got %d" + % (expected_length, len(response.payload.body)) + ) -def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, - call_credentials): +def _large_unary_common_behavior( + stub, fill_username, fill_oauth_scope, call_credentials +): size = 314159 request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=size, - payload=messages_pb2.Payload(body=b'\x00' * 271828), + payload=messages_pb2.Payload(body=b"\x00" * 271828), fill_username=fill_username, - fill_oauth_scope=fill_oauth_scope) - response_future = stub.UnaryCall.future(request, - credentials=call_credentials) + fill_oauth_scope=fill_oauth_scope, + ) + response_future = stub.UnaryCall.future( + request, credentials=call_credentials + ) response = response_future.result() _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) return response @@ -85,8 +95,9 @@ def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, def _empty_unary(stub): response = stub.EmptyCall(empty_pb2.Empty()) if not isinstance(response, empty_pb2.Empty): - raise TypeError('response is of type "%s", not empty_pb2.Empty!' % - type(response)) + raise TypeError( + 'response is of type "%s", not empty_pb2.Empty!' % type(response) + ) def _large_unary(stub): @@ -100,14 +111,18 @@ def _client_streaming(stub): 1828, 45904, ) - payloads = (messages_pb2.Payload(body=b'\x00' * size) - for size in payload_body_sizes) - requests = (messages_pb2.StreamingInputCallRequest(payload=payload) - for payload in payloads) + payloads = ( + messages_pb2.Payload(body=b"\x00" * size) for size in payload_body_sizes + ) + requests = ( + messages_pb2.StreamingInputCallRequest(payload=payload) + for payload in payloads + ) response = stub.StreamingInputCall(requests) if response.aggregated_payload_size != 74922: - raise ValueError('incorrect size %d!' % - response.aggregated_payload_size) + raise ValueError( + "incorrect size %d!" % response.aggregated_payload_size + ) def _server_streaming(stub): @@ -125,15 +140,16 @@ def _server_streaming(stub): messages_pb2.ResponseParameters(size=sizes[1]), messages_pb2.ResponseParameters(size=sizes[2]), messages_pb2.ResponseParameters(size=sizes[3]), - )) + ), + ) response_iterator = stub.StreamingOutputCall(request) for index, response in enumerate(response_iterator): - _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, - sizes[index]) + _validate_payload_type_and_length( + response, messages_pb2.COMPRESSABLE, sizes[index] + ) class _Pipe(object): - def __init__(self): self._condition = threading.Condition() self._values = [] @@ -187,18 +203,21 @@ def _ping_pong(stub): with _Pipe() as pipe: response_iterator = stub.FullDuplexCall(pipe) - for response_size, payload_size in zip(request_response_sizes, - request_payload_sizes): + for response_size, payload_size in zip( + request_response_sizes, request_payload_sizes + ): request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - response_parameters=(messages_pb2.ResponseParameters( - size=response_size),), - payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + response_parameters=( + messages_pb2.ResponseParameters(size=response_size), + ), + payload=messages_pb2.Payload(body=b"\x00" * payload_size), + ) pipe.add(request) response = next(response_iterator) - _validate_payload_type_and_length(response, - messages_pb2.COMPRESSABLE, - response_size) + _validate_payload_type_and_length( + response, messages_pb2.COMPRESSABLE, response_size + ) def _cancel_after_begin(stub): @@ -206,9 +225,9 @@ def _cancel_after_begin(stub): response_future = stub.StreamingInputCall.future(pipe) response_future.cancel() if not response_future.cancelled(): - raise ValueError('expected cancelled method to return True') + raise ValueError("expected cancelled method to return True") if response_future.code() is not grpc.StatusCode.CANCELLED: - raise ValueError('expected status code CANCELLED') + raise ValueError("expected status code CANCELLED") def _cancel_after_first_response(stub): @@ -231,9 +250,11 @@ def _cancel_after_first_response(stub): payload_size = request_payload_sizes[0] request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - response_parameters=(messages_pb2.ResponseParameters( - size=response_size),), - payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + response_parameters=( + messages_pb2.ResponseParameters(size=response_size), + ), + payload=messages_pb2.Payload(body=b"\x00" * payload_size), + ) pipe.add(request) response = next(response_iterator) # We test the contents of `response` in the Ping Pong test - don't check @@ -246,7 +267,7 @@ def _cancel_after_first_response(stub): if rpc_error.code() is not grpc.StatusCode.CANCELLED: raise else: - raise ValueError('expected call to be cancelled') + raise ValueError("expected call to be cancelled") def _timeout_on_sleeping_server(stub): @@ -256,7 +277,8 @@ def _timeout_on_sleeping_server(stub): request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - payload=messages_pb2.Payload(body=b'\x00' * request_payload_size)) + payload=messages_pb2.Payload(body=b"\x00" * request_payload_size), + ) pipe.add(request) try: next(response_iterator) @@ -264,7 +286,7 @@ def _timeout_on_sleeping_server(stub): if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: raise else: - raise ValueError('expected call to exceed deadline') + raise ValueError("expected call to exceed deadline") def _empty_stream(stub): @@ -273,13 +295,13 @@ def _empty_stream(stub): pipe.close() try: next(response_iterator) - raise ValueError('expected exactly 0 responses') + raise ValueError("expected exactly 0 responses") except StopIteration: pass def _status_code_and_message(stub): - details = 'test status message' + details = "test status message" code = 2 status = grpc.StatusCode.UNKNOWN # code = 2 @@ -287,8 +309,9 @@ def _status_code_and_message(stub): request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=1, - payload=messages_pb2.Payload(body=b'\x00'), - response_status=messages_pb2.EchoStatus(code=code, message=details)) + payload=messages_pb2.Payload(body=b"\x00"), + response_status=messages_pb2.EchoStatus(code=code, message=details), + ) response_future = stub.UnaryCall.future(request) _validate_status_code_and_details(response_future, status, details) @@ -298,8 +321,9 @@ def _status_code_and_message(stub): request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, response_parameters=(messages_pb2.ResponseParameters(size=1),), - payload=messages_pb2.Payload(body=b'\x00'), - response_status=messages_pb2.EchoStatus(code=code, message=details)) + payload=messages_pb2.Payload(body=b"\x00"), + response_status=messages_pb2.EchoStatus(code=code, message=details), + ) pipe.add(request) # sends the initial request. try: next(response_iterator) @@ -310,40 +334,53 @@ def _status_code_and_message(stub): def _unimplemented_method(test_service_stub): - response_future = (test_service_stub.UnimplementedCall.future( - empty_pb2.Empty())) + response_future = test_service_stub.UnimplementedCall.future( + empty_pb2.Empty() + ) _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) def _unimplemented_service(unimplemented_service_stub): - response_future = (unimplemented_service_stub.UnimplementedCall.future( - empty_pb2.Empty())) + response_future = unimplemented_service_stub.UnimplementedCall.future( + empty_pb2.Empty() + ) _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED) def _custom_metadata(stub): initial_metadata_value = "test_initial_metadata_value" trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b" - metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value), - (_TRAILING_METADATA_KEY, trailing_metadata_value)) + metadata = ( + (_INITIAL_METADATA_KEY, initial_metadata_value), + (_TRAILING_METADATA_KEY, trailing_metadata_value), + ) def _validate_metadata(response): initial_metadata = dict(response.initial_metadata()) if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: - raise ValueError('expected initial metadata %s, got %s' % - (initial_metadata_value, - initial_metadata[_INITIAL_METADATA_KEY])) + raise ValueError( + "expected initial metadata %s, got %s" + % ( + initial_metadata_value, + initial_metadata[_INITIAL_METADATA_KEY], + ) + ) trailing_metadata = dict(response.trailing_metadata()) if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: - raise ValueError('expected trailing metadata %s, got %s' % - (trailing_metadata_value, - trailing_metadata[_TRAILING_METADATA_KEY])) + raise ValueError( + "expected trailing metadata %s, got %s" + % ( + trailing_metadata_value, + trailing_metadata[_TRAILING_METADATA_KEY], + ) + ) # Testing with UnaryCall request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=1, - payload=messages_pb2.Payload(body=b'\x00')) + payload=messages_pb2.Payload(body=b"\x00"), + ) response_future = stub.UnaryCall.future(request, metadata=metadata) _validate_metadata(response_future) @@ -352,7 +389,8 @@ def _validate_metadata(response): response_iterator = stub.FullDuplexCall(pipe, metadata=metadata) request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - response_parameters=(messages_pb2.ResponseParameters(size=1),)) + response_parameters=(messages_pb2.ResponseParameters(size=1),), + ) pipe.add(request) # Sends the request next(response_iterator) # Causes server to send trailing metadata # Dropping out of the with block closes the pipe @@ -362,50 +400,62 @@ def _validate_metadata(response): def _compute_engine_creds(stub, args): response = _large_unary_common_behavior(stub, True, True, None) if args.default_service_account != response.username: - raise ValueError('expected username %s, got %s' % - (args.default_service_account, response.username)) + raise ValueError( + "expected username %s, got %s" + % (args.default_service_account, response.username) + ) def _oauth2_auth_token(stub, args): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + wanted_email = json.load(open(json_key_filename, "r"))["client_email"] response = _large_unary_common_behavior(stub, True, True, None) if wanted_email != response.username: - raise ValueError('expected username %s, got %s' % - (wanted_email, response.username)) + raise ValueError( + "expected username %s, got %s" % (wanted_email, response.username) + ) if args.oauth_scope.find(response.oauth_scope) == -1: raise ValueError( 'expected to find oauth scope "{}" in received "{}"'.format( - response.oauth_scope, args.oauth_scope)) + response.oauth_scope, args.oauth_scope + ) + ) def _jwt_token_creds(stub, args): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + wanted_email = json.load(open(json_key_filename, "r"))["client_email"] response = _large_unary_common_behavior(stub, True, False, None) if wanted_email != response.username: - raise ValueError('expected username %s, got %s' % - (wanted_email, response.username)) + raise ValueError( + "expected username %s, got %s" % (wanted_email, response.username) + ) def _per_rpc_creds(stub, args): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + wanted_email = json.load(open(json_key_filename, "r"))["client_email"] google_credentials, unused_project_id = google_auth.default( - scopes=[args.oauth_scope]) + scopes=[args.oauth_scope] + ) call_credentials = grpc.metadata_call_credentials( google_auth_transport_grpc.AuthMetadataPlugin( credentials=google_credentials, - request=google_auth_transport_requests.Request())) + request=google_auth_transport_requests.Request(), + ) + ) response = _large_unary_common_behavior(stub, True, False, call_credentials) if wanted_email != response.username: - raise ValueError('expected username %s, got %s' % - (wanted_email, response.username)) + raise ValueError( + "expected username %s, got %s" % (wanted_email, response.username) + ) def _special_status_message(stub, args): - details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode( - 'utf-8') + details = ( + b"\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP" + b" \xf0\x9f\x98\x88\t\n".decode("utf-8") + ) code = 2 status = grpc.StatusCode.UNKNOWN # code = 2 @@ -413,32 +463,33 @@ def _special_status_message(stub, args): request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=1, - payload=messages_pb2.Payload(body=b'\x00'), - response_status=messages_pb2.EchoStatus(code=code, message=details)) + payload=messages_pb2.Payload(body=b"\x00"), + response_status=messages_pb2.EchoStatus(code=code, message=details), + ) response_future = stub.UnaryCall.future(request) _validate_status_code_and_details(response_future, status, details) @enum.unique class TestCase(enum.Enum): - EMPTY_UNARY = 'empty_unary' - LARGE_UNARY = 'large_unary' - SERVER_STREAMING = 'server_streaming' - CLIENT_STREAMING = 'client_streaming' - PING_PONG = 'ping_pong' - CANCEL_AFTER_BEGIN = 'cancel_after_begin' - CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' - EMPTY_STREAM = 'empty_stream' - STATUS_CODE_AND_MESSAGE = 'status_code_and_message' - UNIMPLEMENTED_METHOD = 'unimplemented_method' - UNIMPLEMENTED_SERVICE = 'unimplemented_service' + EMPTY_UNARY = "empty_unary" + LARGE_UNARY = "large_unary" + SERVER_STREAMING = "server_streaming" + CLIENT_STREAMING = "client_streaming" + PING_PONG = "ping_pong" + CANCEL_AFTER_BEGIN = "cancel_after_begin" + CANCEL_AFTER_FIRST_RESPONSE = "cancel_after_first_response" + EMPTY_STREAM = "empty_stream" + STATUS_CODE_AND_MESSAGE = "status_code_and_message" + UNIMPLEMENTED_METHOD = "unimplemented_method" + UNIMPLEMENTED_SERVICE = "unimplemented_service" CUSTOM_METADATA = "custom_metadata" - COMPUTE_ENGINE_CREDS = 'compute_engine_creds' - OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' - JWT_TOKEN_CREDS = 'jwt_token_creds' - PER_RPC_CREDS = 'per_rpc_creds' - TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' - SPECIAL_STATUS_MESSAGE = 'special_status_message' + COMPUTE_ENGINE_CREDS = "compute_engine_creds" + OAUTH2_AUTH_TOKEN = "oauth2_auth_token" + JWT_TOKEN_CREDS = "jwt_token_creds" + PER_RPC_CREDS = "per_rpc_creds" + TIMEOUT_ON_SLEEPING_SERVER = "timeout_on_sleeping_server" + SPECIAL_STATUS_MESSAGE = "special_status_message" def test_interoperability(self, stub, args): if self is TestCase.EMPTY_UNARY: @@ -478,5 +529,6 @@ def test_interoperability(self, stub, args): elif self is TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message(stub, args) else: - raise NotImplementedError('Test case "%s" not implemented!' % - self.name) + raise NotImplementedError( + 'Test case "%s" not implemented!' % self.name + ) diff --git a/src/python/grpcio_tests/tests/interop/resources.py b/src/python/grpcio_tests/tests/interop/resources.py index a47228a355d09..e5eedd987c124 100644 --- a/src/python/grpcio_tests/tests/interop/resources.py +++ b/src/python/grpcio_tests/tests/interop/resources.py @@ -17,9 +17,9 @@ import os import pkgutil -_ROOT_CERTIFICATES_RESOURCE_PATH = 'credentials/ca.pem' -_PRIVATE_KEY_RESOURCE_PATH = 'credentials/server1.key' -_CERTIFICATE_CHAIN_RESOURCE_PATH = 'credentials/server1.pem' +_ROOT_CERTIFICATES_RESOURCE_PATH = "credentials/ca.pem" +_PRIVATE_KEY_RESOURCE_PATH = "credentials/server1.key" +_CERTIFICATE_CHAIN_RESOURCE_PATH = "credentials/server1.pem" def test_root_certificates(): @@ -35,8 +35,8 @@ def certificate_chain(): def parse_bool(value): - if value == 'true': + if value == "true": return True - if value == 'false': + if value == "false": return False - raise argparse.ArgumentTypeError('Only true/false allowed') + raise argparse.ArgumentTypeError("Only true/false allowed") diff --git a/src/python/grpcio_tests/tests/interop/server.py b/src/python/grpcio_tests/tests/interop/server.py index 6286733eddbb4..d51671ef8a979 100644 --- a/src/python/grpcio_tests/tests/interop/server.py +++ b/src/python/grpcio_tests/tests/interop/server.py @@ -30,18 +30,21 @@ def parse_interop_server_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--port', - type=int, - required=True, - help='the port on which to serve') - parser.add_argument('--use_tls', - default=False, - type=resources.parse_bool, - help='require a secure connection') - parser.add_argument('--use_alts', - default=False, - type=resources.parse_bool, - help='require an ALTS connection') + parser.add_argument( + "--port", type=int, required=True, help="the port on which to serve" + ) + parser.add_argument( + "--use_tls", + default=False, + type=resources.parse_bool, + help="require a secure connection", + ) + parser.add_argument( + "--use_alts", + default=False, + type=resources.parse_bool, + help="require an ALTS connection", + ) return parser.parse_args() @@ -58,19 +61,20 @@ def serve(): args = parse_interop_server_arguments() server = test_common.test_server() - test_pb2_grpc.add_TestServiceServicer_to_server(service.TestService(), - server) + test_pb2_grpc.add_TestServiceServicer_to_server( + service.TestService(), server + ) if args.use_tls or args.use_alts: credentials = get_server_credentials(args.use_tls) - server.add_secure_port('[::]:{}'.format(args.port), credentials) + server.add_secure_port("[::]:{}".format(args.port), credentials) else: - server.add_insecure_port('[::]:{}'.format(args.port)) + server.add_insecure_port("[::]:{}".format(args.port)) server.start() - _LOGGER.info('Server serving.') + _LOGGER.info("Server serving.") server.wait_for_termination() - _LOGGER.info('Server stopped; exiting.') + _LOGGER.info("Server stopped; exiting.") -if __name__ == '__main__': +if __name__ == "__main__": serve() diff --git a/src/python/grpcio_tests/tests/interop/service.py b/src/python/grpcio_tests/tests/interop/service.py index 08bb0c45a2450..dfe76e6a9e1bb 100644 --- a/src/python/grpcio_tests/tests/interop/service.py +++ b/src/python/grpcio_tests/tests/interop/service.py @@ -30,24 +30,27 @@ def _maybe_echo_metadata(servicer_context): """Copies metadata from request to response if it is present.""" invocation_metadata = dict(servicer_context.invocation_metadata()) if _INITIAL_METADATA_KEY in invocation_metadata: - initial_metadatum = (_INITIAL_METADATA_KEY, - invocation_metadata[_INITIAL_METADATA_KEY]) + initial_metadatum = ( + _INITIAL_METADATA_KEY, + invocation_metadata[_INITIAL_METADATA_KEY], + ) servicer_context.send_initial_metadata((initial_metadatum,)) if _TRAILING_METADATA_KEY in invocation_metadata: - trailing_metadatum = (_TRAILING_METADATA_KEY, - invocation_metadata[_TRAILING_METADATA_KEY]) + trailing_metadatum = ( + _TRAILING_METADATA_KEY, + invocation_metadata[_TRAILING_METADATA_KEY], + ) servicer_context.set_trailing_metadata((trailing_metadatum,)) def _maybe_echo_status_and_message(request, servicer_context): """Sets the response context code and details if the request asks for them""" - if request.HasField('response_status'): + if request.HasField("response_status"): servicer_context.set_code(request.response_status.code) servicer_context.set_details(request.response_status.message) class TestService(test_pb2_grpc.TestServiceServicer): - def EmptyCall(self, request, context): _maybe_echo_metadata(context) return empty_pb2.Empty() @@ -56,8 +59,11 @@ def UnaryCall(self, request, context): _maybe_echo_metadata(context) _maybe_echo_status_and_message(request, context) return messages_pb2.SimpleResponse( - payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE, - body=b'\x00' * request.response_size)) + payload=messages_pb2.Payload( + type=messages_pb2.COMPRESSABLE, + body=b"\x00" * request.response_size, + ) + ) def StreamingOutputCall(self, request, context): _maybe_echo_status_and_message(request, context) @@ -65,9 +71,11 @@ def StreamingOutputCall(self, request, context): if response_parameters.interval_us != 0: time.sleep(response_parameters.interval_us / _US_IN_A_SECOND) yield messages_pb2.StreamingOutputCallResponse( - payload=messages_pb2.Payload(type=request.response_type, - body=b'\x00' * - response_parameters.size)) + payload=messages_pb2.Payload( + type=request.response_type, + body=b"\x00" * response_parameters.size, + ) + ) def StreamingInputCall(self, request_iterator, context): aggregate_size = 0 @@ -75,7 +83,8 @@ def StreamingInputCall(self, request_iterator, context): if request.payload is not None and request.payload.body: aggregate_size += len(request.payload.body) return messages_pb2.StreamingInputCallResponse( - aggregated_payload_size=aggregate_size) + aggregated_payload_size=aggregate_size + ) def FullDuplexCall(self, request_iterator, context): _maybe_echo_metadata(context) @@ -83,12 +92,15 @@ def FullDuplexCall(self, request_iterator, context): _maybe_echo_status_and_message(request, context) for response_parameters in request.response_parameters: if response_parameters.interval_us != 0: - time.sleep(response_parameters.interval_us / - _US_IN_A_SECOND) + time.sleep( + response_parameters.interval_us / _US_IN_A_SECOND + ) yield messages_pb2.StreamingOutputCallResponse( - payload=messages_pb2.Payload(type=request.payload.type, - body=b'\x00' * - response_parameters.size)) + payload=messages_pb2.Payload( + type=request.payload.type, + body=b"\x00" * response_parameters.size, + ) + ) # NOTE(nathaniel): Apparently this is the same as the full-duplex call? # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)... diff --git a/src/python/grpcio_tests/tests/observability/_observability_test.py b/src/python/grpcio_tests/tests/observability/_observability_test.py index 2123eb5862709..df657934d45c5 100644 --- a/src/python/grpcio_tests/tests/observability/_observability_test.py +++ b/src/python/grpcio_tests/tests/observability/_observability_test.py @@ -27,34 +27,30 @@ logger = logging.getLogger(__name__) -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" STREAM_LENGTH = 5 -CONFIG_ENV_VAR_NAME = 'GRPC_GCP_OBSERVABILITY_CONFIG' -CONFIG_FILE_ENV_VAR_NAME = 'GRPC_GCP_OBSERVABILITY_CONFIG_FILE' +CONFIG_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG" +CONFIG_FILE_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG_FILE" _VALID_CONFIG_TRACING_STATS = { - 'project_id': 'test-project', - 'cloud_trace': { - 'sampling_rate': 1.00 - }, - 'cloud_monitoring': {} + "project_id": "test-project", + "cloud_trace": {"sampling_rate": 1.00}, + "cloud_monitoring": {}, } _VALID_CONFIG_TRACING_ONLY = { - 'project_id': 'test-project', - 'cloud_trace': { - 'sampling_rate': 1.00 - }, + "project_id": "test-project", + "cloud_trace": {"sampling_rate": 1.00}, } _VALID_CONFIG_STATS_ONLY = { - 'project_id': 'test-project', - 'cloud_monitoring': {} + "project_id": "test-project", + "cloud_monitoring": {}, } _VALID_CONFIG_STATS_ONLY_STR = """ { @@ -65,23 +61,27 @@ # Depends on grpc_core::IsTransportSuppliesClientLatencyEnabled, # the following metrcis might not exist. _SKIP_VEFIRY = [_cyobservability.MetricsName.CLIENT_TRANSPORT_LATENCY] -_SPAN_PREFIXS = ['Recv', 'Sent', 'Attempt'] +_SPAN_PREFIXS = ["Recv", "Sent", "Attempt"] class TestExporter(_observability.Exporter): - - def __init__(self, metrics: List[_observability.StatsData], - spans: List[_observability.TracingData]): + def __init__( + self, + metrics: List[_observability.StatsData], + spans: List[_observability.TracingData], + ): self.span_collecter = spans self.metric_collecter = metrics self._server = None - def export_stats_data(self, - stats_data: List[_observability.StatsData]) -> None: + def export_stats_data( + self, stats_data: List[_observability.StatsData] + ) -> None: self.metric_collecter.extend(stats_data) def export_tracing_data( - self, tracing_data: List[_observability.TracingData]) -> None: + self, tracing_data: List[_observability.TracingData] + ) -> None: self.span_collecter.extend(tracing_data) @@ -104,7 +104,6 @@ def handle_stream_stream(request_iterator, servicer_context): class _MethodHandler(grpc.RpcMethodHandler): - def __init__(self, request_streaming, response_streaming): self.request_streaming = request_streaming self.response_streaming = response_streaming @@ -125,7 +124,6 @@ def __init__(self, request_streaming, response_streaming): class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: return _MethodHandler(False, False) @@ -140,7 +138,6 @@ def service(self, handler_call_details): class ObservabilityTest(unittest.TestCase): - def setUp(self): self.all_metric = [] self.all_span = [] @@ -149,15 +146,16 @@ def setUp(self): self._port = None def tearDown(self): - os.environ[CONFIG_ENV_VAR_NAME] = '' - os.environ[CONFIG_FILE_ENV_VAR_NAME] = '' + os.environ[CONFIG_ENV_VAR_NAME] = "" + os.environ[CONFIG_FILE_ENV_VAR_NAME] = "" if self._server: self._server.stop(0) def testRecordUnaryUnary(self): self._set_config_file(_VALID_CONFIG_TRACING_STATS) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.unary_unary_call() @@ -169,24 +167,27 @@ def testRecordUnaryUnary(self): def testThrowErrorWithoutConfig(self): with self.assertRaises(ValueError): with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): pass def testThrowErrorWithInvalidConfig(self): - _INVALID_CONFIG = 'INVALID' + _INVALID_CONFIG = "INVALID" self._set_config_file(_INVALID_CONFIG) with self.assertRaises(ValueError): with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): pass def testNoErrorAndDataWithEmptyConfig(self): _EMPTY_CONFIG = {} self._set_config_file(_EMPTY_CONFIG) # Empty config still require project_id - os.environ['GCP_PROJECT'] = 'test-project' + os.environ["GCP_PROJECT"] = "test-project" with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.unary_unary_call() @@ -197,13 +198,15 @@ def testThrowErrorWhenCallingMultipleInit(self): self._set_config_file(_VALID_CONFIG_TRACING_STATS) with self.assertRaises(ValueError): with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter) as o11y: + exporter=self.test_exporter + ) as o11y: grpc._observability.observability_init(o11y) def testRecordUnaryUnaryStatsOnly(self): self._set_config_file(_VALID_CONFIG_STATS_ONLY) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.unary_unary_call() @@ -214,7 +217,8 @@ def testRecordUnaryUnaryStatsOnly(self): def testRecordUnaryUnaryTracingOnly(self): self._set_config_file(_VALID_CONFIG_TRACING_ONLY) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.unary_unary_call() @@ -225,7 +229,8 @@ def testRecordUnaryUnaryTracingOnly(self): def testRecordUnaryStream(self): self._set_config_file(_VALID_CONFIG_TRACING_STATS) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.unary_stream_call() @@ -237,7 +242,8 @@ def testRecordUnaryStream(self): def testRecordStreamUnary(self): self._set_config_file(_VALID_CONFIG_TRACING_STATS) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.stream_unary_call() @@ -249,7 +255,8 @@ def testRecordStreamUnary(self): def testRecordStreamStream(self): self._set_config_file(_VALID_CONFIG_TRACING_STATS) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.stream_stream_call() @@ -267,7 +274,8 @@ def testNoRecordBeforeInit(self): self._server.stop(0) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.unary_unary_call() @@ -279,7 +287,8 @@ def testNoRecordBeforeInit(self): def testNoRecordAfterExit(self): self._set_config_file(_VALID_CONFIG_TRACING_STATS) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.unary_unary_call() @@ -302,14 +311,13 @@ def testTraceSamplingRate(self): _LOWER_BOUND = 10 * 3 _HIGHER_BOUND = 30 * 3 _VALID_CONFIG_TRACING_ONLY_SAMPLE_HALF = { - 'project_id': 'test-project', - 'cloud_trace': { - 'sampling_rate': 0.5 - }, + "project_id": "test-project", + "cloud_trace": {"sampling_rate": 0.5}, } self._set_config_file(_VALID_CONFIG_TRACING_ONLY_SAMPLE_HALF) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() for _ in range(_CALLS): self.unary_unary_call() @@ -326,7 +334,8 @@ def testConfigFileOverEnvVar(self): self._set_config_file(_VALID_CONFIG_TRACING_ONLY) with grpc_observability.GCPOpenCensusObservability( - exporter=self.test_exporter): + exporter=self.test_exporter + ): self._start_server() self.unary_unary_call() @@ -336,31 +345,32 @@ def testConfigFileOverEnvVar(self): def _set_config_file(self, config: Dict[str, Any]) -> None: # Using random name here so multiple tests can run with different config files. - config_file_path = '/tmp/' + str(random.randint(0, 100000)) - with open(config_file_path, 'w', encoding='utf-8') as f: + config_file_path = "/tmp/" + str(random.randint(0, 100000)) + with open(config_file_path, "w", encoding="utf-8") as f: f.write(json.dumps(config)) os.environ[CONFIG_FILE_ENV_VAR_NAME] = config_file_path def unary_unary_call(self): - with grpc.insecure_channel(f'localhost:{self._port}') as channel: + with grpc.insecure_channel(f"localhost:{self._port}") as channel: multi_callable = channel.unary_unary(_UNARY_UNARY) unused_response, call = multi_callable.with_call(_REQUEST) def unary_stream_call(self): - with grpc.insecure_channel(f'localhost:{self._port}') as channel: + with grpc.insecure_channel(f"localhost:{self._port}") as channel: multi_callable = channel.unary_stream(_UNARY_STREAM) call = multi_callable(_REQUEST) for _ in call: pass def stream_unary_call(self): - with grpc.insecure_channel(f'localhost:{self._port}') as channel: + with grpc.insecure_channel(f"localhost:{self._port}") as channel: multi_callable = channel.stream_unary(_STREAM_UNARY) unused_response, call = multi_callable.with_call( - iter([_REQUEST] * STREAM_LENGTH)) + iter([_REQUEST] * STREAM_LENGTH) + ) def stream_stream_call(self): - with grpc.insecure_channel(f'localhost:{self._port}') as channel: + with grpc.insecure_channel(f"localhost:{self._port}") as channel: multi_callable = channel.stream_stream(_STREAM_STREAM) call = multi_callable(iter([_REQUEST] * STREAM_LENGTH)) for _ in call: @@ -369,30 +379,37 @@ def stream_stream_call(self): def _start_server(self) -> None: self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) self._server.add_generic_rpc_handlers((_GenericHandler(),)) - self._port = self._server.add_insecure_port('[::]:0') + self._port = self._server.add_insecure_port("[::]:0") self._server.start() - def _validate_metrics(self, - metrics: List[_observability.StatsData]) -> None: + def _validate_metrics( + self, metrics: List[_observability.StatsData] + ) -> None: metric_names = set(metric.name for metric in metrics) for name in _cyobservability.MetricsName: if name in _SKIP_VEFIRY: continue if name not in metric_names: - logger.error('metric %s not found in exported metrics: %s!', - name, metric_names) + logger.error( + "metric %s not found in exported metrics: %s!", + name, + metric_names, + ) self.assertTrue(name in metric_names) - def _validate_spans(self, - tracing_data: List[_observability.TracingData]) -> None: + def _validate_spans( + self, tracing_data: List[_observability.TracingData] + ) -> None: span_names = set(data.name for data in tracing_data) for prefix in _SPAN_PREFIXS: prefix_exist = any(prefix in name for name in span_names) if not prefix_exist: logger.error( - 'missing span with prefix %s in exported spans: %s!', - prefix, span_names) + "missing span with prefix %s in exported spans: %s!", + prefix, + span_names, + ) self.assertTrue(prefix_exist) diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py index 1b7adb01d1fba..1dda7149d654c 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py @@ -36,13 +36,12 @@ from tests.unit.framework.common import test_constants # Identifiers of entities we expect to find in the generated module. -STUB_IDENTIFIER = 'TestServiceStub' -SERVICER_IDENTIFIER = 'TestServiceServicer' -ADD_SERVICER_TO_SERVER_IDENTIFIER = 'add_TestServiceServicer_to_server' +STUB_IDENTIFIER = "TestServiceStub" +SERVICER_IDENTIFIER = "TestServiceServicer" +ADD_SERVICER_TO_SERVER_IDENTIFIER = "add_TestServiceServicer_to_server" class _ServicerMethods(object): - def __init__(self): self._condition = threading.Condition() self._paused = False @@ -75,7 +74,7 @@ def _control(self): # pylint: disable=invalid-name def UnaryCall(self, request, unused_rpc_context): response = response_pb2.SimpleResponse() response.payload.payload_type = payload_pb2.COMPRESSABLE - response.payload.payload_compressable = 'a' * request.response_size + response.payload.payload_compressable = "a" * request.response_size self._control() return response @@ -83,7 +82,7 @@ def StreamingOutputCall(self, request, unused_rpc_context): for parameter in request.response_parameters: response = response_pb2.StreamingOutputCallResponse() response.payload.payload_type = payload_pb2.COMPRESSABLE - response.payload.payload_compressable = 'a' * parameter.size + response.payload.payload_compressable = "a" * parameter.size self._control() yield response @@ -101,7 +100,7 @@ def FullDuplexCall(self, request_iter, unused_rpc_context): for parameter in request.response_parameters: response = response_pb2.StreamingOutputCallResponse() response.payload.payload_type = payload_pb2.COMPRESSABLE - response.payload.payload_compressable = 'a' * parameter.size + response.payload.payload_compressable = "a" * parameter.size self._control() yield response @@ -111,7 +110,7 @@ def HalfDuplexCall(self, request_iter, unused_rpc_context): for parameter in request.response_parameters: response = response_pb2.StreamingOutputCallResponse() response.payload.payload_type = payload_pb2.COMPRESSABLE - response.payload.payload_compressable = 'a' * parameter.size + response.payload.payload_compressable = "a" * parameter.size self._control() responses.append(response) for response in responses: @@ -119,30 +118,33 @@ def HalfDuplexCall(self, request_iter, unused_rpc_context): class _Service( - collections.namedtuple('_Service', ( - 'servicer_methods', - 'server', - 'stub', - ))): + collections.namedtuple( + "_Service", + ( + "servicer_methods", + "server", + "stub", + ), + ) +): """A live and running service. - Attributes: - servicer_methods: The _ServicerMethods servicing RPCs. - server: The grpc.Server servicing RPCs. - stub: A stub on which to invoke RPCs. - """ + Attributes: + servicer_methods: The _ServicerMethods servicing RPCs. + server: The grpc.Server servicing RPCs. + stub: A stub on which to invoke RPCs. + """ def _CreateService(): """Provides a servicer backend and a stub. - Returns: - A _Service with which to test RPCs. - """ + Returns: + A _Service with which to test RPCs. + """ servicer_methods = _ServicerMethods() class Servicer(getattr(service_pb2_grpc, SERVICER_IDENTIFIER)): - def UnaryCall(self, request, context): return servicer_methods.UnaryCall(request, context) @@ -150,8 +152,9 @@ def StreamingOutputCall(self, request, context): return servicer_methods.StreamingOutputCall(request, context) def StreamingInputCall(self, request_iterator, context): - return servicer_methods.StreamingInputCall(request_iterator, - context) + return servicer_methods.StreamingInputCall( + request_iterator, context + ) def FullDuplexCall(self, request_iterator, context): return servicer_methods.FullDuplexCall(request_iterator, context) @@ -160,11 +163,12 @@ def HalfDuplexCall(self, request_iterator, context): return servicer_methods.HalfDuplexCall(request_iterator, context) server = test_common.test_server() - getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), - server) - port = server.add_insecure_port('[::]:0') + getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER)( + Servicer(), server + ) + port = server.add_insecure_port("[::]:0") server.start() - channel = grpc.insecure_channel('localhost:{}'.format(port)) + channel = grpc.insecure_channel("localhost:{}".format(port)) stub = getattr(service_pb2_grpc, STUB_IDENTIFIER)(channel) return _Service(servicer_methods, server, stub) @@ -172,20 +176,21 @@ def HalfDuplexCall(self, request_iterator, context): def _CreateIncompleteService(): """Provides a servicer backend that fails to implement methods and its stub. - Returns: - A _Service with which to test RPCs. The returned _Service's - servicer_methods implements none of the methods required of it. - """ + Returns: + A _Service with which to test RPCs. The returned _Service's + servicer_methods implements none of the methods required of it. + """ class Servicer(getattr(service_pb2_grpc, SERVICER_IDENTIFIER)): pass server = test_common.test_server() - getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), - server) - port = server.add_insecure_port('[::]:0') + getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER)( + Servicer(), server + ) + port = server.add_insecure_port("[::]:0") server.start() - channel = grpc.insecure_channel('localhost:{}'.format(port)) + channel = grpc.insecure_channel("localhost:{}".format(port)) stub = getattr(service_pb2_grpc, STUB_IDENTIFIER)(channel) return _Service(None, server, stub) @@ -194,7 +199,7 @@ def _streaming_input_request_iterator(): for _ in range(3): request = request_pb2.StreamingInputCallRequest() request.payload.payload_type = payload_pb2.COMPRESSABLE - request.payload.payload_compressable = 'a' + request.payload.payload_compressable = "a" yield request @@ -220,18 +225,20 @@ def _full_duplex_request_iterator(): class PythonPluginTest(unittest.TestCase): """Test case for the gRPC Python protoc-plugin. - While reading these tests, remember that the futures API - (`stub.method.future()`) only gives futures for the *response-unary* - methods and does not exist for response-streaming methods. - """ + While reading these tests, remember that the futures API + (`stub.method.future()`) only gives futures for the *response-unary* + methods and does not exist for response-streaming methods. + """ def testImportAttributes(self): # check that we can access the generated module and its members. self.assertIsNotNone(getattr(service_pb2_grpc, STUB_IDENTIFIER, None)) self.assertIsNotNone( - getattr(service_pb2_grpc, SERVICER_IDENTIFIER, None)) + getattr(service_pb2_grpc, SERVICER_IDENTIFIER, None) + ) self.assertIsNotNone( - getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER, None)) + getattr(service_pb2_grpc, ADD_SERVICER_TO_SERVER_IDENTIFIER, None) + ) def testUpDown(self): service = _CreateService() @@ -245,8 +252,9 @@ def testIncompleteServicer(self): request = request_pb2.SimpleRequest(response_size=13) with self.assertRaises(grpc.RpcError) as exception_context: service.stub.UnaryCall(request) - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.UNIMPLEMENTED) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.UNIMPLEMENTED + ) service.server.stop(None) def testUnaryCall(self): @@ -254,7 +262,8 @@ def testUnaryCall(self): request = request_pb2.SimpleRequest(response_size=13) response = service.stub.UnaryCall(request) expected_response = service.servicer_methods.UnaryCall( - request, 'not a real context!') + request, "not a real context!" + ) self.assertEqual(expected_response, response) service.server.stop(None) @@ -266,7 +275,8 @@ def testUnaryCallFuture(self): response_future = service.stub.UnaryCall.future(request) response = response_future.result() expected_response = service.servicer_methods.UnaryCall( - request, 'not a real RpcContext!') + request, "not a real RpcContext!" + ) self.assertEqual(expected_response, response) service.server.stop(None) @@ -275,11 +285,14 @@ def testUnaryCallFutureExpired(self): request = request_pb2.SimpleRequest(response_size=13) with service.servicer_methods.pause(): response_future = service.stub.UnaryCall.future( - request, timeout=test_constants.SHORT_TIMEOUT) + request, timeout=test_constants.SHORT_TIMEOUT + ) with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertIs( + exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED, + ) self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED) service.server.stop(None) @@ -307,9 +320,11 @@ def testStreamingOutputCall(self): request = _streaming_output_request() responses = service.stub.StreamingOutputCall(request) expected_responses = service.servicer_methods.StreamingOutputCall( - request, 'not a real RpcContext!') + request, "not a real RpcContext!" + ) for expected_response, response in itertools.zip_longest( - expected_responses, responses): + expected_responses, responses + ): self.assertEqual(expected_response, response) service.server.stop(None) @@ -318,11 +333,14 @@ def testStreamingOutputCallExpired(self): request = _streaming_output_request() with service.servicer_methods.pause(): responses = service.stub.StreamingOutputCall( - request, timeout=test_constants.SHORT_TIMEOUT) + request, timeout=test_constants.SHORT_TIMEOUT + ) with self.assertRaises(grpc.RpcError) as exception_context: list(responses) - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertIs( + exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED, + ) service.server.stop(None) def testStreamingOutputCallCancelled(self): @@ -344,16 +362,19 @@ def testStreamingOutputCallFailed(self): self.assertIsNotNone(responses) with self.assertRaises(grpc.RpcError) as exception_context: next(responses) - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.UNKNOWN) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.UNKNOWN + ) service.server.stop(None) def testStreamingInputCall(self): service = _CreateService() response = service.stub.StreamingInputCall( - _streaming_input_request_iterator()) + _streaming_input_request_iterator() + ) expected_response = service.servicer_methods.StreamingInputCall( - _streaming_input_request_iterator(), 'not a real RpcContext!') + _streaming_input_request_iterator(), "not a real RpcContext!" + ) self.assertEqual(expected_response, response) service.server.stop(None) @@ -361,10 +382,12 @@ def testStreamingInputCallFuture(self): service = _CreateService() with service.servicer_methods.pause(): response_future = service.stub.StreamingInputCall.future( - _streaming_input_request_iterator()) + _streaming_input_request_iterator() + ) response = response_future.result() expected_response = service.servicer_methods.StreamingInputCall( - _streaming_input_request_iterator(), 'not a real RpcContext!') + _streaming_input_request_iterator(), "not a real RpcContext!" + ) self.assertEqual(expected_response, response) service.server.stop(None) @@ -373,21 +396,27 @@ def testStreamingInputCallFutureExpired(self): with service.servicer_methods.pause(): response_future = service.stub.StreamingInputCall.future( _streaming_input_request_iterator(), - timeout=test_constants.SHORT_TIMEOUT) + timeout=test_constants.SHORT_TIMEOUT, + ) with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() self.assertIsInstance(response_future.exception(), grpc.RpcError) - self.assertIs(response_future.exception().code(), - grpc.StatusCode.DEADLINE_EXCEEDED) - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertIs( + response_future.exception().code(), + grpc.StatusCode.DEADLINE_EXCEEDED, + ) + self.assertIs( + exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED, + ) service.server.stop(None) def testStreamingInputCallFutureCancelled(self): service = _CreateService() with service.servicer_methods.pause(): response_future = service.stub.StreamingInputCall.future( - _streaming_input_request_iterator()) + _streaming_input_request_iterator() + ) response_future.cancel() self.assertTrue(response_future.cancelled()) with self.assertRaises(grpc.FutureCancelledError): @@ -398,7 +427,8 @@ def testStreamingInputCallFutureFailed(self): service = _CreateService() with service.servicer_methods.fail(): response_future = service.stub.StreamingInputCall.future( - _streaming_input_request_iterator()) + _streaming_input_request_iterator() + ) self.assertIsNotNone(response_future.exception()) self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN) service.server.stop(None) @@ -407,9 +437,11 @@ def testFullDuplexCall(self): service = _CreateService() responses = service.stub.FullDuplexCall(_full_duplex_request_iterator()) expected_responses = service.servicer_methods.FullDuplexCall( - _full_duplex_request_iterator(), 'not a real RpcContext!') + _full_duplex_request_iterator(), "not a real RpcContext!" + ) for expected_response, response in itertools.zip_longest( - expected_responses, responses): + expected_responses, responses + ): self.assertEqual(expected_response, response) service.server.stop(None) @@ -418,11 +450,14 @@ def testFullDuplexCallExpired(self): service = _CreateService() with service.servicer_methods.pause(): responses = service.stub.FullDuplexCall( - request_iterator, timeout=test_constants.SHORT_TIMEOUT) + request_iterator, timeout=test_constants.SHORT_TIMEOUT + ) with self.assertRaises(grpc.RpcError) as exception_context: list(responses) - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertIs( + exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED, + ) service.server.stop(None) def testFullDuplexCallCancelled(self): @@ -433,8 +468,9 @@ def testFullDuplexCallCancelled(self): responses.cancel() with self.assertRaises(grpc.RpcError) as exception_context: next(responses) - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.CANCELLED) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.CANCELLED + ) service.server.stop(None) def testFullDuplexCallFailed(self): @@ -444,8 +480,9 @@ def testFullDuplexCallFailed(self): responses = service.stub.FullDuplexCall(request_iterator) with self.assertRaises(grpc.RpcError) as exception_context: next(responses) - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.UNKNOWN) + self.assertIs( + exception_context.exception.code(), grpc.StatusCode.UNKNOWN + ) service.server.stop(None) def testHalfDuplexCall(self): @@ -462,9 +499,11 @@ def half_duplex_request_iterator(): responses = service.stub.HalfDuplexCall(half_duplex_request_iterator()) expected_responses = service.servicer_methods.HalfDuplexCall( - half_duplex_request_iterator(), 'not a real RpcContext!') + half_duplex_request_iterator(), "not a real RpcContext!" + ) for expected_response, response in itertools.zip_longest( - expected_responses, responses): + expected_responses, responses + ): self.assertEqual(expected_response, response) service.server.stop(None) @@ -494,50 +533,60 @@ def half_duplex_request_iterator(): with wait(): responses = service.stub.HalfDuplexCall( half_duplex_request_iterator(), - timeout=test_constants.SHORT_TIMEOUT) + timeout=test_constants.SHORT_TIMEOUT, + ) # half-duplex waits for the client to send all info with self.assertRaises(grpc.RpcError) as exception_context: next(responses) - self.assertIs(exception_context.exception.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertIs( + exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED, + ) service.server.stop(None) -@unittest.skipIf(sys.version_info[0] < 3 or sys.version_info[1] < 6, - "Unsupported on Python 2.") +@unittest.skipIf( + sys.version_info[0] < 3 or sys.version_info[1] < 6, + "Unsupported on Python 2.", +) class SimpleStubsPluginTest(unittest.TestCase): servicer_methods = _ServicerMethods() class Servicer(service_pb2_grpc.TestServiceServicer): - def UnaryCall(self, request, context): return SimpleStubsPluginTest.servicer_methods.UnaryCall( - request, context) + request, context + ) def StreamingOutputCall(self, request, context): return SimpleStubsPluginTest.servicer_methods.StreamingOutputCall( - request, context) + request, context + ) def StreamingInputCall(self, request_iterator, context): return SimpleStubsPluginTest.servicer_methods.StreamingInputCall( - request_iterator, context) + request_iterator, context + ) def FullDuplexCall(self, request_iterator, context): return SimpleStubsPluginTest.servicer_methods.FullDuplexCall( - request_iterator, context) + request_iterator, context + ) def HalfDuplexCall(self, request_iterator, context): return SimpleStubsPluginTest.servicer_methods.HalfDuplexCall( - request_iterator, context) + request_iterator, context + ) def setUp(self): super(SimpleStubsPluginTest, self).setUp() self._server = test_common.test_server() service_pb2_grpc.add_TestServiceServicer_to_server( - self.Servicer(), self._server) - self._port = self._server.add_insecure_port('[::]:0') + self.Servicer(), self._server + ) + self._port = self._server.add_insecure_port("[::]:0") self._server.start() - self._target = 'localhost:{}'.format(self._port) + self._target = "localhost:{}".format(self._port) def tearDown(self): self._server.stop(None) @@ -548,63 +597,68 @@ def testUnaryCall(self): response = service_pb2_grpc.TestService.UnaryCall( request, self._target, - channel_credentials=grpc.experimental.insecure_channel_credentials( - ), - wait_for_ready=True) + channel_credentials=grpc.experimental.insecure_channel_credentials(), + wait_for_ready=True, + ) expected_response = self.servicer_methods.UnaryCall( - request, 'not a real context!') + request, "not a real context!" + ) self.assertEqual(expected_response, response) def testUnaryCallInsecureSugar(self): request = request_pb2.SimpleRequest(response_size=13) - response = service_pb2_grpc.TestService.UnaryCall(request, - self._target, - insecure=True, - wait_for_ready=True) + response = service_pb2_grpc.TestService.UnaryCall( + request, self._target, insecure=True, wait_for_ready=True + ) expected_response = self.servicer_methods.UnaryCall( - request, 'not a real context!') + request, "not a real context!" + ) self.assertEqual(expected_response, response) def testStreamingOutputCall(self): request = _streaming_output_request() expected_responses = self.servicer_methods.StreamingOutputCall( - request, 'not a real RpcContext!') + request, "not a real RpcContext!" + ) responses = service_pb2_grpc.TestService.StreamingOutputCall( request, self._target, - channel_credentials=grpc.experimental.insecure_channel_credentials( - ), - wait_for_ready=True) + channel_credentials=grpc.experimental.insecure_channel_credentials(), + wait_for_ready=True, + ) for expected_response, response in itertools.zip_longest( - expected_responses, responses): + expected_responses, responses + ): self.assertEqual(expected_response, response) def testStreamingInputCall(self): response = service_pb2_grpc.TestService.StreamingInputCall( _streaming_input_request_iterator(), self._target, - channel_credentials=grpc.experimental.insecure_channel_credentials( - ), - wait_for_ready=True) + channel_credentials=grpc.experimental.insecure_channel_credentials(), + wait_for_ready=True, + ) expected_response = self.servicer_methods.StreamingInputCall( - _streaming_input_request_iterator(), 'not a real RpcContext!') + _streaming_input_request_iterator(), "not a real RpcContext!" + ) self.assertEqual(expected_response, response) def testFullDuplexCall(self): responses = service_pb2_grpc.TestService.FullDuplexCall( _full_duplex_request_iterator(), self._target, - channel_credentials=grpc.experimental.insecure_channel_credentials( - ), - wait_for_ready=True) + channel_credentials=grpc.experimental.insecure_channel_credentials(), + wait_for_ready=True, + ) expected_responses = self.servicer_methods.FullDuplexCall( - _full_duplex_request_iterator(), 'not a real RpcContext!') + _full_duplex_request_iterator(), "not a real RpcContext!" + ) for expected_response, response in itertools.zip_longest( - expected_responses, responses): + expected_responses, responses + ): self.assertEqual(expected_response, response) def testHalfDuplexCall(self): - def half_duplex_request_iterator(): request = request_pb2.StreamingOutputCallRequest() request.response_parameters.add(size=1, interval_us=0) @@ -617,37 +671,48 @@ def half_duplex_request_iterator(): responses = service_pb2_grpc.TestService.HalfDuplexCall( half_duplex_request_iterator(), self._target, - channel_credentials=grpc.experimental.insecure_channel_credentials( - ), - wait_for_ready=True) + channel_credentials=grpc.experimental.insecure_channel_credentials(), + wait_for_ready=True, + ) expected_responses = self.servicer_methods.HalfDuplexCall( - half_duplex_request_iterator(), 'not a real RpcContext!') + half_duplex_request_iterator(), "not a real RpcContext!" + ) for expected_response, response in itertools.zip_longest( - expected_responses, responses): + expected_responses, responses + ): self.assertEqual(expected_response, response) class ModuleMainTest(unittest.TestCase): - """Test case for running `python -m grpc_tools.protoc`. - """ + """Test case for running `python -m grpc_tools.protoc`.""" def test_clean_output(self): if sys.executable is None: raise unittest.SkipTest( - "Running on a interpreter that cannot be invoked from the CLI.") + "Running on a interpreter that cannot be invoked from the CLI." + ) proto_dir_path = os.path.join("src", "proto") - test_proto_path = os.path.join(proto_dir_path, "grpc", "testing", - "empty.proto") + test_proto_path = os.path.join( + proto_dir_path, "grpc", "testing", "empty.proto" + ) streams = tuple(tempfile.TemporaryFile() for _ in range(2)) work_dir = tempfile.mkdtemp() try: - invocation = (sys.executable, "-m", "grpc_tools.protoc", - "--proto_path", proto_dir_path, "--python_out", - work_dir, "--grpc_python_out", work_dir, - test_proto_path) - proc = subprocess.Popen(invocation, - stdout=streams[0], - stderr=streams[1]) + invocation = ( + sys.executable, + "-m", + "grpc_tools.protoc", + "--proto_path", + proto_dir_path, + "--python_out", + work_dir, + "--grpc_python_out", + work_dir, + test_proto_path, + ) + proc = subprocess.Popen( + invocation, stdout=streams[0], stderr=streams[1] + ) proc.wait() outs = [] for stream in streams: @@ -658,5 +723,5 @@ def test_clean_output(self): shutil.rmtree(work_dir) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py index 19a295a0467c5..9b8c3acf94870 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py @@ -31,11 +31,11 @@ from tests.unit import test_common _MESSAGES_IMPORT = b'import "messages.proto";' -_SPLIT_NAMESPACE = b'package grpc_protoc_plugin.invocation_testing.split;' -_COMMON_NAMESPACE = b'package grpc_protoc_plugin.invocation_testing;' +_SPLIT_NAMESPACE = b"package grpc_protoc_plugin.invocation_testing.split;" +_COMMON_NAMESPACE = b"package grpc_protoc_plugin.invocation_testing;" -_RELATIVE_PROTO_PATH = 'relative_proto_path' -_RELATIVE_PYTHON_OUT = 'relative_python_out' +_RELATIVE_PROTO_PATH = "relative_proto_path" +_RELATIVE_PYTHON_OUT = "relative_python_out" _TEST_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -53,7 +53,7 @@ def _system_path(path_insertion): def _create_directory_tree(root, path_components_sequence): created = set() for path_components in path_components_sequence: - thus_far = '' + thus_far = "" for path_component in path_components: relative_path = path.join(thus_far, path_component) if relative_path not in created: @@ -62,29 +62,35 @@ def _create_directory_tree(root, path_components_sequence): thus_far = path.join(thus_far, path_component) -def _massage_proto_content(proto_content, test_name_bytes, - messages_proto_relative_file_name_bytes): - package_substitution = (b'package grpc_protoc_plugin.invocation_testing.' + - test_name_bytes + b';') +def _massage_proto_content( + proto_content, test_name_bytes, messages_proto_relative_file_name_bytes +): + package_substitution = ( + b"package grpc_protoc_plugin.invocation_testing." + + test_name_bytes + + b";" + ) common_namespace_substituted = proto_content.replace( - _COMMON_NAMESPACE, package_substitution) + _COMMON_NAMESPACE, package_substitution + ) split_namespace_substituted = common_namespace_substituted.replace( - _SPLIT_NAMESPACE, package_substitution) + _SPLIT_NAMESPACE, package_substitution + ) message_import_replaced = split_namespace_substituted.replace( _MESSAGES_IMPORT, - b'import "' + messages_proto_relative_file_name_bytes + b'";') + b'import "' + messages_proto_relative_file_name_bytes + b'";', + ) return message_import_replaced def _packagify(directory): for subdirectory, _, _ in os.walk(directory): - init_file_name = path.join(subdirectory, '__init__.py') - with open(init_file_name, 'wb') as init_file: - init_file.write(b'') + init_file_name = path.join(subdirectory, "__init__.py") + with open(init_file_name, "wb") as init_file: + init_file.write(b"") class _Servicer(object): - def __init__(self, response_class): self._response_class = response_class @@ -92,78 +98,98 @@ def Call(self, request, context): return self._response_class() -def _protoc(proto_path, python_out, grpc_python_out_flag, grpc_python_out, - absolute_proto_file_names): +def _protoc( + proto_path, + python_out, + grpc_python_out_flag, + grpc_python_out, + absolute_proto_file_names, +): args = [ - '', - '--proto_path={}'.format(proto_path), + "", + "--proto_path={}".format(proto_path), ] if python_out is not None: - args.append('--python_out={}'.format(python_out)) + args.append("--python_out={}".format(python_out)) if grpc_python_out is not None: - args.append('--grpc_python_out={}:{}'.format(grpc_python_out_flag, - grpc_python_out)) + args.append( + "--grpc_python_out={}:{}".format( + grpc_python_out_flag, grpc_python_out + ) + ) args.extend(absolute_proto_file_names) return protoc.main(args) class _Mid2016ProtocStyle(object): - def name(self): - return 'Mid2016ProtocStyle' + return "Mid2016ProtocStyle" def grpc_in_pb2_expected(self): return True def protoc(self, proto_path, python_out, absolute_proto_file_names): - return (_protoc(proto_path, python_out, 'grpc_1_0', python_out, - absolute_proto_file_names),) + return ( + _protoc( + proto_path, + python_out, + "grpc_1_0", + python_out, + absolute_proto_file_names, + ), + ) class _SingleProtocExecutionProtocStyle(object): - def name(self): - return 'SingleProtocExecutionProtocStyle' + return "SingleProtocExecutionProtocStyle" def grpc_in_pb2_expected(self): return False def protoc(self, proto_path, python_out, absolute_proto_file_names): - return (_protoc(proto_path, python_out, 'grpc_2_0', python_out, - absolute_proto_file_names),) + return ( + _protoc( + proto_path, + python_out, + "grpc_2_0", + python_out, + absolute_proto_file_names, + ), + ) class _ProtoBeforeGrpcProtocStyle(object): - def name(self): - return 'ProtoBeforeGrpcProtocStyle' + return "ProtoBeforeGrpcProtocStyle" def grpc_in_pb2_expected(self): return False def protoc(self, proto_path, python_out, absolute_proto_file_names): - pb2_protoc_exit_code = _protoc(proto_path, python_out, None, None, - absolute_proto_file_names) - pb2_grpc_protoc_exit_code = _protoc(proto_path, None, 'grpc_2_0', - python_out, - absolute_proto_file_names) + pb2_protoc_exit_code = _protoc( + proto_path, python_out, None, None, absolute_proto_file_names + ) + pb2_grpc_protoc_exit_code = _protoc( + proto_path, None, "grpc_2_0", python_out, absolute_proto_file_names + ) return pb2_protoc_exit_code, pb2_grpc_protoc_exit_code class _GrpcBeforeProtoProtocStyle(object): - def name(self): - return 'GrpcBeforeProtoProtocStyle' + return "GrpcBeforeProtoProtocStyle" def grpc_in_pb2_expected(self): return False def protoc(self, proto_path, python_out, absolute_proto_file_names): - pb2_grpc_protoc_exit_code = _protoc(proto_path, None, 'grpc_2_0', - python_out, - absolute_proto_file_names) - pb2_protoc_exit_code = _protoc(proto_path, python_out, None, None, - absolute_proto_file_names) + pb2_grpc_protoc_exit_code = _protoc( + proto_path, None, "grpc_2_0", python_out, absolute_proto_file_names + ) + pb2_protoc_exit_code = _protoc( + proto_path, python_out, None, None, absolute_proto_file_names + ) return pb2_grpc_protoc_exit_code, pb2_protoc_exit_code @@ -175,12 +201,12 @@ def protoc(self, proto_path, python_out, absolute_proto_file_names): ) -@unittest.skipIf(platform.python_implementation() == 'PyPy', - 'Skip test if run with PyPy!') +@unittest.skipIf( + platform.python_implementation() == "PyPy", "Skip test if run with PyPy!" +) class _Test(unittest.TestCase, metaclass=abc.ABCMeta): - def setUp(self): - self._directory = tempfile.mkdtemp(suffix=self.NAME, dir='.') + self._directory = tempfile.mkdtemp(suffix=self.NAME, dir=".") self._proto_path = path.join(self._directory, _RELATIVE_PROTO_PATH) self._python_out = path.join(self._directory, _RELATIVE_PYTHON_OUT) @@ -197,24 +223,32 @@ def setUp(self): self.SERVICES_PROTO_FILE_NAME, ), } - messages_proto_relative_file_name_forward_slashes = '/'.join( - self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES + - (self.MESSAGES_PROTO_FILE_NAME,)) + messages_proto_relative_file_name_forward_slashes = "/".join( + self.MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES + + (self.MESSAGES_PROTO_FILE_NAME,) + ) _create_directory_tree( self._proto_path, - (relative_proto_directory_names for relative_proto_directory_names, - _ in proto_directories_and_names)) + ( + relative_proto_directory_names + for relative_proto_directory_names, _ in proto_directories_and_names + ), + ) self._absolute_proto_file_names = set() for relative_directory_names, file_name in proto_directories_and_names: absolute_proto_file_name = path.join( - self._proto_path, *relative_directory_names + (file_name,)) + self._proto_path, *relative_directory_names + (file_name,) + ) raw_proto_content = pkgutil.get_data( - 'tests.protoc_plugin.protos.invocation_testing', - path.join(*relative_directory_names + (file_name,))) + "tests.protoc_plugin.protos.invocation_testing", + path.join(*relative_directory_names + (file_name,)), + ) massaged_proto_content = _massage_proto_content( - raw_proto_content, self.NAME.encode(), - messages_proto_relative_file_name_forward_slashes.encode()) - with open(absolute_proto_file_name, 'wb') as proto_file: + raw_proto_content, + self.NAME.encode(), + messages_proto_relative_file_name_forward_slashes.encode(), + ) + with open(absolute_proto_file_name, "wb") as proto_file: proto_file.write(massaged_proto_content) self._absolute_proto_file_names.add(absolute_proto_file_name) @@ -223,7 +257,8 @@ def tearDown(self): def _protoc(self): protoc_exit_codes = self.PROTOC_STYLE.protoc( - self._proto_path, self._python_out, self._absolute_proto_file_names) + self._proto_path, self._python_out, self._absolute_proto_file_names + ) for protoc_exit_code in protoc_exit_codes: self.assertEqual(0, protoc_exit_code) @@ -243,7 +278,8 @@ def _protoc(self): self._messages_pb2 = generated_modules[self.EXPECTED_MESSAGES_PB2] self._services_pb2 = generated_modules[self.EXPECTED_SERVICES_PB2] self._services_pb2_grpc = generated_modules[ - self.EXPECTED_SERVICES_PB2_GRPC] + self.EXPECTED_SERVICES_PB2_GRPC + ] def _services_modules(self): if self.PROTOC_STYLE.grpc_in_pb2_expected(): @@ -256,7 +292,7 @@ def test_imported_attributes(self): self._messages_pb2.Request self._messages_pb2.Response - self._services_pb2.DESCRIPTOR.services_by_name['TestService'] + self._services_pb2.DESCRIPTOR.services_by_name["TestService"] for services_module in self._services_modules(): services_module.TestServiceStub services_module.TestServiceServicer @@ -268,10 +304,11 @@ def test_call(self): for services_module in self._services_modules(): server = test_common.test_server() services_module.add_TestServiceServicer_to_server( - _Servicer(self._messages_pb2.Response), server) - port = server.add_insecure_port('[::]:0') + _Servicer(self._messages_pb2.Response), server + ) + port = server.add_insecure_port("[::]:0") server.start() - channel = grpc.insecure_channel('localhost:{}'.format(port)) + channel = grpc.insecure_channel("localhost:{}".format(port)) stub = services_module.TestServiceStub(channel) response = stub.Call(self._messages_pb2.Request()) self.assertEqual(self._messages_pb2.Response(), response) @@ -281,74 +318,77 @@ def test_call(self): def _create_test_case_class(split_proto, protoc_style): attributes = {} - name = '{}{}'.format('SplitProto' if split_proto else 'SameProto', - protoc_style.name()) - attributes['NAME'] = name + name = "{}{}".format( + "SplitProto" if split_proto else "SameProto", protoc_style.name() + ) + attributes["NAME"] = name if split_proto: - attributes['MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES'] = ( - 'split_messages', - 'sub', + attributes["MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES"] = ( + "split_messages", + "sub", ) - attributes['MESSAGES_PROTO_FILE_NAME'] = 'messages.proto' - attributes['SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES'] = ( - 'split_services',) - attributes['SERVICES_PROTO_FILE_NAME'] = 'services.proto' - attributes['EXPECTED_MESSAGES_PB2'] = 'split_messages.sub.messages_pb2' - attributes['EXPECTED_SERVICES_PB2'] = 'split_services.services_pb2' - attributes['EXPECTED_SERVICES_PB2_GRPC'] = ( - 'split_services.services_pb2_grpc') + attributes["MESSAGES_PROTO_FILE_NAME"] = "messages.proto" + attributes["SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES"] = ( + "split_services", + ) + attributes["SERVICES_PROTO_FILE_NAME"] = "services.proto" + attributes["EXPECTED_MESSAGES_PB2"] = "split_messages.sub.messages_pb2" + attributes["EXPECTED_SERVICES_PB2"] = "split_services.services_pb2" + attributes[ + "EXPECTED_SERVICES_PB2_GRPC" + ] = "split_services.services_pb2_grpc" else: - attributes['MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES'] = () - attributes['MESSAGES_PROTO_FILE_NAME'] = 'same.proto' - attributes['SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES'] = () - attributes['SERVICES_PROTO_FILE_NAME'] = 'same.proto' - attributes['EXPECTED_MESSAGES_PB2'] = 'same_pb2' - attributes['EXPECTED_SERVICES_PB2'] = 'same_pb2' - attributes['EXPECTED_SERVICES_PB2_GRPC'] = 'same_pb2_grpc' + attributes["MESSAGES_PROTO_RELATIVE_DIRECTORY_NAMES"] = () + attributes["MESSAGES_PROTO_FILE_NAME"] = "same.proto" + attributes["SERVICES_PROTO_RELATIVE_DIRECTORY_NAMES"] = () + attributes["SERVICES_PROTO_FILE_NAME"] = "same.proto" + attributes["EXPECTED_MESSAGES_PB2"] = "same_pb2" + attributes["EXPECTED_SERVICES_PB2"] = "same_pb2" + attributes["EXPECTED_SERVICES_PB2_GRPC"] = "same_pb2_grpc" - attributes['PROTOC_STYLE'] = protoc_style + attributes["PROTOC_STYLE"] = protoc_style - attributes['__module__'] = _Test.__module__ + attributes["__module__"] = _Test.__module__ - return type('{}Test'.format(name), (_Test,), attributes) + return type("{}Test".format(name), (_Test,), attributes) def _create_test_case_classes(): for split_proto in ( - False, - True, + False, + True, ): for protoc_style in _PROTOC_STYLES: yield _create_test_case_class(split_proto, protoc_style) class WellKnownTypesTest(unittest.TestCase): - def testWellKnownTypes(self): os.chdir(_TEST_DIR) - out_dir = tempfile.mkdtemp(suffix="wkt_test", dir='.') + out_dir = tempfile.mkdtemp(suffix="wkt_test", dir=".") well_known_protos_include = pkg_resources.resource_filename( - 'grpc_tools', '_proto') + "grpc_tools", "_proto" + ) args = [ - 'grpc_tools.protoc', - '--proto_path=protos', - '--proto_path={}'.format(well_known_protos_include), - '--python_out={}'.format(out_dir), - '--grpc_python_out={}'.format(out_dir), - 'protos/invocation_testing/compiler.proto', + "grpc_tools.protoc", + "--proto_path=protos", + "--proto_path={}".format(well_known_protos_include), + "--python_out={}".format(out_dir), + "--grpc_python_out={}".format(out_dir), + "protos/invocation_testing/compiler.proto", ] rc = protoc.main(args) self.assertEqual(0, rc) def load_tests(loader, tests, pattern): - tests = (tuple( + tests = tuple( loader.loadTestsFromTestCase(test_case_class) - for test_case_class in _create_test_case_classes()) + - tuple(loader.loadTestsFromTestCase(WellKnownTypesTest))) + for test_case_class in _create_test_case_classes() + ) + tuple(loader.loadTestsFromTestCase(WellKnownTypesTest)) return unittest.TestSuite(tests=tests) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py index dfacea16ee9d4..6c343241e34d2 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py @@ -32,43 +32,43 @@ from tests.unit.framework.common import test_constants -_RELATIVE_PROTO_PATH = 'relative_proto_path' -_RELATIVE_PYTHON_OUT = 'relative_python_out' +_RELATIVE_PROTO_PATH = "relative_proto_path" +_RELATIVE_PYTHON_OUT = "relative_python_out" _PROTO_FILES_PATH_COMPONENTS = ( ( - 'beta_grpc_plugin_test', - 'payload', - 'test_payload.proto', + "beta_grpc_plugin_test", + "payload", + "test_payload.proto", ), ( - 'beta_grpc_plugin_test', - 'requests', - 'r', - 'test_requests.proto', + "beta_grpc_plugin_test", + "requests", + "r", + "test_requests.proto", ), ( - 'beta_grpc_plugin_test', - 'responses', - 'test_responses.proto', + "beta_grpc_plugin_test", + "responses", + "test_responses.proto", ), ( - 'beta_grpc_plugin_test', - 'service', - 'test_service.proto', + "beta_grpc_plugin_test", + "service", + "test_service.proto", ), ) -_PAYLOAD_PB2 = 'beta_grpc_plugin_test.payload.test_payload_pb2' -_REQUESTS_PB2 = 'beta_grpc_plugin_test.requests.r.test_requests_pb2' -_RESPONSES_PB2 = 'beta_grpc_plugin_test.responses.test_responses_pb2' -_SERVICE_PB2 = 'beta_grpc_plugin_test.service.test_service_pb2' +_PAYLOAD_PB2 = "beta_grpc_plugin_test.payload.test_payload_pb2" +_REQUESTS_PB2 = "beta_grpc_plugin_test.requests.r.test_requests_pb2" +_RESPONSES_PB2 = "beta_grpc_plugin_test.responses.test_responses_pb2" +_SERVICE_PB2 = "beta_grpc_plugin_test.service.test_service_pb2" # Identifiers of entities we expect to find in the generated module. -SERVICER_IDENTIFIER = 'BetaTestServiceServicer' -STUB_IDENTIFIER = 'BetaTestServiceStub' -SERVER_FACTORY_IDENTIFIER = 'beta_create_TestService_server' -STUB_FACTORY_IDENTIFIER = 'beta_create_TestService_stub' +SERVICER_IDENTIFIER = "BetaTestServiceServicer" +STUB_IDENTIFIER = "BetaTestServiceStub" +SERVER_FACTORY_IDENTIFIER = "beta_create_TestService_server" +STUB_FACTORY_IDENTIFIER = "beta_create_TestService_stub" @contextlib.contextmanager @@ -82,7 +82,7 @@ def _system_path(path_insertion): def _create_directory_tree(root, path_components_sequence): created = set() for path_components in path_components_sequence: - thus_far = '' + thus_far = "" for path_component in path_components: relative_path = path.join(thus_far, path_component) if relative_path not in created: @@ -94,21 +94,22 @@ def _create_directory_tree(root, path_components_sequence): def _massage_proto_content(raw_proto_content): imports_substituted = raw_proto_content.replace( b'import "tests/protoc_plugin/protos/', - b'import "beta_grpc_plugin_test/') + b'import "beta_grpc_plugin_test/', + ) package_statement_substituted = imports_substituted.replace( - b'package grpc_protoc_plugin;', b'package beta_grpc_protoc_plugin;') + b"package grpc_protoc_plugin;", b"package beta_grpc_protoc_plugin;" + ) return package_statement_substituted def _packagify(directory): for subdirectory, _, _ in os.walk(directory): - init_file_name = path.join(subdirectory, '__init__.py') - with open(init_file_name, 'wb') as init_file: - init_file.write(b'') + init_file_name = path.join(subdirectory, "__init__.py") + with open(init_file_name, "wb") as init_file: + init_file.write(b"") class _ServicerMethods(object): - def __init__(self, payload_pb2, responses_pb2): self._condition = threading.Condition() self._paused = False @@ -143,7 +144,7 @@ def _control(self): # pylint: disable=invalid-name def UnaryCall(self, request, unused_rpc_context): response = self._responses_pb2.SimpleResponse() response.payload.payload_type = self._payload_pb2.COMPRESSABLE - response.payload.payload_compressable = 'a' * request.response_size + response.payload.payload_compressable = "a" * request.response_size self._control() return response @@ -151,7 +152,7 @@ def StreamingOutputCall(self, request, unused_rpc_context): for parameter in request.response_parameters: response = self._responses_pb2.StreamingOutputCallResponse() response.payload.payload_type = self._payload_pb2.COMPRESSABLE - response.payload.payload_compressable = 'a' * parameter.size + response.payload.payload_compressable = "a" * parameter.size self._control() yield response @@ -169,7 +170,7 @@ def FullDuplexCall(self, request_iter, unused_rpc_context): for parameter in request.response_parameters: response = self._responses_pb2.StreamingOutputCallResponse() response.payload.payload_type = self._payload_pb2.COMPRESSABLE - response.payload.payload_compressable = 'a' * parameter.size + response.payload.payload_compressable = "a" * parameter.size self._control() yield response @@ -179,7 +180,7 @@ def HalfDuplexCall(self, request_iter, unused_rpc_context): for parameter in request.response_parameters: response = self._responses_pb2.StreamingOutputCallResponse() response.payload.payload_type = self._payload_pb2.COMPRESSABLE - response.payload.payload_compressable = 'a' * parameter.size + response.payload.payload_compressable = "a" * parameter.size self._control() responses.append(response) for response in responses: @@ -190,18 +191,17 @@ def HalfDuplexCall(self, request_iter, unused_rpc_context): def _CreateService(payload_pb2, responses_pb2, service_pb2): """Provides a servicer backend and a stub. - The servicer is just the implementation of the actual servicer passed to the - face player of the python RPC implementation; the two are detached. + The servicer is just the implementation of the actual servicer passed to the + face player of the python RPC implementation; the two are detached. - Yields: - A (servicer_methods, stub) pair where servicer_methods is the back-end of - the service bound to the stub and stub is the stub on which to invoke - RPCs. - """ + Yields: + A (servicer_methods, stub) pair where servicer_methods is the back-end of + the service bound to the stub and stub is the stub on which to invoke + RPCs. + """ servicer_methods = _ServicerMethods(payload_pb2, responses_pb2) class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)): - def UnaryCall(self, request, context): return servicer_methods.UnaryCall(request, context) @@ -219,9 +219,9 @@ def HalfDuplexCall(self, request_iter, context): servicer = Servicer() server = getattr(service_pb2, SERVER_FACTORY_IDENTIFIER)(servicer) - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") server.start() - channel = implementations.insecure_channel('localhost', port) + channel = implementations.insecure_channel("localhost", port) stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel) yield servicer_methods, stub server.stop(0) @@ -231,24 +231,24 @@ def HalfDuplexCall(self, request_iter, context): def _CreateIncompleteService(service_pb2): """Provides a servicer backend that fails to implement methods and its stub. - The servicer is just the implementation of the actual servicer passed to the - face player of the python RPC implementation; the two are detached. - Args: - service_pb2: The service_pb2 module generated by this test. - Yields: - A (servicer_methods, stub) pair where servicer_methods is the back-end of - the service bound to the stub and stub is the stub on which to invoke - RPCs. - """ + The servicer is just the implementation of the actual servicer passed to the + face player of the python RPC implementation; the two are detached. + Args: + service_pb2: The service_pb2 module generated by this test. + Yields: + A (servicer_methods, stub) pair where servicer_methods is the back-end of + the service bound to the stub and stub is the stub on which to invoke + RPCs. + """ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)): pass servicer = Servicer() server = getattr(service_pb2, SERVER_FACTORY_IDENTIFIER)(servicer) - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") server.start() - channel = implementations.insecure_channel('localhost', port) + channel = implementations.insecure_channel("localhost", port) stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel) yield None, stub server.stop(0) @@ -258,7 +258,7 @@ def _streaming_input_request_iterator(payload_pb2, requests_pb2): for _ in range(3): request = requests_pb2.StreamingInputCallRequest() request.payload.payload_type = payload_pb2.COMPRESSABLE - request.payload.payload_compressable = 'a' + request.payload.payload_compressable = "a" yield request @@ -284,13 +284,13 @@ def _full_duplex_request_iterator(requests_pb2): class PythonPluginTest(unittest.TestCase): """Test case for the gRPC Python protoc-plugin. - While reading these tests, remember that the futures API - (`stub.method.future()`) only gives futures for the *response-unary* - methods and does not exist for response-streaming methods. - """ + While reading these tests, remember that the futures API + (`stub.method.future()`) only gives futures for the *response-unary* + methods and does not exist for response-streaming methods. + """ def setUp(self): - self._directory = tempfile.mkdtemp(dir='.') + self._directory = tempfile.mkdtemp(dir=".") self._proto_path = path.join(self._directory, _RELATIVE_PROTO_PATH) self._python_out = path.join(self._directory, _RELATIVE_PYTHON_OUT) @@ -305,12 +305,14 @@ def setUp(self): self._proto_file_names = set() for proto_file_path_components in _PROTO_FILES_PATH_COMPONENTS: raw_proto_content = pkgutil.get_data( - 'tests.protoc_plugin.protos', - path.join(*proto_file_path_components[1:])) + "tests.protoc_plugin.protos", + path.join(*proto_file_path_components[1:]), + ) massaged_proto_content = _massage_proto_content(raw_proto_content) - proto_file_name = path.join(self._proto_path, - *proto_file_path_components) - with open(proto_file_name, 'wb') as proto_file: + proto_file_name = path.join( + self._proto_path, *proto_file_path_components + ) + with open(proto_file_name, "wb") as proto_file: proto_file.write(massaged_proto_content) self._proto_file_names.add(proto_file_name) @@ -319,10 +321,10 @@ def tearDown(self): def _protoc(self): args = [ - '', - '--proto_path={}'.format(self._proto_path), - '--python_out={}'.format(self._python_out), - '--grpc_python_out=grpc_1_0:{}'.format(self._python_out), + "", + "--proto_path={}".format(self._proto_path), + "--python_out={}".format(self._python_out), + "--grpc_python_out=grpc_1_0:{}".format(self._python_out), ] + list(self._proto_file_names) protoc_exit_code = protoc.main(args) self.assertEqual(0, protoc_exit_code) @@ -340,18 +342,22 @@ def testImportAttributes(self): # check that we can access the generated module and its members. self.assertIsNotNone( - getattr(self._service_pb2, SERVICER_IDENTIFIER, None)) + getattr(self._service_pb2, SERVICER_IDENTIFIER, None) + ) self.assertIsNotNone(getattr(self._service_pb2, STUB_IDENTIFIER, None)) self.assertIsNotNone( - getattr(self._service_pb2, SERVER_FACTORY_IDENTIFIER, None)) + getattr(self._service_pb2, SERVER_FACTORY_IDENTIFIER, None) + ) self.assertIsNotNone( - getattr(self._service_pb2, STUB_FACTORY_IDENTIFIER, None)) + getattr(self._service_pb2, STUB_FACTORY_IDENTIFIER, None) + ) def testUpDown(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ): self._requests_pb2.SimpleRequest(response_size=13) def testIncompleteServicer(self): @@ -362,50 +368,57 @@ def testIncompleteServicer(self): try: stub.UnaryCall(request, test_constants.LONG_TIMEOUT) except face.AbortionError as error: - self.assertEqual(interfaces.StatusCode.UNIMPLEMENTED, - error.code) + self.assertEqual( + interfaces.StatusCode.UNIMPLEMENTED, error.code + ) def testUnaryCall(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = self._requests_pb2.SimpleRequest(response_size=13) response = stub.UnaryCall(request, test_constants.LONG_TIMEOUT) - expected_response = methods.UnaryCall(request, 'not a real context!') + expected_response = methods.UnaryCall(request, "not a real context!") self.assertEqual(expected_response, response) def testUnaryCallFuture(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = self._requests_pb2.SimpleRequest(response_size=13) # Check that the call does not block waiting for the server to respond. with methods.pause(): response_future = stub.UnaryCall.future( - request, test_constants.LONG_TIMEOUT) + request, test_constants.LONG_TIMEOUT + ) response = response_future.result() - expected_response = methods.UnaryCall(request, 'not a real RpcContext!') + expected_response = methods.UnaryCall(request, "not a real RpcContext!") self.assertEqual(expected_response, response) def testUnaryCallFutureExpired(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = self._requests_pb2.SimpleRequest(response_size=13) with methods.pause(): response_future = stub.UnaryCall.future( - request, test_constants.SHORT_TIMEOUT) + request, test_constants.SHORT_TIMEOUT + ) with self.assertRaises(face.ExpirationError): response_future.result() def testUnaryCallFutureCancelled(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = self._requests_pb2.SimpleRequest(response_size=13) with methods.pause(): response_future = stub.UnaryCall.future(request, 1) @@ -415,48 +428,58 @@ def testUnaryCallFutureCancelled(self): def testUnaryCallFutureFailed(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = self._requests_pb2.SimpleRequest(response_size=13) with methods.fail(): response_future = stub.UnaryCall.future( - request, test_constants.LONG_TIMEOUT) + request, test_constants.LONG_TIMEOUT + ) self.assertIsNotNone(response_future.exception()) def testStreamingOutputCall(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = _streaming_output_request(self._requests_pb2) - responses = stub.StreamingOutputCall(request, - test_constants.LONG_TIMEOUT) + responses = stub.StreamingOutputCall( + request, test_constants.LONG_TIMEOUT + ) expected_responses = methods.StreamingOutputCall( - request, 'not a real RpcContext!') + request, "not a real RpcContext!" + ) for expected_response, response in itertools.zip_longest( - expected_responses, responses): + expected_responses, responses + ): self.assertEqual(expected_response, response) def testStreamingOutputCallExpired(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = _streaming_output_request(self._requests_pb2) with methods.pause(): responses = stub.StreamingOutputCall( - request, test_constants.SHORT_TIMEOUT) + request, test_constants.SHORT_TIMEOUT + ) with self.assertRaises(face.ExpirationError): list(responses) def testStreamingOutputCallCancelled(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = _streaming_output_request(self._requests_pb2) - responses = stub.StreamingOutputCall(request, - test_constants.LONG_TIMEOUT) + responses = stub.StreamingOutputCall( + request, test_constants.LONG_TIMEOUT + ) next(responses) responses.cancel() with self.assertRaises(face.CancellationError): @@ -465,8 +488,9 @@ def testStreamingOutputCallCancelled(self): def testStreamingOutputCallFailed(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request = _streaming_output_request(self._requests_pb2) with methods.fail(): responses = stub.StreamingOutputCall(request, 1) @@ -477,60 +501,77 @@ def testStreamingOutputCallFailed(self): def testStreamingInputCall(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): response = stub.StreamingInputCall( - _streaming_input_request_iterator(self._payload_pb2, - self._requests_pb2), - test_constants.LONG_TIMEOUT) + _streaming_input_request_iterator( + self._payload_pb2, self._requests_pb2 + ), + test_constants.LONG_TIMEOUT, + ) expected_response = methods.StreamingInputCall( - _streaming_input_request_iterator(self._payload_pb2, - self._requests_pb2), - 'not a real RpcContext!') + _streaming_input_request_iterator( + self._payload_pb2, self._requests_pb2 + ), + "not a real RpcContext!", + ) self.assertEqual(expected_response, response) def testStreamingInputCallFuture(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): with methods.pause(): response_future = stub.StreamingInputCall.future( - _streaming_input_request_iterator(self._payload_pb2, - self._requests_pb2), - test_constants.LONG_TIMEOUT) + _streaming_input_request_iterator( + self._payload_pb2, self._requests_pb2 + ), + test_constants.LONG_TIMEOUT, + ) response = response_future.result() expected_response = methods.StreamingInputCall( - _streaming_input_request_iterator(self._payload_pb2, - self._requests_pb2), - 'not a real RpcContext!') + _streaming_input_request_iterator( + self._payload_pb2, self._requests_pb2 + ), + "not a real RpcContext!", + ) self.assertEqual(expected_response, response) def testStreamingInputCallFutureExpired(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): with methods.pause(): response_future = stub.StreamingInputCall.future( - _streaming_input_request_iterator(self._payload_pb2, - self._requests_pb2), - test_constants.SHORT_TIMEOUT) + _streaming_input_request_iterator( + self._payload_pb2, self._requests_pb2 + ), + test_constants.SHORT_TIMEOUT, + ) with self.assertRaises(face.ExpirationError): response_future.result() - self.assertIsInstance(response_future.exception(), - face.ExpirationError) + self.assertIsInstance( + response_future.exception(), face.ExpirationError + ) def testStreamingInputCallFutureCancelled(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): with methods.pause(): response_future = stub.StreamingInputCall.future( - _streaming_input_request_iterator(self._payload_pb2, - self._requests_pb2), - test_constants.LONG_TIMEOUT) + _streaming_input_request_iterator( + self._payload_pb2, self._requests_pb2 + ), + test_constants.LONG_TIMEOUT, + ) response_future.cancel() self.assertTrue(response_future.cancelled()) with self.assertRaises(future.CancelledError): @@ -539,50 +580,61 @@ def testStreamingInputCallFutureCancelled(self): def testStreamingInputCallFutureFailed(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): with methods.fail(): response_future = stub.StreamingInputCall.future( - _streaming_input_request_iterator(self._payload_pb2, - self._requests_pb2), - test_constants.LONG_TIMEOUT) + _streaming_input_request_iterator( + self._payload_pb2, self._requests_pb2 + ), + test_constants.LONG_TIMEOUT, + ) self.assertIsNotNone(response_future.exception()) def testFullDuplexCall(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): responses = stub.FullDuplexCall( _full_duplex_request_iterator(self._requests_pb2), - test_constants.LONG_TIMEOUT) + test_constants.LONG_TIMEOUT, + ) expected_responses = methods.FullDuplexCall( _full_duplex_request_iterator(self._requests_pb2), - 'not a real RpcContext!') + "not a real RpcContext!", + ) for expected_response, response in itertools.zip_longest( - expected_responses, responses): + expected_responses, responses + ): self.assertEqual(expected_response, response) def testFullDuplexCallExpired(self): self._protoc() request_iterator = _full_duplex_request_iterator(self._requests_pb2) - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): with methods.pause(): - responses = stub.FullDuplexCall(request_iterator, - test_constants.SHORT_TIMEOUT) + responses = stub.FullDuplexCall( + request_iterator, test_constants.SHORT_TIMEOUT + ) with self.assertRaises(face.ExpirationError): list(responses) def testFullDuplexCallCancelled(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): request_iterator = _full_duplex_request_iterator(self._requests_pb2) - responses = stub.FullDuplexCall(request_iterator, - test_constants.LONG_TIMEOUT) + responses = stub.FullDuplexCall( + request_iterator, test_constants.LONG_TIMEOUT + ) next(responses) responses.cancel() with self.assertRaises(face.CancellationError): @@ -592,11 +644,13 @@ def testFullDuplexCallFailed(self): self._protoc() request_iterator = _full_duplex_request_iterator(self._requests_pb2) - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): with methods.fail(): - responses = stub.FullDuplexCall(request_iterator, - test_constants.LONG_TIMEOUT) + responses = stub.FullDuplexCall( + request_iterator, test_constants.LONG_TIMEOUT + ) self.assertIsNotNone(responses) with self.assertRaises(face.RemoteError): next(responses) @@ -604,8 +658,9 @@ def testFullDuplexCallFailed(self): def testHalfDuplexCall(self): self._protoc() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): def half_duplex_request_iterator(): request = self._requests_pb2.StreamingOutputCallRequest() @@ -616,10 +671,12 @@ def half_duplex_request_iterator(): request.response_parameters.add(size=3, interval_us=0) yield request - responses = stub.HalfDuplexCall(half_duplex_request_iterator(), - test_constants.LONG_TIMEOUT) + responses = stub.HalfDuplexCall( + half_duplex_request_iterator(), test_constants.LONG_TIMEOUT + ) expected_responses = methods.HalfDuplexCall( - half_duplex_request_iterator(), 'not a real RpcContext!') + half_duplex_request_iterator(), "not a real RpcContext!" + ) for check in itertools.zip_longest(expected_responses, responses): expected_response, response = check self.assertEqual(expected_response, response) @@ -648,15 +705,17 @@ def half_duplex_request_iterator(): while wait_cell[0]: condition.wait() - with _CreateService(self._payload_pb2, self._responses_pb2, - self._service_pb2) as (methods, stub): + with _CreateService( + self._payload_pb2, self._responses_pb2, self._service_pb2 + ) as (methods, stub): with wait(): - responses = stub.HalfDuplexCall(half_duplex_request_iterator(), - test_constants.SHORT_TIMEOUT) + responses = stub.HalfDuplexCall( + half_duplex_request_iterator(), test_constants.SHORT_TIMEOUT + ) # half-duplex waits for the client to send all info with self.assertRaises(face.ExpirationError): next(responses) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py index 865e6713048dd..e5aafc4142f53 100644 --- a/src/python/grpcio_tests/tests/qps/benchmark_client.py +++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py @@ -30,14 +30,16 @@ class GenericStub(object): - def __init__(self, channel): self.UnaryCall = channel.unary_unary( - '/grpc.testing.BenchmarkService/UnaryCall') + "/grpc.testing.BenchmarkService/UnaryCall" + ) self.StreamingFromServer = channel.unary_stream( - '/grpc.testing.BenchmarkService/StreamingFromServer') + "/grpc.testing.BenchmarkService/StreamingFromServer" + ) self.StreamingCall = channel.stream_stream( - '/grpc.testing.BenchmarkService/StreamingCall') + "/grpc.testing.BenchmarkService/StreamingCall" + ) class BenchmarkClient: @@ -47,32 +49,37 @@ class BenchmarkClient: def __init__(self, server, config, hist): # Create the stub - if config.HasField('security_params'): + if config.HasField("security_params"): creds = grpc.ssl_channel_credentials( - resources.test_root_certificates()) + resources.test_root_certificates() + ) channel = test_common.test_secure_channel( - server, creds, config.security_params.server_host_override) + server, creds, config.security_params.server_host_override + ) else: channel = grpc.insecure_channel(server) # waits for the channel to be ready before we start sending messages grpc.channel_ready_future(channel).result() - if config.payload_config.WhichOneof('payload') == 'simple_params': + if config.payload_config.WhichOneof("payload") == "simple_params": self._generic = False self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( - channel) + channel + ) payload = messages_pb2.Payload( - body=bytes(b'\0' * - config.payload_config.simple_params.req_size)) + body=bytes(b"\0" * config.payload_config.simple_params.req_size) + ) self._request = messages_pb2.SimpleRequest( payload=payload, - response_size=config.payload_config.simple_params.resp_size) + response_size=config.payload_config.simple_params.resp_size, + ) else: self._generic = True self._stub = GenericStub(channel) - self._request = bytes(b'\0' * - config.payload_config.bytebuf_params.req_size) + self._request = bytes( + b"\0" * config.payload_config.bytebuf_params.req_size + ) self._hist = hist self._response_callbacks = [] @@ -99,11 +106,11 @@ def _handle_response(self, client, query_time): class UnarySyncBenchmarkClient(BenchmarkClient): - def __init__(self, server, config, hist): super(UnarySyncBenchmarkClient, self).__init__(server, config, hist) self._pool = futures.ThreadPoolExecutor( - max_workers=config.outstanding_rpcs_per_channel) + max_workers=config.outstanding_rpcs_per_channel + ) def send_request(self): # Send requests in separate threads to support multiple outstanding rpcs @@ -122,13 +129,13 @@ def _dispatch_request(self): class UnaryAsyncBenchmarkClient(BenchmarkClient): - def send_request(self): # Use the Future callback api to support multiple outstanding rpcs start_time = time.time() response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT) response_future.add_done_callback( - lambda resp: self._response_received(start_time, resp)) + lambda resp: self._response_received(start_time, resp) + ) def _response_received(self, start_time, resp): resp.result() @@ -140,7 +147,6 @@ def stop(self): class _SyncStream(object): - def __init__(self, stub, generic, request, handle_response): self._stub = stub self._generic = generic @@ -156,12 +162,13 @@ def send_request(self): def start(self): self._is_streaming = True - response_stream = self._stub.StreamingCall(self._request_generator(), - _TIMEOUT) + response_stream = self._stub.StreamingCall( + self._request_generator(), _TIMEOUT + ) for _ in response_stream: self._handle_response( - self, - time.time() - self._send_time_queue.get_nowait()) + self, time.time() - self._send_time_queue.get_nowait() + ) def stop(self): self._is_streaming = False @@ -176,14 +183,15 @@ def _request_generator(self): class StreamingSyncBenchmarkClient(BenchmarkClient): - def __init__(self, server, config, hist): super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist) self._pool = futures.ThreadPoolExecutor( - max_workers=config.outstanding_rpcs_per_channel) + max_workers=config.outstanding_rpcs_per_channel + ) self._streams = [ - _SyncStream(self._stub, self._generic, self._request, - self._handle_response) + _SyncStream( + self._stub, self._generic, self._request, self._handle_response + ) for _ in range(config.outstanding_rpcs_per_channel) ] self._curr_stream = 0 @@ -205,29 +213,32 @@ def stop(self): class ServerStreamingSyncBenchmarkClient(BenchmarkClient): - def __init__(self, server, config, hist): - super(ServerStreamingSyncBenchmarkClient, - self).__init__(server, config, hist) + super(ServerStreamingSyncBenchmarkClient, self).__init__( + server, config, hist + ) if config.outstanding_rpcs_per_channel == 1: self._pool = None else: self._pool = futures.ThreadPoolExecutor( - max_workers=config.outstanding_rpcs_per_channel) + max_workers=config.outstanding_rpcs_per_channel + ) self._rpcs = [] self._sender = None def send_request(self): if self._pool is None: self._sender = threading.Thread( - target=self._one_stream_streaming_rpc, daemon=True) + target=self._one_stream_streaming_rpc, daemon=True + ) self._sender.start() else: self._pool.submit(self._one_stream_streaming_rpc) def _one_stream_streaming_rpc(self): response_stream = self._stub.StreamingFromServer( - self._request, _TIMEOUT) + self._request, _TIMEOUT + ) self._rpcs.append(response_stream) start_time = time.time() for _ in response_stream: diff --git a/src/python/grpcio_tests/tests/qps/benchmark_server.py b/src/python/grpcio_tests/tests/qps/benchmark_server.py index 644543086b6bb..968aef2f05ccc 100644 --- a/src/python/grpcio_tests/tests/qps/benchmark_server.py +++ b/src/python/grpcio_tests/tests/qps/benchmark_server.py @@ -20,21 +20,22 @@ class BenchmarkServer(benchmark_service_pb2_grpc.BenchmarkServiceServicer): """Synchronous Server implementation for the Benchmark service.""" def UnaryCall(self, request, context): - payload = messages_pb2.Payload(body=b'\0' * request.response_size) + payload = messages_pb2.Payload(body=b"\0" * request.response_size) return messages_pb2.SimpleResponse(payload=payload) def StreamingCall(self, request_iterator, context): for request in request_iterator: - payload = messages_pb2.Payload(body=b'\0' * request.response_size) + payload = messages_pb2.Payload(body=b"\0" * request.response_size) yield messages_pb2.SimpleResponse(payload=payload) -class GenericBenchmarkServer(benchmark_service_pb2_grpc.BenchmarkServiceServicer - ): +class GenericBenchmarkServer( + benchmark_service_pb2_grpc.BenchmarkServiceServicer +): """Generic Server implementation for the Benchmark service.""" def __init__(self, resp_size): - self._response = b'\0' * resp_size + self._response = b"\0" * resp_size def UnaryCall(self, request, context): return self._response diff --git a/src/python/grpcio_tests/tests/qps/client_runner.py b/src/python/grpcio_tests/tests/qps/client_runner.py index a03174472c023..eca0155483ca9 100644 --- a/src/python/grpcio_tests/tests/qps/client_runner.py +++ b/src/python/grpcio_tests/tests/qps/client_runner.py @@ -40,13 +40,13 @@ def stop(self): class OpenLoopClientRunner(ClientRunner): - def __init__(self, client, interval_generator): super(OpenLoopClientRunner, self).__init__(client) self._is_running = False self._interval_generator = interval_generator - self._dispatch_thread = threading.Thread(target=self._dispatch_requests, - args=()) + self._dispatch_thread = threading.Thread( + target=self._dispatch_requests, args=() + ) def start(self): self._is_running = True @@ -66,7 +66,6 @@ def _dispatch_requests(self): class ClosedLoopClientRunner(ClientRunner): - def __init__(self, client, request_count, no_ping_pong): super(ClosedLoopClientRunner, self).__init__(client) self._is_running = False diff --git a/src/python/grpcio_tests/tests/qps/histogram.py b/src/python/grpcio_tests/tests/qps/histogram.py index 8139a6ee2fb04..d5fdccb6bfe2f 100644 --- a/src/python/grpcio_tests/tests/qps/histogram.py +++ b/src/python/grpcio_tests/tests/qps/histogram.py @@ -21,8 +21,8 @@ class Histogram(object): """Histogram class used for recording performance testing data. - This class is thread safe. - """ + This class is thread safe. + """ def __init__(self, resolution, max_possible): self._lock = threading.Lock() diff --git a/src/python/grpcio_tests/tests/qps/qps_worker.py b/src/python/grpcio_tests/tests/qps/qps_worker.py index 0708cc06f3e0e..a1260731a0cb5 100644 --- a/src/python/grpcio_tests/tests/qps/qps_worker.py +++ b/src/python/grpcio_tests/tests/qps/qps_worker.py @@ -28,28 +28,33 @@ def run_worker_server(driver_port, server_port): server = test_common.test_server() servicer = worker_server.WorkerServer(server_port) worker_service_pb2_grpc.add_WorkerServiceServicer_to_server( - servicer, server) - server.add_insecure_port('[::]:{}'.format(driver_port)) + servicer, server + ) + server.add_insecure_port("[::]:{}".format(driver_port)) server.start() servicer.wait_for_quit() server.stop(0) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) parser = argparse.ArgumentParser( - description='gRPC Python performance testing worker') + description="gRPC Python performance testing worker" + ) parser.add_argument( - '--driver_port', + "--driver_port", type=int, - dest='driver_port', - help='The port for the worker to expose for driver communication') + dest="driver_port", + help="The port for the worker to expose for driver communication", + ) parser.add_argument( - '--server_port', + "--server_port", type=int, default=None, - dest='server_port', - help='The port for the server if not specified by server config message' + dest="server_port", + help=( + "The port for the server if not specified by server config message" + ), ) args = parser.parse_args() diff --git a/src/python/grpcio_tests/tests/qps/worker_server.py b/src/python/grpcio_tests/tests/qps/worker_server.py index ac70248ef7e9b..82db1a4af74df 100644 --- a/src/python/grpcio_tests/tests/qps/worker_server.py +++ b/src/python/grpcio_tests/tests/qps/worker_server.py @@ -40,7 +40,6 @@ class Snapshotter: - def __init__(self): self._start_time = 0.0 self._end_time = 0.0 @@ -87,7 +86,9 @@ def __init__(self, server_port=None): self._snapshotter = Snapshotter() def RunServer(self, request_iterator, context): - config = next(request_iterator).setup #pylint: disable=stop-iteration-return + # pylint: disable=stop-iteration-return + config = next(request_iterator).setup + # pylint: enable=stop-iteration-return server, port = self._create_server(config) cores = multiprocessing.cpu_count() server.start() @@ -118,44 +119,54 @@ def _create_server(self, config): if config.server_type == control_pb2.ASYNC_SERVER: servicer = benchmark_server.BenchmarkServer() benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( - servicer, server) + servicer, server + ) elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: resp_size = config.payload_config.bytebuf_params.resp_size servicer = benchmark_server.GenericBenchmarkServer(resp_size) method_implementations = { - 'StreamingCall': - grpc.stream_stream_rpc_method_handler(servicer.StreamingCall - ), - 'UnaryCall': - grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), + "StreamingCall": grpc.stream_stream_rpc_method_handler( + servicer.StreamingCall + ), + "UnaryCall": grpc.unary_unary_rpc_method_handler( + servicer.UnaryCall + ), } handler = grpc.method_handlers_generic_handler( - 'grpc.testing.BenchmarkService', method_implementations) + "grpc.testing.BenchmarkService", method_implementations + ) server.add_generic_rpc_handlers((handler,)) else: - raise Exception('Unsupported server type {}'.format( - config.server_type)) + raise Exception( + "Unsupported server type {}".format(config.server_type) + ) if self._server_port is not None and config.port == 0: server_port = self._server_port else: server_port = config.port - if config.HasField('security_params'): # Use SSL + if config.HasField("security_params"): # Use SSL server_creds = grpc.ssl_server_credentials( - ((resources.private_key(), resources.certificate_chain()),)) - port = server.add_secure_port('[::]:{}'.format(server_port), - server_creds) + ((resources.private_key(), resources.certificate_chain()),) + ) + port = server.add_secure_port( + "[::]:{}".format(server_port), server_creds + ) else: - port = server.add_insecure_port('[::]:{}'.format(server_port)) + port = server.add_insecure_port("[::]:{}".format(server_port)) return (server, port) def RunClient(self, request_iterator, context): - config = next(request_iterator).setup #pylint: disable=stop-iteration-return + # pylint: disable=stop-iteration-return + config = next(request_iterator).setup + # pylint: enable=stop-iteration-return client_runners = [] - qps_data = histogram.Histogram(config.histogram_params.resolution, - config.histogram_params.max_possible) + qps_data = histogram.Histogram( + config.histogram_params.resolution, + config.histogram_params.max_possible, + ) self._snapshotter.snapshot() self._snapshotter.reset() @@ -184,8 +195,9 @@ def RunClient(self, request_iterator, context): def _get_client_status(self, qps_data): latencies = qps_data.get_data() - stats = stats_pb2.ClientStats(latencies=latencies, - **self._snapshotter.stats()) + stats = stats_pb2.ClientStats( + latencies=latencies, **self._snapshotter.stats() + ) return control_pb2.ClientStatus(stats=stats) def _create_client_runner(self, server, config, qps_data): @@ -193,29 +205,35 @@ def _create_client_runner(self, server, config, qps_data): if config.client_type == control_pb2.SYNC_CLIENT: if config.rpc_type == control_pb2.UNARY: client = benchmark_client.UnarySyncBenchmarkClient( - server, config, qps_data) + server, config, qps_data + ) elif config.rpc_type == control_pb2.STREAMING: client = benchmark_client.StreamingSyncBenchmarkClient( - server, config, qps_data) + server, config, qps_data + ) elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER: no_ping_pong = True client = benchmark_client.ServerStreamingSyncBenchmarkClient( - server, config, qps_data) + server, config, qps_data + ) elif config.client_type == control_pb2.ASYNC_CLIENT: if config.rpc_type == control_pb2.UNARY: client = benchmark_client.UnaryAsyncBenchmarkClient( - server, config, qps_data) + server, config, qps_data + ) else: - raise Exception('Async streaming client not supported') + raise Exception("Async streaming client not supported") else: - raise Exception('Unsupported client type {}'.format( - config.client_type)) + raise Exception( + "Unsupported client type {}".format(config.client_type) + ) # In multi-channel tests, we split the load across all channels load_factor = float(config.client_channels) - if config.load_params.WhichOneof('load') == 'closed_loop': + if config.load_params.WhichOneof("load") == "closed_loop": runner = client_runner.ClosedLoopClientRunner( - client, config.outstanding_rpcs_per_channel, no_ping_pong) + client, config.outstanding_rpcs_per_channel, no_ping_pong + ) else: # Open loop Poisson alpha = config.load_params.poisson.offered_load / load_factor diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_client_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_client_test.py index a008867985695..0fad0aa76a5f6 100644 --- a/src/python/grpcio_tests/tests/reflection/_reflection_client_test.py +++ b/src/python/grpcio_tests/tests/reflection/_reflection_client_test.py @@ -18,10 +18,12 @@ from google.protobuf.descriptor_pool import DescriptorPool import grpc from grpc_reflection.v1alpha import reflection -from grpc_reflection.v1alpha.proto_reflection_descriptor_database import \ - ProtoReflectionDescriptorDatabase +from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ( + ProtoReflectionDescriptorDatabase, +) from src.proto.grpc.testing import test_pb2 + # Needed to load the EmptyWithExtensions message from src.proto.grpc.testing.proto2 import empty2_extensions_pb2 from tests.unit import test_common @@ -36,7 +38,6 @@ class ReflectionClientTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() self._SERVICE_NAMES = ( @@ -61,8 +62,9 @@ def testListServices(self): self.assertCountEqual(self._SERVICE_NAMES, services) def testReflectionServiceName(self): - self.assertEqual(reflection.SERVICE_NAME, - "grpc.reflection.v1alpha.ServerReflection") + self.assertEqual( + reflection.SERVICE_NAME, "grpc.reflection.v1alpha.ServerReflection" + ) def testFindFile(self): file_name = _PROTO_FILE_NAME @@ -110,7 +112,8 @@ def testFindServiceFindMethod(self): self.assertTrue(method_desc.full_name.endswith(method_name)) empty_message_desc = self.desc_pool.FindMessageTypeByName( - _EMPTY_PROTO_SYMBOL_NAME) + _EMPTY_PROTO_SYMBOL_NAME + ) self.assertEqual(empty_message_desc, method_desc.input_type) self.assertEqual(empty_message_desc, method_desc.output_type) diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py index ee6a05d005b22..7568abcd17811 100644 --- a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py +++ b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py @@ -28,11 +28,18 @@ from src.proto.grpc.testing.proto2 import empty2_pb2 from tests.unit import test_common -_EMPTY_PROTO_FILE_NAME = 'src/proto/grpc/testing/empty.proto' -_EMPTY_PROTO_SYMBOL_NAME = 'grpc.testing.Empty' -_SERVICE_NAMES = ('Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman', - 'Galilei') -_EMPTY_EXTENSIONS_SYMBOL_NAME = 'grpc.testing.proto2.EmptyWithExtensions' +_EMPTY_PROTO_FILE_NAME = "src/proto/grpc/testing/empty.proto" +_EMPTY_PROTO_SYMBOL_NAME = "grpc.testing.Empty" +_SERVICE_NAMES = ( + "Angstrom", + "Bohr", + "Curie", + "Dyson", + "Einstein", + "Feynman", + "Galilei", +) +_EMPTY_EXTENSIONS_SYMBOL_NAME = "grpc.testing.proto2.EmptyWithExtensions" _EMPTY_EXTENSIONS_NUMBERS = ( 124, 125, @@ -48,17 +55,17 @@ def _file_descriptor_to_proto(descriptor): return proto.SerializeToString() -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class ReflectionServicerTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() reflection.enable_server_reflection(_SERVICE_NAMES, self._server) - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) def tearDown(self): @@ -68,47 +75,58 @@ def tearDown(self): def testFileByName(self): requests = ( reflection_pb2.ServerReflectionRequest( - file_by_filename=_EMPTY_PROTO_FILE_NAME), + file_by_filename=_EMPTY_PROTO_FILE_NAME + ), reflection_pb2.ServerReflectionRequest( - file_by_filename='i-donut-exist'), + file_by_filename="i-donut-exist" + ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = ( reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=( - _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))), + _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), + ) + ), + ), reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )), + ), + ), ) self.assertEqual(expected_responses, responses) def testFileBySymbol(self): requests = ( reflection_pb2.ServerReflectionRequest( - file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME), + file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME + ), reflection_pb2.ServerReflectionRequest( - file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo' + file_containing_symbol="i.donut.exist.co.uk.org.net.me.name.foo" ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = ( reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=( - _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))), + _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), + ) + ), + ), reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )), + ), + ), ) self.assertEqual(expected_responses, responses) @@ -118,71 +136,91 @@ def testFileContainingExtension(self): file_containing_extension=reflection_pb2.ExtensionRequest( containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME, extension_number=125, - ),), + ), + ), reflection_pb2.ServerReflectionRequest( file_containing_extension=reflection_pb2.ExtensionRequest( - containing_type='i.donut.exist.co.uk.org.net.me.name.foo', + containing_type="i.donut.exist.co.uk.org.net.me.name.foo", extension_number=55, - ),), + ), + ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = ( reflection_pb2.ServerReflectionResponse( - valid_host='', - file_descriptor_response=reflection_pb2. - FileDescriptorResponse(file_descriptor_proto=( - _file_descriptor_to_proto(empty2_extensions_pb2.DESCRIPTOR), - _file_descriptor_to_proto(empty2_pb2.DESCRIPTOR), - ))), + valid_host="", + file_descriptor_response=reflection_pb2.FileDescriptorResponse( + file_descriptor_proto=( + _file_descriptor_to_proto( + empty2_extensions_pb2.DESCRIPTOR + ), + _file_descriptor_to_proto(empty2_pb2.DESCRIPTOR), + ) + ), + ), reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )), + ), + ), ) self.assertEqual(expected_responses, responses) def testExtensionNumbersOfType(self): requests = ( reflection_pb2.ServerReflectionRequest( - all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME), + all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME + ), reflection_pb2.ServerReflectionRequest( - all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo' + all_extension_numbers_of_type="i.donut.exist.co.uk.net.name.foo" ), ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) expected_responses = ( reflection_pb2.ServerReflectionResponse( - valid_host='', - all_extension_numbers_response=reflection_pb2. - ExtensionNumberResponse( + valid_host="", + all_extension_numbers_response=reflection_pb2.ExtensionNumberResponse( base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME, - extension_number=_EMPTY_EXTENSIONS_NUMBERS)), + extension_number=_EMPTY_EXTENSIONS_NUMBERS, + ), + ), reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )), + ), + ), ) self.assertEqual(expected_responses, responses) def testListServices(self): - requests = (reflection_pb2.ServerReflectionRequest(list_services='',),) + requests = ( + reflection_pb2.ServerReflectionRequest( + list_services="", + ), + ) responses = tuple(self._stub.ServerReflectionInfo(iter(requests))) - expected_responses = (reflection_pb2.ServerReflectionResponse( - valid_host='', - list_services_response=reflection_pb2.ListServiceResponse( - service=tuple( - reflection_pb2.ServiceResponse(name=name) - for name in _SERVICE_NAMES))),) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host="", + list_services_response=reflection_pb2.ListServiceResponse( + service=tuple( + reflection_pb2.ServiceResponse(name=name) + for name in _SERVICE_NAMES + ) + ), + ), + ) self.assertEqual(expected_responses, responses) def testReflectionServiceName(self): - self.assertEqual(reflection.SERVICE_NAME, - 'grpc.reflection.v1alpha.ServerReflection') + self.assertEqual( + reflection.SERVICE_NAME, "grpc.reflection.v1alpha.ServerReflection" + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/status/_grpc_status_test.py b/src/python/grpcio_tests/tests/status/_grpc_status_test.py index a79dd555abf7b..2573e961f1873 100644 --- a/src/python/grpcio_tests/tests/status/_grpc_status_test.py +++ b/src/python/grpcio_tests/tests/status/_grpc_status_test.py @@ -20,6 +20,7 @@ # please refer to comments in the "bazel_namespace_package_hack" module. try: from tests import bazel_namespace_package_hack + bazel_namespace_package_hack.sys_path_to_site_dir_hack() except ImportError: pass @@ -38,19 +39,19 @@ from google.protobuf import any_pb2 from google.rpc import code_pb2, status_pb2, error_details_pb2 -_STATUS_OK = '/test/StatusOK' -_STATUS_NOT_OK = '/test/StatusNotOk' -_ERROR_DETAILS = '/test/ErrorDetails' -_INCONSISTENT = '/test/Inconsistent' -_INVALID_CODE = '/test/InvalidCode' +_STATUS_OK = "/test/StatusOK" +_STATUS_NOT_OK = "/test/StatusNotOk" +_ERROR_DETAILS = "/test/ErrorDetails" +_INCONSISTENT = "/test/Inconsistent" +_INVALID_CODE = "/test/InvalidCode" -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x01\x01\x01' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x01\x01\x01" -_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' +_GRPC_DETAILS_METADATA_KEY = "grpc-status-details-bin" -_STATUS_DETAILS = 'This is an error detail' -_STATUS_DETAILS_ANOTHER = 'This is another error detail' +_STATUS_DETAILS = "This is an error detail" +_STATUS_DETAILS_ANOTHER = "This is another error detail" def _ok_unary_unary(request, servicer_context): @@ -64,8 +65,11 @@ def _not_ok_unary_unary(request, servicer_context): def _error_details_unary_unary(request, servicer_context): details = any_pb2.Any() details.Pack( - error_details_pb2.DebugInfo(stack_entries=traceback.format_stack(), - detail='Intentionally invoked')) + error_details_pb2.DebugInfo( + stack_entries=traceback.format_stack(), + detail="Intentionally invoked", + ) + ) rich_status = status_pb2.Status( code=code_pb2.INTERNAL, message=_STATUS_DETAILS, @@ -83,19 +87,19 @@ def _inconsistent_unary_unary(request, servicer_context): servicer_context.set_details(_STATUS_DETAILS_ANOTHER) # User put inconsistent status information in trailing metadata servicer_context.set_trailing_metadata( - ((_GRPC_DETAILS_METADATA_KEY, rich_status.SerializeToString()),)) + ((_GRPC_DETAILS_METADATA_KEY, rich_status.SerializeToString()),) + ) def _invalid_code_unary_unary(request, servicer_context): rich_status = status_pb2.Status( code=42, - message='Invalid code', + message="Invalid code", ) servicer_context.abort_with_status(rpc_status.to_status(rich_status)) class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _STATUS_OK: return grpc.unary_unary_rpc_method_handler(_ok_unary_unary) @@ -103,28 +107,31 @@ def service(self, handler_call_details): return grpc.unary_unary_rpc_method_handler(_not_ok_unary_unary) elif handler_call_details.method == _ERROR_DETAILS: return grpc.unary_unary_rpc_method_handler( - _error_details_unary_unary) + _error_details_unary_unary + ) elif handler_call_details.method == _INCONSISTENT: return grpc.unary_unary_rpc_method_handler( - _inconsistent_unary_unary) + _inconsistent_unary_unary + ) elif handler_call_details.method == _INVALID_CODE: return grpc.unary_unary_rpc_method_handler( - _invalid_code_unary_unary) + _invalid_code_unary_unary + ) else: return None -@unittest.skipIf(sys.version_info[0] < 3, - 'ProtoBuf descriptor has moved on from Python2') +@unittest.skipIf( + sys.version_info[0] < 3, "ProtoBuf descriptor has moved on from Python2" +) class StatusTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() self._server.add_generic_rpc_handlers((_GenericHandler(),)) - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._server.stop(None) @@ -154,14 +161,15 @@ def test_error_details(self): status = rpc_status.from_call(rpc_error) self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) - self.assertEqual(status.code, code_pb2.Code.Value('INTERNAL')) + self.assertEqual(status.code, code_pb2.Code.Value("INTERNAL")) # Check if the underlying proto message is intact self.assertEqual( - status.details[0].Is(error_details_pb2.DebugInfo.DESCRIPTOR), True) + status.details[0].Is(error_details_pb2.DebugInfo.DESCRIPTOR), True + ) info = error_details_pb2.DebugInfo() status.details[0].Unpack(info) - self.assertIn('_error_details_unary_unary', info.stack_entries[-1]) + self.assertIn("_error_details_unary_unary", info.stack_entries[-1]) def test_code_message_validation(self): with self.assertRaises(grpc.RpcError) as exception_context: @@ -178,9 +186,9 @@ def test_invalid_code(self): rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) # Invalid status code exception raised during coversion - self.assertIn('Invalid status code', rpc_error.details()) + self.assertIn("Invalid status code", rpc_error.details()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/stress/client.py b/src/python/grpcio_tests/tests/stress/client.py index e2f57f552e2c3..a7cf40c0852a4 100644 --- a/src/python/grpcio_tests/tests/stress/client.py +++ b/src/python/grpcio_tests/tests/stress/client.py @@ -31,45 +31,58 @@ def _args(): parser = argparse.ArgumentParser( - description='gRPC Python stress test client') + description="gRPC Python stress test client" + ) parser.add_argument( - '--server_addresses', - help='comma separated list of hostname:port to run servers on', - default='localhost:8080', - type=str) + "--server_addresses", + help="comma separated list of hostname:port to run servers on", + default="localhost:8080", + type=str, + ) parser.add_argument( - '--test_cases', - help='comma separated list of testcase:weighting of tests to run', - default='large_unary:100', - type=str) - parser.add_argument('--test_duration_secs', - help='number of seconds to run the stress test', - default=-1, - type=int) - parser.add_argument('--num_channels_per_server', - help='number of channels per server', - default=1, - type=int) - parser.add_argument('--num_stubs_per_channel', - help='number of stubs to create per channel', - default=1, - type=int) - parser.add_argument('--metrics_port', - help='the port to listen for metrics requests on', - default=8081, - type=int) + "--test_cases", + help="comma separated list of testcase:weighting of tests to run", + default="large_unary:100", + type=str, + ) parser.add_argument( - '--use_test_ca', - help='Whether to use our fake CA. Requires --use_tls=true', + "--test_duration_secs", + help="number of seconds to run the stress test", + default=-1, + type=int, + ) + parser.add_argument( + "--num_channels_per_server", + help="number of channels per server", + default=1, + type=int, + ) + parser.add_argument( + "--num_stubs_per_channel", + help="number of stubs to create per channel", + default=1, + type=int, + ) + parser.add_argument( + "--metrics_port", + help="the port to listen for metrics requests on", + default=8081, + type=int, + ) + parser.add_argument( + "--use_test_ca", + help="Whether to use our fake CA. Requires --use_tls=true", default=False, - type=bool) - parser.add_argument('--use_tls', - help='Whether to use TLS', - default=False, - type=bool) - parser.add_argument('--server_host_override', - help='the server host to which to claim to connect', - type=str) + type=bool, + ) + parser.add_argument( + "--use_tls", help="Whether to use TLS", default=False, type=bool + ) + parser.add_argument( + "--server_host_override", + help="the server host to which to claim to connect", + type=str, + ) return parser.parse_args() @@ -78,13 +91,13 @@ def _test_case_from_arg(test_case_arg): if test_case_arg == test_case.value: return test_case else: - raise ValueError('No test case {}!'.format(test_case_arg)) + raise ValueError("No test case {}!".format(test_case_arg)) def _parse_weighted_test_cases(test_case_args): weighted_test_cases = {} - for test_case_arg in test_case_args.split(','): - name, weight = test_case_arg.split(':', 1) + for test_case_arg in test_case_args.split(","): + name, weight = test_case_arg.split(":", 1) test_case = _test_case_from_arg(name) weighted_test_cases[test_case] = int(weight) return weighted_test_cases @@ -97,14 +110,17 @@ def _get_channel(target, args): else: root_certificates = None # will load default roots. channel_credentials = grpc.ssl_channel_credentials( - root_certificates=root_certificates) - options = (( - 'grpc.ssl_target_name_override', - args.server_host_override, - ),) - channel = grpc.secure_channel(target, - channel_credentials, - options=options) + root_certificates=root_certificates + ) + options = ( + ( + "grpc.ssl_target_name_override", + args.server_host_override, + ), + ) + channel = grpc.secure_channel( + target, channel_credentials, options=options + ) else: channel = grpc.insecure_channel(target) @@ -115,7 +131,7 @@ def _get_channel(target, args): def run_test(args): test_cases = _parse_weighted_test_cases(args.test_cases) - test_server_targets = args.server_addresses.split(',') + test_server_targets = args.server_addresses.split(",") # Propagate any client exceptions with a queue exception_queue = queue.Queue() stop_event = threading.Event() @@ -124,8 +140,9 @@ def run_test(args): server = grpc.server(futures.ThreadPoolExecutor(max_workers=25)) metrics_pb2_grpc.add_MetricsServiceServicer_to_server( - metrics_server.MetricsServer(hist), server) - server.add_insecure_port('[::]:{}'.format(args.metrics_port)) + metrics_server.MetricsServer(hist), server + ) + server.add_insecure_port("[::]:{}".format(args.metrics_port)) server.start() for test_server_target in test_server_targets: @@ -133,8 +150,9 @@ def run_test(args): channel = _get_channel(test_server_target, args) for _ in range(args.num_stubs_per_channel): stub = test_pb2_grpc.TestServiceStub(channel) - runner = test_runner.TestRunner(stub, test_cases, hist, - exception_queue, stop_event) + runner = test_runner.TestRunner( + stub, test_cases, hist, exception_queue, stop_event + ) runners.append(runner) for runner in runners: @@ -155,5 +173,5 @@ def run_test(args): server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": run_test(_args()) diff --git a/src/python/grpcio_tests/tests/stress/metrics_server.py b/src/python/grpcio_tests/tests/stress/metrics_server.py index 33a74b4a38850..c090f2facea89 100644 --- a/src/python/grpcio_tests/tests/stress/metrics_server.py +++ b/src/python/grpcio_tests/tests/stress/metrics_server.py @@ -18,11 +18,10 @@ from src.proto.grpc.testing import metrics_pb2 from src.proto.grpc.testing import metrics_pb2_grpc -GAUGE_NAME = 'python_overall_qps' +GAUGE_NAME = "python_overall_qps" class MetricsServer(metrics_pb2_grpc.MetricsServiceServicer): - def __init__(self, histogram): self._start_time = time.time() self._histogram = histogram @@ -40,6 +39,6 @@ def GetAllGauges(self, request, context): def GetGauge(self, request, context): if request.name != GAUGE_NAME: - raise Exception('Gauge {} does not exist'.format(request.name)) + raise Exception("Gauge {} does not exist".format(request.name)) qps = self._get_qps() return metrics_pb2.GaugeResponse(name=GAUGE_NAME, long_value=qps) diff --git a/src/python/grpcio_tests/tests/stress/test_runner.py b/src/python/grpcio_tests/tests/stress/test_runner.py index 1b6003fc698ec..c0c2379cb9e83 100644 --- a/src/python/grpcio_tests/tests/stress/test_runner.py +++ b/src/python/grpcio_tests/tests/stress/test_runner.py @@ -33,7 +33,6 @@ def _weighted_test_case_generator(weighted_cases): class TestRunner(threading.Thread): - def __init__(self, stub, test_cases, hist, exception_queue, stop_event): super(TestRunner, self).__init__() self._exception_queue = exception_queue @@ -55,4 +54,8 @@ def run(self): self._exception_queue.put( Exception( "An exception occurred during test {}".format( - test_case), e)) + test_case + ), + e, + ) + ) diff --git a/src/python/grpcio_tests/tests/stress/unary_stream_benchmark.py b/src/python/grpcio_tests/tests/stress/unary_stream_benchmark.py index 21d7e6c60891b..2caaa90e561a6 100644 --- a/src/python/grpcio_tests/tests/stress/unary_stream_benchmark.py +++ b/src/python/grpcio_tests/tests/stress/unary_stream_benchmark.py @@ -26,7 +26,8 @@ _MESSAGE_SIZE = 4 _RESPONSE_COUNT = 32 * 1024 -_SERVER_CODE = """ +_SERVER_CODE = ( + """ import datetime import threading import grpc @@ -47,24 +48,29 @@ def Benchmark(self, request, context): unary_stream_benchmark_pb2_grpc.add_UnaryStreamBenchmarkServiceServicer_to_server(Handler(), server) server.start() server.wait_for_termination() -""" % _PORT +""" + % _PORT +) try: - from src.python.grpcio_tests.tests.stress import \ - unary_stream_benchmark_pb2_grpc + from src.python.grpcio_tests.tests.stress import ( + unary_stream_benchmark_pb2_grpc, + ) from src.python.grpcio_tests.tests.stress import unary_stream_benchmark_pb2 _GRPC_CHANNEL_OPTIONS = [ - ('grpc.max_metadata_size', 16 * 1024 * 1024), - ('grpc.max_receive_message_length', 64 * 1024 * 1024), + ("grpc.max_metadata_size", 16 * 1024 * 1024), + ("grpc.max_receive_message_length", 64 * 1024 * 1024), (grpc.experimental.ChannelOptions.SingleThreadedUnaryStream, 1), ] @contextlib.contextmanager def _running_server(): - server_process = subprocess.Popen([sys.executable, '-c', _SERVER_CODE], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + server_process = subprocess.Popen( + [sys.executable, "-c", _SERVER_CODE], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) try: yield finally: @@ -77,11 +83,16 @@ def _running_server(): def profile(message_size, response_count): request = unary_stream_benchmark_pb2.BenchmarkRequest( - message_size=message_size, response_count=response_count) - with grpc.insecure_channel('[::]:{}'.format(_PORT), - options=_GRPC_CHANNEL_OPTIONS) as channel: - stub = unary_stream_benchmark_pb2_grpc.UnaryStreamBenchmarkServiceStub( - channel) + message_size=message_size, response_count=response_count + ) + with grpc.insecure_channel( + "[::]:{}".format(_PORT), options=_GRPC_CHANNEL_OPTIONS + ) as channel: + stub = ( + unary_stream_benchmark_pb2_grpc.UnaryStreamBenchmarkServiceStub( + channel + ) + ) start = datetime.datetime.now() call = stub.Benchmark(request, wait_for_ready=True) for message in call: @@ -96,7 +107,7 @@ def main(): sys.stdout.write("{}\n".format(latency.total_seconds())) sys.stdout.flush() - if __name__ == '__main__': + if __name__ == "__main__": main() except ImportError: diff --git a/src/python/grpcio_tests/tests/testing/_application_common.py b/src/python/grpcio_tests/tests/testing/_application_common.py index 3226d1fb020b8..3a74567b92f6d 100644 --- a/src/python/grpcio_tests/tests/testing/_application_common.py +++ b/src/python/grpcio_tests/tests/testing/_application_common.py @@ -16,11 +16,11 @@ from tests.testing.proto import requests_pb2 from tests.testing.proto import services_pb2 -SERVICE_NAME = 'tests_of_grpc_testing.FirstService' -UNARY_UNARY_METHOD_NAME = 'UnUn' -UNARY_STREAM_METHOD_NAME = 'UnStre' -STREAM_UNARY_METHOD_NAME = 'StreUn' -STREAM_STREAM_METHOD_NAME = 'StreStre' +SERVICE_NAME = "tests_of_grpc_testing.FirstService" +UNARY_UNARY_METHOD_NAME = "UnUn" +UNARY_STREAM_METHOD_NAME = "UnStre" +STREAM_UNARY_METHOD_NAME = "StreUn" +STREAM_STREAM_METHOD_NAME = "StreStre" UNARY_UNARY_REQUEST = requests_pb2.Up(first_up_field=2) ERRONEOUS_UNARY_UNARY_REQUEST = requests_pb2.Up(first_up_field=3) diff --git a/src/python/grpcio_tests/tests/testing/_application_testing_common.py b/src/python/grpcio_tests/tests/testing/_application_testing_common.py index cac813c04d862..c44eb0b10d1d4 100644 --- a/src/python/grpcio_tests/tests/testing/_application_testing_common.py +++ b/src/python/grpcio_tests/tests/testing/_application_testing_common.py @@ -19,15 +19,15 @@ # TODO(https://github.com/grpc/grpc/issues/11657): Eliminate this entirely. # TODO(https://github.com/protocolbuffers/protobuf/issues/3452): Eliminate this if/else. -if services_pb2.DESCRIPTOR.services_by_name.get('FirstService') is None: - FIRST_SERVICE = 'Fix protobuf issue 3452!' - FIRST_SERVICE_UNUN = 'Fix protobuf issue 3452!' - FIRST_SERVICE_UNSTRE = 'Fix protobuf issue 3452!' - FIRST_SERVICE_STREUN = 'Fix protobuf issue 3452!' - FIRST_SERVICE_STRESTRE = 'Fix protobuf issue 3452!' +if services_pb2.DESCRIPTOR.services_by_name.get("FirstService") is None: + FIRST_SERVICE = "Fix protobuf issue 3452!" + FIRST_SERVICE_UNUN = "Fix protobuf issue 3452!" + FIRST_SERVICE_UNSTRE = "Fix protobuf issue 3452!" + FIRST_SERVICE_STREUN = "Fix protobuf issue 3452!" + FIRST_SERVICE_STRESTRE = "Fix protobuf issue 3452!" else: - FIRST_SERVICE = services_pb2.DESCRIPTOR.services_by_name['FirstService'] - FIRST_SERVICE_UNUN = FIRST_SERVICE.methods_by_name['UnUn'] - FIRST_SERVICE_UNSTRE = FIRST_SERVICE.methods_by_name['UnStre'] - FIRST_SERVICE_STREUN = FIRST_SERVICE.methods_by_name['StreUn'] - FIRST_SERVICE_STRESTRE = FIRST_SERVICE.methods_by_name['StreStre'] + FIRST_SERVICE = services_pb2.DESCRIPTOR.services_by_name["FirstService"] + FIRST_SERVICE_UNUN = FIRST_SERVICE.methods_by_name["UnUn"] + FIRST_SERVICE_UNSTRE = FIRST_SERVICE.methods_by_name["UnStre"] + FIRST_SERVICE_STREUN = FIRST_SERVICE.methods_by_name["StreUn"] + FIRST_SERVICE_STRESTRE = FIRST_SERVICE.methods_by_name["StreStre"] diff --git a/src/python/grpcio_tests/tests/testing/_client_application.py b/src/python/grpcio_tests/tests/testing/_client_application.py index 548ed30c931d5..2e713168a086b 100644 --- a/src/python/grpcio_tests/tests/testing/_client_application.py +++ b/src/python/grpcio_tests/tests/testing/_client_application.py @@ -29,18 +29,18 @@ @enum.unique class Scenario(enum.Enum): - UNARY_UNARY = 'unary unary' - UNARY_STREAM = 'unary stream' - STREAM_UNARY = 'stream unary' - STREAM_STREAM = 'stream stream' - CONCURRENT_STREAM_UNARY = 'concurrent stream unary' - CONCURRENT_STREAM_STREAM = 'concurrent stream stream' - CANCEL_UNARY_UNARY = 'cancel unary unary' - CANCEL_UNARY_STREAM = 'cancel unary stream' - INFINITE_REQUEST_STREAM = 'infinite request stream' - - -class Outcome(collections.namedtuple('Outcome', ('kind', 'code', 'details'))): + UNARY_UNARY = "unary unary" + UNARY_STREAM = "unary stream" + STREAM_UNARY = "stream unary" + STREAM_STREAM = "stream stream" + CONCURRENT_STREAM_UNARY = "concurrent stream unary" + CONCURRENT_STREAM_STREAM = "concurrent stream stream" + CANCEL_UNARY_UNARY = "cancel unary unary" + CANCEL_UNARY_STREAM = "cancel unary stream" + INFINITE_REQUEST_STREAM = "infinite request stream" + + +class Outcome(collections.namedtuple("Outcome", ("kind", "code", "details"))): """Outcome of a client application scenario. Attributes: @@ -51,9 +51,9 @@ class Outcome(collections.namedtuple('Outcome', ('kind', 'code', 'details'))): @enum.unique class Kind(enum.Enum): - SATISFACTORY = 'satisfactory' - UNSATISFACTORY = 'unsatisfactory' - RPC_ERROR = 'rpc error' + SATISFACTORY = "satisfactory" + UNSATISFACTORY = "unsatisfactory" + RPC_ERROR = "rpc error" _SATISFACTORY_OUTCOME = Outcome(Outcome.Kind.SATISFACTORY, None, None) @@ -61,7 +61,6 @@ class Kind(enum.Enum): class _Pipe(object): - def __init__(self): self._condition = threading.Condition() self._values = [] @@ -117,9 +116,12 @@ def _run_unary_stream(stub): def _run_stream_unary(stub): response, call = stub.StreUn.with_call( - iter((_application_common.STREAM_UNARY_REQUEST,) * 3)) - if (_application_common.STREAM_UNARY_RESPONSE == response and - call.code() is grpc.StatusCode.OK): + iter((_application_common.STREAM_UNARY_REQUEST,) * 3) + ) + if ( + _application_common.STREAM_UNARY_RESPONSE == response + and call.code() is grpc.StatusCode.OK + ): return _SATISFACTORY_OUTCOME else: return _UNSATISFACTORY_OUTCOME @@ -139,9 +141,11 @@ def _run_stream_stream(stub): unexpected_extra_response = False else: unexpected_extra_response = True - if (first_responses == _application_common.TWO_STREAM_STREAM_RESPONSES and - second_responses == _application_common.TWO_STREAM_STREAM_RESPONSES - and not unexpected_extra_response): + if ( + first_responses == _application_common.TWO_STREAM_STREAM_RESPONSES + and second_responses == _application_common.TWO_STREAM_STREAM_RESPONSES + and not unexpected_extra_response + ): return _SATISFACTORY_OUTCOME else: return _UNSATISFACTORY_OUTCOME @@ -149,9 +153,11 @@ def _run_stream_stream(stub): def _run_concurrent_stream_unary(stub): future_calls = tuple( - stub.StreUn.future(iter((_application_common.STREAM_UNARY_REQUEST,) * - 3)) - for _ in range(test_constants.THREAD_CONCURRENCY)) + stub.StreUn.future( + iter((_application_common.STREAM_UNARY_REQUEST,) * 3) + ) + for _ in range(test_constants.THREAD_CONCURRENCY) + ) for future_call in future_calls: if future_call.code() is grpc.StatusCode.OK: response = future_call.result() @@ -190,7 +196,8 @@ def run_stream_stream(index): def _run_cancel_unary_unary(stub): response_future_call = stub.UnUn.future( - _application_common.UNARY_UNARY_REQUEST) + _application_common.UNARY_UNARY_REQUEST + ) initial_metadata = response_future_call.initial_metadata() cancelled = response_future_call.cancel() if initial_metadata is not None and cancelled: @@ -200,14 +207,14 @@ def _run_cancel_unary_unary(stub): def _run_infinite_request_stream(stub): - def infinite_request_iterator(): while True: yield _application_common.STREAM_UNARY_REQUEST response_future_call = stub.StreUn.future( infinite_request_iterator(), - timeout=_application_common.INFINITE_REQUEST_STREAM_TIMEOUT) + timeout=_application_common.INFINITE_REQUEST_STREAM_TIMEOUT, + ) if response_future_call.code() is grpc.StatusCode.DEADLINE_EXCEEDED: return _SATISFACTORY_OUTCOME else: @@ -231,5 +238,6 @@ def run(scenario, channel): try: return _IMPLEMENTATIONS[scenario](stub) except grpc.RpcError as rpc_error: - return Outcome(Outcome.Kind.RPC_ERROR, rpc_error.code(), - rpc_error.details()) + return Outcome( + Outcome.Kind.RPC_ERROR, rpc_error.code(), rpc_error.details() + ) diff --git a/src/python/grpcio_tests/tests/testing/_client_test.py b/src/python/grpcio_tests/tests/testing/_client_test.py index d92a1a9733c34..279f89dbd0cad 100644 --- a/src/python/grpcio_tests/tests/testing/_client_test.py +++ b/src/python/grpcio_tests/tests/testing/_client_test.py @@ -30,10 +30,10 @@ # TODO(https://github.com/protocolbuffers/protobuf/issues/3452): Drop this skip. @unittest.skipIf( - services_pb2.DESCRIPTOR.services_by_name.get('FirstService') is None, - 'Fix protobuf issue 3452!') + services_pb2.DESCRIPTOR.services_by_name.get("FirstService") is None, + "Fix protobuf issue 3452!", +) class ClientTest(unittest.TestCase): - def setUp(self): # In this test the client-side application under test executes in # a separate thread while we retain use of the test thread to "play @@ -43,74 +43,108 @@ def setUp(self): self._fake_time = grpc_testing.strict_fake_time(time.time()) self._real_time = grpc_testing.strict_real_time() self._fake_time_channel = grpc_testing.channel( - services_pb2.DESCRIPTOR.services_by_name.values(), self._fake_time) + services_pb2.DESCRIPTOR.services_by_name.values(), self._fake_time + ) self._real_time_channel = grpc_testing.channel( - services_pb2.DESCRIPTOR.services_by_name.values(), self._real_time) + services_pb2.DESCRIPTOR.services_by_name.values(), self._real_time + ) def tearDown(self): self._client_execution_thread_pool.shutdown(wait=True) def test_successful_unary_unary(self): application_future = self._client_execution_thread_pool.submit( - _client_application.run, _client_application.Scenario.UNARY_UNARY, - self._real_time_channel) - invocation_metadata, request, rpc = ( - self._real_time_channel.take_unary_unary( - _application_testing_common.FIRST_SERVICE_UNUN)) + _client_application.run, + _client_application.Scenario.UNARY_UNARY, + self._real_time_channel, + ) + ( + invocation_metadata, + request, + rpc, + ) = self._real_time_channel.take_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN + ) rpc.send_initial_metadata(()) - rpc.terminate(_application_common.UNARY_UNARY_RESPONSE, (), - grpc.StatusCode.OK, '') + rpc.terminate( + _application_common.UNARY_UNARY_RESPONSE, (), grpc.StatusCode.OK, "" + ) application_return_value = application_future.result() self.assertEqual(_application_common.UNARY_UNARY_REQUEST, request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.SATISFACTORY) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY, + ) def test_successful_unary_stream(self): application_future = self._client_execution_thread_pool.submit( - _client_application.run, _client_application.Scenario.UNARY_STREAM, - self._fake_time_channel) - invocation_metadata, request, rpc = ( - self._fake_time_channel.take_unary_stream( - _application_testing_common.FIRST_SERVICE_UNSTRE)) + _client_application.run, + _client_application.Scenario.UNARY_STREAM, + self._fake_time_channel, + ) + ( + invocation_metadata, + request, + rpc, + ) = self._fake_time_channel.take_unary_stream( + _application_testing_common.FIRST_SERVICE_UNSTRE + ) rpc.send_initial_metadata(()) - rpc.terminate((), grpc.StatusCode.OK, '') + rpc.terminate((), grpc.StatusCode.OK, "") application_return_value = application_future.result() self.assertEqual(_application_common.UNARY_STREAM_REQUEST, request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.SATISFACTORY) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY, + ) def test_successful_stream_unary(self): application_future = self._client_execution_thread_pool.submit( - _client_application.run, _client_application.Scenario.STREAM_UNARY, - self._real_time_channel) + _client_application.run, + _client_application.Scenario.STREAM_UNARY, + self._real_time_channel, + ) invocation_metadata, rpc = self._real_time_channel.take_stream_unary( - _application_testing_common.FIRST_SERVICE_STREUN) + _application_testing_common.FIRST_SERVICE_STREUN + ) rpc.send_initial_metadata(()) first_request = rpc.take_request() second_request = rpc.take_request() third_request = rpc.take_request() rpc.requests_closed() - rpc.terminate(_application_common.STREAM_UNARY_RESPONSE, (), - grpc.StatusCode.OK, '') + rpc.terminate( + _application_common.STREAM_UNARY_RESPONSE, + (), + grpc.StatusCode.OK, + "", + ) application_return_value = application_future.result() - self.assertEqual(_application_common.STREAM_UNARY_REQUEST, - first_request) - self.assertEqual(_application_common.STREAM_UNARY_REQUEST, - second_request) - self.assertEqual(_application_common.STREAM_UNARY_REQUEST, - third_request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.SATISFACTORY) + self.assertEqual( + _application_common.STREAM_UNARY_REQUEST, first_request + ) + self.assertEqual( + _application_common.STREAM_UNARY_REQUEST, second_request + ) + self.assertEqual( + _application_common.STREAM_UNARY_REQUEST, third_request + ) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY, + ) def test_successful_stream_stream(self): application_future = self._client_execution_thread_pool.submit( - _client_application.run, _client_application.Scenario.STREAM_STREAM, - self._fake_time_channel) + _client_application.run, + _client_application.Scenario.STREAM_STREAM, + self._fake_time_channel, + ) invocation_metadata, rpc = self._fake_time_channel.take_stream_stream( - _application_testing_common.FIRST_SERVICE_STRESTRE) + _application_testing_common.FIRST_SERVICE_STRESTRE + ) first_request = rpc.take_request() rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) @@ -118,26 +152,34 @@ def test_successful_stream_stream(self): rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) rpc.requests_closed() - rpc.terminate((), grpc.StatusCode.OK, '') + rpc.terminate((), grpc.StatusCode.OK, "") application_return_value = application_future.result() - self.assertEqual(_application_common.STREAM_STREAM_REQUEST, - first_request) - self.assertEqual(_application_common.STREAM_STREAM_REQUEST, - second_request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.SATISFACTORY) + self.assertEqual( + _application_common.STREAM_STREAM_REQUEST, first_request + ) + self.assertEqual( + _application_common.STREAM_STREAM_REQUEST, second_request + ) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY, + ) def test_concurrent_stream_stream(self): application_future = self._client_execution_thread_pool.submit( _client_application.run, _client_application.Scenario.CONCURRENT_STREAM_STREAM, - self._real_time_channel) + self._real_time_channel, + ) rpcs = [] for _ in range(test_constants.RPC_CONCURRENCY): - invocation_metadata, rpc = ( - self._real_time_channel.take_stream_stream( - _application_testing_common.FIRST_SERVICE_STRESTRE)) + ( + invocation_metadata, + rpc, + ) = self._real_time_channel.take_stream_stream( + _application_testing_common.FIRST_SERVICE_STRESTRE + ) rpcs.append(rpc) requests = {} for rpc in rpcs: @@ -153,70 +195,99 @@ def test_concurrent_stream_stream(self): for rpc in rpcs: rpc.requests_closed() for rpc in rpcs: - rpc.terminate((), grpc.StatusCode.OK, '') + rpc.terminate((), grpc.StatusCode.OK, "") application_return_value = application_future.result() for requests_of_one_rpc in requests.values(): for request in requests_of_one_rpc: - self.assertEqual(_application_common.STREAM_STREAM_REQUEST, - request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.SATISFACTORY) + self.assertEqual( + _application_common.STREAM_STREAM_REQUEST, request + ) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY, + ) def test_cancelled_unary_unary(self): application_future = self._client_execution_thread_pool.submit( _client_application.run, _client_application.Scenario.CANCEL_UNARY_UNARY, - self._fake_time_channel) - invocation_metadata, request, rpc = ( - self._fake_time_channel.take_unary_unary( - _application_testing_common.FIRST_SERVICE_UNUN)) + self._fake_time_channel, + ) + ( + invocation_metadata, + request, + rpc, + ) = self._fake_time_channel.take_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN + ) rpc.send_initial_metadata(()) rpc.cancelled() application_return_value = application_future.result() self.assertEqual(_application_common.UNARY_UNARY_REQUEST, request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.SATISFACTORY) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY, + ) def test_status_stream_unary(self): application_future = self._client_execution_thread_pool.submit( _client_application.run, _client_application.Scenario.CONCURRENT_STREAM_UNARY, - self._fake_time_channel) + self._fake_time_channel, + ) rpcs = tuple( self._fake_time_channel.take_stream_unary( - _application_testing_common.FIRST_SERVICE_STREUN)[1] - for _ in range(test_constants.THREAD_CONCURRENCY)) + _application_testing_common.FIRST_SERVICE_STREUN + )[1] + for _ in range(test_constants.THREAD_CONCURRENCY) + ) for rpc in rpcs: rpc.take_request() rpc.take_request() rpc.take_request() rpc.requests_closed() - rpc.send_initial_metadata((( - 'my_metadata_key', - 'My Metadata Value!', - ),)) + rpc.send_initial_metadata( + ( + ( + "my_metadata_key", + "My Metadata Value!", + ), + ) + ) for rpc in rpcs[:-1]: - rpc.terminate(_application_common.STREAM_UNARY_RESPONSE, (), - grpc.StatusCode.OK, '') - rpcs[-1].terminate(_application_common.STREAM_UNARY_RESPONSE, (), - grpc.StatusCode.RESOURCE_EXHAUSTED, - 'nope; not able to handle all those RPCs!') + rpc.terminate( + _application_common.STREAM_UNARY_RESPONSE, + (), + grpc.StatusCode.OK, + "", + ) + rpcs[-1].terminate( + _application_common.STREAM_UNARY_RESPONSE, + (), + grpc.StatusCode.RESOURCE_EXHAUSTED, + "nope; not able to handle all those RPCs!", + ) application_return_value = application_future.result() - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.UNSATISFACTORY) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.UNSATISFACTORY, + ) def test_status_stream_stream(self): code = grpc.StatusCode.DEADLINE_EXCEEDED - details = 'test deadline exceeded!' + details = "test deadline exceeded!" application_future = self._client_execution_thread_pool.submit( - _client_application.run, _client_application.Scenario.STREAM_STREAM, - self._real_time_channel) + _client_application.run, + _client_application.Scenario.STREAM_STREAM, + self._real_time_channel, + ) invocation_metadata, rpc = self._real_time_channel.take_stream_stream( - _application_testing_common.FIRST_SERVICE_STRESTRE) + _application_testing_common.FIRST_SERVICE_STRESTRE + ) first_request = rpc.take_request() rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) @@ -227,37 +298,56 @@ def test_status_stream_stream(self): rpc.terminate((), code, details) application_return_value = application_future.result() - self.assertEqual(_application_common.STREAM_STREAM_REQUEST, - first_request) - self.assertEqual(_application_common.STREAM_STREAM_REQUEST, - second_request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.RPC_ERROR) + self.assertEqual( + _application_common.STREAM_STREAM_REQUEST, first_request + ) + self.assertEqual( + _application_common.STREAM_STREAM_REQUEST, second_request + ) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.RPC_ERROR, + ) self.assertIs(application_return_value.code, code) self.assertEqual(application_return_value.details, details) def test_misbehaving_server_unary_unary(self): application_future = self._client_execution_thread_pool.submit( - _client_application.run, _client_application.Scenario.UNARY_UNARY, - self._fake_time_channel) - invocation_metadata, request, rpc = ( - self._fake_time_channel.take_unary_unary( - _application_testing_common.FIRST_SERVICE_UNUN)) + _client_application.run, + _client_application.Scenario.UNARY_UNARY, + self._fake_time_channel, + ) + ( + invocation_metadata, + request, + rpc, + ) = self._fake_time_channel.take_unary_unary( + _application_testing_common.FIRST_SERVICE_UNUN + ) rpc.send_initial_metadata(()) - rpc.terminate(_application_common.ERRONEOUS_UNARY_UNARY_RESPONSE, (), - grpc.StatusCode.OK, '') + rpc.terminate( + _application_common.ERRONEOUS_UNARY_UNARY_RESPONSE, + (), + grpc.StatusCode.OK, + "", + ) application_return_value = application_future.result() self.assertEqual(_application_common.UNARY_UNARY_REQUEST, request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.UNSATISFACTORY) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.UNSATISFACTORY, + ) def test_misbehaving_server_stream_stream(self): application_future = self._client_execution_thread_pool.submit( - _client_application.run, _client_application.Scenario.STREAM_STREAM, - self._real_time_channel) + _client_application.run, + _client_application.Scenario.STREAM_STREAM, + self._real_time_channel, + ) invocation_metadata, rpc = self._real_time_channel.take_stream_stream( - _application_testing_common.FIRST_SERVICE_STRESTRE) + _application_testing_common.FIRST_SERVICE_STRESTRE + ) first_request = rpc.take_request() rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) @@ -267,42 +357,58 @@ def test_misbehaving_server_stream_stream(self): rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) rpc.send_response(_application_common.STREAM_STREAM_RESPONSE) rpc.requests_closed() - rpc.terminate((), grpc.StatusCode.OK, '') + rpc.terminate((), grpc.StatusCode.OK, "") application_return_value = application_future.result() - self.assertEqual(_application_common.STREAM_STREAM_REQUEST, - first_request) - self.assertEqual(_application_common.STREAM_STREAM_REQUEST, - second_request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.UNSATISFACTORY) + self.assertEqual( + _application_common.STREAM_STREAM_REQUEST, first_request + ) + self.assertEqual( + _application_common.STREAM_STREAM_REQUEST, second_request + ) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.UNSATISFACTORY, + ) def test_infinite_request_stream_real_time(self): application_future = self._client_execution_thread_pool.submit( _client_application.run, _client_application.Scenario.INFINITE_REQUEST_STREAM, - self._real_time_channel) + self._real_time_channel, + ) invocation_metadata, rpc = self._real_time_channel.take_stream_unary( - _application_testing_common.FIRST_SERVICE_STREUN) + _application_testing_common.FIRST_SERVICE_STREUN + ) rpc.send_initial_metadata(()) first_request = rpc.take_request() second_request = rpc.take_request() third_request = rpc.take_request() self._real_time.sleep_for( - _application_common.INFINITE_REQUEST_STREAM_TIMEOUT) - rpc.terminate(_application_common.STREAM_UNARY_RESPONSE, (), - grpc.StatusCode.DEADLINE_EXCEEDED, '') + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT + ) + rpc.terminate( + _application_common.STREAM_UNARY_RESPONSE, + (), + grpc.StatusCode.DEADLINE_EXCEEDED, + "", + ) application_return_value = application_future.result() - self.assertEqual(_application_common.STREAM_UNARY_REQUEST, - first_request) - self.assertEqual(_application_common.STREAM_UNARY_REQUEST, - second_request) - self.assertEqual(_application_common.STREAM_UNARY_REQUEST, - third_request) - self.assertIs(application_return_value.kind, - _client_application.Outcome.Kind.SATISFACTORY) - - -if __name__ == '__main__': + self.assertEqual( + _application_common.STREAM_UNARY_REQUEST, first_request + ) + self.assertEqual( + _application_common.STREAM_UNARY_REQUEST, second_request + ) + self.assertEqual( + _application_common.STREAM_UNARY_REQUEST, third_request + ) + self.assertIs( + application_return_value.kind, + _client_application.Outcome.Kind.SATISFACTORY, + ) + + +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/testing/_server_application.py b/src/python/grpcio_tests/tests/testing/_server_application.py index ffab6f5b933bb..90d51dd1640b8 100644 --- a/src/python/grpcio_tests/tests/testing/_server_application.py +++ b/src/python/grpcio_tests/tests/testing/_server_application.py @@ -37,37 +37,47 @@ def UnUn(self, request, context): elif request == _application_common.ABORT_REQUEST: with self._abort_lock: try: - context.abort(grpc.StatusCode.PERMISSION_DENIED, - "Denying permission to test abort.") + context.abort( + grpc.StatusCode.PERMISSION_DENIED, + "Denying permission to test abort.", + ) except Exception as e: # pylint: disable=broad-except - self._abort_response = _application_common.ABORT_SUCCESS_RESPONSE + self._abort_response = ( + _application_common.ABORT_SUCCESS_RESPONSE + ) else: - self._abort_status = _application_common.ABORT_FAILURE_RESPONSE + self._abort_status = ( + _application_common.ABORT_FAILURE_RESPONSE + ) return None # NOTE: For the linter. elif request == _application_common.ABORT_SUCCESS_QUERY: with self._abort_lock: return self._abort_response else: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details('Something is wrong with your request!') + context.set_details("Something is wrong with your request!") return services_pb2.Down() def UnStre(self, request, context): if _application_common.UNARY_STREAM_REQUEST != request: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details('Something is wrong with your request!') + context.set_details("Something is wrong with your request!") return yield services_pb2.Strange() # pylint: disable=unreachable def StreUn(self, request_iterator, context): - context.send_initial_metadata((( - 'server_application_metadata_key', - 'Hi there!', - ),)) + context.send_initial_metadata( + ( + ( + "server_application_metadata_key", + "Hi there!", + ), + ) + ) for request in request_iterator: if request != _application_common.STREAM_UNARY_REQUEST: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details('Something is wrong with your request!') + context.set_details("Something is wrong with your request!") return services_pb2.Strange() elif not context.is_active(): return services_pb2.Strange() @@ -75,12 +85,14 @@ def StreUn(self, request_iterator, context): return _application_common.STREAM_UNARY_RESPONSE def StreStre(self, request_iterator, context): - valid_requests = (_application_common.STREAM_STREAM_REQUEST, - _application_common.STREAM_STREAM_MUTATING_REQUEST) + valid_requests = ( + _application_common.STREAM_STREAM_REQUEST, + _application_common.STREAM_STREAM_MUTATING_REQUEST, + ) for request in request_iterator: if request not in valid_requests: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details('Something is wrong with your request!') + context.set_details("Something is wrong with your request!") return elif not context.is_active(): return @@ -90,6 +102,7 @@ def StreStre(self, request_iterator, context): elif request == _application_common.STREAM_STREAM_MUTATING_REQUEST: response = services_pb2.Bottom() for i in range( - _application_common.STREAM_STREAM_MUTATING_COUNT): + _application_common.STREAM_STREAM_MUTATING_COUNT + ): response.first_bottom_field = i yield response diff --git a/src/python/grpcio_tests/tests/testing/_server_test.py b/src/python/grpcio_tests/tests/testing/_server_test.py index 617a41b7e54de..5206950707fca 100644 --- a/src/python/grpcio_tests/tests/testing/_server_test.py +++ b/src/python/grpcio_tests/tests/testing/_server_test.py @@ -25,7 +25,6 @@ class FirstServiceServicerTest(unittest.TestCase): - def setUp(self): self._real_time = grpc_testing.strict_real_time() self._fake_time = grpc_testing.strict_fake_time(time.time()) @@ -34,14 +33,19 @@ def setUp(self): _application_testing_common.FIRST_SERVICE: servicer } self._real_time_server = grpc_testing.server_from_dictionary( - descriptors_to_servicers, self._real_time) + descriptors_to_servicers, self._real_time + ) self._fake_time_server = grpc_testing.server_from_dictionary( - descriptors_to_servicers, self._fake_time) + descriptors_to_servicers, self._fake_time + ) def test_successful_unary_unary(self): rpc = self._real_time_server.invoke_unary_unary( - _application_testing_common.FIRST_SERVICE_UNUN, (), - _application_common.UNARY_UNARY_REQUEST, None) + _application_testing_common.FIRST_SERVICE_UNUN, + (), + _application_common.UNARY_UNARY_REQUEST, + None, + ) initial_metadata = rpc.initial_metadata() response, trailing_metadata, code, details = rpc.termination() @@ -50,8 +54,11 @@ def test_successful_unary_unary(self): def test_successful_unary_stream(self): rpc = self._real_time_server.invoke_unary_stream( - _application_testing_common.FIRST_SERVICE_UNSTRE, (), - _application_common.UNARY_STREAM_REQUEST, None) + _application_testing_common.FIRST_SERVICE_UNSTRE, + (), + _application_common.UNARY_STREAM_REQUEST, + None, + ) initial_metadata = rpc.initial_metadata() trailing_metadata, code, details = rpc.termination() @@ -59,7 +66,8 @@ def test_successful_unary_stream(self): def test_successful_stream_unary(self): rpc = self._real_time_server.invoke_stream_unary( - _application_testing_common.FIRST_SERVICE_STREUN, (), None) + _application_testing_common.FIRST_SERVICE_STREUN, (), None + ) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) @@ -72,7 +80,8 @@ def test_successful_stream_unary(self): def test_successful_stream_stream(self): rpc = self._real_time_server.invoke_stream_stream( - _application_testing_common.FIRST_SERVICE_STRESTRE, (), None) + _application_testing_common.FIRST_SERVICE_STRESTRE, (), None + ) rpc.send_request(_application_common.STREAM_STREAM_REQUEST) initial_metadata = rpc.initial_metadata() responses = [ @@ -81,23 +90,27 @@ def test_successful_stream_stream(self): ] rpc.send_request(_application_common.STREAM_STREAM_REQUEST) rpc.send_request(_application_common.STREAM_STREAM_REQUEST) - responses.extend([ - rpc.take_response(), - rpc.take_response(), - rpc.take_response(), - rpc.take_response(), - ]) + responses.extend( + [ + rpc.take_response(), + rpc.take_response(), + rpc.take_response(), + rpc.take_response(), + ] + ) rpc.requests_closed() trailing_metadata, code, details = rpc.termination() for response in responses: - self.assertEqual(_application_common.STREAM_STREAM_RESPONSE, - response) + self.assertEqual( + _application_common.STREAM_STREAM_RESPONSE, response + ) self.assertIs(code, grpc.StatusCode.OK) def test_mutating_stream_stream(self): rpc = self._real_time_server.invoke_stream_stream( - _application_testing_common.FIRST_SERVICE_STRESTRE, (), None) + _application_testing_common.FIRST_SERVICE_STRESTRE, (), None + ) rpc.send_request(_application_common.STREAM_STREAM_MUTATING_REQUEST) initial_metadata = rpc.initial_metadata() responses = [ @@ -105,10 +118,12 @@ def test_mutating_stream_stream(self): for _ in range(_application_common.STREAM_STREAM_MUTATING_COUNT) ] rpc.send_request(_application_common.STREAM_STREAM_MUTATING_REQUEST) - responses.extend([ - rpc.take_response() - for _ in range(_application_common.STREAM_STREAM_MUTATING_COUNT) - ]) + responses.extend( + [ + rpc.take_response() + for _ in range(_application_common.STREAM_STREAM_MUTATING_COUNT) + ] + ) rpc.requests_closed() _, _, _ = rpc.termination() expected_responses = ( @@ -121,8 +136,11 @@ def test_mutating_stream_stream(self): def test_server_rpc_idempotence(self): rpc = self._real_time_server.invoke_unary_unary( - _application_testing_common.FIRST_SERVICE_UNUN, (), - _application_common.UNARY_UNARY_REQUEST, None) + _application_testing_common.FIRST_SERVICE_UNUN, + (), + _application_common.UNARY_UNARY_REQUEST, + None, + ) first_initial_metadata = rpc.initial_metadata() second_initial_metadata = rpc.initial_metadata() third_initial_metadata = rpc.initial_metadata() @@ -131,8 +149,8 @@ def test_server_rpc_idempotence(self): third_termination = rpc.termination() for later_initial_metadata in ( - second_initial_metadata, - third_initial_metadata, + second_initial_metadata, + third_initial_metadata, ): self.assertEqual(first_initial_metadata, later_initial_metadata) response = first_termination[0] @@ -140,8 +158,8 @@ def test_server_rpc_idempotence(self): code = first_termination[2] details = first_termination[3] for later_termination in ( - second_termination, - third_termination, + second_termination, + third_termination, ): self.assertEqual(response, later_termination[0]) self.assertEqual(terminal_metadata, later_termination[1]) @@ -152,8 +170,11 @@ def test_server_rpc_idempotence(self): def test_misbehaving_client_unary_unary(self): rpc = self._real_time_server.invoke_unary_unary( - _application_testing_common.FIRST_SERVICE_UNUN, (), - _application_common.ERRONEOUS_UNARY_UNARY_REQUEST, None) + _application_testing_common.FIRST_SERVICE_UNUN, + (), + _application_common.ERRONEOUS_UNARY_UNARY_REQUEST, + None, + ) initial_metadata = rpc.initial_metadata() response, trailing_metadata, code, details = rpc.termination() @@ -161,14 +182,17 @@ def test_misbehaving_client_unary_unary(self): def test_infinite_request_stream_real_time(self): rpc = self._real_time_server.invoke_stream_unary( - _application_testing_common.FIRST_SERVICE_STREUN, (), - _application_common.INFINITE_REQUEST_STREAM_TIMEOUT) + _application_testing_common.FIRST_SERVICE_STREUN, + (), + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT, + ) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) initial_metadata = rpc.initial_metadata() self._real_time.sleep_for( - _application_common.INFINITE_REQUEST_STREAM_TIMEOUT * 2) + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT * 2 + ) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) response, trailing_metadata, code, details = rpc.termination() @@ -176,14 +200,17 @@ def test_infinite_request_stream_real_time(self): def test_infinite_request_stream_fake_time(self): rpc = self._fake_time_server.invoke_stream_unary( - _application_testing_common.FIRST_SERVICE_STREUN, (), - _application_common.INFINITE_REQUEST_STREAM_TIMEOUT) + _application_testing_common.FIRST_SERVICE_STREUN, + (), + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT, + ) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) initial_metadata = rpc.initial_metadata() self._fake_time.sleep_for( - _application_common.INFINITE_REQUEST_STREAM_TIMEOUT * 2) + _application_common.INFINITE_REQUEST_STREAM_TIMEOUT * 2 + ) rpc.send_request(_application_common.STREAM_UNARY_REQUEST) response, trailing_metadata, code, details = rpc.termination() @@ -191,17 +218,23 @@ def test_infinite_request_stream_fake_time(self): def test_servicer_context_abort(self): rpc = self._real_time_server.invoke_unary_unary( - _application_testing_common.FIRST_SERVICE_UNUN, (), - _application_common.ABORT_REQUEST, None) + _application_testing_common.FIRST_SERVICE_UNUN, + (), + _application_common.ABORT_REQUEST, + None, + ) _, _, code, _ = rpc.termination() self.assertIs(code, grpc.StatusCode.PERMISSION_DENIED) rpc = self._real_time_server.invoke_unary_unary( - _application_testing_common.FIRST_SERVICE_UNUN, (), - _application_common.ABORT_SUCCESS_QUERY, None) + _application_testing_common.FIRST_SERVICE_UNUN, + (), + _application_common.ABORT_SUCCESS_QUERY, + None, + ) response, _, code, _ = rpc.termination() self.assertEqual(_application_common.ABORT_SUCCESS_RESPONSE, response) self.assertIs(code, grpc.StatusCode.OK) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/testing/_time_test.py b/src/python/grpcio_tests/tests/testing/_time_test.py index cab665c045c7f..3e5273f32e7a0 100644 --- a/src/python/grpcio_tests/tests/testing/_time_test.py +++ b/src/python/grpcio_tests/tests/testing/_time_test.py @@ -25,11 +25,10 @@ # eventually run what needs to be run (and risk timing out) or declare # that the scheduler didn't schedule work reasonably fast enough. We # choose the latter for this test. -_PATHOLOGICAL_SCHEDULING = 'pathological thread scheduling!' +_PATHOLOGICAL_SCHEDULING = "pathological thread scheduling!" class _TimeNoter(object): - def __init__(self, time): self._condition = threading.Condition() self._time = time @@ -45,7 +44,6 @@ def call_times(self): class TimeTest(object): - def test_sleep_for(self): start_time = self._time.time() self._time.sleep_for(_QUANTUM) @@ -102,11 +100,14 @@ def test_many(self): for test_event in test_events: possibly_cancelled_futures[test_event] = self._time.call_in( - test_event.set, _QUANTUM * (2 + random.random())) + test_event.set, _QUANTUM * (2 + random.random()) + ) for _ in range(_MANY): background_noise_futures.append( - self._time.call_in(threading.Event().set, - _QUANTUM * 1000 * random.random())) + self._time.call_in( + threading.Event().set, _QUANTUM * 1000 * random.random() + ) + ) self._time.sleep_for(_QUANTUM) cancelled = set() for test_event, test_future in possibly_cancelled_futures.items(): @@ -116,7 +117,8 @@ def test_many(self): for test_event in test_events: (self.assertFalse if test_event in cancelled else self.assertTrue)( - test_event.is_set()) + test_event.is_set() + ) for background_noise_future in background_noise_futures: background_noise_future.cancel() @@ -149,17 +151,16 @@ def test_same_behavior_used_several_times(self): class StrictRealTimeTest(TimeTest, unittest.TestCase): - def setUp(self): self._time = grpc_testing.strict_real_time() class StrictFakeTimeTest(TimeTest, unittest.TestCase): - def setUp(self): self._time = grpc_testing.strict_fake_time( - random.randint(0, int(time.time()))) + random.randint(0, int(time.time())) + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_abort_test.py b/src/python/grpcio_tests/tests/unit/_abort_test.py index 84604726c993c..731bb741bec3f 100644 --- a/src/python/grpcio_tests/tests/unit/_abort_test.py +++ b/src/python/grpcio_tests/tests/unit/_abort_test.py @@ -24,21 +24,21 @@ from tests.unit import test_common from tests.unit.framework.common import test_constants -_ABORT = '/test/abort' -_ABORT_WITH_STATUS = '/test/AbortWithStatus' -_INVALID_CODE = '/test/InvalidCode' +_ABORT = "/test/abort" +_ABORT_WITH_STATUS = "/test/AbortWithStatus" +_INVALID_CODE = "/test/InvalidCode" -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" -_ABORT_DETAILS = 'Abandon ship!' -_ABORT_METADATA = (('a-trailing-metadata', '42'),) +_ABORT_DETAILS = "Abandon ship!" +_ABORT_METADATA = (("a-trailing-metadata", "42"),) class _Status( - collections.namedtuple('_Status', - ('code', 'details', 'trailing_metadata')), - grpc.Status): + collections.namedtuple("_Status", ("code", "details", "trailing_metadata")), + grpc.Status, +): pass @@ -55,7 +55,7 @@ def abort_unary_unary(request, servicer_context): grpc.StatusCode.INTERNAL, _ABORT_DETAILS, ) - raise Exception('This line should not be executed!') + raise Exception("This line should not be executed!") def abort_with_status_unary_unary(request, servicer_context): @@ -64,8 +64,9 @@ def abort_with_status_unary_unary(request, servicer_context): code=grpc.StatusCode.INTERNAL, details=_ABORT_DETAILS, trailing_metadata=_ABORT_METADATA, - )) - raise Exception('This line should not be executed!') + ) + ) + raise Exception("This line should not be executed!") def invalid_code_unary_unary(request, servicer_context): @@ -76,29 +77,29 @@ def invalid_code_unary_unary(request, servicer_context): class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _ABORT: return grpc.unary_unary_rpc_method_handler(abort_unary_unary) elif handler_call_details.method == _ABORT_WITH_STATUS: return grpc.unary_unary_rpc_method_handler( - abort_with_status_unary_unary) + abort_with_status_unary_unary + ) elif handler_call_details.method == _INVALID_CODE: return grpc.stream_stream_rpc_method_handler( - invalid_code_unary_unary) + invalid_code_unary_unary + ) else: return None class AbortTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") self._server.add_generic_rpc_handlers((_GenericHandler(),)) self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._channel.close() @@ -149,6 +150,6 @@ def test_invalid_code(self): self.assertEqual(rpc_error.details(), _ABORT_DETAILS) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py index 44243ab0c3457..10a5fefcf567d 100644 --- a/src/python/grpcio_tests/tests/unit/_api_test.py +++ b/src/python/grpcio_tests/tests/unit/_api_test.py @@ -22,99 +22,100 @@ class AllTest(unittest.TestCase): - def testAll(self): expected_grpc_code_elements = ( - 'FutureTimeoutError', - 'FutureCancelledError', - 'Future', - 'ChannelConnectivity', - 'Compression', - 'StatusCode', - 'Status', - 'RpcError', - 'RpcContext', - 'Call', - 'ChannelCredentials', - 'CallCredentials', - 'AuthMetadataContext', - 'AuthMetadataPluginCallback', - 'AuthMetadataPlugin', - 'ServerCertificateConfiguration', - 'ServerCredentials', - 'UnaryUnaryMultiCallable', - 'UnaryStreamMultiCallable', - 'StreamUnaryMultiCallable', - 'StreamStreamMultiCallable', - 'UnaryUnaryClientInterceptor', - 'UnaryStreamClientInterceptor', - 'StreamUnaryClientInterceptor', - 'StreamStreamClientInterceptor', - 'Channel', - 'ServicerContext', - 'RpcMethodHandler', - 'HandlerCallDetails', - 'GenericRpcHandler', - 'ServiceRpcHandler', - 'Server', - 'ServerInterceptor', - 'LocalConnectionType', - 'local_channel_credentials', - 'local_server_credentials', - 'alts_channel_credentials', - 'alts_server_credentials', - 'unary_unary_rpc_method_handler', - 'unary_stream_rpc_method_handler', - 'stream_unary_rpc_method_handler', - 'ClientCallDetails', - 'stream_stream_rpc_method_handler', - 'method_handlers_generic_handler', - 'ssl_channel_credentials', - 'metadata_call_credentials', - 'access_token_call_credentials', - 'composite_call_credentials', - 'composite_channel_credentials', - 'compute_engine_channel_credentials', - 'ssl_server_credentials', - 'ssl_server_certificate_configuration', - 'dynamic_ssl_server_credentials', - 'channel_ready_future', - 'insecure_channel', - 'secure_channel', - 'intercept_channel', - 'server', - 'protos', - 'services', - 'protos_and_services', - 'xds_channel_credentials', - 'xds_server_credentials', - 'insecure_server_credentials', + "FutureTimeoutError", + "FutureCancelledError", + "Future", + "ChannelConnectivity", + "Compression", + "StatusCode", + "Status", + "RpcError", + "RpcContext", + "Call", + "ChannelCredentials", + "CallCredentials", + "AuthMetadataContext", + "AuthMetadataPluginCallback", + "AuthMetadataPlugin", + "ServerCertificateConfiguration", + "ServerCredentials", + "UnaryUnaryMultiCallable", + "UnaryStreamMultiCallable", + "StreamUnaryMultiCallable", + "StreamStreamMultiCallable", + "UnaryUnaryClientInterceptor", + "UnaryStreamClientInterceptor", + "StreamUnaryClientInterceptor", + "StreamStreamClientInterceptor", + "Channel", + "ServicerContext", + "RpcMethodHandler", + "HandlerCallDetails", + "GenericRpcHandler", + "ServiceRpcHandler", + "Server", + "ServerInterceptor", + "LocalConnectionType", + "local_channel_credentials", + "local_server_credentials", + "alts_channel_credentials", + "alts_server_credentials", + "unary_unary_rpc_method_handler", + "unary_stream_rpc_method_handler", + "stream_unary_rpc_method_handler", + "ClientCallDetails", + "stream_stream_rpc_method_handler", + "method_handlers_generic_handler", + "ssl_channel_credentials", + "metadata_call_credentials", + "access_token_call_credentials", + "composite_call_credentials", + "composite_channel_credentials", + "compute_engine_channel_credentials", + "ssl_server_credentials", + "ssl_server_certificate_configuration", + "dynamic_ssl_server_credentials", + "channel_ready_future", + "insecure_channel", + "secure_channel", + "intercept_channel", + "server", + "protos", + "services", + "protos_and_services", + "xds_channel_credentials", + "xds_server_credentials", + "insecure_server_credentials", ) - self.assertCountEqual(expected_grpc_code_elements, - _from_grpc_import_star.GRPC_ELEMENTS) + self.assertCountEqual( + expected_grpc_code_elements, _from_grpc_import_star.GRPC_ELEMENTS + ) class ChannelConnectivityTest(unittest.TestCase): - def testChannelConnectivity(self): - self.assertSequenceEqual(( - grpc.ChannelConnectivity.IDLE, - grpc.ChannelConnectivity.CONNECTING, - grpc.ChannelConnectivity.READY, - grpc.ChannelConnectivity.TRANSIENT_FAILURE, - grpc.ChannelConnectivity.SHUTDOWN, - ), tuple(grpc.ChannelConnectivity)) + self.assertSequenceEqual( + ( + grpc.ChannelConnectivity.IDLE, + grpc.ChannelConnectivity.CONNECTING, + grpc.ChannelConnectivity.READY, + grpc.ChannelConnectivity.TRANSIENT_FAILURE, + grpc.ChannelConnectivity.SHUTDOWN, + ), + tuple(grpc.ChannelConnectivity), + ) class ChannelTest(unittest.TestCase): - def test_secure_channel(self): channel_credentials = grpc.ssl_channel_credentials() - channel = grpc.secure_channel('google.com:443', channel_credentials) + channel = grpc.secure_channel("google.com:443", channel_credentials) channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py index f82a583dcd0d2..039c908c3e56a 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -24,53 +24,60 @@ from tests.unit import resources from tests.unit import test_common -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" -_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_UNARY = "/test/UnaryUnary" -_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_SERVER_HOST_OVERRIDE = "foo.test.google.fr" _CLIENT_IDS = ( - b'*.test.google.fr', - b'waterzooi.test.google.be', - b'*.test.youtube.com', - b'192.168.1.3', + b"*.test.google.fr", + b"waterzooi.test.google.be", + b"*.test.youtube.com", + b"192.168.1.3", ) -_ID = 'id' -_ID_KEY = 'id_key' -_AUTH_CTX = 'auth_ctx' +_ID = "id" +_ID_KEY = "id_key" +_AUTH_CTX = "auth_ctx" _PRIVATE_KEY = resources.private_key() _CERTIFICATE_CHAIN = resources.certificate_chain() _TEST_ROOT_CERTIFICATES = resources.test_root_certificates() _SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) -_PROPERTY_OPTIONS = (( - 'grpc.ssl_target_name_override', - _SERVER_HOST_OVERRIDE, -),) +_PROPERTY_OPTIONS = ( + ( + "grpc.ssl_target_name_override", + _SERVER_HOST_OVERRIDE, + ), +) def handle_unary_unary(request, servicer_context): - return pickle.dumps({ - _ID: servicer_context.peer_identities(), - _ID_KEY: servicer_context.peer_identity_key(), - _AUTH_CTX: servicer_context.auth_context() - }) + return pickle.dumps( + { + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context(), + } + ) class AuthContextTest(unittest.TestCase): - def testInsecure(self): - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(handle_unary_unary) - }) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + handle_unary_unary + ) + }, + ) server = test_common.test_server() server.add_generic_rpc_handlers((handler,)) - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") server.start() - with grpc.insecure_channel('localhost:%d' % port) as channel: + with grpc.insecure_channel("localhost:%d" % port) as channel: response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) server.stop(None) @@ -79,26 +86,35 @@ def testInsecure(self): self.assertIsNone(auth_data[_ID_KEY]) self.assertDictEqual( { - 'security_level': [b'TSI_SECURITY_NONE'], - 'transport_security_type': [b'insecure'], - }, auth_data[_AUTH_CTX]) + "security_level": [b"TSI_SECURITY_NONE"], + "transport_security_type": [b"insecure"], + }, + auth_data[_AUTH_CTX], + ) def testSecureNoCert(self): - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(handle_unary_unary) - }) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + handle_unary_unary + ) + }, + ) server = test_common.test_server() server.add_generic_rpc_handlers((handler,)) server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) - port = server.add_secure_port('[::]:0', server_cred) + port = server.add_secure_port("[::]:0", server_cred) server.start() channel_creds = grpc.ssl_channel_credentials( - root_certificates=_TEST_ROOT_CERTIFICATES) - channel = grpc.secure_channel('localhost:{}'.format(port), - channel_creds, - options=_PROPERTY_OPTIONS) + root_certificates=_TEST_ROOT_CERTIFICATES + ) + channel = grpc.secure_channel( + "localhost:{}".format(port), + channel_creds, + options=_PROPERTY_OPTIONS, + ) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) channel.close() server.stop(None) @@ -108,32 +124,42 @@ def testSecureNoCert(self): self.assertIsNone(auth_data[_ID_KEY]) self.assertDictEqual( { - 'security_level': [b'TSI_PRIVACY_AND_INTEGRITY'], - 'transport_security_type': [b'ssl'], - 'ssl_session_reused': [b'false'], - }, auth_data[_AUTH_CTX]) + "security_level": [b"TSI_PRIVACY_AND_INTEGRITY"], + "transport_security_type": [b"ssl"], + "ssl_session_reused": [b"false"], + }, + auth_data[_AUTH_CTX], + ) def testSecureClientCert(self): - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(handle_unary_unary) - }) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + handle_unary_unary + ) + }, + ) server = test_common.test_server() server.add_generic_rpc_handlers((handler,)) server_cred = grpc.ssl_server_credentials( _SERVER_CERTS, root_certificates=_TEST_ROOT_CERTIFICATES, - require_client_auth=True) - port = server.add_secure_port('[::]:0', server_cred) + require_client_auth=True, + ) + port = server.add_secure_port("[::]:0", server_cred) server.start() channel_creds = grpc.ssl_channel_credentials( root_certificates=_TEST_ROOT_CERTIFICATES, private_key=_PRIVATE_KEY, - certificate_chain=_CERTIFICATE_CHAIN) - channel = grpc.secure_channel('localhost:{}'.format(port), - channel_creds, - options=_PROPERTY_OPTIONS) + certificate_chain=_CERTIFICATE_CHAIN, + ) + channel = grpc.secure_channel( + "localhost:{}".format(port), + channel_creds, + options=_PROPERTY_OPTIONS, + ) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) channel.close() @@ -142,55 +168,69 @@ def testSecureClientCert(self): auth_data = pickle.loads(response) auth_ctx = auth_data[_AUTH_CTX] self.assertCountEqual(_CLIENT_IDS, auth_data[_ID]) - self.assertEqual('x509_subject_alternative_name', auth_data[_ID_KEY]) - self.assertSequenceEqual([b'ssl'], auth_ctx['transport_security_type']) - self.assertSequenceEqual([b'*.test.google.com'], - auth_ctx['x509_common_name']) - - def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, - expect_ssl_session_reused): - channel = grpc.secure_channel('localhost:{}'.format(port), - channel_creds, - options=channel_options) + self.assertEqual("x509_subject_alternative_name", auth_data[_ID_KEY]) + self.assertSequenceEqual([b"ssl"], auth_ctx["transport_security_type"]) + self.assertSequenceEqual( + [b"*.test.google.com"], auth_ctx["x509_common_name"] + ) + + def _do_one_shot_client_rpc( + self, channel_creds, channel_options, port, expect_ssl_session_reused + ): + channel = grpc.secure_channel( + "localhost:{}".format(port), channel_creds, options=channel_options + ) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) auth_data = pickle.loads(response) - self.assertEqual(expect_ssl_session_reused, - auth_data[_AUTH_CTX]['ssl_session_reused']) + self.assertEqual( + expect_ssl_session_reused, + auth_data[_AUTH_CTX]["ssl_session_reused"], + ) channel.close() def testSessionResumption(self): # Set up a secure server - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(handle_unary_unary) - }) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + handle_unary_unary + ) + }, + ) server = test_common.test_server() server.add_generic_rpc_handlers((handler,)) server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) - port = server.add_secure_port('[::]:0', server_cred) + port = server.add_secure_port("[::]:0", server_cred) server.start() # Create a cache for TLS session tickets cache = session_cache.ssl_session_cache_lru(1) channel_creds = grpc.ssl_channel_credentials( - root_certificates=_TEST_ROOT_CERTIFICATES) + root_certificates=_TEST_ROOT_CERTIFICATES + ) channel_options = _PROPERTY_OPTIONS + ( - ('grpc.ssl_session_cache', cache),) + ("grpc.ssl_session_cache", cache), + ) # Initial connection has no session to resume - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port, - expect_ssl_session_reused=[b'false']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b"false"], + ) # Subsequent connections resume sessions - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port, - expect_ssl_session_reused=[b'true']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b"true"], + ) server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_auth_test.py b/src/python/grpcio_tests/tests/unit/_auth_test.py index 345239e0b8708..3127c36561be8 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_test.py @@ -22,27 +22,25 @@ class MockGoogleCreds(object): - def get_access_token(self): - token = collections.namedtuple('MockAccessTokenInfo', - ('access_token', 'expires_in')) - token.access_token = 'token' + token = collections.namedtuple( + "MockAccessTokenInfo", ("access_token", "expires_in") + ) + token.access_token = "token" return token class MockExceptionGoogleCreds(object): - def get_access_token(self): raise Exception() class GoogleCallCredentialsTest(unittest.TestCase): - def test_google_call_credentials_success(self): callback_event = threading.Event() def mock_callback(metadata, error): - self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertEqual(metadata, (("authorization", "Bearer token"),)) self.assertIsNone(error) callback_event.set() @@ -63,20 +61,19 @@ def mock_callback(metadata, error): class AccessTokenAuthMetadataPluginTest(unittest.TestCase): - def test_google_call_credentials_success(self): callback_event = threading.Event() def mock_callback(metadata, error): - self.assertEqual(metadata, (('authorization', 'Bearer token'),)) + self.assertEqual(metadata, (("authorization", "Bearer token"),)) self.assertIsNone(error) callback_event.set() - metadata_plugin = _auth.AccessTokenAuthMetadataPlugin('token') + metadata_plugin = _auth.AccessTokenAuthMetadataPlugin("token") metadata_plugin(None, mock_callback) self.assertTrue(callback_event.wait(1.0)) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_channel_args_test.py b/src/python/grpcio_tests/tests/unit/_channel_args_test.py index d71906f6f41de..e47f51a1e2f8a 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_args_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_args_test.py @@ -21,45 +21,44 @@ class TestPointerWrapper(object): - def __int__(self): return 123456 TEST_CHANNEL_ARGS = ( - ('arg1', b'bytes_val'), - ('arg2', 'str_val'), - ('arg3', 1), - (b'arg4', 'str_val'), - ('arg6', TestPointerWrapper()), + ("arg1", b"bytes_val"), + ("arg2", "str_val"), + ("arg3", 1), + (b"arg4", "str_val"), + ("arg6", TestPointerWrapper()), ) INVALID_TEST_CHANNEL_ARGS = [ - { - 'foo': 'bar' - }, - (('key',),), - 'str', + {"foo": "bar"}, + (("key",),), + "str", ] class ChannelArgsTest(unittest.TestCase): - def test_client(self): - grpc.insecure_channel('localhost:8080', options=TEST_CHANNEL_ARGS) + grpc.insecure_channel("localhost:8080", options=TEST_CHANNEL_ARGS) def test_server(self): - grpc.server(futures.ThreadPoolExecutor(max_workers=1), - options=TEST_CHANNEL_ARGS) + grpc.server( + futures.ThreadPoolExecutor(max_workers=1), options=TEST_CHANNEL_ARGS + ) def test_invalid_client_args(self): for invalid_arg in INVALID_TEST_CHANNEL_ARGS: - self.assertRaises(ValueError, - grpc.insecure_channel, - 'localhost:8080', - options=invalid_arg) + self.assertRaises( + ValueError, + grpc.insecure_channel, + "localhost:8080", + options=invalid_arg, + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_channel_close_test.py b/src/python/grpcio_tests/tests/unit/_channel_close_test.py index 47f52b4890ec7..4e5f215af89a9 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_close_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_close_test.py @@ -28,12 +28,11 @@ _SOME_TIME = 5 _MORE_TIME = 10 -_STREAM_URI = 'Meffod' -_UNARY_URI = 'MeffodMan' +_STREAM_URI = "Meffod" +_UNARY_URI = "MeffodMan" class _StreamingMethodHandler(grpc.RpcMethodHandler): - request_streaming = True response_streaming = True request_deserializer = None @@ -45,7 +44,6 @@ def stream_stream(self, request_iterator, servicer_context): class _UnaryMethodHandler(grpc.RpcMethodHandler): - request_streaming = False response_streaming = False request_deserializer = None @@ -60,7 +58,6 @@ def unary_unary(self, request, servicer_context): class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _STREAM_URI: return _STREAMING_METHOD_HANDLER @@ -72,7 +69,6 @@ def service(self, handler_call_details): class _Pipe(object): - def __init__(self, values): self._condition = threading.Condition() self._values = list(values) @@ -114,19 +110,19 @@ def __exit__(self, type, value, traceback): class ChannelCloseTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server( - max_workers=test_constants.THREAD_CONCURRENCY) + max_workers=test_constants.THREAD_CONCURRENCY + ) self._server.add_generic_rpc_handlers((_GENERIC_HANDLER,)) - self._port = self._server.add_insecure_port('[::]:0') + self._port = self._server.add_insecure_port("[::]:0") self._server.start() def tearDown(self): self._server.stop(None) def test_close_immediately_after_call_invocation(self): - channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + channel = grpc.insecure_channel("localhost:{}".format(self._port)) multi_callable = channel.stream_stream(_STREAM_URI) request_iterator = _Pipe(()) response_iterator = multi_callable(request_iterator) @@ -136,9 +132,9 @@ def test_close_immediately_after_call_invocation(self): self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_close_while_call_active(self): - channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + channel = grpc.insecure_channel("localhost:{}".format(self._port)) multi_callable = channel.stream_stream(_STREAM_URI) - request_iterator = _Pipe((b'abc',)) + request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) channel.close() @@ -147,10 +143,11 @@ def test_close_while_call_active(self): self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_context_manager_close_while_call_active(self): - with grpc.insecure_channel('localhost:{}'.format( - self._port)) as channel: # pylint: disable=bad-continuation + with grpc.insecure_channel( + "localhost:{}".format(self._port) + ) as channel: # pylint: disable=bad-continuation multi_callable = channel.stream_stream(_STREAM_URI) - request_iterator = _Pipe((b'abc',)) + request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) request_iterator.close() @@ -158,12 +155,14 @@ def test_context_manager_close_while_call_active(self): self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_context_manager_close_while_many_calls_active(self): - with grpc.insecure_channel('localhost:{}'.format( - self._port)) as channel: # pylint: disable=bad-continuation + with grpc.insecure_channel( + "localhost:{}".format(self._port) + ) as channel: # pylint: disable=bad-continuation multi_callable = channel.stream_stream(_STREAM_URI) request_iterators = tuple( - _Pipe((b'abc',)) - for _ in range(test_constants.THREAD_CONCURRENCY)) + _Pipe((b"abc",)) + for _ in range(test_constants.THREAD_CONCURRENCY) + ) response_iterators = [] for request_iterator in request_iterators: response_iterator = multi_callable(request_iterator) @@ -176,9 +175,9 @@ def test_context_manager_close_while_many_calls_active(self): self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_many_concurrent_closes(self): - channel = grpc.insecure_channel('localhost:{}'.format(self._port)) + channel = grpc.insecure_channel("localhost:{}".format(self._port)) multi_callable = channel.stream_stream(_STREAM_URI) - request_iterator = _Pipe((b'abc',)) + request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) start = time.time() @@ -192,7 +191,7 @@ def sleep_some_time_then_close(): close_thread = threading.Thread(target=sleep_some_time_then_close) close_thread.start() while True: - request_iterator.add(b'def') + request_iterator.add(b"def") time.sleep(_BEAT) if end < time.time(): break @@ -201,12 +200,13 @@ def sleep_some_time_then_close(): self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) def test_exception_in_callback(self): - with grpc.insecure_channel('localhost:{}'.format( - self._port)) as channel: + with grpc.insecure_channel( + "localhost:{}".format(self._port) + ) as channel: stream_multi_callable = channel.stream_stream(_STREAM_URI) - endless_iterator = itertools.repeat(b'abc') + endless_iterator = itertools.repeat(b"abc") stream_response_iterator = stream_multi_callable(endless_iterator) - future = channel.unary_unary(_UNARY_URI).future(b'abc') + future = channel.unary_unary(_UNARY_URI).future(b"abc") def on_done_callback(future): raise Exception("This should not cause a deadlock.") @@ -215,6 +215,6 @@ def on_done_callback(future): future.result() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py index 912d8290a4c4c..458e76f93d0f8 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py @@ -33,7 +33,6 @@ def _last_connectivity_is_not_ready(connectivities): class _Callback(object): - def __init__(self): self._condition = threading.Condition() self._connectivities = [] @@ -58,16 +57,16 @@ def block_until_connectivities_satisfy(self, predicate): class ChannelConnectivityTest(unittest.TestCase): - def test_lonely_channel_connectivity(self): callback = _Callback() - channel = grpc.insecure_channel('localhost:12345') + channel = grpc.insecure_channel("localhost:12345") channel.subscribe(callback.update, try_to_connect=False) first_connectivities = callback.block_until_connectivities_satisfy(bool) channel.subscribe(callback.update, try_to_connect=True) second_connectivities = callback.block_until_connectivities_satisfy( - lambda connectivities: 2 <= len(connectivities)) + lambda connectivities: 2 <= len(connectivities) + ) # Wait for a connection that will never happen. time.sleep(test_constants.SHORT_TIMEOUT) third_connectivities = callback.connectivities() @@ -78,8 +77,9 @@ def test_lonely_channel_connectivity(self): channel.close() - self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), - first_connectivities) + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), first_connectivities + ) self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities) self.assertNotIn(grpc.ChannelConnectivity.READY, third_connectivities) self.assertNotIn(grpc.ChannelConnectivity.READY, fourth_connectivities) @@ -87,70 +87,88 @@ def test_lonely_channel_connectivity(self): def test_immediately_connectable_channel_connectivity(self): recording_thread_pool = thread_pool.RecordingThreadPool( - max_workers=None) - server = grpc.server(recording_thread_pool, - options=(('grpc.so_reuseport', 0),)) - port = server.add_insecure_port('[::]:0') + max_workers=None + ) + server = grpc.server( + recording_thread_pool, options=(("grpc.so_reuseport", 0),) + ) + port = server.add_insecure_port("[::]:0") server.start() first_callback = _Callback() second_callback = _Callback() - channel = grpc.insecure_channel('localhost:{}'.format(port)) + channel = grpc.insecure_channel("localhost:{}".format(port)) channel.subscribe(first_callback.update, try_to_connect=False) - first_connectivities = first_callback.block_until_connectivities_satisfy( - bool) + first_connectivities = ( + first_callback.block_until_connectivities_satisfy(bool) + ) # Wait for a connection that will never happen because try_to_connect=True # has not yet been passed. time.sleep(test_constants.SHORT_TIMEOUT) second_connectivities = first_callback.connectivities() channel.subscribe(second_callback.update, try_to_connect=True) - third_connectivities = first_callback.block_until_connectivities_satisfy( - lambda connectivities: 2 <= len(connectivities)) - fourth_connectivities = second_callback.block_until_connectivities_satisfy( - bool) + third_connectivities = ( + first_callback.block_until_connectivities_satisfy( + lambda connectivities: 2 <= len(connectivities) + ) + ) + fourth_connectivities = ( + second_callback.block_until_connectivities_satisfy(bool) + ) # Wait for a connection that will happen (or may already have happened). first_callback.block_until_connectivities_satisfy( - _ready_in_connectivities) + _ready_in_connectivities + ) second_callback.block_until_connectivities_satisfy( - _ready_in_connectivities) + _ready_in_connectivities + ) channel.close() server.stop(None) - self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), - first_connectivities) - self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,), - second_connectivities) - self.assertNotIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE, - third_connectivities) - self.assertNotIn(grpc.ChannelConnectivity.SHUTDOWN, - third_connectivities) - self.assertNotIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE, - fourth_connectivities) - self.assertNotIn(grpc.ChannelConnectivity.SHUTDOWN, - fourth_connectivities) + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), first_connectivities + ) + self.assertSequenceEqual( + (grpc.ChannelConnectivity.IDLE,), second_connectivities + ) + self.assertNotIn( + grpc.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities + ) + self.assertNotIn( + grpc.ChannelConnectivity.SHUTDOWN, third_connectivities + ) + self.assertNotIn( + grpc.ChannelConnectivity.TRANSIENT_FAILURE, fourth_connectivities + ) + self.assertNotIn( + grpc.ChannelConnectivity.SHUTDOWN, fourth_connectivities + ) self.assertFalse(recording_thread_pool.was_used()) def test_reachable_then_unreachable_channel_connectivity(self): recording_thread_pool = thread_pool.RecordingThreadPool( - max_workers=None) - server = grpc.server(recording_thread_pool, - options=(('grpc.so_reuseport', 0),)) - port = server.add_insecure_port('[::]:0') + max_workers=None + ) + server = grpc.server( + recording_thread_pool, options=(("grpc.so_reuseport", 0),) + ) + port = server.add_insecure_port("[::]:0") server.start() callback = _Callback() - channel = grpc.insecure_channel('localhost:{}'.format(port)) + channel = grpc.insecure_channel("localhost:{}".format(port)) channel.subscribe(callback.update, try_to_connect=True) callback.block_until_connectivities_satisfy(_ready_in_connectivities) # Now take down the server and confirm that channel readiness is repudiated. server.stop(None) callback.block_until_connectivities_satisfy( - _last_connectivity_is_not_ready) + _last_connectivity_is_not_ready + ) channel.unsubscribe(callback.update) channel.close() self.assertFalse(recording_thread_pool.was_used()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py index 84a6f9196b3a3..e6e819e9ca843 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py @@ -24,7 +24,6 @@ class _Callback(object): - def __init__(self): self._condition = threading.Condition() self._value = None @@ -42,9 +41,8 @@ def block_until_called(self): class ChannelReadyFutureTest(unittest.TestCase): - def test_lonely_channel_connectivity(self): - channel = grpc.insecure_channel('localhost:12345') + channel = grpc.insecure_channel("localhost:12345") callback = _Callback() ready_future = grpc.channel_ready_future(channel) @@ -65,18 +63,21 @@ def test_lonely_channel_connectivity(self): def test_immediately_connectable_channel_connectivity(self): recording_thread_pool = thread_pool.RecordingThreadPool( - max_workers=None) - server = grpc.server(recording_thread_pool, - options=(('grpc.so_reuseport', 0),)) - port = server.add_insecure_port('[::]:0') + max_workers=None + ) + server = grpc.server( + recording_thread_pool, options=(("grpc.so_reuseport", 0),) + ) + port = server.add_insecure_port("[::]:0") server.start() - channel = grpc.insecure_channel('localhost:{}'.format(port)) + channel = grpc.insecure_channel("localhost:{}".format(port)) callback = _Callback() ready_future = grpc.channel_ready_future(channel) ready_future.add_done_callback(callback.accept_value) self.assertIsNone( - ready_future.result(timeout=test_constants.LONG_TIMEOUT)) + ready_future.result(timeout=test_constants.LONG_TIMEOUT) + ) value_passed_to_callback = callback.block_until_called() self.assertIs(ready_future, value_passed_to_callback) self.assertFalse(ready_future.cancelled()) @@ -93,6 +94,6 @@ def test_immediately_connectable_channel_connectivity(self): server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py index 30a5d6bf6855e..be2a528ea9bea 100644 --- a/src/python/grpcio_tests/tests/unit/_compression_test.py +++ b/src/python/grpcio_tests/tests/unit/_compression_test.py @@ -27,17 +27,17 @@ from tests.unit import _tcp_proxy from tests.unit.framework.common import test_constants -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" # Cut down on test time. _STREAM_LENGTH = test_constants.STREAM_LENGTH // 16 -_HOST = 'localhost' +_HOST = "localhost" -_REQUEST = b'\x00' * 100 +_REQUEST = b"\x00" * 100 _COMPRESSION_RATIO_THRESHOLD = 0.05 _COMPRESSION_METHODS = ( None, @@ -47,24 +47,23 @@ grpc.Compression.Gzip, ) _COMPRESSION_NAMES = { - None: 'Uncompressed', - grpc.Compression.NoCompression: 'NoCompression', - grpc.Compression.Deflate: 'DeflateCompression', - grpc.Compression.Gzip: 'GzipCompression', + None: "Uncompressed", + grpc.Compression.NoCompression: "NoCompression", + grpc.Compression.Deflate: "DeflateCompression", + grpc.Compression.Gzip: "GzipCompression", } _TEST_OPTIONS = { - 'client_streaming': (True, False), - 'server_streaming': (True, False), - 'channel_compression': _COMPRESSION_METHODS, - 'multicallable_compression': _COMPRESSION_METHODS, - 'server_compression': _COMPRESSION_METHODS, - 'server_call_compression': _COMPRESSION_METHODS, + "client_streaming": (True, False), + "server_streaming": (True, False), + "channel_compression": _COMPRESSION_METHODS, + "multicallable_compression": _COMPRESSION_METHODS, + "server_compression": _COMPRESSION_METHODS, + "server_call_compression": _COMPRESSION_METHODS, } def _make_handle_unary_unary(pre_response_callback): - def _handle_unary(request, servicer_context): if pre_response_callback: pre_response_callback(request, servicer_context) @@ -74,7 +73,6 @@ def _handle_unary(request, servicer_context): def _make_handle_unary_stream(pre_response_callback): - def _handle_unary_stream(request, servicer_context): if pre_response_callback: pre_response_callback(request, servicer_context) @@ -85,7 +83,6 @@ def _handle_unary_stream(request, servicer_context): def _make_handle_stream_unary(pre_response_callback): - def _handle_stream_unary(request_iterator, servicer_context): if pre_response_callback: pre_response_callback(request_iterator, servicer_context) @@ -99,7 +96,6 @@ def _handle_stream_unary(request_iterator, servicer_context): def _make_handle_stream_stream(pre_response_callback): - def _handle_stream(request_iterator, servicer_context): # TODO(issue:#6891) We should be able to remove this loop, # and replace with return; yield @@ -111,8 +107,9 @@ def _handle_stream(request_iterator, servicer_context): return _handle_stream -def set_call_compression(compression_method, request_or_iterator, - servicer_context): +def set_call_compression( + compression_method, request_or_iterator, servicer_context +): del request_or_iterator servicer_context.set_compression(compression_method) @@ -123,14 +120,14 @@ def disable_next_compression(request, servicer_context): def disable_first_compression(request, servicer_context): - if int(request.decode('ascii')) == 0: + if int(request.decode("ascii")) == 0: servicer_context.disable_next_message_compression() class _MethodHandler(grpc.RpcMethodHandler): - - def __init__(self, request_streaming, response_streaming, - pre_response_callback): + def __init__( + self, request_streaming, response_streaming, pre_response_callback + ): self.request_streaming = request_streaming self.response_streaming = response_streaming self.request_deserializer = None @@ -142,7 +139,8 @@ def __init__(self, request_streaming, response_streaming, if self.request_streaming and self.response_streaming: self.stream_stream = _make_handle_stream_stream( - pre_response_callback) + pre_response_callback + ) elif not self.request_streaming and not self.response_streaming: self.unary_unary = _make_handle_unary_unary(pre_response_callback) elif not self.request_streaming and self.response_streaming: @@ -152,7 +150,6 @@ def __init__(self, request_streaming, response_streaming, class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self, pre_response_callback): self._pre_response_callback = pre_response_callback @@ -170,53 +167,82 @@ def service(self, handler_call_details): @contextlib.contextmanager -def _instrumented_client_server_pair(channel_kwargs, server_kwargs, - server_handler): +def _instrumented_client_server_pair( + channel_kwargs, server_kwargs, server_handler +): server = grpc.server(futures.ThreadPoolExecutor(), **server_kwargs) server.add_generic_rpc_handlers((server_handler,)) - server_port = server.add_insecure_port('{}:0'.format(_HOST)) + server_port = server.add_insecure_port("{}:0".format(_HOST)) server.start() with _tcp_proxy.TcpProxy(_HOST, _HOST, server_port) as proxy: proxy_port = proxy.get_port() - with grpc.insecure_channel('{}:{}'.format(_HOST, proxy_port), - **channel_kwargs) as client_channel: + with grpc.insecure_channel( + "{}:{}".format(_HOST, proxy_port), **channel_kwargs + ) as client_channel: try: yield client_channel, proxy, server finally: server.stop(None) -def _get_byte_counts(channel_kwargs, multicallable_kwargs, client_function, - server_kwargs, server_handler, message): - with _instrumented_client_server_pair(channel_kwargs, server_kwargs, - server_handler) as pipeline: +def _get_byte_counts( + channel_kwargs, + multicallable_kwargs, + client_function, + server_kwargs, + server_handler, + message, +): + with _instrumented_client_server_pair( + channel_kwargs, server_kwargs, server_handler + ) as pipeline: client_channel, proxy, server = pipeline client_function(client_channel, multicallable_kwargs, message) return proxy.get_byte_count() -def _get_compression_ratios(client_function, first_channel_kwargs, - first_multicallable_kwargs, first_server_kwargs, - first_server_handler, second_channel_kwargs, - second_multicallable_kwargs, second_server_kwargs, - second_server_handler, message): +def _get_compression_ratios( + client_function, + first_channel_kwargs, + first_multicallable_kwargs, + first_server_kwargs, + first_server_handler, + second_channel_kwargs, + second_multicallable_kwargs, + second_server_kwargs, + second_server_handler, + message, +): first_bytes_sent, first_bytes_received = _get_byte_counts( - first_channel_kwargs, first_multicallable_kwargs, client_function, - first_server_kwargs, first_server_handler, message) + first_channel_kwargs, + first_multicallable_kwargs, + client_function, + first_server_kwargs, + first_server_handler, + message, + ) second_bytes_sent, second_bytes_received = _get_byte_counts( - second_channel_kwargs, second_multicallable_kwargs, client_function, - second_server_kwargs, second_server_handler, message) - return ((second_bytes_sent - first_bytes_sent) / float(first_bytes_sent), - (second_bytes_received - first_bytes_received) / - float(first_bytes_received)) + second_channel_kwargs, + second_multicallable_kwargs, + client_function, + second_server_kwargs, + second_server_handler, + message, + ) + return ( + (second_bytes_sent - first_bytes_sent) / float(first_bytes_sent), + (second_bytes_received - first_bytes_received) + / float(first_bytes_received), + ) def _unary_unary_client(channel, multicallable_kwargs, message): multi_callable = channel.unary_unary(_UNARY_UNARY) response = multi_callable(message, **multicallable_kwargs) if response != message: - raise RuntimeError("Request '{}' != Response '{}'".format( - message, response)) + raise RuntimeError( + "Request '{}' != Response '{}'".format(message, response) + ) def _unary_stream_client(channel, multicallable_kwargs, message): @@ -224,8 +250,9 @@ def _unary_stream_client(channel, multicallable_kwargs, message): response_iterator = multi_callable(message, **multicallable_kwargs) for response in response_iterator: if response != message: - raise RuntimeError("Request '{}' != Response '{}'".format( - message, response)) + raise RuntimeError( + "Request '{}' != Response '{}'".format(message, response) + ) def _stream_unary_client(channel, multicallable_kwargs, message): @@ -233,49 +260,67 @@ def _stream_unary_client(channel, multicallable_kwargs, message): requests = (_REQUEST for _ in range(_STREAM_LENGTH)) response = multi_callable(requests, **multicallable_kwargs) if response != message: - raise RuntimeError("Request '{}' != Response '{}'".format( - message, response)) + raise RuntimeError( + "Request '{}' != Response '{}'".format(message, response) + ) def _stream_stream_client(channel, multicallable_kwargs, message): multi_callable = channel.stream_stream(_STREAM_STREAM) - request_prefix = str(0).encode('ascii') * 100 + request_prefix = str(0).encode("ascii") * 100 requests = ( - request_prefix + str(i).encode('ascii') for i in range(_STREAM_LENGTH)) + request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH) + ) response_iterator = multi_callable(requests, **multicallable_kwargs) for i, response in enumerate(response_iterator): - if int(response.decode('ascii')) != i: - raise RuntimeError("Request '{}' != Response '{}'".format( - i, response)) + if int(response.decode("ascii")) != i: + raise RuntimeError( + "Request '{}' != Response '{}'".format(i, response) + ) class CompressionTest(unittest.TestCase): - def assertCompressed(self, compression_ratio): self.assertLess( compression_ratio, -1.0 * _COMPRESSION_RATIO_THRESHOLD, - msg='Actual compression ratio: {}'.format(compression_ratio)) + msg="Actual compression ratio: {}".format(compression_ratio), + ) def assertNotCompressed(self, compression_ratio): self.assertGreaterEqual( compression_ratio, -1.0 * _COMPRESSION_RATIO_THRESHOLD, - msg='Actual compession ratio: {}'.format(compression_ratio)) - - def assertConfigurationCompressed(self, client_streaming, server_streaming, - channel_compression, - multicallable_compression, - server_compression, - server_call_compression): - client_side_compressed = channel_compression or multicallable_compression + msg="Actual compession ratio: {}".format(compression_ratio), + ) + + def assertConfigurationCompressed( + self, + client_streaming, + server_streaming, + channel_compression, + multicallable_compression, + server_compression, + server_call_compression, + ): + client_side_compressed = ( + channel_compression or multicallable_compression + ) server_side_compressed = server_compression or server_call_compression - channel_kwargs = { - 'compression': channel_compression, - } if channel_compression else {} - multicallable_kwargs = { - 'compression': multicallable_compression, - } if multicallable_compression else {} + channel_kwargs = ( + { + "compression": channel_compression, + } + if channel_compression + else {} + ) + multicallable_kwargs = ( + { + "compression": multicallable_compression, + } + if multicallable_compression + else {} + ) client_function = None if not client_streaming and not server_streaming: @@ -287,57 +332,100 @@ def assertConfigurationCompressed(self, client_streaming, server_streaming, else: client_function = _stream_stream_client - server_kwargs = { - 'compression': server_compression, - } if server_compression else {} - server_handler = _GenericHandler( - functools.partial(set_call_compression, grpc.Compression.Gzip) - ) if server_call_compression else _GenericHandler(None) - _get_compression_ratios(client_function, {}, {}, {}, - _GenericHandler(None), channel_kwargs, - multicallable_kwargs, server_kwargs, - server_handler, _REQUEST) + server_kwargs = ( + { + "compression": server_compression, + } + if server_compression + else {} + ) + server_handler = ( + _GenericHandler( + functools.partial(set_call_compression, grpc.Compression.Gzip) + ) + if server_call_compression + else _GenericHandler(None) + ) + _get_compression_ratios( + client_function, + {}, + {}, + {}, + _GenericHandler(None), + channel_kwargs, + multicallable_kwargs, + server_kwargs, + server_handler, + _REQUEST, + ) def testDisableNextCompressionStreaming(self): server_kwargs = { - 'compression': grpc.Compression.Deflate, + "compression": grpc.Compression.Deflate, } - _get_compression_ratios(_stream_stream_client, {}, {}, {}, - _GenericHandler(None), {}, {}, server_kwargs, - _GenericHandler(disable_next_compression), - _REQUEST) + _get_compression_ratios( + _stream_stream_client, + {}, + {}, + {}, + _GenericHandler(None), + {}, + {}, + server_kwargs, + _GenericHandler(disable_next_compression), + _REQUEST, + ) def testDisableNextCompressionStreamingResets(self): server_kwargs = { - 'compression': grpc.Compression.Deflate, + "compression": grpc.Compression.Deflate, } - _get_compression_ratios(_stream_stream_client, {}, {}, {}, - _GenericHandler(None), {}, {}, server_kwargs, - _GenericHandler(disable_first_compression), - _REQUEST) + _get_compression_ratios( + _stream_stream_client, + {}, + {}, + {}, + _GenericHandler(None), + {}, + {}, + server_kwargs, + _GenericHandler(disable_first_compression), + _REQUEST, + ) def _get_compression_str(name, value): - return '{}{}'.format(name, _COMPRESSION_NAMES[value]) - - -def _get_compression_test_name(client_streaming, server_streaming, - channel_compression, multicallable_compression, - server_compression, server_call_compression): - client_arity = 'Stream' if client_streaming else 'Unary' - server_arity = 'Stream' if server_streaming else 'Unary' - arity = '{}{}'.format(client_arity, server_arity) - channel_compression_str = _get_compression_str('Channel', - channel_compression) + return "{}{}".format(name, _COMPRESSION_NAMES[value]) + + +def _get_compression_test_name( + client_streaming, + server_streaming, + channel_compression, + multicallable_compression, + server_compression, + server_call_compression, +): + client_arity = "Stream" if client_streaming else "Unary" + server_arity = "Stream" if server_streaming else "Unary" + arity = "{}{}".format(client_arity, server_arity) + channel_compression_str = _get_compression_str( + "Channel", channel_compression + ) multicallable_compression_str = _get_compression_str( - 'Multicallable', multicallable_compression) - server_compression_str = _get_compression_str('Server', server_compression) - server_call_compression_str = _get_compression_str('ServerCall', - server_call_compression) - return 'test{}{}{}{}{}'.format(arity, channel_compression_str, - multicallable_compression_str, - server_compression_str, - server_call_compression_str) + "Multicallable", multicallable_compression + ) + server_compression_str = _get_compression_str("Server", server_compression) + server_call_compression_str = _get_compression_str( + "ServerCall", server_call_compression + ) + return "test{}{}{}{}{}".format( + arity, + channel_compression_str, + multicallable_compression_str, + server_compression_str, + server_call_compression_str, + ) def _test_options(): @@ -348,15 +436,17 @@ def _test_options(): for options in _test_options(): def test_compression(**kwargs): - def _test_compression(self): self.assertConfigurationCompressed(**kwargs) return _test_compression - setattr(CompressionTest, _get_compression_test_name(**options), - test_compression(**options)) + setattr( + CompressionTest, + _get_compression_test_name(**options), + test_compression(**options), + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py index ccf4a07a0a7c0..6f3b601ceb231 100644 --- a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py +++ b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py @@ -36,13 +36,13 @@ def _unary_unary_handler(request, context): def contextvars_supported(): try: import contextvars + return True except ImportError: return False class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) @@ -54,7 +54,7 @@ def service(self, handler_call_details): def _server(): try: server = test_common.test_server() - target = 'localhost:0' + target = "localhost:0" port = server.add_insecure_port(target) server.add_generic_rpc_handlers((_GenericHandler(),)) server.start() @@ -73,15 +73,17 @@ def set_up_expected_context(): test_var.set(_EXPECTED_VALUE) class TestCallCredentials(grpc.AuthMetadataPlugin): - def __call__(self, context, callback): - if test_var.get( - ) != _EXPECTED_VALUE and not test_common.running_under_gevent(): + if ( + test_var.get() != _EXPECTED_VALUE + and not test_common.running_under_gevent() + ): # contextvars do not work under gevent, but the rest of this # test is still valuable as a test of concurrent runs of the # metadata credentials code path. - raise AssertionError("{} != {}".format(test_var.get(), - _EXPECTED_VALUE)) + raise AssertionError( + "{} != {}".format(test_var.get(), _EXPECTED_VALUE) + ) callback((), None) def assert_called(self, test): @@ -94,7 +96,6 @@ def set_up_expected_context(): pass class TestCallCredentials(grpc.AuthMetadataPlugin): - def __call__(self, context, callback): callback((), None) @@ -102,7 +103,6 @@ def __call__(self, context, callback): # TODO(https://github.com/grpc/grpc/issues/22257) @unittest.skipIf(os.name == "nt", "LocalCredentials not supported on Windows.") class ContextVarsPropagationTest(unittest.TestCase): - def test_propagation_to_auth_plugin(self): set_up_expected_context() with _server() as port: @@ -110,9 +110,11 @@ def test_propagation_to_auth_plugin(self): local_credentials = grpc.local_channel_credentials() test_call_credentials = TestCallCredentials() call_credentials = grpc.metadata_call_credentials( - test_call_credentials, "test call credentials") + test_call_credentials, "test call credentials" + ) composite_credentials = grpc.composite_channel_credentials( - local_credentials, call_credentials) + local_credentials, call_credentials + ) with grpc.secure_channel(target, composite_credentials) as channel: stub = channel.unary_unary(_UNARY_UNARY) response = stub(_REQUEST, wait_for_ready=True) @@ -128,15 +130,18 @@ def test_concurrent_propagation(self): local_credentials = grpc.local_channel_credentials() test_call_credentials = TestCallCredentials() call_credentials = grpc.metadata_call_credentials( - test_call_credentials, "test call credentials") + test_call_credentials, "test call credentials" + ) composite_credentials = grpc.composite_channel_credentials( - local_credentials, call_credentials) + local_credentials, call_credentials + ) wait_group = test_common.WaitGroup(_THREAD_COUNT) def _run_on_thread(exception_queue): try: - with grpc.secure_channel(target, - composite_credentials) as channel: + with grpc.secure_channel( + target, composite_credentials + ) as channel: stub = channel.unary_unary(_UNARY_UNARY) wait_group.done() wait_group.wait() @@ -161,6 +166,6 @@ def _run_on_thread(exception_queue): raise q.get() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_credentials_test.py b/src/python/grpcio_tests/tests/unit/_credentials_test.py index ac2bfe2222038..e3a6e74c444d6 100644 --- a/src/python/grpcio_tests/tests/unit/_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_credentials_test.py @@ -20,49 +20,54 @@ class CredentialsTest(unittest.TestCase): - def test_call_credentials_composition(self): - first = grpc.access_token_call_credentials('abc') - second = grpc.access_token_call_credentials('def') - third = grpc.access_token_call_credentials('ghi') + first = grpc.access_token_call_credentials("abc") + second = grpc.access_token_call_credentials("def") + third = grpc.access_token_call_credentials("ghi") first_and_second = grpc.composite_call_credentials(first, second) first_second_and_third = grpc.composite_call_credentials( - first, second, third) + first, second, third + ) self.assertIsInstance(first_and_second, grpc.CallCredentials) self.assertIsInstance(first_second_and_third, grpc.CallCredentials) def test_channel_credentials_composition(self): - first_call_credentials = grpc.access_token_call_credentials('abc') - second_call_credentials = grpc.access_token_call_credentials('def') - third_call_credentials = grpc.access_token_call_credentials('ghi') + first_call_credentials = grpc.access_token_call_credentials("abc") + second_call_credentials = grpc.access_token_call_credentials("def") + third_call_credentials = grpc.access_token_call_credentials("ghi") channel_credentials = grpc.ssl_channel_credentials() channel_and_first = grpc.composite_channel_credentials( - channel_credentials, first_call_credentials) + channel_credentials, first_call_credentials + ) channel_first_and_second = grpc.composite_channel_credentials( - channel_credentials, first_call_credentials, - second_call_credentials) + channel_credentials, first_call_credentials, second_call_credentials + ) channel_first_second_and_third = grpc.composite_channel_credentials( - channel_credentials, first_call_credentials, - second_call_credentials, third_call_credentials) + channel_credentials, + first_call_credentials, + second_call_credentials, + third_call_credentials, + ) self.assertIsInstance(channel_and_first, grpc.ChannelCredentials) self.assertIsInstance(channel_first_and_second, grpc.ChannelCredentials) - self.assertIsInstance(channel_first_second_and_third, - grpc.ChannelCredentials) + self.assertIsInstance( + channel_first_second_and_third, grpc.ChannelCredentials + ) def test_invalid_string_certificate(self): self.assertRaises( TypeError, grpc.ssl_channel_credentials, - root_certificates='A Certificate', + root_certificates="A Certificate", private_key=None, certificate_chain=None, ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py index 3ca0d686d61e6..497249b397aa8 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py @@ -25,11 +25,11 @@ _EMPTY_FLAGS = 0 _EMPTY_METADATA = () -_SERVER_SHUTDOWN_TAG = 'server_shutdown' -_REQUEST_CALL_TAG = 'request_call' -_RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server' -_RECEIVE_MESSAGE_TAG = 'receive_message' -_SERVER_COMPLETE_CALL_TAG = 'server_complete_call' +_SERVER_SHUTDOWN_TAG = "server_shutdown" +_REQUEST_CALL_TAG = "request_call" +_RECEIVE_CLOSE_ON_SERVER_TAG = "receive_close_on_server" +_RECEIVE_MESSAGE_TAG = "receive_message" +_SERVER_COMPLETE_CALL_TAG = "server_complete_call" _SUCCESS_CALL_FRACTION = 1.0 / 8.0 _SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION) @@ -37,7 +37,6 @@ class _State(object): - def __init__(self): self.condition = threading.Condition() self.handlers_released = False @@ -46,12 +45,13 @@ def __init__(self): def _is_cancellation_event(event): - return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and - event.batch_operations[0].cancelled()) + return ( + event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG + and event.batch_operations[0].cancelled() + ) class _Handler(object): - def __init__(self, state, completion_queue, rpc_event): self._state = state self._lock = threading.Lock() @@ -69,25 +69,32 @@ def __call__(self): with self._lock: self._call.start_server_batch( (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),), - _RECEIVE_CLOSE_ON_SERVER_TAG) + _RECEIVE_CLOSE_ON_SERVER_TAG, + ) self._call.start_server_batch( (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),), - _RECEIVE_MESSAGE_TAG) + _RECEIVE_MESSAGE_TAG, + ) first_event = self._completion_queue.poll() if _is_cancellation_event(first_event): self._completion_queue.poll() else: with self._lock: operations = ( - cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, - _EMPTY_FLAGS), - cygrpc.SendMessageOperation(b'\x79\x57', _EMPTY_FLAGS), + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS + ), + cygrpc.SendMessageOperation(b"\x79\x57", _EMPTY_FLAGS), cygrpc.SendStatusFromServerOperation( - _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!', - _EMPTY_FLAGS), + _EMPTY_METADATA, + cygrpc.StatusCode.ok, + b"test details!", + _EMPTY_FLAGS, + ), + ) + self._call.start_server_batch( + operations, _SERVER_COMPLETE_CALL_TAG ) - self._call.start_server_batch(operations, - _SERVER_COMPLETE_CALL_TAG) self._completion_queue.poll() self._completion_queue.poll() @@ -95,8 +102,9 @@ def __call__(self): def _serve(state, server, server_completion_queue, thread_pool): for _ in range(test_constants.RPC_CONCURRENCY): call_completion_queue = cygrpc.CompletionQueue() - server.request_call(call_completion_queue, server_completion_queue, - _REQUEST_CALL_TAG) + server.request_call( + call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG + ) rpc_event = server_completion_queue.poll() thread_pool.submit(_Handler(state, call_completion_queue, rpc_event)) with state.condition: @@ -107,7 +115,6 @@ def _serve(state, server, server_completion_queue, thread_pool): class _QueueDriver(object): - def __init__(self, condition, completion_queue, due): self._condition = condition self._completion_queue = completion_queue @@ -116,7 +123,6 @@ def __init__(self, condition, completion_queue, due): self._returned = False def start(self): - def in_thread(): while True: event = self._completion_queue.poll() @@ -139,21 +145,27 @@ def events(self, at_least): class CancelManyCallsTest(unittest.TestCase): - def testCancelManyCalls(self): server_thread_pool = logging_pool.pool( - test_constants.THREAD_CONCURRENCY) + test_constants.THREAD_CONCURRENCY + ) server_completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server([( - b'grpc.so_reuseport', - 0, - )], False) + server = cygrpc.Server( + [ + ( + b"grpc.so_reuseport", + 0, + ) + ], + False, + ) server.register_completion_queue(server_completion_queue) - port = server.add_http2_port(b'[::]:0') + port = server.add_http2_port(b"[::]:0") server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None, - None) + channel = cygrpc.Channel( + "localhost:{}".format(port).encode(), None, None + ) state = _State() @@ -172,28 +184,46 @@ def testCancelManyCalls(self): with client_condition: client_calls = [] for index in range(test_constants.RPC_CONCURRENCY): - tag = 'client_complete_call_{0:04d}_tag'.format(index) + tag = "client_complete_call_{0:04d}_tag".format(index) client_call = channel.integrated_call( - _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, - None, (( + _EMPTY_FLAGS, + b"/twinkies", + None, + None, + _EMPTY_METADATA, + None, + ( ( - cygrpc.SendInitialMetadataOperation( - _EMPTY_METADATA, _EMPTY_FLAGS), - cygrpc.SendMessageOperation(b'\x45\x56', - _EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation( - _EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ( + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS + ), + cygrpc.SendMessageOperation( + b"\x45\x56", _EMPTY_FLAGS + ), + cygrpc.SendCloseFromClientOperation( + _EMPTY_FLAGS + ), + cygrpc.ReceiveInitialMetadataOperation( + _EMPTY_FLAGS + ), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation( + _EMPTY_FLAGS + ), + ), + tag, ), - tag, - ),)) + ), + ) client_due.add(tag) client_calls.append(client_call) - client_events_future = test_utilities.SimpleFuture(lambda: tuple( - channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS))) + client_events_future = test_utilities.SimpleFuture( + lambda: tuple( + channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS) + ) + ) with state.condition: while True: @@ -209,14 +239,14 @@ def testCancelManyCalls(self): client_events_future.result() with client_condition: for client_call in client_calls: - client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!') + client_call.cancel(cygrpc.StatusCode.cancelled, "Cancelled!") for _ in range(_UNSUCCESSFUL_CALLS): channel.next_call_event() - channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!') + channel.close(cygrpc.StatusCode.unknown, "Cancelled on channel close!") with state.condition: server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py index 8f0b6fedc02fa..34682d3bcb4bf 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py @@ -22,7 +22,7 @@ def _channel(): - return cygrpc.Channel(b'localhost:54321', (), None) + return cygrpc.Channel(b"localhost:54321", (), None) def _connectivity_loop(channel): @@ -34,13 +34,14 @@ def _connectivity_loop(channel): def _create_loop_destroy(): channel = _channel() _connectivity_loop(channel) - channel.close(cygrpc.StatusCode.ok, 'Channel close!') + channel.close(cygrpc.StatusCode.ok, "Channel close!") def _in_parallel(behavior, arguments): threads = tuple( threading.Thread(target=behavior, args=arguments) - for _ in range(test_constants.THREAD_CONCURRENCY)) + for _ in range(test_constants.THREAD_CONCURRENCY) + ) for thread in threads: thread.start() for thread in threads: @@ -48,11 +49,10 @@ def _in_parallel(behavior, arguments): class ChannelTest(unittest.TestCase): - def test_single_channel_lonely_connectivity(self): channel = _channel() _connectivity_loop(channel) - channel.close(cygrpc.StatusCode.ok, 'Channel close!') + channel.close(cygrpc.StatusCode.ok, "Channel close!") def test_multiple_channels_lonely_connectivity(self): _in_parallel(_create_loop_destroy, ()) @@ -61,10 +61,10 @@ def test_negative_deadline_connectivity(self): channel = _channel() connectivity = channel.check_connectivity_state(True) channel.watch_connectivity_state(connectivity, -3.14) - channel.close(cygrpc.StatusCode.ok, 'Channel close!') + channel.close(cygrpc.StatusCode.ok, "Channel close!") # NOTE(lidiz) The negative timeout should not trigger SIGABRT. # Bug report: https://github.com/grpc/grpc/issues/18244 -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_common.py b/src/python/grpcio_tests/tests/unit/_cython/_common.py index 42ec655feee7d..1a7c3dd646f12 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_common.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_common.py @@ -23,23 +23,22 @@ EMPTY_FLAGS = 0 INVOCATION_METADATA = ( - ('client-md-key', 'client-md-key'), - ('client-md-key-bin', b'\x00\x01' * 3000), + ("client-md-key", "client-md-key"), + ("client-md-key-bin", b"\x00\x01" * 3000), ) INITIAL_METADATA = ( - ('server-initial-md-key', 'server-initial-md-value'), - ('server-initial-md-key-bin', b'\x00\x02' * 3000), + ("server-initial-md-key", "server-initial-md-value"), + ("server-initial-md-key-bin", b"\x00\x02" * 3000), ) TRAILING_METADATA = ( - ('server-trailing-md-key', 'server-trailing-md-value'), - ('server-trailing-md-key-bin', b'\x00\x03' * 3000), + ("server-trailing-md-key", "server-trailing-md-value"), + ("server-trailing-md-key-bin", b"\x00\x03" * 3000), ) class QueueDriver(object): - def __init__(self, condition, completion_queue): self._condition = condition self._completion_queue = completion_queue @@ -80,44 +79,54 @@ def execute_many_times(behavior): class OperationResult( - collections.namedtuple('OperationResult', ( - 'start_batch_result', - 'completion_type', - 'success', - ))): + collections.namedtuple( + "OperationResult", + ( + "start_batch_result", + "completion_type", + "success", + ), + ) +): pass SUCCESSFUL_OPERATION_RESULT = OperationResult( - cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True) + cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True +) class RpcTest(object): - def setUp(self): self.server_completion_queue = cygrpc.CompletionQueue() - self.server = cygrpc.Server([(b'grpc.so_reuseport', 0)], False) + self.server = cygrpc.Server([(b"grpc.so_reuseport", 0)], False) self.server.register_completion_queue(self.server_completion_queue) - port = self.server.add_http2_port(b'[::]:0') + port = self.server.add_http2_port(b"[::]:0") self.server.start() - self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [], - None) + self.channel = cygrpc.Channel( + "localhost:{}".format(port).encode(), [], None + ) - self._server_shutdown_tag = 'server_shutdown_tag' + self._server_shutdown_tag = "server_shutdown_tag" self.server_condition = threading.Condition() - self.server_driver = QueueDriver(self.server_condition, - self.server_completion_queue) + self.server_driver = QueueDriver( + self.server_condition, self.server_completion_queue + ) with self.server_condition: - self.server_driver.add_due({ - self._server_shutdown_tag, - }) + self.server_driver.add_due( + { + self._server_shutdown_tag, + } + ) self.client_condition = threading.Condition() self.client_completion_queue = cygrpc.CompletionQueue() - self.client_driver = QueueDriver(self.client_condition, - self.client_completion_queue) + self.client_driver = QueueDriver( + self.client_condition, self.client_completion_queue + ) def tearDown(self): - self.server.shutdown(self.server_completion_queue, - self._server_shutdown_tag) + self.server.shutdown( + self.server_completion_queue, self._server_shutdown_tag + ) self.server.cancel_all_calls() diff --git a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py index 5a5dedd5f2696..00e380af7647c 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py @@ -23,15 +23,13 @@ def _get_number_active_threads(): return cygrpc._fork_state.active_thread_count._num_active_threads -@unittest.skipIf(os.name == 'nt', 'Posix-specific tests') +@unittest.skipIf(os.name == "nt", "Posix-specific tests") class ForkPosixTester(unittest.TestCase): - def setUp(self): self._saved_fork_support_flag = cygrpc._GRPC_ENABLE_FORK_SUPPORT cygrpc._GRPC_ENABLE_FORK_SUPPORT = True def testForkManagedThread(self): - def cb(): self.assertEqual(1, _get_number_active_threads()) @@ -41,7 +39,6 @@ def cb(): self.assertEqual(0, _get_number_active_threads()) def testForkManagedThreadThrowsException(self): - def cb(): self.assertEqual(1, _get_number_active_threads()) raise Exception("expected exception") @@ -55,11 +52,9 @@ def tearDown(self): cygrpc._GRPC_ENABLE_FORK_SUPPORT = self._saved_fork_support_flag -@unittest.skipUnless(os.name == 'nt', 'Windows-specific tests') +@unittest.skipUnless(os.name == "nt", "Windows-specific tests") class ForkWindowsTester(unittest.TestCase): - def testForkManagedThreadIsNoOp(self): - def cb(): pass @@ -68,5 +63,5 @@ def cb(): thread.join() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py index 144a2fcae3f17..bff43358ef405 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_server_completion_queue_per_call_test.py @@ -23,75 +23,118 @@ class Test(_common.RpcTest, unittest.TestCase): - def _do_rpcs(self): server_call_condition = threading.Condition() server_call_completion_queue = cygrpc.CompletionQueue() - server_call_driver = _common.QueueDriver(server_call_condition, - server_call_completion_queue) + server_call_driver = _common.QueueDriver( + server_call_condition, server_call_completion_queue + ) - server_request_call_tag = 'server_request_call_tag' - server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' - server_complete_rpc_tag = 'server_complete_rpc_tag' + server_request_call_tag = "server_request_call_tag" + server_send_initial_metadata_tag = "server_send_initial_metadata_tag" + server_complete_rpc_tag = "server_complete_rpc_tag" with self.server_condition: server_request_call_start_batch_result = self.server.request_call( - server_call_completion_queue, self.server_completion_queue, - server_request_call_tag) - self.server_driver.add_due({ + server_call_completion_queue, + self.server_completion_queue, server_request_call_tag, - }) - - client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' - client_complete_rpc_tag = 'client_complete_rpc_tag' + ) + self.server_driver.add_due( + { + server_request_call_tag, + } + ) + + client_receive_initial_metadata_tag = ( + "client_receive_initial_metadata_tag" + ) + client_complete_rpc_tag = "client_complete_rpc_tag" client_call = self.channel.integrated_call( - _common.EMPTY_FLAGS, b'/twinkies', None, None, - _common.INVOCATION_METADATA, None, [( - [ - cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), - ], - client_receive_initial_metadata_tag, - )]) - client_call.operate([ - cygrpc.SendInitialMetadataOperation(_common.INVOCATION_METADATA, - _common.EMPTY_FLAGS), - cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), - cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), - ], client_complete_rpc_tag) - - client_events_future = test_utilities.SimpleFuture(lambda: [ - self.channel.next_call_event(), - self.channel.next_call_event(), - ]) + _common.EMPTY_FLAGS, + b"/twinkies", + None, + None, + _common.INVOCATION_METADATA, + None, + [ + ( + [ + cygrpc.ReceiveInitialMetadataOperation( + _common.EMPTY_FLAGS + ), + ], + client_receive_initial_metadata_tag, + ) + ], + ) + client_call.operate( + [ + cygrpc.SendInitialMetadataOperation( + _common.INVOCATION_METADATA, _common.EMPTY_FLAGS + ), + cygrpc.SendCloseFromClientOperation(_common.EMPTY_FLAGS), + cygrpc.ReceiveStatusOnClientOperation(_common.EMPTY_FLAGS), + ], + client_complete_rpc_tag, + ) + + client_events_future = test_utilities.SimpleFuture( + lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(), + ] + ) server_request_call_event = self.server_driver.event_with_tag( - server_request_call_tag) + server_request_call_tag + ) with server_call_condition: server_send_initial_metadata_start_batch_result = ( - server_request_call_event.call.start_server_batch([ - cygrpc.SendInitialMetadataOperation( - _common.INITIAL_METADATA, _common.EMPTY_FLAGS), - ], server_send_initial_metadata_tag)) - server_call_driver.add_due({ - server_send_initial_metadata_tag, - }) + server_request_call_event.call.start_server_batch( + [ + cygrpc.SendInitialMetadataOperation( + _common.INITIAL_METADATA, _common.EMPTY_FLAGS + ), + ], + server_send_initial_metadata_tag, + ) + ) + server_call_driver.add_due( + { + server_send_initial_metadata_tag, + } + ) server_send_initial_metadata_event = server_call_driver.event_with_tag( - server_send_initial_metadata_tag) + server_send_initial_metadata_tag + ) with server_call_condition: server_complete_rpc_start_batch_result = ( - server_request_call_event.call.start_server_batch([ - cygrpc.ReceiveCloseOnServerOperation(_common.EMPTY_FLAGS), - cygrpc.SendStatusFromServerOperation( - _common.TRAILING_METADATA, cygrpc.StatusCode.ok, - b'test details', _common.EMPTY_FLAGS), - ], server_complete_rpc_tag)) - server_call_driver.add_due({ - server_complete_rpc_tag, - }) + server_request_call_event.call.start_server_batch( + [ + cygrpc.ReceiveCloseOnServerOperation( + _common.EMPTY_FLAGS + ), + cygrpc.SendStatusFromServerOperation( + _common.TRAILING_METADATA, + cygrpc.StatusCode.ok, + b"test details", + _common.EMPTY_FLAGS, + ), + ], + server_complete_rpc_tag, + ) + ) + server_call_driver.add_due( + { + server_complete_rpc_tag, + } + ) server_complete_rpc_event = server_call_driver.event_with_tag( - server_complete_rpc_tag) + server_complete_rpc_tag + ) client_events = client_events_future.result() if client_events[0].tag is client_receive_initial_metadata_tag: @@ -102,31 +145,40 @@ def _do_rpcs(self): client_receive_initial_metadata_event = client_events[1] return ( - _common.OperationResult(server_request_call_start_batch_result, - server_request_call_event.completion_type, - server_request_call_event.success), + _common.OperationResult( + server_request_call_start_batch_result, + server_request_call_event.completion_type, + server_request_call_event.success, + ), _common.OperationResult( cygrpc.CallError.ok, client_receive_initial_metadata_event.completion_type, - client_receive_initial_metadata_event.success), - _common.OperationResult(cygrpc.CallError.ok, - client_complete_rpc_event.completion_type, - client_complete_rpc_event.success), + client_receive_initial_metadata_event.success, + ), + _common.OperationResult( + cygrpc.CallError.ok, + client_complete_rpc_event.completion_type, + client_complete_rpc_event.success, + ), _common.OperationResult( server_send_initial_metadata_start_batch_result, server_send_initial_metadata_event.completion_type, - server_send_initial_metadata_event.success), - _common.OperationResult(server_complete_rpc_start_batch_result, - server_complete_rpc_event.completion_type, - server_complete_rpc_event.success), + server_send_initial_metadata_event.success, + ), + _common.OperationResult( + server_complete_rpc_start_batch_result, + server_complete_rpc_event.completion_type, + server_complete_rpc_event.success, + ), ) def test_rpcs(self): - expecteds = [(_common.SUCCESSFUL_OPERATION_RESULT,) * 5 - ] * _common.RPC_COUNT + expecteds = [ + (_common.SUCCESSFUL_OPERATION_RESULT,) * 5 + ] * _common.RPC_COUNT actuallys = _common.execute_many_times(self._do_rpcs) self.assertSequenceEqual(expecteds, actuallys) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py index 38964768db708..9e7d3b6865873 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_no_messages_single_server_completion_queue_test.py @@ -23,104 +23,153 @@ class Test(_common.RpcTest, unittest.TestCase): - def _do_rpcs(self): - server_request_call_tag = 'server_request_call_tag' - server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' - server_complete_rpc_tag = 'server_complete_rpc_tag' + server_request_call_tag = "server_request_call_tag" + server_send_initial_metadata_tag = "server_send_initial_metadata_tag" + server_complete_rpc_tag = "server_complete_rpc_tag" with self.server_condition: server_request_call_start_batch_result = self.server.request_call( - self.server_completion_queue, self.server_completion_queue, - server_request_call_tag) - self.server_driver.add_due({ + self.server_completion_queue, + self.server_completion_queue, server_request_call_tag, - }) - - client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' - client_complete_rpc_tag = 'client_complete_rpc_tag' + ) + self.server_driver.add_due( + { + server_request_call_tag, + } + ) + + client_receive_initial_metadata_tag = ( + "client_receive_initial_metadata_tag" + ) + client_complete_rpc_tag = "client_complete_rpc_tag" client_call = self.channel.integrated_call( - _common.EMPTY_FLAGS, b'/twinkies', None, None, - _common.INVOCATION_METADATA, None, [ + _common.EMPTY_FLAGS, + b"/twinkies", + None, + None, + _common.INVOCATION_METADATA, + None, + [ ( [ cygrpc.SendInitialMetadataOperation( - _common.INVOCATION_METADATA, _common.EMPTY_FLAGS), + _common.INVOCATION_METADATA, _common.EMPTY_FLAGS + ), cygrpc.SendCloseFromClientOperation( - _common.EMPTY_FLAGS), + _common.EMPTY_FLAGS + ), cygrpc.ReceiveStatusOnClientOperation( - _common.EMPTY_FLAGS), + _common.EMPTY_FLAGS + ), ], client_complete_rpc_tag, ), - ]) - client_call.operate([ - cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), - ], client_receive_initial_metadata_tag) + ], + ) + client_call.operate( + [ + cygrpc.ReceiveInitialMetadataOperation(_common.EMPTY_FLAGS), + ], + client_receive_initial_metadata_tag, + ) - client_events_future = test_utilities.SimpleFuture(lambda: [ - self.channel.next_call_event(), - self.channel.next_call_event(), - ]) + client_events_future = test_utilities.SimpleFuture( + lambda: [ + self.channel.next_call_event(), + self.channel.next_call_event(), + ] + ) server_request_call_event = self.server_driver.event_with_tag( - server_request_call_tag) + server_request_call_tag + ) with self.server_condition: server_send_initial_metadata_start_batch_result = ( - server_request_call_event.call.start_server_batch([ - cygrpc.SendInitialMetadataOperation( - _common.INITIAL_METADATA, _common.EMPTY_FLAGS), - ], server_send_initial_metadata_tag)) - self.server_driver.add_due({ - server_send_initial_metadata_tag, - }) + server_request_call_event.call.start_server_batch( + [ + cygrpc.SendInitialMetadataOperation( + _common.INITIAL_METADATA, _common.EMPTY_FLAGS + ), + ], + server_send_initial_metadata_tag, + ) + ) + self.server_driver.add_due( + { + server_send_initial_metadata_tag, + } + ) server_send_initial_metadata_event = self.server_driver.event_with_tag( - server_send_initial_metadata_tag) + server_send_initial_metadata_tag + ) with self.server_condition: server_complete_rpc_start_batch_result = ( - server_request_call_event.call.start_server_batch([ - cygrpc.ReceiveCloseOnServerOperation(_common.EMPTY_FLAGS), - cygrpc.SendStatusFromServerOperation( - _common.TRAILING_METADATA, cygrpc.StatusCode.ok, - 'test details', _common.EMPTY_FLAGS), - ], server_complete_rpc_tag)) - self.server_driver.add_due({ - server_complete_rpc_tag, - }) + server_request_call_event.call.start_server_batch( + [ + cygrpc.ReceiveCloseOnServerOperation( + _common.EMPTY_FLAGS + ), + cygrpc.SendStatusFromServerOperation( + _common.TRAILING_METADATA, + cygrpc.StatusCode.ok, + "test details", + _common.EMPTY_FLAGS, + ), + ], + server_complete_rpc_tag, + ) + ) + self.server_driver.add_due( + { + server_complete_rpc_tag, + } + ) server_complete_rpc_event = self.server_driver.event_with_tag( - server_complete_rpc_tag) + server_complete_rpc_tag + ) client_events = client_events_future.result() client_receive_initial_metadata_event = client_events[0] client_complete_rpc_event = client_events[1] return ( - _common.OperationResult(server_request_call_start_batch_result, - server_request_call_event.completion_type, - server_request_call_event.success), + _common.OperationResult( + server_request_call_start_batch_result, + server_request_call_event.completion_type, + server_request_call_event.success, + ), _common.OperationResult( cygrpc.CallError.ok, client_receive_initial_metadata_event.completion_type, - client_receive_initial_metadata_event.success), - _common.OperationResult(cygrpc.CallError.ok, - client_complete_rpc_event.completion_type, - client_complete_rpc_event.success), + client_receive_initial_metadata_event.success, + ), + _common.OperationResult( + cygrpc.CallError.ok, + client_complete_rpc_event.completion_type, + client_complete_rpc_event.success, + ), _common.OperationResult( server_send_initial_metadata_start_batch_result, server_send_initial_metadata_event.completion_type, - server_send_initial_metadata_event.success), - _common.OperationResult(server_complete_rpc_start_batch_result, - server_complete_rpc_event.completion_type, - server_complete_rpc_event.success), + server_send_initial_metadata_event.success, + ), + _common.OperationResult( + server_complete_rpc_start_batch_result, + server_complete_rpc_event.completion_type, + server_complete_rpc_event.success, + ), ) def test_rpcs(self): - expecteds = [(_common.SUCCESSFUL_OPERATION_RESULT,) * 5 - ] * _common.RPC_COUNT + expecteds = [ + (_common.SUCCESSFUL_OPERATION_RESULT,) * 5 + ] * _common.RPC_COUNT actuallys = _common.execute_many_times(self._do_rpcs) self.assertSequenceEqual(expecteds, actuallys) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py index 701ebcee5c28a..7f87b2b8a8c6e 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py @@ -25,7 +25,6 @@ class _ServerDriver(object): - def __init__(self, completion_queue, shutdown_tag): self._condition = threading.Condition() self._completion_queue = completion_queue @@ -34,7 +33,6 @@ def __init__(self, completion_queue, shutdown_tag): self._saw_shutdown_tag = False def start(self): - def in_thread(): while True: event = self._completion_queue.poll() @@ -66,7 +64,6 @@ def events(self): class _QueueDriver(object): - def __init__(self, condition, completion_queue, due): self._condition = condition self._completion_queue = completion_queue @@ -75,7 +72,6 @@ def __init__(self, condition, completion_queue, due): self._returned = False def start(self): - def in_thread(): while True: event = self._completion_queue.poll() @@ -110,53 +106,71 @@ def events(self): class ReadSomeButNotAllResponsesTest(unittest.TestCase): - def testReadSomeButNotAllResponses(self): server_completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server([( - b'grpc.so_reuseport', - 0, - )], False) + server = cygrpc.Server( + [ + ( + b"grpc.so_reuseport", + 0, + ) + ], + False, + ) server.register_completion_queue(server_completion_queue) - port = server.add_http2_port(b'[::]:0') + port = server.add_http2_port(b"[::]:0") server.start() - channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set(), - None) + channel = cygrpc.Channel( + "localhost:{}".format(port).encode(), set(), None + ) - server_shutdown_tag = 'server_shutdown_tag' - server_driver = _ServerDriver(server_completion_queue, - server_shutdown_tag) + server_shutdown_tag = "server_shutdown_tag" + server_driver = _ServerDriver( + server_completion_queue, server_shutdown_tag + ) server_driver.start() client_condition = threading.Condition() client_due = set() server_call_condition = threading.Condition() - server_send_initial_metadata_tag = 'server_send_initial_metadata_tag' - server_send_first_message_tag = 'server_send_first_message_tag' - server_send_second_message_tag = 'server_send_second_message_tag' - server_complete_rpc_tag = 'server_complete_rpc_tag' - server_call_due = set(( - server_send_initial_metadata_tag, - server_send_first_message_tag, - server_send_second_message_tag, - server_complete_rpc_tag, - )) + server_send_initial_metadata_tag = "server_send_initial_metadata_tag" + server_send_first_message_tag = "server_send_first_message_tag" + server_send_second_message_tag = "server_send_second_message_tag" + server_complete_rpc_tag = "server_complete_rpc_tag" + server_call_due = set( + ( + server_send_initial_metadata_tag, + server_send_first_message_tag, + server_send_second_message_tag, + server_complete_rpc_tag, + ) + ) server_call_completion_queue = cygrpc.CompletionQueue() - server_call_driver = _QueueDriver(server_call_condition, - server_call_completion_queue, - server_call_due) + server_call_driver = _QueueDriver( + server_call_condition, server_call_completion_queue, server_call_due + ) server_call_driver.start() - server_rpc_tag = 'server_rpc_tag' - request_call_result = server.request_call(server_call_completion_queue, - server_completion_queue, - server_rpc_tag) + server_rpc_tag = "server_rpc_tag" + request_call_result = server.request_call( + server_call_completion_queue, + server_completion_queue, + server_rpc_tag, + ) - client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag' - client_complete_rpc_tag = 'client_complete_rpc_tag' + client_receive_initial_metadata_tag = ( + "client_receive_initial_metadata_tag" + ) + client_complete_rpc_tag = "client_complete_rpc_tag" client_call = channel.segregated_call( - _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, None, ( + _EMPTY_FLAGS, + b"/twinkies", + None, + None, + _EMPTY_METADATA, + None, + ( ( [ cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), @@ -166,76 +180,111 @@ def testReadSomeButNotAllResponses(self): ( [ cygrpc.SendInitialMetadataOperation( - _EMPTY_METADATA, _EMPTY_FLAGS), + _EMPTY_METADATA, _EMPTY_FLAGS + ), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), ], client_complete_rpc_tag, ), - )) - client_receive_initial_metadata_event_future = test_utilities.SimpleFuture( - client_call.next_event) + ), + ) + client_receive_initial_metadata_event_future = ( + test_utilities.SimpleFuture(client_call.next_event) + ) server_rpc_event = server_driver.first_event() with server_call_condition: server_send_initial_metadata_start_batch_result = ( - server_rpc_event.call.start_server_batch([ - cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA, - _EMPTY_FLAGS), - ], server_send_initial_metadata_tag)) + server_rpc_event.call.start_server_batch( + [ + cygrpc.SendInitialMetadataOperation( + _EMPTY_METADATA, _EMPTY_FLAGS + ), + ], + server_send_initial_metadata_tag, + ) + ) server_send_first_message_start_batch_result = ( - server_rpc_event.call.start_server_batch([ - cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS), - ], server_send_first_message_tag)) + server_rpc_event.call.start_server_batch( + [ + cygrpc.SendMessageOperation(b"\x07", _EMPTY_FLAGS), + ], + server_send_first_message_tag, + ) + ) server_send_initial_metadata_event = server_call_driver.event_with_tag( - server_send_initial_metadata_tag) + server_send_initial_metadata_tag + ) server_send_first_message_event = server_call_driver.event_with_tag( - server_send_first_message_tag) + server_send_first_message_tag + ) with server_call_condition: server_send_second_message_start_batch_result = ( - server_rpc_event.call.start_server_batch([ - cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS), - ], server_send_second_message_tag)) + server_rpc_event.call.start_server_batch( + [ + cygrpc.SendMessageOperation(b"\x07", _EMPTY_FLAGS), + ], + server_send_second_message_tag, + ) + ) server_complete_rpc_start_batch_result = ( - server_rpc_event.call.start_server_batch([ - cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), - cygrpc.SendStatusFromServerOperation( - (), cygrpc.StatusCode.ok, b'test details', - _EMPTY_FLAGS), - ], server_complete_rpc_tag)) + server_rpc_event.call.start_server_batch( + [ + cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + (), + cygrpc.StatusCode.ok, + b"test details", + _EMPTY_FLAGS, + ), + ], + server_complete_rpc_tag, + ) + ) server_send_second_message_event = server_call_driver.event_with_tag( - server_send_second_message_tag) + server_send_second_message_tag + ) server_complete_rpc_event = server_call_driver.event_with_tag( - server_complete_rpc_tag) + server_complete_rpc_tag + ) server_call_driver.events() - client_recieve_initial_metadata_event = client_receive_initial_metadata_event_future.result( + client_recieve_initial_metadata_event = ( + client_receive_initial_metadata_event_future.result() ) - client_receive_first_message_tag = 'client_receive_first_message_tag' - client_call.operate([ - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - ], client_receive_first_message_tag) + client_receive_first_message_tag = "client_receive_first_message_tag" + client_call.operate( + [ + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], + client_receive_first_message_tag, + ) client_receive_first_message_event = client_call.next_event() client_call_cancel_result = client_call.cancel( - cygrpc.StatusCode.cancelled, 'Cancelled during test!') + cygrpc.StatusCode.cancelled, "Cancelled during test!" + ) client_complete_rpc_event = client_call.next_event() - channel.close(cygrpc.StatusCode.unknown, 'Channel closed!') + channel.close(cygrpc.StatusCode.unknown, "Channel closed!") server.shutdown(server_completion_queue, server_shutdown_tag) server.cancel_all_calls() server_driver.events() self.assertEqual(cygrpc.CallError.ok, request_call_result) - self.assertEqual(cygrpc.CallError.ok, - server_send_initial_metadata_start_batch_result) + self.assertEqual( + cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result + ) self.assertIs(server_rpc_tag, server_rpc_event.tag) - self.assertEqual(cygrpc.CompletionType.operation_complete, - server_rpc_event.completion_type) + self.assertEqual( + cygrpc.CompletionType.operation_complete, + server_rpc_event.completion_type, + ) self.assertIsInstance(server_rpc_event.call, cygrpc.Call) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/_server_test.py b/src/python/grpcio_tests/tests/unit/_cython/_server_test.py index 60b068243c046..2ee36818a07ca 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_server_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_server_test.py @@ -21,29 +21,31 @@ class Test(unittest.TestCase): - def test_lonely_server(self): server_call_completion_queue = cygrpc.CompletionQueue() server_shutdown_completion_queue = cygrpc.CompletionQueue() server = cygrpc.Server(None, False) server.register_completion_queue(server_call_completion_queue) server.register_completion_queue(server_shutdown_completion_queue) - port = server.add_http2_port(b'[::]:0') + port = server.add_http2_port(b"[::]:0") server.start() - server_request_call_tag = 'server_request_call_tag' + server_request_call_tag = "server_request_call_tag" server_request_call_start_batch_result = server.request_call( - server_call_completion_queue, server_call_completion_queue, - server_request_call_tag) + server_call_completion_queue, + server_call_completion_queue, + server_request_call_tag, + ) time.sleep(4) - server_shutdown_tag = 'server_shutdown_tag' + server_shutdown_tag = "server_shutdown_tag" server_shutdown_result = server.shutdown( - server_shutdown_completion_queue, server_shutdown_tag) + server_shutdown_completion_queue, server_shutdown_tag + ) server_request_call_event = server_call_completion_queue.poll() server_shutdown_event = server_shutdown_completion_queue.poll() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py index 58bc579373e95..1e66e2bd7eedb 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py @@ -23,117 +23,150 @@ from tests.unit import test_common from tests.unit._cython import test_utilities -_SSL_HOST_OVERRIDE = b'foo.test.google.fr' -_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key' -_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value' +_SSL_HOST_OVERRIDE = b"foo.test.google.fr" +_CALL_CREDENTIALS_METADATA_KEY = "call-creds-key" +_CALL_CREDENTIALS_METADATA_VALUE = "call-creds-value" _EMPTY_FLAGS = 0 def _metadata_plugin(context, callback): - callback((( - _CALL_CREDENTIALS_METADATA_KEY, - _CALL_CREDENTIALS_METADATA_VALUE, - ),), cygrpc.StatusCode.ok, b'') + callback( + ( + ( + _CALL_CREDENTIALS_METADATA_KEY, + _CALL_CREDENTIALS_METADATA_VALUE, + ), + ), + cygrpc.StatusCode.ok, + b"", + ) class TypeSmokeTest(unittest.TestCase): - def testCompletionQueueUpDown(self): completion_queue = cygrpc.CompletionQueue() del completion_queue def testServerUpDown(self): - server = cygrpc.Server(set([( - b'grpc.so_reuseport', - 0, - )]), False) + server = cygrpc.Server( + set( + [ + ( + b"grpc.so_reuseport", + 0, + ) + ] + ), + False, + ) del server def testChannelUpDown(self): - channel = cygrpc.Channel(b'[::]:0', None, None) - channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!') + channel = cygrpc.Channel(b"[::]:0", None, None) + channel.close(cygrpc.StatusCode.cancelled, "Test method anyway!") def test_metadata_plugin_call_credentials_up_down(self): - cygrpc.MetadataPluginCallCredentials(_metadata_plugin, - b'test plugin name!') + cygrpc.MetadataPluginCallCredentials( + _metadata_plugin, b"test plugin name!" + ) def testServerStartNoExplicitShutdown(self): - server = cygrpc.Server([( - b'grpc.so_reuseport', - 0, - )], False) + server = cygrpc.Server( + [ + ( + b"grpc.so_reuseport", + 0, + ) + ], + False, + ) completion_queue = cygrpc.CompletionQueue() server.register_completion_queue(completion_queue) - port = server.add_http2_port(b'[::]:0') + port = server.add_http2_port(b"[::]:0") self.assertIsInstance(port, int) server.start() del server def testServerStartShutdown(self): completion_queue = cygrpc.CompletionQueue() - server = cygrpc.Server([ - ( - b'grpc.so_reuseport', - 0, - ), - ], False) - server.add_http2_port(b'[::]:0') + server = cygrpc.Server( + [ + ( + b"grpc.so_reuseport", + 0, + ), + ], + False, + ) + server.add_http2_port(b"[::]:0") server.register_completion_queue(completion_queue) server.start() shutdown_tag = object() server.shutdown(completion_queue, shutdown_tag) event = completion_queue.poll() - self.assertEqual(cygrpc.CompletionType.operation_complete, - event.completion_type) + self.assertEqual( + cygrpc.CompletionType.operation_complete, event.completion_type + ) self.assertIs(shutdown_tag, event.tag) del server del completion_queue class ServerClientMixin(object): - def setUpMixin(self, server_credentials, client_credentials, host_override): self.server_completion_queue = cygrpc.CompletionQueue() - self.server = cygrpc.Server([( - b'grpc.so_reuseport', - 0, - )], False) + self.server = cygrpc.Server( + [ + ( + b"grpc.so_reuseport", + 0, + ) + ], + False, + ) self.server.register_completion_queue(self.server_completion_queue) if server_credentials: - self.port = self.server.add_http2_port(b'[::]:0', - server_credentials) + self.port = self.server.add_http2_port( + b"[::]:0", server_credentials + ) else: - self.port = self.server.add_http2_port(b'[::]:0') + self.port = self.server.add_http2_port(b"[::]:0") self.server.start() self.client_completion_queue = cygrpc.CompletionQueue() if client_credentials: - client_channel_arguments = (( - cygrpc.ChannelArgKey.ssl_target_name_override, - host_override, - ),) + client_channel_arguments = ( + ( + cygrpc.ChannelArgKey.ssl_target_name_override, + host_override, + ), + ) self.client_channel = cygrpc.Channel( - 'localhost:{}'.format(self.port).encode(), - client_channel_arguments, client_credentials) + "localhost:{}".format(self.port).encode(), + client_channel_arguments, + client_credentials, + ) else: self.client_channel = cygrpc.Channel( - 'localhost:{}'.format(self.port).encode(), set(), None) + "localhost:{}".format(self.port).encode(), set(), None + ) if host_override: self.host_argument = None # default host self.expected_host = host_override else: # arbitrary host name necessitating no further identification - self.host_argument = b'hostess' + self.host_argument = b"hostess" self.expected_host = self.host_argument def tearDownMixin(self): - self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!') + self.client_channel.close(cygrpc.StatusCode.ok, "test being torn down!") del self.client_channel del self.server del self.client_completion_queue del self.server_completion_queue - def _perform_queue_operations(self, operations, call, queue, deadline, - description): + def _perform_queue_operations( + self, operations, call, queue, deadline, description + ): """Perform the operations with given call, queue, and deadline. Invocation errors are reported with as an exception with `description` @@ -147,13 +180,16 @@ def performer(): call_result = call.start_client_batch(operations, tag) self.assertEqual(cygrpc.CallError.ok, call_result) event = queue.poll(deadline=deadline) - self.assertEqual(cygrpc.CompletionType.operation_complete, - event.completion_type) + self.assertEqual( + cygrpc.CompletionType.operation_complete, + event.completion_type, + ) self.assertTrue(event.success) self.assertIs(tag, event.tag) except Exception as error: - raise Exception("Error in '{}': {}".format( - description, error.message)) + raise Exception( + "Error in '{}': {}".format(description, error.message) + ) return event return test_utilities.SimpleFuture(performer) @@ -161,24 +197,26 @@ def performer(): def test_echo(self): DEADLINE = time.time() + 5 DEADLINE_TOLERANCE = 0.25 - CLIENT_METADATA_ASCII_KEY = 'key' - CLIENT_METADATA_ASCII_VALUE = 'val' - CLIENT_METADATA_BIN_KEY = 'key-bin' - CLIENT_METADATA_BIN_VALUE = b'\0' * 1000 - SERVER_INITIAL_METADATA_KEY = 'init_me_me_me' - SERVER_INITIAL_METADATA_VALUE = 'whodawha?' - SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought' - SERVER_TRAILING_METADATA_VALUE = 'zomg it is' + CLIENT_METADATA_ASCII_KEY = "key" + CLIENT_METADATA_ASCII_VALUE = "val" + CLIENT_METADATA_BIN_KEY = "key-bin" + CLIENT_METADATA_BIN_VALUE = b"\0" * 1000 + SERVER_INITIAL_METADATA_KEY = "init_me_me_me" + SERVER_INITIAL_METADATA_VALUE = "whodawha?" + SERVER_TRAILING_METADATA_KEY = "california_is_in_a_drought" + SERVER_TRAILING_METADATA_VALUE = "zomg it is" SERVER_STATUS_CODE = cygrpc.StatusCode.ok - SERVER_STATUS_DETAILS = 'our work is never over' - REQUEST = b'in death a member of project mayhem has a name' - RESPONSE = b'his name is robert paulson' - METHOD = b'twinkies' + SERVER_STATUS_DETAILS = "our work is never over" + REQUEST = b"in death a member of project mayhem has a name" + RESPONSE = b"his name is robert paulson" + METHOD = b"twinkies" server_request_tag = object() request_call_result = self.server.request_call( - self.server_completion_queue, self.server_completion_queue, - server_request_tag) + self.server_completion_queue, + self.server_completion_queue, + server_request_tag, + ) self.assertEqual(cygrpc.CallError.ok, request_call_result) @@ -194,12 +232,18 @@ def test_echo(self): ), ) client_call = self.client_channel.integrated_call( - 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata, - None, [ + 0, + METHOD, + self.host_argument, + DEADLINE, + client_initial_metadata, + None, + [ ( [ cygrpc.SendInitialMetadataOperation( - client_initial_metadata, _EMPTY_FLAGS), + client_initial_metadata, _EMPTY_FLAGS + ), cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS), cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), @@ -208,51 +252,74 @@ def test_echo(self): ], client_call_tag, ), - ]) + ], + ) client_event_future = test_utilities.SimpleFuture( - self.client_channel.next_call_event) + self.client_channel.next_call_event + ) request_event = self.server_completion_queue.poll(deadline=DEADLINE) - self.assertEqual(cygrpc.CompletionType.operation_complete, - request_event.completion_type) + self.assertEqual( + cygrpc.CompletionType.operation_complete, + request_event.completion_type, + ) self.assertIsInstance(request_event.call, cygrpc.Call) self.assertIs(server_request_tag, request_event.tag) self.assertTrue( - test_common.metadata_transmitted(client_initial_metadata, - request_event.invocation_metadata)) + test_common.metadata_transmitted( + client_initial_metadata, request_event.invocation_metadata + ) + ) self.assertEqual(METHOD, request_event.call_details.method) self.assertEqual(self.expected_host, request_event.call_details.host) - self.assertLess(abs(DEADLINE - request_event.call_details.deadline), - DEADLINE_TOLERANCE) + self.assertLess( + abs(DEADLINE - request_event.call_details.deadline), + DEADLINE_TOLERANCE, + ) server_call_tag = object() server_call = request_event.call - server_start_batch_result = server_call.start_server_batch([ - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - ], server_call_tag) + server_start_batch_result = server_call.start_server_batch( + [ + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], + server_call_tag, + ) self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) server_message_event = self.server_completion_queue.poll( - deadline=DEADLINE) + deadline=DEADLINE + ) server_call_tag = object() - server_initial_metadata = (( - SERVER_INITIAL_METADATA_KEY, - SERVER_INITIAL_METADATA_VALUE, - ),) - server_trailing_metadata = (( - SERVER_TRAILING_METADATA_KEY, - SERVER_TRAILING_METADATA_VALUE, - ),) - server_start_batch_result = server_call.start_server_batch([ - cygrpc.SendInitialMetadataOperation(server_initial_metadata, - _EMPTY_FLAGS), - cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS), - cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), - cygrpc.SendStatusFromServerOperation( - server_trailing_metadata, SERVER_STATUS_CODE, - SERVER_STATUS_DETAILS, _EMPTY_FLAGS) - ], server_call_tag) + server_initial_metadata = ( + ( + SERVER_INITIAL_METADATA_KEY, + SERVER_INITIAL_METADATA_VALUE, + ), + ) + server_trailing_metadata = ( + ( + SERVER_TRAILING_METADATA_KEY, + SERVER_TRAILING_METADATA_VALUE, + ), + ) + server_start_batch_result = server_call.start_server_batch( + [ + cygrpc.SendInitialMetadataOperation( + server_initial_metadata, _EMPTY_FLAGS + ), + cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS), + cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + server_trailing_metadata, + SERVER_STATUS_CODE, + SERVER_STATUS_DETAILS, + _EMPTY_FLAGS, + ), + ], + server_call_tag, + ) self.assertEqual(cygrpc.CallError.ok, server_start_batch_result) server_event = self.server_completion_queue.poll(deadline=DEADLINE) @@ -264,31 +331,43 @@ def test_echo(self): # we expect each op type to be unique self.assertNotIn(client_result.type(), found_client_op_types) found_client_op_types.add(client_result.type()) - if client_result.type( - ) == cygrpc.OperationType.receive_initial_metadata: + if ( + client_result.type() + == cygrpc.OperationType.receive_initial_metadata + ): self.assertTrue( test_common.metadata_transmitted( server_initial_metadata, - client_result.initial_metadata())) + client_result.initial_metadata(), + ) + ) elif client_result.type() == cygrpc.OperationType.receive_message: self.assertEqual(RESPONSE, client_result.message()) - elif client_result.type( - ) == cygrpc.OperationType.receive_status_on_client: + elif ( + client_result.type() + == cygrpc.OperationType.receive_status_on_client + ): self.assertTrue( test_common.metadata_transmitted( server_trailing_metadata, - client_result.trailing_metadata())) + client_result.trailing_metadata(), + ) + ) self.assertEqual(SERVER_STATUS_DETAILS, client_result.details()) self.assertEqual(SERVER_STATUS_CODE, client_result.code()) self.assertEqual( - set([ - cygrpc.OperationType.send_initial_metadata, - cygrpc.OperationType.send_message, - cygrpc.OperationType.send_close_from_client, - cygrpc.OperationType.receive_initial_metadata, - cygrpc.OperationType.receive_message, - cygrpc.OperationType.receive_status_on_client - ]), found_client_op_types) + set( + [ + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.send_message, + cygrpc.OperationType.send_close_from_client, + cygrpc.OperationType.receive_initial_metadata, + cygrpc.OperationType.receive_message, + cygrpc.OperationType.receive_status_on_client, + ] + ), + found_client_op_types, + ) self.assertEqual(1, len(server_message_event.batch_operations)) found_server_op_types = set() @@ -297,12 +376,19 @@ def test_echo(self): found_server_op_types.add(server_result.type()) if server_result.type() == cygrpc.OperationType.receive_message: self.assertEqual(REQUEST, server_result.message()) - elif server_result.type( - ) == cygrpc.OperationType.receive_close_on_server: + elif ( + server_result.type() + == cygrpc.OperationType.receive_close_on_server + ): self.assertFalse(server_result.cancelled()) - self.assertEqual(set([ - cygrpc.OperationType.receive_message, - ]), found_server_op_types) + self.assertEqual( + set( + [ + cygrpc.OperationType.receive_message, + ] + ), + found_server_op_types, + ) self.assertEqual(4, len(server_event.batch_operations)) found_server_op_types = set() @@ -311,16 +397,22 @@ def test_echo(self): found_server_op_types.add(server_result.type()) if server_result.type() == cygrpc.OperationType.receive_message: self.assertEqual(REQUEST, server_result.message()) - elif server_result.type( - ) == cygrpc.OperationType.receive_close_on_server: + elif ( + server_result.type() + == cygrpc.OperationType.receive_close_on_server + ): self.assertFalse(server_result.cancelled()) self.assertEqual( - set([ - cygrpc.OperationType.send_initial_metadata, - cygrpc.OperationType.send_message, - cygrpc.OperationType.receive_close_on_server, - cygrpc.OperationType.send_status_from_server - ]), found_server_op_types) + set( + [ + cygrpc.OperationType.send_initial_metadata, + cygrpc.OperationType.send_message, + cygrpc.OperationType.receive_close_on_server, + cygrpc.OperationType.send_status_from_server, + ] + ), + found_server_op_types, + ) del client_call del server_call @@ -328,88 +420,128 @@ def test_echo(self): def test_6522(self): DEADLINE = time.time() + 5 DEADLINE_TOLERANCE = 0.25 - METHOD = b'twinkies' + METHOD = b"twinkies" empty_metadata = () # Prologue server_request_tag = object() - self.server.request_call(self.server_completion_queue, - self.server_completion_queue, - server_request_tag) + self.server.request_call( + self.server_completion_queue, + self.server_completion_queue, + server_request_tag, + ) client_call = self.client_channel.segregated_call( - 0, METHOD, self.host_argument, DEADLINE, None, None, - ([( + 0, + METHOD, + self.host_argument, + DEADLINE, + None, + None, + ( [ - cygrpc.SendInitialMetadataOperation(empty_metadata, - _EMPTY_FLAGS), - cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS), - ], - object(), + ( + [ + cygrpc.SendInitialMetadataOperation( + empty_metadata, _EMPTY_FLAGS + ), + cygrpc.ReceiveInitialMetadataOperation( + _EMPTY_FLAGS + ), + ], + object(), + ), + ( + [ + cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), + ], + object(), + ), + ] ), - ( - [ - cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS), - ], - object(), - )])) + ) client_initial_metadata_event_future = test_utilities.SimpleFuture( - client_call.next_event) + client_call.next_event + ) request_event = self.server_completion_queue.poll(deadline=DEADLINE) server_call = request_event.call def perform_server_operations(operations, description): - return self._perform_queue_operations(operations, server_call, - self.server_completion_queue, - DEADLINE, description) - - server_event_future = perform_server_operations([ - cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS), - ], "Server prologue") + return self._perform_queue_operations( + operations, + server_call, + self.server_completion_queue, + DEADLINE, + description, + ) + + server_event_future = perform_server_operations( + [ + cygrpc.SendInitialMetadataOperation( + empty_metadata, _EMPTY_FLAGS + ), + ], + "Server prologue", + ) client_initial_metadata_event_future.result() # force completion server_event_future.result() # Messaging for _ in range(10): - client_call.operate([ - cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - ], "Client message") + client_call.operate( + [ + cygrpc.SendMessageOperation(b"", _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], + "Client message", + ) client_message_event_future = test_utilities.SimpleFuture( - client_call.next_event) - server_event_future = perform_server_operations([ - cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS), - cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), - ], "Server receive") + client_call.next_event + ) + server_event_future = perform_server_operations( + [ + cygrpc.SendMessageOperation(b"", _EMPTY_FLAGS), + cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS), + ], + "Server receive", + ) client_message_event_future.result() # force completion server_event_future.result() # Epilogue - client_call.operate([ - cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), - ], "Client epilogue") + client_call.operate( + [ + cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS), + ], + "Client epilogue", + ) # One for ReceiveStatusOnClient, one for SendCloseFromClient. - client_events_future = test_utilities.SimpleFuture(lambda: { - client_call.next_event(), - client_call.next_event(), - }) + client_events_future = test_utilities.SimpleFuture( + lambda: { + client_call.next_event(), + client_call.next_event(), + } + ) - server_event_future = perform_server_operations([ - cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), - cygrpc.SendStatusFromServerOperation( - empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS) - ], "Server epilogue") + server_event_future = perform_server_operations( + [ + cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS), + cygrpc.SendStatusFromServerOperation( + empty_metadata, cygrpc.StatusCode.ok, b"", _EMPTY_FLAGS + ), + ], + "Server epilogue", + ) client_events_future.result() # force completion server_event_future.result() class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin): - def setUp(self): self.setUpMixin(None, None, None) @@ -418,21 +550,26 @@ def tearDown(self): class SecureServerSecureClient(unittest.TestCase, ServerClientMixin): - def setUp(self): server_credentials = cygrpc.server_credentials_ssl( - None, [ - cygrpc.SslPemKeyCertPair(resources.private_key(), - resources.certificate_chain()) - ], False) + None, + [ + cygrpc.SslPemKeyCertPair( + resources.private_key(), resources.certificate_chain() + ) + ], + False, + ) client_credentials = cygrpc.SSLChannelCredentials( - resources.test_root_certificates(), None, None) - self.setUpMixin(server_credentials, client_credentials, - _SSL_HOST_OVERRIDE) + resources.test_root_certificates(), None, None + ) + self.setUpMixin( + server_credentials, client_credentials, _SSL_HOST_OVERRIDE + ) def tearDown(self): self.tearDownMixin() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py b/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py index 7d5eaaaa84261..9dd62ad65575e 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py +++ b/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py @@ -21,7 +21,6 @@ class SimpleFuture(object): """A simple future mechanism.""" def __init__(self, function, *args, **kwargs): - def wrapped_function(): try: self._result = function(*args, **kwargs) @@ -36,8 +35,8 @@ def wrapped_function(): def result(self): """The resulting value of this future. - Re-raises any exceptions. - """ + Re-raises any exceptions. + """ self._thread.join() if self._error: # TODO(atash): re-raise exceptions in a way that preserves tracebacks @@ -46,7 +45,7 @@ def result(self): class CompletionQueuePollFuture(SimpleFuture): - def __init__(self, completion_queue, deadline): - super(CompletionQueuePollFuture, - self).__init__(lambda: completion_queue.poll(deadline=deadline)) + super(CompletionQueuePollFuture, self).__init__( + lambda: completion_queue.poll(deadline=deadline) + ) diff --git a/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py b/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py index f4196aaaac169..62a95a021357f 100644 --- a/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py +++ b/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py @@ -21,13 +21,12 @@ from tests.unit import test_common from tests.unit.framework.common import test_constants -_METHOD = '/ANY/METHOD' -_REQUEST = b'\x00\x00\x00' +_METHOD = "/ANY/METHOD" +_REQUEST = b"\x00\x00\x00" _RESPONSE = _REQUEST class GenericHandler(grpc.GenericRpcHandler): - def service(self, unused_handler_details): return grpc.unary_unary_rpc_method_handler( lambda request, unused_context: request, @@ -35,11 +34,10 @@ def service(self, unused_handler_details): class DNSResolverTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() self._server.add_generic_rpc_handlers((GenericHandler(),)) - self._port = self._server.add_insecure_port('[::]:0') + self._port = self._server.add_insecure_port("[::]:0") self._server.start() def tearDown(self): @@ -53,15 +51,18 @@ def test_connect_loopback(self): # it returns the expected responses even when DNS64 dns servers # are used on the test worker (and for purposes of this # test the use of loopback4 vs loopback46 makes no difference). - with grpc.insecure_channel('loopback46.unittest.grpc.io:%d' % - self._port) as channel: + with grpc.insecure_channel( + "loopback46.unittest.grpc.io:%d" % self._port + ) as channel: self.assertEqual( channel.unary_unary(_METHOD)( _REQUEST, timeout=10, - ), _RESPONSE) + ), + _RESPONSE, + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_dynamic_stubs_test.py b/src/python/grpcio_tests/tests/unit/_dynamic_stubs_test.py index 7f95188b910ab..90f0779a4a2d4 100644 --- a/src/python/grpcio_tests/tests/unit/_dynamic_stubs_test.py +++ b/src/python/grpcio_tests/tests/unit/_dynamic_stubs_test.py @@ -45,7 +45,6 @@ def _grpc_tools_unimportable(): def _collect_errors(fn): - @functools.wraps(fn) def _wrapped(error_queue): try: @@ -58,7 +57,6 @@ def _wrapped(error_queue): def _python3_check(fn): - @functools.wraps(fn) def _wrapped(): if sys.version_info[0] == 3: @@ -71,7 +69,8 @@ def _wrapped(): def _run_in_subprocess(test_case): sys.path.insert( - 0, os.path.join(os.path.realpath(os.path.dirname(__file__)), "..")) + 0, os.path.join(os.path.realpath(os.path.dirname(__file__)), "..") + ) error_queue = multiprocessing.Queue() proc = multiprocessing.Process(target=test_case, args=(error_queue,)) proc.start() @@ -80,17 +79,21 @@ def _run_in_subprocess(test_case): if not error_queue.empty(): raise error_queue.get() assert proc.exitcode == 0, "Process exited with code {}".format( - proc.exitcode) + proc.exitcode + ) def _assert_unimplemented(msg_substr): import grpc + try: protos, services = grpc.protos_and_services( - "tests/unit/data/foo/bar.proto") + "tests/unit/data/foo/bar.proto" + ) except NotImplementedError as e: assert msg_substr in str(e), "{} was not in '{}'".format( - msg_substr, str(e)) + msg_substr, str(e) + ) else: assert False, "Did not raise NotImplementedError" @@ -99,8 +102,10 @@ def _assert_unimplemented(msg_substr): @_python3_check def _test_sunny_day(): import grpc + protos, services = grpc.protos_and_services( - os.path.join(_DATA_DIR, "foo", "bar.proto")) + os.path.join(_DATA_DIR, "foo", "bar.proto") + ) assert protos.BarMessage is not None assert services.BarStub is not None @@ -109,8 +114,10 @@ def _test_sunny_day(): @_python3_check def _test_well_known_types(): import grpc + protos, services = grpc.protos_and_services( - os.path.join(_DATA_DIR, "foo", "bar_with_wkt.proto")) + os.path.join(_DATA_DIR, "foo", "bar_with_wkt.proto") + ) assert protos.BarMessage is not None assert services.BarStub is not None @@ -127,7 +134,6 @@ def _test_grpc_tools_unimportable(): # if run directly on Windows, but not if started by the test runner. @unittest.skipIf(os.name == "nt", "Windows multiprocessing unsupported") class DynamicStubTest(unittest.TestCase): - def test_sunny_day(self): _run_in_subprocess(_test_sunny_day) diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py index 918dbe73d1af9..e2dc1594202a7 100644 --- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py +++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py @@ -20,13 +20,13 @@ from tests.unit import test_common from tests.unit.framework.common import test_constants -_REQUEST = b'' -_RESPONSE = b'' +_REQUEST = b"" +_RESPONSE = b"" -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" def handle_unary_unary(request, servicer_context): @@ -50,7 +50,6 @@ def handle_stream_stream(request_iterator, servicer_context): class _MethodHandler(grpc.RpcMethodHandler): - def __init__(self, request_streaming, response_streaming): self.request_streaming = request_streaming self.response_streaming = response_streaming @@ -71,7 +70,6 @@ def __init__(self, request_streaming, response_streaming): class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: return _MethodHandler(False, False) @@ -86,13 +84,12 @@ def service(self, handler_call_details): class EmptyMessageTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() self._server.add_generic_rpc_handlers((_GenericHandler(),)) - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._server.stop(0) @@ -104,21 +101,25 @@ def testUnaryUnary(self): def testUnaryStream(self): response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST) - self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH, - list(response_iterator)) + self.assertSequenceEqual( + [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) + ) def testStreamUnary(self): - response = self._channel.stream_unary(_STREAM_UNARY)(iter( - [_REQUEST] * test_constants.STREAM_LENGTH)) + response = self._channel.stream_unary(_STREAM_UNARY)( + iter([_REQUEST] * test_constants.STREAM_LENGTH) + ) self.assertEqual(_RESPONSE, response) def testStreamStream(self): - response_iterator = self._channel.stream_stream(_STREAM_STREAM)(iter( - [_REQUEST] * test_constants.STREAM_LENGTH)) - self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH, - list(response_iterator)) + response_iterator = self._channel.stream_stream(_STREAM_STREAM)( + iter([_REQUEST] * test_constants.STREAM_LENGTH) + ) + self.assertSequenceEqual( + [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py index e58007ad3eda0..4f07477fac9c5 100644 --- a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py +++ b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py @@ -22,19 +22,18 @@ from tests.unit.framework.common import test_constants _UNICODE_ERROR_MESSAGES = [ - b'\xe2\x80\x9d'.decode('utf-8'), - b'abc\x80\xd0\xaf'.decode('latin-1'), - b'\xc3\xa9'.decode('utf-8'), + b"\xe2\x80\x9d".decode("utf-8"), + b"abc\x80\xd0\xaf".decode("latin-1"), + b"\xc3\xa9".decode("utf-8"), ] -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" -_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_UNARY = "/test/UnaryUnary" class _MethodHandler(grpc.RpcMethodHandler): - def __init__(self, request_streaming=None, response_streaming=None): self.request_streaming = request_streaming self.response_streaming = response_streaming @@ -46,12 +45,11 @@ def __init__(self, request_streaming=None, response_streaming=None): def unary_unary(self, request, servicer_context): servicer_context.set_code(grpc.StatusCode.UNKNOWN) - servicer_context.set_details(request.decode('utf-8')) + servicer_context.set_details(request.decode("utf-8")) return _RESPONSE class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self, test): self._test = test @@ -60,14 +58,14 @@ def service(self, handler_call_details): class ErrorMessageEncodingTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() self._server.add_generic_rpc_handlers( - (_GenericHandler(weakref.proxy(self)),)) - port = self._server.add_insecure_port('[::]:0') + (_GenericHandler(weakref.proxy(self)),) + ) + port = self._server.add_insecure_port("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._server.stop(0) @@ -77,11 +75,11 @@ def testMessageEncoding(self): for message in _UNICODE_ERROR_MESSAGES: multi_callable = self._channel.unary_unary(_UNARY_UNARY) with self.assertRaises(grpc.RpcError) as cm: - multi_callable(message.encode('utf-8')) + multi_callable(message.encode("utf-8")) self.assertEqual(cm.exception.code(), grpc.StatusCode.UNKNOWN) self.assertEqual(cm.exception.details(), message) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py index 301afb6c27c72..c1f9816df08c7 100644 --- a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py +++ b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py @@ -24,27 +24,27 @@ WAIT_TIME = 1000 -REQUEST = b'request' - -UNSTARTED_SERVER = 'unstarted_server' -RUNNING_SERVER = 'running_server' -POLL_CONNECTIVITY_NO_SERVER = 'poll_connectivity_no_server' -POLL_CONNECTIVITY = 'poll_connectivity' -IN_FLIGHT_UNARY_UNARY_CALL = 'in_flight_unary_unary_call' -IN_FLIGHT_UNARY_STREAM_CALL = 'in_flight_unary_stream_call' -IN_FLIGHT_STREAM_UNARY_CALL = 'in_flight_stream_unary_call' -IN_FLIGHT_STREAM_STREAM_CALL = 'in_flight_stream_stream_call' -IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = 'in_flight_partial_unary_stream_call' -IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = 'in_flight_partial_stream_unary_call' -IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = 'in_flight_partial_stream_stream_call' - -UNARY_UNARY = b'/test/UnaryUnary' -UNARY_STREAM = b'/test/UnaryStream' -STREAM_UNARY = b'/test/StreamUnary' -STREAM_STREAM = b'/test/StreamStream' -PARTIAL_UNARY_STREAM = b'/test/PartialUnaryStream' -PARTIAL_STREAM_UNARY = b'/test/PartialStreamUnary' -PARTIAL_STREAM_STREAM = b'/test/PartialStreamStream' +REQUEST = b"request" + +UNSTARTED_SERVER = "unstarted_server" +RUNNING_SERVER = "running_server" +POLL_CONNECTIVITY_NO_SERVER = "poll_connectivity_no_server" +POLL_CONNECTIVITY = "poll_connectivity" +IN_FLIGHT_UNARY_UNARY_CALL = "in_flight_unary_unary_call" +IN_FLIGHT_UNARY_STREAM_CALL = "in_flight_unary_stream_call" +IN_FLIGHT_STREAM_UNARY_CALL = "in_flight_stream_unary_call" +IN_FLIGHT_STREAM_STREAM_CALL = "in_flight_stream_stream_call" +IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = "in_flight_partial_unary_stream_call" +IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = "in_flight_partial_stream_unary_call" +IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = "in_flight_partial_stream_stream_call" + +UNARY_UNARY = b"/test/UnaryUnary" +UNARY_STREAM = b"/test/UnaryStream" +STREAM_UNARY = b"/test/StreamUnary" +STREAM_STREAM = b"/test/StreamStream" +PARTIAL_UNARY_STREAM = b"/test/PartialUnaryStream" +PARTIAL_STREAM_UNARY = b"/test/PartialStreamUnary" +PARTIAL_STREAM_STREAM = b"/test/PartialStreamStream" TEST_TO_METHOD = { IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY, @@ -87,12 +87,11 @@ def hang_stream_stream(request_iterator, servicer_context): def hang_partial_stream_stream(request_iterator, servicer_context): for _ in range(test_constants.STREAM_LENGTH // 2): - yield next(request_iterator) #pylint: disable=stop-iteration-return + yield next(request_iterator) # pylint: disable=stop-iteration-return time.sleep(WAIT_TIME) class MethodHandler(grpc.RpcMethodHandler): - def __init__(self, request_streaming, response_streaming, partial_hang): self.request_streaming = request_streaming self.response_streaming = response_streaming @@ -122,7 +121,6 @@ def __init__(self, request_streaming, response_streaming, partial_hang): class GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == UNARY_UNARY: return MethodHandler(False, False, False) @@ -146,7 +144,6 @@ def service(self, handler_call_details): # current jobs complete. Because we submit jobs that will # never finish, we don't want to block exit on these jobs. class DaemonPool(object): - def submit(self, fn, *args, **kwargs): thread = threading.Thread(target=fn, args=args, kwargs=kwargs) thread.daemon = True @@ -161,27 +158,27 @@ def infinite_request_iterator(): yield REQUEST -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() parser = argparse.ArgumentParser() - parser.add_argument('scenario', type=str) - parser.add_argument('--wait_for_interrupt', - dest='wait_for_interrupt', - action='store_true') + parser.add_argument("scenario", type=str) + parser.add_argument( + "--wait_for_interrupt", dest="wait_for_interrupt", action="store_true" + ) args = parser.parse_args() if args.scenario == UNSTARTED_SERVER: - server = grpc.server(DaemonPool(), options=(('grpc.so_reuseport', 0),)) + server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) if args.wait_for_interrupt: time.sleep(WAIT_TIME) elif args.scenario == RUNNING_SERVER: - server = grpc.server(DaemonPool(), options=(('grpc.so_reuseport', 0),)) - port = server.add_insecure_port('[::]:0') + server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) + port = server.add_insecure_port("[::]:0") server.start() if args.wait_for_interrupt: time.sleep(WAIT_TIME) elif args.scenario == POLL_CONNECTIVITY_NO_SERVER: - channel = grpc.insecure_channel('localhost:12345') + channel = grpc.insecure_channel("localhost:12345") def connectivity_callback(connectivity): pass @@ -190,10 +187,10 @@ def connectivity_callback(connectivity): if args.wait_for_interrupt: time.sleep(WAIT_TIME) elif args.scenario == POLL_CONNECTIVITY: - server = grpc.server(DaemonPool(), options=(('grpc.so_reuseport', 0),)) - port = server.add_insecure_port('[::]:0') + server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) + port = server.add_insecure_port("[::]:0") server.start() - channel = grpc.insecure_channel('localhost:%d' % port) + channel = grpc.insecure_channel("localhost:%d" % port) def connectivity_callback(connectivity): pass @@ -204,11 +201,11 @@ def connectivity_callback(connectivity): else: handler = GenericHandler() - server = grpc.server(DaemonPool(), options=(('grpc.so_reuseport', 0),)) - port = server.add_insecure_port('[::]:0') + server = grpc.server(DaemonPool(), options=(("grpc.so_reuseport", 0),)) + port = server.add_insecure_port("[::]:0") server.add_generic_rpc_handlers((handler,)) server.start() - channel = grpc.insecure_channel('localhost:%d' % port) + channel = grpc.insecure_channel("localhost:%d" % port) method = TEST_TO_METHOD[args.scenario] @@ -216,20 +213,27 @@ def connectivity_callback(connectivity): multi_callable = channel.unary_unary(method) future = multi_callable.future(REQUEST) result, call = multi_callable.with_call(REQUEST) - elif (args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or - args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL): + elif ( + args.scenario == IN_FLIGHT_UNARY_STREAM_CALL + or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL + ): multi_callable = channel.unary_stream(method) response_iterator = multi_callable(REQUEST) for response in response_iterator: pass - elif (args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or - args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL): + elif ( + args.scenario == IN_FLIGHT_STREAM_UNARY_CALL + or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL + ): multi_callable = channel.stream_unary(method) future = multi_callable.future(infinite_request_iterator()) result, call = multi_callable.with_call( - iter([REQUEST] * test_constants.STREAM_LENGTH)) - elif (args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or - args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL): + iter([REQUEST] * test_constants.STREAM_LENGTH) + ) + elif ( + args.scenario == IN_FLIGHT_STREAM_STREAM_CALL + or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL + ): multi_callable = channel.stream_stream(method) response_iterator = multi_callable(infinite_request_iterator()) for response in response_iterator: diff --git a/src/python/grpcio_tests/tests/unit/_exit_test.py b/src/python/grpcio_tests/tests/unit/_exit_test.py index d3c3d6a22e737..8ff4610762527 100644 --- a/src/python/grpcio_tests/tests/unit/_exit_test.py +++ b/src/python/grpcio_tests/tests/unit/_exit_test.py @@ -31,11 +31,13 @@ from tests.unit import _exit_scenarios SCENARIO_FILE = os.path.abspath( - os.path.join(os.path.dirname(os.path.realpath(__file__)), - '_exit_scenarios.py')) + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "_exit_scenarios.py" + ) +) INTERPRETER = sys.executable BASE_COMMAND = [INTERPRETER, SCENARIO_FILE] -BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt'] +BASE_SIGTERM_COMMAND = BASE_COMMAND + ["--wait_for_interrupt"] INIT_TIME = datetime.timedelta(seconds=1) WAIT_CHECK_INTERVAL = datetime.timedelta(milliseconds=100) @@ -65,7 +67,7 @@ def _process_wait_with_timeout(process, timeout=WAIT_CHECK_DEFAULT_TIMEOUT): while (process.poll() is None) and (datetime.datetime.now() < deadline): time.sleep(WAIT_CHECK_INTERVAL.total_seconds()) if process.returncode is None: - raise RuntimeError('Process failed to exit within %s' % timeout) + raise RuntimeError("Process failed to exit within %s" % timeout) def interrupt_and_wait(process): @@ -83,133 +85,154 @@ def wait(process): # TODO(lidiz) enable exit tests once the root cause found. -@unittest.skip('https://github.com/grpc/grpc/issues/23982') -@unittest.skip('https://github.com/grpc/grpc/issues/23028') +@unittest.skip("https://github.com/grpc/grpc/issues/23982") +@unittest.skip("https://github.com/grpc/grpc/issues/23028") class ExitTest(unittest.TestCase): - def test_unstarted_server(self): - process = subprocess.Popen(BASE_COMMAND + - [_exit_scenarios.UNSTARTED_SERVER], - stdout=sys.stdout, - stderr=sys.stderr) + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.UNSTARTED_SERVER], + stdout=sys.stdout, + stderr=sys.stderr, + ) wait(process) def test_unstarted_server_terminate(self): - process = subprocess.Popen(BASE_SIGTERM_COMMAND + - [_exit_scenarios.UNSTARTED_SERVER], - stdout=sys.stdout) + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.UNSTARTED_SERVER], + stdout=sys.stdout, + ) interrupt_and_wait(process) def test_running_server(self): - process = subprocess.Popen(BASE_COMMAND + - [_exit_scenarios.RUNNING_SERVER], - stdout=sys.stdout, - stderr=sys.stderr) + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.RUNNING_SERVER], + stdout=sys.stdout, + stderr=sys.stderr, + ) wait(process) def test_running_server_terminate(self): - process = subprocess.Popen(BASE_SIGTERM_COMMAND + - [_exit_scenarios.RUNNING_SERVER], - stdout=sys.stdout, - stderr=sys.stderr) + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.RUNNING_SERVER], + stdout=sys.stdout, + stderr=sys.stderr, + ) interrupt_and_wait(process) def test_poll_connectivity_no_server(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) wait(process) def test_poll_connectivity_no_server_terminate(self): process = subprocess.Popen( - BASE_SIGTERM_COMMAND + - [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], + BASE_SIGTERM_COMMAND + + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) interrupt_and_wait(process) def test_poll_connectivity(self): - process = subprocess.Popen(BASE_COMMAND + - [_exit_scenarios.POLL_CONNECTIVITY], - stdout=sys.stdout, - stderr=sys.stderr) + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY], + stdout=sys.stdout, + stderr=sys.stderr, + ) wait(process) def test_poll_connectivity_terminate(self): - process = subprocess.Popen(BASE_SIGTERM_COMMAND + - [_exit_scenarios.POLL_CONNECTIVITY], - stdout=sys.stdout, - stderr=sys.stderr) + process = subprocess.Popen( + BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY], + stdout=sys.stdout, + stderr=sys.stderr, + ) interrupt_and_wait(process) - @unittest.skipIf(os.name == 'nt', - 'os.kill does not have required permission on Windows') + @unittest.skipIf( + os.name == "nt", "os.kill does not have required permission on Windows" + ) def test_in_flight_unary_unary_call(self): - process = subprocess.Popen(BASE_COMMAND + - [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL], - stdout=sys.stdout, - stderr=sys.stderr) + process = subprocess.Popen( + BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL], + stdout=sys.stdout, + stderr=sys.stderr, + ) interrupt_and_wait(process) - @unittest.skipIf(os.name == 'nt', - 'os.kill does not have required permission on Windows') + @unittest.skipIf( + os.name == "nt", "os.kill does not have required permission on Windows" + ) def test_in_flight_unary_stream_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) interrupt_and_wait(process) - @unittest.skipIf(os.name == 'nt', - 'os.kill does not have required permission on Windows') + @unittest.skipIf( + os.name == "nt", "os.kill does not have required permission on Windows" + ) def test_in_flight_stream_unary_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) interrupt_and_wait(process) - @unittest.skipIf(os.name == 'nt', - 'os.kill does not have required permission on Windows') + @unittest.skipIf( + os.name == "nt", "os.kill does not have required permission on Windows" + ) def test_in_flight_stream_stream_call(self): process = subprocess.Popen( BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) interrupt_and_wait(process) - @unittest.skipIf(os.name == 'nt', - 'os.kill does not have required permission on Windows') + @unittest.skipIf( + os.name == "nt", "os.kill does not have required permission on Windows" + ) def test_in_flight_partial_unary_stream_call(self): process = subprocess.Popen( - BASE_COMMAND + - [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL], + BASE_COMMAND + + [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) interrupt_and_wait(process) - @unittest.skipIf(os.name == 'nt', - 'os.kill does not have required permission on Windows') + @unittest.skipIf( + os.name == "nt", "os.kill does not have required permission on Windows" + ) def test_in_flight_partial_stream_unary_call(self): process = subprocess.Popen( - BASE_COMMAND + - [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL], + BASE_COMMAND + + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) interrupt_and_wait(process) - @unittest.skipIf(os.name == 'nt', - 'os.kill does not have required permission on Windows') + @unittest.skipIf( + os.name == "nt", "os.kill does not have required permission on Windows" + ) def test_in_flight_partial_stream_stream_call(self): process = subprocess.Popen( - BASE_COMMAND + - [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL], + BASE_COMMAND + + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) interrupt_and_wait(process) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py b/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py index 1ada25382de5b..b5f8427cf4eaf 100644 --- a/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py +++ b/src/python/grpcio_tests/tests/unit/_from_grpc_import_star.py @@ -19,5 +19,7 @@ _AFTER_IMPORT = tuple(globals()) GRPC_ELEMENTS = tuple( - element for element in _AFTER_IMPORT - if element not in _BEFORE_IMPORT and element != '_BEFORE_IMPORT') + element + for element in _AFTER_IMPORT + if element not in _BEFORE_IMPORT and element != "_BEFORE_IMPORT" +) diff --git a/src/python/grpcio_tests/tests/unit/_grpc_shutdown_test.py b/src/python/grpcio_tests/tests/unit/_grpc_shutdown_test.py index b1f43e061b0a3..75716d32a91ab 100644 --- a/src/python/grpcio_tests/tests/unit/_grpc_shutdown_test.py +++ b/src/python/grpcio_tests/tests/unit/_grpc_shutdown_test.py @@ -24,7 +24,6 @@ class GrpcShutdownTest(unittest.TestCase): - def test_channel_close_with_connectivity_watcher(self): """Originated by https://github.com/grpc/grpc/issues/20299. @@ -34,8 +33,10 @@ def test_channel_close_with_connectivity_watcher(self): connection_failed = threading.Event() def on_state_change(state): - if state in (grpc.ChannelConnectivity.TRANSIENT_FAILURE, - grpc.ChannelConnectivity.SHUTDOWN): + if state in ( + grpc.ChannelConnectivity.TRANSIENT_FAILURE, + grpc.ChannelConnectivity.SHUTDOWN, + ): connection_failed.set() # Connects to an void address, and subscribes state changes @@ -50,5 +51,5 @@ def on_state_change(state): channel.close() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index d8f3c90f415ab..60119da4d74d4 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -29,16 +29,16 @@ from tests.unit.framework.common import test_control _SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 -_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :] _SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 -_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3] -_EXCEPTION_REQUEST = b'\x09\x0a' +_EXCEPTION_REQUEST = b"\x09\x0a" -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" class _ApplicationErrorStandin(Exception): @@ -46,7 +46,6 @@ class _ApplicationErrorStandin(Exception): class _Callback(object): - def __init__(self): self._condition = threading.Condition() self._value = None @@ -66,17 +65,20 @@ def value(self): class _Handler(object): - def __init__(self, control): self._control = control def handle_unary_unary(self, request, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) if request == _EXCEPTION_REQUEST: raise _ApplicationErrorStandin() return request @@ -89,10 +91,14 @@ def handle_unary_stream(self, request, servicer_context): yield request self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) def handle_stream_unary(self, request_iterator, servicer_context): if servicer_context is not None: @@ -104,21 +110,29 @@ def handle_stream_unary(self, request_iterator, servicer_context): response_elements.append(request) self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) if _EXCEPTION_REQUEST in response_elements: raise _ApplicationErrorStandin() - return b''.join(response_elements) + return b"".join(response_elements) def handle_stream_stream(self, request_iterator, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) for request in request_iterator: if request == _EXCEPTION_REQUEST: raise _ApplicationErrorStandin() @@ -128,10 +142,17 @@ def handle_stream_stream(self, request_iterator, servicer_context): class _MethodHandler(grpc.RpcMethodHandler): - - def __init__(self, request_streaming, response_streaming, - request_deserializer, response_serializer, unary_unary, - unary_stream, stream_unary, stream_stream): + def __init__( + self, + request_streaming, + response_streaming, + request_deserializer, + response_serializer, + unary_unary, + unary_stream, + stream_unary, + stream_stream, + ): self.request_streaming = request_streaming self.response_streaming = response_streaming self.request_deserializer = request_deserializer @@ -143,26 +164,54 @@ def __init__(self, request_streaming, response_streaming, class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self, handler): self._handler = handler def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: - return _MethodHandler(False, False, None, None, - self._handler.handle_unary_unary, None, None, - None) + return _MethodHandler( + False, + False, + None, + None, + self._handler.handle_unary_unary, + None, + None, + None, + ) elif handler_call_details.method == _UNARY_STREAM: - return _MethodHandler(False, True, _DESERIALIZE_REQUEST, - _SERIALIZE_RESPONSE, None, - self._handler.handle_unary_stream, None, None) + return _MethodHandler( + False, + True, + _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, + None, + self._handler.handle_unary_stream, + None, + None, + ) elif handler_call_details.method == _STREAM_UNARY: - return _MethodHandler(True, False, _DESERIALIZE_REQUEST, - _SERIALIZE_RESPONSE, None, None, - self._handler.handle_stream_unary, None) + return _MethodHandler( + True, + False, + _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, + None, + None, + self._handler.handle_stream_unary, + None, + ) elif handler_call_details.method == _STREAM_STREAM: - return _MethodHandler(True, True, None, None, None, None, None, - self._handler.handle_stream_stream) + return _MethodHandler( + True, + True, + None, + None, + None, + None, + None, + self._handler.handle_stream_stream, + ) else: return None @@ -172,15 +221,19 @@ def _unary_unary_multi_callable(channel): def _unary_stream_multi_callable(channel): - return channel.unary_stream(_UNARY_STREAM, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def _stream_unary_multi_callable(channel): - return channel.stream_unary(_STREAM_UNARY, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def _stream_stream_multi_callable(channel): @@ -188,110 +241,128 @@ def _stream_stream_multi_callable(channel): class _ClientCallDetails( - collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials')), - grpc.ClientCallDetails): + collections.namedtuple( + "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") + ), + grpc.ClientCallDetails, +): pass -class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor): - +class _GenericClientInterceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): def __init__(self, interceptor_function): self._fn = interceptor_function def intercept_unary_unary(self, continuation, client_call_details, request): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, False) + client_call_details, iter((request,)), False, False + ) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response - def intercept_unary_stream(self, continuation, client_call_details, - request): + def intercept_unary_stream( + self, continuation, client_call_details, request + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, iter((request,)), False, True) + client_call_details, iter((request,)), False, True + ) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it - def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, False) + client_call_details, request_iterator, True, False + ) response = continuation(new_details, next(new_request_iterator)) return postprocess(response) if postprocess else response - def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): new_details, new_request_iterator, postprocess = self._fn( - client_call_details, request_iterator, True, True) + client_call_details, request_iterator, True, True + ) response_it = continuation(new_details, new_request_iterator) return postprocess(response_it) if postprocess else response_it -class _LoggingInterceptor(grpc.ServerInterceptor, - grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor): - +class _LoggingInterceptor( + grpc.ServerInterceptor, + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): def __init__(self, tag, record): self.tag = tag self.record = record def intercept_service(self, continuation, handler_call_details): - self.record.append(self.tag + ':intercept_service') + self.record.append(self.tag + ":intercept_service") return continuation(handler_call_details) def intercept_unary_unary(self, continuation, client_call_details, request): - self.record.append(self.tag + ':intercept_unary_unary') + self.record.append(self.tag + ":intercept_unary_unary") result = continuation(client_call_details, request) assert isinstance( - result, - grpc.Call), '{} ({}) is not an instance of grpc.Call'.format( - result, type(result)) + result, grpc.Call + ), "{} ({}) is not an instance of grpc.Call".format( + result, type(result) + ) assert isinstance( - result, - grpc.Future), '{} ({}) is not an instance of grpc.Future'.format( - result, type(result)) + result, grpc.Future + ), "{} ({}) is not an instance of grpc.Future".format( + result, type(result) + ) return result - def intercept_unary_stream(self, continuation, client_call_details, - request): - self.record.append(self.tag + ':intercept_unary_stream') + def intercept_unary_stream( + self, continuation, client_call_details, request + ): + self.record.append(self.tag + ":intercept_unary_stream") return continuation(client_call_details, request) - def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): - self.record.append(self.tag + ':intercept_stream_unary') + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): + self.record.append(self.tag + ":intercept_stream_unary") result = continuation(client_call_details, request_iterator) assert isinstance( - result, - grpc.Call), '{} is not an instance of grpc.Call'.format(result) + result, grpc.Call + ), "{} is not an instance of grpc.Call".format(result) assert isinstance( - result, - grpc.Future), '{} is not an instance of grpc.Future'.format(result) + result, grpc.Future + ), "{} is not an instance of grpc.Future".format(result) return result - def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): - self.record.append(self.tag + ':intercept_stream_stream') + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): + self.record.append(self.tag + ":intercept_stream_stream") return continuation(client_call_details, request_iterator) class _DefectiveClientInterceptor(grpc.UnaryUnaryClientInterceptor): - - def intercept_unary_unary(self, ignored_continuation, - ignored_client_call_details, ignored_request): + def intercept_unary_unary( + self, ignored_continuation, ignored_client_call_details, ignored_request + ): raise test_control.Defect() def _wrap_request_iterator_stream_interceptor(wrapper): - - def intercept_call(client_call_details, request_iterator, request_streaming, - ignored_response_streaming): + def intercept_call( + client_call_details, + request_iterator, + request_streaming, + ignored_response_streaming, + ): if request_streaming: return client_call_details, wrapper(request_iterator), None else: @@ -301,26 +372,33 @@ def intercept_call(client_call_details, request_iterator, request_streaming, def _append_request_header_interceptor(header, value): - - def intercept_call(client_call_details, request_iterator, - ignored_request_streaming, ignored_response_streaming): + def intercept_call( + client_call_details, + request_iterator, + ignored_request_streaming, + ignored_response_streaming, + ): metadata = [] if client_call_details.metadata: metadata = list(client_call_details.metadata) - metadata.append(( - header, - value, - )) + metadata.append( + ( + header, + value, + ) + ) client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials) + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + ) return client_call_details, request_iterator, None return _GenericClientInterceptor(intercept_call) class _GenericServerInterceptor(grpc.ServerInterceptor): - def __init__(self, fn): self._fn = fn @@ -329,18 +407,17 @@ def intercept_service(self, continuation, handler_call_details): def _filter_server_interceptor(condition, interceptor): - def intercept_service(continuation, handler_call_details): if condition(handler_call_details): - return interceptor.intercept_service(continuation, - handler_call_details) + return interceptor.intercept_service( + continuation, handler_call_details + ) return continuation(handler_call_details) return _GenericServerInterceptor(intercept_service) class InterceptorTest(unittest.TestCase): - def setUp(self): self._control = test_control.PauseFailControl() self._handler = _Handler(self._control) @@ -348,21 +425,24 @@ def setUp(self): self._record = [] conditional_interceptor = _filter_server_interceptor( - lambda x: ('secret', '42') in x.invocation_metadata, - _LoggingInterceptor('s3', self._record)) - - self._server = grpc.server(self._server_pool, - options=(('grpc.so_reuseport', 0),), - interceptors=( - _LoggingInterceptor('s1', self._record), - conditional_interceptor, - _LoggingInterceptor('s2', self._record), - )) - port = self._server.add_insecure_port('[::]:0') + lambda x: ("secret", "42") in x.invocation_metadata, + _LoggingInterceptor("s3", self._record), + ) + + self._server = grpc.server( + self._server_pool, + options=(("grpc.so_reuseport", 0),), + interceptors=( + _LoggingInterceptor("s1", self._record), + conditional_interceptor, + _LoggingInterceptor("s2", self._record), + ), + ) + port = self._server.add_insecure_port("[::]:0") self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._server.stop(None) @@ -370,7 +450,6 @@ def tearDown(self): self._channel.close() def testTripleRequestMessagesClientInterceptor(self): - def triple(request_iterator): while True: try: @@ -384,14 +463,19 @@ def triple(request_iterator): interceptor = _wrap_request_iterator_stream_interceptor(triple) channel = grpc.intercept_channel(self._channel, interceptor) requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) multi_callable = _stream_stream_multi_callable(channel) response_iterator = multi_callable( iter(requests), metadata=( - ('test', - 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + ( + "test", + "InterceptedStreamRequestBlockingUnaryResponseWithCall", + ), + ), + ) responses = tuple(response_iterator) self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH) @@ -400,8 +484,12 @@ def triple(request_iterator): response_iterator = multi_callable( iter(requests), metadata=( - ('test', - 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) + ( + "test", + "InterceptedStreamRequestBlockingUnaryResponseWithCall", + ), + ), + ) responses = tuple(response_iterator) self.assertEqual(len(responses), test_constants.STREAM_LENGTH) @@ -410,25 +498,30 @@ def testDefectiveClientInterceptor(self): interceptor = _DefectiveClientInterceptor() defective_channel = grpc.intercept_channel(self._channel, interceptor) - request = b'\x07\x08' + request = b"\x07\x08" multi_callable = _unary_unary_multi_callable(defective_channel) call_future = multi_callable.future( request, - metadata=(('test', - 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + metadata=( + ("test", "InterceptedUnaryRequestBlockingUnaryResponse"), + ), + ) self.assertIsNotNone(call_future.exception()) self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL) def testInterceptedHeaderManipulationWithServerSideVerification(self): - request = b'\x07\x08' + request = b"\x07\x08" channel = grpc.intercept_channel( - self._channel, _append_request_header_interceptor('secret', '42')) + self._channel, _append_request_header_interceptor("secret", "42") + ) channel = grpc.intercept_channel( - channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) self._record[:] = [] @@ -436,34 +529,52 @@ def testInterceptedHeaderManipulationWithServerSideVerification(self): multi_callable.with_call( request, metadata=( - ('test', - 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) - - self.assertSequenceEqual(self._record, [ - 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', - 's1:intercept_service', 's3:intercept_service', - 's2:intercept_service' - ]) + ( + "test", + "InterceptedUnaryRequestBlockingUnaryResponseWithCall", + ), + ), + ) + + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_unary_unary", + "c2:intercept_unary_unary", + "s1:intercept_service", + "s3:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedUnaryRequestBlockingUnaryResponse(self): - request = b'\x07\x08' + request = b"\x07\x08" self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _unary_unary_multi_callable(channel) multi_callable( request, - metadata=(('test', - 'InterceptedUnaryRequestBlockingUnaryResponse'),)) - - self.assertSequenceEqual(self._record, [ - 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', - 's1:intercept_service', 's2:intercept_service' - ]) + metadata=( + ("test", "InterceptedUnaryRequestBlockingUnaryResponse"), + ), + ) + + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_unary_unary", + "c2:intercept_unary_unary", + "s1:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedUnaryRequestBlockingUnaryResponseWithError(self): request = _EXCEPTION_REQUEST @@ -471,15 +582,19 @@ def testInterceptedUnaryRequestBlockingUnaryResponseWithError(self): self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _unary_unary_multi_callable(channel) with self.assertRaises(grpc.RpcError) as exception_context: multi_callable( request, - metadata=(('test', - 'InterceptedUnaryRequestBlockingUnaryResponse'),)) + metadata=( + ("test", "InterceptedUnaryRequestBlockingUnaryResponse"), + ), + ) exception = exception_context.exception self.assertFalse(exception.cancelled()) self.assertFalse(exception.running()) @@ -489,11 +604,13 @@ def testInterceptedUnaryRequestBlockingUnaryResponseWithError(self): self.assertIsInstance(exception.exception(), grpc.RpcError) def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): - request = b'\x07\x08' + request = b"\x07\x08" channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) self._record[:] = [] @@ -501,64 +618,92 @@ def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self): multi_callable.with_call( request, metadata=( - ('test', - 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),)) - - self.assertSequenceEqual(self._record, [ - 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', - 's1:intercept_service', 's2:intercept_service' - ]) + ( + "test", + "InterceptedUnaryRequestBlockingUnaryResponseWithCall", + ), + ), + ) + + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_unary_unary", + "c2:intercept_unary_unary", + "s1:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedUnaryRequestFutureUnaryResponse(self): - request = b'\x07\x08' + request = b"\x07\x08" self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _unary_unary_multi_callable(channel) response_future = multi_callable.future( request, - metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),)) + metadata=(("test", "InterceptedUnaryRequestFutureUnaryResponse"),), + ) response_future.result() - self.assertSequenceEqual(self._record, [ - 'c1:intercept_unary_unary', 'c2:intercept_unary_unary', - 's1:intercept_service', 's2:intercept_service' - ]) + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_unary_unary", + "c2:intercept_unary_unary", + "s1:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedUnaryRequestStreamResponse(self): - request = b'\x37\x58' + request = b"\x37\x58" self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _unary_stream_multi_callable(channel) response_iterator = multi_callable( request, - metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) + metadata=(("test", "InterceptedUnaryRequestStreamResponse"),), + ) tuple(response_iterator) - self.assertSequenceEqual(self._record, [ - 'c1:intercept_unary_stream', 'c2:intercept_unary_stream', - 's1:intercept_service', 's2:intercept_service' - ]) + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_unary_stream", + "c2:intercept_unary_stream", + "s1:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedUnaryRequestStreamResponseWithError(self): request = _EXCEPTION_REQUEST self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _unary_stream_multi_callable(channel) response_iterator = multi_callable( request, - metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),)) + metadata=(("test", "InterceptedUnaryRequestStreamResponse"),), + ) with self.assertRaises(grpc.RpcError) as exception_context: tuple(response_iterator) exception = exception_context.exception @@ -571,82 +716,117 @@ def testInterceptedUnaryRequestStreamResponseWithError(self): def testInterceptedStreamRequestBlockingUnaryResponse(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _stream_unary_multi_callable(channel) multi_callable( request_iterator, - metadata=(('test', - 'InterceptedStreamRequestBlockingUnaryResponse'),)) - - self.assertSequenceEqual(self._record, [ - 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', - 's1:intercept_service', 's2:intercept_service' - ]) + metadata=( + ("test", "InterceptedStreamRequestBlockingUnaryResponse"), + ), + ) + + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_stream_unary", + "c2:intercept_stream_unary", + "s1:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _stream_unary_multi_callable(channel) multi_callable.with_call( request_iterator, metadata=( - ('test', - 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),)) - - self.assertSequenceEqual(self._record, [ - 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', - 's1:intercept_service', 's2:intercept_service' - ]) + ( + "test", + "InterceptedStreamRequestBlockingUnaryResponseWithCall", + ), + ), + ) + + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_stream_unary", + "c2:intercept_stream_unary", + "s1:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedStreamRequestFutureUnaryResponse(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _stream_unary_multi_callable(channel) response_future = multi_callable.future( request_iterator, - metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) + metadata=(("test", "InterceptedStreamRequestFutureUnaryResponse"),), + ) response_future.result() - self.assertSequenceEqual(self._record, [ - 'c1:intercept_stream_unary', 'c2:intercept_stream_unary', - 's1:intercept_service', 's2:intercept_service' - ]) + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_stream_unary", + "c2:intercept_stream_unary", + "s1:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedStreamRequestFutureUnaryResponseWithError(self): requests = tuple( - _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH)) + _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _stream_unary_multi_callable(channel) response_future = multi_callable.future( request_iterator, - metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),)) + metadata=(("test", "InterceptedStreamRequestFutureUnaryResponse"),), + ) with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() exception = exception_context.exception @@ -659,39 +839,52 @@ def testInterceptedStreamRequestFutureUnaryResponseWithError(self): def testInterceptedStreamRequestStreamResponse(self): requests = tuple( - b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) + b"\x77\x58" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _stream_stream_multi_callable(channel) response_iterator = multi_callable( request_iterator, - metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) + metadata=(("test", "InterceptedStreamRequestStreamResponse"),), + ) tuple(response_iterator) - self.assertSequenceEqual(self._record, [ - 'c1:intercept_stream_stream', 'c2:intercept_stream_stream', - 's1:intercept_service', 's2:intercept_service' - ]) + self.assertSequenceEqual( + self._record, + [ + "c1:intercept_stream_stream", + "c2:intercept_stream_stream", + "s1:intercept_service", + "s2:intercept_service", + ], + ) def testInterceptedStreamRequestStreamResponseWithError(self): requests = tuple( - _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH)) + _EXCEPTION_REQUEST for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) self._record[:] = [] channel = grpc.intercept_channel( - self._channel, _LoggingInterceptor('c1', self._record), - _LoggingInterceptor('c2', self._record)) + self._channel, + _LoggingInterceptor("c1", self._record), + _LoggingInterceptor("c2", self._record), + ) multi_callable = _stream_stream_multi_callable(channel) response_iterator = multi_callable( request_iterator, - metadata=(('test', 'InterceptedStreamRequestStreamResponse'),)) + metadata=(("test", "InterceptedStreamRequestStreamResponse"),), + ) with self.assertRaises(grpc.RpcError) as exception_context: tuple(response_iterator) exception = exception_context.exception @@ -703,6 +896,6 @@ def testInterceptedStreamRequestStreamResponseWithError(self): self.assertIsInstance(exception.exception(), grpc.RpcError) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index c56b719c408a7..a19966131c5c7 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -21,14 +21,14 @@ from tests.unit.framework.common import test_constants _SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 -_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :] _SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 -_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3] -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" def _unary_unary_multi_callable(channel): @@ -36,15 +36,19 @@ def _unary_unary_multi_callable(channel): def _unary_stream_multi_callable(channel): - return channel.unary_stream(_UNARY_STREAM, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def _stream_unary_multi_callable(channel): - return channel.stream_unary(_STREAM_UNARY, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def _stream_stream_multi_callable(channel): @@ -52,9 +56,8 @@ def _stream_stream_multi_callable(channel): class InvalidMetadataTest(unittest.TestCase): - def setUp(self): - self._channel = grpc.insecure_channel('localhost:8080') + self._channel = grpc.insecure_channel("localhost:8080") self._unary_unary = _unary_unary_multi_callable(self._channel) self._unary_stream = _unary_stream_multi_callable(self._channel) self._stream_unary = _stream_unary_multi_callable(self._channel) @@ -64,31 +67,31 @@ def tearDown(self): self._channel.close() def testUnaryRequestBlockingUnaryResponse(self): - request = b'\x07\x08' - metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),) + request = b"\x07\x08" + metadata = (("InVaLiD", "UnaryRequestBlockingUnaryResponse"),) expected_error_details = "metadata was invalid: %s" % metadata with self.assertRaises(ValueError) as exception_context: self._unary_unary(request, metadata=metadata) self.assertIn(expected_error_details, str(exception_context.exception)) def testUnaryRequestBlockingUnaryResponseWithCall(self): - request = b'\x07\x08' - metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponseWithCall'),) + request = b"\x07\x08" + metadata = (("InVaLiD", "UnaryRequestBlockingUnaryResponseWithCall"),) expected_error_details = "metadata was invalid: %s" % metadata with self.assertRaises(ValueError) as exception_context: self._unary_unary.with_call(request, metadata=metadata) self.assertIn(expected_error_details, str(exception_context.exception)) def testUnaryRequestFutureUnaryResponse(self): - request = b'\x07\x08' - metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),) + request = b"\x07\x08" + metadata = (("InVaLiD", "UnaryRequestFutureUnaryResponse"),) expected_error_details = "metadata was invalid: %s" % metadata with self.assertRaises(ValueError) as exception_context: self._unary_unary.future(request, metadata=metadata) def testUnaryRequestStreamResponse(self): - request = b'\x37\x58' - metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),) + request = b"\x37\x58" + metadata = (("InVaLiD", "UnaryRequestStreamResponse"),) expected_error_details = "metadata was invalid: %s" % metadata with self.assertRaises(ValueError) as exception_context: self._unary_stream(request, metadata=metadata) @@ -96,8 +99,9 @@ def testUnaryRequestStreamResponse(self): def testStreamRequestBlockingUnaryResponse(self): request_iterator = ( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) - metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) + metadata = (("InVaLiD", "StreamRequestBlockingUnaryResponse"),) expected_error_details = "metadata was invalid: %s" % metadata with self.assertRaises(ValueError) as exception_context: self._stream_unary(request_iterator, metadata=metadata) @@ -105,8 +109,9 @@ def testStreamRequestBlockingUnaryResponse(self): def testStreamRequestBlockingUnaryResponseWithCall(self): request_iterator = ( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) - metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) + metadata = (("InVaLiD", "StreamRequestBlockingUnaryResponseWithCall"),) expected_error_details = "metadata was invalid: %s" % metadata multi_callable = _stream_unary_multi_callable(self._channel) with self.assertRaises(ValueError) as exception_context: @@ -115,8 +120,9 @@ def testStreamRequestBlockingUnaryResponseWithCall(self): def testStreamRequestFutureUnaryResponse(self): request_iterator = ( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) - metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) + metadata = (("InVaLiD", "StreamRequestFutureUnaryResponse"),) expected_error_details = "metadata was invalid: %s" % metadata with self.assertRaises(ValueError) as exception_context: self._stream_unary.future(request_iterator, metadata=metadata) @@ -124,17 +130,18 @@ def testStreamRequestFutureUnaryResponse(self): def testStreamRequestStreamResponse(self): request_iterator = ( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) - metadata = (('InVaLiD', 'StreamRequestStreamResponse'),) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) + metadata = (("InVaLiD", "StreamRequestStreamResponse"),) expected_error_details = "metadata was invalid: %s" % metadata with self.assertRaises(ValueError) as exception_context: self._stream_stream(request_iterator, metadata=metadata) self.assertIn(expected_error_details, str(exception_context.exception)) def testInvalidMetadata(self): - self.assertRaises(TypeError, self._unary_unary, b'', metadata=42) + self.assertRaises(TypeError, self._unary_unary, b"", metadata=42) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py index 904d8b914f278..b22ab01659311 100644 --- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py +++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py @@ -22,34 +22,38 @@ from tests.unit.framework.common import test_control _SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 -_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :] _SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 -_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3] -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_UNARY_NESTED_EXCEPTION = '/test/UnaryUnaryNestedException' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' -_DEFECTIVE_GENERIC_RPC_HANDLER = '/test/DefectiveGenericRpcHandler' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_UNARY_NESTED_EXCEPTION = "/test/UnaryUnaryNestedException" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" +_DEFECTIVE_GENERIC_RPC_HANDLER = "/test/DefectiveGenericRpcHandler" class _Handler(object): - def __init__(self, control): self._control = control def handle_unary_unary(self, request, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) return request - def handle_unary_unary_with_nested_exception(self, request, - servicer_context): + def handle_unary_unary_with_nested_exception( + self, request, servicer_context + ): raise test_control.NestedDefect() def handle_unary_stream(self, request, servicer_context): @@ -58,10 +62,14 @@ def handle_unary_stream(self, request, servicer_context): yield request self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) def handle_stream_unary(self, request_iterator, servicer_context): if servicer_context is not None: @@ -73,19 +81,27 @@ def handle_stream_unary(self, request_iterator, servicer_context): response_elements.append(request) self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) - return b''.join(response_elements) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) + return b"".join(response_elements) def handle_stream_stream(self, request_iterator, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) for request in request_iterator: self._control.control() yield request @@ -96,10 +112,17 @@ def defective_generic_rpc_handler(self): class _MethodHandler(grpc.RpcMethodHandler): - - def __init__(self, request_streaming, response_streaming, - request_deserializer, response_serializer, unary_unary, - unary_stream, stream_unary, stream_stream): + def __init__( + self, + request_streaming, + response_streaming, + request_deserializer, + response_serializer, + unary_unary, + unary_stream, + stream_unary, + stream_stream, + ): self.request_streaming = request_streaming self.response_streaming = response_streaming self.request_deserializer = request_deserializer @@ -111,39 +134,72 @@ def __init__(self, request_streaming, response_streaming, class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self, handler): self._handler = handler def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: - return _MethodHandler(False, False, None, None, - self._handler.handle_unary_unary, None, None, - None) + return _MethodHandler( + False, + False, + None, + None, + self._handler.handle_unary_unary, + None, + None, + None, + ) elif handler_call_details.method == _UNARY_STREAM: - return _MethodHandler(False, True, _DESERIALIZE_REQUEST, - _SERIALIZE_RESPONSE, None, - self._handler.handle_unary_stream, None, None) + return _MethodHandler( + False, + True, + _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, + None, + self._handler.handle_unary_stream, + None, + None, + ) elif handler_call_details.method == _STREAM_UNARY: - return _MethodHandler(True, False, _DESERIALIZE_REQUEST, - _SERIALIZE_RESPONSE, None, None, - self._handler.handle_stream_unary, None) + return _MethodHandler( + True, + False, + _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, + None, + None, + self._handler.handle_stream_unary, + None, + ) elif handler_call_details.method == _STREAM_STREAM: - return _MethodHandler(True, True, None, None, None, None, None, - self._handler.handle_stream_stream) + return _MethodHandler( + True, + True, + None, + None, + None, + None, + None, + self._handler.handle_stream_stream, + ) elif handler_call_details.method == _DEFECTIVE_GENERIC_RPC_HANDLER: return self._handler.defective_generic_rpc_handler() elif handler_call_details.method == _UNARY_UNARY_NESTED_EXCEPTION: return _MethodHandler( - False, False, None, None, - self._handler.handle_unary_unary_with_nested_exception, None, - None, None) + False, + False, + None, + None, + self._handler.handle_unary_unary_with_nested_exception, + None, + None, + None, + ) else: return None class FailAfterFewIterationsCounter(object): - def __init__(self, high, bytestring): self._current = 0 self._high = high @@ -167,15 +223,19 @@ def _unary_unary_multi_callable(channel): def _unary_stream_multi_callable(channel): - return channel.unary_stream(_UNARY_STREAM, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def _stream_unary_multi_callable(channel): - return channel.stream_unary(_STREAM_UNARY, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def _stream_stream_multi_callable(channel): @@ -198,11 +258,11 @@ def setUp(self): self._handler = _Handler(self._control) self._server = test_common.test_server() - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._server.stop(0) @@ -215,79 +275,92 @@ def testIterableStreamRequestBlockingUnaryResponse(self): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable( requests, - metadata=(('test', - 'IterableStreamRequestBlockingUnaryResponse'),)) + metadata=( + ("test", "IterableStreamRequestBlockingUnaryResponse"), + ), + ) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) def testIterableStreamRequestFutureUnaryResponse(self): requests = object() multi_callable = _stream_unary_multi_callable(self._channel) response_future = multi_callable.future( requests, - metadata=(('test', 'IterableStreamRequestFutureUnaryResponse'),)) + metadata=(("test", "IterableStreamRequestFutureUnaryResponse"),), + ) with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) def testIterableStreamRequestStreamResponse(self): requests = object() multi_callable = _stream_stream_multi_callable(self._channel) response_iterator = multi_callable( requests, - metadata=(('test', 'IterableStreamRequestStreamResponse'),)) + metadata=(("test", "IterableStreamRequestStreamResponse"),), + ) with self.assertRaises(grpc.RpcError) as exception_context: next(response_iterator) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) def testIteratorStreamRequestStreamResponse(self): requests_iterator = FailAfterFewIterationsCounter( - test_constants.STREAM_LENGTH // 2, b'\x07\x08') + test_constants.STREAM_LENGTH // 2, b"\x07\x08" + ) multi_callable = _stream_stream_multi_callable(self._channel) response_iterator = multi_callable( requests_iterator, - metadata=(('test', 'IteratorStreamRequestStreamResponse'),)) + metadata=(("test", "IteratorStreamRequestStreamResponse"),), + ) with self.assertRaises(grpc.RpcError) as exception_context: for _ in range(test_constants.STREAM_LENGTH // 2 + 1): next(response_iterator) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) def testDefectiveGenericRpcHandlerUnaryResponse(self): - request = b'\x07\x08' + request = b"\x07\x08" multi_callable = _defective_handler_multi_callable(self._channel) with self.assertRaises(grpc.RpcError) as exception_context: - multi_callable(request, - metadata=(('test', - 'DefectiveGenericRpcHandlerUnary'),)) + multi_callable( + request, metadata=(("test", "DefectiveGenericRpcHandlerUnary"),) + ) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) def testNestedExceptionGenericRpcHandlerUnaryResponse(self): - request = b'\x07\x08' + request = b"\x07\x08" multi_callable = _defective_nested_exception_handler_multi_callable( - self._channel) + self._channel + ) with self.assertRaises(grpc.RpcError) as exception_context: - multi_callable(request, - metadata=(('test', - 'DefectiveGenericRpcHandlerUnary'),)) + multi_callable( + request, metadata=(("test", "DefectiveGenericRpcHandlerUnary"),) + ) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_local_credentials_test.py b/src/python/grpcio_tests/tests/unit/_local_credentials_test.py index ce92feed4b37b..165f6ca16eb9c 100644 --- a/src/python/grpcio_tests/tests/unit/_local_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_local_credentials_test.py @@ -21,58 +21,68 @@ class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): return grpc.unary_unary_rpc_method_handler( - lambda request, unused_context: request) + lambda request, unused_context: request + ) class LocalCredentialsTest(unittest.TestCase): - def _create_server(self): server = grpc.server(ThreadPoolExecutor()) server.add_generic_rpc_handlers((_GenericHandler(),)) return server - @unittest.skipIf(os.name == 'nt', - 'TODO(https://github.com/grpc/grpc/issues/20078)') + @unittest.skipIf( + os.name == "nt", "TODO(https://github.com/grpc/grpc/issues/20078)" + ) def test_local_tcp(self): - server_addr = 'localhost:{}' + server_addr = "localhost:{}" channel_creds = grpc.local_channel_credentials( - grpc.LocalConnectionType.LOCAL_TCP) + grpc.LocalConnectionType.LOCAL_TCP + ) server_creds = grpc.local_server_credentials( - grpc.LocalConnectionType.LOCAL_TCP) + grpc.LocalConnectionType.LOCAL_TCP + ) server = self._create_server() port = server.add_secure_port(server_addr.format(0), server_creds) server.start() - with grpc.secure_channel(server_addr.format(port), - channel_creds) as channel: + with grpc.secure_channel( + server_addr.format(port), channel_creds + ) as channel: self.assertEqual( - b'abc', - channel.unary_unary('/test/method')(b'abc', - wait_for_ready=True)) + b"abc", + channel.unary_unary("/test/method")( + b"abc", wait_for_ready=True + ), + ) server.stop(None) - @unittest.skipIf(os.name == 'nt', - 'Unix Domain Socket is not supported on Windows') + @unittest.skipIf( + os.name == "nt", "Unix Domain Socket is not supported on Windows" + ) def test_uds(self): - server_addr = 'unix:/tmp/grpc_fullstack_test' + server_addr = "unix:/tmp/grpc_fullstack_test" channel_creds = grpc.local_channel_credentials( - grpc.LocalConnectionType.UDS) + grpc.LocalConnectionType.UDS + ) server_creds = grpc.local_server_credentials( - grpc.LocalConnectionType.UDS) + grpc.LocalConnectionType.UDS + ) server = self._create_server() server.add_secure_port(server_addr, server_creds) server.start() with grpc.secure_channel(server_addr, channel_creds) as channel: self.assertEqual( - b'abc', - channel.unary_unary('/test/method')(b'abc', - wait_for_ready=True)) + b"abc", + channel.unary_unary("/test/method")( + b"abc", wait_for_ready=True + ), + ) server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/src/python/grpcio_tests/tests/unit/_logging_test.py b/src/python/grpcio_tests/tests/unit/_logging_test.py index caae51452ccc5..0b2435c24d198 100644 --- a/src/python/grpcio_tests/tests/unit/_logging_test.py +++ b/src/python/grpcio_tests/tests/unit/_logging_test.py @@ -24,7 +24,6 @@ class LoggingTest(unittest.TestCase): - def test_logger_not_occupied(self): script = """if True: import logging @@ -44,7 +43,7 @@ def test_handler_found(self): import grpc """ out, err = self._verifyScriptSucceeds(script) - self.assertEqual(0, len(err), 'unexpected output to stderr') + self.assertEqual(0, len(err), "unexpected output to stderr") def test_can_configure_logger(self): script = """if True: @@ -84,16 +83,20 @@ def test_grpc_logger(self): self._verifyScriptSucceeds(script) def _verifyScriptSucceeds(self, script): - process = subprocess.Popen([INTERPRETER, '-c', script], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + process = subprocess.Popen( + [INTERPRETER, "-c", script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) out, err = process.communicate() self.assertEqual( - 0, process.returncode, - 'process failed with exit code %d (stdout: %s, stderr: %s)' % - (process.returncode, out, err)) + 0, + process.returncode, + "process failed with exit code %d (stdout: %s, stderr: %s)" + % (process.returncode, out, err), + ) return out, err -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 89c028b307be8..3c530058dc37b 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -23,43 +23,49 @@ from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_control -_SERIALIZED_REQUEST = b'\x46\x47\x48' -_SERIALIZED_RESPONSE = b'\x49\x50\x51' +_SERIALIZED_REQUEST = b"\x46\x47\x48" +_SERIALIZED_RESPONSE = b"\x49\x50\x51" _REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST _REQUEST_DESERIALIZER = lambda unused_serialized_request: object() _RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE _RESPONSE_DESERIALIZER = lambda unused_serialized_response: object() -_SERVICE = 'test.TestService' -_UNARY_UNARY = 'UnaryUnary' -_UNARY_STREAM = 'UnaryStream' -_STREAM_UNARY = 'StreamUnary' -_STREAM_STREAM = 'StreamStream' +_SERVICE = "test.TestService" +_UNARY_UNARY = "UnaryUnary" +_UNARY_STREAM = "UnaryStream" +_STREAM_UNARY = "StreamUnary" +_STREAM_STREAM = "StreamStream" -_CLIENT_METADATA = (('client-md-key', 'client-md-key'), ('client-md-key-bin', - b'\x00\x01')) +_CLIENT_METADATA = ( + ("client-md-key", "client-md-key"), + ("client-md-key-bin", b"\x00\x01"), +) -_SERVER_INITIAL_METADATA = (('server-initial-md-key', - 'server-initial-md-value'), - ('server-initial-md-key-bin', b'\x00\x02')) +_SERVER_INITIAL_METADATA = ( + ("server-initial-md-key", "server-initial-md-value"), + ("server-initial-md-key-bin", b"\x00\x02"), +) -_SERVER_TRAILING_METADATA = (('server-trailing-md-key', - 'server-trailing-md-value'), - ('server-trailing-md-key-bin', b'\x00\x03')) +_SERVER_TRAILING_METADATA = ( + ("server-trailing-md-key", "server-trailing-md-value"), + ("server-trailing-md-key-bin", b"\x00\x03"), +) _NON_OK_CODE = grpc.StatusCode.NOT_FOUND -_DETAILS = 'Test details!' +_DETAILS = "Test details!" # calling abort should always fail an RPC, even for "invalid" codes _ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK) -_EXPECTED_CLIENT_CODES = (_NON_OK_CODE, grpc.StatusCode.UNKNOWN, - grpc.StatusCode.UNKNOWN) -_EXPECTED_DETAILS = (_DETAILS, _DETAILS, '') +_EXPECTED_CLIENT_CODES = ( + _NON_OK_CODE, + grpc.StatusCode.UNKNOWN, + grpc.StatusCode.UNKNOWN, +) +_EXPECTED_DETAILS = (_DETAILS, _DETAILS, "") class _Servicer(object): - def __init__(self): self._lock = threading.Lock() self._abort_call = False @@ -170,62 +176,74 @@ def received_client_metadata(self): def _generic_handler(servicer): method_handlers = { - _UNARY_UNARY: - grpc.unary_unary_rpc_method_handler( - servicer.unary_unary, - request_deserializer=_REQUEST_DESERIALIZER, - response_serializer=_RESPONSE_SERIALIZER), - _UNARY_STREAM: - grpc.unary_stream_rpc_method_handler(servicer.unary_stream), - _STREAM_UNARY: - grpc.stream_unary_rpc_method_handler(servicer.stream_unary), - _STREAM_STREAM: - grpc.stream_stream_rpc_method_handler( - servicer.stream_stream, - request_deserializer=_REQUEST_DESERIALIZER, - response_serializer=_RESPONSE_SERIALIZER), + _UNARY_UNARY: grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=_REQUEST_DESERIALIZER, + response_serializer=_RESPONSE_SERIALIZER, + ), + _UNARY_STREAM: grpc.unary_stream_rpc_method_handler( + servicer.unary_stream + ), + _STREAM_UNARY: grpc.stream_unary_rpc_method_handler( + servicer.stream_unary + ), + _STREAM_STREAM: grpc.stream_stream_rpc_method_handler( + servicer.stream_stream, + request_deserializer=_REQUEST_DESERIALIZER, + response_serializer=_RESPONSE_SERIALIZER, + ), } return grpc.method_handlers_generic_handler(_SERVICE, method_handlers) class MetadataCodeDetailsTest(unittest.TestCase): - def setUp(self): self._servicer = _Servicer() self._server = test_common.test_server() self._server.add_generic_rpc_handlers( - (_generic_handler(self._servicer),)) - port = self._server.add_insecure_port('[::]:0') + (_generic_handler(self._servicer),) + ) + port = self._server.add_insecure_port("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._channel = grpc.insecure_channel("localhost:{}".format(port)) self._unary_unary = self._channel.unary_unary( - '/'.join(( - '', - _SERVICE, - _UNARY_UNARY, - )), + "/".join( + ( + "", + _SERVICE, + _UNARY_UNARY, + ) + ), request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, ) self._unary_stream = self._channel.unary_stream( - '/'.join(( - '', - _SERVICE, - _UNARY_STREAM, - )),) + "/".join( + ( + "", + _SERVICE, + _UNARY_STREAM, + ) + ), + ) self._stream_unary = self._channel.stream_unary( - '/'.join(( - '', - _SERVICE, - _STREAM_UNARY, - )),) + "/".join( + ( + "", + _SERVICE, + _STREAM_UNARY, + ) + ), + ) self._stream_stream = self._channel.stream_stream( - '/'.join(( - '', - _SERVICE, - _STREAM_STREAM, - )), + "/".join( + ( + "", + _SERVICE, + _STREAM_STREAM, + ) + ), request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, ) @@ -238,37 +256,51 @@ def testSuccessfulUnaryUnary(self): self._servicer.set_details(_DETAILS) unused_response, call = self._unary_unary.with_call( - object(), metadata=_CLIENT_METADATA) + object(), metadata=_CLIENT_METADATA + ) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - call.initial_metadata())) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, call.initial_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata() + ) + ) self.assertIs(grpc.StatusCode.OK, call.code()) def testSuccessfulUnaryStream(self): self._servicer.set_details(_DETAILS) - response_iterator_call = self._unary_stream(_SERIALIZED_REQUEST, - metadata=_CLIENT_METADATA) + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA + ) received_initial_metadata = response_iterator_call.initial_metadata() list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) + response_iterator_call.trailing_metadata(), + ) + ) self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) def testSuccessfulStreamUnary(self): @@ -276,43 +308,58 @@ def testSuccessfulStreamUnary(self): unused_response, call = self._stream_unary.with_call( iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) + metadata=_CLIENT_METADATA, + ) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - call.initial_metadata())) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, call.initial_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, - call.trailing_metadata())) + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, call.trailing_metadata() + ) + ) self.assertIs(grpc.StatusCode.OK, call.code()) def testSuccessfulStreamStream(self): self._servicer.set_details(_DETAILS) - response_iterator_call = self._stream_stream(iter( - [object()] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA, + ) received_initial_metadata = response_iterator_call.initial_metadata() list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) + response_iterator_call.trailing_metadata(), + ) + ) self.assertIs(grpc.StatusCode.OK, response_iterator_call.code()) def testAbortedUnaryUnary(self): - test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, - _EXPECTED_DETAILS) + test_cases = zip( + _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS + ) for abort_code, expected_code, expected_details in test_cases: self._servicer.set_code(abort_code) self._servicer.set_details(_DETAILS) @@ -323,81 +370,104 @@ def testAbortedUnaryUnary(self): self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, - self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) + exception_context.exception.initial_metadata(), + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(expected_code, exception_context.exception.code()) - self.assertEqual(expected_details, - exception_context.exception.details()) + self.assertEqual( + expected_details, exception_context.exception.details() + ) def testAbortedUnaryStream(self): - test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, - _EXPECTED_DETAILS) + test_cases = zip( + _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS + ) for abort_code, expected_code, expected_details in test_cases: self._servicer.set_code(abort_code) self._servicer.set_details(_DETAILS) self._servicer.set_abort_call() response_iterator_call = self._unary_stream( - _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = \ + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA + ) + received_initial_metadata = ( response_iterator_call.initial_metadata() + ) with self.assertRaises(grpc.RpcError): self.assertEqual(len(list(response_iterator_call)), 0) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, - self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) + response_iterator_call.trailing_metadata(), + ) + ) self.assertIs(expected_code, response_iterator_call.code()) self.assertEqual(expected_details, response_iterator_call.details()) def testAbortedStreamUnary(self): - test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, - _EXPECTED_DETAILS) + test_cases = zip( + _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS + ) for abort_code, expected_code, expected_details in test_cases: self._servicer.set_code(abort_code) self._servicer.set_details(_DETAILS) self._servicer.set_abort_call() with self.assertRaises(grpc.RpcError) as exception_context: - self._stream_unary.with_call(iter([_SERIALIZED_REQUEST] * - test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA, + ) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, - self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) + exception_context.exception.initial_metadata(), + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(expected_code, exception_context.exception.code()) - self.assertEqual(expected_details, - exception_context.exception.details()) + self.assertEqual( + expected_details, exception_context.exception.details() + ) def testAbortedStreamStream(self): - test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, - _EXPECTED_DETAILS) + test_cases = zip( + _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS + ) for abort_code, expected_code, expected_details in test_cases: self._servicer.set_code(abort_code) self._servicer.set_details(_DETAILS) @@ -405,23 +475,30 @@ def testAbortedStreamStream(self): response_iterator_call = self._stream_stream( iter([object()] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) - received_initial_metadata = \ + metadata=_CLIENT_METADATA, + ) + received_initial_metadata = ( response_iterator_call.initial_metadata() + ) with self.assertRaises(grpc.RpcError): self.assertEqual(len(list(response_iterator_call)), 0) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, - self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) + response_iterator_call.trailing_metadata(), + ) + ) self.assertIs(expected_code, response_iterator_call.code()) self.assertEqual(expected_details, response_iterator_call.details()) @@ -434,15 +511,21 @@ def testCustomCodeUnaryUnary(self): self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) + exception_context.exception.initial_metadata(), + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertEqual(_DETAILS, exception_context.exception.details()) @@ -450,22 +533,29 @@ def testCustomCodeUnaryStream(self): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) - response_iterator_call = self._unary_stream(_SERIALIZED_REQUEST, - metadata=_CLIENT_METADATA) + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA + ) received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) + response_iterator_call.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, response_iterator_call.code()) self.assertEqual(_DETAILS, response_iterator_call.details()) @@ -474,21 +564,28 @@ def testCustomCodeStreamUnary(self): self._servicer.set_details(_DETAILS) with self.assertRaises(grpc.RpcError) as exception_context: - self._stream_unary.with_call(iter([_SERIALIZED_REQUEST] * - test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA, + ) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) + exception_context.exception.initial_metadata(), + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertEqual(_DETAILS, exception_context.exception.details()) @@ -496,23 +593,30 @@ def testCustomCodeStreamStream(self): self._servicer.set_code(_NON_OK_CODE) self._servicer.set_details(_DETAILS) - response_iterator_call = self._stream_stream(iter( - [object()] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA, + ) received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError) as exception_context: list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertEqual(_DETAILS, exception_context.exception.details()) @@ -526,15 +630,21 @@ def testCustomCodeExceptionUnaryUnary(self): self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) + exception_context.exception.initial_metadata(), + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertEqual(_DETAILS, exception_context.exception.details()) @@ -543,22 +653,29 @@ def testCustomCodeExceptionUnaryStream(self): self._servicer.set_details(_DETAILS) self._servicer.set_exception() - response_iterator_call = self._unary_stream(_SERIALIZED_REQUEST, - metadata=_CLIENT_METADATA) + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA + ) received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) + response_iterator_call.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, response_iterator_call.code()) self.assertEqual(_DETAILS, response_iterator_call.details()) @@ -568,21 +685,28 @@ def testCustomCodeExceptionStreamUnary(self): self._servicer.set_exception() with self.assertRaises(grpc.RpcError) as exception_context: - self._stream_unary.with_call(iter([_SERIALIZED_REQUEST] * - test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA, + ) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) + exception_context.exception.initial_metadata(), + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertEqual(_DETAILS, exception_context.exception.details()) @@ -591,23 +715,30 @@ def testCustomCodeExceptionStreamStream(self): self._servicer.set_details(_DETAILS) self._servicer.set_exception() - response_iterator_call = self._stream_stream(iter( - [object()] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA, + ) received_initial_metadata = response_iterator_call.initial_metadata() with self.assertRaises(grpc.RpcError): list(response_iterator_call) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, received_initial_metadata + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) + response_iterator_call.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, response_iterator_call.code()) self.assertEqual(_DETAILS, response_iterator_call.details()) @@ -621,15 +752,21 @@ def testCustomCodeReturnNoneUnaryUnary(self): self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) + exception_context.exception.initial_metadata(), + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertEqual(_DETAILS, exception_context.exception.details()) @@ -639,27 +776,33 @@ def testCustomCodeReturnNoneStreamUnary(self): self._servicer.set_return_none() with self.assertRaises(grpc.RpcError) as exception_context: - self._stream_unary.with_call(iter([_SERIALIZED_REQUEST] * - test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA, + ) self.assertTrue( test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) + _CLIENT_METADATA, self._servicer.received_client_metadata() + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) + exception_context.exception.initial_metadata(), + ) + ) self.assertTrue( test_common.metadata_transmitted( _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) + exception_context.exception.trailing_metadata(), + ) + ) self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertEqual(_DETAILS, exception_context.exception.details()) class _InspectServicer(_Servicer): - def __init__(self): super(_InspectServicer, self).__init__() self.actual_code = None @@ -675,22 +818,24 @@ def unary_unary(self, request, context): class InspectContextTest(unittest.TestCase): - def setUp(self): self._servicer = _InspectServicer() self._server = test_common.test_server() self._server.add_generic_rpc_handlers( - (_generic_handler(self._servicer),)) - port = self._server.add_insecure_port('[::]:0') + (_generic_handler(self._servicer),) + ) + port = self._server.add_insecure_port("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._channel = grpc.insecure_channel("localhost:{}".format(port)) self._unary_unary = self._channel.unary_unary( - '/'.join(( - '', - _SERVICE, - _UNARY_UNARY, - )), + "/".join( + ( + "", + _SERVICE, + _UNARY_UNARY, + ) + ), request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, ) @@ -710,12 +855,14 @@ def testCodeDetailsInContext(self): self.assertEqual(_NON_OK_CODE, err.code()) self.assertEqual(self._servicer.actual_code, _NON_OK_CODE) - self.assertEqual(self._servicer.actual_details.decode('utf-8'), - _DETAILS) - self.assertEqual(self._servicer.actual_trailing_metadata, - _SERVER_TRAILING_METADATA) + self.assertEqual( + self._servicer.actual_details.decode("utf-8"), _DETAILS + ) + self.assertEqual( + self._servicer.actual_trailing_metadata, _SERVER_TRAILING_METADATA + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index 598a3fb832c52..a67a496860fee 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -28,13 +28,13 @@ from tests.unit.framework.common import get_socket from tests.unit.framework.common import test_constants -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" def handle_unary_unary(test, request, servicer_context): @@ -58,7 +58,6 @@ def handle_stream_stream(test, request_iterator, servicer_context): class _MethodHandler(grpc.RpcMethodHandler): - def __init__(self, test, request_streaming, response_streaming): self.request_streaming = request_streaming self.response_streaming = response_streaming @@ -70,20 +69,23 @@ def __init__(self, test, request_streaming, response_streaming): self.stream_stream = None if self.request_streaming and self.response_streaming: self.stream_stream = lambda req, ctx: handle_stream_stream( - test, req, ctx) + test, req, ctx + ) elif self.request_streaming: self.stream_unary = lambda req, ctx: handle_stream_unary( - test, req, ctx) + test, req, ctx + ) elif self.response_streaming: self.unary_stream = lambda req, ctx: handle_unary_stream( - test, req, ctx) + test, req, ctx + ) else: self.unary_unary = lambda req, ctx: handle_unary_unary( - test, req, ctx) + test, req, ctx + ) class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self, test): self._test = test @@ -104,36 +106,39 @@ def create_phony_channel(): """Creating phony channels is a workaround for retries""" host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,)) sock.close() - return grpc.insecure_channel('{}:{}'.format(host, port)) + return grpc.insecure_channel("{}:{}".format(host, port)) def perform_unary_unary_call(channel, wait_for_ready=None): channel.unary_unary(_UNARY_UNARY).__call__( _REQUEST, timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + ) def perform_unary_unary_with_call(channel, wait_for_ready=None): channel.unary_unary(_UNARY_UNARY).with_call( _REQUEST, timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + ) def perform_unary_unary_future(channel, wait_for_ready=None): channel.unary_unary(_UNARY_UNARY).future( _REQUEST, timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready).result( - timeout=test_constants.LONG_TIMEOUT) + wait_for_ready=wait_for_ready, + ).result(timeout=test_constants.LONG_TIMEOUT) def perform_unary_stream_call(channel, wait_for_ready=None): response_iterator = channel.unary_stream(_UNARY_STREAM).__call__( _REQUEST, timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + ) for _ in response_iterator: pass @@ -142,43 +147,49 @@ def perform_stream_unary_call(channel, wait_for_ready=None): channel.stream_unary(_STREAM_UNARY).__call__( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + ) def perform_stream_unary_with_call(channel, wait_for_ready=None): channel.stream_unary(_STREAM_UNARY).with_call( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + ) def perform_stream_unary_future(channel, wait_for_ready=None): channel.stream_unary(_STREAM_UNARY).future( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready).result( - timeout=test_constants.LONG_TIMEOUT) + wait_for_ready=wait_for_ready, + ).result(timeout=test_constants.LONG_TIMEOUT) def perform_stream_stream_call(channel, wait_for_ready=None): response_iterator = channel.stream_stream(_STREAM_STREAM).__call__( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + wait_for_ready=wait_for_ready, + ) for _ in response_iterator: pass _ALL_CALL_CASES = [ - perform_unary_unary_call, perform_unary_unary_with_call, - perform_unary_unary_future, perform_unary_stream_call, - perform_stream_unary_call, perform_stream_unary_with_call, - perform_stream_unary_future, perform_stream_stream_call + perform_unary_unary_call, + perform_unary_unary_with_call, + perform_unary_unary_future, + perform_unary_stream_call, + perform_stream_unary_call, + perform_stream_unary_with_call, + perform_stream_unary_future, + perform_stream_stream_call, ] class MetadataFlagsTest(unittest.TestCase): - def check_connection_does_failfast(self, fn, channel, wait_for_ready=None): try: fn(channel, wait_for_ready) @@ -194,9 +205,9 @@ def test_call_wait_for_ready_default(self): def test_call_wait_for_ready_disabled(self): for perform_call in _ALL_CALL_CASES: with create_phony_channel() as channel: - self.check_connection_does_failfast(perform_call, - channel, - wait_for_ready=False) + self.check_connection_does_failfast( + perform_call, channel, wait_for_ready=False + ) def test_call_wait_for_ready_enabled(self): # To test the wait mechanism, Python thread is required to make @@ -210,11 +221,14 @@ def test_call_wait_for_ready_enabled(self): host, port, sock = get_socket(sock_options=(socket.SO_REUSEADDR,)) sock.close() - addr = '{}:{}'.format(host, port) + addr = "{}:{}".format(host, port) wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) def wait_for_transient_failure(channel_connectivity): - if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: + if ( + channel_connectivity + == grpc.ChannelConnectivity.TRANSIENT_FAILURE + ): wg.done() def test_call(perform_call): @@ -231,8 +245,9 @@ def test_call(perform_call): test_threads = [] for perform_call in _ALL_CALL_CASES: - test_thread = threading.Thread(target=test_call, - args=(perform_call,)) + test_thread = threading.Thread( + target=test_call, args=(perform_call,) + ) test_thread.daemon = True test_thread.exception = None test_thread.start() @@ -255,6 +270,6 @@ def test_call(perform_call): raise unhandled_exceptions.get(True) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index d975228d3b0e5..7110177fa1825 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -23,59 +23,63 @@ from tests.unit import test_common from tests.unit.framework.common import test_constants -_CHANNEL_ARGS = (('grpc.primary_user_agent', 'primary-agent'), - ('grpc.secondary_user_agent', 'secondary-agent')) +_CHANNEL_ARGS = ( + ("grpc.primary_user_agent", "primary-agent"), + ("grpc.secondary_user_agent", "secondary-agent"), +) -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" _INVOCATION_METADATA = ( ( - b'invocation-md-key', - u'invocation-md-value', + b"invocation-md-key", + "invocation-md-value", ), ( - u'invocation-md-key-bin', - b'\x00\x01', + "invocation-md-key-bin", + b"\x00\x01", ), ) _EXPECTED_INVOCATION_METADATA = ( ( - 'invocation-md-key', - 'invocation-md-value', + "invocation-md-key", + "invocation-md-value", ), ( - 'invocation-md-key-bin', - b'\x00\x01', + "invocation-md-key-bin", + b"\x00\x01", ), ) -_INITIAL_METADATA = ((b'initial-md-key', u'initial-md-value'), - (u'initial-md-key-bin', b'\x00\x02')) +_INITIAL_METADATA = ( + (b"initial-md-key", "initial-md-value"), + ("initial-md-key-bin", b"\x00\x02"), +) _EXPECTED_INITIAL_METADATA = ( ( - 'initial-md-key', - 'initial-md-value', + "initial-md-key", + "initial-md-value", ), ( - 'initial-md-key-bin', - b'\x00\x02', + "initial-md-key-bin", + b"\x00\x02", ), ) _TRAILING_METADATA = ( ( - 'server-trailing-md-key', - 'server-trailing-md-value', + "server-trailing-md-key", + "server-trailing-md-value", ), ( - 'server-trailing-md-key-bin', - b'\x00\x03', + "server-trailing-md-key-bin", + b"\x00\x03", ), ) _EXPECTED_TRAILING_METADATA = _TRAILING_METADATA @@ -83,20 +87,23 @@ def _user_agent(metadata): for key, val in metadata: - if key == 'user-agent': + if key == "user-agent": return val - raise KeyError('No user agent!') + raise KeyError("No user agent!") def validate_client_metadata(test, servicer_context): invocation_metadata = servicer_context.invocation_metadata() test.assertTrue( - test_common.metadata_transmitted(_EXPECTED_INVOCATION_METADATA, - invocation_metadata)) + test_common.metadata_transmitted( + _EXPECTED_INVOCATION_METADATA, invocation_metadata + ) + ) user_agent = _user_agent(invocation_metadata) test.assertTrue( - user_agent.startswith('primary-agent ' + _channel._USER_AGENT)) - test.assertTrue(user_agent.endswith('secondary-agent')) + user_agent.startswith("primary-agent " + _channel._USER_AGENT) + ) + test.assertTrue(user_agent.endswith("secondary-agent")) def handle_unary_unary(test, request, servicer_context): @@ -135,7 +142,6 @@ def handle_stream_stream(test, request_iterator, servicer_context): class _MethodHandler(grpc.RpcMethodHandler): - def __init__(self, test, request_streaming, response_streaming): self.request_streaming = request_streaming self.response_streaming = response_streaming @@ -156,7 +162,6 @@ def __init__(self, test, request_streaming, response_streaming): class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self, test): self._test = test @@ -174,15 +179,16 @@ def service(self, handler_call_details): class MetadataTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() self._server.add_generic_rpc_handlers( - (_GenericHandler(weakref.proxy(self)),)) - port = self._server.add_insecure_port('[::]:0') + (_GenericHandler(weakref.proxy(self)),) + ) + port = self._server.add_insecure_port("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port, - options=_CHANNEL_ARGS) + self._channel = grpc.insecure_channel( + "localhost:%d" % port, options=_CHANNEL_ARGS + ) def tearDown(self): self._server.stop(0) @@ -191,52 +197,72 @@ def tearDown(self): def testUnaryUnary(self): multi_callable = self._channel.unary_unary(_UNARY_UNARY) unused_response, call = multi_callable.with_call( - _REQUEST, metadata=_INVOCATION_METADATA) + _REQUEST, metadata=_INVOCATION_METADATA + ) self.assertTrue( - test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA, - call.initial_metadata())) + test_common.metadata_transmitted( + _EXPECTED_INITIAL_METADATA, call.initial_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA, - call.trailing_metadata())) + test_common.metadata_transmitted( + _EXPECTED_TRAILING_METADATA, call.trailing_metadata() + ) + ) def testUnaryStream(self): multi_callable = self._channel.unary_stream(_UNARY_STREAM) call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA) self.assertTrue( - test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA, - call.initial_metadata())) + test_common.metadata_transmitted( + _EXPECTED_INITIAL_METADATA, call.initial_metadata() + ) + ) for _ in call: pass self.assertTrue( - test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA, - call.trailing_metadata())) + test_common.metadata_transmitted( + _EXPECTED_TRAILING_METADATA, call.trailing_metadata() + ) + ) def testStreamUnary(self): multi_callable = self._channel.stream_unary(_STREAM_UNARY) unused_response, call = multi_callable.with_call( iter([_REQUEST] * test_constants.STREAM_LENGTH), - metadata=_INVOCATION_METADATA) + metadata=_INVOCATION_METADATA, + ) self.assertTrue( - test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA, - call.initial_metadata())) + test_common.metadata_transmitted( + _EXPECTED_INITIAL_METADATA, call.initial_metadata() + ) + ) self.assertTrue( - test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA, - call.trailing_metadata())) + test_common.metadata_transmitted( + _EXPECTED_TRAILING_METADATA, call.trailing_metadata() + ) + ) def testStreamStream(self): multi_callable = self._channel.stream_stream(_STREAM_STREAM) - call = multi_callable(iter([_REQUEST] * test_constants.STREAM_LENGTH), - metadata=_INVOCATION_METADATA) + call = multi_callable( + iter([_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_INVOCATION_METADATA, + ) self.assertTrue( - test_common.metadata_transmitted(_EXPECTED_INITIAL_METADATA, - call.initial_metadata())) + test_common.metadata_transmitted( + _EXPECTED_INITIAL_METADATA, call.initial_metadata() + ) + ) for _ in call: pass self.assertTrue( - test_common.metadata_transmitted(_EXPECTED_TRAILING_METADATA, - call.trailing_metadata())) + test_common.metadata_transmitted( + _EXPECTED_TRAILING_METADATA, call.trailing_metadata() + ) + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py index 90d010b9360b5..d412533251cac 100644 --- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py +++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py @@ -24,10 +24,10 @@ from tests.unit.framework.common import bound_socket from tests.unit.framework.common import test_constants -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x01' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x01" -_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_UNARY = "/test/UnaryUnary" def _handle_unary_unary(unused_request, unused_servicer_context): @@ -35,16 +35,19 @@ def _handle_unary_unary(unused_request, unused_servicer_context): class ReconnectTest(unittest.TestCase): - def test_reconnect(self): server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(_handle_unary_unary) - }) - options = (('grpc.so_reuseport', 1),) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + _handle_unary_unary + ) + }, + ) + options = (("grpc.so_reuseport", 1),) with bound_socket() as (host, port): - addr = '{}:{}'.format(host, port) + addr = "{}:{}".format(host, port) server = grpc.server(server_pool, (handler,), options=options) server.add_insecure_port(addr) server.start() @@ -64,6 +67,6 @@ def test_reconnect(self): channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py index bd3272176ef17..3fc04f06a18bc 100644 --- a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py +++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py @@ -24,17 +24,16 @@ from tests.unit import test_common from tests.unit.framework.common import test_constants -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" class _TestTrigger(object): - def __init__(self, total_call_count): self._total_call_count = total_call_count self._pending_calls = 0 @@ -93,7 +92,6 @@ def handle_stream_stream(trigger, request_iterator, servicer_context): class _MethodHandler(grpc.RpcMethodHandler): - def __init__(self, trigger, request_streaming, response_streaming): self.request_streaming = request_streaming self.response_streaming = response_streaming @@ -104,8 +102,9 @@ def __init__(self, trigger, request_streaming, response_streaming): self.stream_unary = None self.stream_stream = None if self.request_streaming and self.response_streaming: - self.stream_stream = ( - lambda x, y: handle_stream_stream(trigger, x, y)) + self.stream_stream = lambda x, y: handle_stream_stream( + trigger, x, y + ) elif self.request_streaming: self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y) elif self.response_streaming: @@ -115,7 +114,6 @@ def __init__(self, trigger, request_streaming, response_streaming): class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self, trigger): self._trigger = trigger @@ -133,18 +131,18 @@ def service(self, handler_call_details): class ResourceExhaustedTest(unittest.TestCase): - def setUp(self): self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY) self._server = grpc.server( self._server_pool, handlers=(_GenericHandler(self._trigger),), - options=(('grpc.so_reuseport', 0),), - maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY) - port = self._server.add_insecure_port('[::]:0') + options=(("grpc.so_reuseport", 0),), + maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY, + ) + port = self._server.add_insecure_port("[::]:0") self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._server.stop(0) @@ -161,12 +159,16 @@ def testUnaryUnary(self): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable(_REQUEST) - self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.RESOURCE_EXHAUSTED, + exception_context.exception.code(), + ) future_exception = multi_callable.future(_REQUEST) - self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, - future_exception.exception().code()) + self.assertEqual( + grpc.StatusCode.RESOURCE_EXHAUSTED, + future_exception.exception().code(), + ) self._trigger.trigger() for future in futures: @@ -186,8 +188,10 @@ def testUnaryStream(self): with self.assertRaises(grpc.RpcError) as exception_context: next(multi_callable(_REQUEST)) - self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.RESOURCE_EXHAUSTED, + exception_context.exception.code(), + ) self._trigger.trigger() @@ -212,12 +216,16 @@ def testStreamUnary(self): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable(request) - self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.RESOURCE_EXHAUSTED, + exception_context.exception.code(), + ) future_exception = multi_callable.future(request) - self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, - future_exception.exception().code()) + self.assertEqual( + grpc.StatusCode.RESOURCE_EXHAUSTED, + future_exception.exception().code(), + ) self._trigger.trigger() @@ -239,8 +247,10 @@ def testStreamStream(self): with self.assertRaises(grpc.RpcError) as exception_context: next(multi_callable(request)) - self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.RESOURCE_EXHAUSTED, + exception_context.exception.code(), + ) self._trigger.trigger() @@ -254,6 +264,6 @@ def testStreamStream(self): self.assertEqual(_RESPONSE, response) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_rpc_part_1_test.py b/src/python/grpcio_tests/tests/unit/_rpc_part_1_test.py index c3354a621a109..778934b716efe 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_part_1_test.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_part_1_test.py @@ -22,13 +22,15 @@ import grpc from grpc.framework.foundation import logging_pool +from tests.unit._rpc_test_helpers import ( + stream_stream_non_blocking_multi_callable, +) +from tests.unit._rpc_test_helpers import ( + unary_stream_non_blocking_multi_callable, +) from tests.unit._rpc_test_helpers import BaseRPCTest from tests.unit._rpc_test_helpers import Callback from tests.unit._rpc_test_helpers import TIMEOUT_SHORT -from tests.unit._rpc_test_helpers import \ - stream_stream_non_blocking_multi_callable -from tests.unit._rpc_test_helpers import \ - unary_stream_non_blocking_multi_callable from tests.unit._rpc_test_helpers import stream_stream_multi_callable from tests.unit._rpc_test_helpers import stream_unary_multi_callable from tests.unit._rpc_test_helpers import unary_stream_multi_callable @@ -37,10 +39,10 @@ class RPCPart1Test(BaseRPCTest, unittest.TestCase): - def testExpiredStreamRequestBlockingUnaryResponse(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) multi_callable = stream_unary_multi_callable(self._channel) @@ -49,20 +51,25 @@ def testExpiredStreamRequestBlockingUnaryResponse(self): multi_callable( request_iterator, timeout=TIMEOUT_SHORT, - metadata=(('test', - 'ExpiredStreamRequestBlockingUnaryResponse'),)) + metadata=( + ("test", "ExpiredStreamRequestBlockingUnaryResponse"), + ), + ) self.assertIsInstance(exception_context.exception, grpc.RpcError) self.assertIsInstance(exception_context.exception, grpc.Call) self.assertIsNotNone(exception_context.exception.initial_metadata()) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code(), + ) self.assertIsNotNone(exception_context.exception.details()) self.assertIsNotNone(exception_context.exception.trailing_metadata()) def testExpiredStreamRequestFutureUnaryResponse(self): requests = tuple( - b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x18" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) callback = Callback() @@ -71,7 +78,8 @@ def testExpiredStreamRequestFutureUnaryResponse(self): response_future = multi_callable.future( request_iterator, timeout=TIMEOUT_SHORT, - metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),)) + metadata=(("test", "ExpiredStreamRequestFutureUnaryResponse"),), + ) with self.assertRaises(grpc.FutureTimeoutError): response_future.result(timeout=TIMEOUT_SHORT / 2.0) response_future.add_done_callback(callback) @@ -80,8 +88,10 @@ def testExpiredStreamRequestFutureUnaryResponse(self): with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code()) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code(), + ) self.assertIsInstance(response_future.exception(), grpc.RpcError) self.assertIsNotNone(response_future.traceback()) self.assertIs(response_future, value_passed_to_callback) @@ -92,40 +102,46 @@ def testExpiredStreamRequestFutureUnaryResponse(self): def testExpiredStreamRequestStreamResponse(self): self._expired_stream_request_stream_response( - stream_stream_multi_callable(self._channel)) + stream_stream_multi_callable(self._channel) + ) def testExpiredStreamRequestStreamResponseNonBlocking(self): self._expired_stream_request_stream_response( - stream_stream_non_blocking_multi_callable(self._channel)) + stream_stream_non_blocking_multi_callable(self._channel) + ) def testFailedUnaryRequestBlockingUnaryResponse(self): - request = b'\x37\x17' + request = b"\x37\x17" multi_callable = unary_unary_multi_callable(self._channel) with self._control.fail(): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable.with_call( request, - metadata=(('test', - 'FailedUnaryRequestBlockingUnaryResponse'),)) - - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + metadata=( + ("test", "FailedUnaryRequestBlockingUnaryResponse"), + ), + ) + + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) # sanity checks on to make sure returned string contains default members # of the error debug_error_string = exception_context.exception.debug_error_string() - self.assertIn('grpc_status', debug_error_string) - self.assertIn('grpc_message', debug_error_string) + self.assertIn("grpc_status", debug_error_string) + self.assertIn("grpc_message", debug_error_string) def testFailedUnaryRequestFutureUnaryResponse(self): - request = b'\x37\x17' + request = b"\x37\x17" callback = Callback() multi_callable = unary_unary_multi_callable(self._channel) with self._control.fail(): response_future = multi_callable.future( request, - metadata=(('test', 'FailedUnaryRequestFutureUnaryResponse'),)) + metadata=(("test", "FailedUnaryRequestFutureUnaryResponse"),), + ) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -133,25 +149,30 @@ def testFailedUnaryRequestFutureUnaryResponse(self): self.assertIsInstance(response_future, grpc.Call) with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) self.assertIsInstance(response_future.exception(), grpc.RpcError) self.assertIsNotNone(response_future.traceback()) - self.assertIs(grpc.StatusCode.UNKNOWN, - response_future.exception().code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, response_future.exception().code() + ) self.assertIs(response_future, value_passed_to_callback) def testFailedUnaryRequestStreamResponse(self): self._failed_unary_request_stream_response( - unary_stream_multi_callable(self._channel)) + unary_stream_multi_callable(self._channel) + ) def testFailedUnaryRequestStreamResponseNonBlocking(self): self._failed_unary_request_stream_response( - unary_stream_non_blocking_multi_callable(self._channel)) + unary_stream_non_blocking_multi_callable(self._channel) + ) def testFailedStreamRequestBlockingUnaryResponse(self): requests = tuple( - b'\x47\x58' for _ in range(test_constants.STREAM_LENGTH)) + b"\x47\x58" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) multi_callable = stream_unary_multi_callable(self._channel) @@ -159,15 +180,19 @@ def testFailedStreamRequestBlockingUnaryResponse(self): with self.assertRaises(grpc.RpcError) as exception_context: multi_callable( request_iterator, - metadata=(('test', - 'FailedStreamRequestBlockingUnaryResponse'),)) + metadata=( + ("test", "FailedStreamRequestBlockingUnaryResponse"), + ), + ) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) def testFailedStreamRequestFutureUnaryResponse(self): requests = tuple( - b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x18" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) callback = Callback() @@ -175,62 +200,73 @@ def testFailedStreamRequestFutureUnaryResponse(self): with self._control.fail(): response_future = multi_callable.future( request_iterator, - metadata=(('test', 'FailedStreamRequestFutureUnaryResponse'),)) + metadata=(("test", "FailedStreamRequestFutureUnaryResponse"),), + ) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() self.assertIs(grpc.StatusCode.UNKNOWN, response_future.code()) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) self.assertIsInstance(response_future.exception(), grpc.RpcError) self.assertIsNotNone(response_future.traceback()) self.assertIs(response_future, value_passed_to_callback) def testFailedStreamRequestStreamResponse(self): self._failed_stream_request_stream_response( - stream_stream_multi_callable(self._channel)) + stream_stream_multi_callable(self._channel) + ) def testFailedStreamRequestStreamResponseNonBlocking(self): self._failed_stream_request_stream_response( - stream_stream_non_blocking_multi_callable(self._channel)) + stream_stream_non_blocking_multi_callable(self._channel) + ) def testIgnoredUnaryRequestFutureUnaryResponse(self): - request = b'\x37\x17' + request = b"\x37\x17" multi_callable = unary_unary_multi_callable(self._channel) multi_callable.future( request, - metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),)) + metadata=(("test", "IgnoredUnaryRequestFutureUnaryResponse"),), + ) def testIgnoredUnaryRequestStreamResponse(self): self._ignored_unary_stream_request_future_unary_response( - unary_stream_multi_callable(self._channel)) + unary_stream_multi_callable(self._channel) + ) def testIgnoredUnaryRequestStreamResponseNonBlocking(self): self._ignored_unary_stream_request_future_unary_response( - unary_stream_non_blocking_multi_callable(self._channel)) + unary_stream_non_blocking_multi_callable(self._channel) + ) def testIgnoredStreamRequestFutureUnaryResponse(self): requests = tuple( - b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x18" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) multi_callable = stream_unary_multi_callable(self._channel) multi_callable.future( request_iterator, - metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),)) + metadata=(("test", "IgnoredStreamRequestFutureUnaryResponse"),), + ) def testIgnoredStreamRequestStreamResponse(self): self._ignored_stream_request_stream_response( - stream_stream_multi_callable(self._channel)) + stream_stream_multi_callable(self._channel) + ) def testIgnoredStreamRequestStreamResponseNonBlocking(self): self._ignored_stream_request_stream_response( - stream_stream_non_blocking_multi_callable(self._channel)) + stream_stream_non_blocking_multi_callable(self._channel) + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=3) diff --git a/src/python/grpcio_tests/tests/unit/_rpc_part_2_test.py b/src/python/grpcio_tests/tests/unit/_rpc_part_2_test.py index a8e6ddeb53473..5d4f46f9d8b70 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_part_2_test.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_part_2_test.py @@ -22,13 +22,15 @@ import grpc from grpc.framework.foundation import logging_pool +from tests.unit._rpc_test_helpers import ( + stream_stream_non_blocking_multi_callable, +) +from tests.unit._rpc_test_helpers import ( + unary_stream_non_blocking_multi_callable, +) from tests.unit._rpc_test_helpers import BaseRPCTest from tests.unit._rpc_test_helpers import Callback from tests.unit._rpc_test_helpers import TIMEOUT_SHORT -from tests.unit._rpc_test_helpers import \ - stream_stream_non_blocking_multi_callable -from tests.unit._rpc_test_helpers import \ - unary_stream_non_blocking_multi_callable from tests.unit._rpc_test_helpers import stream_stream_multi_callable from tests.unit._rpc_test_helpers import stream_unary_multi_callable from tests.unit._rpc_test_helpers import unary_stream_multi_callable @@ -37,59 +39,65 @@ class RPCPart2Test(BaseRPCTest, unittest.TestCase): - def testDefaultThreadPoolIsUsed(self): self._consume_one_stream_response_unary_request( - unary_stream_multi_callable(self._channel)) + unary_stream_multi_callable(self._channel) + ) self.assertFalse(self._thread_pool.was_used()) def testExperimentalThreadPoolIsUsed(self): self._consume_one_stream_response_unary_request( - unary_stream_non_blocking_multi_callable(self._channel)) + unary_stream_non_blocking_multi_callable(self._channel) + ) self.assertTrue(self._thread_pool.was_used()) def testUnrecognizedMethod(self): - request = b'abc' + request = b"abc" with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary('NoSuchMethod')(request) + self._channel.unary_unary("NoSuchMethod")(request) - self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code() + ) def testSuccessfulUnaryRequestBlockingUnaryResponse(self): - request = b'\x07\x08' + request = b"\x07\x08" expected_response = self._handler.handle_unary_unary(request, None) multi_callable = unary_unary_multi_callable(self._channel) response = multi_callable( request, - metadata=(('test', 'SuccessfulUnaryRequestBlockingUnaryResponse'),)) + metadata=(("test", "SuccessfulUnaryRequestBlockingUnaryResponse"),), + ) self.assertEqual(expected_response, response) def testSuccessfulUnaryRequestBlockingUnaryResponseWithCall(self): - request = b'\x07\x08' + request = b"\x07\x08" expected_response = self._handler.handle_unary_unary(request, None) multi_callable = unary_unary_multi_callable(self._channel) response, call = multi_callable.with_call( request, - metadata=(('test', - 'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),)) + metadata=( + ("test", "SuccessfulUnaryRequestBlockingUnaryResponseWithCall"), + ), + ) self.assertEqual(expected_response, response) self.assertIs(grpc.StatusCode.OK, call.code()) - self.assertEqual('', call.debug_error_string()) + self.assertEqual("", call.debug_error_string()) def testSuccessfulUnaryRequestFutureUnaryResponse(self): - request = b'\x07\x08' + request = b"\x07\x08" expected_response = self._handler.handle_unary_unary(request, None) multi_callable = unary_unary_multi_callable(self._channel) response_future = multi_callable.future( request, - metadata=(('test', 'SuccessfulUnaryRequestFutureUnaryResponse'),)) + metadata=(("test", "SuccessfulUnaryRequestFutureUnaryResponse"),), + ) response = response_future.result() self.assertIsInstance(response_future, grpc.Future) @@ -99,61 +107,76 @@ def testSuccessfulUnaryRequestFutureUnaryResponse(self): self.assertIsNone(response_future.traceback()) def testSuccessfulUnaryRequestStreamResponse(self): - request = b'\x37\x58' + request = b"\x37\x58" expected_responses = tuple( - self._handler.handle_unary_stream(request, None)) + self._handler.handle_unary_stream(request, None) + ) multi_callable = unary_stream_multi_callable(self._channel) response_iterator = multi_callable( request, - metadata=(('test', 'SuccessfulUnaryRequestStreamResponse'),)) + metadata=(("test", "SuccessfulUnaryRequestStreamResponse"),), + ) responses = tuple(response_iterator) self.assertSequenceEqual(expected_responses, responses) def testSuccessfulStreamRequestBlockingUnaryResponse(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) expected_response = self._handler.handle_stream_unary( - iter(requests), None) + iter(requests), None + ) request_iterator = iter(requests) multi_callable = stream_unary_multi_callable(self._channel) response = multi_callable( request_iterator, - metadata=(('test', - 'SuccessfulStreamRequestBlockingUnaryResponse'),)) + metadata=( + ("test", "SuccessfulStreamRequestBlockingUnaryResponse"), + ), + ) self.assertEqual(expected_response, response) def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) expected_response = self._handler.handle_stream_unary( - iter(requests), None) + iter(requests), None + ) request_iterator = iter(requests) multi_callable = stream_unary_multi_callable(self._channel) response, call = multi_callable.with_call( request_iterator, metadata=( - ('test', - 'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),)) + ( + "test", + "SuccessfulStreamRequestBlockingUnaryResponseWithCall", + ), + ), + ) self.assertEqual(expected_response, response) self.assertIs(grpc.StatusCode.OK, call.code()) def testSuccessfulStreamRequestFutureUnaryResponse(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) expected_response = self._handler.handle_stream_unary( - iter(requests), None) + iter(requests), None + ) request_iterator = iter(requests) multi_callable = stream_unary_multi_callable(self._channel) response_future = multi_callable.future( request_iterator, - metadata=(('test', 'SuccessfulStreamRequestFutureUnaryResponse'),)) + metadata=(("test", "SuccessfulStreamRequestFutureUnaryResponse"),), + ) response = response_future.result() self.assertEqual(expected_response, response) @@ -162,35 +185,40 @@ def testSuccessfulStreamRequestFutureUnaryResponse(self): def testSuccessfulStreamRequestStreamResponse(self): requests = tuple( - b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)) + b"\x77\x58" for _ in range(test_constants.STREAM_LENGTH) + ) expected_responses = tuple( - self._handler.handle_stream_stream(iter(requests), None)) + self._handler.handle_stream_stream(iter(requests), None) + ) request_iterator = iter(requests) multi_callable = stream_stream_multi_callable(self._channel) response_iterator = multi_callable( request_iterator, - metadata=(('test', 'SuccessfulStreamRequestStreamResponse'),)) + metadata=(("test", "SuccessfulStreamRequestStreamResponse"),), + ) responses = tuple(response_iterator) self.assertSequenceEqual(expected_responses, responses) def testSequentialInvocations(self): - first_request = b'\x07\x08' - second_request = b'\x0809' + first_request = b"\x07\x08" + second_request = b"\x0809" expected_first_response = self._handler.handle_unary_unary( - first_request, None) + first_request, None + ) expected_second_response = self._handler.handle_unary_unary( - second_request, None) + second_request, None + ) multi_callable = unary_unary_multi_callable(self._channel) - first_response = multi_callable(first_request, - metadata=(('test', - 'SequentialInvocations'),)) - second_response = multi_callable(second_request, - metadata=(('test', - 'SequentialInvocations'),)) + first_response = multi_callable( + first_request, metadata=(("test", "SequentialInvocations"),) + ) + second_response = multi_callable( + second_request, metadata=(("test", "SequentialInvocations"),) + ) self.assertEqual(expected_first_response, first_response) self.assertEqual(expected_second_response, second_response) @@ -198,11 +226,14 @@ def testSequentialInvocations(self): def testConcurrentBlockingInvocations(self): pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) expected_response = self._handler.handle_stream_unary( - iter(requests), None) - expected_responses = [expected_response - ] * test_constants.THREAD_CONCURRENCY + iter(requests), None + ) + expected_responses = [ + expected_response + ] * test_constants.THREAD_CONCURRENCY response_futures = [None] * test_constants.THREAD_CONCURRENCY multi_callable = stream_unary_multi_callable(self._channel) @@ -211,21 +242,26 @@ def testConcurrentBlockingInvocations(self): response_future = pool.submit( multi_callable, request_iterator, - metadata=(('test', 'ConcurrentBlockingInvocations'),)) + metadata=(("test", "ConcurrentBlockingInvocations"),), + ) response_futures[index] = response_future responses = tuple( - response_future.result() for response_future in response_futures) + response_future.result() for response_future in response_futures + ) pool.shutdown(wait=True) self.assertSequenceEqual(expected_responses, responses) def testConcurrentFutureInvocations(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) expected_response = self._handler.handle_stream_unary( - iter(requests), None) - expected_responses = [expected_response - ] * test_constants.THREAD_CONCURRENCY + iter(requests), None + ) + expected_responses = [ + expected_response + ] * test_constants.THREAD_CONCURRENCY response_futures = [None] * test_constants.THREAD_CONCURRENCY multi_callable = stream_unary_multi_callable(self._channel) @@ -233,23 +269,24 @@ def testConcurrentFutureInvocations(self): request_iterator = iter(requests) response_future = multi_callable.future( request_iterator, - metadata=(('test', 'ConcurrentFutureInvocations'),)) + metadata=(("test", "ConcurrentFutureInvocations"),), + ) response_futures[index] = response_future responses = tuple( - response_future.result() for response_future in response_futures) + response_future.result() for response_future in response_futures + ) self.assertSequenceEqual(expected_responses, responses) def testWaitingForSomeButNotAllConcurrentFutureInvocations(self): pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY) - request = b'\x67\x68' + request = b"\x67\x68" expected_response = self._handler.handle_unary_unary(request, None) response_futures = [None] * test_constants.THREAD_CONCURRENCY lock = threading.Lock() test_is_running_cell = [True] def wrap_future(future): - def wrap(): try: return future.result() @@ -266,15 +303,21 @@ def wrap(): inner_response_future = multi_callable.future( request, metadata=( - ('test', - 'WaitingForSomeButNotAllConcurrentFutureInvocations'),)) + ( + "test", + "WaitingForSomeButNotAllConcurrentFutureInvocations", + ), + ), + ) outer_response_future = pool.submit( - wrap_future(inner_response_future)) + wrap_future(inner_response_future) + ) response_futures[index] = outer_response_future some_completed_response_futures_iterator = itertools.islice( futures.as_completed(response_futures), - test_constants.THREAD_CONCURRENCY // 2) + test_constants.THREAD_CONCURRENCY // 2, + ) for response_future in some_completed_response_futures_iterator: self.assertEqual(expected_response, response_future.result()) with lock: @@ -282,44 +325,53 @@ def wrap(): def testConsumingOneStreamResponseUnaryRequest(self): self._consume_one_stream_response_unary_request( - unary_stream_multi_callable(self._channel)) + unary_stream_multi_callable(self._channel) + ) def testConsumingOneStreamResponseUnaryRequestNonBlocking(self): self._consume_one_stream_response_unary_request( - unary_stream_non_blocking_multi_callable(self._channel)) + unary_stream_non_blocking_multi_callable(self._channel) + ) def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self): self._consume_some_but_not_all_stream_responses_unary_request( - unary_stream_multi_callable(self._channel)) + unary_stream_multi_callable(self._channel) + ) def testConsumingSomeButNotAllStreamResponsesUnaryRequestNonBlocking(self): self._consume_some_but_not_all_stream_responses_unary_request( - unary_stream_non_blocking_multi_callable(self._channel)) + unary_stream_non_blocking_multi_callable(self._channel) + ) def testConsumingSomeButNotAllStreamResponsesStreamRequest(self): self._consume_some_but_not_all_stream_responses_stream_request( - stream_stream_multi_callable(self._channel)) + stream_stream_multi_callable(self._channel) + ) def testConsumingSomeButNotAllStreamResponsesStreamRequestNonBlocking(self): self._consume_some_but_not_all_stream_responses_stream_request( - stream_stream_non_blocking_multi_callable(self._channel)) + stream_stream_non_blocking_multi_callable(self._channel) + ) def testConsumingTooManyStreamResponsesStreamRequest(self): self._consume_too_many_stream_responses_stream_request( - stream_stream_multi_callable(self._channel)) + stream_stream_multi_callable(self._channel) + ) def testConsumingTooManyStreamResponsesStreamRequestNonBlocking(self): self._consume_too_many_stream_responses_stream_request( - stream_stream_non_blocking_multi_callable(self._channel)) + stream_stream_non_blocking_multi_callable(self._channel) + ) def testCancelledUnaryRequestUnaryResponse(self): - request = b'\x07\x17' + request = b"\x07\x17" multi_callable = unary_unary_multi_callable(self._channel) with self._control.pause(): response_future = multi_callable.future( request, - metadata=(('test', 'CancelledUnaryRequestUnaryResponse'),)) + metadata=(("test", "CancelledUnaryRequestUnaryResponse"),), + ) response_future.cancel() self.assertIs(grpc.StatusCode.CANCELLED, response_future.code()) @@ -333,22 +385,26 @@ def testCancelledUnaryRequestUnaryResponse(self): def testCancelledUnaryRequestStreamResponse(self): self._cancelled_unary_request_stream_response( - unary_stream_multi_callable(self._channel)) + unary_stream_multi_callable(self._channel) + ) def testCancelledUnaryRequestStreamResponseNonBlocking(self): self._cancelled_unary_request_stream_response( - unary_stream_non_blocking_multi_callable(self._channel)) + unary_stream_non_blocking_multi_callable(self._channel) + ) def testCancelledStreamRequestUnaryResponse(self): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) multi_callable = stream_unary_multi_callable(self._channel) with self._control.pause(): response_future = multi_callable.future( request_iterator, - metadata=(('test', 'CancelledStreamRequestUnaryResponse'),)) + metadata=(("test", "CancelledStreamRequestUnaryResponse"),), + ) self._control.block_until_paused() response_future.cancel() @@ -366,14 +422,16 @@ def testCancelledStreamRequestUnaryResponse(self): def testCancelledStreamRequestStreamResponse(self): self._cancelled_stream_request_stream_response( - stream_stream_multi_callable(self._channel)) + stream_stream_multi_callable(self._channel) + ) def testCancelledStreamRequestStreamResponseNonBlocking(self): self._cancelled_stream_request_stream_response( - stream_stream_non_blocking_multi_callable(self._channel)) + stream_stream_non_blocking_multi_callable(self._channel) + ) def testExpiredUnaryRequestBlockingUnaryResponse(self): - request = b'\x07\x17' + request = b"\x07\x17" multi_callable = unary_unary_multi_callable(self._channel) with self._control.pause(): @@ -381,18 +439,22 @@ def testExpiredUnaryRequestBlockingUnaryResponse(self): multi_callable.with_call( request, timeout=TIMEOUT_SHORT, - metadata=(('test', - 'ExpiredUnaryRequestBlockingUnaryResponse'),)) + metadata=( + ("test", "ExpiredUnaryRequestBlockingUnaryResponse"), + ), + ) self.assertIsInstance(exception_context.exception, grpc.Call) self.assertIsNotNone(exception_context.exception.initial_metadata()) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code(), + ) self.assertIsNotNone(exception_context.exception.details()) self.assertIsNotNone(exception_context.exception.trailing_metadata()) def testExpiredUnaryRequestFutureUnaryResponse(self): - request = b'\x07\x17' + request = b"\x07\x17" callback = Callback() multi_callable = unary_unary_multi_callable(self._channel) @@ -400,7 +462,8 @@ def testExpiredUnaryRequestFutureUnaryResponse(self): response_future = multi_callable.future( request, timeout=TIMEOUT_SHORT, - metadata=(('test', 'ExpiredUnaryRequestFutureUnaryResponse'),)) + metadata=(("test", "ExpiredUnaryRequestFutureUnaryResponse"),), + ) response_future.add_done_callback(callback) value_passed_to_callback = callback.value() @@ -411,22 +474,28 @@ def testExpiredUnaryRequestFutureUnaryResponse(self): self.assertIsNotNone(response_future.trailing_metadata()) with self.assertRaises(grpc.RpcError) as exception_context: response_future.result() - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code(), + ) self.assertIsInstance(response_future.exception(), grpc.RpcError) self.assertIsNotNone(response_future.traceback()) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - response_future.exception().code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, + response_future.exception().code(), + ) def testExpiredUnaryRequestStreamResponse(self): self._expired_unary_request_stream_response( - unary_stream_multi_callable(self._channel)) + unary_stream_multi_callable(self._channel) + ) def testExpiredUnaryRequestStreamResponseNonBlocking(self): self._expired_unary_request_stream_response( - unary_stream_non_blocking_multi_callable(self._channel)) + unary_stream_non_blocking_multi_callable(self._channel) + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py b/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py index 49c08e009a7b4..1027be1c677db 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py @@ -25,22 +25,21 @@ from tests.unit.framework.common import test_control _SERIALIZE_REQUEST = lambda bytestring: bytestring * 2 -_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:] +_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2 :] _SERIALIZE_RESPONSE = lambda bytestring: bytestring * 3 -_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3] +_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[: len(bytestring) // 3] -_UNARY_UNARY = '/test/UnaryUnary' -_UNARY_STREAM = '/test/UnaryStream' -_UNARY_STREAM_NON_BLOCKING = '/test/UnaryStreamNonBlocking' -_STREAM_UNARY = '/test/StreamUnary' -_STREAM_STREAM = '/test/StreamStream' -_STREAM_STREAM_NON_BLOCKING = '/test/StreamStreamNonBlocking' +_UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" +_UNARY_STREAM_NON_BLOCKING = "/test/UnaryStreamNonBlocking" +_STREAM_UNARY = "/test/StreamUnary" +_STREAM_STREAM = "/test/StreamStream" +_STREAM_STREAM_NON_BLOCKING = "/test/StreamStreamNonBlocking" TIMEOUT_SHORT = datetime.timedelta(seconds=4).total_seconds() class Callback(object): - def __init__(self): self._condition = threading.Condition() self._value = None @@ -60,23 +59,30 @@ def value(self): class _Handler(object): - def __init__(self, control, thread_pool): self._control = control self._thread_pool = thread_pool - non_blocking_functions = (self.handle_unary_stream_non_blocking, - self.handle_stream_stream_non_blocking) + non_blocking_functions = ( + self.handle_unary_stream_non_blocking, + self.handle_stream_stream_non_blocking, + ) for non_blocking_function in non_blocking_functions: non_blocking_function.__func__.experimental_non_blocking = True - non_blocking_function.__func__.experimental_thread_pool = self._thread_pool + non_blocking_function.__func__.experimental_thread_pool = ( + self._thread_pool + ) def handle_unary_unary(self, request, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) # TODO(https://github.com/grpc/grpc/issues/8483): test the values # returned by these methods rather than only "smoke" testing that # the return after having been called. @@ -90,22 +96,31 @@ def handle_unary_stream(self, request, servicer_context): yield request self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) - - def handle_unary_stream_non_blocking(self, request, servicer_context, - on_next): + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) + + def handle_unary_stream_non_blocking( + self, request, servicer_context, on_next + ): for _ in range(test_constants.STREAM_LENGTH): self._control.control() on_next(request) self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) on_next(None) def handle_stream_unary(self, request_iterator, servicer_context): @@ -118,32 +133,45 @@ def handle_stream_unary(self, request_iterator, servicer_context): response_elements.append(request) self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) - return b''.join(response_elements) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) + return b"".join(response_elements) def handle_stream_stream(self, request_iterator, servicer_context): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) for request in request_iterator: self._control.control() yield request self._control.control() - def handle_stream_stream_non_blocking(self, request_iterator, - servicer_context, on_next): + def handle_stream_stream_non_blocking( + self, request_iterator, servicer_context, on_next + ): self._control.control() if servicer_context is not None: - servicer_context.set_trailing_metadata((( - 'testkey', - 'testvalue', - ),)) + servicer_context.set_trailing_metadata( + ( + ( + "testkey", + "testvalue", + ), + ) + ) for request in request_iterator: self._control.control() on_next(request) @@ -152,10 +180,17 @@ def handle_stream_stream_non_blocking(self, request_iterator, class _MethodHandler(grpc.RpcMethodHandler): - - def __init__(self, request_streaming, response_streaming, - request_deserializer, response_serializer, unary_unary, - unary_stream, stream_unary, stream_stream): + def __init__( + self, + request_streaming, + response_streaming, + request_deserializer, + response_serializer, + unary_unary, + unary_stream, + stream_unary, + stream_stream, + ): self.request_streaming = request_streaming self.response_streaming = response_streaming self.request_deserializer = request_deserializer @@ -167,34 +202,76 @@ def __init__(self, request_streaming, response_streaming, class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self, handler): self._handler = handler def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: - return _MethodHandler(False, False, None, None, - self._handler.handle_unary_unary, None, None, - None) + return _MethodHandler( + False, + False, + None, + None, + self._handler.handle_unary_unary, + None, + None, + None, + ) elif handler_call_details.method == _UNARY_STREAM: - return _MethodHandler(False, True, _DESERIALIZE_REQUEST, - _SERIALIZE_RESPONSE, None, - self._handler.handle_unary_stream, None, None) + return _MethodHandler( + False, + True, + _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, + None, + self._handler.handle_unary_stream, + None, + None, + ) elif handler_call_details.method == _UNARY_STREAM_NON_BLOCKING: return _MethodHandler( - False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, - self._handler.handle_unary_stream_non_blocking, None, None) + False, + True, + _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, + None, + self._handler.handle_unary_stream_non_blocking, + None, + None, + ) elif handler_call_details.method == _STREAM_UNARY: - return _MethodHandler(True, False, _DESERIALIZE_REQUEST, - _SERIALIZE_RESPONSE, None, None, - self._handler.handle_stream_unary, None) + return _MethodHandler( + True, + False, + _DESERIALIZE_REQUEST, + _SERIALIZE_RESPONSE, + None, + None, + self._handler.handle_stream_unary, + None, + ) elif handler_call_details.method == _STREAM_STREAM: - return _MethodHandler(True, True, None, None, None, None, None, - self._handler.handle_stream_stream) + return _MethodHandler( + True, + True, + None, + None, + None, + None, + None, + self._handler.handle_stream_stream, + ) elif handler_call_details.method == _STREAM_STREAM_NON_BLOCKING: return _MethodHandler( - True, True, None, None, None, None, None, - self._handler.handle_stream_stream_non_blocking) + True, + True, + None, + None, + None, + None, + None, + self._handler.handle_stream_stream_non_blocking, + ) else: return None @@ -204,21 +281,27 @@ def unary_unary_multi_callable(channel): def unary_stream_multi_callable(channel): - return channel.unary_stream(_UNARY_STREAM, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.unary_stream( + _UNARY_STREAM, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def unary_stream_non_blocking_multi_callable(channel): - return channel.unary_stream(_UNARY_STREAM_NON_BLOCKING, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.unary_stream( + _UNARY_STREAM_NON_BLOCKING, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def stream_unary_multi_callable(channel): - return channel.stream_unary(_STREAM_UNARY, - request_serializer=_SERIALIZE_REQUEST, - response_deserializer=_DESERIALIZE_RESPONSE) + return channel.stream_unary( + _STREAM_UNARY, + request_serializer=_SERIALIZE_REQUEST, + response_deserializer=_DESERIALIZE_RESPONSE, + ) def stream_stream_multi_callable(channel): @@ -230,64 +313,74 @@ def stream_stream_non_blocking_multi_callable(channel): class BaseRPCTest(object): - def setUp(self): self._control = test_control.PauseFailControl() self._thread_pool = thread_pool.RecordingThreadPool(max_workers=None) self._handler = _Handler(self._control, self._thread_pool) self._server = test_common.test_server() - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),)) self._server.start() - self._channel = grpc.insecure_channel('localhost:%d' % port) + self._channel = grpc.insecure_channel("localhost:%d" % port) def tearDown(self): self._server.stop(None) self._channel.close() def _consume_one_stream_response_unary_request(self, multi_callable): - request = b'\x57\x38' + request = b"\x57\x38" response_iterator = multi_callable( request, - metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),)) + metadata=(("test", "ConsumingOneStreamResponseUnaryRequest"),), + ) next(response_iterator) def _consume_some_but_not_all_stream_responses_unary_request( - self, multi_callable): - request = b'\x57\x38' + self, multi_callable + ): + request = b"\x57\x38" response_iterator = multi_callable( request, - metadata=(('test', - 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),)) + metadata=( + ("test", "ConsumingSomeButNotAllStreamResponsesUnaryRequest"), + ), + ) for _ in range(test_constants.STREAM_LENGTH // 2): next(response_iterator) def _consume_some_but_not_all_stream_responses_stream_request( - self, multi_callable): + self, multi_callable + ): requests = tuple( - b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) response_iterator = multi_callable( request_iterator, - metadata=(('test', - 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),)) + metadata=( + ("test", "ConsumingSomeButNotAllStreamResponsesStreamRequest"), + ), + ) for _ in range(test_constants.STREAM_LENGTH // 2): next(response_iterator) def _consume_too_many_stream_responses_stream_request(self, multi_callable): requests = tuple( - b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) response_iterator = multi_callable( request_iterator, - metadata=(('test', - 'ConsumingTooManyStreamResponsesStreamRequest'),)) + metadata=( + ("test", "ConsumingTooManyStreamResponsesStreamRequest"), + ), + ) for _ in range(test_constants.STREAM_LENGTH): next(response_iterator) for _ in range(test_constants.STREAM_LENGTH): @@ -300,19 +393,21 @@ def _consume_too_many_stream_responses_stream_request(self, multi_callable): self.assertIsNotNone(response_iterator.trailing_metadata()) def _cancelled_unary_request_stream_response(self, multi_callable): - request = b'\x07\x19' + request = b"\x07\x19" with self._control.pause(): response_iterator = multi_callable( request, - metadata=(('test', 'CancelledUnaryRequestStreamResponse'),)) + metadata=(("test", "CancelledUnaryRequestStreamResponse"),), + ) self._control.block_until_paused() response_iterator.cancel() with self.assertRaises(grpc.RpcError) as exception_context: next(response_iterator) - self.assertIs(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.CANCELLED, exception_context.exception.code() + ) self.assertIsNotNone(response_iterator.initial_metadata()) self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code()) self.assertIsNotNone(response_iterator.details()) @@ -320,13 +415,15 @@ def _cancelled_unary_request_stream_response(self, multi_callable): def _cancelled_stream_request_stream_response(self, multi_callable): requests = tuple( - b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)) + b"\x07\x08" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) with self._control.pause(): response_iterator = multi_callable( request_iterator, - metadata=(('test', 'CancelledStreamRequestStreamResponse'),)) + metadata=(("test", "CancelledStreamRequestStreamResponse"),), + ) response_iterator.cancel() with self.assertRaises(grpc.RpcError): @@ -337,24 +434,29 @@ def _cancelled_stream_request_stream_response(self, multi_callable): self.assertIsNotNone(response_iterator.trailing_metadata()) def _expired_unary_request_stream_response(self, multi_callable): - request = b'\x07\x19' + request = b"\x07\x19" with self._control.pause(): with self.assertRaises(grpc.RpcError) as exception_context: response_iterator = multi_callable( request, timeout=test_constants.SHORT_TIMEOUT, - metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),)) + metadata=(("test", "ExpiredUnaryRequestStreamResponse"),), + ) next(response_iterator) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - response_iterator.code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code(), + ) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code() + ) def _expired_stream_request_stream_response(self, multi_callable): requests = tuple( - b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH)) + b"\x67\x18" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) with self._control.pause(): @@ -362,56 +464,68 @@ def _expired_stream_request_stream_response(self, multi_callable): response_iterator = multi_callable( request_iterator, timeout=test_constants.SHORT_TIMEOUT, - metadata=(('test', 'ExpiredStreamRequestStreamResponse'),)) + metadata=(("test", "ExpiredStreamRequestStreamResponse"),), + ) next(response_iterator) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) - self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, - response_iterator.code()) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code(), + ) + self.assertIs( + grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code() + ) def _failed_unary_request_stream_response(self, multi_callable): - request = b'\x37\x17' + request = b"\x37\x17" with self.assertRaises(grpc.RpcError) as exception_context: with self._control.fail(): response_iterator = multi_callable( request, - metadata=(('test', 'FailedUnaryRequestStreamResponse'),)) + metadata=(("test", "FailedUnaryRequestStreamResponse"),), + ) next(response_iterator) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) def _failed_stream_request_stream_response(self, multi_callable): requests = tuple( - b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) with self._control.fail(): with self.assertRaises(grpc.RpcError) as exception_context: response_iterator = multi_callable( request_iterator, - metadata=(('test', 'FailedStreamRequestStreamResponse'),)) + metadata=(("test", "FailedStreamRequestStreamResponse"),), + ) tuple(response_iterator) - self.assertIs(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertIs( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code()) def _ignored_unary_stream_request_future_unary_response( - self, multi_callable): - request = b'\x37\x17' + self, multi_callable + ): + request = b"\x37\x17" - multi_callable(request, - metadata=(('test', - 'IgnoredUnaryRequestStreamResponse'),)) + multi_callable( + request, metadata=(("test", "IgnoredUnaryRequestStreamResponse"),) + ) def _ignored_stream_request_stream_response(self, multi_callable): requests = tuple( - b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH)) + b"\x67\x88" for _ in range(test_constants.STREAM_LENGTH) + ) request_iterator = iter(requests) - multi_callable(request_iterator, - metadata=(('test', - 'IgnoredStreamRequestStreamResponse'),)) + multi_callable( + request_iterator, + metadata=(("test", "IgnoredStreamRequestStreamResponse"),), + ) diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py index 7729e2eda9482..9190f108f7bc5 100644 --- a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py @@ -27,14 +27,14 @@ WAIT_TIME = 1000 -REQUEST = b'request' -RESPONSE = b'response' +REQUEST = b"request" +RESPONSE = b"response" -SERVER_RAISES_EXCEPTION = 'server_raises_exception' -SERVER_DEALLOCATED = 'server_deallocated' -SERVER_FORK_CAN_EXIT = 'server_fork_can_exit' +SERVER_RAISES_EXCEPTION = "server_raises_exception" +SERVER_DEALLOCATED = "server_deallocated" +SERVER_FORK_CAN_EXIT = "server_fork_can_exit" -FORK_EXIT = '/test/ForkExit' +FORK_EXIT = "/test/ForkExit" def fork_and_exit(request, servicer_context): @@ -45,7 +45,6 @@ def fork_and_exit(request, servicer_context): class GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == FORK_EXIT: return grpc.unary_unary_rpc_method_handler(fork_and_exit) @@ -55,7 +54,7 @@ def service(self, handler_call_details): def run_server(port_queue): server = test_common.test_server() - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") port_queue.put(port) server.add_generic_rpc_handlers((GenericHandler(),)) server.start() @@ -81,17 +80,17 @@ def run_test(args): thread.daemon = True thread.start() port = port_queue.get() - channel = grpc.insecure_channel('localhost:%d' % port) + channel = grpc.insecure_channel("localhost:%d" % port) multi_callable = channel.unary_unary(FORK_EXIT) result, call = multi_callable.with_call(REQUEST, wait_for_ready=True) os.wait() else: - raise ValueError('unknown test scenario') + raise ValueError("unknown test scenario") -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() parser = argparse.ArgumentParser() - parser.add_argument('scenario', type=str) + parser.add_argument("scenario", type=str) args = parser.parse_args() run_test(args) diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py index 7067f1f4e3fbc..162cfbdc73aef 100644 --- a/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_test.py @@ -28,8 +28,11 @@ from tests.unit import _server_shutdown_scenarios SCENARIO_FILE = os.path.abspath( - os.path.join(os.path.dirname(os.path.realpath(__file__)), - '_server_shutdown_scenarios.py')) + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "_server_shutdown_scenarios.py", + ) +) INTERPRETER = sys.executable BASE_COMMAND = [INTERPRETER, SCENARIO_FILE] @@ -58,32 +61,34 @@ def wait(process): class ServerShutdown(unittest.TestCase): - # Currently we shut down a server (if possible) after the Python server # instance is garbage collected. This behavior may change in the future. def test_deallocated_server_stops(self): process = subprocess.Popen( BASE_COMMAND + [_server_shutdown_scenarios.SERVER_DEALLOCATED], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) wait(process) def test_server_exception_exits(self): process = subprocess.Popen( BASE_COMMAND + [_server_shutdown_scenarios.SERVER_RAISES_EXCEPTION], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) wait(process) - @unittest.skipIf(os.name == 'nt', 'fork not supported on windows') + @unittest.skipIf(os.name == "nt", "fork not supported on windows") def test_server_fork_can_exit(self): process = subprocess.Popen( BASE_COMMAND + [_server_shutdown_scenarios.SERVER_FORK_CAN_EXIT], stdout=sys.stdout, - stderr=sys.stderr) + stderr=sys.stderr, + ) wait(process) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py b/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py index 2ac9e95f8baf4..4062cba36d4ab 100644 --- a/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py +++ b/src/python/grpcio_tests/tests/unit/_server_ssl_cert_config_test.py @@ -49,28 +49,36 @@ CA_2_PEM = resources.cert_hier_2_root_ca_cert() CLIENT_KEY_1_PEM = resources.cert_hier_1_client_1_key() -CLIENT_CERT_CHAIN_1_PEM = (resources.cert_hier_1_client_1_cert() + - resources.cert_hier_1_intermediate_ca_cert()) +CLIENT_CERT_CHAIN_1_PEM = ( + resources.cert_hier_1_client_1_cert() + + resources.cert_hier_1_intermediate_ca_cert() +) CLIENT_KEY_2_PEM = resources.cert_hier_2_client_1_key() -CLIENT_CERT_CHAIN_2_PEM = (resources.cert_hier_2_client_1_cert() + - resources.cert_hier_2_intermediate_ca_cert()) +CLIENT_CERT_CHAIN_2_PEM = ( + resources.cert_hier_2_client_1_cert() + + resources.cert_hier_2_intermediate_ca_cert() +) SERVER_KEY_1_PEM = resources.cert_hier_1_server_1_key() -SERVER_CERT_CHAIN_1_PEM = (resources.cert_hier_1_server_1_cert() + - resources.cert_hier_1_intermediate_ca_cert()) +SERVER_CERT_CHAIN_1_PEM = ( + resources.cert_hier_1_server_1_cert() + + resources.cert_hier_1_intermediate_ca_cert() +) SERVER_KEY_2_PEM = resources.cert_hier_2_server_1_key() -SERVER_CERT_CHAIN_2_PEM = (resources.cert_hier_2_server_1_cert() + - resources.cert_hier_2_intermediate_ca_cert()) +SERVER_CERT_CHAIN_2_PEM = ( + resources.cert_hier_2_server_1_cert() + + resources.cert_hier_2_intermediate_ca_cert() +) # for use with the CertConfigFetcher. Roughly a simple custom mock # implementation -Call = collections.namedtuple('Call', ['did_raise', 'returned_cert_config']) +Call = collections.namedtuple("Call", ["did_raise", "returned_cert_config"]) def _create_channel(port, credentials): - return grpc.secure_channel('localhost:{}'.format(port), credentials) + return grpc.secure_channel("localhost:{}".format(port), credentials) def _create_client_stub(channel, expect_success): @@ -82,7 +90,6 @@ def _create_client_stub(channel, expect_success): class CertConfigFetcher(object): - def __init__(self): self._lock = threading.Lock() self._calls = [] @@ -97,7 +104,8 @@ def reset(self): def configure(self, should_raise, cert_config): assert not (should_raise and cert_config), ( - "should not specify both should_raise and a cert_config at the same time" + "should not specify both should_raise and a cert_config at the same" + " time" ) with self._lock: self._should_raise = should_raise @@ -111,14 +119,13 @@ def __call__(self): with self._lock: if self._should_raise: self._calls.append(Call(True, None)) - raise ValueError('just for fun, should not affect the test') + raise ValueError("just for fun, should not affect the test") else: self._calls.append(Call(False, self._cert_config)) return self._cert_config class _ServerSSLCertReloadTest(unittest.TestCase, metaclass=abc.ABCMeta): - def __init__(self, *args, **kwargs): super(_ServerSSLCertReloadTest, self).__init__(*args, **kwargs) self.server = None @@ -131,17 +138,20 @@ def require_client_auth(self): def setUp(self): self.server = test_common.test_server() services_pb2_grpc.add_FirstServiceServicer_to_server( - _server_application.FirstServiceServicer(), self.server) + _server_application.FirstServiceServicer(), self.server + ) switch_cert_on_client_num = 10 initial_cert_config = grpc.ssl_server_certificate_configuration( [(SERVER_KEY_1_PEM, SERVER_CERT_CHAIN_1_PEM)], - root_certificates=CA_2_PEM) + root_certificates=CA_2_PEM, + ) self.cert_config_fetcher = CertConfigFetcher() server_credentials = grpc.dynamic_ssl_server_credentials( initial_cert_config, self.cert_config_fetcher, - require_client_authentication=self.require_client_auth()) - self.port = self.server.add_secure_port('[::]:0', server_credentials) + require_client_authentication=self.require_client_auth(), + ) + self.port = self.server.add_secure_port("[::]:0", server_credentials) self.server.start() def tearDown(self): @@ -165,18 +175,23 @@ def _perform_rpc(self, client_stub, expect_success): # the handshake is complete, so the TSI handshaker returns the # TSI_PROTOCOL_FAILURE result. This result does not have a # corresponding status code, so this yields an UNKNOWN status. - self.assertTrue(exception_context.exception.code( - ) in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]) - - def _do_one_shot_client_rpc(self, - expect_success, - root_certificates=None, - private_key=None, - certificate_chain=None): + self.assertTrue( + exception_context.exception.code() + in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN] + ) + + def _do_one_shot_client_rpc( + self, + expect_success, + root_certificates=None, + private_key=None, + certificate_chain=None, + ): credentials = grpc.ssl_channel_credentials( root_certificates=root_certificates, private_key=private_key, - certificate_chain=certificate_chain) + certificate_chain=certificate_chain, + ) with _create_channel(self.port, credentials) as client_channel: client_stub = _create_client_stub(client_channel, expect_success) self._perform_rpc(client_stub, expect_success) @@ -184,10 +199,12 @@ def _do_one_shot_client_rpc(self, def _test(self): # things should work... self.cert_config_fetcher.configure(False, None) - self._do_one_shot_client_rpc(True, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) @@ -197,24 +214,28 @@ def _test(self): # fails because client trusts ca2 and so will reject server self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, None) - self._do_one_shot_client_rpc(False, - root_certificates=CA_2_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + False, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertGreaterEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) for i, call in enumerate(actual_calls): - self.assertFalse(call.did_raise, 'i= {}'.format(i)) - self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i)) + self.assertFalse(call.did_raise, "i= {}".format(i)) + self.assertIsNone(call.returned_cert_config, "i= {}".format(i)) # should work again... self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(True, None) - self._do_one_shot_client_rpc(True, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) self.assertTrue(actual_calls[0].did_raise) @@ -225,23 +246,27 @@ def _test(self): # so server will reject self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, None) - self._do_one_shot_client_rpc(not self.require_client_auth(), - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_1_PEM, - certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + self._do_one_shot_client_rpc( + not self.require_client_auth(), + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertGreaterEqual(len(actual_calls), 1) for i, call in enumerate(actual_calls): - self.assertFalse(call.did_raise, 'i= {}'.format(i)) - self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i)) + self.assertFalse(call.did_raise, "i= {}".format(i)) + self.assertIsNone(call.returned_cert_config, "i= {}".format(i)) # should work again... self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, None) - self._do_one_shot_client_rpc(True, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) @@ -255,7 +280,9 @@ def _test(self): grpc.ssl_channel_credentials( root_certificates=CA_1_PEM, private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM)) + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ), + ) persistent_client_stub_A = _create_client_stub(channel_A, True) self._perform_rpc(persistent_client_stub_A, True) actual_calls = self.cert_config_fetcher.getCalls() @@ -270,7 +297,9 @@ def _test(self): grpc.ssl_channel_credentials( root_certificates=CA_1_PEM, private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM)) + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ), + ) persistent_client_stub_B = _create_client_stub(channel_B, True) self._perform_rpc(persistent_client_stub_B, True) actual_calls = self.cert_config_fetcher.getCalls() @@ -282,28 +311,34 @@ def _test(self): # server switch cert... cert_config = grpc.ssl_server_certificate_configuration( [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)], - root_certificates=CA_1_PEM) + root_certificates=CA_1_PEM, + ) self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, cert_config) - self._do_one_shot_client_rpc(False, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + False, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertGreaterEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) for i, call in enumerate(actual_calls): - self.assertFalse(call.did_raise, 'i= {}'.format(i)) - self.assertEqual(call.returned_cert_config, cert_config, - 'i= {}'.format(i)) + self.assertFalse(call.did_raise, "i= {}".format(i)) + self.assertEqual( + call.returned_cert_config, cert_config, "i= {}".format(i) + ) # now should work again... self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, None) - self._do_one_shot_client_rpc(True, - root_certificates=CA_2_PEM, - private_key=CLIENT_KEY_1_PEM, - certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + self._do_one_shot_client_rpc( + True, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) @@ -312,28 +347,32 @@ def _test(self): # client should be rejected by server if with_client_auth self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, None) - self._do_one_shot_client_rpc(not self.require_client_auth(), - root_certificates=CA_2_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + not self.require_client_auth(), + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertGreaterEqual(len(actual_calls), 1) for i, call in enumerate(actual_calls): - self.assertFalse(call.did_raise, 'i= {}'.format(i)) - self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i)) + self.assertFalse(call.did_raise, "i= {}".format(i)) + self.assertIsNone(call.returned_cert_config, "i= {}".format(i)) # here client should reject server... self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, None) - self._do_one_shot_client_rpc(False, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + False, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertGreaterEqual(len(actual_calls), 1) for i, call in enumerate(actual_calls): - self.assertFalse(call.did_raise, 'i= {}'.format(i)) - self.assertIsNone(call.returned_cert_config, 'i= {}'.format(i)) + self.assertFalse(call.did_raise, "i= {}".format(i)) + self.assertIsNone(call.returned_cert_config, "i= {}".format(i)) # persistent clients should continue to work self.cert_config_fetcher.reset() @@ -353,7 +392,6 @@ def _test(self): class ServerSSLCertConfigFetcherParamsChecks(unittest.TestCase): - def test_check_on_initial_config(self): with self.assertRaises(TypeError): grpc.dynamic_ssl_server_credentials(None, str) @@ -363,7 +401,8 @@ def test_check_on_initial_config(self): def test_check_on_config_fetcher(self): cert_config = grpc.ssl_server_certificate_configuration( [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)], - root_certificates=CA_1_PEM) + root_certificates=CA_1_PEM, + ) with self.assertRaises(TypeError): grpc.dynamic_ssl_server_credentials(cert_config, None) with self.assertRaises(TypeError): @@ -371,7 +410,6 @@ def test_check_on_config_fetcher(self): class ServerSSLCertReloadTestWithClientAuth(_ServerSSLCertReloadTest): - def require_client_auth(self): return True @@ -379,7 +417,6 @@ def require_client_auth(self): class ServerSSLCertReloadTestWithoutClientAuth(_ServerSSLCertReloadTest): - def require_client_auth(self): return False @@ -404,106 +441,127 @@ def require_client_auth(self): def setUp(self): self.server = test_common.test_server() services_pb2_grpc.add_FirstServiceServicer_to_server( - _server_application.FirstServiceServicer(), self.server) + _server_application.FirstServiceServicer(), self.server + ) self.cert_config_A = grpc.ssl_server_certificate_configuration( [(SERVER_KEY_1_PEM, SERVER_CERT_CHAIN_1_PEM)], - root_certificates=CA_2_PEM) + root_certificates=CA_2_PEM, + ) self.cert_config_B = grpc.ssl_server_certificate_configuration( [(SERVER_KEY_2_PEM, SERVER_CERT_CHAIN_2_PEM)], - root_certificates=CA_1_PEM) + root_certificates=CA_1_PEM, + ) self.cert_config_fetcher = CertConfigFetcher() server_credentials = grpc.dynamic_ssl_server_credentials( self.cert_config_A, self.cert_config_fetcher, - require_client_authentication=True) - self.port = self.server.add_secure_port('[::]:0', server_credentials) + require_client_authentication=True, + ) + self.port = self.server.add_secure_port("[::]:0", server_credentials) self.server.start() def test_cert_config_reuse(self): - # succeed with A self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, self.cert_config_A) - self._do_one_shot_client_rpc(True, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) - self.assertEqual(actual_calls[0].returned_cert_config, - self.cert_config_A) + self.assertEqual( + actual_calls[0].returned_cert_config, self.cert_config_A + ) # fail with A self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, self.cert_config_A) - self._do_one_shot_client_rpc(False, - root_certificates=CA_2_PEM, - private_key=CLIENT_KEY_1_PEM, - certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + self._do_one_shot_client_rpc( + False, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertGreaterEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) for i, call in enumerate(actual_calls): - self.assertFalse(call.did_raise, 'i= {}'.format(i)) - self.assertEqual(call.returned_cert_config, self.cert_config_A, - 'i= {}'.format(i)) + self.assertFalse(call.did_raise, "i= {}".format(i)) + self.assertEqual( + call.returned_cert_config, self.cert_config_A, "i= {}".format(i) + ) # succeed again with A self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, self.cert_config_A) - self._do_one_shot_client_rpc(True, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + True, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) - self.assertEqual(actual_calls[0].returned_cert_config, - self.cert_config_A) + self.assertEqual( + actual_calls[0].returned_cert_config, self.cert_config_A + ) # succeed with B self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, self.cert_config_B) - self._do_one_shot_client_rpc(True, - root_certificates=CA_2_PEM, - private_key=CLIENT_KEY_1_PEM, - certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + self._do_one_shot_client_rpc( + True, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) - self.assertEqual(actual_calls[0].returned_cert_config, - self.cert_config_B) + self.assertEqual( + actual_calls[0].returned_cert_config, self.cert_config_B + ) # fail with B self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, self.cert_config_B) - self._do_one_shot_client_rpc(False, - root_certificates=CA_1_PEM, - private_key=CLIENT_KEY_2_PEM, - certificate_chain=CLIENT_CERT_CHAIN_2_PEM) + self._do_one_shot_client_rpc( + False, + root_certificates=CA_1_PEM, + private_key=CLIENT_KEY_2_PEM, + certificate_chain=CLIENT_CERT_CHAIN_2_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertGreaterEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) for i, call in enumerate(actual_calls): - self.assertFalse(call.did_raise, 'i= {}'.format(i)) - self.assertEqual(call.returned_cert_config, self.cert_config_B, - 'i= {}'.format(i)) + self.assertFalse(call.did_raise, "i= {}".format(i)) + self.assertEqual( + call.returned_cert_config, self.cert_config_B, "i= {}".format(i) + ) # succeed again with B self.cert_config_fetcher.reset() self.cert_config_fetcher.configure(False, self.cert_config_B) - self._do_one_shot_client_rpc(True, - root_certificates=CA_2_PEM, - private_key=CLIENT_KEY_1_PEM, - certificate_chain=CLIENT_CERT_CHAIN_1_PEM) + self._do_one_shot_client_rpc( + True, + root_certificates=CA_2_PEM, + private_key=CLIENT_KEY_1_PEM, + certificate_chain=CLIENT_CERT_CHAIN_1_PEM, + ) actual_calls = self.cert_config_fetcher.getCalls() self.assertEqual(len(actual_calls), 1) self.assertFalse(actual_calls[0].did_raise) - self.assertEqual(actual_calls[0].returned_cert_config, - self.cert_config_B) + self.assertEqual( + actual_calls[0].returned_cert_config, self.cert_config_B + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_server_test.py b/src/python/grpcio_tests/tests/unit/_server_test.py index 2cddaf4b81fb4..f1ee5addf9a8f 100644 --- a/src/python/grpcio_tests/tests/unit/_server_test.py +++ b/src/python/grpcio_tests/tests/unit/_server_test.py @@ -22,48 +22,52 @@ class _ActualGenericRpcHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): return None class ServerTest(unittest.TestCase): - def test_not_a_generic_rpc_handler_at_construction(self): with self.assertRaises(AttributeError) as exception_context: - grpc.server(futures.ThreadPoolExecutor(max_workers=5), - handlers=[ - _ActualGenericRpcHandler(), - object(), - ]) - self.assertIn('grpc.GenericRpcHandler', - str(exception_context.exception)) + grpc.server( + futures.ThreadPoolExecutor(max_workers=5), + handlers=[ + _ActualGenericRpcHandler(), + object(), + ], + ) + self.assertIn( + "grpc.GenericRpcHandler", str(exception_context.exception) + ) def test_not_a_generic_rpc_handler_after_construction(self): server = grpc.server(futures.ThreadPoolExecutor(max_workers=5)) with self.assertRaises(AttributeError) as exception_context: - server.add_generic_rpc_handlers([ - _ActualGenericRpcHandler(), - object(), - ]) - self.assertIn('grpc.GenericRpcHandler', - str(exception_context.exception)) + server.add_generic_rpc_handlers( + [ + _ActualGenericRpcHandler(), + object(), + ] + ) + self.assertIn( + "grpc.GenericRpcHandler", str(exception_context.exception) + ) def test_failed_port_binding_exception(self): - server = grpc.server(None, options=(('grpc.so_reuseport', 0),)) - port = server.add_insecure_port('localhost:0') + server = grpc.server(None, options=(("grpc.so_reuseport", 0),)) + port = server.add_insecure_port("localhost:0") bind_address = "localhost:%d" % port with self.assertRaises(RuntimeError): server.add_insecure_port(bind_address) - server_credentials = grpc.ssl_server_credentials([ - (resources.private_key(), resources.certificate_chain()) - ]) + server_credentials = grpc.ssl_server_credentials( + [(resources.private_key(), resources.certificate_chain())] + ) with self.assertRaises(RuntimeError): server.add_secure_port(bind_address, server_credentials) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_server_wait_for_termination_test.py b/src/python/grpcio_tests/tests/unit/_server_wait_for_termination_test.py index 62f6cda1be041..e317cc32c8143 100644 --- a/src/python/grpcio_tests/tests/unit/_server_wait_for_termination_test.py +++ b/src/python/grpcio_tests/tests/unit/_server_wait_for_termination_test.py @@ -34,16 +34,17 @@ def _block_on_waiting(server, termination_event, timeout=None): class ServerWaitForTerminationTest(unittest.TestCase): - def test_unblock_by_invoking_stop(self): termination_event = threading.Event() server = grpc.server(futures.ThreadPoolExecutor()) - wait_thread = threading.Thread(target=_block_on_waiting, - args=( - server, - termination_event, - )) + wait_thread = threading.Thread( + target=_block_on_waiting, + args=( + server, + termination_event, + ), + ) wait_thread.daemon = True wait_thread.start() time.sleep(_WAIT_FOR_BLOCKING.total_seconds()) @@ -56,11 +57,13 @@ def test_unblock_by_del(self): termination_event = threading.Event() server = grpc.server(futures.ThreadPoolExecutor()) - wait_thread = threading.Thread(target=_block_on_waiting, - args=( - server, - termination_event, - )) + wait_thread = threading.Thread( + target=_block_on_waiting, + args=( + server, + termination_event, + ), + ) wait_thread.daemon = True wait_thread.start() time.sleep(_WAIT_FOR_BLOCKING.total_seconds()) @@ -74,12 +77,14 @@ def test_unblock_by_timeout(self): termination_event = threading.Event() server = grpc.server(futures.ThreadPoolExecutor()) - wait_thread = threading.Thread(target=_block_on_waiting, - args=( - server, - termination_event, - test_constants.SHORT_TIMEOUT / 2, - )) + wait_thread = threading.Thread( + target=_block_on_waiting, + args=( + server, + termination_event, + test_constants.SHORT_TIMEOUT / 2, + ), + ) wait_thread.daemon = True wait_thread.start() @@ -87,5 +92,5 @@ def test_unblock_by_timeout(self): self.assertTrue(termination_event.is_set()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_session_cache_test.py b/src/python/grpcio_tests/tests/unit/_session_cache_test.py index 6091219126394..acf671d1ca98e 100644 --- a/src/python/grpcio_tests/tests/unit/_session_cache_test.py +++ b/src/python/grpcio_tests/tests/unit/_session_cache_test.py @@ -24,58 +24,65 @@ from tests.unit import resources from tests.unit import test_common -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" -_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_UNARY = "/test/UnaryUnary" -_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' -_ID = 'id' -_ID_KEY = 'id_key' -_AUTH_CTX = 'auth_ctx' +_SERVER_HOST_OVERRIDE = "foo.test.google.fr" +_ID = "id" +_ID_KEY = "id_key" +_AUTH_CTX = "auth_ctx" _PRIVATE_KEY = resources.private_key() _CERTIFICATE_CHAIN = resources.certificate_chain() _TEST_ROOT_CERTIFICATES = resources.test_root_certificates() _SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) -_PROPERTY_OPTIONS = (( - 'grpc.ssl_target_name_override', - _SERVER_HOST_OVERRIDE, -),) +_PROPERTY_OPTIONS = ( + ( + "grpc.ssl_target_name_override", + _SERVER_HOST_OVERRIDE, + ), +) def handle_unary_unary(request, servicer_context): - return pickle.dumps({ - _ID: servicer_context.peer_identities(), - _ID_KEY: servicer_context.peer_identity_key(), - _AUTH_CTX: servicer_context.auth_context() - }) + return pickle.dumps( + { + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context(), + } + ) def start_secure_server(): handler = grpc.method_handlers_generic_handler( - 'test', - {'UnaryUnary': grpc.unary_unary_rpc_method_handler(handle_unary_unary)}) + "test", + {"UnaryUnary": grpc.unary_unary_rpc_method_handler(handle_unary_unary)}, + ) server = test_common.test_server() server.add_generic_rpc_handlers((handler,)) server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) - port = server.add_secure_port('[::]:0', server_cred) + port = server.add_secure_port("[::]:0", server_cred) server.start() return server, port class SSLSessionCacheTest(unittest.TestCase): - - def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, - expect_ssl_session_reused): - channel = grpc.secure_channel('localhost:{}'.format(port), - channel_creds, - options=channel_options) + def _do_one_shot_client_rpc( + self, channel_creds, channel_options, port, expect_ssl_session_reused + ): + channel = grpc.secure_channel( + "localhost:{}".format(port), channel_creds, options=channel_options + ) response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) auth_data = pickle.loads(response) - self.assertEqual(expect_ssl_session_reused, - auth_data[_AUTH_CTX]['ssl_session_reused']) + self.assertEqual( + expect_ssl_session_reused, + auth_data[_AUTH_CTX]["ssl_session_reused"], + ) channel.close() def testSSLSessionCacheLRU(self): @@ -83,58 +90,74 @@ def testSSLSessionCacheLRU(self): cache = session_cache.ssl_session_cache_lru(1) channel_creds = grpc.ssl_channel_credentials( - root_certificates=_TEST_ROOT_CERTIFICATES) + root_certificates=_TEST_ROOT_CERTIFICATES + ) channel_options = _PROPERTY_OPTIONS + ( - ('grpc.ssl_session_cache', cache),) + ("grpc.ssl_session_cache", cache), + ) # Initial connection has no session to resume - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port_1, - expect_ssl_session_reused=[b'false']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b"false"], + ) # Connection to server_1 resumes from initial session - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port_1, - expect_ssl_session_reused=[b'true']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b"true"], + ) # Connection to a different server with the same name overwrites the cache entry server_2, port_2 = start_secure_server() - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port_2, - expect_ssl_session_reused=[b'false']) - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port_2, - expect_ssl_session_reused=[b'true']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_2, + expect_ssl_session_reused=[b"false"], + ) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_2, + expect_ssl_session_reused=[b"true"], + ) server_2.stop(None) # Connection to server_1 now falls back to full TLS handshake - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port_1, - expect_ssl_session_reused=[b'false']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b"false"], + ) # Re-creating server_1 causes old sessions to become invalid server_1.stop(None) server_1, port_1 = start_secure_server() # Old sessions should no longer be valid - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port_1, - expect_ssl_session_reused=[b'false']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b"false"], + ) # Resumption should work for subsequent connections - self._do_one_shot_client_rpc(channel_creds, - channel_options, - port_1, - expect_ssl_session_reused=[b'true']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b"true"], + ) server_1.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_signal_client.py b/src/python/grpcio_tests/tests/unit/_signal_client.py index eac83b1844ad4..56563c2007502 100644 --- a/src/python/grpcio_tests/tests/unit/_signal_client.py +++ b/src/python/grpcio_tests/tests/unit/_signal_client.py @@ -28,7 +28,7 @@ UNARY_UNARY = "/test/Unary" UNARY_STREAM = "/test/ServerStreaming" -_MESSAGE = b'\x00\x00\x00' +_MESSAGE = b"\x00\x00\x00" _ASSERTION_MESSAGE = "Control flow should never reach here." @@ -55,8 +55,9 @@ def main_unary(server_target): with grpc.insecure_channel(server_target) as channel: multicallable = channel.unary_unary(UNARY_UNARY) signal.signal(signal.SIGINT, handle_sigint) - per_process_rpc_future = multicallable.future(_MESSAGE, - wait_for_ready=True) + per_process_rpc_future = multicallable.future( + _MESSAGE, wait_for_ready=True + ) result = per_process_rpc_future.result() assert False, _ASSERTION_MESSAGE @@ -67,7 +68,8 @@ def main_streaming(server_target): with grpc.insecure_channel(server_target) as channel: signal.signal(signal.SIGINT, handle_sigint) per_process_rpc_future = channel.unary_stream(UNARY_STREAM)( - _MESSAGE, wait_for_ready=True) + _MESSAGE, wait_for_ready=True + ) for result in per_process_rpc_future: pass assert False, _ASSERTION_MESSAGE @@ -90,8 +92,9 @@ def main_streaming_with_exception(server_target): """Initiate a streaming RPC with a signal handler that will raise.""" channel = grpc.insecure_channel(server_target) try: - for _ in channel.unary_stream(UNARY_STREAM)(_MESSAGE, - wait_for_ready=True): + for _ in channel.unary_stream(UNARY_STREAM)( + _MESSAGE, wait_for_ready=True + ): pass except KeyboardInterrupt: sys.stderr.write("Running signal handler.\n") @@ -101,16 +104,18 @@ def main_streaming_with_exception(server_target): channel.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Signal test client.') - parser.add_argument('server', help='Server target') - parser.add_argument('arity', help='Arity', choices=('unary', 'streaming')) - parser.add_argument('--exception', - help='Whether the signal throws an exception', - action='store_true') - parser.add_argument('--gevent', - help='Whether to run under gevent.', - action='store_true') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Signal test client.") + parser.add_argument("server", help="Server target") + parser.add_argument("arity", help="Arity", choices=("unary", "streaming")) + parser.add_argument( + "--exception", + help="Whether the signal throws an exception", + action="store_true", + ) + parser.add_argument( + "--gevent", help="Whether to run under gevent.", action="store_true" + ) args = parser.parse_args() if args.gevent: from gevent import monkey @@ -119,13 +124,14 @@ def main_streaming_with_exception(server_target): monkey.patch_all() import grpc.experimental.gevent + grpc.experimental.gevent.init_gevent() - if args.arity == 'unary' and not args.exception: + if args.arity == "unary" and not args.exception: main_unary(args.server) - elif args.arity == 'streaming' and not args.exception: + elif args.arity == "streaming" and not args.exception: main_streaming(args.server) - elif args.arity == 'unary' and args.exception: + elif args.arity == "unary" and args.exception: main_unary_with_exception(args.server) else: main_streaming_with_exception(args.server) diff --git a/src/python/grpcio_tests/tests/unit/_signal_handling_test.py b/src/python/grpcio_tests/tests/unit/_signal_handling_test.py index df38159d54683..b16261d63e20a 100644 --- a/src/python/grpcio_tests/tests/unit/_signal_handling_test.py +++ b/src/python/grpcio_tests/tests/unit/_signal_handling_test.py @@ -37,9 +37,10 @@ client_name = sys.argv[1].split("/")[-1] del sys.argv[1] # For compatibility with test runner. _CLIENT_PATH = os.path.realpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name)) + os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name) + ) -_HOST = 'localhost' +_HOST = "localhost" # The gevent test harness cannot run the monkeypatch code for the child process, # so we need to instrument it manually. @@ -47,16 +48,17 @@ class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self): self._connected_clients_lock = threading.RLock() self._connected_clients_event = threading.Event() self._connected_clients = 0 self._unary_unary_handler = grpc.unary_unary_rpc_method_handler( - self._handle_unary_unary) + self._handle_unary_unary + ) self._unary_stream_handler = grpc.unary_stream_rpc_method_handler( - self._handle_unary_stream) + self._handle_unary_stream + ) def _on_client_connect(self): with self._connected_clients_lock: @@ -129,10 +131,9 @@ def _start_client(args, stdout, stderr): class SignalHandlingTest(unittest.TestCase): - def setUp(self): self._server = test_common.test_server() - self._port = self._server.add_insecure_port('{}:0'.format(_HOST)) + self._port = self._server.add_insecure_port("{}:0".format(_HOST)) self._handler = _GenericHandler() self._server.add_generic_rpc_handlers((self._handler,)) self._server.start() @@ -140,58 +141,69 @@ def setUp(self): def tearDown(self): self._server.stop(None) - @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") def testUnary(self): """Tests that the server unary code path does not stall signal handlers.""" - server_target = '{}:{}'.format(_HOST, self._port) - with tempfile.TemporaryFile(mode='r') as client_stdout: - with tempfile.TemporaryFile(mode='r') as client_stderr: - client = _start_client((server_target, 'unary') + _GEVENT_ARG, - client_stdout, client_stderr) + server_target = "{}:{}".format(_HOST, self._port) + with tempfile.TemporaryFile(mode="r") as client_stdout: + with tempfile.TemporaryFile(mode="r") as client_stderr: + client = _start_client( + (server_target, "unary") + _GEVENT_ARG, + client_stdout, + client_stderr, + ) self._handler.await_connected_client() client.send_signal(signal.SIGINT) self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) client_stdout.seek(0) - self.assertIn(_signal_client.SIGTERM_MESSAGE, - client_stdout.read()) + self.assertIn( + _signal_client.SIGTERM_MESSAGE, client_stdout.read() + ) - @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") def testStreaming(self): """Tests that the server streaming code path does not stall signal handlers.""" - server_target = '{}:{}'.format(_HOST, self._port) - with tempfile.TemporaryFile(mode='r') as client_stdout: - with tempfile.TemporaryFile(mode='r') as client_stderr: + server_target = "{}:{}".format(_HOST, self._port) + with tempfile.TemporaryFile(mode="r") as client_stdout: + with tempfile.TemporaryFile(mode="r") as client_stderr: client = _start_client( - (server_target, 'streaming') + _GEVENT_ARG, client_stdout, - client_stderr) + (server_target, "streaming") + _GEVENT_ARG, + client_stdout, + client_stderr, + ) self._handler.await_connected_client() client.send_signal(signal.SIGINT) self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) client_stdout.seek(0) - self.assertIn(_signal_client.SIGTERM_MESSAGE, - client_stdout.read()) + self.assertIn( + _signal_client.SIGTERM_MESSAGE, client_stdout.read() + ) - @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") def testUnaryWithException(self): - server_target = '{}:{}'.format(_HOST, self._port) - with tempfile.TemporaryFile(mode='r') as client_stdout: - with tempfile.TemporaryFile(mode='r') as client_stderr: + server_target = "{}:{}".format(_HOST, self._port) + with tempfile.TemporaryFile(mode="r") as client_stdout: + with tempfile.TemporaryFile(mode="r") as client_stderr: client = _start_client( - ('--exception', server_target, 'unary') + _GEVENT_ARG, - client_stdout, client_stderr) + ("--exception", server_target, "unary") + _GEVENT_ARG, + client_stdout, + client_stderr, + ) self._handler.await_connected_client() client.send_signal(signal.SIGINT) client.wait() self.assertEqual(0, client.returncode) - @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + @unittest.skipIf(os.name == "nt", "SIGINT not supported on windows") def testStreamingHandlerWithException(self): - server_target = '{}:{}'.format(_HOST, self._port) - with tempfile.TemporaryFile(mode='r') as client_stdout: - with tempfile.TemporaryFile(mode='r') as client_stderr: + server_target = "{}:{}".format(_HOST, self._port) + with tempfile.TemporaryFile(mode="r") as client_stdout: + with tempfile.TemporaryFile(mode="r") as client_stderr: client = _start_client( - ('--exception', server_target, 'streaming') + _GEVENT_ARG, - client_stdout, client_stderr) + ("--exception", server_target, "streaming") + _GEVENT_ARG, + client_stdout, + client_stderr, + ) self._handler.await_connected_client() client.send_signal(signal.SIGINT) client.wait() @@ -199,6 +211,6 @@ def testStreamingHandlerWithException(self): self.assertEqual(0, client.returncode) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_tcp_proxy.py b/src/python/grpcio_tests/tests/unit/_tcp_proxy.py index 84dc0e2d6cf8e..781ba970e7195 100644 --- a/src/python/grpcio_tests/tests/unit/_tcp_proxy.py +++ b/src/python/grpcio_tests/tests/unit/_tcp_proxy.py @@ -65,9 +65,11 @@ def __init__(self, bind_address, gateway_address, gateway_port): def start(self): _, self._port, self._listen_socket = get_socket( - bind_address=self._bind_address) - self._proxy_socket = _init_proxy_socket(self._gateway_address, - self._gateway_port) + bind_address=self._bind_address + ) + self._proxy_socket = _init_proxy_socket( + self._gateway_address, self._gateway_port + ) self._thread.start() def get_port(self): @@ -92,7 +94,7 @@ def _handle_reads(self, sockets_to_read): else: self._client_sockets.remove(socket_to_read) else: - raise RuntimeError('Unidentified socket appeared in read set.') + raise RuntimeError("Unidentified socket appeared in read set.") def _handle_writes(self, sockets_to_write): for socket_to_write in sockets_to_write: @@ -108,11 +110,15 @@ def _handle_writes(self, sockets_to_write): def _run_proxy(self): while not self._stop_event.is_set(): expected_reads = (self._listen_socket, self._proxy_socket) + tuple( - self._client_sockets) + self._client_sockets + ) expected_writes = expected_reads sockets_to_read, sockets_to_write, _ = select.select( - expected_reads, expected_writes, (), - _TCP_PROXY_TIMEOUT.total_seconds()) + expected_reads, + expected_writes, + (), + _TCP_PROXY_TIMEOUT.total_seconds(), + ) self._handle_reads(sockets_to_read) self._handle_writes(sockets_to_write) for client_socket in self._client_sockets: diff --git a/src/python/grpcio_tests/tests/unit/_version_test.py b/src/python/grpcio_tests/tests/unit/_version_test.py index a81e51e56c5ac..7bf8cf4bf37f6 100644 --- a/src/python/grpcio_tests/tests/unit/_version_test.py +++ b/src/python/grpcio_tests/tests/unit/_version_test.py @@ -21,11 +21,10 @@ class VersionTest(unittest.TestCase): - def test_get_version(self): self.assertEqual(grpc.__version__, _grpcio_metadata.__version__) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py b/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py index 64594061c01ba..977d564888dee 100644 --- a/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py @@ -26,10 +26,10 @@ class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): return grpc.unary_unary_rpc_method_handler( - lambda request, unused_context: request) + lambda request, unused_context: request + ) @contextlib.contextmanager @@ -37,7 +37,8 @@ def xds_channel_server_without_xds(server_fallback_creds): server = grpc.server(futures.ThreadPoolExecutor()) server.add_generic_rpc_handlers((_GenericHandler(),)) server_server_fallback_creds = grpc.ssl_server_credentials( - ((resources.private_key(), resources.certificate_chain()),)) + ((resources.private_key(), resources.certificate_chain()),) + ) server_creds = grpc.xds_server_credentials(server_fallback_creds) port = server.add_secure_port("localhost:0", server_creds) server.start() @@ -48,27 +49,31 @@ def xds_channel_server_without_xds(server_fallback_creds): class XdsCredentialsTest(unittest.TestCase): - def test_xds_creds_fallback_ssl(self): # Since there is no xDS server, the fallback credentials will be used. # In this case, SSL credentials. server_fallback_creds = grpc.ssl_server_credentials( - ((resources.private_key(), resources.certificate_chain()),)) + ((resources.private_key(), resources.certificate_chain()),) + ) with xds_channel_server_without_xds( - server_fallback_creds) as server_address: - override_options = (("grpc.ssl_target_name_override", - "foo.test.google.fr"),) + server_fallback_creds + ) as server_address: + override_options = ( + ("grpc.ssl_target_name_override", "foo.test.google.fr"), + ) channel_fallback_creds = grpc.ssl_channel_credentials( root_certificates=resources.test_root_certificates(), private_key=resources.private_key(), - certificate_chain=resources.certificate_chain()) + certificate_chain=resources.certificate_chain(), + ) channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) - with grpc.secure_channel(server_address, - channel_creds, - options=override_options) as channel: + with grpc.secure_channel( + server_address, channel_creds, options=override_options + ) as channel: request = b"abc" response = channel.unary_unary("/test/method")( - request, wait_for_ready=True) + request, wait_for_ready=True + ) self.assertEqual(response, request) def test_xds_creds_fallback_insecure(self): @@ -76,14 +81,17 @@ def test_xds_creds_fallback_insecure(self): # In this case, insecure. server_fallback_creds = grpc.insecure_server_credentials() with xds_channel_server_without_xds( - server_fallback_creds) as server_address: - channel_fallback_creds = grpc.experimental.insecure_channel_credentials( + server_fallback_creds + ) as server_address: + channel_fallback_creds = ( + grpc.experimental.insecure_channel_credentials() ) channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) with grpc.secure_channel(server_address, channel_creds) as channel: request = b"abc" response = channel.unary_unary("/test/method")( - request, wait_for_ready=True) + request, wait_for_ready=True + ) self.assertEqual(response, request) def test_start_xds_server(self): diff --git a/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py b/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py index 0ce3a127392a1..97f424bcbbdfb 100644 --- a/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py +++ b/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py @@ -25,23 +25,22 @@ from tests.unit.beta import test_utilities from tests.unit.framework.common import test_constants -_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_SERVER_HOST_OVERRIDE = "foo.test.google.fr" -_PER_RPC_CREDENTIALS_METADATA_KEY = b'my-call-credentials-metadata-key' -_PER_RPC_CREDENTIALS_METADATA_VALUE = b'my-call-credentials-metadata-value' +_PER_RPC_CREDENTIALS_METADATA_KEY = b"my-call-credentials-metadata-key" +_PER_RPC_CREDENTIALS_METADATA_VALUE = b"my-call-credentials-metadata-value" -_GROUP = 'group' -_UNARY_UNARY = 'unary-unary' -_UNARY_STREAM = 'unary-stream' -_STREAM_UNARY = 'stream-unary' -_STREAM_STREAM = 'stream-stream' +_GROUP = "group" +_UNARY_UNARY = "unary-unary" +_UNARY_STREAM = "unary-stream" +_STREAM_UNARY = "stream-unary" +_STREAM_STREAM = "stream-stream" -_REQUEST = b'abc' -_RESPONSE = b'123' +_REQUEST = b"abc" +_RESPONSE = b"123" class _Servicer(object): - def __init__(self): self._condition = threading.Condition() self._peer = None @@ -101,7 +100,6 @@ def block_until_serviced(self): class _BlockingIterator(object): - def __init__(self, upstream): self._condition = threading.Condition() self._upstream = upstream @@ -133,24 +131,33 @@ def allow(self): def _metadata_plugin(context, callback): - callback([ - (_PER_RPC_CREDENTIALS_METADATA_KEY, _PER_RPC_CREDENTIALS_METADATA_VALUE) - ], None) + callback( + [ + ( + _PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE, + ) + ], + None, + ) class BetaFeaturesTest(unittest.TestCase): - def setUp(self): self._servicer = _Servicer() method_implementations = { - (_GROUP, _UNARY_UNARY): - utilities.unary_unary_inline(self._servicer.unary_unary), - (_GROUP, _UNARY_STREAM): - utilities.unary_stream_inline(self._servicer.unary_stream), - (_GROUP, _STREAM_UNARY): - utilities.stream_unary_inline(self._servicer.stream_unary), - (_GROUP, _STREAM_STREAM): - utilities.stream_stream_inline(self._servicer.stream_stream), + (_GROUP, _UNARY_UNARY): utilities.unary_unary_inline( + self._servicer.unary_unary + ), + (_GROUP, _UNARY_STREAM): utilities.unary_stream_inline( + self._servicer.unary_stream + ), + (_GROUP, _STREAM_UNARY): utilities.stream_unary_inline( + self._servicer.stream_unary + ), + (_GROUP, _STREAM_STREAM): utilities.stream_stream_inline( + self._servicer.stream_stream + ), } cardinalities = { @@ -161,29 +168,36 @@ def setUp(self): } server_options = implementations.server_options( - thread_pool_size=test_constants.POOL_SIZE) - self._server = implementations.server(method_implementations, - options=server_options) - server_credentials = implementations.ssl_server_credentials([ - ( - resources.private_key(), - resources.certificate_chain(), - ), - ]) - port = self._server.add_secure_port('[::]:0', server_credentials) + thread_pool_size=test_constants.POOL_SIZE + ) + self._server = implementations.server( + method_implementations, options=server_options + ) + server_credentials = implementations.ssl_server_credentials( + [ + ( + resources.private_key(), + resources.certificate_chain(), + ), + ] + ) + port = self._server.add_secure_port("[::]:0", server_credentials) self._server.start() self._channel_credentials = implementations.ssl_channel_credentials( - resources.test_root_certificates()) + resources.test_root_certificates() + ) self._call_credentials = implementations.metadata_call_credentials( - _metadata_plugin) + _metadata_plugin + ) channel = test_utilities.not_really_secure_channel( - 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE) + "localhost", port, self._channel_credentials, _SERVER_HOST_OVERRIDE + ) stub_options = implementations.stub_options( - thread_pool_size=test_constants.POOL_SIZE) - self._dynamic_stub = implementations.dynamic_stub(channel, - _GROUP, - cardinalities, - options=stub_options) + thread_pool_size=test_constants.POOL_SIZE + ) + self._dynamic_stub = implementations.dynamic_stub( + channel, _GROUP, cardinalities, options=stub_options + ) def tearDown(self): self._dynamic_stub = None @@ -191,46 +205,56 @@ def tearDown(self): def test_unary_unary(self): call_options = interfaces.grpc_call_options( - disable_compression=True, credentials=self._call_credentials) - response = getattr(self._dynamic_stub, - _UNARY_UNARY)(_REQUEST, - test_constants.LONG_TIMEOUT, - protocol_options=call_options) + disable_compression=True, credentials=self._call_credentials + ) + response = getattr(self._dynamic_stub, _UNARY_UNARY)( + _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options + ) self.assertEqual(_RESPONSE, response) self.assertIsNotNone(self._servicer.peer()) invocation_metadata = [ (metadatum.key, metadatum.value) for metadatum in self._servicer._invocation_metadata ] - self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY, - _PER_RPC_CREDENTIALS_METADATA_VALUE), - invocation_metadata) + self.assertIn( + ( + _PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE, + ), + invocation_metadata, + ) def test_unary_stream(self): call_options = interfaces.grpc_call_options( - disable_compression=True, credentials=self._call_credentials) + disable_compression=True, credentials=self._call_credentials + ) response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)( - _REQUEST, - test_constants.LONG_TIMEOUT, - protocol_options=call_options) + _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options + ) self._servicer.block_until_serviced() self.assertIsNotNone(self._servicer.peer()) invocation_metadata = [ (metadatum.key, metadatum.value) for metadatum in self._servicer._invocation_metadata ] - self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY, - _PER_RPC_CREDENTIALS_METADATA_VALUE), - invocation_metadata) + self.assertIn( + ( + _PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE, + ), + invocation_metadata, + ) def test_stream_unary(self): call_options = interfaces.grpc_call_options( - credentials=self._call_credentials) + credentials=self._call_credentials + ) request_iterator = _BlockingIterator(iter((_REQUEST,))) response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future( request_iterator, test_constants.LONG_TIMEOUT, - protocol_options=call_options) + protocol_options=call_options, + ) response_future.protocol_context().disable_next_request_compression() request_iterator.allow() response_future.protocol_context().disable_next_request_compression() @@ -242,18 +266,24 @@ def test_stream_unary(self): (metadatum.key, metadatum.value) for metadatum in self._servicer._invocation_metadata ] - self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY, - _PER_RPC_CREDENTIALS_METADATA_VALUE), - invocation_metadata) + self.assertIn( + ( + _PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE, + ), + invocation_metadata, + ) def test_stream_stream(self): call_options = interfaces.grpc_call_options( - credentials=self._call_credentials) + credentials=self._call_credentials + ) request_iterator = _BlockingIterator(iter((_REQUEST,))) response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)( request_iterator, test_constants.SHORT_TIMEOUT, - protocol_options=call_options) + protocol_options=call_options, + ) response_iterator.protocol_context().disable_next_request_compression() request_iterator.allow() response = next(response_iterator) @@ -266,24 +296,31 @@ def test_stream_stream(self): (metadatum.key, metadatum.value) for metadatum in self._servicer._invocation_metadata ] - self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY, - _PER_RPC_CREDENTIALS_METADATA_VALUE), - invocation_metadata) + self.assertIn( + ( + _PER_RPC_CREDENTIALS_METADATA_KEY, + _PER_RPC_CREDENTIALS_METADATA_VALUE, + ), + invocation_metadata, + ) class ContextManagementAndLifecycleTest(unittest.TestCase): - def setUp(self): self._servicer = _Servicer() self._method_implementations = { - (_GROUP, _UNARY_UNARY): - utilities.unary_unary_inline(self._servicer.unary_unary), - (_GROUP, _UNARY_STREAM): - utilities.unary_stream_inline(self._servicer.unary_stream), - (_GROUP, _STREAM_UNARY): - utilities.stream_unary_inline(self._servicer.stream_unary), - (_GROUP, _STREAM_STREAM): - utilities.stream_stream_inline(self._servicer.stream_stream), + (_GROUP, _UNARY_UNARY): utilities.unary_unary_inline( + self._servicer.unary_unary + ), + (_GROUP, _UNARY_STREAM): utilities.unary_stream_inline( + self._servicer.unary_stream + ), + (_GROUP, _STREAM_UNARY): utilities.stream_unary_inline( + self._servicer.stream_unary + ), + (_GROUP, _STREAM_STREAM): utilities.stream_stream_inline( + self._servicer.stream_stream + ), } self._cardinalities = { @@ -294,41 +331,49 @@ def setUp(self): } self._server_options = implementations.server_options( - thread_pool_size=test_constants.POOL_SIZE) - self._server_credentials = implementations.ssl_server_credentials([ - ( - resources.private_key(), - resources.certificate_chain(), - ), - ]) + thread_pool_size=test_constants.POOL_SIZE + ) + self._server_credentials = implementations.ssl_server_credentials( + [ + ( + resources.private_key(), + resources.certificate_chain(), + ), + ] + ) self._channel_credentials = implementations.ssl_channel_credentials( - resources.test_root_certificates()) + resources.test_root_certificates() + ) self._stub_options = implementations.stub_options( - thread_pool_size=test_constants.POOL_SIZE) + thread_pool_size=test_constants.POOL_SIZE + ) def test_stub_context(self): - server = implementations.server(self._method_implementations, - options=self._server_options) - port = server.add_secure_port('[::]:0', self._server_credentials) + server = implementations.server( + self._method_implementations, options=self._server_options + ) + port = server.add_secure_port("[::]:0", self._server_credentials) server.start() channel = test_utilities.not_really_secure_channel( - 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE) - dynamic_stub = implementations.dynamic_stub(channel, - _GROUP, - self._cardinalities, - options=self._stub_options) + "localhost", port, self._channel_credentials, _SERVER_HOST_OVERRIDE + ) + dynamic_stub = implementations.dynamic_stub( + channel, _GROUP, self._cardinalities, options=self._stub_options + ) for _ in range(100): with dynamic_stub: pass for _ in range(10): with dynamic_stub: call_options = interfaces.grpc_call_options( - disable_compression=True) - response = getattr(dynamic_stub, - _UNARY_UNARY)(_REQUEST, - test_constants.LONG_TIMEOUT, - protocol_options=call_options) + disable_compression=True + ) + response = getattr(dynamic_stub, _UNARY_UNARY)( + _REQUEST, + test_constants.LONG_TIMEOUT, + protocol_options=call_options, + ) self.assertEqual(_RESPONSE, response) self.assertIsNotNone(self._servicer.peer()) @@ -336,20 +381,22 @@ def test_stub_context(self): def test_server_lifecycle(self): for _ in range(100): - server = implementations.server(self._method_implementations, - options=self._server_options) - port = server.add_secure_port('[::]:0', self._server_credentials) + server = implementations.server( + self._method_implementations, options=self._server_options + ) + port = server.add_secure_port("[::]:0", self._server_credentials) server.start() server.stop(test_constants.SHORT_TIMEOUT).wait() for _ in range(100): - server = implementations.server(self._method_implementations, - options=self._server_options) - server.add_secure_port('[::]:0', self._server_credentials) - server.add_insecure_port('[::]:0') + server = implementations.server( + self._method_implementations, options=self._server_options + ) + server.add_secure_port("[::]:0", self._server_credentials) + server.add_insecure_port("[::]:0") with server: server.stop(test_constants.SHORT_TIMEOUT) server.stop(test_constants.SHORT_TIMEOUT) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py b/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py index 1416902eab8e3..bdff2a52ecbd2 100644 --- a/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py +++ b/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py @@ -19,7 +19,6 @@ class ConnectivityStatesTest(unittest.TestCase): - def testBetaConnectivityStates(self): self.assertIsNotNone(interfaces.ChannelConnectivity.IDLE) self.assertIsNotNone(interfaces.ChannelConnectivity.CONNECTING) @@ -28,5 +27,5 @@ def testBetaConnectivityStates(self): self.assertIsNotNone(interfaces.ChannelConnectivity.FATAL_FAILURE) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py b/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py index cc2a2ea0a87c8..d9ab2568fd6c8 100644 --- a/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py +++ b/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py @@ -23,33 +23,39 @@ class ChannelCredentialsTest(unittest.TestCase): - def test_runtime_provided_root_certificates(self): channel_credentials = implementations.ssl_channel_credentials() - self.assertIsInstance(channel_credentials, - implementations.ChannelCredentials) + self.assertIsInstance( + channel_credentials, implementations.ChannelCredentials + ) def test_application_provided_root_certificates(self): channel_credentials = implementations.ssl_channel_credentials( - resources.test_root_certificates()) - self.assertIsInstance(channel_credentials, - implementations.ChannelCredentials) + resources.test_root_certificates() + ) + self.assertIsInstance( + channel_credentials, implementations.ChannelCredentials + ) class CallCredentialsTest(unittest.TestCase): - def test_google_call_credentials(self): creds = oauth2client_client.GoogleCredentials( - 'token', 'client_id', 'secret', 'refresh_token', - datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/', - 'user_agent') + "token", + "client_id", + "secret", + "refresh_token", + datetime.datetime(2008, 6, 24), + "https://refresh.uri.com/", + "user_agent", + ) call_creds = implementations.google_call_credentials(creds) self.assertIsInstance(call_creds, implementations.CallCredentials) def test_access_token_call_credentials(self): - call_creds = implementations.access_token_call_credentials('token') + call_creds = implementations.access_token_call_credentials("token") self.assertIsInstance(call_creds, implementations.CallCredentials) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py b/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py index 27fdecb8b7bbf..ad56155ce5068 100644 --- a/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py +++ b/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py @@ -23,11 +23,10 @@ class NotFoundTest(unittest.TestCase): - def setUp(self): self._server = implementations.server({}) - port = self._server.add_insecure_port('[::]:0') - channel = implementations.insecure_channel('localhost', port) + port = self._server.add_insecure_port("[::]:0") + channel = implementations.insecure_channel("localhost", port) self._generic_stub = implementations.generic_stub(channel) self._server.start() @@ -37,24 +36,32 @@ def tearDown(self): def test_blocking_unary_unary_not_found(self): with self.assertRaises(face.LocalError) as exception_assertion_context: - self._generic_stub.blocking_unary_unary('groop', - 'meffod', - b'abc', - test_constants.LONG_TIMEOUT, - with_call=True) - self.assertIs(exception_assertion_context.exception.code, - interfaces.StatusCode.UNIMPLEMENTED) + self._generic_stub.blocking_unary_unary( + "groop", + "meffod", + b"abc", + test_constants.LONG_TIMEOUT, + with_call=True, + ) + self.assertIs( + exception_assertion_context.exception.code, + interfaces.StatusCode.UNIMPLEMENTED, + ) def test_future_stream_unary_not_found(self): rpc_future = self._generic_stub.future_stream_unary( - 'grupe', 'mevvod', iter([b'def']), test_constants.LONG_TIMEOUT) + "grupe", "mevvod", iter([b"def"]), test_constants.LONG_TIMEOUT + ) with self.assertRaises(face.LocalError) as exception_assertion_context: rpc_future.result() - self.assertIs(exception_assertion_context.exception.code, - interfaces.StatusCode.UNIMPLEMENTED) - self.assertIs(rpc_future.exception().code, - interfaces.StatusCode.UNIMPLEMENTED) + self.assertIs( + exception_assertion_context.exception.code, + interfaces.StatusCode.UNIMPLEMENTED, + ) + self.assertIs( + rpc_future.exception().code, interfaces.StatusCode.UNIMPLEMENTED + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py b/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py index 25773036f1644..30970eaa1233d 100644 --- a/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py +++ b/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py @@ -25,7 +25,6 @@ class _Callback(object): - def __init__(self): self._condition = threading.Condition() self._value = None @@ -42,11 +41,10 @@ def block_until_called(self): return self._value -@unittest.skip('https://github.com/grpc/grpc/issues/16134') +@unittest.skip("https://github.com/grpc/grpc/issues/16134") class ChannelConnectivityTest(unittest.TestCase): - def test_lonely_channel_connectivity(self): - channel = implementations.insecure_channel('localhost', 12345) + channel = implementations.insecure_channel("localhost", 12345) callback = _Callback() ready_future = utilities.channel_ready_future(channel) @@ -65,16 +63,17 @@ def test_lonely_channel_connectivity(self): def test_immediately_connectable_channel_connectivity(self): server = implementations.server({}) - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") server.start() - channel = implementations.insecure_channel('localhost', port) + channel = implementations.insecure_channel("localhost", port) callback = _Callback() try: ready_future = utilities.channel_ready_future(channel) ready_future.add_done_callback(callback.accept_value) self.assertIsNone( - ready_future.result(timeout=test_constants.LONG_TIMEOUT)) + ready_future.result(timeout=test_constants.LONG_TIMEOUT) + ) value_passed_to_callback = callback.block_until_called() self.assertIs(ready_future, value_passed_to_callback) self.assertFalse(ready_future.cancelled()) @@ -90,5 +89,5 @@ def test_immediately_connectable_channel_connectivity(self): server.stop(0) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/beta/test_utilities.py b/src/python/grpcio_tests/tests/unit/beta/test_utilities.py index c8d920d35e9df..b6519e5a72416 100644 --- a/src/python/grpcio_tests/tests/unit/beta/test_utilities.py +++ b/src/python/grpcio_tests/tests/unit/beta/test_utilities.py @@ -17,24 +17,31 @@ from grpc.beta import implementations -def not_really_secure_channel(host, port, channel_credentials, - server_host_override): +def not_really_secure_channel( + host, port, channel_credentials, server_host_override +): """Creates an insecure Channel to a remote host. - Args: - host: The name of the remote host to which to connect. - port: The port of the remote host to which to connect. - channel_credentials: The implementations.ChannelCredentials with which to - connect. - server_host_override: The target name used for SSL host name checking. + Args: + host: The name of the remote host to which to connect. + port: The port of the remote host to which to connect. + channel_credentials: The implementations.ChannelCredentials with which to + connect. + server_host_override: The target name used for SSL host name checking. - Returns: - An implementations.Channel to the remote host through which RPCs may be - conducted. - """ - target = '%s:%d' % (host, port) - channel = grpc.secure_channel(target, channel_credentials, (( - 'grpc.ssl_target_name_override', - server_host_override, - ),)) + Returns: + An implementations.Channel to the remote host through which RPCs may be + conducted. + """ + target = "%s:%d" % (host, port) + channel = grpc.secure_channel( + target, + channel_credentials, + ( + ( + "grpc.ssl_target_name_override", + server_host_override, + ), + ), + ) return implementations.Channel(channel) diff --git a/src/python/grpcio_tests/tests/unit/framework/common/__init__.py b/src/python/grpcio_tests/tests/unit/framework/common/__init__.py index 709f6175b2e1f..d3c36a31eefe7 100644 --- a/src/python/grpcio_tests/tests/unit/framework/common/__init__.py +++ b/src/python/grpcio_tests/tests/unit/framework/common/__init__.py @@ -17,16 +17,20 @@ import os import socket -_DEFAULT_SOCK_OPTIONS = (socket.SO_REUSEADDR, - socket.SO_REUSEPORT) if os.name != 'nt' else ( - socket.SO_REUSEADDR,) +_DEFAULT_SOCK_OPTIONS = ( + (socket.SO_REUSEADDR, socket.SO_REUSEPORT) + if os.name != "nt" + else (socket.SO_REUSEADDR,) +) _UNRECOVERABLE_ERRNOS = (errno.EADDRINUSE, errno.ENOSR) -def get_socket(bind_address='localhost', - port=0, - listen=True, - sock_options=_DEFAULT_SOCK_OPTIONS): +def get_socket( + bind_address="localhost", + port=0, + listen=True, + sock_options=_DEFAULT_SOCK_OPTIONS, +): """Opens a socket. Useful for reserving a port for a system-under-test. @@ -47,7 +51,7 @@ def get_socket(bind_address='localhost', if socket.has_ipv6: address_families = (socket.AF_INET6, socket.AF_INET) else: - address_families = (socket.AF_INET) + address_families = socket.AF_INET for address_family in address_families: try: sock = socket.socket(address_family, socket.SOCK_STREAM) @@ -68,15 +72,20 @@ def get_socket(bind_address='localhost', except socket.error: # pylint: disable=duplicate-except sock.close() continue - raise RuntimeError("Failed to bind to {} with sock_options {}".format( - bind_address, sock_options)) + raise RuntimeError( + "Failed to bind to {} with sock_options {}".format( + bind_address, sock_options + ) + ) @contextlib.contextmanager -def bound_socket(bind_address='localhost', - port=0, - listen=True, - sock_options=_DEFAULT_SOCK_OPTIONS): +def bound_socket( + bind_address="localhost", + port=0, + listen=True, + sock_options=_DEFAULT_SOCK_OPTIONS, +): """Opens a socket bound to an arbitrary port. Useful for reserving a port for a system-under-test. @@ -92,10 +101,12 @@ def bound_socket(bind_address='localhost', - the address to which the socket is bound - the port to which the socket is bound """ - host, port, sock = get_socket(bind_address=bind_address, - port=port, - listen=listen, - sock_options=sock_options) + host, port, sock = get_socket( + bind_address=bind_address, + port=port, + listen=listen, + sock_options=sock_options, + ) try: yield host, port finally: diff --git a/src/python/grpcio_tests/tests/unit/framework/common/test_control.py b/src/python/grpcio_tests/tests/unit/framework/common/test_control.py index 8ef040751ebf9..849c33bd2d00e 100644 --- a/src/python/grpcio_tests/tests/unit/framework/common/test_control.py +++ b/src/python/grpcio_tests/tests/unit/framework/common/test_control.py @@ -21,26 +21,26 @@ class Defect(Exception): """Simulates a programming defect raised into in a system under test. - Use of a standard exception type is too easily misconstrued as an actual - defect in either the test infrastructure or the system under test. - """ + Use of a standard exception type is too easily misconstrued as an actual + defect in either the test infrastructure or the system under test. + """ class NestedDefect(Exception): """Simulates a nested programming defect raised into in a system under test.""" def __str__(self): - raise Exception('Nested Exception') + raise Exception("Nested Exception") class Control(abc.ABC): """An object that accepts program control from a system under test. - Systems under test passed a Control should call its control() method - frequently during execution. The control() method may block, raise an - exception, or do nothing, all according to the enclosing test's desire for - the system under test to simulate freezing, failing, or functioning. - """ + Systems under test passed a Control should call its control() method + frequently during execution. The control() method may block, raise an + exception, or do nothing, all according to the enclosing test's desire for + the system under test to simulate freezing, failing, or functioning. + """ @abc.abstractmethod def control(self): @@ -51,10 +51,10 @@ def control(self): class PauseFailControl(Control): """A Control that can be used to pause or fail code under control. - This object is only safe for use from two threads: one of the system under - test calling control and the other from the test system calling pause, - block_until_paused, and fail. - """ + This object is only safe for use from two threads: one of the system under + test calling control and the other from the test system calling pause, + block_until_paused, and fail. + """ def __init__(self): self._condition = threading.Condition() @@ -86,8 +86,8 @@ def pause(self): def block_until_paused(self): """Blocks controlling code until code under control is paused. - May only be called within the context of a pause call. - """ + May only be called within the context of a pause call. + """ with self._condition: while not self._paused: self._condition.wait() diff --git a/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py b/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py index c4ea03177cc88..684efb5e8422a 100644 --- a/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py +++ b/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py @@ -22,7 +22,6 @@ class _CallableObject(object): - def __init__(self): self._lock = threading.Lock() self._passed_values = [] @@ -37,7 +36,6 @@ def passed_values(self): class LoggingPoolTest(unittest.TestCase): - def testUpAndDown(self): pool = logging_pool.pool(_POOL_SIZE) pool.shutdown(wait=True) @@ -65,9 +63,10 @@ def testCallableObjectExecuted(self): with logging_pool.pool(_POOL_SIZE) as pool: future = pool.submit(callable_object, passed_object) self.assertIsNone(future.result()) - self.assertSequenceEqual((passed_object,), - callable_object.passed_values()) + self.assertSequenceEqual( + (passed_object,), callable_object.passed_values() + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py b/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py index dd5c5b3b031b4..dc16173766c53 100644 --- a/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py +++ b/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py @@ -19,10 +19,10 @@ class TestConsumer(stream.Consumer): """A stream.Consumer instrumented for testing. - Attributes: - calls: A sequence of value-termination pairs describing the history of calls - made on this object. - """ + Attributes: + calls: A sequence of value-termination pairs describing the history of calls + made on this object. + """ def __init__(self): self.calls = [] diff --git a/src/python/grpcio_tests/tests/unit/resources.py b/src/python/grpcio_tests/tests/unit/resources.py index 6efd870fc86d9..a0975388727f2 100644 --- a/src/python/grpcio_tests/tests/unit/resources.py +++ b/src/python/grpcio_tests/tests/unit/resources.py @@ -16,9 +16,9 @@ import os import pkgutil -_ROOT_CERTIFICATES_RESOURCE_PATH = 'credentials/ca.pem' -_PRIVATE_KEY_RESOURCE_PATH = 'credentials/server1.key' -_CERTIFICATE_CHAIN_RESOURCE_PATH = 'credentials/server1.pem' +_ROOT_CERTIFICATES_RESOURCE_PATH = "credentials/ca.pem" +_PRIVATE_KEY_RESOURCE_PATH = "credentials/server1.key" +_CERTIFICATE_CHAIN_RESOURCE_PATH = "credentials/server1.pem" def test_root_certificates(): @@ -35,79 +35,81 @@ def certificate_chain(): def cert_hier_1_root_ca_cert(): return pkgutil.get_data( - __name__, 'credentials/certificate_hierarchy_1/certs/ca.cert.pem') + __name__, "credentials/certificate_hierarchy_1/certs/ca.cert.pem" + ) def cert_hier_1_intermediate_ca_cert(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_1/intermediate/certs/intermediate.cert.pem' + "credentials/certificate_hierarchy_1/intermediate/certs/intermediate.cert.pem", ) def cert_hier_1_client_1_key(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_1/intermediate/private/client.key.pem' + "credentials/certificate_hierarchy_1/intermediate/private/client.key.pem", ) def cert_hier_1_client_1_cert(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_1/intermediate/certs/client.cert.pem' + "credentials/certificate_hierarchy_1/intermediate/certs/client.cert.pem", ) def cert_hier_1_server_1_key(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_1/intermediate/private/localhost-1.key.pem' + "credentials/certificate_hierarchy_1/intermediate/private/localhost-1.key.pem", ) def cert_hier_1_server_1_cert(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_1/intermediate/certs/localhost-1.cert.pem' + "credentials/certificate_hierarchy_1/intermediate/certs/localhost-1.cert.pem", ) def cert_hier_2_root_ca_cert(): return pkgutil.get_data( - __name__, 'credentials/certificate_hierarchy_2/certs/ca.cert.pem') + __name__, "credentials/certificate_hierarchy_2/certs/ca.cert.pem" + ) def cert_hier_2_intermediate_ca_cert(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_2/intermediate/certs/intermediate.cert.pem' + "credentials/certificate_hierarchy_2/intermediate/certs/intermediate.cert.pem", ) def cert_hier_2_client_1_key(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_2/intermediate/private/client.key.pem' + "credentials/certificate_hierarchy_2/intermediate/private/client.key.pem", ) def cert_hier_2_client_1_cert(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_2/intermediate/certs/client.cert.pem' + "credentials/certificate_hierarchy_2/intermediate/certs/client.cert.pem", ) def cert_hier_2_server_1_key(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_2/intermediate/private/localhost-1.key.pem' + "credentials/certificate_hierarchy_2/intermediate/private/localhost-1.key.pem", ) def cert_hier_2_server_1_cert(): return pkgutil.get_data( __name__, - 'credentials/certificate_hierarchy_2/intermediate/certs/localhost-1.cert.pem' + "credentials/certificate_hierarchy_2/intermediate/certs/localhost-1.cert.pem", ) diff --git a/src/python/grpcio_tests/tests/unit/test_common.py b/src/python/grpcio_tests/tests/unit/test_common.py index 0889cc9617a5e..97dcbf5910189 100644 --- a/src/python/grpcio_tests/tests/unit/test_common.py +++ b/src/python/grpcio_tests/tests/unit/test_common.py @@ -20,42 +20,42 @@ import grpc INVOCATION_INITIAL_METADATA = ( - ('0', 'abc'), - ('1', 'def'), - ('2', 'ghi'), + ("0", "abc"), + ("1", "def"), + ("2", "ghi"), ) SERVICE_INITIAL_METADATA = ( - ('3', 'jkl'), - ('4', 'mno'), - ('5', 'pqr'), + ("3", "jkl"), + ("4", "mno"), + ("5", "pqr"), ) SERVICE_TERMINAL_METADATA = ( - ('6', 'stu'), - ('7', 'vwx'), - ('8', 'yza'), + ("6", "stu"), + ("7", "vwx"), + ("8", "yza"), ) -DETAILS = 'test details' +DETAILS = "test details" def metadata_transmitted(original_metadata, transmitted_metadata): """Judges whether or not metadata was acceptably transmitted. - gRPC is allowed to insert key-value pairs into the metadata values given by - applications and to reorder key-value pairs with different keys but it is not - allowed to alter existing key-value pairs or to reorder key-value pairs with - the same key. - - Args: - original_metadata: A metadata value used in a test of gRPC. An iterable over - iterables of length 2. - transmitted_metadata: A metadata value corresponding to original_metadata - after having been transmitted via gRPC. An iterable over iterables of - length 2. - - Returns: - A boolean indicating whether transmitted_metadata accurately reflects - original_metadata after having been transmitted via gRPC. - """ + gRPC is allowed to insert key-value pairs into the metadata values given by + applications and to reorder key-value pairs with different keys but it is not + allowed to alter existing key-value pairs or to reorder key-value pairs with + the same key. + + Args: + original_metadata: A metadata value used in a test of gRPC. An iterable over + iterables of length 2. + transmitted_metadata: A metadata value corresponding to original_metadata + after having been transmitted via gRPC. An iterable over iterables of + length 2. + + Returns: + A boolean indicating whether transmitted_metadata accurately reflects + original_metadata after having been transmitted via gRPC. + """ original = collections.defaultdict(list) for key, value in original_metadata: original[key].append(value) @@ -81,35 +81,42 @@ def metadata_transmitted(original_metadata, transmitted_metadata): def test_secure_channel(target, channel_credentials, server_host_override): """Creates an insecure Channel to a remote host. - Args: - host: The name of the remote host to which to connect. - port: The port of the remote host to which to connect. - channel_credentials: The implementations.ChannelCredentials with which to - connect. - server_host_override: The target name used for SSL host name checking. - - Returns: - An implementations.Channel to the remote host through which RPCs may be - conducted. - """ - channel = grpc.secure_channel(target, channel_credentials, (( - 'grpc.ssl_target_name_override', - server_host_override, - ),)) + Args: + host: The name of the remote host to which to connect. + port: The port of the remote host to which to connect. + channel_credentials: The implementations.ChannelCredentials with which to + connect. + server_host_override: The target name used for SSL host name checking. + + Returns: + An implementations.Channel to the remote host through which RPCs may be + conducted. + """ + channel = grpc.secure_channel( + target, + channel_credentials, + ( + ( + "grpc.ssl_target_name_override", + server_host_override, + ), + ), + ) return channel def test_server(max_workers=10, reuse_port=False): """Creates an insecure grpc server. - These servers have SO_REUSEPORT disabled to prevent cross-talk. - """ - return grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), - options=(('grpc.so_reuseport', int(reuse_port)),)) + These servers have SO_REUSEPORT disabled to prevent cross-talk. + """ + return grpc.server( + futures.ThreadPoolExecutor(max_workers=max_workers), + options=(("grpc.so_reuseport", int(reuse_port)),), + ) class WaitGroup(object): - def __init__(self, n=0): self.count = n self.cv = threading.Condition() @@ -141,4 +148,5 @@ def running_under_gevent(): return False else: import socket + return socket.socket is gevent.socket.socket diff --git a/src/python/grpcio_tests/tests_aio/_sanity/_sanity_test.py b/src/python/grpcio_tests/tests_aio/_sanity/_sanity_test.py index e74dec0739b4a..db894df547cac 100644 --- a/src/python/grpcio_tests/tests_aio/_sanity/_sanity_test.py +++ b/src/python/grpcio_tests/tests_aio/_sanity/_sanity_test.py @@ -18,10 +18,9 @@ class AioSanityTest(_sanity_test.SanityTest): + TEST_PKG_MODULE_NAME = "tests_aio" + TEST_PKG_PATH = "tests_aio" - TEST_PKG_MODULE_NAME = 'tests_aio' - TEST_PKG_PATH = 'tests_aio' - -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/benchmark/benchmark_client.py b/src/python/grpcio_tests/tests_aio/benchmark/benchmark_client.py index 301dba4bae8ea..7daf88556ed31 100644 --- a/src/python/grpcio_tests/tests_aio/benchmark/benchmark_client.py +++ b/src/python/grpcio_tests/tests_aio/benchmark/benchmark_client.py @@ -30,59 +30,78 @@ class GenericStub(object): - def __init__(self, channel: aio.Channel): self.UnaryCall = channel.unary_unary( - '/grpc.testing.BenchmarkService/UnaryCall') + "/grpc.testing.BenchmarkService/UnaryCall" + ) self.StreamingFromServer = channel.unary_stream( - '/grpc.testing.BenchmarkService/StreamingFromServer') + "/grpc.testing.BenchmarkService/StreamingFromServer" + ) self.StreamingCall = channel.stream_stream( - '/grpc.testing.BenchmarkService/StreamingCall') + "/grpc.testing.BenchmarkService/StreamingCall" + ) class BenchmarkClient(abc.ABC): """Benchmark client interface that exposes a non-blocking send_request().""" - def __init__(self, address: str, config: control_pb2.ClientConfig, - hist: histogram.Histogram): + def __init__( + self, + address: str, + config: control_pb2.ClientConfig, + hist: histogram.Histogram, + ): # Disables underlying reuse of subchannels - unique_option = (('iv', random.random()),) + unique_option = (("iv", random.random()),) # Parses the channel argument from config channel_args = tuple( - (arg.name, arg.str_value) if arg.HasField('str_value') else ( - arg.name, int(arg.int_value)) for arg in config.channel_args) + (arg.name, arg.str_value) + if arg.HasField("str_value") + else (arg.name, int(arg.int_value)) + for arg in config.channel_args + ) # Creates the channel - if config.HasField('security_params'): + if config.HasField("security_params"): channel_credentials = grpc.ssl_channel_credentials( - resources.test_root_certificates(),) - server_host_override_option = (( - 'grpc.ssl_target_name_override', - config.security_params.server_host_override, - ),) + resources.test_root_certificates(), + ) + server_host_override_option = ( + ( + "grpc.ssl_target_name_override", + config.security_params.server_host_override, + ), + ) self._channel = aio.secure_channel( - address, channel_credentials, - unique_option + channel_args + server_host_override_option) + address, + channel_credentials, + unique_option + channel_args + server_host_override_option, + ) else: - self._channel = aio.insecure_channel(address, - options=unique_option + - channel_args) + self._channel = aio.insecure_channel( + address, options=unique_option + channel_args + ) # Creates the stub - if config.payload_config.WhichOneof('payload') == 'simple_params': + if config.payload_config.WhichOneof("payload") == "simple_params": self._generic = False self._stub = benchmark_service_pb2_grpc.BenchmarkServiceStub( - self._channel) + self._channel + ) payload = messages_pb2.Payload( - body=b'\0' * config.payload_config.simple_params.req_size) + body=b"\0" * config.payload_config.simple_params.req_size + ) self._request = messages_pb2.SimpleRequest( payload=payload, - response_size=config.payload_config.simple_params.resp_size) + response_size=config.payload_config.simple_params.resp_size, + ) else: self._generic = True self._stub = GenericStub(self._channel) - self._request = b'\0' * config.payload_config.bytebuf_params.req_size + self._request = ( + b"\0" * config.payload_config.bytebuf_params.req_size + ) self._hist = hist self._response_callbacks = [] @@ -99,9 +118,12 @@ def _record_query_time(self, query_time: float) -> None: class UnaryAsyncBenchmarkClient(BenchmarkClient): - - def __init__(self, address: str, config: control_pb2.ClientConfig, - hist: histogram.Histogram): + def __init__( + self, + address: str, + config: control_pb2.ClientConfig, + hist: histogram.Histogram, + ): super().__init__(address, config, hist) self._running = None self._stopped = asyncio.Event() @@ -129,9 +151,12 @@ async def stop(self) -> None: class StreamingAsyncBenchmarkClient(BenchmarkClient): - - def __init__(self, address: str, config: control_pb2.ClientConfig, - hist: histogram.Histogram): + def __init__( + self, + address: str, + config: control_pb2.ClientConfig, + hist: histogram.Histogram, + ): super().__init__(address, config, hist) self._running = None self._stopped = asyncio.Event() @@ -159,9 +184,12 @@ async def stop(self): class ServerStreamingAsyncBenchmarkClient(BenchmarkClient): - - def __init__(self, address: str, config: control_pb2.ClientConfig, - hist: histogram.Histogram): + def __init__( + self, + address: str, + config: control_pb2.ClientConfig, + hist: histogram.Histogram, + ): super().__init__(address, config, hist) self._running = None self._stopped = asyncio.Event() @@ -177,7 +205,8 @@ async def run(self): await super().run() self._running = True senders = ( - self._one_server_streaming_call() for _ in range(self._concurrency)) + self._one_server_streaming_call() for _ in range(self._concurrency) + ) await asyncio.gather(*senders) self._stopped.set() diff --git a/src/python/grpcio_tests/tests_aio/benchmark/benchmark_servicer.py b/src/python/grpcio_tests/tests_aio/benchmark/benchmark_servicer.py index b519554a56cba..2003ca3159f90 100644 --- a/src/python/grpcio_tests/tests_aio/benchmark/benchmark_servicer.py +++ b/src/python/grpcio_tests/tests_aio/benchmark/benchmark_servicer.py @@ -24,29 +24,29 @@ class BenchmarkServicer(benchmark_service_pb2_grpc.BenchmarkServiceServicer): - async def UnaryCall(self, request, unused_context): - payload = messages_pb2.Payload(body=b'\0' * request.response_size) + payload = messages_pb2.Payload(body=b"\0" * request.response_size) return messages_pb2.SimpleResponse(payload=payload) async def StreamingFromServer(self, request, unused_context): - payload = messages_pb2.Payload(body=b'\0' * request.response_size) + payload = messages_pb2.Payload(body=b"\0" * request.response_size) # Sends response at full capacity! while True: yield messages_pb2.SimpleResponse(payload=payload) async def StreamingCall(self, request_iterator, unused_context): async for request in request_iterator: - payload = messages_pb2.Payload(body=b'\0' * request.response_size) + payload = messages_pb2.Payload(body=b"\0" * request.response_size) yield messages_pb2.SimpleResponse(payload=payload) class GenericBenchmarkServicer( - benchmark_service_pb2_grpc.BenchmarkServiceServicer): + benchmark_service_pb2_grpc.BenchmarkServiceServicer +): """Generic (no-codec) Server implementation for the Benchmark service.""" def __init__(self, resp_size): - self._response = '\0' * resp_size + self._response = "\0" * resp_size async def UnaryCall(self, unused_request, unused_context): return self._response diff --git a/src/python/grpcio_tests/tests_aio/benchmark/server.py b/src/python/grpcio_tests/tests_aio/benchmark/server.py index 561298a626bb5..c118f7763e803 100644 --- a/src/python/grpcio_tests/tests_aio/benchmark/server.py +++ b/src/python/grpcio_tests/tests_aio/benchmark/server.py @@ -25,13 +25,14 @@ async def _start_async_server(): server = aio.server() - port = server.add_insecure_port('localhost:%s' % 50051) + port = server.add_insecure_port("localhost:%s" % 50051) servicer = benchmark_servicer.BenchmarkServicer() benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( - servicer, server) + servicer, server + ) await server.start() - logging.info('Benchmark server started at :%d' % port) + logging.info("Benchmark server started at :%d" % port) await server.wait_for_termination() @@ -41,6 +42,6 @@ def main(): loop.run_forever() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) main() diff --git a/src/python/grpcio_tests/tests_aio/benchmark/worker.py b/src/python/grpcio_tests/tests_aio/benchmark/worker.py index dc16f050872ec..68a5661b9482b 100644 --- a/src/python/grpcio_tests/tests_aio/benchmark/worker.py +++ b/src/python/grpcio_tests/tests_aio/benchmark/worker.py @@ -27,9 +27,10 @@ async def run_worker_server(port: int) -> None: servicer = worker_servicer.WorkerServicer() worker_service_pb2_grpc.add_WorkerServiceServicer_to_server( - servicer, server) + servicer, server + ) - server.add_insecure_port('[::]:{}'.format(port)) + server.add_insecure_port("[::]:{}".format(port)) await server.start() @@ -37,21 +38,25 @@ async def run_worker_server(port: int) -> None: await server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) parser = argparse.ArgumentParser( - description='gRPC Python performance testing worker') - parser.add_argument('--driver_port', - type=int, - dest='port', - help='The port the worker should listen on') - parser.add_argument('--uvloop', - action='store_true', - help='Use uvloop or not') + description="gRPC Python performance testing worker" + ) + parser.add_argument( + "--driver_port", + type=int, + dest="port", + help="The port the worker should listen on", + ) + parser.add_argument( + "--uvloop", action="store_true", help="Use uvloop or not" + ) args = parser.parse_args() if args.uvloop: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) loop = uvloop.new_event_loop() asyncio.set_event_loop(loop) diff --git a/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py b/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py index 684beddb4ab17..adee76fd0631d 100644 --- a/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py +++ b/src/python/grpcio_tests/tests_aio/benchmark/worker_servicer.py @@ -36,18 +36,19 @@ _NUM_CORES = multiprocessing.cpu_count() _WORKER_ENTRY_FILE = os.path.join( - os.path.split(os.path.abspath(__file__))[0], 'worker.py') + os.path.split(os.path.abspath(__file__))[0], "worker.py" +) _LOGGER = logging.getLogger(__name__) class _SubWorker( - collections.namedtuple('_SubWorker', - ['process', 'port', 'channel', 'stub'])): + collections.namedtuple("_SubWorker", ["process", "port", "channel", "stub"]) +): """A data class that holds information about a child qps worker.""" def _repr(self): - return f'<_SubWorker pid={self.process.pid} port={self.port}>' + return f"<_SubWorker pid={self.process.pid} port={self.port}>" def __repr__(self): return self._repr() @@ -56,80 +57,94 @@ def __str__(self): return self._repr() -def _get_server_status(start_time: float, end_time: float, - port: int) -> control_pb2.ServerStatus: +def _get_server_status( + start_time: float, end_time: float, port: int +) -> control_pb2.ServerStatus: """Creates ServerStatus proto message.""" end_time = time.monotonic() elapsed_time = end_time - start_time # TODO(lidiz) Collect accurate time system to compute QPS/core-second. - stats = stats_pb2.ServerStats(time_elapsed=elapsed_time, - time_user=elapsed_time, - time_system=elapsed_time) + stats = stats_pb2.ServerStats( + time_elapsed=elapsed_time, + time_user=elapsed_time, + time_system=elapsed_time, + ) return control_pb2.ServerStatus(stats=stats, port=port, cores=_NUM_CORES) def _create_server(config: control_pb2.ServerConfig) -> Tuple[aio.Server, int]: """Creates a server object according to the ServerConfig.""" channel_args = tuple( - (arg.name, - arg.str_value) if arg.HasField('str_value') else (arg.name, - int(arg.int_value)) - for arg in config.channel_args) + (arg.name, arg.str_value) + if arg.HasField("str_value") + else (arg.name, int(arg.int_value)) + for arg in config.channel_args + ) - server = aio.server(options=channel_args + (('grpc.so_reuseport', 1),)) + server = aio.server(options=channel_args + (("grpc.so_reuseport", 1),)) if config.server_type == control_pb2.ASYNC_SERVER: servicer = benchmark_servicer.BenchmarkServicer() benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server( - servicer, server) + servicer, server + ) elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER: resp_size = config.payload_config.bytebuf_params.resp_size servicer = benchmark_servicer.GenericBenchmarkServicer(resp_size) method_implementations = { - 'StreamingCall': - grpc.stream_stream_rpc_method_handler(servicer.StreamingCall), - 'UnaryCall': - grpc.unary_unary_rpc_method_handler(servicer.UnaryCall), + "StreamingCall": grpc.stream_stream_rpc_method_handler( + servicer.StreamingCall + ), + "UnaryCall": grpc.unary_unary_rpc_method_handler( + servicer.UnaryCall + ), } handler = grpc.method_handlers_generic_handler( - 'grpc.testing.BenchmarkService', method_implementations) + "grpc.testing.BenchmarkService", method_implementations + ) server.add_generic_rpc_handlers((handler,)) else: - raise NotImplementedError('Unsupported server type {}'.format( - config.server_type)) + raise NotImplementedError( + "Unsupported server type {}".format(config.server_type) + ) - if config.HasField('security_params'): # Use SSL + if config.HasField("security_params"): # Use SSL server_creds = grpc.ssl_server_credentials( - ((resources.private_key(), resources.certificate_chain()),)) - port = server.add_secure_port('[::]:{}'.format(config.port), - server_creds) + ((resources.private_key(), resources.certificate_chain()),) + ) + port = server.add_secure_port( + "[::]:{}".format(config.port), server_creds + ) else: - port = server.add_insecure_port('[::]:{}'.format(config.port)) + port = server.add_insecure_port("[::]:{}".format(config.port)) return server, port def _get_client_status( - start_time: float, end_time: float, - qps_data: histogram.Histogram) -> control_pb2.ClientStatus: + start_time: float, end_time: float, qps_data: histogram.Histogram +) -> control_pb2.ClientStatus: """Creates ClientStatus proto message.""" latencies = qps_data.get_data() end_time = time.monotonic() elapsed_time = end_time - start_time # TODO(lidiz) Collect accurate time system to compute QPS/core-second. - stats = stats_pb2.ClientStats(latencies=latencies, - time_elapsed=elapsed_time, - time_user=elapsed_time, - time_system=elapsed_time) + stats = stats_pb2.ClientStats( + latencies=latencies, + time_elapsed=elapsed_time, + time_user=elapsed_time, + time_system=elapsed_time, + ) return control_pb2.ClientStatus(stats=stats) def _create_client( - server: str, config: control_pb2.ClientConfig, - qps_data: histogram.Histogram) -> benchmark_client.BenchmarkClient: + server: str, config: control_pb2.ClientConfig, qps_data: histogram.Histogram +) -> benchmark_client.BenchmarkClient: """Creates a client object according to the ClientConfig.""" - if config.load_params.WhichOneof('load') != 'closed_loop': + if config.load_params.WhichOneof("load") != "closed_loop": raise NotImplementedError( - f'Unsupported load parameter {config.load_params}') + f"Unsupported load parameter {config.load_params}" + ) if config.client_type == control_pb2.ASYNC_CLIENT: if config.rpc_type == control_pb2.UNARY: @@ -140,10 +155,12 @@ def _create_client( client_type = benchmark_client.ServerStreamingAsyncBenchmarkClient else: raise NotImplementedError( - f'Unsupported rpc_type [{config.rpc_type}]') + f"Unsupported rpc_type [{config.rpc_type}]" + ) else: raise NotImplementedError( - f'Unsupported client type {config.client_type}') + f"Unsupported client type {config.client_type}" + ) return client_type(server, config, qps_data) @@ -159,14 +176,17 @@ async def _create_sub_worker() -> _SubWorker: """Creates a child qps worker as a subprocess.""" port = _pick_an_unused_port() - _LOGGER.info('Creating sub worker at port [%d]...', port) - process = await asyncio.create_subprocess_exec(sys.executable, - _WORKER_ENTRY_FILE, - '--driver_port', str(port)) - _LOGGER.info('Created sub worker process for port [%d] at pid [%d]', port, - process.pid) - channel = aio.insecure_channel(f'localhost:{port}') - _LOGGER.info('Waiting for sub worker at port [%d]', port) + _LOGGER.info("Creating sub worker at port [%d]...", port) + process = await asyncio.create_subprocess_exec( + sys.executable, _WORKER_ENTRY_FILE, "--driver_port", str(port) + ) + _LOGGER.info( + "Created sub worker process for port [%d] at pid [%d]", + port, + process.pid, + ) + channel = aio.insecure_channel(f"localhost:{port}") + _LOGGER.info("Waiting for sub worker at port [%d]", port) await channel.channel_ready() stub = worker_service_pb2_grpc.WorkerServiceStub(channel) return _SubWorker( @@ -187,7 +207,7 @@ def __init__(self): async def _run_single_server(self, config, request_iterator, context): server, port = _create_server(config) await server.start() - _LOGGER.info('Server started at port [%d]', port) + _LOGGER.info("Server started at port [%d]", port) start_time = time.monotonic() await context.write(_get_server_status(start_time, start_time, port)) @@ -203,15 +223,15 @@ async def _run_single_server(self, config, request_iterator, context): async def RunServer(self, request_iterator, context): config_request = await context.read() config = config_request.setup - _LOGGER.info('Received ServerConfig: %s', config) + _LOGGER.info("Received ServerConfig: %s", config) if config.server_processes <= 0: - _LOGGER.info('Using server_processes == [%d]', _NUM_CORES) + _LOGGER.info("Using server_processes == [%d]", _NUM_CORES) config.server_processes = _NUM_CORES if config.port == 0: config.port = _pick_an_unused_port() - _LOGGER.info('Port picked [%d]', config.port) + _LOGGER.info("Port picked [%d]", config.port) if config.server_processes == 1: # If server_processes == 1, start the server in this process. @@ -219,7 +239,8 @@ async def RunServer(self, request_iterator, context): else: # If server_processes > 1, offload to other processes. sub_workers = await asyncio.gather( - *[_create_sub_worker() for _ in range(config.server_processes)]) + *[_create_sub_worker() for _ in range(config.server_processes)] + ) calls = [worker.stub.RunServer() for worker in sub_workers] @@ -236,9 +257,10 @@ async def RunServer(self, request_iterator, context): start_time, start_time, config.port, - )) + ) + ) - _LOGGER.info('Servers are ready to serve.') + _LOGGER.info("Servers are ready to serve.") async for request in request_iterator: end_time = time.monotonic() @@ -263,20 +285,22 @@ async def RunServer(self, request_iterator, context): for worker in sub_workers: await worker.stub.QuitWorker(control_pb2.Void()) await worker.channel.close() - _LOGGER.info('Waiting for [%s] to quit...', worker) + _LOGGER.info("Waiting for [%s] to quit...", worker) await worker.process.wait() async def _run_single_client(self, config, request_iterator, context): running_tasks = [] - qps_data = histogram.Histogram(config.histogram_params.resolution, - config.histogram_params.max_possible) + qps_data = histogram.Histogram( + config.histogram_params.resolution, + config.histogram_params.max_possible, + ) start_time = time.monotonic() # Create a client for each channel as asyncio.Task for i in range(config.client_channels): server = config.server_targets[i % len(config.server_targets)] client = _create_client(server, config, qps_data) - _LOGGER.info('Client created against server [%s]', server) + _LOGGER.info("Client created against server [%s]", server) running_tasks.append(self._loop.create_task(client.run())) end_time = time.monotonic() @@ -298,12 +322,13 @@ async def _run_single_client(self, config, request_iterator, context): async def RunClient(self, request_iterator, context): config_request = await context.read() config = config_request.setup - _LOGGER.info('Received ClientConfig: %s', config) + _LOGGER.info("Received ClientConfig: %s", config) if config.client_processes <= 0: - _LOGGER.info('client_processes can\'t be [%d]', - config.client_processes) - _LOGGER.info('Using client_processes == [%d]', _NUM_CORES) + _LOGGER.info( + "client_processes can't be [%d]", config.client_processes + ) + _LOGGER.info("Using client_processes == [%d]", _NUM_CORES) config.client_processes = _NUM_CORES if config.client_processes == 1: @@ -312,7 +337,8 @@ async def RunClient(self, request_iterator, context): else: # If client_processes > 1, offload the work to other processes. sub_workers = await asyncio.gather( - *[_create_sub_worker() for _ in range(config.client_processes)]) + *[_create_sub_worker() for _ in range(config.client_processes)] + ) calls = [worker.stub.RunClient() for worker in sub_workers] @@ -324,29 +350,35 @@ async def RunClient(self, request_iterator, context): await call.read() start_time = time.monotonic() - result = histogram.Histogram(config.histogram_params.resolution, - config.histogram_params.max_possible) + result = histogram.Histogram( + config.histogram_params.resolution, + config.histogram_params.max_possible, + ) end_time = time.monotonic() - await context.write(_get_client_status(start_time, end_time, - result)) + await context.write( + _get_client_status(start_time, end_time, result) + ) async for request in request_iterator: end_time = time.monotonic() for call in calls: - _LOGGER.debug('Fetching status...') + _LOGGER.debug("Fetching status...") await call.write(request) sub_status = await call.read() result.merge(sub_status.stats.latencies) - _LOGGER.debug('Update from sub worker count=[%d]', - sub_status.stats.latencies.count) + _LOGGER.debug( + "Update from sub worker count=[%d]", + sub_status.stats.latencies.count, + ) status = _get_client_status(start_time, end_time, result) if request.mark.reset: result.reset() start_time = time.monotonic() - _LOGGER.debug('Reporting count=[%d]', - status.stats.latencies.count) + _LOGGER.debug( + "Reporting count=[%d]", status.stats.latencies.count + ) await context.write(status) for call in calls: @@ -355,16 +387,16 @@ async def RunClient(self, request_iterator, context): for worker in sub_workers: await worker.stub.QuitWorker(control_pb2.Void()) await worker.channel.close() - _LOGGER.info('Waiting for sub worker [%s] to quit...', worker) + _LOGGER.info("Waiting for sub worker [%s] to quit...", worker) await worker.process.wait() - _LOGGER.info('Sub worker [%s] quit', worker) + _LOGGER.info("Sub worker [%s] quit", worker) @staticmethod async def CoreCount(unused_request, unused_context): return control_pb2.CoreResponse(cores=_NUM_CORES) async def QuitWorker(self, unused_request, unused_context): - _LOGGER.info('QuitWorker command received.') + _LOGGER.info("QuitWorker command received.") self._quit_event.set() return control_pb2.Void() diff --git a/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py b/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py index 8c62826a64166..12b8fd4163295 100644 --- a/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py +++ b/src/python/grpcio_tests/tests_aio/channelz/channelz_servicer_test.py @@ -26,16 +26,16 @@ from tests.unit.framework.common import test_constants from tests_aio.unit._test_base import AioTestBase -_SUCCESSFUL_UNARY_UNARY = '/test/SuccessfulUnaryUnary' -_FAILED_UNARY_UNARY = '/test/FailedUnaryUnary' -_SUCCESSFUL_STREAM_STREAM = '/test/SuccessfulStreamStream' +_SUCCESSFUL_UNARY_UNARY = "/test/SuccessfulUnaryUnary" +_FAILED_UNARY_UNARY = "/test/FailedUnaryUnary" +_SUCCESSFUL_STREAM_STREAM = "/test/SuccessfulStreamStream" -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x01\x01\x01' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x01\x01\x01" -_DISABLE_REUSE_PORT = (('grpc.so_reuseport', 0),) -_ENABLE_CHANNELZ = (('grpc.enable_channelz', 1),) -_DISABLE_CHANNELZ = (('grpc.enable_channelz', 0),) +_DISABLE_REUSE_PORT = (("grpc.so_reuseport", 0),) +_ENABLE_CHANNELZ = (("grpc.enable_channelz", 1),) +_DISABLE_CHANNELZ = (("grpc.enable_channelz", 0),) _LARGE_UNASSIGNED_ID = 10000 @@ -55,7 +55,6 @@ async def _successful_stream_stream(request_iterator, servicer_context): class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _SUCCESSFUL_UNARY_UNARY: return grpc.unary_unary_rpc_method_handler(_successful_unary_unary) @@ -63,15 +62,15 @@ def service(self, handler_call_details): return grpc.unary_unary_rpc_method_handler(_failed_unary_unary) elif handler_call_details.method == _SUCCESSFUL_STREAM_STREAM: return grpc.stream_stream_rpc_method_handler( - _successful_stream_stream) + _successful_stream_stream + ) else: return None class _ChannelServerPair: - def __init__(self): - self.address = '' + self.address = "" self.server = None self.channel = None self.server_ref_id = None @@ -80,24 +79,27 @@ def __init__(self): async def start(self): # Server will enable channelz service self.server = aio.server(options=_DISABLE_REUSE_PORT + _ENABLE_CHANNELZ) - port = self.server.add_insecure_port('[::]:0') - self.address = 'localhost:%d' % port + port = self.server.add_insecure_port("[::]:0") + self.address = "localhost:%d" % port self.server.add_generic_rpc_handlers((_GenericHandler(),)) await self.server.start() # Channel will enable channelz service... - self.channel = aio.insecure_channel(self.address, - options=_ENABLE_CHANNELZ) + self.channel = aio.insecure_channel( + self.address, options=_ENABLE_CHANNELZ + ) async def bind_channelz(self, channelz_stub): resp = await channelz_stub.GetTopChannels( - channelz_pb2.GetTopChannelsRequest(start_channel_id=0)) + channelz_pb2.GetTopChannelsRequest(start_channel_id=0) + ) for channel in resp.channel: if channel.data.target == self.address: self.channel_ref_id = channel.ref.channel_id resp = await channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=0)) + channelz_pb2.GetServersRequest(start_server_id=0) + ) self.server_ref_id = resp.server[-1].ref.server_id async def stop(self): @@ -121,20 +123,21 @@ async def _destroy_channel_server_pairs(pairs): class ChannelzServicerTest(AioTestBase): - async def setUp(self): # This server is for Channelz info fetching only # It self should not enable Channelz - self._server = aio.server(options=_DISABLE_REUSE_PORT + - _DISABLE_CHANNELZ) - port = self._server.add_insecure_port('[::]:0') + self._server = aio.server( + options=_DISABLE_REUSE_PORT + _DISABLE_CHANNELZ + ) + port = self._server.add_insecure_port("[::]:0") channelz.add_channelz_servicer(self._server) await self._server.start() # This channel is used to fetch Channelz info only # Channelz should not be enabled - self._channel = aio.insecure_channel('localhost:%d' % port, - options=_DISABLE_CHANNELZ) + self._channel = aio.insecure_channel( + "localhost:%d" % port, options=_DISABLE_CHANNELZ + ) self._channelz_stub = channelz_pb2_grpc.ChannelzStub(self._channel) async def tearDown(self): @@ -144,7 +147,8 @@ async def tearDown(self): async def _get_server_by_ref_id(self, ref_id): """Server id may not be consecutive""" resp = await self._channelz_stub.GetServers( - channelz_pb2.GetServersRequest(start_server_id=ref_id)) + channelz_pb2.GetServersRequest(start_server_id=ref_id) + ) self.assertEqual(ref_id, resp.server[0].ref.server_id) return resp.server[0] @@ -161,8 +165,9 @@ async def _send_failed_unary_unary(self, pair): self.fail("This call supposed to fail") async def _send_successful_stream_stream(self, pair): - call = pair.channel.stream_stream(_SUCCESSFUL_STREAM_STREAM)(iter( - [_REQUEST] * test_constants.STREAM_LENGTH)) + call = pair.channel.stream_stream(_SUCCESSFUL_STREAM_STREAM)( + iter([_REQUEST] * test_constants.STREAM_LENGTH) + ) cnt = 0 async for _ in call: cnt += 1 @@ -173,7 +178,9 @@ async def test_get_top_channels_high_start_id(self): resp = await self._channelz_stub.GetTopChannels( channelz_pb2.GetTopChannelsRequest( - start_channel_id=_LARGE_UNASSIGNED_ID)) + start_channel_id=_LARGE_UNASSIGNED_ID + ) + ) self.assertEqual(len(resp.channel), 0) self.assertEqual(resp.end, True) @@ -184,7 +191,8 @@ async def test_successful_request(self): await self._send_successful_unary_unary(pairs[0]) resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id) + ) self.assertEqual(resp.channel.data.calls_started, 1) self.assertEqual(resp.channel.data.calls_succeeded, 1) @@ -197,7 +205,8 @@ async def test_failed_request(self): await self._send_failed_unary_unary(pairs[0]) resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id) + ) self.assertEqual(resp.channel.data.calls_started, 1) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, 1) @@ -214,7 +223,8 @@ async def test_many_requests(self): for i in range(k_failed): await self._send_failed_unary_unary(pairs[0]) resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id) + ) self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, k_failed) @@ -223,8 +233,9 @@ async def test_many_requests(self): async def test_many_requests_many_channel(self): k_channels = 4 - pairs = await _create_channel_server_pairs(k_channels, - self._channelz_stub) + pairs = await _create_channel_server_pairs( + k_channels, self._channelz_stub + ) k_success = 11 k_failed = 13 for i in range(k_success): @@ -236,28 +247,32 @@ async def test_many_requests_many_channel(self): # The first channel saw only successes resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id) + ) self.assertEqual(resp.channel.data.calls_started, k_success) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, 0) # The second channel saw only failures resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=pairs[1].channel_ref_id)) + channelz_pb2.GetChannelRequest(channel_id=pairs[1].channel_ref_id) + ) self.assertEqual(resp.channel.data.calls_started, k_failed) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, k_failed) # The third channel saw both successes and failures resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=pairs[2].channel_ref_id)) + channelz_pb2.GetChannelRequest(channel_id=pairs[2].channel_ref_id) + ) self.assertEqual(resp.channel.data.calls_started, k_success + k_failed) self.assertEqual(resp.channel.data.calls_succeeded, k_success) self.assertEqual(resp.channel.data.calls_failed, k_failed) # The fourth channel saw nothing resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=pairs[3].channel_ref_id)) + channelz_pb2.GetChannelRequest(channel_id=pairs[3].channel_ref_id) + ) self.assertEqual(resp.channel.data.calls_started, 0) self.assertEqual(resp.channel.data.calls_succeeded, 0) self.assertEqual(resp.channel.data.calls_failed, 0) @@ -266,8 +281,9 @@ async def test_many_requests_many_channel(self): async def test_many_subchannels(self): k_channels = 4 - pairs = await _create_channel_server_pairs(k_channels, - self._channelz_stub) + pairs = await _create_channel_server_pairs( + k_channels, self._channelz_stub + ) k_success = 17 k_failed = 19 for i in range(k_success): @@ -280,7 +296,9 @@ async def test_many_subchannels(self): for i in range(k_channels): gc_resp = await self._channelz_stub.GetChannel( channelz_pb2.GetChannelRequest( - channel_id=pairs[i].channel_ref_id)) + channel_id=pairs[i].channel_ref_id + ) + ) # If no call performed in the channel, there shouldn't be any subchannel if gc_resp.channel.data.calls_started == 0: self.assertEqual(len(gc_resp.channel.subchannel_ref), 0) @@ -290,14 +308,23 @@ async def test_many_subchannels(self): self.assertGreater(len(gc_resp.channel.subchannel_ref), 0) gsc_resp = await self._channelz_stub.GetSubchannel( channelz_pb2.GetSubchannelRequest( - subchannel_id=gc_resp.channel.subchannel_ref[0]. - subchannel_id)) - self.assertEqual(gc_resp.channel.data.calls_started, - gsc_resp.subchannel.data.calls_started) - self.assertEqual(gc_resp.channel.data.calls_succeeded, - gsc_resp.subchannel.data.calls_succeeded) - self.assertEqual(gc_resp.channel.data.calls_failed, - gsc_resp.subchannel.data.calls_failed) + subchannel_id=gc_resp.channel.subchannel_ref[ + 0 + ].subchannel_id + ) + ) + self.assertEqual( + gc_resp.channel.data.calls_started, + gsc_resp.subchannel.data.calls_started, + ) + self.assertEqual( + gc_resp.channel.data.calls_succeeded, + gsc_resp.subchannel.data.calls_succeeded, + ) + self.assertEqual( + gc_resp.channel.data.calls_failed, + gsc_resp.subchannel.data.calls_failed, + ) await _destroy_channel_server_pairs(pairs) @@ -320,8 +347,9 @@ async def test_server_call(self): async def test_many_subchannels_and_sockets(self): k_channels = 4 - pairs = await _create_channel_server_pairs(k_channels, - self._channelz_stub) + pairs = await _create_channel_server_pairs( + k_channels, self._channelz_stub + ) k_success = 3 k_failed = 5 for i in range(k_success): @@ -334,7 +362,9 @@ async def test_many_subchannels_and_sockets(self): for i in range(k_channels): gc_resp = await self._channelz_stub.GetChannel( channelz_pb2.GetChannelRequest( - channel_id=pairs[i].channel_ref_id)) + channel_id=pairs[i].channel_ref_id + ) + ) # If no call performed in the channel, there shouldn't be any subchannel if gc_resp.channel.data.calls_started == 0: @@ -345,19 +375,28 @@ async def test_many_subchannels_and_sockets(self): self.assertGreater(len(gc_resp.channel.subchannel_ref), 0) gsc_resp = await self._channelz_stub.GetSubchannel( channelz_pb2.GetSubchannelRequest( - subchannel_id=gc_resp.channel.subchannel_ref[0]. - subchannel_id)) + subchannel_id=gc_resp.channel.subchannel_ref[ + 0 + ].subchannel_id + ) + ) self.assertEqual(len(gsc_resp.subchannel.socket_ref), 1) gs_resp = await self._channelz_stub.GetSocket( channelz_pb2.GetSocketRequest( - socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) - self.assertEqual(gsc_resp.subchannel.data.calls_started, - gs_resp.socket.data.streams_started) + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id + ) + ) + self.assertEqual( + gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.streams_started, + ) self.assertEqual(0, gs_resp.socket.data.streams_failed) # Calls started == messages sent, only valid for unary calls - self.assertEqual(gsc_resp.subchannel.data.calls_started, - gs_resp.socket.data.messages_sent) + self.assertEqual( + gsc_resp.subchannel.data.calls_started, + gs_resp.socket.data.messages_sent, + ) await _destroy_channel_server_pairs(pairs) @@ -368,7 +407,8 @@ async def test_streaming_rpc(self): await self._send_successful_stream_stream(pairs[0]) gc_resp = await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id)) + channelz_pb2.GetChannelRequest(channel_id=pairs[0].channel_ref_id) + ) self.assertEqual(gc_resp.channel.data.calls_started, 1) self.assertEqual(gc_resp.channel.data.calls_succeeded, 1) self.assertEqual(gc_resp.channel.data.calls_failed, 0) @@ -378,9 +418,16 @@ async def test_streaming_rpc(self): while True: gsc_resp = await self._channelz_stub.GetSubchannel( channelz_pb2.GetSubchannelRequest( - subchannel_id=gc_resp.channel.subchannel_ref[0]. - subchannel_id)) - if gsc_resp.subchannel.data.calls_started == gsc_resp.subchannel.data.calls_succeeded + gsc_resp.subchannel.data.calls_failed: + subchannel_id=gc_resp.channel.subchannel_ref[ + 0 + ].subchannel_id + ) + ) + if ( + gsc_resp.subchannel.data.calls_started + == gsc_resp.subchannel.data.calls_succeeded + + gsc_resp.subchannel.data.calls_failed + ): break self.assertEqual(gsc_resp.subchannel.data.calls_started, 1) self.assertEqual(gsc_resp.subchannel.data.calls_failed, 0) @@ -391,16 +438,24 @@ async def test_streaming_rpc(self): while True: gs_resp = await self._channelz_stub.GetSocket( channelz_pb2.GetSocketRequest( - socket_id=gsc_resp.subchannel.socket_ref[0].socket_id)) - if gs_resp.socket.data.streams_started == gs_resp.socket.data.streams_succeeded + gs_resp.socket.data.streams_failed: + socket_id=gsc_resp.subchannel.socket_ref[0].socket_id + ) + ) + if ( + gs_resp.socket.data.streams_started + == gs_resp.socket.data.streams_succeeded + + gs_resp.socket.data.streams_failed + ): break self.assertEqual(gs_resp.socket.data.streams_started, 1) self.assertEqual(gs_resp.socket.data.streams_failed, 0) self.assertEqual(gs_resp.socket.data.streams_succeeded, 1) - self.assertEqual(gs_resp.socket.data.messages_sent, - test_constants.STREAM_LENGTH) - self.assertEqual(gs_resp.socket.data.messages_received, - test_constants.STREAM_LENGTH) + self.assertEqual( + gs_resp.socket.data.messages_sent, test_constants.STREAM_LENGTH + ) + self.assertEqual( + gs_resp.socket.data.messages_received, test_constants.STREAM_LENGTH + ) await _destroy_channel_server_pairs(pairs) @@ -416,8 +471,10 @@ async def test_server_sockets(self): self.assertEqual(resp.data.calls_failed, 1) gss_resp = await self._channelz_stub.GetServerSockets( - channelz_pb2.GetServerSocketsRequest(server_id=resp.ref.server_id, - start_socket_id=0)) + channelz_pb2.GetServerSocketsRequest( + server_id=resp.ref.server_id, start_socket_id=0 + ) + ) # If the RPC call failed, it will raise a grpc.RpcError # So, if there is no exception raised, considered pass await _destroy_channel_server_pairs(pairs) @@ -430,7 +487,9 @@ async def test_server_listen_sockets(self): gs_resp = await self._channelz_stub.GetSocket( channelz_pb2.GetSocketRequest( - socket_id=resp.listen_socket[0].socket_id)) + socket_id=resp.listen_socket[0].socket_id + ) + ) # If the RPC call failed, it will raise a grpc.RpcError # So, if there is no exception raised, considered pass await _destroy_channel_server_pairs(pairs) @@ -438,31 +497,40 @@ async def test_server_listen_sockets(self): async def test_invalid_query_get_server(self): with self.assertRaises(aio.AioRpcError) as exception_context: await self._channelz_stub.GetServer( - channelz_pb2.GetServerRequest(server_id=_LARGE_UNASSIGNED_ID)) - self.assertEqual(grpc.StatusCode.NOT_FOUND, - exception_context.exception.code()) + channelz_pb2.GetServerRequest(server_id=_LARGE_UNASSIGNED_ID) + ) + self.assertEqual( + grpc.StatusCode.NOT_FOUND, exception_context.exception.code() + ) async def test_invalid_query_get_channel(self): with self.assertRaises(aio.AioRpcError) as exception_context: await self._channelz_stub.GetChannel( - channelz_pb2.GetChannelRequest(channel_id=_LARGE_UNASSIGNED_ID)) - self.assertEqual(grpc.StatusCode.NOT_FOUND, - exception_context.exception.code()) + channelz_pb2.GetChannelRequest(channel_id=_LARGE_UNASSIGNED_ID) + ) + self.assertEqual( + grpc.StatusCode.NOT_FOUND, exception_context.exception.code() + ) async def test_invalid_query_get_subchannel(self): with self.assertRaises(aio.AioRpcError) as exception_context: await self._channelz_stub.GetSubchannel( channelz_pb2.GetSubchannelRequest( - subchannel_id=_LARGE_UNASSIGNED_ID)) - self.assertEqual(grpc.StatusCode.NOT_FOUND, - exception_context.exception.code()) + subchannel_id=_LARGE_UNASSIGNED_ID + ) + ) + self.assertEqual( + grpc.StatusCode.NOT_FOUND, exception_context.exception.code() + ) async def test_invalid_query_get_socket(self): with self.assertRaises(aio.AioRpcError) as exception_context: await self._channelz_stub.GetSocket( - channelz_pb2.GetSocketRequest(socket_id=_LARGE_UNASSIGNED_ID)) - self.assertEqual(grpc.StatusCode.NOT_FOUND, - exception_context.exception.code()) + channelz_pb2.GetSocketRequest(socket_id=_LARGE_UNASSIGNED_ID) + ) + self.assertEqual( + grpc.StatusCode.NOT_FOUND, exception_context.exception.code() + ) async def test_invalid_query_get_server_sockets(self): with self.assertRaises(aio.AioRpcError) as exception_context: @@ -470,11 +538,13 @@ async def test_invalid_query_get_server_sockets(self): channelz_pb2.GetServerSocketsRequest( server_id=_LARGE_UNASSIGNED_ID, start_socket_id=0, - )) - self.assertEqual(grpc.StatusCode.NOT_FOUND, - exception_context.exception.code()) + ) + ) + self.assertEqual( + grpc.StatusCode.NOT_FOUND, exception_context.exception.code() + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py b/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py index 7c6776ecd750e..3bb0f6c31c4d4 100644 --- a/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py +++ b/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py @@ -28,10 +28,10 @@ from tests.unit.framework.common import test_constants from tests_aio.unit._test_base import AioTestBase -_SERVING_SERVICE = 'grpc.test.TestServiceServing' -_UNKNOWN_SERVICE = 'grpc.test.TestServiceUnknown' -_NOT_SERVING_SERVICE = 'grpc.test.TestServiceNotServing' -_WATCH_SERVICE = 'grpc.test.WatchService' +_SERVING_SERVICE = "grpc.test.TestServiceServing" +_UNKNOWN_SERVICE = "grpc.test.TestServiceUnknown" +_NOT_SERVING_SERVICE = "grpc.test.TestServiceNotServing" +_WATCH_SERVICE = "grpc.test.WatchService" _LARGE_NUMBER_OF_STATUS_CHANGES = 1000 @@ -42,22 +42,25 @@ async def _pipe_to_queue(call, queue): class HealthServicerTest(AioTestBase): - async def setUp(self): self._servicer = health.aio.HealthServicer() - await self._servicer.set(_SERVING_SERVICE, - health_pb2.HealthCheckResponse.SERVING) - await self._servicer.set(_UNKNOWN_SERVICE, - health_pb2.HealthCheckResponse.UNKNOWN) - await self._servicer.set(_NOT_SERVING_SERVICE, - health_pb2.HealthCheckResponse.NOT_SERVING) + await self._servicer.set( + _SERVING_SERVICE, health_pb2.HealthCheckResponse.SERVING + ) + await self._servicer.set( + _UNKNOWN_SERVICE, health_pb2.HealthCheckResponse.UNKNOWN + ) + await self._servicer.set( + _NOT_SERVING_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING + ) self._server = aio.server() - port = self._server.add_insecure_port('[::]:0') - health_pb2_grpc.add_HealthServicer_to_server(self._servicer, - self._server) + port = self._server.add_insecure_port("[::]:0") + health_pb2_grpc.add_HealthServicer_to_server( + self._servicer, self._server + ) await self._server.start() - self._channel = aio.insecure_channel('localhost:%d' % port) + self._channel = aio.insecure_channel("localhost:%d" % port) self._stub = health_pb2_grpc.HealthStub(self._channel) async def tearDown(self): @@ -82,18 +85,19 @@ async def test_check_unknown_service(self): async def test_check_not_serving_service(self): request = health_pb2.HealthCheckRequest(service=_NOT_SERVING_SERVICE) resp = await self._stub.Check(request) - self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, - resp.status) + self.assertEqual( + health_pb2.HealthCheckResponse.NOT_SERVING, resp.status + ) async def test_check_not_found_service(self): - request = health_pb2.HealthCheckRequest(service='not-found') + request = health_pb2.HealthCheckRequest(service="not-found") with self.assertRaises(aio.AioRpcError) as context: await self._stub.Check(request) self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code()) async def test_health_service_name(self): - self.assertEqual(health.SERVICE_NAME, 'grpc.health.v1.Health') + self.assertEqual(health.SERVICE_NAME, "grpc.health.v1.Health") async def test_watch_empty_service(self): request = health_pb2.HealthCheckRequest(service=health.OVERALL_HEALTH) @@ -102,8 +106,9 @@ async def test_watch_empty_service(self): queue = asyncio.Queue() task = self.loop.create_task(_pipe_to_queue(call, queue)) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - (await queue.get()).status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, (await queue.get()).status + ) call.cancel() @@ -118,18 +123,25 @@ async def test_watch_new_service(self): queue = asyncio.Queue() task = self.loop.create_task(_pipe_to_queue(call, queue)) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - (await queue.get()).status) - - await self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - (await queue.get()).status) - - await self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.NOT_SERVING) - self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, - (await queue.get()).status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue.get()).status, + ) + + await self._servicer.set( + _WATCH_SERVICE, health_pb2.HealthCheckResponse.SERVING + ) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, (await queue.get()).status + ) + + await self._servicer.set( + _WATCH_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING + ) + self.assertEqual( + health_pb2.HealthCheckResponse.NOT_SERVING, + (await queue.get()).status, + ) call.cancel() @@ -144,11 +156,14 @@ async def test_watch_service_isolation(self): queue = asyncio.Queue() task = self.loop.create_task(_pipe_to_queue(call, queue)) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - (await queue.get()).status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue.get()).status, + ) - await self._servicer.set('some-other-service', - health_pb2.HealthCheckResponse.SERVING) + await self._servicer.set( + "some-other-service", health_pb2.HealthCheckResponse.SERVING + ) # The change of health status in other service should be isolated. # Hence, no additional notification should be observed. with self.assertRaises(asyncio.TimeoutError): @@ -170,17 +185,24 @@ async def test_two_watchers(self): task1 = self.loop.create_task(_pipe_to_queue(call1, queue1)) task2 = self.loop.create_task(_pipe_to_queue(call2, queue2)) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - (await queue1.get()).status) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - (await queue2.get()).status) - - await self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - (await queue1.get()).status) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - (await queue2.get()).status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue1.get()).status, + ) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue2.get()).status, + ) + + await self._servicer.set( + _WATCH_SERVICE, health_pb2.HealthCheckResponse.SERVING + ) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, (await queue1.get()).status + ) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, (await queue2.get()).status + ) call1.cancel() call2.cancel() @@ -200,22 +222,27 @@ async def test_cancelled_watch_removed_from_watch_list(self): queue = asyncio.Queue() task = self.loop.create_task(_pipe_to_queue(call, queue)) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - (await queue.get()).status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue.get()).status, + ) call.cancel() - await self._servicer.set(_WATCH_SERVICE, - health_pb2.HealthCheckResponse.SERVING) + await self._servicer.set( + _WATCH_SERVICE, health_pb2.HealthCheckResponse.SERVING + ) with self.assertRaises(asyncio.CancelledError): await task # Wait for the serving coroutine to process client cancellation. timeout = time.monotonic() + test_constants.TIME_ALLOWANCE - while (time.monotonic() < timeout and self._servicer._server_watchers): + while time.monotonic() < timeout and self._servicer._server_watchers: await asyncio.sleep(1) - self.assertFalse(self._servicer._server_watchers, - 'There should not be any watcher left') + self.assertFalse( + self._servicer._server_watchers, + "There should not be any watcher left", + ) self.assertTrue(queue.empty()) async def test_graceful_shutdown(self): @@ -224,20 +251,25 @@ async def test_graceful_shutdown(self): queue = asyncio.Queue() task = self.loop.create_task(_pipe_to_queue(call, queue)) - self.assertEqual(health_pb2.HealthCheckResponse.SERVING, - (await queue.get()).status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVING, (await queue.get()).status + ) await self._servicer.enter_graceful_shutdown() - self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, - (await queue.get()).status) + self.assertEqual( + health_pb2.HealthCheckResponse.NOT_SERVING, + (await queue.get()).status, + ) # This should be a no-op. - await self._servicer.set(health.OVERALL_HEALTH, - health_pb2.HealthCheckResponse.SERVING) + await self._servicer.set( + health.OVERALL_HEALTH, health_pb2.HealthCheckResponse.SERVING + ) resp = await self._stub.Check(request) - self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, - resp.status) + self.assertEqual( + health_pb2.HealthCheckResponse.NOT_SERVING, resp.status + ) call.cancel() @@ -252,8 +284,10 @@ async def test_no_duplicate_status(self): queue = asyncio.Queue() task = self.loop.create_task(_pipe_to_queue(call, queue)) - self.assertEqual(health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, - (await queue.get()).status) + self.assertEqual( + health_pb2.HealthCheckResponse.SERVICE_UNKNOWN, + (await queue.get()).status, + ) last_status = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN for _ in range(_LARGE_NUMBER_OF_STATUS_CHANGES): @@ -275,6 +309,6 @@ async def test_no_duplicate_status(self): self.assertTrue(queue.empty()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/interop/client.py b/src/python/grpcio_tests/tests_aio/interop/client.py index a4c5e12cedab2..f9f7d370e5858 100644 --- a/src/python/grpcio_tests/tests_aio/interop/client.py +++ b/src/python/grpcio_tests/tests_aio/interop/client.py @@ -28,11 +28,17 @@ def _create_channel(args): - target = f'{args.server_host}:{args.server_port}' + target = f"{args.server_host}:{args.server_port}" - if args.use_tls or args.use_alts or args.custom_credentials_type is not None: - channel_credentials, options = interop_client_lib.get_secure_channel_parameters( - args) + if ( + args.use_tls + or args.use_alts + or args.custom_credentials_type is not None + ): + ( + channel_credentials, + options, + ) = interop_client_lib.get_secure_channel_parameters(args) return aio.secure_channel(target, channel_credentials, options) else: return aio.insecure_channel(target) @@ -47,7 +53,6 @@ def _test_case_from_arg(test_case_arg): async def test_interoperability(): - args = interop_client_lib.parse_interop_client_args() channel = _create_channel(args) stub = interop_client_lib.create_stub(channel, args) @@ -55,7 +60,7 @@ async def test_interoperability(): await methods.test_interoperability(test_case, stub, args) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) asyncio.get_event_loop().set_debug(True) asyncio.get_event_loop().run_until_complete(test_interoperability()) diff --git a/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py b/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py index 0db15be3a94c7..a0347670edf18 100644 --- a/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py +++ b/src/python/grpcio_tests/tests_aio/interop/local_interop_test.py @@ -25,7 +25,7 @@ from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server -_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_SERVER_HOST_OVERRIDE = "foo.test.google.fr" class InteropTestCaseMixin: @@ -34,67 +34,81 @@ class InteropTestCaseMixin: This class must be mixed in with unittest.TestCase and a class that defines setUp and tearDown methods that manage a stub attribute. """ + _stub: test_pb2_grpc.TestServiceStub async def test_empty_unary(self): - await methods.test_interoperability(methods.TestCase.EMPTY_UNARY, - self._stub, None) + await methods.test_interoperability( + methods.TestCase.EMPTY_UNARY, self._stub, None + ) async def test_large_unary(self): - await methods.test_interoperability(methods.TestCase.LARGE_UNARY, - self._stub, None) + await methods.test_interoperability( + methods.TestCase.LARGE_UNARY, self._stub, None + ) async def test_server_streaming(self): - await methods.test_interoperability(methods.TestCase.SERVER_STREAMING, - self._stub, None) + await methods.test_interoperability( + methods.TestCase.SERVER_STREAMING, self._stub, None + ) async def test_client_streaming(self): - await methods.test_interoperability(methods.TestCase.CLIENT_STREAMING, - self._stub, None) + await methods.test_interoperability( + methods.TestCase.CLIENT_STREAMING, self._stub, None + ) async def test_ping_pong(self): - await methods.test_interoperability(methods.TestCase.PING_PONG, - self._stub, None) + await methods.test_interoperability( + methods.TestCase.PING_PONG, self._stub, None + ) async def test_cancel_after_begin(self): - await methods.test_interoperability(methods.TestCase.CANCEL_AFTER_BEGIN, - self._stub, None) + await methods.test_interoperability( + methods.TestCase.CANCEL_AFTER_BEGIN, self._stub, None + ) async def test_cancel_after_first_response(self): await methods.test_interoperability( - methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE, self._stub, None) + methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE, self._stub, None + ) async def test_timeout_on_sleeping_server(self): await methods.test_interoperability( - methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER, self._stub, None) + methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER, self._stub, None + ) async def test_empty_stream(self): - await methods.test_interoperability(methods.TestCase.EMPTY_STREAM, - self._stub, None) + await methods.test_interoperability( + methods.TestCase.EMPTY_STREAM, self._stub, None + ) async def test_status_code_and_message(self): await methods.test_interoperability( - methods.TestCase.STATUS_CODE_AND_MESSAGE, self._stub, None) + methods.TestCase.STATUS_CODE_AND_MESSAGE, self._stub, None + ) async def test_unimplemented_method(self): await methods.test_interoperability( - methods.TestCase.UNIMPLEMENTED_METHOD, self._stub, None) + methods.TestCase.UNIMPLEMENTED_METHOD, self._stub, None + ) async def test_unimplemented_service(self): await methods.test_interoperability( - methods.TestCase.UNIMPLEMENTED_SERVICE, self._stub, None) + methods.TestCase.UNIMPLEMENTED_SERVICE, self._stub, None + ) async def test_custom_metadata(self): - await methods.test_interoperability(methods.TestCase.CUSTOM_METADATA, - self._stub, None) + await methods.test_interoperability( + methods.TestCase.CUSTOM_METADATA, self._stub, None + ) async def test_special_status_message(self): await methods.test_interoperability( - methods.TestCase.SPECIAL_STATUS_MESSAGE, self._stub, None) + methods.TestCase.SPECIAL_STATUS_MESSAGE, self._stub, None + ) class InsecureLocalInteropTest(InteropTestCaseMixin, AioTestBase): - async def setUp(self): address, self._server = await start_test_server() self._channel = aio.insecure_channel(address) @@ -106,22 +120,26 @@ async def tearDown(self): class SecureLocalInteropTest(InteropTestCaseMixin, AioTestBase): - async def setUp(self): - server_credentials = grpc.ssl_server_credentials([ - (resources.private_key(), resources.certificate_chain()) - ]) + server_credentials = grpc.ssl_server_credentials( + [(resources.private_key(), resources.certificate_chain())] + ) channel_credentials = grpc.ssl_channel_credentials( - resources.test_root_certificates()) - channel_options = (( - 'grpc.ssl_target_name_override', - _SERVER_HOST_OVERRIDE, - ),) + resources.test_root_certificates() + ) + channel_options = ( + ( + "grpc.ssl_target_name_override", + _SERVER_HOST_OVERRIDE, + ), + ) address, self._server = await start_test_server( - secure=True, server_credentials=server_credentials) - self._channel = aio.secure_channel(address, channel_credentials, - channel_options) + secure=True, server_credentials=server_credentials + ) + self._channel = aio.secure_channel( + address, channel_credentials, channel_options + ) self._stub = test_pb2_grpc.TestServiceStub(self._channel) async def tearDown(self): @@ -129,6 +147,6 @@ async def tearDown(self): await self._server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/interop/methods.py b/src/python/grpcio_tests/tests_aio/interop/methods.py index 6524a5ed0b010..ef887433fb28d 100644 --- a/src/python/grpcio_tests/tests_aio/interop/methods.py +++ b/src/python/grpcio_tests/tests_aio/interop/methods.py @@ -40,51 +40,65 @@ _TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin" -async def _expect_status_code(call: aio.Call, - expected_code: grpc.StatusCode) -> None: +async def _expect_status_code( + call: aio.Call, expected_code: grpc.StatusCode +) -> None: code = await call.code() if code != expected_code: - raise ValueError('expected code %s, got %s' % - (expected_code, await call.code())) + raise ValueError( + "expected code %s, got %s" % (expected_code, await call.code()) + ) async def _expect_status_details(call: aio.Call, expected_details: str) -> None: details = await call.details() if details != expected_details: - raise ValueError('expected message %s, got %s' % - (expected_details, await call.details())) + raise ValueError( + "expected message %s, got %s" + % (expected_details, await call.details()) + ) -async def _validate_status_code_and_details(call: aio.Call, - expected_code: grpc.StatusCode, - expected_details: str) -> None: +async def _validate_status_code_and_details( + call: aio.Call, expected_code: grpc.StatusCode, expected_details: str +) -> None: await _expect_status_code(call, expected_code) await _expect_status_details(call, expected_details) -def _validate_payload_type_and_length(response: Union[ - messages_pb2.SimpleResponse, messages_pb2.StreamingOutputCallResponse], - expected_type: Any, - expected_length: int) -> None: +def _validate_payload_type_and_length( + response: Union[ + messages_pb2.SimpleResponse, messages_pb2.StreamingOutputCallResponse + ], + expected_type: Any, + expected_length: int, +) -> None: if response.payload.type is not expected_type: - raise ValueError('expected payload type %s, got %s' % - (expected_type, type(response.payload.type))) + raise ValueError( + "expected payload type %s, got %s" + % (expected_type, type(response.payload.type)) + ) elif len(response.payload.body) != expected_length: - raise ValueError('expected payload body size %d, got %d' % - (expected_length, len(response.payload.body))) + raise ValueError( + "expected payload body size %d, got %d" + % (expected_length, len(response.payload.body)) + ) async def _large_unary_common_behavior( - stub: test_pb2_grpc.TestServiceStub, fill_username: bool, - fill_oauth_scope: bool, call_credentials: Optional[grpc.CallCredentials] + stub: test_pb2_grpc.TestServiceStub, + fill_username: bool, + fill_oauth_scope: bool, + call_credentials: Optional[grpc.CallCredentials], ) -> messages_pb2.SimpleResponse: size = 314159 request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=size, - payload=messages_pb2.Payload(body=b'\x00' * 271828), + payload=messages_pb2.Payload(body=b"\x00" * 271828), fill_username=fill_username, - fill_oauth_scope=fill_oauth_scope) + fill_oauth_scope=fill_oauth_scope, + ) response = await stub.UnaryCall(request, credentials=call_credentials) _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size) return response @@ -93,8 +107,9 @@ async def _large_unary_common_behavior( async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None: response = await stub.EmptyCall(empty_pb2.Empty()) if not isinstance(response, empty_pb2.Empty): - raise TypeError('response is of type "%s", not empty_pb2.Empty!' % - type(response)) + raise TypeError( + 'response is of type "%s", not empty_pb2.Empty!' % type(response) + ) async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None: @@ -112,12 +127,14 @@ async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: async def request_gen(): for size in payload_body_sizes: yield messages_pb2.StreamingInputCallRequest( - payload=messages_pb2.Payload(body=b'\x00' * size)) + payload=messages_pb2.Payload(body=b"\x00" * size) + ) response = await stub.StreamingInputCall(request_gen()) if response.aggregated_payload_size != sum(payload_body_sizes): - raise ValueError('incorrect size %d!' % - response.aggregated_payload_size) + raise ValueError( + "incorrect size %d!" % response.aggregated_payload_size + ) async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: @@ -135,12 +152,14 @@ async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None: messages_pb2.ResponseParameters(size=sizes[1]), messages_pb2.ResponseParameters(size=sizes[2]), messages_pb2.ResponseParameters(size=sizes[3]), - )) + ), + ) call = stub.StreamingOutputCall(request) for size in sizes: response = await call.read() - _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, - size) + _validate_payload_type_and_length( + response, messages_pb2.COMPRESSABLE, size + ) async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None: @@ -158,30 +177,34 @@ async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None: ) call = stub.FullDuplexCall() - for response_size, payload_size in zip(request_response_sizes, - request_payload_sizes): + for response_size, payload_size in zip( + request_response_sizes, request_payload_sizes + ): request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - response_parameters=(messages_pb2.ResponseParameters( - size=response_size),), - payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + response_parameters=( + messages_pb2.ResponseParameters(size=response_size), + ), + payload=messages_pb2.Payload(body=b"\x00" * payload_size), + ) await call.write(request) response = await call.read() - _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, - response_size) + _validate_payload_type_and_length( + response, messages_pb2.COMPRESSABLE, response_size + ) await call.done_writing() - await _validate_status_code_and_details(call, grpc.StatusCode.OK, '') + await _validate_status_code_and_details(call, grpc.StatusCode.OK, "") async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub): call = stub.StreamingInputCall() call.cancel() if not call.cancelled(): - raise ValueError('expected cancelled method to return True') + raise ValueError("expected cancelled method to return True") code = await call.code() if code is not grpc.StatusCode.CANCELLED: - raise ValueError('expected status code CANCELLED') + raise ValueError("expected status code CANCELLED") async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub): @@ -204,9 +227,11 @@ async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub): payload_size = request_payload_sizes[0] request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - response_parameters=(messages_pb2.ResponseParameters( - size=response_size),), - payload=messages_pb2.Payload(body=b'\x00' * payload_size)) + response_parameters=( + messages_pb2.ResponseParameters(size=response_size), + ), + payload=messages_pb2.Payload(body=b"\x00" * payload_size), + ) await call.write(request) await call.read() @@ -218,7 +243,7 @@ async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub): except asyncio.CancelledError: assert await call.code() is grpc.StatusCode.CANCELLED else: - raise ValueError('expected call to be cancelled') + raise ValueError("expected call to be cancelled") async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub): @@ -229,9 +254,13 @@ async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub): request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - payload=messages_pb2.Payload(body=b'\x00' * request_payload_size), - response_parameters=(messages_pb2.ResponseParameters( - interval_us=int(time_limit.total_seconds() * 2 * 10**6)),)) + payload=messages_pb2.Payload(body=b"\x00" * request_payload_size), + response_parameters=( + messages_pb2.ResponseParameters( + interval_us=int(time_limit.total_seconds() * 2 * 10**6) + ), + ), + ) await call.write(request) await call.done_writing() try: @@ -240,7 +269,7 @@ async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub): if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED: raise else: - raise ValueError('expected call to exceed deadline') + raise ValueError("expected call to exceed deadline") async def _empty_stream(stub: test_pb2_grpc.TestServiceStub): @@ -250,16 +279,18 @@ async def _empty_stream(stub: test_pb2_grpc.TestServiceStub): async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub): - details = 'test status message' + details = "test status message" status = grpc.StatusCode.UNKNOWN # code = 2 # Test with a UnaryCall request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=1, - payload=messages_pb2.Payload(body=b'\x00'), - response_status=messages_pb2.EchoStatus(code=status.value[0], - message=details)) + payload=messages_pb2.Payload(body=b"\x00"), + response_status=messages_pb2.EchoStatus( + code=status.value[0], message=details + ), + ) call = stub.UnaryCall(request) await _validate_status_code_and_details(call, status, details) @@ -268,9 +299,11 @@ async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub): request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, response_parameters=(messages_pb2.ResponseParameters(size=1),), - payload=messages_pb2.Payload(body=b'\x00'), - response_status=messages_pb2.EchoStatus(code=status.value[0], - message=details)) + payload=messages_pb2.Payload(body=b"\x00"), + response_status=messages_pb2.EchoStatus( + code=status.value[0], message=details + ), + ) await call.write(request) # sends the initial request. await call.done_writing() try: @@ -301,21 +334,30 @@ async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub): async def _validate_metadata(call): initial_metadata = await call.initial_metadata() if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value: - raise ValueError('expected initial metadata %s, got %s' % - (initial_metadata_value, - initial_metadata[_INITIAL_METADATA_KEY])) + raise ValueError( + "expected initial metadata %s, got %s" + % ( + initial_metadata_value, + initial_metadata[_INITIAL_METADATA_KEY], + ) + ) trailing_metadata = await call.trailing_metadata() if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value: - raise ValueError('expected trailing metadata %s, got %s' % - (trailing_metadata_value, - trailing_metadata[_TRAILING_METADATA_KEY])) + raise ValueError( + "expected trailing metadata %s, got %s" + % ( + trailing_metadata_value, + trailing_metadata[_TRAILING_METADATA_KEY], + ) + ) # Testing with UnaryCall request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=1, - payload=messages_pb2.Payload(body=b'\x00')) + payload=messages_pb2.Payload(body=b"\x00"), + ) call = stub.UnaryCall(request, metadata=metadata) await _validate_metadata(call) @@ -323,97 +365,116 @@ async def _validate_metadata(call): call = stub.FullDuplexCall(metadata=metadata) request = messages_pb2.StreamingOutputCallRequest( response_type=messages_pb2.COMPRESSABLE, - response_parameters=(messages_pb2.ResponseParameters(size=1),)) + response_parameters=(messages_pb2.ResponseParameters(size=1),), + ) await call.write(request) await call.read() await call.done_writing() await _validate_metadata(call) -async def _compute_engine_creds(stub: test_pb2_grpc.TestServiceStub, - args: argparse.Namespace): +async def _compute_engine_creds( + stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace +): response = await _large_unary_common_behavior(stub, True, True, None) if args.default_service_account != response.username: - raise ValueError('expected username %s, got %s' % - (args.default_service_account, response.username)) + raise ValueError( + "expected username %s, got %s" + % (args.default_service_account, response.username) + ) -async def _oauth2_auth_token(stub: test_pb2_grpc.TestServiceStub, - args: argparse.Namespace): +async def _oauth2_auth_token( + stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace +): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + wanted_email = json.load(open(json_key_filename, "r"))["client_email"] response = await _large_unary_common_behavior(stub, True, True, None) if wanted_email != response.username: - raise ValueError('expected username %s, got %s' % - (wanted_email, response.username)) + raise ValueError( + "expected username %s, got %s" % (wanted_email, response.username) + ) if args.oauth_scope.find(response.oauth_scope) == -1: raise ValueError( 'expected to find oauth scope "{}" in received "{}"'.format( - response.oauth_scope, args.oauth_scope)) + response.oauth_scope, args.oauth_scope + ) + ) async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + wanted_email = json.load(open(json_key_filename, "r"))["client_email"] response = await _large_unary_common_behavior(stub, True, False, None) if wanted_email != response.username: - raise ValueError('expected username %s, got %s' % - (wanted_email, response.username)) + raise ValueError( + "expected username %s, got %s" % (wanted_email, response.username) + ) -async def _per_rpc_creds(stub: test_pb2_grpc.TestServiceStub, - args: argparse.Namespace): +async def _per_rpc_creds( + stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace +): json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'r'))['client_email'] + wanted_email = json.load(open(json_key_filename, "r"))["client_email"] google_credentials, unused_project_id = google_auth.default( - scopes=[args.oauth_scope]) + scopes=[args.oauth_scope] + ) call_credentials = grpc.metadata_call_credentials( google_auth_transport_grpc.AuthMetadataPlugin( credentials=google_credentials, - request=google_auth_transport_requests.Request())) - response = await _large_unary_common_behavior(stub, True, False, - call_credentials) + request=google_auth_transport_requests.Request(), + ) + ) + response = await _large_unary_common_behavior( + stub, True, False, call_credentials + ) if wanted_email != response.username: - raise ValueError('expected username %s, got %s' % - (wanted_email, response.username)) + raise ValueError( + "expected username %s, got %s" % (wanted_email, response.username) + ) async def _special_status_message(stub: test_pb2_grpc.TestServiceStub): - details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode( - 'utf-8') + details = ( + b"\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP" + b" \xf0\x9f\x98\x88\t\n".decode("utf-8") + ) status = grpc.StatusCode.UNKNOWN # code = 2 # Test with a UnaryCall request = messages_pb2.SimpleRequest( response_type=messages_pb2.COMPRESSABLE, response_size=1, - payload=messages_pb2.Payload(body=b'\x00'), - response_status=messages_pb2.EchoStatus(code=status.value[0], - message=details)) + payload=messages_pb2.Payload(body=b"\x00"), + response_status=messages_pb2.EchoStatus( + code=status.value[0], message=details + ), + ) call = stub.UnaryCall(request) await _validate_status_code_and_details(call, status, details) @enum.unique class TestCase(enum.Enum): - EMPTY_UNARY = 'empty_unary' - LARGE_UNARY = 'large_unary' - SERVER_STREAMING = 'server_streaming' - CLIENT_STREAMING = 'client_streaming' - PING_PONG = 'ping_pong' - CANCEL_AFTER_BEGIN = 'cancel_after_begin' - CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' - TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' - EMPTY_STREAM = 'empty_stream' - STATUS_CODE_AND_MESSAGE = 'status_code_and_message' - UNIMPLEMENTED_METHOD = 'unimplemented_method' - UNIMPLEMENTED_SERVICE = 'unimplemented_service' + EMPTY_UNARY = "empty_unary" + LARGE_UNARY = "large_unary" + SERVER_STREAMING = "server_streaming" + CLIENT_STREAMING = "client_streaming" + PING_PONG = "ping_pong" + CANCEL_AFTER_BEGIN = "cancel_after_begin" + CANCEL_AFTER_FIRST_RESPONSE = "cancel_after_first_response" + TIMEOUT_ON_SLEEPING_SERVER = "timeout_on_sleeping_server" + EMPTY_STREAM = "empty_stream" + STATUS_CODE_AND_MESSAGE = "status_code_and_message" + UNIMPLEMENTED_METHOD = "unimplemented_method" + UNIMPLEMENTED_SERVICE = "unimplemented_service" CUSTOM_METADATA = "custom_metadata" - COMPUTE_ENGINE_CREDS = 'compute_engine_creds' - OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' - JWT_TOKEN_CREDS = 'jwt_token_creds' - PER_RPC_CREDS = 'per_rpc_creds' - SPECIAL_STATUS_MESSAGE = 'special_status_message' + COMPUTE_ENGINE_CREDS = "compute_engine_creds" + OAUTH2_AUTH_TOKEN = "oauth2_auth_token" + JWT_TOKEN_CREDS = "jwt_token_creds" + PER_RPC_CREDS = "per_rpc_creds" + SPECIAL_STATUS_MESSAGE = "special_status_message" _TEST_CASE_IMPLEMENTATION_MAPPING = { @@ -439,9 +500,10 @@ class TestCase(enum.Enum): async def test_interoperability( - case: TestCase, - stub: test_pb2_grpc.TestServiceStub, - args: Optional[argparse.Namespace] = None) -> None: + case: TestCase, + stub: test_pb2_grpc.TestServiceStub, + args: Optional[argparse.Namespace] = None, +) -> None: method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case) if method is None: raise NotImplementedError(f'Test case "{case}" not implemented!') @@ -453,6 +515,6 @@ async def test_interoperability( if args is not None: await method(stub, args) else: - raise ValueError(f'Failed to run case [{case}]: args is None') + raise ValueError(f"Failed to run case [{case}]: args is None") else: - raise ValueError(f'Invalid number of parameters [{num_params}]') + raise ValueError(f"Invalid number of parameters [{num_params}]") diff --git a/src/python/grpcio_tests/tests_aio/interop/server.py b/src/python/grpcio_tests/tests_aio/interop/server.py index e40c831a876fe..7786a4206a6d6 100644 --- a/src/python/grpcio_tests/tests_aio/interop/server.py +++ b/src/python/grpcio_tests/tests_aio/interop/server.py @@ -33,17 +33,18 @@ async def serve(): if args.use_tls or args.use_alts: credentials = interop_server_lib.get_server_credentials(args.use_tls) address, server = await _test_server.start_test_server( - port=args.port, secure=True, server_credentials=credentials) + port=args.port, secure=True, server_credentials=credentials + ) else: address, server = await _test_server.start_test_server( port=args.port, secure=False, ) - _LOGGER.info('Server serving at %s', address) + _LOGGER.info("Server serving at %s", address) await server.wait_for_termination() - _LOGGER.info('Server stopped; exiting.') + _LOGGER.info("Server stopped; exiting.") -if __name__ == '__main__': +if __name__ == "__main__": asyncio.get_event_loop().run_until_complete(serve()) diff --git a/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py b/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py index 888ec448fa5cd..de89aca588285 100644 --- a/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py +++ b/src/python/grpcio_tests/tests_aio/reflection/reflection_servicer_test.py @@ -28,11 +28,18 @@ from src.proto.grpc.testing.proto2 import empty2_pb2 from tests_aio.unit._test_base import AioTestBase -_EMPTY_PROTO_FILE_NAME = 'src/proto/grpc/testing/empty.proto' -_EMPTY_PROTO_SYMBOL_NAME = 'grpc.testing.Empty' -_SERVICE_NAMES = ('Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman', - 'Galilei') -_EMPTY_EXTENSIONS_SYMBOL_NAME = 'grpc.testing.proto2.EmptyWithExtensions' +_EMPTY_PROTO_FILE_NAME = "src/proto/grpc/testing/empty.proto" +_EMPTY_PROTO_SYMBOL_NAME = "grpc.testing.Empty" +_SERVICE_NAMES = ( + "Angstrom", + "Bohr", + "Curie", + "Dyson", + "Einstein", + "Feynman", + "Galilei", +) +_EMPTY_EXTENSIONS_SYMBOL_NAME = "grpc.testing.proto2.EmptyWithExtensions" _EMPTY_EXTENSIONS_NUMBERS = ( 124, 125, @@ -49,14 +56,13 @@ def _file_descriptor_to_proto(descriptor): class ReflectionServicerTest(AioTestBase): - async def setUp(self): self._server = aio.server() reflection.enable_server_reflection(_SERVICE_NAMES, self._server) - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") await self._server.start() - self._channel = aio.insecure_channel('localhost:%d' % port) + self._channel = aio.insecure_channel("localhost:%d" % port) self._stub = reflection_pb2_grpc.ServerReflectionStub(self._channel) async def tearDown(self): @@ -66,34 +72,41 @@ async def tearDown(self): async def test_file_by_name(self): requests = ( reflection_pb2.ServerReflectionRequest( - file_by_filename=_EMPTY_PROTO_FILE_NAME), + file_by_filename=_EMPTY_PROTO_FILE_NAME + ), reflection_pb2.ServerReflectionRequest( - file_by_filename='i-donut-exist'), + file_by_filename="i-donut-exist" + ), ) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) expected_responses = ( reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=( - _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))), + _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), + ) + ), + ), reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )), + ), + ), ) self.assertSequenceEqual(expected_responses, responses) async def test_file_by_symbol(self): requests = ( reflection_pb2.ServerReflectionRequest( - file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME), + file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME + ), reflection_pb2.ServerReflectionRequest( - file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo' + file_containing_symbol="i.donut.exist.co.uk.org.net.me.name.foo" ), ) responses = [] @@ -101,16 +114,20 @@ async def test_file_by_symbol(self): responses.append(response) expected_responses = ( reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", file_descriptor_response=reflection_pb2.FileDescriptorResponse( file_descriptor_proto=( - _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))), + _file_descriptor_to_proto(empty_pb2.DESCRIPTOR), + ) + ), + ), reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )), + ), + ), ) self.assertSequenceEqual(expected_responses, responses) @@ -120,39 +137,47 @@ async def test_file_containing_extension(self): file_containing_extension=reflection_pb2.ExtensionRequest( containing_type=_EMPTY_EXTENSIONS_SYMBOL_NAME, extension_number=125, - ),), + ), + ), reflection_pb2.ServerReflectionRequest( file_containing_extension=reflection_pb2.ExtensionRequest( - containing_type='i.donut.exist.co.uk.org.net.me.name.foo', + containing_type="i.donut.exist.co.uk.org.net.me.name.foo", extension_number=55, - ),), + ), + ), ) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) expected_responses = ( reflection_pb2.ServerReflectionResponse( - valid_host='', - file_descriptor_response=reflection_pb2. - FileDescriptorResponse(file_descriptor_proto=( - _file_descriptor_to_proto(empty2_extensions_pb2.DESCRIPTOR), - _file_descriptor_to_proto(empty2_pb2.DESCRIPTOR), - ))), + valid_host="", + file_descriptor_response=reflection_pb2.FileDescriptorResponse( + file_descriptor_proto=( + _file_descriptor_to_proto( + empty2_extensions_pb2.DESCRIPTOR + ), + _file_descriptor_to_proto(empty2_pb2.DESCRIPTOR), + ) + ), + ), reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )), + ), + ), ) self.assertSequenceEqual(expected_responses, responses) async def test_extension_numbers_of_type(self): requests = ( reflection_pb2.ServerReflectionRequest( - all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME), + all_extension_numbers_of_type=_EMPTY_EXTENSIONS_SYMBOL_NAME + ), reflection_pb2.ServerReflectionRequest( - all_extension_numbers_of_type='i.donut.exist.co.uk.net.name.foo' + all_extension_numbers_of_type="i.donut.exist.co.uk.net.name.foo" ), ) responses = [] @@ -160,38 +185,50 @@ async def test_extension_numbers_of_type(self): responses.append(response) expected_responses = ( reflection_pb2.ServerReflectionResponse( - valid_host='', - all_extension_numbers_response=reflection_pb2. - ExtensionNumberResponse( + valid_host="", + all_extension_numbers_response=reflection_pb2.ExtensionNumberResponse( base_type_name=_EMPTY_EXTENSIONS_SYMBOL_NAME, - extension_number=_EMPTY_EXTENSIONS_NUMBERS)), + extension_number=_EMPTY_EXTENSIONS_NUMBERS, + ), + ), reflection_pb2.ServerReflectionResponse( - valid_host='', + valid_host="", error_response=reflection_pb2.ErrorResponse( error_code=grpc.StatusCode.NOT_FOUND.value[0], error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(), - )), + ), + ), ) self.assertSequenceEqual(expected_responses, responses) async def test_list_services(self): - requests = (reflection_pb2.ServerReflectionRequest(list_services='',),) + requests = ( + reflection_pb2.ServerReflectionRequest( + list_services="", + ), + ) responses = [] async for response in self._stub.ServerReflectionInfo(iter(requests)): responses.append(response) - expected_responses = (reflection_pb2.ServerReflectionResponse( - valid_host='', - list_services_response=reflection_pb2.ListServiceResponse( - service=tuple( - reflection_pb2.ServiceResponse(name=name) - for name in _SERVICE_NAMES))),) + expected_responses = ( + reflection_pb2.ServerReflectionResponse( + valid_host="", + list_services_response=reflection_pb2.ListServiceResponse( + service=tuple( + reflection_pb2.ServiceResponse(name=name) + for name in _SERVICE_NAMES + ) + ), + ), + ) self.assertSequenceEqual(expected_responses, responses) def test_reflection_service_name(self): - self.assertEqual(reflection.SERVICE_NAME, - 'grpc.reflection.v1alpha.ServerReflection') + self.assertEqual( + reflection.SERVICE_NAME, "grpc.reflection.v1alpha.ServerReflection" + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/status/grpc_status_test.py b/src/python/grpcio_tests/tests_aio/status/grpc_status_test.py index df5b75b9cd9e7..4e7dbb12ba1ef 100644 --- a/src/python/grpcio_tests/tests_aio/status/grpc_status_test.py +++ b/src/python/grpcio_tests/tests_aio/status/grpc_status_test.py @@ -27,19 +27,19 @@ from tests_aio.unit._test_base import AioTestBase -_STATUS_OK = '/test/StatusOK' -_STATUS_NOT_OK = '/test/StatusNotOk' -_ERROR_DETAILS = '/test/ErrorDetails' -_INCONSISTENT = '/test/Inconsistent' -_INVALID_CODE = '/test/InvalidCode' +_STATUS_OK = "/test/StatusOK" +_STATUS_NOT_OK = "/test/StatusNotOk" +_ERROR_DETAILS = "/test/ErrorDetails" +_INCONSISTENT = "/test/Inconsistent" +_INVALID_CODE = "/test/InvalidCode" -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x01\x01\x01' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x01\x01\x01" -_GRPC_DETAILS_METADATA_KEY = 'grpc-status-details-bin' +_GRPC_DETAILS_METADATA_KEY = "grpc-status-details-bin" -_STATUS_DETAILS = 'This is an error detail' -_STATUS_DETAILS_ANOTHER = 'This is another error detail' +_STATUS_DETAILS = "This is an error detail" +_STATUS_DETAILS_ANOTHER = "This is another error detail" async def _ok_unary_unary(request, servicer_context): @@ -53,8 +53,11 @@ async def _not_ok_unary_unary(request, servicer_context): async def _error_details_unary_unary(request, servicer_context): details = any_pb2.Any() details.Pack( - error_details_pb2.DebugInfo(stack_entries=traceback.format_stack(), - detail='Intentionally invoked')) + error_details_pb2.DebugInfo( + stack_entries=traceback.format_stack(), + detail="Intentionally invoked", + ) + ) rich_status = status_pb2.Status( code=code_pb2.INTERNAL, message=_STATUS_DETAILS, @@ -72,19 +75,19 @@ async def _inconsistent_unary_unary(request, servicer_context): servicer_context.set_details(_STATUS_DETAILS_ANOTHER) # User put inconsistent status information in trailing metadata servicer_context.set_trailing_metadata( - ((_GRPC_DETAILS_METADATA_KEY, rich_status.SerializeToString()),)) + ((_GRPC_DETAILS_METADATA_KEY, rich_status.SerializeToString()),) + ) async def _invalid_code_unary_unary(request, servicer_context): rich_status = status_pb2.Status( code=42, - message='Invalid code', + message="Invalid code", ) await servicer_context.abort_with_status(rpc_status.to_status(rich_status)) class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _STATUS_OK: return grpc.unary_unary_rpc_method_handler(_ok_unary_unary) @@ -92,26 +95,28 @@ def service(self, handler_call_details): return grpc.unary_unary_rpc_method_handler(_not_ok_unary_unary) elif handler_call_details.method == _ERROR_DETAILS: return grpc.unary_unary_rpc_method_handler( - _error_details_unary_unary) + _error_details_unary_unary + ) elif handler_call_details.method == _INCONSISTENT: return grpc.unary_unary_rpc_method_handler( - _inconsistent_unary_unary) + _inconsistent_unary_unary + ) elif handler_call_details.method == _INVALID_CODE: return grpc.unary_unary_rpc_method_handler( - _invalid_code_unary_unary) + _invalid_code_unary_unary + ) else: return None class StatusTest(AioTestBase): - async def setUp(self): self._server = aio.server() self._server.add_generic_rpc_handlers((_GenericHandler(),)) - port = self._server.add_insecure_port('[::]:0') + port = self._server.add_insecure_port("[::]:0") await self._server.start() - self._channel = aio.insecure_channel('localhost:%d' % port) + self._channel = aio.insecure_channel("localhost:%d" % port) async def tearDown(self): await self._server.stop(None) @@ -143,14 +148,15 @@ async def test_error_details(self): status = await rpc_status.aio.from_call(call) self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) - self.assertEqual(status.code, code_pb2.Code.Value('INTERNAL')) + self.assertEqual(status.code, code_pb2.Code.Value("INTERNAL")) # Check if the underlying proto message is intact - self.assertTrue(status.details[0].Is( - error_details_pb2.DebugInfo.DESCRIPTOR)) + self.assertTrue( + status.details[0].Is(error_details_pb2.DebugInfo.DESCRIPTOR) + ) info = error_details_pb2.DebugInfo() status.details[0].Unpack(info) - self.assertIn('_error_details_unary_unary', info.stack_entries[-1]) + self.assertIn("_error_details_unary_unary", info.stack_entries[-1]) async def test_code_message_validation(self): call = self._channel.unary_unary(_INCONSISTENT)(_REQUEST) @@ -169,9 +175,9 @@ async def test_invalid_code(self): rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) # Invalid status code exception raised during coversion - self.assertIn('Invalid status code', rpc_error.details()) + self.assertIn("Invalid status code", rpc_error.details()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index 05947733a0822..b29dfb4889e6c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -24,21 +24,23 @@ from tests.unit.framework.common import test_constants -ADHOC_METHOD = '/test/AdHoc' +ADHOC_METHOD = "/test/AdHoc" def seen_metadata(expected: Metadata, actual: Metadata): return not bool(set(tuple(expected)) - set(tuple(actual))) -def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, - actual: Metadata) -> bool: +def seen_metadatum( + expected_key: MetadataKey, expected_value: MetadataValue, actual: Metadata +) -> bool: obtained = actual[expected_key] return obtained == expected_value -async def block_until_certain_state(channel: aio.Channel, - expected_state: grpc.ChannelConnectivity): +async def block_until_certain_state( + channel: aio.Channel, expected_state: grpc.ChannelConnectivity +): state = channel.get_state() while state != expected_state: await channel.wait_for_state_change(state) @@ -67,15 +69,16 @@ def second_callback(call): async def validation(): await asyncio.wait_for( - asyncio.gather(first_callback_ran.wait(), - second_callback_ran.wait()), - test_constants.SHORT_TIMEOUT) + asyncio.gather( + first_callback_ran.wait(), second_callback_ran.wait() + ), + test_constants.SHORT_TIMEOUT, + ) return validation() class CountingRequestIterator: - def __init__(self, request_iterator): self.request_cnt = 0 self._request_iterator = request_iterator @@ -90,7 +93,6 @@ def __aiter__(self): class CountingResponseIterator: - def __init__(self, response_iterator): self.response_cnt = 0 self._response_iterator = response_iterator @@ -106,6 +108,7 @@ def __aiter__(self): class AdhocGenericHandler(grpc.GenericRpcHandler): """A generic handler to plugin testing server methods on the fly.""" + _handler: grpc.RpcMethodHandler def __init__(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/_constants.py b/src/python/grpcio_tests/tests_aio/unit/_constants.py index ab7e06f8fca61..74d15c37dce34 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_constants.py +++ b/src/python/grpcio_tests/tests_aio/unit/_constants.py @@ -15,5 +15,5 @@ # If we use an unreachable IP, depending on the network stack, we might not get # with an RST fast enough. This used to cause tests to flake under different # platforms. -UNREACHABLE_TARGET = 'foo/bar' +UNREACHABLE_TARGET = "foo/bar" UNARY_CALL_WITH_SLEEP_VALUE = 0.2 diff --git a/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py index c0594cb06abe3..98a932f7280c1 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/_metadata_test.py @@ -22,8 +22,11 @@ class TestTypeMetadata(unittest.TestCase): """Tests for the metadata type""" _DEFAULT_DATA = (("key1", "value1"), ("key2", "value2")) - _MULTI_ENTRY_DATA = (("key1", "value1"), ("key1", "other value 1"), - ("key2", "value2")) + _MULTI_ENTRY_DATA = ( + ("key1", "value1"), + ("key1", "other value 1"), + ("key2", "value2"), + ) def test_init_metadata(self): test_cases = { @@ -37,8 +40,9 @@ def test_init_metadata(self): self.assertEqual(len(metadata), len(args)) def test_get_item(self): - metadata = Metadata(("key", "value1"), ("key", "value2"), - ("key2", "other value")) + metadata = Metadata( + ("key", "value1"), ("key", "value2"), ("key2", "other value") + ) self.assertEqual(metadata["key"], "value1") self.assertEqual(metadata["key2"], "other value") self.assertEqual(metadata.get("key"), "value1") @@ -88,8 +92,9 @@ def test_set(self): metadata["key1"] = override_value self.assertEqual(metadata["key1"], override_value) - self.assertEqual(metadata.get_all("key1"), - [override_value, "other value 1"]) + self.assertEqual( + metadata.get_all("key1"), [override_value, "other value 1"] + ) empty_metadata = Metadata() for _ in range(3): @@ -132,6 +137,6 @@ def test_metadata_from_tuple(self): self.assertEqual(expected, Metadata.from_tuple(source)) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_base.py b/src/python/grpcio_tests/tests_aio/unit/_test_base.py index fcd1e90a5a05b..07fbc85311838 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_base.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_base.py @@ -20,13 +20,12 @@ from grpc.experimental import aio -__all__ = 'AioTestBase' +__all__ = "AioTestBase" -_COROUTINE_FUNCTION_ALLOWLIST = ['setUp', 'tearDown'] +_COROUTINE_FUNCTION_ALLOWLIST = ["setUp", "tearDown"] def _async_to_sync_decorator(f: Callable, loop: asyncio.AbstractEventLoop): - @functools.wraps(f) def wrapper(*args, **kwargs): return loop.run_until_complete(f(*args, **kwargs)) @@ -60,7 +59,7 @@ def __getattribute__(self, name): attr = super().__getattribute__(name) # If possible, converts the coroutine into a sync function. - if name.startswith('test_') or name in _COROUTINE_FUNCTION_ALLOWLIST: + if name.startswith("test_") or name in _COROUTINE_FUNCTION_ALLOWLIST: if asyncio.iscoroutinefunction(attr): return _async_to_sync_decorator(attr, self._TEST_LOOP) # For other attributes, let them pass. diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 0119fda37c56d..bcc29cd2a6b09 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -32,48 +32,60 @@ async def _maybe_echo_metadata(servicer_context): """Copies metadata from request to response if it is present.""" invocation_metadata = dict(servicer_context.invocation_metadata()) if _INITIAL_METADATA_KEY in invocation_metadata: - initial_metadatum = (_INITIAL_METADATA_KEY, - invocation_metadata[_INITIAL_METADATA_KEY]) + initial_metadatum = ( + _INITIAL_METADATA_KEY, + invocation_metadata[_INITIAL_METADATA_KEY], + ) await servicer_context.send_initial_metadata((initial_metadatum,)) if _TRAILING_METADATA_KEY in invocation_metadata: - trailing_metadatum = (_TRAILING_METADATA_KEY, - invocation_metadata[_TRAILING_METADATA_KEY]) + trailing_metadatum = ( + _TRAILING_METADATA_KEY, + invocation_metadata[_TRAILING_METADATA_KEY], + ) servicer_context.set_trailing_metadata((trailing_metadatum,)) -async def _maybe_echo_status(request: messages_pb2.SimpleRequest, - servicer_context): +async def _maybe_echo_status( + request: messages_pb2.SimpleRequest, servicer_context +): """Echos the RPC status if demanded by the request.""" - if request.HasField('response_status'): - await servicer_context.abort(request.response_status.code, - request.response_status.message) + if request.HasField("response_status"): + await servicer_context.abort( + request.response_status.code, request.response_status.message + ) class TestServiceServicer(test_pb2_grpc.TestServiceServicer): - async def UnaryCall(self, request, context): await _maybe_echo_metadata(context) await _maybe_echo_status(request, context) return messages_pb2.SimpleResponse( - payload=messages_pb2.Payload(type=messages_pb2.COMPRESSABLE, - body=b'\x00' * request.response_size)) + payload=messages_pb2.Payload( + type=messages_pb2.COMPRESSABLE, + body=b"\x00" * request.response_size, + ) + ) async def EmptyCall(self, request, context): return empty_pb2.Empty() async def StreamingOutputCall( - self, request: messages_pb2.StreamingOutputCallRequest, - unused_context): + self, request: messages_pb2.StreamingOutputCallRequest, unused_context + ): for response_parameters in request.response_parameters: if response_parameters.interval_us != 0: await asyncio.sleep( - datetime.timedelta(microseconds=response_parameters. - interval_us).total_seconds()) + datetime.timedelta( + microseconds=response_parameters.interval_us + ).total_seconds() + ) if response_parameters.size != 0: yield messages_pb2.StreamingOutputCallResponse( - payload=messages_pb2.Payload(type=request.response_type, - body=b'\x00' * - response_parameters.size)) + payload=messages_pb2.Payload( + type=request.response_type, + body=b"\x00" * response_parameters.size, + ) + ) else: yield messages_pb2.StreamingOutputCallResponse() @@ -90,7 +102,8 @@ async def StreamingInputCall(self, request_async_iterator, unused_context): if request.payload is not None and request.payload.body: aggregate_size += len(request.payload.body) return messages_pb2.StreamingInputCallResponse( - aggregated_payload_size=aggregate_size) + aggregated_payload_size=aggregate_size + ) async def FullDuplexCall(self, request_async_iterator, context): await _maybe_echo_metadata(context) @@ -99,13 +112,17 @@ async def FullDuplexCall(self, request_async_iterator, context): for response_parameters in request.response_parameters: if response_parameters.interval_us != 0: await asyncio.sleep( - datetime.timedelta(microseconds=response_parameters. - interval_us).total_seconds()) + datetime.timedelta( + microseconds=response_parameters.interval_us + ).total_seconds() + ) if response_parameters.size != 0: yield messages_pb2.StreamingOutputCallResponse( - payload=messages_pb2.Payload(type=request.payload.type, - body=b'\x00' * - response_parameters.size)) + payload=messages_pb2.Payload( + type=request.payload.type, + body=b"\x00" * response_parameters.size, + ) + ) else: yield messages_pb2.StreamingOutputCallResponse() @@ -114,23 +131,23 @@ def _create_extra_generic_handler(servicer: TestServiceServicer): # Add programatically extra methods not provided by the proto file # that are used during the tests rpc_method_handlers = { - 'UnaryCallWithSleep': - grpc.unary_unary_rpc_method_handler( - servicer.UnaryCallWithSleep, - request_deserializer=messages_pb2.SimpleRequest.FromString, - response_serializer=messages_pb2.SimpleResponse. - SerializeToString) + "UnaryCallWithSleep": grpc.unary_unary_rpc_method_handler( + servicer.UnaryCallWithSleep, + request_deserializer=messages_pb2.SimpleRequest.FromString, + response_serializer=messages_pb2.SimpleResponse.SerializeToString, + ) } - return grpc.method_handlers_generic_handler('grpc.testing.TestService', - rpc_method_handlers) + return grpc.method_handlers_generic_handler( + "grpc.testing.TestService", rpc_method_handlers + ) -async def start_test_server(port=0, - secure=False, - server_credentials=None, - interceptors=None): - server = aio.server(options=(('grpc.so_reuseport', 0),), - interceptors=interceptors) +async def start_test_server( + port=0, secure=False, server_credentials=None, interceptors=None +): + server = aio.server( + options=(("grpc.so_reuseport", 0),), interceptors=interceptors + ) servicer = TestServiceServicer() test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) @@ -138,14 +155,14 @@ async def start_test_server(port=0, if secure: if server_credentials is None: - server_credentials = grpc.ssl_server_credentials([ - (resources.private_key(), resources.certificate_chain()) - ]) - port = server.add_secure_port('[::]:%d' % port, server_credentials) + server_credentials = grpc.ssl_server_credentials( + [(resources.private_key(), resources.certificate_chain())] + ) + port = server.add_secure_port("[::]:%d" % port, server_credentials) else: - port = server.add_insecure_port('[::]:%d' % port) + port = server.add_insecure_port("[::]:%d" % port) await server.start() # NOTE(lidizheng) returning the server to prevent it from deallocation - return 'localhost:%d' % port, server + return "localhost:%d" % port, server diff --git a/src/python/grpcio_tests/tests_aio/unit/abort_test.py b/src/python/grpcio_tests/tests_aio/unit/abort_test.py index 45ef9481a8c61..a6543b9de70e1 100644 --- a/src/python/grpcio_tests/tests_aio/unit/abort_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/abort_test.py @@ -24,25 +24,24 @@ from tests.unit.framework.common import test_constants from tests_aio.unit._test_base import AioTestBase -_UNARY_UNARY_ABORT = '/test/UnaryUnaryAbort' -_SUPPRESS_ABORT = '/test/SuppressAbort' -_REPLACE_ABORT = '/test/ReplaceAbort' -_ABORT_AFTER_REPLY = '/test/AbortAfterReply' +_UNARY_UNARY_ABORT = "/test/UnaryUnaryAbort" +_SUPPRESS_ABORT = "/test/SuppressAbort" +_REPLACE_ABORT = "/test/ReplaceAbort" +_ABORT_AFTER_REPLY = "/test/AbortAfterReply" -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x01\x01\x01' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x01\x01\x01" _NUM_STREAM_RESPONSES = 5 _ABORT_CODE = grpc.StatusCode.RESOURCE_EXHAUSTED -_ABORT_DETAILS = 'Phony error details' +_ABORT_DETAILS = "Phony error details" class _GenericHandler(grpc.GenericRpcHandler): - @staticmethod async def _unary_unary_abort(unused_request, context): await context.abort(_ABORT_CODE, _ABORT_DETAILS) - raise RuntimeError('This line should not be executed') + raise RuntimeError("This line should not be executed") @staticmethod async def _suppress_abort(unused_request, context): @@ -57,14 +56,15 @@ async def _replace_abort(unused_request, context): try: await context.abort(_ABORT_CODE, _ABORT_DETAILS) except aio.AbortError as e: - await context.abort(grpc.StatusCode.INVALID_ARGUMENT, - 'Override abort!') + await context.abort( + grpc.StatusCode.INVALID_ARGUMENT, "Override abort!" + ) @staticmethod async def _abort_after_reply(unused_request, context): yield _RESPONSE await context.abort(_ABORT_CODE, _ABORT_DETAILS) - raise RuntimeError('This line should not be executed') + raise RuntimeError("This line should not be executed") def service(self, handler_details): if handler_details.method == _UNARY_UNARY_ABORT: @@ -79,14 +79,13 @@ def service(self, handler_details): async def _start_test_server(): server = aio.server() - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") server.add_generic_rpc_handlers((_GenericHandler(),)) await server.start() - return 'localhost:%d' % port, server + return "localhost:%d" % port, server class TestAbort(AioTestBase): - async def setUp(self): address, self._server = await _start_test_server() self._channel = aio.insecure_channel(address) @@ -147,6 +146,6 @@ async def test_abort_after_reply(self): self.assertEqual(_ABORT_DETAILS, await call.details()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py index 730871d1be317..83cf239dacc89 100644 --- a/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/aio_rpc_error_test.py @@ -23,30 +23,36 @@ from tests_aio.unit._test_base import AioTestBase _TEST_INITIAL_METADATA = aio.Metadata( - ('initial metadata key', 'initial metadata value')) + ("initial metadata key", "initial metadata value") +) _TEST_TRAILING_METADATA = aio.Metadata( - ('trailing metadata key', 'trailing metadata value')) -_TEST_DEBUG_ERROR_STRING = '{This is a debug string}' + ("trailing metadata key", "trailing metadata value") +) +_TEST_DEBUG_ERROR_STRING = "{This is a debug string}" class TestAioRpcError(unittest.TestCase): - def test_attributes(self): - aio_rpc_error = AioRpcError(grpc.StatusCode.CANCELLED, - initial_metadata=_TEST_INITIAL_METADATA, - trailing_metadata=_TEST_TRAILING_METADATA, - details="details", - debug_error_string=_TEST_DEBUG_ERROR_STRING) + aio_rpc_error = AioRpcError( + grpc.StatusCode.CANCELLED, + initial_metadata=_TEST_INITIAL_METADATA, + trailing_metadata=_TEST_TRAILING_METADATA, + details="details", + debug_error_string=_TEST_DEBUG_ERROR_STRING, + ) self.assertEqual(aio_rpc_error.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(aio_rpc_error.details(), 'details') - self.assertEqual(aio_rpc_error.initial_metadata(), - _TEST_INITIAL_METADATA) - self.assertEqual(aio_rpc_error.trailing_metadata(), - _TEST_TRAILING_METADATA) - self.assertEqual(aio_rpc_error.debug_error_string(), - _TEST_DEBUG_ERROR_STRING) - - -if __name__ == '__main__': + self.assertEqual(aio_rpc_error.details(), "details") + self.assertEqual( + aio_rpc_error.initial_metadata(), _TEST_INITIAL_METADATA + ) + self.assertEqual( + aio_rpc_error.trailing_metadata(), _TEST_TRAILING_METADATA + ) + self.assertEqual( + aio_rpc_error.debug_error_string(), _TEST_DEBUG_ERROR_STRING + ) + + +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py b/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py index 819e72e630706..78c7c22094c3c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/auth_context_test.py @@ -24,54 +24,62 @@ from tests.unit import resources from tests_aio.unit._test_base import AioTestBase -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x00\x00\x00' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x00\x00\x00" -_UNARY_UNARY = '/test/UnaryUnary' +_UNARY_UNARY = "/test/UnaryUnary" -_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_SERVER_HOST_OVERRIDE = "foo.test.google.fr" _CLIENT_IDS = ( - b'*.test.google.fr', - b'waterzooi.test.google.be', - b'*.test.youtube.com', - b'192.168.1.3', + b"*.test.google.fr", + b"waterzooi.test.google.be", + b"*.test.youtube.com", + b"192.168.1.3", ) -_ID = 'id' -_ID_KEY = 'id_key' -_AUTH_CTX = 'auth_ctx' +_ID = "id" +_ID_KEY = "id_key" +_AUTH_CTX = "auth_ctx" _PRIVATE_KEY = resources.private_key() _CERTIFICATE_CHAIN = resources.certificate_chain() _TEST_ROOT_CERTIFICATES = resources.test_root_certificates() _SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) -_PROPERTY_OPTIONS = (( - 'grpc.ssl_target_name_override', - _SERVER_HOST_OVERRIDE, -),) +_PROPERTY_OPTIONS = ( + ( + "grpc.ssl_target_name_override", + _SERVER_HOST_OVERRIDE, + ), +) -async def handle_unary_unary(unused_request: bytes, - servicer_context: aio.ServicerContext): - return pickle.dumps({ - _ID: servicer_context.peer_identities(), - _ID_KEY: servicer_context.peer_identity_key(), - _AUTH_CTX: servicer_context.auth_context() - }) +async def handle_unary_unary( + unused_request: bytes, servicer_context: aio.ServicerContext +): + return pickle.dumps( + { + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context(), + } + ) class TestAuthContext(AioTestBase): - async def test_insecure(self): - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(handle_unary_unary) - }) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + handle_unary_unary + ) + }, + ) server = aio.server() server.add_generic_rpc_handlers((handler,)) - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") await server.start() - async with aio.insecure_channel('localhost:%d' % port) as channel: + async with aio.insecure_channel("localhost:%d" % port) as channel: response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) await server.stop(None) @@ -80,26 +88,35 @@ async def test_insecure(self): self.assertIsNone(auth_data[_ID_KEY]) self.assertDictEqual( { - 'security_level': [b'TSI_SECURITY_NONE'], - 'transport_security_type': [b'insecure'], - }, auth_data[_AUTH_CTX]) + "security_level": [b"TSI_SECURITY_NONE"], + "transport_security_type": [b"insecure"], + }, + auth_data[_AUTH_CTX], + ) async def test_secure_no_cert(self): - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(handle_unary_unary) - }) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + handle_unary_unary + ) + }, + ) server = aio.server() server.add_generic_rpc_handlers((handler,)) server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) - port = server.add_secure_port('[::]:0', server_cred) + port = server.add_secure_port("[::]:0", server_cred) await server.start() channel_creds = grpc.ssl_channel_credentials( - root_certificates=_TEST_ROOT_CERTIFICATES) - channel = aio.secure_channel('localhost:{}'.format(port), - channel_creds, - options=_PROPERTY_OPTIONS) + root_certificates=_TEST_ROOT_CERTIFICATES + ) + channel = aio.secure_channel( + "localhost:{}".format(port), + channel_creds, + options=_PROPERTY_OPTIONS, + ) response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) await channel.close() await server.stop(None) @@ -109,32 +126,42 @@ async def test_secure_no_cert(self): self.assertIsNone(auth_data[_ID_KEY]) self.assertDictEqual( { - 'security_level': [b'TSI_PRIVACY_AND_INTEGRITY'], - 'transport_security_type': [b'ssl'], - 'ssl_session_reused': [b'false'], - }, auth_data[_AUTH_CTX]) + "security_level": [b"TSI_PRIVACY_AND_INTEGRITY"], + "transport_security_type": [b"ssl"], + "ssl_session_reused": [b"false"], + }, + auth_data[_AUTH_CTX], + ) async def test_secure_client_cert(self): - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(handle_unary_unary) - }) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + handle_unary_unary + ) + }, + ) server = aio.server() server.add_generic_rpc_handlers((handler,)) server_cred = grpc.ssl_server_credentials( _SERVER_CERTS, root_certificates=_TEST_ROOT_CERTIFICATES, - require_client_auth=True) - port = server.add_secure_port('[::]:0', server_cred) + require_client_auth=True, + ) + port = server.add_secure_port("[::]:0", server_cred) await server.start() channel_creds = grpc.ssl_channel_credentials( root_certificates=_TEST_ROOT_CERTIFICATES, private_key=_PRIVATE_KEY, - certificate_chain=_CERTIFICATE_CHAIN) - channel = aio.secure_channel('localhost:{}'.format(port), - channel_creds, - options=_PROPERTY_OPTIONS) + certificate_chain=_CERTIFICATE_CHAIN, + ) + channel = aio.secure_channel( + "localhost:{}".format(port), + channel_creds, + options=_PROPERTY_OPTIONS, + ) response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) await channel.close() @@ -143,55 +170,69 @@ async def test_secure_client_cert(self): auth_data = pickle.loads(response) auth_ctx = auth_data[_AUTH_CTX] self.assertCountEqual(_CLIENT_IDS, auth_data[_ID]) - self.assertEqual('x509_subject_alternative_name', auth_data[_ID_KEY]) - self.assertSequenceEqual([b'ssl'], auth_ctx['transport_security_type']) - self.assertSequenceEqual([b'*.test.google.com'], - auth_ctx['x509_common_name']) - - async def _do_one_shot_client_rpc(self, channel_creds, channel_options, - port, expect_ssl_session_reused): - channel = aio.secure_channel('localhost:{}'.format(port), - channel_creds, - options=channel_options) + self.assertEqual("x509_subject_alternative_name", auth_data[_ID_KEY]) + self.assertSequenceEqual([b"ssl"], auth_ctx["transport_security_type"]) + self.assertSequenceEqual( + [b"*.test.google.com"], auth_ctx["x509_common_name"] + ) + + async def _do_one_shot_client_rpc( + self, channel_creds, channel_options, port, expect_ssl_session_reused + ): + channel = aio.secure_channel( + "localhost:{}".format(port), channel_creds, options=channel_options + ) response = await channel.unary_unary(_UNARY_UNARY)(_REQUEST) auth_data = pickle.loads(response) - self.assertEqual(expect_ssl_session_reused, - auth_data[_AUTH_CTX]['ssl_session_reused']) + self.assertEqual( + expect_ssl_session_reused, + auth_data[_AUTH_CTX]["ssl_session_reused"], + ) await channel.close() async def test_session_resumption(self): # Set up a secure server - handler = grpc.method_handlers_generic_handler('test', { - 'UnaryUnary': - grpc.unary_unary_rpc_method_handler(handle_unary_unary) - }) + handler = grpc.method_handlers_generic_handler( + "test", + { + "UnaryUnary": grpc.unary_unary_rpc_method_handler( + handle_unary_unary + ) + }, + ) server = aio.server() server.add_generic_rpc_handlers((handler,)) server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) - port = server.add_secure_port('[::]:0', server_cred) + port = server.add_secure_port("[::]:0", server_cred) await server.start() # Create a cache for TLS session tickets cache = session_cache.ssl_session_cache_lru(1) channel_creds = grpc.ssl_channel_credentials( - root_certificates=_TEST_ROOT_CERTIFICATES) + root_certificates=_TEST_ROOT_CERTIFICATES + ) channel_options = _PROPERTY_OPTIONS + ( - ('grpc.ssl_session_cache', cache),) + ("grpc.ssl_session_cache", cache), + ) # Initial connection has no session to resume - await self._do_one_shot_client_rpc(channel_creds, - channel_options, - port, - expect_ssl_session_reused=[b'false']) + await self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b"false"], + ) # Subsequent connections resume sessions - await self._do_one_shot_client_rpc(channel_creds, - channel_options, - port, - expect_ssl_session_reused=[b'true']) + await self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b"true"], + ) await server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index e0642d4e89d23..6885f0301e85c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -33,7 +33,7 @@ _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 _REQUEST_PAYLOAD_SIZE = 7 -_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' +_LOCAL_CANCEL_DETAILS_EXPECTATION = "Locally cancelled by application!" _RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) _INFINITE_INTERVAL_US = 2**31 - 1 @@ -41,8 +41,7 @@ _NONDETERMINISTIC_SERVER_SLEEP_MAX_US = 1000 -class _MulticallableTestMixin(): - +class _MulticallableTestMixin: async def setUp(self): address, self._server = await start_test_server() self._channel = aio.insecure_channel(address) @@ -54,7 +53,6 @@ async def tearDown(self): class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): - async def test_call_to_string(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -91,8 +89,9 @@ async def test_call_rpc_error(self): with self.assertRaises(aio.AioRpcError) as exception_context: await call - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() + ) self.assertTrue(call.done()) self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) @@ -103,7 +102,7 @@ async def test_call_code_awaitable(self): async def test_call_details_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - self.assertEqual('', await call.details()) + self.assertEqual("", await call.details()) async def test_call_initial_metadata_awaitable(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -169,8 +168,10 @@ async def coro(): await call - self.assertEqual([grpc.StatusCode.OK, grpc.StatusCode.OK], await - asyncio.gather(task1, task2)) + self.assertEqual( + [grpc.StatusCode.OK, grpc.StatusCode.OK], + await asyncio.gather(task1, task2), + ) async def test_cancel_unary_unary(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -186,8 +187,9 @@ async def test_cancel_unary_unary(self): # The info in the RpcError should match the info in Call object. self.assertTrue(call.cancelled()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - 'Locally cancelled by application!') + self.assertEqual( + await call.details(), "Locally cancelled by application!" + ) async def test_cancel_unary_unary_in_task(self): coro_started = asyncio.Event() @@ -214,14 +216,14 @@ async def test_passing_credentials_fails_over_insecure_channel(self): grpc.access_token_call_credentials("def"), ) with self.assertRaisesRegex( - aio.UsageError, - "Call credentials are only valid on secure channels"): - self._stub.UnaryCall(messages_pb2.SimpleRequest(), - credentials=call_credentials) + aio.UsageError, "Call credentials are only valid on secure channels" + ): + self._stub.UnaryCall( + messages_pb2.SimpleRequest(), credentials=call_credentials + ) class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): - async def test_call_rpc_error(self): channel = aio.insecure_channel(UNREACHABLE_TARGET) request = messages_pb2.StreamingOutputCallRequest() @@ -232,8 +234,9 @@ async def test_call_rpc_error(self): async for response in call: pass - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() + ) self.assertTrue(call.done()) self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) @@ -247,7 +250,8 @@ async def test_cancel_unary_stream(self): messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_RESPONSE_INTERVAL_US, - )) + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) @@ -259,8 +263,9 @@ async def test_cancel_unary_stream(self): self.assertTrue(call.cancel()) self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) - self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await - call.details()) + self.assertEqual( + _LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details() + ) self.assertFalse(call.cancel()) with self.assertRaises(asyncio.CancelledError): @@ -275,7 +280,8 @@ async def test_multiple_cancel_unary_stream(self): messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_RESPONSE_INTERVAL_US, - )) + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) @@ -302,7 +308,8 @@ async def test_early_cancel_unary_stream(self): messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_RESPONSE_INTERVAL_US, - )) + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) @@ -317,8 +324,9 @@ async def test_early_cancel_unary_stream(self): self.assertTrue(call.cancelled()) self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) - self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await - call.details()) + self.assertEqual( + _LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details() + ) async def test_late_cancel_unary_stream(self): """Test cancellation after received all messages.""" @@ -326,23 +334,28 @@ async def test_late_cancel_unary_stream(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) for _ in range(_NUM_STREAM_RESPONSES): response = await call.read() - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # After all messages received, it is possible that the final state # is received or on its way. It's basically a data race, so our # expectation here is do not crash :) call.cancel() - self.assertIn(await call.code(), - [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED]) + self.assertIn( + await call.code(), [grpc.StatusCode.OK, grpc.StatusCode.CANCELLED] + ) async def test_too_many_reads_unary_stream(self): """Test calling read after received all messages fails.""" @@ -350,15 +363,19 @@ async def test_too_many_reads_unary_stream(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) for _ in range(_NUM_STREAM_RESPONSES): response = await call.read() - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertIs(await call.read(), aio.EOF) @@ -372,21 +389,25 @@ async def test_unary_stream_async_generator(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) self.assertFalse(call.cancelled()) async for response in call: - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_cancel_unary_stream_with_many_interleavings(self): - """ A cheap alternative to a structured fuzzer. + """A cheap alternative to a structured fuzzer. Certain classes of error only appear for very specific interleavings of coroutines. Rather than inserting semi-private asyncio.Events throughout @@ -401,7 +422,8 @@ async def test_cancel_unary_stream_with_many_interleavings(self): for sleep_range in sleep_ranges: for _ in range(_NONDETERMINISTIC_ITERATIONS): interval_us = random.randrange( - _NONDETERMINISTIC_SERVER_SLEEP_MAX_US) + _NONDETERMINISTIC_SERVER_SLEEP_MAX_US + ) sleep_secs = sleep_range * random.random() coro_started = asyncio.Event() @@ -412,7 +434,8 @@ async def test_cancel_unary_stream_with_many_interleavings(self): messages_pb2.ResponseParameters( size=1, interval_us=interval_us, - )) + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) @@ -452,7 +475,8 @@ async def test_cancel_unary_stream_in_task_using_read(self): messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_INFINITE_INTERVAL_US, - )) + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) @@ -481,7 +505,8 @@ async def test_cancel_unary_stream_in_task_using_async_for(self): messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_INFINITE_INTERVAL_US, - )) + ) + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) @@ -506,16 +531,21 @@ async def test_time_remaining(self): request = messages_pb2.StreamingOutputCallRequest() # First message comes back immediately request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + ) + ) # Second message comes back after a unit of wait time request.response_parameters.append( messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_RESPONSE_INTERVAL_US, - )) + ) + ) - call = self._stub.StreamingOutputCall(request, - timeout=_SHORT_TIMEOUT_S * 2) + call = self._stub.StreamingOutputCall( + request, timeout=_SHORT_TIMEOUT_S * 2 + ) response = await call.read() self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) @@ -540,27 +570,28 @@ async def test_empty_responses(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters()) + messages_pb2.ResponseParameters() + ) # Invokes the actual RPC call = self._stub.StreamingOutputCall(request) for _ in range(_NUM_STREAM_RESPONSES): response = await call.read() - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) - self.assertEqual(b'', response.SerializeToString()) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) + self.assertEqual(b"", response.SerializeToString()) self.assertEqual(grpc.StatusCode.OK, await call.code()) class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): - async def test_cancel_stream_unary(self): call = self._stub.StreamingInputCall() # Prepares the request - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Sends out requests @@ -600,7 +631,7 @@ async def test_write_after_done_writing(self): call = self._stub.StreamingInputCall() # Prepares the request - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Sends out requests @@ -615,8 +646,10 @@ async def test_write_after_done_writing(self): response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) @@ -627,7 +660,8 @@ async def test_error_in_async_generator(self): messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_RESPONSE_INTERVAL_US, - )) + ) + ) # We expect the request iterator to receive the exception request_iterator_received_the_exception = asyncio.Event() @@ -658,7 +692,7 @@ async def cancel_later(): async def test_normal_iterable_requests(self): # Prepares the request - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) requests = [request] * _NUM_STREAM_RESPONSES @@ -668,8 +702,10 @@ async def test_normal_iterable_requests(self): # RPC should succeed response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) @@ -682,8 +718,9 @@ async def test_call_rpc_error(self): with self.assertRaises(aio.AioRpcError) as exception_context: await call - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() + ) self.assertTrue(call.done()) self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) @@ -704,15 +741,17 @@ async def test_timeout(self): # Prepares the request that stream in a ping-pong manner. _STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() _STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) -_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = messages_pb2.StreamingOutputCallRequest( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) +) +_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE = ( + messages_pb2.StreamingOutputCallRequest() ) _STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE.response_parameters.append( - messages_pb2.ResponseParameters()) + messages_pb2.ResponseParameters() +) class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): - async def test_cancel(self): # Invokes the actual RPC call = self._stub.FullDuplexCall() @@ -720,8 +759,9 @@ async def test_cancel(self): for _ in range(_NUM_STREAM_RESPONSES): await call.write(_STREAM_OUTPUT_REQUEST_ONE_RESPONSE) response = await call.read() - self.assertIsInstance(response, - messages_pb2.StreamingOutputCallResponse) + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) # Cancels the RPC @@ -798,7 +838,6 @@ async def test_late_cancel(self): self.assertEqual(grpc.StatusCode.OK, await call.code()) async def test_async_generator(self): - async def request_generator(): yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE @@ -810,7 +849,6 @@ async def request_generator(): self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_too_many_reads(self): - async def request_generator(): for _ in range(_NUM_STREAM_RESPONSES): yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE @@ -852,7 +890,8 @@ async def test_error_in_async_generator(self): messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE, interval_us=_RESPONSE_INTERVAL_US, - )) + ) + ) # We expect the request iterator to receive the exception request_iterator_received_the_exception = asyncio.Event() @@ -875,8 +914,9 @@ async def cancel_later(): with self.assertRaises(asyncio.CancelledError): async for response in call: - self.assertEqual(_RESPONSE_PAYLOAD_SIZE, - len(response.payload.body)) + self.assertEqual( + _RESPONSE_PAYLOAD_SIZE, len(response.payload.body) + ) await request_iterator_received_the_exception.wait() @@ -898,11 +938,11 @@ async def test_empty_ping_pong(self): for _ in range(_NUM_STREAM_RESPONSES): await call.write(_STREAM_OUTPUT_REQUEST_ONE_EMPTY_RESPONSE) response = await call.read() - self.assertEqual(b'', response.SerializeToString()) + self.assertEqual(b"", response.SerializeToString()) await call.done_writing() self.assertEqual(await call.code(), grpc.StatusCode.OK) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py index 6eb4c3c2d574d..99eede30914bd 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py @@ -31,9 +31,9 @@ _RANDOM_SEED = 42 -_ENABLE_REUSE_PORT = 'SO_REUSEPORT enabled' -_DISABLE_REUSE_PORT = 'SO_REUSEPORT disabled' -_SOCKET_OPT_SO_REUSEPORT = 'grpc.so_reuseport' +_ENABLE_REUSE_PORT = "SO_REUSEPORT enabled" +_DISABLE_REUSE_PORT = "SO_REUSEPORT disabled" +_SOCKET_OPT_SO_REUSEPORT = "grpc.so_reuseport" _OPTIONS = ( (_ENABLE_REUSE_PORT, ((_SOCKET_OPT_SO_REUSEPORT, 1),)), (_DISABLE_REUSE_PORT, ((_SOCKET_OPT_SO_REUSEPORT, 0),)), @@ -41,44 +41,41 @@ _NUM_SERVER_CREATED = 5 -_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH = 'grpc.max_receive_message_length' +_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH = "grpc.max_receive_message_length" _MAX_MESSAGE_LENGTH = 1024 _ADDRESS_TOKEN_ERRNO = errno.EADDRINUSE, errno.ENOSR class _TestPointerWrapper(object): - def __int__(self): return 123456 _TEST_CHANNEL_ARGS = ( - ('arg1', b'bytes_val'), - ('arg2', 'str_val'), - ('arg3', 1), - (b'arg4', 'str_val'), - ('arg6', _TestPointerWrapper()), + ("arg1", b"bytes_val"), + ("arg2", "str_val"), + ("arg3", 1), + (b"arg4", "str_val"), + ("arg6", _TestPointerWrapper()), ) _INVALID_TEST_CHANNEL_ARGS = [ - { - 'foo': 'bar' - }, - (('key',),), - 'str', + {"foo": "bar"}, + (("key",),), + "str", ] async def test_if_reuse_port_enabled(server: aio.Server): - port = server.add_insecure_port('localhost:0') + port = server.add_insecure_port("localhost:0") await server.start() try: with common.bound_socket( - bind_address='localhost', - port=port, - listen=False, + bind_address="localhost", + port=port, + listen=False, ) as (unused_host, bound_port): assert bound_port == port except OSError as e: @@ -92,16 +89,18 @@ async def test_if_reuse_port_enabled(server: aio.Server): class TestChannelArgument(AioTestBase): - async def setUp(self): random.seed(_RANDOM_SEED) - @unittest.skipIf(platform.system() == 'Windows', - 'SO_REUSEPORT only available in Linux-like OS.') - @unittest.skipIf('aarch64' in platform.machine(), - 'SO_REUSEPORT needs to be enabled in Core\'s port.h.') + @unittest.skipIf( + platform.system() == "Windows", + "SO_REUSEPORT only available in Linux-like OS.", + ) + @unittest.skipIf( + "aarch64" in platform.machine(), + "SO_REUSEPORT needs to be enabled in Core's port.h.", + ) async def test_server_so_reuse_port_is_set_properly(self): - async def test_body(): fact, options = random.choice(_OPTIONS) server = aio.server(options=options) @@ -109,11 +108,12 @@ async def test_body(): result = await test_if_reuse_port_enabled(server) if fact == _ENABLE_REUSE_PORT and not result: self.fail( - 'Enabled reuse port in options, but not observed in socket' + "Enabled reuse port in options, but not observed in" + " socket" ) elif fact == _DISABLE_REUSE_PORT and result: self.fail( - 'Disabled reuse port in options, but observed in socket' + "Disabled reuse port in options, but observed in socket" ) finally: await server.stop(None) @@ -123,7 +123,7 @@ async def test_body(): async def test_client(self): # Do not segfault, or raise exception! - channel = aio.insecure_channel('[::]:0', options=_TEST_CHANNEL_ARGS) + channel = aio.insecure_channel("[::]:0", options=_TEST_CHANNEL_ARGS) await channel.close() async def test_server(self): @@ -133,47 +133,60 @@ async def test_server(self): async def test_invalid_client_args(self): for invalid_arg in _INVALID_TEST_CHANNEL_ARGS: - self.assertRaises((ValueError, TypeError), - aio.insecure_channel, - '[::]:0', - options=invalid_arg) + self.assertRaises( + (ValueError, TypeError), + aio.insecure_channel, + "[::]:0", + options=invalid_arg, + ) async def test_max_message_length_applied(self): address, server = await start_test_server() async with aio.insecure_channel( - address, - options=((_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, - _MAX_MESSAGE_LENGTH),)) as channel: + address, + options=( + (_GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, _MAX_MESSAGE_LENGTH), + ), + ) as channel: stub = test_pb2_grpc.TestServiceStub(channel) request = messages_pb2.StreamingOutputCallRequest() # First request will pass request.response_parameters.append( - messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH // 2,)) + messages_pb2.ResponseParameters( + size=_MAX_MESSAGE_LENGTH // 2, + ) + ) # Second request should fail request.response_parameters.append( - messages_pb2.ResponseParameters(size=_MAX_MESSAGE_LENGTH * 2,)) + messages_pb2.ResponseParameters( + size=_MAX_MESSAGE_LENGTH * 2, + ) + ) call = stub.StreamingOutputCall(request) response = await call.read() - self.assertEqual(_MAX_MESSAGE_LENGTH // 2, - len(response.payload.body)) + self.assertEqual( + _MAX_MESSAGE_LENGTH // 2, len(response.payload.body) + ) with self.assertRaises(aio.AioRpcError) as exception_context: await call.read() rpc_error = exception_context.exception - self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, - rpc_error.code()) + self.assertEqual( + grpc.StatusCode.RESOURCE_EXHAUSTED, rpc_error.code() + ) self.assertIn(str(_MAX_MESSAGE_LENGTH), rpc_error.details()) - self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED, await - call.code()) + self.assertEqual( + grpc.StatusCode.RESOURCE_EXHAUSTED, await call.code() + ) await server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py index 46e4d208ccf65..18d3f423269f0 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_ready_test.py @@ -31,10 +31,10 @@ class TestChannelReady(AioTestBase): - async def setUp(self): address, self._port, self._socket = get_socket( - listen=False, sock_options=(socket.SO_REUSEADDR,)) + listen=False, sock_options=(socket.SO_REUSEADDR,) + ) self._channel = aio.insecure_channel(f"{address}:{self._port}") self._socket.close() @@ -44,11 +44,13 @@ async def tearDown(self): async def test_channel_ready_success(self): # Start `channel_ready` as another Task channel_ready_task = self.loop.create_task( - self._channel.channel_ready()) + self._channel.channel_ready() + ) # Wait for TRANSIENT_FAILURE await _common.block_until_certain_state( - self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE) + self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE + ) try: # Start the server @@ -61,10 +63,11 @@ async def test_channel_ready_success(self): async def test_channel_ready_blocked(self): with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for(self._channel.channel_ready(), - test_constants.SHORT_TIMEOUT) + await asyncio.wait_for( + self._channel.channel_ready(), test_constants.SHORT_TIMEOUT + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 699fe798f8240..26ef006fa1548 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -28,13 +28,13 @@ from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server -_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' -_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' -_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' +_UNARY_CALL_METHOD = "/grpc.testing.TestService/UnaryCall" +_UNARY_CALL_METHOD_WITH_SLEEP = "/grpc.testing.TestService/UnaryCallWithSleep" +_STREAMING_OUTPUT_CALL_METHOD = "/grpc.testing.TestService/StreamingOutputCall" _INVOCATION_METADATA = ( - ('x-grpc-test-echo-initial', 'initial-md-value'), - ('x-grpc-test-echo-trailing-bin', b'\x00\x02'), + ("x-grpc-test-echo-initial", "initial-md-value"), + ("x-grpc-test-echo-trailing-bin", b"\x00\x02"), ) _NUM_STREAM_RESPONSES = 5 @@ -43,7 +43,6 @@ class TestChannel(AioTestBase): - async def setUp(self): self._server_target, self._server = await start_test_server() @@ -55,7 +54,8 @@ async def test_async_context(self): hi = channel.unary_unary( _UNARY_CALL_METHOD, request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) await hi(messages_pb2.SimpleRequest()) async def test_unary_unary(self): @@ -63,7 +63,8 @@ async def test_unary_unary(self): hi = channel.unary_unary( _UNARY_CALL_METHOD, request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) response = await hi(messages_pb2.SimpleRequest()) self.assertIsInstance(response, messages_pb2.SimpleResponse) @@ -77,20 +78,32 @@ async def test_unary_call_times_out(self): ) with self.assertRaises(grpc.RpcError) as exception_context: - await hi(messages_pb2.SimpleRequest(), - timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) - - _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable - self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, - exception_context.exception.code()) - self.assertEqual(details.title(), - exception_context.exception.details()) + await hi( + messages_pb2.SimpleRequest(), + timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, + ) + + ( + _, + details, + ) = ( + grpc.StatusCode.DEADLINE_EXCEEDED.value + ) # pylint: disable=unused-variable + self.assertEqual( + grpc.StatusCode.DEADLINE_EXCEEDED, + exception_context.exception.code(), + ) + self.assertEqual( + details.title(), exception_context.exception.details() + ) self.assertIsNotNone(exception_context.exception.initial_metadata()) self.assertIsNotNone( - exception_context.exception.trailing_metadata()) + exception_context.exception.trailing_metadata() + ) - @unittest.skipIf(os.name == 'nt', - 'TODO: https://github.com/grpc/grpc/issues/21658') + @unittest.skipIf( + os.name == "nt", "TODO: https://github.com/grpc/grpc/issues/21658" + ) async def test_unary_call_does_not_times_out(self): async with aio.insecure_channel(self._server_target) as channel: hi = channel.unary_unary( @@ -99,8 +112,10 @@ async def test_unary_call_does_not_times_out(self): response_deserializer=messages_pb2.SimpleResponse.FromString, ) - call = hi(messages_pb2.SimpleRequest(), - timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5) + call = hi( + messages_pb2.SimpleRequest(), + timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_unary_stream(self): @@ -111,7 +126,8 @@ async def test_unary_stream(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) # Invokes the actual RPC call = stub.StreamingOutputCall(request) @@ -120,8 +136,9 @@ async def test_unary_stream(self): response_cnt = 0 async for response in call: response_cnt += 1 - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) @@ -136,7 +153,7 @@ async def test_stream_unary_using_write(self): call = stub.StreamingInputCall() # Prepares the request - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Sends out requests @@ -147,8 +164,10 @@ async def test_stream_unary_using_write(self): # Validates the responses response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close() @@ -158,7 +177,7 @@ async def test_stream_unary_using_async_gen(self): stub = test_pb2_grpc.TestServiceStub(channel) # Prepares the request - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def gen(): @@ -171,8 +190,10 @@ async def gen(): # Validates the responses response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close() @@ -187,13 +208,15 @@ async def test_stream_stream_using_read_write(self): # Prepares the request request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) response = await call.read() - self.assertIsInstance(response, - messages_pb2.StreamingOutputCallResponse) + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) await call.done_writing() @@ -208,7 +231,8 @@ async def test_stream_stream_using_async_gen(self): # Prepares the request request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) async def gen(): for _ in range(_NUM_STREAM_RESPONSES): @@ -218,14 +242,15 @@ async def gen(): call = stub.FullDuplexCall(gen()) async for response in call: - self.assertIsInstance(response, - messages_pb2.StreamingOutputCallResponse) + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(grpc.StatusCode.OK, await call.code()) await channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py index 13ad9b075db33..125dd8fe7f1ed 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py @@ -30,9 +30,9 @@ class _StreamStreamInterceptorEmpty(aio.StreamStreamClientInterceptor): - - async def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): + async def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): return await continuation(client_call_details, request_iterator) def assert_in_final_state(self, test: unittest.TestCase): @@ -40,24 +40,26 @@ def assert_in_final_state(self, test: unittest.TestCase): class _StreamStreamInterceptorWithRequestAndResponseIterator( - aio.StreamStreamClientInterceptor): - - async def intercept_stream_stream(self, continuation, client_call_details, - request_iterator): + aio.StreamStreamClientInterceptor +): + async def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): self.request_iterator = CountingRequestIterator(request_iterator) call = await continuation(client_call_details, self.request_iterator) self.response_iterator = CountingResponseIterator(call) return self.response_iterator def assert_in_final_state(self, test: unittest.TestCase): - test.assertEqual(_NUM_STREAM_REQUESTS, - self.request_iterator.request_cnt) - test.assertEqual(_NUM_STREAM_RESPONSES, - self.response_iterator.response_cnt) + test.assertEqual( + _NUM_STREAM_REQUESTS, self.request_iterator.request_cnt + ) + test.assertEqual( + _NUM_STREAM_RESPONSES, self.response_iterator.response_cnt + ) class TestStreamStreamClientInterceptor(AioTestBase): - async def setUp(self): self._server_target, self._server = await start_test_server() @@ -65,22 +67,22 @@ async def tearDown(self): await self._server.stop(None) async def test_intercepts(self): - for interceptor_class in ( - _StreamStreamInterceptorEmpty, - _StreamStreamInterceptorWithRequestAndResponseIterator): - + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) # Prepares the request request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters( - size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) async def request_iterator(): for _ in range(_NUM_STREAM_REQUESTS): @@ -94,16 +96,18 @@ async def request_iterator(): async for response in call: response_cnt += 1 self.assertIsInstance( - response, messages_pb2.StreamingOutputCallResponse) - self.assertEqual(_RESPONSE_PAYLOAD_SIZE, - len(response.payload.body)) + response, messages_pb2.StreamingOutputCallResponse + ) + self.assertEqual( + _RESPONSE_PAYLOAD_SIZE, len(response.payload.body) + ) self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) - self.assertEqual(await call.details(), '') - self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(await call.details(), "") + self.assertEqual(await call.debug_error_string(), "") self.assertEqual(call.cancel(), False) self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) @@ -114,20 +118,21 @@ async def request_iterator(): async def test_intercepts_using_write_and_read(self): for interceptor_class in ( - _StreamStreamInterceptorEmpty, - _StreamStreamInterceptorWithRequestAndResponseIterator): - + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) # Prepares the request request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters( - size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) call = stub.FullDuplexCall() @@ -135,17 +140,19 @@ async def test_intercepts_using_write_and_read(self): await call.write(request) response = await call.read() self.assertIsInstance( - response, messages_pb2.StreamingOutputCallResponse) - self.assertEqual(_RESPONSE_PAYLOAD_SIZE, - len(response.payload.body)) + response, messages_pb2.StreamingOutputCallResponse + ) + self.assertEqual( + _RESPONSE_PAYLOAD_SIZE, len(response.payload.body) + ) await call.done_writing() self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) - self.assertEqual(await call.details(), '') - self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(await call.details(), "") + self.assertEqual(await call.debug_error_string(), "") self.assertEqual(call.cancel(), False) self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) @@ -156,21 +163,21 @@ async def test_intercepts_using_write_and_read(self): async def test_multiple_interceptors_request_iterator(self): for interceptor_class in ( - _StreamStreamInterceptorEmpty, - _StreamStreamInterceptorWithRequestAndResponseIterator): - + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator, + ): with self.subTest(name=interceptor_class): - interceptors = [interceptor_class(), interceptor_class()] - channel = aio.insecure_channel(self._server_target, - interceptors=interceptors) + channel = aio.insecure_channel( + self._server_target, interceptors=interceptors + ) stub = test_pb2_grpc.TestServiceStub(channel) # Prepares the request request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters( - size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) call = stub.FullDuplexCall() @@ -178,17 +185,19 @@ async def test_multiple_interceptors_request_iterator(self): await call.write(request) response = await call.read() self.assertIsInstance( - response, messages_pb2.StreamingOutputCallResponse) - self.assertEqual(_RESPONSE_PAYLOAD_SIZE, - len(response.payload.body)) + response, messages_pb2.StreamingOutputCallResponse + ) + self.assertEqual( + _RESPONSE_PAYLOAD_SIZE, len(response.payload.body) + ) await call.done_writing() self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) - self.assertEqual(await call.details(), '') - self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(await call.details(), "") + self.assertEqual(await call.debug_error_string(), "") self.assertEqual(call.cancel(), False) self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) @@ -199,6 +208,6 @@ async def test_multiple_interceptors_request_iterator(self): await channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py index ff99920c7f0b7..106be6cc34967 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py @@ -36,9 +36,9 @@ class _StreamUnaryInterceptorEmpty(aio.StreamUnaryClientInterceptor): - - async def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): + async def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): return await continuation(client_call_details, request_iterator) def assert_in_final_state(self, test: unittest.TestCase): @@ -46,21 +46,22 @@ def assert_in_final_state(self, test: unittest.TestCase): class _StreamUnaryInterceptorWithRequestIterator( - aio.StreamUnaryClientInterceptor): - - async def intercept_stream_unary(self, continuation, client_call_details, - request_iterator): + aio.StreamUnaryClientInterceptor +): + async def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): self.request_iterator = CountingRequestIterator(request_iterator) call = await continuation(client_call_details, self.request_iterator) return call def assert_in_final_state(self, test: unittest.TestCase): - test.assertEqual(_NUM_STREAM_REQUESTS, - self.request_iterator.request_cnt) + test.assertEqual( + _NUM_STREAM_REQUESTS, self.request_iterator.request_cnt + ) class TestStreamUnaryClientInterceptor(AioTestBase): - async def setUp(self): self._server_target, self._server = await start_test_server() @@ -68,19 +69,23 @@ async def tearDown(self): await self._server.stop(None) async def test_intercepts(self): - for interceptor_class in (_StreamUnaryInterceptorEmpty, - _StreamUnaryInterceptorWithRequestIterator): - + for interceptor_class in ( + _StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * - _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload( + body=b"\0" * _REQUEST_PAYLOAD_SIZE + ) request = messages_pb2.StreamingInputCallRequest( - payload=payload) + payload=payload + ) async def request_iterator(): for _ in range(_NUM_STREAM_REQUESTS): @@ -90,13 +95,15 @@ async def request_iterator(): response = await call - self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) - self.assertEqual(await call.details(), '') - self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(await call.details(), "") + self.assertEqual(await call.debug_error_string(), "") self.assertEqual(call.cancel(), False) self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) @@ -106,19 +113,23 @@ async def request_iterator(): await channel.close() async def test_intercepts_using_write(self): - for interceptor_class in (_StreamUnaryInterceptorEmpty, - _StreamUnaryInterceptorWithRequestIterator): - + for interceptor_class in ( + _StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * - _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload( + body=b"\0" * _REQUEST_PAYLOAD_SIZE + ) request = messages_pb2.StreamingInputCallRequest( - payload=payload) + payload=payload + ) call = stub.StreamingInputCall() @@ -129,13 +140,15 @@ async def test_intercepts_using_write(self): response = await call - self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) - self.assertEqual(await call.details(), '') - self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(await call.details(), "") + self.assertEqual(await call.debug_error_string(), "") self.assertEqual(call.cancel(), False) self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) @@ -145,20 +158,24 @@ async def test_intercepts_using_write(self): await channel.close() async def test_add_done_callback_interceptor_task_not_finished(self): - for interceptor_class in (_StreamUnaryInterceptorEmpty, - _StreamUnaryInterceptorWithRequestIterator): - + for interceptor_class in ( + _StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * - _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload( + body=b"\0" * _REQUEST_PAYLOAD_SIZE + ) request = messages_pb2.StreamingInputCallRequest( - payload=payload) + payload=payload + ) async def request_iterator(): for _ in range(_NUM_STREAM_REQUESTS): @@ -175,20 +192,24 @@ async def request_iterator(): await channel.close() async def test_add_done_callback_interceptor_task_finished(self): - for interceptor_class in (_StreamUnaryInterceptorEmpty, - _StreamUnaryInterceptorWithRequestIterator): - + for interceptor_class in ( + _StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * - _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload( + body=b"\0" * _REQUEST_PAYLOAD_SIZE + ) request = messages_pb2.StreamingInputCallRequest( - payload=payload) + payload=payload + ) async def request_iterator(): for _ in range(_NUM_STREAM_REQUESTS): @@ -205,20 +226,23 @@ async def request_iterator(): await channel.close() async def test_multiple_interceptors_request_iterator(self): - for interceptor_class in (_StreamUnaryInterceptorEmpty, - _StreamUnaryInterceptorWithRequestIterator): - + for interceptor_class in ( + _StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator, + ): with self.subTest(name=interceptor_class): - interceptors = [interceptor_class(), interceptor_class()] - channel = aio.insecure_channel(self._server_target, - interceptors=interceptors) + channel = aio.insecure_channel( + self._server_target, interceptors=interceptors + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * - _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload( + body=b"\0" * _REQUEST_PAYLOAD_SIZE + ) request = messages_pb2.StreamingInputCallRequest( - payload=payload) + payload=payload + ) async def request_iterator(): for _ in range(_NUM_STREAM_REQUESTS): @@ -228,13 +252,15 @@ async def request_iterator(): response = await call - self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) - self.assertEqual(await call.details(), '') - self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(await call.details(), "") + self.assertEqual(await call.debug_error_string(), "") self.assertEqual(call.cancel(), False) self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) @@ -245,18 +271,22 @@ async def request_iterator(): await channel.close() async def test_intercepts_request_iterator_rpc_error(self): - for interceptor_class in (_StreamUnaryInterceptorEmpty, - _StreamUnaryInterceptorWithRequestIterator): - + for interceptor_class in ( + _StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator, + ): with self.subTest(name=interceptor_class): channel = aio.insecure_channel( - UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + UNREACHABLE_TARGET, interceptors=[interceptor_class()] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * - _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload( + body=b"\0" * _REQUEST_PAYLOAD_SIZE + ) request = messages_pb2.StreamingInputCallRequest( - payload=payload) + payload=payload + ) # When there is an error the request iterator is no longer # consumed. @@ -269,26 +299,32 @@ async def request_iterator(): with self.assertRaises(aio.AioRpcError) as exception_context: await call - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code(), + ) self.assertTrue(call.done()) self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) await channel.close() async def test_intercepts_request_iterator_rpc_error_using_write(self): - for interceptor_class in (_StreamUnaryInterceptorEmpty, - _StreamUnaryInterceptorWithRequestIterator): - + for interceptor_class in ( + _StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator, + ): with self.subTest(name=interceptor_class): channel = aio.insecure_channel( - UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + UNREACHABLE_TARGET, interceptors=[interceptor_class()] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * - _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload( + body=b"\0" * _REQUEST_PAYLOAD_SIZE + ) request = messages_pb2.StreamingInputCallRequest( - payload=payload) + payload=payload + ) call = stub.StreamingInputCall() @@ -300,31 +336,32 @@ async def test_intercepts_request_iterator_rpc_error_using_write(self): with self.assertRaises(aio.AioRpcError) as exception_context: await call - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code(), + ) self.assertTrue(call.done()) self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) await channel.close() async def test_cancel_before_rpc(self): - interceptor_reached = asyncio.Event() wait_for_ever = self.loop.create_future() class Interceptor(aio.StreamUnaryClientInterceptor): - - async def intercept_stream_unary(self, continuation, - client_call_details, - request_iterator): + async def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): interceptor_reached.set() await wait_for_ever - channel = aio.insecure_channel(self._server_target, - interceptors=[Interceptor()]) + channel = aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) call = stub.StreamingInputCall() @@ -351,24 +388,23 @@ async def intercept_stream_unary(self, continuation, await channel.close() async def test_cancel_after_rpc(self): - interceptor_reached = asyncio.Event() wait_for_ever = self.loop.create_future() class Interceptor(aio.StreamUnaryClientInterceptor): - - async def intercept_stream_unary(self, continuation, - client_call_details, - request_iterator): + async def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): call = await continuation(client_call_details, request_iterator) interceptor_reached.set() await wait_for_ever - channel = aio.insecure_channel(self._server_target, - interceptors=[Interceptor()]) + channel = aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) call = stub.StreamingInputCall() @@ -397,18 +433,23 @@ async def intercept_stream_unary(self, continuation, async def test_cancel_while_writing(self): # Test cancelation before making any write or after doing at least 1 for num_writes_before_cancel in (0, 1): - with self.subTest(name="Num writes before cancel: {}".format( - num_writes_before_cancel)): - + with self.subTest( + name="Num writes before cancel: {}".format( + num_writes_before_cancel + ) + ): channel = aio.insecure_channel( UNREACHABLE_TARGET, - interceptors=[_StreamUnaryInterceptorWithRequestIterator()]) + interceptors=[_StreamUnaryInterceptorWithRequestIterator()], + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * - _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload( + body=b"\0" * _REQUEST_PAYLOAD_SIZE + ) request = messages_pb2.StreamingInputCallRequest( - payload=payload) + payload=payload + ) call = stub.StreamingInputCall() @@ -428,21 +469,20 @@ async def test_cancel_while_writing(self): await channel.close() async def test_cancel_by_the_interceptor(self): - class Interceptor(aio.StreamUnaryClientInterceptor): - - async def intercept_stream_unary(self, continuation, - client_call_details, - request_iterator): + async def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): call = await continuation(client_call_details, request_iterator) call.cancel() return call - channel = aio.insecure_channel(UNREACHABLE_TARGET, - interceptors=[Interceptor()]) + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[Interceptor()] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) call = stub.StreamingInputCall() @@ -461,22 +501,21 @@ async def intercept_stream_unary(self, continuation, await channel.close() async def test_exception_raised_by_interceptor(self): - class InterceptorException(Exception): pass class Interceptor(aio.StreamUnaryClientInterceptor): - - async def intercept_stream_unary(self, continuation, - client_call_details, - request_iterator): + async def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): raise InterceptorException - channel = aio.insecure_channel(UNREACHABLE_TARGET, - interceptors=[Interceptor()]) + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[Interceptor()] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) call = stub.StreamingInputCall() @@ -492,10 +531,11 @@ async def intercept_stream_unary(self, continuation, async def test_intercepts_prohibit_mixing_style(self): channel = aio.insecure_channel( - self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()]) + self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()] + ) stub = test_pb2_grpc.TestServiceStub(channel) - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def request_iterator(): @@ -513,6 +553,6 @@ async def request_iterator(): await channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py index f0c0cba8eb9ca..c86cb36b02b59 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py @@ -37,9 +37,9 @@ class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): - - async def intercept_unary_stream(self, continuation, client_call_details, - request): + async def intercept_unary_stream( + self, continuation, client_call_details, request + ): return await continuation(client_call_details, request) def assert_in_final_state(self, test: unittest.TestCase): @@ -47,21 +47,22 @@ def assert_in_final_state(self, test: unittest.TestCase): class _UnaryStreamInterceptorWithResponseIterator( - aio.UnaryStreamClientInterceptor): - - async def intercept_unary_stream(self, continuation, client_call_details, - request): + aio.UnaryStreamClientInterceptor +): + async def intercept_unary_stream( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) self.response_iterator = CountingResponseIterator(call) return self.response_iterator def assert_in_final_state(self, test: unittest.TestCase): - test.assertEqual(_NUM_STREAM_RESPONSES, - self.response_iterator.response_cnt) + test.assertEqual( + _NUM_STREAM_RESPONSES, self.response_iterator.response_cnt + ) class TestUnaryStreamClientInterceptor(AioTestBase): - async def setUp(self): self._server_target, self._server = await start_test_server() @@ -69,19 +70,26 @@ async def tearDown(self): await self._server.stop(None) async def test_intercepts(self): - for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWithResponseIterator): - + for interceptor_class in ( + _UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() request = messages_pb2.StreamingOutputCallRequest() - request.response_parameters.extend([ - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) - ] * _NUM_STREAM_RESPONSES) + request.response_parameters.extend( + [ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE + ) + ] + * _NUM_STREAM_RESPONSES + ) - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -90,17 +98,19 @@ async def test_intercepts(self): response_cnt = 0 async for response in call: response_cnt += 1 - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) - self.assertEqual(_RESPONSE_PAYLOAD_SIZE, - len(response.payload.body)) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) + self.assertEqual( + _RESPONSE_PAYLOAD_SIZE, len(response.payload.body) + ) self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) - self.assertEqual(await call.details(), '') - self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(await call.details(), "") + self.assertEqual(await call.debug_error_string(), "") self.assertEqual(call.cancel(), False) self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) @@ -110,19 +120,26 @@ async def test_intercepts(self): await channel.close() async def test_add_done_callback_interceptor_task_not_finished(self): - for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWithResponseIterator): - + for interceptor_class in ( + _UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() request = messages_pb2.StreamingOutputCallRequest() - request.response_parameters.extend([ - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) - ] * _NUM_STREAM_RESPONSES) + request.response_parameters.extend( + [ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE + ) + ] + * _NUM_STREAM_RESPONSES + ) - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -136,19 +153,26 @@ async def test_add_done_callback_interceptor_task_not_finished(self): await channel.close() async def test_add_done_callback_interceptor_task_finished(self): - for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWithResponseIterator): - + for interceptor_class in ( + _UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator, + ): with self.subTest(name=interceptor_class): interceptor = interceptor_class() request = messages_pb2.StreamingOutputCallRequest() - request.response_parameters.extend([ - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) - ] * _NUM_STREAM_RESPONSES) + request.response_parameters.extend( + [ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE + ) + ] + * _NUM_STREAM_RESPONSES + ) - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -169,14 +193,16 @@ async def test_add_done_callback_interceptor_task_finished(self): async def test_response_iterator_using_read(self): interceptor = _UnaryStreamInterceptorWithResponseIterator() - channel = aio.insecure_channel(self._server_target, - interceptors=[interceptor]) + channel = aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) stub = test_pb2_grpc.TestServiceStub(channel) request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.extend( - [messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] * - _NUM_STREAM_RESPONSES) + [messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] + * _NUM_STREAM_RESPONSES + ) call = stub.StreamingOutputCall(request) @@ -184,43 +210,53 @@ async def test_response_iterator_using_read(self): for response in range(_NUM_STREAM_RESPONSES): response = await call.read() response_cnt += 1 - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) - self.assertEqual(interceptor.response_iterator.response_cnt, - _NUM_STREAM_RESPONSES) + self.assertEqual( + interceptor.response_iterator.response_cnt, _NUM_STREAM_RESPONSES + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close() async def test_multiple_interceptors_response_iterator(self): - for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWithResponseIterator): - + for interceptor_class in ( + _UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator, + ): with self.subTest(name=interceptor_class): - interceptors = [interceptor_class(), interceptor_class()] - channel = aio.insecure_channel(self._server_target, - interceptors=interceptors) + channel = aio.insecure_channel( + self._server_target, interceptors=interceptors + ) stub = test_pb2_grpc.TestServiceStub(channel) request = messages_pb2.StreamingOutputCallRequest() - request.response_parameters.extend([ - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) - ] * _NUM_STREAM_RESPONSES) + request.response_parameters.extend( + [ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE + ) + ] + * _NUM_STREAM_RESPONSES + ) call = stub.StreamingOutputCall(request) response_cnt = 0 async for response in call: response_cnt += 1 - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) - self.assertEqual(_RESPONSE_PAYLOAD_SIZE, - len(response.payload.body)) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) + self.assertEqual( + _RESPONSE_PAYLOAD_SIZE, len(response.payload.body) + ) self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) @@ -228,13 +264,14 @@ async def test_multiple_interceptors_response_iterator(self): await channel.close() async def test_intercepts_response_iterator_rpc_error(self): - for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWithResponseIterator): - + for interceptor_class in ( + _UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator, + ): with self.subTest(name=interceptor_class): - channel = aio.insecure_channel( - UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + UNREACHABLE_TARGET, interceptors=[interceptor_class()] + ) request = messages_pb2.StreamingOutputCallRequest() stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -243,27 +280,29 @@ async def test_intercepts_response_iterator_rpc_error(self): async for response in call: pass - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code(), + ) self.assertTrue(call.done()) self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) await channel.close() async def test_cancel_before_rpc(self): - interceptor_reached = asyncio.Event() wait_for_ever = self.loop.create_future() class Interceptor(aio.UnaryStreamClientInterceptor): - - async def intercept_unary_stream(self, continuation, - client_call_details, request): + async def intercept_unary_stream( + self, continuation, client_call_details, request + ): interceptor_reached.set() await wait_for_ever - channel = aio.insecure_channel(UNREACHABLE_TARGET, - interceptors=[Interceptor()]) + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[Interceptor()] + ) request = messages_pb2.StreamingOutputCallRequest() stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -286,20 +325,20 @@ async def intercept_unary_stream(self, continuation, await channel.close() async def test_cancel_after_rpc(self): - interceptor_reached = asyncio.Event() wait_for_ever = self.loop.create_future() class Interceptor(aio.UnaryStreamClientInterceptor): - - async def intercept_unary_stream(self, continuation, - client_call_details, request): + async def intercept_unary_stream( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) interceptor_reached.set() await wait_for_ever - channel = aio.insecure_channel(UNREACHABLE_TARGET, - interceptors=[Interceptor()]) + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[Interceptor()] + ) request = messages_pb2.StreamingOutputCallRequest() stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -324,12 +363,14 @@ async def intercept_unary_stream(self, continuation, async def test_cancel_consuming_response_iterator(self): request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.extend( - [messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] * - _NUM_STREAM_RESPONSES) + [messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] + * _NUM_STREAM_RESPONSES + ) channel = aio.insecure_channel( self._server_target, - interceptors=[_UnaryStreamInterceptorWithResponseIterator()]) + interceptors=[_UnaryStreamInterceptorWithResponseIterator()], + ) stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -343,17 +384,17 @@ async def test_cancel_consuming_response_iterator(self): await channel.close() async def test_cancel_by_the_interceptor(self): - class Interceptor(aio.UnaryStreamClientInterceptor): - - async def intercept_unary_stream(self, continuation, - client_call_details, request): + async def intercept_unary_stream( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) call.cancel() return call - channel = aio.insecure_channel(UNREACHABLE_TARGET, - interceptors=[Interceptor()]) + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[Interceptor()] + ) request = messages_pb2.StreamingOutputCallRequest() stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -368,18 +409,18 @@ async def intercept_unary_stream(self, continuation, await channel.close() async def test_exception_raised_by_interceptor(self): - class InterceptorException(Exception): pass class Interceptor(aio.UnaryStreamClientInterceptor): - - async def intercept_unary_stream(self, continuation, - client_call_details, request): + async def intercept_unary_stream( + self, continuation, client_call_details, request + ): raise InterceptorException - channel = aio.insecure_channel(UNREACHABLE_TARGET, - interceptors=[Interceptor()]) + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[Interceptor()] + ) request = messages_pb2.StreamingOutputCallRequest() stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -391,6 +432,6 @@ async def intercept_unary_stream(self, continuation, await channel.close() -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py index bcb5df5481939..4789b309504ea 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py @@ -27,16 +27,15 @@ from tests_aio.unit._test_server import _TRAILING_METADATA_KEY from tests_aio.unit._test_server import start_test_server -_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' +_LOCAL_CANCEL_DETAILS_EXPECTATION = "Locally cancelled by application!" _INITIAL_METADATA_TO_INJECT = aio.Metadata( - (_INITIAL_METADATA_KEY, 'extra info'), - (_TRAILING_METADATA_KEY, b'\x13\x37'), + (_INITIAL_METADATA_KEY, "extra info"), + (_TRAILING_METADATA_KEY, b"\x13\x37"), ) _TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED = 1.0 class TestUnaryUnaryClientInterceptor(AioTestBase): - async def setUp(self): self._server_target, self._server = await start_test_server() @@ -44,7 +43,6 @@ async def tearDown(self): await self._server.stop(None) def test_invalid_interceptor(self): - class InvalidInterceptor: """Just an invalid Interceptor""" @@ -52,26 +50,28 @@ class InvalidInterceptor: aio.insecure_channel("", interceptors=[InvalidInterceptor()]) async def test_executed_right_order(self): - interceptors_executed = [] class Interceptor(aio.UnaryUnaryClientInterceptor): """Interceptor used for testing if the interceptor is being called""" - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): interceptors_executed.append(self) call = await continuation(client_call_details, request) return call interceptors = [Interceptor() for i in range(2)] - async with aio.insecure_channel(self._server_target, - interceptors=interceptors) as channel: + async with aio.insecure_channel( + self._server_target, interceptors=interceptors + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) response = await call @@ -94,15 +94,15 @@ def test_modify_credentials(self): raise NotImplementedError() async def test_status_code_Ok(self): - class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor): """Interceptor used for observing status code Ok returned by the RPC""" def __init__(self): self.status_code_Ok_observed = False - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) code = await call.code() if code == grpc.StatusCode.OK: @@ -112,58 +112,63 @@ async def intercept_unary_unary(self, continuation, interceptor = StatusCodeOkInterceptor() - async with aio.insecure_channel(self._server_target, - interceptors=[interceptor]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) as channel: # when no error StatusCode.OK must be observed multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) await multicallable(messages_pb2.SimpleRequest()) self.assertTrue(interceptor.status_code_Ok_observed) async def test_add_timeout(self): - class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor): """Interceptor used for adding a timeout to the RPC""" - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, metadata=client_call_details.metadata, credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready) + wait_for_ready=client_call_details.wait_for_ready, + ) return await continuation(new_client_call_details, request) interceptor = TimeoutInterceptor() - async with aio.insecure_channel(self._server_target, - interceptors=[interceptor]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCallWithSleep', + "/grpc.testing.TestService/UnaryCallWithSleep", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) with self.assertRaises(aio.AioRpcError) as exception_context: await call - self.assertEqual(exception_context.exception.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual( + exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED, + ) self.assertTrue(call.done()) - self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await - call.code()) + self.assertEqual( + grpc.StatusCode.DEADLINE_EXCEEDED, await call.code() + ) async def test_retry(self): - class RetryInterceptor(aio.UnaryUnaryClientInterceptor): """Simulates a Retry Interceptor which ends up by making two RPC calls.""" @@ -171,15 +176,16 @@ class RetryInterceptor(aio.UnaryUnaryClientInterceptor): def __init__(self): self.calls = [] - async def intercept_unary_unary(self, continuation, - client_call_details, request): - + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, metadata=client_call_details.metadata, credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready) + wait_for_ready=client_call_details.wait_for_ready, + ) try: call = await continuation(new_client_call_details, request) @@ -194,7 +200,8 @@ async def intercept_unary_unary(self, continuation, timeout=None, metadata=client_call_details.metadata, credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready) + wait_for_ready=client_call_details.wait_for_ready, + ) call = await continuation(new_client_call_details, request) self.calls.append(call) @@ -202,13 +209,14 @@ async def intercept_unary_unary(self, continuation, interceptor = RetryInterceptor() - async with aio.insecure_channel(self._server_target, - interceptors=[interceptor]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[interceptor] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCallWithSleep', + "/grpc.testing.TestService/UnaryCallWithSleep", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) @@ -219,17 +227,19 @@ async def intercept_unary_unary(self, continuation, # Check that two calls were made, first one finishing with # a deadline and second one finishing ok.. self.assertEqual(len(interceptor.calls), 2) - self.assertEqual(await interceptor.calls[0].code(), - grpc.StatusCode.DEADLINE_EXCEEDED) - self.assertEqual(await interceptor.calls[1].code(), - grpc.StatusCode.OK) + self.assertEqual( + await interceptor.calls[0].code(), + grpc.StatusCode.DEADLINE_EXCEEDED, + ) + self.assertEqual( + await interceptor.calls[1].code(), grpc.StatusCode.OK + ) async def test_retry_with_multiple_interceptors(self): - class RetryInterceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): # Simulate retry twice for _ in range(2): call = await continuation(client_call_details, request) @@ -237,12 +247,12 @@ async def intercept_unary_unary(self, continuation, return result class AnotherInterceptor(aio.UnaryUnaryClientInterceptor): - def __init__(self): self.called_times = 0 - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): self.called_times += 1 call = await continuation(client_call_details, request) result = await call @@ -252,14 +262,14 @@ async def intercept_unary_unary(self, continuation, retry_interceptor = RetryInterceptor() another_interceptor = AnotherInterceptor() async with aio.insecure_channel( - self._server_target, - interceptors=[retry_interceptor, - another_interceptor]) as channel: - + self._server_target, + interceptors=[retry_interceptor, another_interceptor], + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCallWithSleep', + "/grpc.testing.TestService/UnaryCallWithSleep", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) @@ -269,34 +279,37 @@ async def intercept_unary_unary(self, continuation, self.assertEqual(another_interceptor.called_times, 2) async def test_rpcresponse(self): - class Interceptor(aio.UnaryUnaryClientInterceptor): """Raw responses are seen as reegular calls""" - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) response = await call return call class ResponseInterceptor(aio.UnaryUnaryClientInterceptor): """Return a raw response""" + response = messages_pb2.SimpleResponse() - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): return ResponseInterceptor.response interceptor, interceptor_response = Interceptor(), ResponseInterceptor() async with aio.insecure_channel( - self._server_target, - interceptors=[interceptor, interceptor_response]) as channel: - + self._server_target, + interceptors=[interceptor, interceptor_response], + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) response = await call @@ -310,14 +323,13 @@ async def intercept_unary_unary(self, continuation, self.assertFalse(call.cancel()) self.assertFalse(call.cancelled()) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.details(), '') + self.assertEqual(await call.details(), "") self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.trailing_metadata(), None) self.assertEqual(await call.debug_error_string(), None) class TestInterceptedUnaryUnaryCall(AioTestBase): - async def setUp(self): self._server_target, self._server = await start_test_server() @@ -325,22 +337,21 @@ async def tearDown(self): await self._server.stop(None) async def test_call_ok(self): - class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) response = await call @@ -348,28 +359,27 @@ async def intercept_unary_unary(self, continuation, self.assertFalse(call.cancelled()) self.assertEqual(type(response), messages_pb2.SimpleResponse) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.details(), '') + self.assertEqual(await call.details(), "") self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_ok_awaited(self): - class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) await call return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) response = await call @@ -377,97 +387,98 @@ async def intercept_unary_unary(self, continuation, self.assertFalse(call.cancelled()) self.assertEqual(type(response), messages_pb2.SimpleResponse) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertEqual(await call.details(), '') + self.assertEqual(await call.details(), "") self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_rpc_error(self): - class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCallWithSleep', + "/grpc.testing.TestService/UnaryCallWithSleep", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable( messages_pb2.SimpleRequest(), - timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2) + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, + ) with self.assertRaises(aio.AioRpcError) as exception_context: await call self.assertTrue(call.done()) self.assertFalse(call.cancelled()) - self.assertEqual(await call.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) - self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual( + await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED + ) + self.assertEqual(await call.details(), "Deadline Exceeded") self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_call_rpc_error_awaited(self): - class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) await call return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCallWithSleep', + "/grpc.testing.TestService/UnaryCallWithSleep", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable( messages_pb2.SimpleRequest(), - timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2) + timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2, + ) with self.assertRaises(aio.AioRpcError) as exception_context: await call self.assertTrue(call.done()) self.assertFalse(call.cancelled()) - self.assertEqual(await call.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) - self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual( + await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED + ) + self.assertEqual(await call.details(), "Deadline Exceeded") self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual(await call.trailing_metadata(), aio.Metadata()) async def test_cancel_before_rpc(self): - interceptor_reached = asyncio.Event() wait_for_ever = self.loop.create_future() class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): interceptor_reached.set() await wait_for_ever - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) self.assertFalse(call.cancelled()) @@ -482,33 +493,33 @@ async def intercept_unary_unary(self, continuation, self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual( + await call.details(), _LOCAL_CANCEL_DETAILS_EXPECTATION + ) self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.trailing_metadata(), None) async def test_cancel_after_rpc(self): - interceptor_reached = asyncio.Event() wait_for_ever = self.loop.create_future() class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) await call interceptor_reached.set() await wait_for_ever - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) self.assertFalse(call.cancelled()) @@ -523,30 +534,30 @@ async def intercept_unary_unary(self, continuation, self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual( + await call.details(), _LOCAL_CANCEL_DETAILS_EXPECTATION + ) self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.trailing_metadata(), None) async def test_cancel_inside_interceptor_after_rpc_awaiting(self): - class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) call.cancel() await call return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) with self.assertRaises(asyncio.CancelledError): @@ -555,29 +566,29 @@ async def intercept_unary_unary(self, continuation, self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual( + await call.details(), _LOCAL_CANCEL_DETAILS_EXPECTATION + ) self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.trailing_metadata(), None) async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self): - class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) call.cancel() return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) with self.assertRaises(asyncio.CancelledError): @@ -586,21 +597,24 @@ async def intercept_unary_unary(self, continuation, self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual( + await call.details(), _LOCAL_CANCEL_DETAILS_EXPECTATION + ) self.assertEqual(await call.initial_metadata(), aio.Metadata()) self.assertEqual( - await call.trailing_metadata(), aio.Metadata(), - "When the raw response is None, empty metadata is returned") + await call.trailing_metadata(), + aio.Metadata(), + "When the raw response is None, empty metadata is returned", + ) async def test_initial_metadata_modification(self): - class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): - new_metadata = aio.Metadata(*client_call_details.metadata, - *_INITIAL_METADATA_TO_INJECT) + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): + new_metadata = aio.Metadata( + *client_call_details.metadata, *_INITIAL_METADATA_TO_INJECT + ) new_details = aio.ClientCallDetails( method=client_call_details.method, timeout=client_call_details.timeout, @@ -610,9 +624,9 @@ async def intercept_unary_unary(self, continuation, ) return await continuation(new_details, request) - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: stub = test_pb2_grpc.TestServiceStub(channel) call = stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -621,17 +635,21 @@ async def intercept_unary_unary(self, continuation, _common.seen_metadatum( expected_key=_INITIAL_METADATA_KEY, expected_value=_INITIAL_METADATA_TO_INJECT[ - _INITIAL_METADATA_KEY], + _INITIAL_METADATA_KEY + ], actual=await call.initial_metadata(), - )) + ) + ) # Expected to see the echoed trailing metadata self.assertTrue( _common.seen_metadatum( expected_key=_TRAILING_METADATA_KEY, expected_value=_INITIAL_METADATA_TO_INJECT[ - _TRAILING_METADATA_KEY], + _TRAILING_METADATA_KEY + ], actual=await call.trailing_metadata(), - )) + ) + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_add_done_callback_before_finishes(self): @@ -642,22 +660,21 @@ def callback(call): called.set() class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): - + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): await interceptor_can_continue.wait() call = await continuation(client_call_details, request) return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) call.add_done_callback(callback) interceptor_can_continue.set() @@ -665,8 +682,8 @@ async def intercept_unary_unary(self, continuation, try: await asyncio.wait_for( - called.wait(), - timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) + called.wait(), timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED + ) except: self.fail("Callback was not called") @@ -677,21 +694,20 @@ def callback(call): called.set() class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): - + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) await call @@ -700,8 +716,8 @@ async def intercept_unary_unary(self, continuation, try: await asyncio.wait_for( - called.wait(), - timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) + called.wait(), timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED + ) except: self.fail("Callback was not called") @@ -712,21 +728,20 @@ def callback(call): called.set() class Interceptor(aio.UnaryUnaryClientInterceptor): - - async def intercept_unary_unary(self, continuation, - client_call_details, request): - + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): call = await continuation(client_call_details, request) return call - async with aio.insecure_channel(self._server_target, - interceptors=[Interceptor() - ]) as channel: - + async with aio.insecure_channel( + self._server_target, interceptors=[Interceptor()] + ) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) call.add_done_callback(callback) @@ -735,12 +750,12 @@ async def intercept_unary_unary(self, continuation, try: await asyncio.wait_for( - called.wait(), - timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) + called.wait(), timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED + ) except: self.fail("Callback was not called") -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 8d481a9a3b0a6..62063beaaf0c7 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -26,12 +26,11 @@ from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server -_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' +_UNARY_CALL_METHOD_WITH_SLEEP = "/grpc.testing.TestService/UnaryCallWithSleep" _LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60 class TestCloseChannel(AioTestBase): - async def setUp(self): self._server_target, self._server = await start_test_server() @@ -134,6 +133,6 @@ async def test_channel_isolation(self): self.assertTrue(call2.cancelled()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py b/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py index 4749f39d035bb..255ec59319d85 100644 --- a/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/compatibility_test.py @@ -36,40 +36,44 @@ _NUM_STREAM_RESPONSES = 5 _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 -_REQUEST = b'\x03\x07' +_REQUEST = b"\x03\x07" def _unique_options() -> Sequence[Tuple[str, float]]: - return (('iv', random.random()),) + return (("iv", random.random()),) @unittest.skipIf( - os.environ.get('GRPC_ASYNCIO_ENGINE', '').lower() == 'custom_io_manager', - 'Compatible mode needs POLLER completion queue.') + os.environ.get("GRPC_ASYNCIO_ENGINE", "").lower() == "custom_io_manager", + "Compatible mode needs POLLER completion queue.", +) class TestCompatibility(AioTestBase): - async def setUp(self): self._async_server = aio.server( - options=(('grpc.so_reuseport', 0),), - migration_thread_pool=ThreadPoolExecutor()) + options=(("grpc.so_reuseport", 0),), + migration_thread_pool=ThreadPoolExecutor(), + ) - test_pb2_grpc.add_TestServiceServicer_to_server(TestServiceServicer(), - self._async_server) + test_pb2_grpc.add_TestServiceServicer_to_server( + TestServiceServicer(), self._async_server + ) self._adhoc_handlers = _common.AdhocGenericHandler() self._async_server.add_generic_rpc_handlers((self._adhoc_handlers,)) - port = self._async_server.add_insecure_port('[::]:0') - address = 'localhost:%d' % port + port = self._async_server.add_insecure_port("[::]:0") + address = "localhost:%d" % port await self._async_server.start() # Create async stub - self._async_channel = aio.insecure_channel(address, - options=_unique_options()) + self._async_channel = aio.insecure_channel( + address, options=_unique_options() + ) self._async_stub = test_pb2_grpc.TestServiceStub(self._async_channel) # Create sync stub - self._sync_channel = grpc.insecure_channel(address, - options=_unique_options()) + self._sync_channel = grpc.insecure_channel( + address, options=_unique_options() + ) self._sync_stub = test_pb2_grpc.TestServiceStub(self._sync_channel) async def tearDown(self): @@ -91,14 +95,16 @@ def thread_work(): async def test_unary_unary(self): # Calling async API in this thread - await self._async_stub.UnaryCall(messages_pb2.SimpleRequest(), - timeout=test_constants.LONG_TIMEOUT) + await self._async_stub.UnaryCall( + messages_pb2.SimpleRequest(), timeout=test_constants.LONG_TIMEOUT + ) # Calling sync API in a different thread def sync_work() -> None: response, call = self._sync_stub.UnaryCall.with_call( messages_pb2.SimpleRequest(), - timeout=test_constants.LONG_TIMEOUT) + timeout=test_constants.LONG_TIMEOUT, + ) self.assertIsInstance(response, messages_pb2.SimpleResponse) self.assertEqual(grpc.StatusCode.OK, call.code()) @@ -108,7 +114,8 @@ async def test_unary_stream(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) # Calling async API in this thread call = self._async_stub.StreamingOutputCall(request) @@ -127,7 +134,7 @@ def sync_work() -> None: await self._run_in_another_thread(sync_work) async def test_stream_unary(self): - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Calling async API in this thread @@ -136,22 +143,28 @@ async def gen(): yield request response = await self._async_stub.StreamingInputCall(gen()) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) # Calling sync API in a different thread def sync_work() -> None: response = self._sync_stub.StreamingInputCall( - iter([request] * _NUM_STREAM_RESPONSES)) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + iter([request] * _NUM_STREAM_RESPONSES) + ) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) await self._run_in_another_thread(sync_work) async def test_stream_stream(self): request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) # Calling async API in this thread call = self._async_stub.FullDuplexCall() @@ -174,24 +187,23 @@ def sync_work() -> None: await self._run_in_another_thread(sync_work) async def test_server(self): - class GenericHandlers(grpc.GenericRpcHandler): - def service(self, handler_call_details): return grpc.unary_unary_rpc_method_handler(lambda x, _: x) # It's fine to instantiate server object in the event loop thread. # The server will spawn its own serving thread. - server = grpc.server(ThreadPoolExecutor(), - handlers=(GenericHandlers(),)) - port = server.add_insecure_port('localhost:0') + server = grpc.server( + ThreadPoolExecutor(), handlers=(GenericHandlers(),) + ) + port = server.add_insecure_port("localhost:0") server.start() def sync_work() -> None: for _ in range(100): - with grpc.insecure_channel('localhost:%d' % port) as channel: - response = channel.unary_unary('/test/test')(b'\x07\x08') - self.assertEqual(response, b'\x07\x08') + with grpc.insecure_channel("localhost:%d" % port) as channel: + response = channel.unary_unary("/test/test")(b"\x07\x08") + self.assertEqual(response, b"\x07\x08") await self._run_in_another_thread(sync_work) @@ -200,11 +212,11 @@ async def test_many_loop(self): # Run another loop in another thread def sync_work(): - async def async_work(): # Create async stub - async_channel = aio.insecure_channel(address, - options=_unique_options()) + async_channel = aio.insecure_channel( + address, options=_unique_options() + ) async_stub = test_pb2_grpc.TestServiceStub(async_channel) call = async_stub.UnaryCall(messages_pb2.SimpleRequest()) @@ -219,18 +231,18 @@ async def async_work(): await server.stop(None) async def test_sync_unary_unary_success(self): - @grpc.unary_unary_rpc_method_handler def echo_unary_unary(request: bytes, unused_context): return request self._adhoc_handlers.set_adhoc_handler(echo_unary_unary) - response = await self._async_channel.unary_unary(_common.ADHOC_METHOD - )(_REQUEST) + response = await self._async_channel.unary_unary(_common.ADHOC_METHOD)( + _REQUEST + ) self.assertEqual(_REQUEST, response) async def test_sync_unary_unary_metadata(self): - metadata = (('unique', 'key-42'),) + metadata = (("unique", "key-42"),) @grpc.unary_unary_rpc_method_handler def metadata_unary_unary(request: bytes, context: grpc.ServicerContext): @@ -240,37 +252,40 @@ def metadata_unary_unary(request: bytes, context: grpc.ServicerContext): self._adhoc_handlers.set_adhoc_handler(metadata_unary_unary) call = self._async_channel.unary_unary(_common.ADHOC_METHOD)(_REQUEST) self.assertTrue( - _common.seen_metadata(aio.Metadata(*metadata), await - call.initial_metadata())) + _common.seen_metadata( + aio.Metadata(*metadata), await call.initial_metadata() + ) + ) async def test_sync_unary_unary_abort(self): - @grpc.unary_unary_rpc_method_handler def abort_unary_unary(request: bytes, context: grpc.ServicerContext): - context.abort(grpc.StatusCode.INTERNAL, 'Test') + context.abort(grpc.StatusCode.INTERNAL, "Test") self._adhoc_handlers.set_adhoc_handler(abort_unary_unary) with self.assertRaises(aio.AioRpcError) as exception_context: - await self._async_channel.unary_unary(_common.ADHOC_METHOD - )(_REQUEST) - self.assertEqual(grpc.StatusCode.INTERNAL, - exception_context.exception.code()) + await self._async_channel.unary_unary(_common.ADHOC_METHOD)( + _REQUEST + ) + self.assertEqual( + grpc.StatusCode.INTERNAL, exception_context.exception.code() + ) async def test_sync_unary_unary_set_code(self): - @grpc.unary_unary_rpc_method_handler def set_code_unary_unary(request: bytes, context: grpc.ServicerContext): context.set_code(grpc.StatusCode.INTERNAL) self._adhoc_handlers.set_adhoc_handler(set_code_unary_unary) with self.assertRaises(aio.AioRpcError) as exception_context: - await self._async_channel.unary_unary(_common.ADHOC_METHOD - )(_REQUEST) - self.assertEqual(grpc.StatusCode.INTERNAL, - exception_context.exception.code()) + await self._async_channel.unary_unary(_common.ADHOC_METHOD)( + _REQUEST + ) + self.assertEqual( + grpc.StatusCode.INTERNAL, exception_context.exception.code() + ) async def test_sync_unary_stream_success(self): - @grpc.unary_stream_rpc_method_handler def echo_unary_stream(request: bytes, unused_context): for _ in range(_NUM_STREAM_RESPONSES): @@ -282,86 +297,92 @@ def echo_unary_stream(request: bytes, unused_context): self.assertEqual(_REQUEST, response) async def test_sync_unary_stream_error(self): - @grpc.unary_stream_rpc_method_handler def error_unary_stream(request: bytes, unused_context): for _ in range(_NUM_STREAM_RESPONSES): yield request - raise RuntimeError('Test') + raise RuntimeError("Test") self._adhoc_handlers.set_adhoc_handler(error_unary_stream) call = self._async_channel.unary_stream(_common.ADHOC_METHOD)(_REQUEST) with self.assertRaises(aio.AioRpcError) as exception_context: async for response in call: self.assertEqual(_REQUEST, response) - self.assertEqual(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) async def test_sync_stream_unary_success(self): - @grpc.stream_unary_rpc_method_handler - def echo_stream_unary(request_iterator: Iterable[bytes], - unused_context): + def echo_stream_unary( + request_iterator: Iterable[bytes], unused_context + ): self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES) return _REQUEST self._adhoc_handlers.set_adhoc_handler(echo_stream_unary) request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) - response = await self._async_channel.stream_unary(_common.ADHOC_METHOD - )(request_iterator) + response = await self._async_channel.stream_unary(_common.ADHOC_METHOD)( + request_iterator + ) self.assertEqual(_REQUEST, response) async def test_sync_stream_unary_error(self): - @grpc.stream_unary_rpc_method_handler - def echo_stream_unary(request_iterator: Iterable[bytes], - unused_context): + def echo_stream_unary( + request_iterator: Iterable[bytes], unused_context + ): self.assertEqual(len(list(request_iterator)), _NUM_STREAM_RESPONSES) - raise RuntimeError('Test') + raise RuntimeError("Test") self._adhoc_handlers.set_adhoc_handler(echo_stream_unary) request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) with self.assertRaises(aio.AioRpcError) as exception_context: response = await self._async_channel.stream_unary( - _common.ADHOC_METHOD)(request_iterator) - self.assertEqual(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + _common.ADHOC_METHOD + )(request_iterator) + self.assertEqual( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) async def test_sync_stream_stream_success(self): - @grpc.stream_stream_rpc_method_handler - def echo_stream_stream(request_iterator: Iterable[bytes], - unused_context): + def echo_stream_stream( + request_iterator: Iterable[bytes], unused_context + ): for request in request_iterator: yield request self._adhoc_handlers.set_adhoc_handler(echo_stream_stream) request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) - call = self._async_channel.stream_stream( - _common.ADHOC_METHOD)(request_iterator) + call = self._async_channel.stream_stream(_common.ADHOC_METHOD)( + request_iterator + ) async for response in call: self.assertEqual(_REQUEST, response) async def test_sync_stream_stream_error(self): - @grpc.stream_stream_rpc_method_handler - def echo_stream_stream(request_iterator: Iterable[bytes], - unused_context): + def echo_stream_stream( + request_iterator: Iterable[bytes], unused_context + ): for request in request_iterator: yield request - raise RuntimeError('test') + raise RuntimeError("test") self._adhoc_handlers.set_adhoc_handler(echo_stream_stream) request_iterator = iter([_REQUEST] * _NUM_STREAM_RESPONSES) - call = self._async_channel.stream_stream( - _common.ADHOC_METHOD)(request_iterator) + call = self._async_channel.stream_stream(_common.ADHOC_METHOD)( + request_iterator + ) with self.assertRaises(aio.AioRpcError) as exception_context: async for response in call: self.assertEqual(_REQUEST, response) - self.assertEqual(grpc.StatusCode.UNKNOWN, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNKNOWN, exception_context.exception.code() + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/compression_test.py b/src/python/grpcio_tests/tests_aio/unit/compression_test.py index eb28a93f58c37..74d3a8a79917a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/compression_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/compression_test.py @@ -25,19 +25,23 @@ from tests_aio.unit import _common from tests_aio.unit._test_base import AioTestBase -_GZIP_CHANNEL_ARGUMENT = ('grpc.default_compression_algorithm', 2) -_GZIP_DISABLED_CHANNEL_ARGUMENT = ('grpc.compression_enabled_algorithms_bitset', - 3) +_GZIP_CHANNEL_ARGUMENT = ("grpc.default_compression_algorithm", 2) +_GZIP_DISABLED_CHANNEL_ARGUMENT = ( + "grpc.compression_enabled_algorithms_bitset", + 3, +) _DEFLATE_DISABLED_CHANNEL_ARGUMENT = ( - 'grpc.compression_enabled_algorithms_bitset', 5) + "grpc.compression_enabled_algorithms_bitset", + 5, +) -_TEST_UNARY_UNARY = '/test/TestUnaryUnary' -_TEST_SET_COMPRESSION = '/test/TestSetCompression' -_TEST_DISABLE_COMPRESSION_UNARY = '/test/TestDisableCompressionUnary' -_TEST_DISABLE_COMPRESSION_STREAM = '/test/TestDisableCompressionStream' +_TEST_UNARY_UNARY = "/test/TestUnaryUnary" +_TEST_SET_COMPRESSION = "/test/TestSetCompression" +_TEST_DISABLE_COMPRESSION_UNARY = "/test/TestDisableCompressionUnary" +_TEST_DISABLE_COMPRESSION_STREAM = "/test/TestDisableCompressionStream" -_REQUEST = b'\x01' * 100 -_RESPONSE = b'\x02' * 100 +_REQUEST = b"\x01" * 100 +_RESPONSE = b"\x02" * 100 async def _test_unary_unary(unused_request, unused_context): @@ -58,7 +62,8 @@ async def _test_set_compression(unused_request_iterator, context): pass else: raise ValueError( - 'Expecting exceptions if set_compression is not effective') + "Expecting exceptions if set_compression is not effective" + ) async def _test_disable_compression_unary(request, context): @@ -78,33 +83,33 @@ async def _test_disable_compression_stream(unused_request_iterator, context): _ROUTING_TABLE = { - _TEST_UNARY_UNARY: - grpc.unary_unary_rpc_method_handler(_test_unary_unary), - _TEST_SET_COMPRESSION: - grpc.stream_stream_rpc_method_handler(_test_set_compression), - _TEST_DISABLE_COMPRESSION_UNARY: - grpc.unary_unary_rpc_method_handler(_test_disable_compression_unary), - _TEST_DISABLE_COMPRESSION_STREAM: - grpc.stream_stream_rpc_method_handler(_test_disable_compression_stream), + _TEST_UNARY_UNARY: grpc.unary_unary_rpc_method_handler(_test_unary_unary), + _TEST_SET_COMPRESSION: grpc.stream_stream_rpc_method_handler( + _test_set_compression + ), + _TEST_DISABLE_COMPRESSION_UNARY: grpc.unary_unary_rpc_method_handler( + _test_disable_compression_unary + ), + _TEST_DISABLE_COMPRESSION_STREAM: grpc.stream_stream_rpc_method_handler( + _test_disable_compression_stream + ), } class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): return _ROUTING_TABLE.get(handler_call_details.method) async def _start_test_server(options=None): server = aio.server(options=options) - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") server.add_generic_rpc_handlers((_GenericHandler(),)) await server.start() - return f'localhost:{port}', server + return f"localhost:{port}", server class TestCompression(AioTestBase): - async def setUp(self): server_options = (_GZIP_DISABLED_CHANNEL_ARGUMENT,) self._address, self._server = await _start_test_server(server_options) @@ -117,7 +122,8 @@ async def tearDown(self): async def test_channel_level_compression_baned_compression(self): # GZIP is disabled, this call should fail async with aio.insecure_channel( - self._address, compression=grpc.Compression.Gzip) as channel: + self._address, compression=grpc.Compression.Gzip + ) as channel: multicallable = channel.unary_unary(_TEST_UNARY_UNARY) call = multicallable(_REQUEST) with self.assertRaises(aio.AioRpcError) as exception_context: @@ -128,7 +134,8 @@ async def test_channel_level_compression_baned_compression(self): async def test_channel_level_compression_allowed_compression(self): # Deflate is allowed, this call should succeed async with aio.insecure_channel( - self._address, compression=grpc.Compression.Deflate) as channel: + self._address, compression=grpc.Compression.Deflate + ) as channel: multicallable = channel.unary_unary(_TEST_UNARY_UNARY) call = multicallable(_REQUEST) self.assertEqual(grpc.StatusCode.OK, await call.code()) @@ -160,14 +167,16 @@ async def test_server_call_level_compression(self): async def test_server_disable_compression_unary(self): multicallable = self._channel.unary_unary( - _TEST_DISABLE_COMPRESSION_UNARY) + _TEST_DISABLE_COMPRESSION_UNARY + ) call = multicallable(_REQUEST) self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) async def test_server_disable_compression_stream(self): multicallable = self._channel.stream_stream( - _TEST_DISABLE_COMPRESSION_STREAM) + _TEST_DISABLE_COMPRESSION_STREAM + ) call = multicallable() await call.write(_REQUEST) await call.done_writing() @@ -178,11 +187,11 @@ async def test_server_disable_compression_stream(self): async def test_server_default_compression_algorithm(self): server = aio.server(compression=grpc.Compression.Deflate) - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") server.add_generic_rpc_handlers((_GenericHandler(),)) await server.start() - async with aio.insecure_channel(f'localhost:{port}') as channel: + async with aio.insecure_channel(f"localhost:{port}") as channel: multicallable = channel.unary_unary(_TEST_UNARY_UNARY) call = multicallable(_REQUEST) self.assertEqual(_RESPONSE, await call) @@ -191,6 +200,6 @@ async def test_server_default_compression_algorithm(self): await server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py index 8ed2689cb836c..9af401a377e4a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py @@ -31,27 +31,32 @@ class TestConnectivityState(AioTestBase): - async def setUp(self): self._server_address, self._server = await start_test_server() async def tearDown(self): await self._server.stop(None) - @unittest.skipIf('aarch64' in platform.machine(), - 'The transient failure propagation is slower on aarch64') + @unittest.skipIf( + "aarch64" in platform.machine(), + "The transient failure propagation is slower on aarch64", + ) async def test_unavailable_backend(self): async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: - self.assertEqual(grpc.ChannelConnectivity.IDLE, - channel.get_state(False)) - self.assertEqual(grpc.ChannelConnectivity.IDLE, - channel.get_state(True)) + self.assertEqual( + grpc.ChannelConnectivity.IDLE, channel.get_state(False) + ) + self.assertEqual( + grpc.ChannelConnectivity.IDLE, channel.get_state(True) + ) # Should not time out await asyncio.wait_for( _common.block_until_certain_state( - channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE), - test_constants.SHORT_TIMEOUT) + channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE + ), + test_constants.SHORT_TIMEOUT, + ) async def test_normal_backend(self): async with aio.insecure_channel(self._server_address) as channel: @@ -61,26 +66,32 @@ async def test_normal_backend(self): # Should not time out await asyncio.wait_for( _common.block_until_certain_state( - channel, grpc.ChannelConnectivity.READY), - test_constants.SHORT_TIMEOUT) + channel, grpc.ChannelConnectivity.READY + ), + test_constants.SHORT_TIMEOUT, + ) async def test_timeout(self): async with aio.insecure_channel(self._server_address) as channel: - self.assertEqual(grpc.ChannelConnectivity.IDLE, - channel.get_state(False)) + self.assertEqual( + grpc.ChannelConnectivity.IDLE, channel.get_state(False) + ) # If timed out, the function should return None. with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for( _common.block_until_certain_state( - channel, grpc.ChannelConnectivity.READY), - test_constants.SHORT_TIMEOUT) + channel, grpc.ChannelConnectivity.READY + ), + test_constants.SHORT_TIMEOUT, + ) async def test_shutdown(self): channel = aio.insecure_channel(self._server_address) - self.assertEqual(grpc.ChannelConnectivity.IDLE, - channel.get_state(False)) + self.assertEqual( + grpc.ChannelConnectivity.IDLE, channel.get_state(False) + ) # Waiting for changes in a separate coroutine wait_started = asyncio.Event() @@ -94,11 +105,13 @@ async def a_pending_wait(): await channel.close() - self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, - channel.get_state(True)) + self.assertEqual( + grpc.ChannelConnectivity.SHUTDOWN, channel.get_state(True) + ) - self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, - channel.get_state(False)) + self.assertEqual( + grpc.ChannelConnectivity.SHUTDOWN, channel.get_state(False) + ) # Make sure there isn't any exception in the task await pending_task @@ -107,9 +120,10 @@ async def a_pending_wait(): # segfault or abort. with self.assertRaises(aio.UsageError): await channel.wait_for_state_change( - grpc.ChannelConnectivity.SHUTDOWN) + grpc.ChannelConnectivity.SHUTDOWN + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py b/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py index 743d6599ef33c..531f28646be41 100644 --- a/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/context_peer_test.py @@ -30,38 +30,38 @@ from tests_aio.unit._test_server import TestServiceServicer from tests_aio.unit._test_server import start_test_server -_REQUEST = b'\x03\x07' -_TEST_METHOD = '/test/UnaryUnary' +_REQUEST = b"\x03\x07" +_TEST_METHOD = "/test/UnaryUnary" class TestContextPeer(AioTestBase): - async def test_peer(self): - @grpc.unary_unary_rpc_method_handler - async def check_peer_unary_unary(request: bytes, - context: aio.ServicerContext): + async def check_peer_unary_unary( + request: bytes, context: aio.ServicerContext + ): self.assertEqual(_REQUEST, request) # The peer address could be ipv4 or ipv6 - self.assertIn('ip', context.peer()) + self.assertIn("ip", context.peer()) return request # Creates a server server = aio.server() handlers = grpc.method_handlers_generic_handler( - 'test', {'UnaryUnary': check_peer_unary_unary}) + "test", {"UnaryUnary": check_peer_unary_unary} + ) server.add_generic_rpc_handlers((handlers,)) - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") await server.start() # Creates a channel - async with aio.insecure_channel('localhost:%d' % port) as channel: + async with aio.insecure_channel("localhost:%d" % port) as channel: response = await channel.unary_unary(_TEST_METHOD)(_REQUEST) self.assertEqual(_REQUEST, response) await server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py index 42a070f2e81aa..b6a4a352147fe 100644 --- a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py @@ -29,14 +29,13 @@ _NUM_STREAM_RESPONSES = 5 _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 -_REQUEST = b'\x01\x02\x03' -_RESPONSE = b'\x04\x05\x06' -_TEST_METHOD = '/test/Test' -_FAKE_METHOD = '/test/Fake' +_REQUEST = b"\x01\x02\x03" +_RESPONSE = b"\x04\x05\x06" +_TEST_METHOD = "/test/Test" +_FAKE_METHOD = "/test/Fake" class TestClientSideDoneCallback(AioTestBase): - async def setUp(self): address, self._server = await start_test_server() self._channel = aio.insecure_channel(address) @@ -65,7 +64,8 @@ async def test_unary_stream(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) call = self._stub.StreamingOutputCall(request) validation = inject_callbacks(call) @@ -73,8 +73,9 @@ async def test_unary_stream(self): response_cnt = 0 async for response in call: response_cnt += 1 - self.assertIsInstance(response, - messages_pb2.StreamingOutputCallResponse) + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) @@ -83,7 +84,7 @@ async def test_unary_stream(self): await validation async def test_stream_unary(self): - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def gen(): @@ -95,8 +96,10 @@ async def gen(): response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(grpc.StatusCode.OK, await call.code()) await validation @@ -107,13 +110,15 @@ async def test_stream_stream(self): request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) response = await call.read() - self.assertIsInstance(response, - messages_pb2.StreamingOutputCallResponse) + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) await call.done_writing() @@ -123,11 +128,10 @@ async def test_stream_stream(self): class TestServerSideDoneCallback(AioTestBase): - async def setUp(self): self._server = aio.server() - port = self._server.add_insecure_port('[::]:0') - self._channel = aio.insecure_channel('localhost:%d' % port) + port = self._server.add_insecure_port("[::]:0") + self._channel = aio.insecure_channel("localhost:%d" % port) async def tearDown(self): await self._channel.close() @@ -136,7 +140,7 @@ async def tearDown(self): async def _register_method_handler(self, method_handler): """Registers method handler and starts the server""" generic_handler = grpc.method_handlers_generic_handler( - 'test', + "test", dict(Test=method_handler), ) self._server.add_generic_rpc_handlers((generic_handler,)) @@ -151,7 +155,8 @@ async def test_handler(request: bytes, context: aio.ServicerContext): return _RESPONSE await self._register_method_handler( - grpc.unary_unary_rpc_method_handler(test_handler)) + grpc.unary_unary_rpc_method_handler(test_handler) + ) response = await self._channel.unary_unary(_TEST_METHOD)(_REQUEST) self.assertEqual(_RESPONSE, response) @@ -168,7 +173,8 @@ async def test_handler(request: bytes, context: aio.ServicerContext): yield _RESPONSE await self._register_method_handler( - grpc.unary_stream_rpc_method_handler(test_handler)) + grpc.unary_stream_rpc_method_handler(test_handler) + ) call = self._channel.unary_stream(_TEST_METHOD)(_REQUEST) async for response in call: self.assertEqual(_RESPONSE, response) @@ -187,7 +193,8 @@ async def test_handler(request_iterator, context: aio.ServicerContext): return _RESPONSE await self._register_method_handler( - grpc.stream_unary_rpc_method_handler(test_handler)) + grpc.stream_unary_rpc_method_handler(test_handler) + ) call = self._channel.stream_unary(_TEST_METHOD)() for _ in range(_NUM_STREAM_RESPONSES): await call.write(_REQUEST) @@ -208,7 +215,8 @@ async def test_handler(request_iterator, context: aio.ServicerContext): return _RESPONSE await self._register_method_handler( - grpc.stream_stream_rpc_method_handler(test_handler)) + grpc.stream_stream_rpc_method_handler(test_handler) + ) call = self._channel.stream_stream(_TEST_METHOD)() for _ in range(_NUM_STREAM_RESPONSES): await call.write(_REQUEST) @@ -226,10 +234,11 @@ async def test_error_in_handler(self): async def test_handler(request: bytes, context: aio.ServicerContext): self.assertEqual(_REQUEST, request) validation_future.set_result(inject_callbacks(context)) - raise RuntimeError('A test RuntimeError') + raise RuntimeError("A test RuntimeError") await self._register_method_handler( - grpc.unary_unary_rpc_method_handler(test_handler)) + grpc.unary_unary_rpc_method_handler(test_handler) + ) with self.assertRaises(aio.AioRpcError) as exception_context: await self._channel.unary_unary(_TEST_METHOD)(_REQUEST) rpc_error = exception_context.exception @@ -246,14 +255,15 @@ async def test_handler(request: bytes, context: aio.ServicerContext): self.assertEqual(_REQUEST, request) def exception_raiser(unused_context): - raise RuntimeError('A test RuntimeError') + raise RuntimeError("A test RuntimeError") context.add_done_callback(exception_raiser) validation_future.set_result(inject_callbacks(context)) return _RESPONSE await self._register_method_handler( - grpc.unary_unary_rpc_method_handler(test_handler)) + grpc.unary_unary_rpc_method_handler(test_handler) + ) response = await self._channel.unary_unary(_TEST_METHOD)(_REQUEST) self.assertEqual(_RESPONSE, response) @@ -271,6 +281,6 @@ def exception_raiser(unused_context): self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/init_test.py b/src/python/grpcio_tests/tests_aio/unit/init_test.py index b7889b9942b29..819d26690b56a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/init_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/init_test.py @@ -16,18 +16,19 @@ class TestInit(unittest.TestCase): - def test_grpc(self): import grpc # pylint: disable=wrong-import-position - channel = grpc.aio.insecure_channel('phony') + + channel = grpc.aio.insecure_channel("phony") self.assertIsInstance(channel, grpc.aio.Channel) def test_grpc_dot_aio(self): import grpc.aio # pylint: disable=wrong-import-position - channel = grpc.aio.insecure_channel('phony') + + channel = grpc.aio.insecure_channel("phony") self.assertIsInstance(channel, grpc.aio.Channel) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 4043d19e3179d..47b9cf51f43e8 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -25,33 +25,34 @@ from tests_aio.unit import _common from tests_aio.unit._test_base import AioTestBase -_TEST_CLIENT_TO_SERVER = '/test/TestClientToServer' -_TEST_SERVER_TO_CLIENT = '/test/TestServerToClient' -_TEST_TRAILING_METADATA = '/test/TestTrailingMetadata' -_TEST_ECHO_INITIAL_METADATA = '/test/TestEchoInitialMetadata' -_TEST_GENERIC_HANDLER = '/test/TestGenericHandler' -_TEST_UNARY_STREAM = '/test/TestUnaryStream' -_TEST_STREAM_UNARY = '/test/TestStreamUnary' -_TEST_STREAM_STREAM = '/test/TestStreamStream' -_TEST_INSPECT_CONTEXT = '/test/TestInspectContext' - -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x01\x01\x01' +_TEST_CLIENT_TO_SERVER = "/test/TestClientToServer" +_TEST_SERVER_TO_CLIENT = "/test/TestServerToClient" +_TEST_TRAILING_METADATA = "/test/TestTrailingMetadata" +_TEST_ECHO_INITIAL_METADATA = "/test/TestEchoInitialMetadata" +_TEST_GENERIC_HANDLER = "/test/TestGenericHandler" +_TEST_UNARY_STREAM = "/test/TestUnaryStream" +_TEST_STREAM_UNARY = "/test/TestStreamUnary" +_TEST_STREAM_STREAM = "/test/TestStreamStream" +_TEST_INSPECT_CONTEXT = "/test/TestInspectContext" + +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x01\x01\x01" _INITIAL_METADATA_FROM_CLIENT_TO_SERVER = aio.Metadata( - ('client-to-server', 'question'), - ('client-to-server-bin', b'\x07\x07\x07'), + ("client-to-server", "question"), + ("client-to-server-bin", b"\x07\x07\x07"), ) _INITIAL_METADATA_FROM_SERVER_TO_CLIENT = aio.Metadata( - ('server-to-client', 'answer'), - ('server-to-client-bin', b'\x06\x06\x06'), + ("server-to-client", "answer"), + ("server-to-client-bin", b"\x06\x06\x06"), ) _TRAILING_METADATA = aio.Metadata( - ('a-trailing-metadata', 'stack-trace'), - ('a-trailing-metadata-bin', b'\x05\x05\x05'), + ("a-trailing-metadata", "stack-trace"), + ("a-trailing-metadata-bin", b"\x05\x05\x05"), ) _INITIAL_METADATA_FOR_GENERIC_HANDLER = aio.Metadata( - ('a-must-have-key', 'secret'),) + ("a-must-have-key", "secret"), +) _INVALID_METADATA_TEST_CASES = ( ( @@ -72,49 +73,55 @@ ), ( TypeError, - (('normal', object()),), + (("normal", object()),), ), ) _NON_OK_CODE = grpc.StatusCode.NOT_FOUND -_DETAILS = 'Test details!' +_DETAILS = "Test details!" class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): - def __init__(self): self._routing_table = { - _TEST_CLIENT_TO_SERVER: - grpc.unary_unary_rpc_method_handler(self._test_client_to_server - ), - _TEST_SERVER_TO_CLIENT: - grpc.unary_unary_rpc_method_handler(self._test_server_to_client - ), - _TEST_TRAILING_METADATA: - grpc.unary_unary_rpc_method_handler(self._test_trailing_metadata - ), - _TEST_UNARY_STREAM: - grpc.unary_stream_rpc_method_handler(self._test_unary_stream), - _TEST_STREAM_UNARY: - grpc.stream_unary_rpc_method_handler(self._test_stream_unary), - _TEST_STREAM_STREAM: - grpc.stream_stream_rpc_method_handler(self._test_stream_stream), - _TEST_INSPECT_CONTEXT: - grpc.unary_unary_rpc_method_handler(self._test_inspect_context), + _TEST_CLIENT_TO_SERVER: grpc.unary_unary_rpc_method_handler( + self._test_client_to_server + ), + _TEST_SERVER_TO_CLIENT: grpc.unary_unary_rpc_method_handler( + self._test_server_to_client + ), + _TEST_TRAILING_METADATA: grpc.unary_unary_rpc_method_handler( + self._test_trailing_metadata + ), + _TEST_UNARY_STREAM: grpc.unary_stream_rpc_method_handler( + self._test_unary_stream + ), + _TEST_STREAM_UNARY: grpc.stream_unary_rpc_method_handler( + self._test_stream_unary + ), + _TEST_STREAM_STREAM: grpc.stream_stream_rpc_method_handler( + self._test_stream_stream + ), + _TEST_INSPECT_CONTEXT: grpc.unary_unary_rpc_method_handler( + self._test_inspect_context + ), } @staticmethod async def _test_client_to_server(request, context): assert _REQUEST == request - assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, - context.invocation_metadata()) + assert _common.seen_metadata( + _INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata(), + ) return _RESPONSE @staticmethod async def _test_server_to_client(request, context): assert _REQUEST == request await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT + ) return _RESPONSE @staticmethod @@ -126,19 +133,25 @@ async def _test_trailing_metadata(request, context): @staticmethod async def _test_unary_stream(request, context): assert _REQUEST == request - assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, - context.invocation_metadata()) + assert _common.seen_metadata( + _INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata(), + ) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT + ) yield _RESPONSE context.set_trailing_metadata(_TRAILING_METADATA) @staticmethod async def _test_stream_unary(request_iterator, context): - assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, - context.invocation_metadata()) + assert _common.seen_metadata( + _INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata(), + ) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT + ) async for request in request_iterator: assert _REQUEST == request @@ -148,10 +161,13 @@ async def _test_stream_unary(request_iterator, context): @staticmethod async def _test_stream_stream(request_iterator, context): - assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER, - context.invocation_metadata()) + assert _common.seen_metadata( + _INITIAL_METADATA_FROM_CLIENT_TO_SERVER, + context.invocation_metadata(), + ) await context.send_initial_metadata( - _INITIAL_METADATA_FROM_SERVER_TO_CLIENT) + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT + ) async for request in request_iterator: assert _REQUEST == request @@ -177,31 +193,33 @@ def service(self, handler_call_details): class _TestGenericHandlerItself(grpc.GenericRpcHandler): - @staticmethod async def _method(request, unused_context): assert _REQUEST == request return _RESPONSE def service(self, handler_call_details): - assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER, - handler_call_details.invocation_metadata) + assert _common.seen_metadata( + _INITIAL_METADATA_FOR_GENERIC_HANDLER, + handler_call_details.invocation_metadata, + ) return grpc.unary_unary_rpc_method_handler(self._method) async def _start_test_server(): server = aio.server() - port = server.add_insecure_port('[::]:0') - server.add_generic_rpc_handlers(( - _TestGenericHandlerForMethods(), - _TestGenericHandlerItself(), - )) + port = server.add_insecure_port("[::]:0") + server.add_generic_rpc_handlers( + ( + _TestGenericHandlerForMethods(), + _TestGenericHandlerItself(), + ) + ) await server.start() - return 'localhost:%d' % port, server + return "localhost:%d" % port, server class TestMetadata(AioTestBase): - async def setUp(self): address, self._server = await _start_test_server() self._client = aio.insecure_channel(address) @@ -212,8 +230,9 @@ async def tearDown(self): async def test_from_client_to_server(self): multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) - call = multicallable(_REQUEST, - metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) + call = multicallable( + _REQUEST, metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER + ) self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) @@ -221,8 +240,10 @@ async def test_from_server_to_client(self): multicallable = self._client.unary_unary(_TEST_SERVER_TO_CLIENT) call = multicallable(_REQUEST) - self.assertEqual(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await - call.initial_metadata()) + self.assertEqual( + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT, + await call.initial_metadata(), + ) self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) @@ -236,12 +257,15 @@ async def test_trailing_metadata(self): async def test_from_client_to_server_with_list(self): multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) call = multicallable( - _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)) # pytype: disable=wrong-arg-types + _REQUEST, metadata=list(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) + ) # pytype: disable=wrong-arg-types self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) - @unittest.skipIf(platform.system() == 'Windows', - 'https://github.com/grpc/grpc/issues/21943') + @unittest.skipIf( + platform.system() == "Windows", + "https://github.com/grpc/grpc/issues/21943", + ) async def test_invalid_metadata(self): multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER) for exception_type, metadata in _INVALID_METADATA_TEST_CASES: @@ -252,22 +276,28 @@ async def test_invalid_metadata(self): async def test_generic_handler(self): multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER) - call = multicallable(_REQUEST, - metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER) + call = multicallable( + _REQUEST, metadata=_INITIAL_METADATA_FOR_GENERIC_HANDLER + ) self.assertEqual(_RESPONSE, await call) self.assertEqual(grpc.StatusCode.OK, await call.code()) async def test_unary_stream(self): multicallable = self._client.unary_stream(_TEST_UNARY_STREAM) - call = multicallable(_REQUEST, - metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER) + call = multicallable( + _REQUEST, metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER + ) self.assertTrue( - _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await - call.initial_metadata())) + _common.seen_metadata( + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT, + await call.initial_metadata(), + ) + ) - self.assertSequenceEqual([_RESPONSE], - [request async for request in call]) + self.assertSequenceEqual( + [_RESPONSE], [request async for request in call] + ) self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) self.assertEqual(grpc.StatusCode.OK, await call.code()) @@ -279,8 +309,11 @@ async def test_stream_unary(self): await call.done_writing() self.assertTrue( - _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await - call.initial_metadata())) + _common.seen_metadata( + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT, + await call.initial_metadata(), + ) + ) self.assertEqual(_RESPONSE, await call) self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) @@ -293,22 +326,27 @@ async def test_stream_stream(self): await call.done_writing() self.assertTrue( - _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await - call.initial_metadata())) - self.assertSequenceEqual([_RESPONSE], - [request async for request in call]) + _common.seen_metadata( + _INITIAL_METADATA_FROM_SERVER_TO_CLIENT, + await call.initial_metadata(), + ) + ) + self.assertSequenceEqual( + [_RESPONSE], [request async for request in call] + ) self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata()) self.assertEqual(grpc.StatusCode.OK, await call.code()) async def test_compatibility_with_tuple(self): - metadata_obj = aio.Metadata(('key', '42'), ('key-2', 'value')) + metadata_obj = aio.Metadata(("key", "42"), ("key-2", "value")) self.assertEqual(metadata_obj, tuple(metadata_obj)) self.assertEqual(tuple(metadata_obj), metadata_obj) - expected_sum = tuple(metadata_obj) + (('third', '3'),) - self.assertEqual(expected_sum, metadata_obj + (('third', '3'),)) - self.assertEqual(expected_sum, metadata_obj + aio.Metadata( - ('third', '3'))) + expected_sum = tuple(metadata_obj) + (("third", "3"),) + self.assertEqual(expected_sum, metadata_obj + (("third", "3"),)) + self.assertEqual( + expected_sum, metadata_obj + aio.Metadata(("third", "3")) + ) async def test_inspect_context(self): multicallable = self._client.unary_unary(_TEST_INSPECT_CONTEXT) @@ -320,6 +358,6 @@ async def test_inspect_context(self): self.assertEqual(_NON_OK_CODE, err.code()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/outside_init_test.py b/src/python/grpcio_tests/tests_aio/unit/outside_init_test.py index 79a7518585f6d..e3f8ba7e98338 100644 --- a/src/python/grpcio_tests/tests_aio/unit/outside_init_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/outside_init_test.py @@ -28,7 +28,6 @@ class TestOutsideInit(unittest.TestCase): - def test_behavior_outside_asyncio(self): # Ensures non-AsyncIO object can be initiated channel_creds = grpc.ssl_channel_credentials() @@ -37,8 +36,8 @@ def test_behavior_outside_asyncio(self): # NOTE(lidiz) This behavior is bound with GAPIC generator, and required # by test frameworks like pytest. In test frameworks, objects shared # across cases need to be created outside of AsyncIO coroutines. - aio.insecure_channel('') - aio.secure_channel('', channel_creds) + aio.insecure_channel("") + aio.secure_channel("", channel_creds) aio.server() aio.init_grpc_aio() aio.shutdown_grpc_aio() diff --git a/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py b/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py index a5b03f43ae47b..c14ccced35680 100644 --- a/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py @@ -25,7 +25,7 @@ from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server -_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_SERVER_HOST_OVERRIDE = "foo.test.google.fr" _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 @@ -34,20 +34,25 @@ class _SecureCallMixin: """A Mixin to run the call tests over a secure channel.""" async def setUp(self): - server_credentials = grpc.ssl_server_credentials([ - (resources.private_key(), resources.certificate_chain()) - ]) + server_credentials = grpc.ssl_server_credentials( + [(resources.private_key(), resources.certificate_chain())] + ) channel_credentials = grpc.ssl_channel_credentials( - resources.test_root_certificates()) + resources.test_root_certificates() + ) self._server_address, self._server = await start_test_server( - secure=True, server_credentials=server_credentials) - channel_options = (( - 'grpc.ssl_target_name_override', - _SERVER_HOST_OVERRIDE, - ),) - self._channel = aio.secure_channel(self._server_address, - channel_credentials, channel_options) + secure=True, server_credentials=server_credentials + ) + channel_options = ( + ( + "grpc.ssl_target_name_override", + _SERVER_HOST_OVERRIDE, + ), + ) + self._channel = aio.secure_channel( + self._server_address, channel_credentials, channel_options + ) self._stub = test_pb2_grpc.TestServiceStub(self._channel) async def tearDown(self): @@ -69,8 +74,9 @@ async def test_call_with_credentials(self): grpc.access_token_call_credentials("abc"), grpc.access_token_call_credentials("def"), ) - call = self._stub.UnaryCall(messages_pb2.SimpleRequest(), - credentials=call_credentials) + call = self._stub.UnaryCall( + messages_pb2.SimpleRequest(), credentials=call_credentials + ) response = await call self.assertIsInstance(response, messages_pb2.SimpleResponse) @@ -82,18 +88,23 @@ class TestUnaryStreamSecureCall(_SecureCallMixin, AioTestBase): async def test_unary_stream_async_generator_secure(self): request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.extend( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,) - for _ in range(_NUM_STREAM_RESPONSES)) + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + ) + for _ in range(_NUM_STREAM_RESPONSES) + ) call_credentials = grpc.composite_call_credentials( grpc.access_token_call_credentials("abc"), grpc.access_token_call_credentials("def"), ) - call = self._stub.StreamingOutputCall(request, - credentials=call_credentials) + call = self._stub.StreamingOutputCall( + request, credentials=call_credentials + ) async for response in call: - self.assertIsInstance(response, - messages_pb2.StreamingOutputCallResponse) + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(len(response.payload.body), _RESPONSE_PAYLOAD_SIZE) self.assertEqual(await call.code(), grpc.StatusCode.OK) @@ -102,14 +113,14 @@ async def test_unary_stream_async_generator_secure(self): # Prepares the request that stream in a ping-pong manner. _STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() _STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) +) class TestStreamStreamSecureCall(_SecureCallMixin, AioTestBase): _STREAM_ITERATIONS = 2 async def test_async_generator_secure_channel(self): - async def request_generator(): for _ in range(self._STREAM_ITERATIONS): yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE @@ -119,14 +130,15 @@ async def request_generator(): grpc.access_token_call_credentials("def"), ) - call = self._stub.FullDuplexCall(request_generator(), - credentials=call_credentials) + call = self._stub.FullDuplexCall( + request_generator(), credentials=call_credentials + ) async for response in call: self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(await call.code(), grpc.StatusCode.OK) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py index 18f5df09de19e..1975403e3aa64 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_interceptor_test.py @@ -34,50 +34,59 @@ class _LoggingInterceptor(aio.ServerInterceptor): - def __init__(self, tag: str, record: list) -> None: self.tag = tag self.record = record async def intercept_service( - self, continuation: Callable[[grpc.HandlerCallDetails], - Awaitable[grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: - self.record.append(self.tag + ':intercept_service') + self.record.append(self.tag + ":intercept_service") return await continuation(handler_call_details) class _GenericInterceptor(aio.ServerInterceptor): - def __init__( - self, fn: Callable[[ - Callable[[grpc.HandlerCallDetails], - Awaitable[grpc.RpcMethodHandler]], grpc.HandlerCallDetails - ], Any] + self, + fn: Callable[ + [ + Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + grpc.HandlerCallDetails, + ], + Any, + ], ) -> None: self._fn = fn async def intercept_service( - self, continuation: Callable[[grpc.HandlerCallDetails], - Awaitable[grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: return await self._fn(continuation, handler_call_details) def _filter_server_interceptor( - condition: Callable, - interceptor: aio.ServerInterceptor) -> aio.ServerInterceptor: - + condition: Callable, interceptor: aio.ServerInterceptor +) -> aio.ServerInterceptor: async def intercept_service( - continuation: Callable[[grpc.HandlerCallDetails], - Awaitable[grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: if condition(handler_call_details): - return await interceptor.intercept_service(continuation, - handler_call_details) + return await interceptor.intercept_service( + continuation, handler_call_details + ) return await continuation(handler_call_details) return _GenericInterceptor(intercept_service) @@ -90,30 +99,37 @@ def __init__(self, cache_store=None): self.cache_store = cache_store or {} async def intercept_service( - self, continuation: Callable[[grpc.HandlerCallDetails], - Awaitable[grpc.RpcMethodHandler]], - handler_call_details: grpc.HandlerCallDetails + self, + continuation: Callable[ + [grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler] + ], + handler_call_details: grpc.HandlerCallDetails, ) -> grpc.RpcMethodHandler: # Get the actual handler handler = await continuation(handler_call_details) # Only intercept unary call RPCs - if handler and (handler.request_streaming or # pytype: disable=attribute-error - handler.response_streaming): # pytype: disable=attribute-error + if handler and ( + handler.request_streaming + or handler.response_streaming # pytype: disable=attribute-error + ): # pytype: disable=attribute-error return handler - def wrapper(behavior: Callable[ - [messages_pb2.SimpleRequest, aio.ServicerContext], - messages_pb2.SimpleResponse]): - + def wrapper( + behavior: Callable[ + [messages_pb2.SimpleRequest, aio.ServicerContext], + messages_pb2.SimpleResponse, + ] + ): @functools.wraps(behavior) async def wrapper( - request: messages_pb2.SimpleRequest, - context: aio.ServicerContext + request: messages_pb2.SimpleRequest, + context: aio.ServicerContext, ) -> messages_pb2.SimpleResponse: if request.response_size not in self.cache_store: self.cache_store[request.response_size] = await behavior( - request, context) + request, context + ) return self.cache_store[request.response_size] return wrapper @@ -122,7 +138,7 @@ async def wrapper( async def _create_server_stub_pair( - *interceptors: aio.ServerInterceptor + *interceptors: aio.ServerInterceptor, ) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]: """Creates a server-stub pair with given interceptors. @@ -134,130 +150,160 @@ async def _create_server_stub_pair( class TestServerInterceptor(AioTestBase): - async def test_invalid_interceptor(self): - class InvalidInterceptor: """Just an invalid Interceptor""" with self.assertRaises(ValueError): server_target, _ = await start_test_server( - interceptors=(InvalidInterceptor(),)) + interceptors=(InvalidInterceptor(),) + ) async def test_executed_right_order(self): record = [] - server_target, _ = await start_test_server(interceptors=( - _LoggingInterceptor('log1', record), - _LoggingInterceptor('log2', record), - )) + server_target, _ = await start_test_server( + interceptors=( + _LoggingInterceptor("log1", record), + _LoggingInterceptor("log2", record), + ) + ) async with aio.insecure_channel(server_target) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) response = await call # Check that all interceptors were executed, and were executed # in the right order. - self.assertSequenceEqual([ - 'log1:intercept_service', - 'log2:intercept_service', - ], record) + self.assertSequenceEqual( + [ + "log1:intercept_service", + "log2:intercept_service", + ], + record, + ) self.assertIsInstance(response, messages_pb2.SimpleResponse) async def test_response_ok(self): record = [] server_target, _ = await start_test_server( - interceptors=(_LoggingInterceptor('log1', record),)) + interceptors=(_LoggingInterceptor("log1", record),) + ) async with aio.insecure_channel(server_target) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) call = multicallable(messages_pb2.SimpleRequest()) response = await call code = await call.code() - self.assertSequenceEqual(['log1:intercept_service'], record) + self.assertSequenceEqual(["log1:intercept_service"], record) self.assertIsInstance(response, messages_pb2.SimpleResponse) self.assertEqual(code, grpc.StatusCode.OK) async def test_apply_different_interceptors_by_metadata(self): record = [] conditional_interceptor = _filter_server_interceptor( - lambda x: ('secret', '42') in x.invocation_metadata, - _LoggingInterceptor('log3', record)) - server_target, _ = await start_test_server(interceptors=( - _LoggingInterceptor('log1', record), - conditional_interceptor, - _LoggingInterceptor('log2', record), - )) + lambda x: ("secret", "42") in x.invocation_metadata, + _LoggingInterceptor("log3", record), + ) + server_target, _ = await start_test_server( + interceptors=( + _LoggingInterceptor("log1", record), + conditional_interceptor, + _LoggingInterceptor("log2", record), + ) + ) async with aio.insecure_channel(server_target) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + "/grpc.testing.TestService/UnaryCall", request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) - - metadata = aio.Metadata(('key', 'value'),) - call = multicallable(messages_pb2.SimpleRequest(), - metadata=metadata) + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + + metadata = aio.Metadata( + ("key", "value"), + ) + call = multicallable( + messages_pb2.SimpleRequest(), metadata=metadata + ) await call - self.assertSequenceEqual([ - 'log1:intercept_service', - 'log2:intercept_service', - ], record) + self.assertSequenceEqual( + [ + "log1:intercept_service", + "log2:intercept_service", + ], + record, + ) record.clear() - metadata = aio.Metadata(('key', 'value'), ('secret', '42')) - call = multicallable(messages_pb2.SimpleRequest(), - metadata=metadata) + metadata = aio.Metadata(("key", "value"), ("secret", "42")) + call = multicallable( + messages_pb2.SimpleRequest(), metadata=metadata + ) await call - self.assertSequenceEqual([ - 'log1:intercept_service', - 'log3:intercept_service', - 'log2:intercept_service', - ], record) + self.assertSequenceEqual( + [ + "log1:intercept_service", + "log3:intercept_service", + "log2:intercept_service", + ], + record, + ) async def test_response_caching(self): # Prepares a preset value to help testing - interceptor = _CacheInterceptor({ - 42: - messages_pb2.SimpleResponse(payload=messages_pb2.Payload( - body=b'\x42')) - }) + interceptor = _CacheInterceptor( + { + 42: messages_pb2.SimpleResponse( + payload=messages_pb2.Payload(body=b"\x42") + ) + } + ) # Constructs a server with the cache interceptor server, stub = await _create_server_stub_pair(interceptor) # Tests if the cache store is used response = await stub.UnaryCall( - messages_pb2.SimpleRequest(response_size=42)) + messages_pb2.SimpleRequest(response_size=42) + ) self.assertEqual(1, len(interceptor.cache_store[42].payload.body)) self.assertEqual(interceptor.cache_store[42], response) # Tests response can be cached response = await stub.UnaryCall( - messages_pb2.SimpleRequest(response_size=1337)) + messages_pb2.SimpleRequest(response_size=1337) + ) self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body)) self.assertEqual(interceptor.cache_store[1337], response) response = await stub.UnaryCall( - messages_pb2.SimpleRequest(response_size=1337)) + messages_pb2.SimpleRequest(response_size=1337) + ) self.assertEqual(interceptor.cache_store[1337], response) async def test_interceptor_unary_stream(self): record = [] server, stub = await _create_server_stub_pair( - _LoggingInterceptor('log_unary_stream', record)) + _LoggingInterceptor("log_unary_stream", record) + ) # Prepares the request request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,)) + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE, + ) + ) # Tests if the cache store is used call = stub.StreamingOutputCall(request) @@ -267,20 +313,24 @@ async def test_interceptor_unary_stream(self): self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertSequenceEqual([ - 'log_unary_stream:intercept_service', - ], record) + self.assertSequenceEqual( + [ + "log_unary_stream:intercept_service", + ], + record, + ) async def test_interceptor_stream_unary(self): record = [] server, stub = await _create_server_stub_pair( - _LoggingInterceptor('log_stream_unary', record)) + _LoggingInterceptor("log_stream_unary", record) + ) # Invokes the actual RPC call = stub.StreamingInputCall() # Prepares the request - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) # Sends out requests @@ -291,22 +341,28 @@ async def test_interceptor_stream_unary(self): # Validates the responses response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertSequenceEqual([ - 'log_stream_unary:intercept_service', - ], record) + self.assertSequenceEqual( + [ + "log_stream_unary:intercept_service", + ], + record, + ) async def test_interceptor_stream_stream(self): record = [] server, stub = await _create_server_stub_pair( - _LoggingInterceptor('log_stream_stream', record)) + _LoggingInterceptor("log_stream_stream", record) + ) # Prepares the request - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def gen(): @@ -319,16 +375,21 @@ async def gen(): # Validates the responses response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) - self.assertSequenceEqual([ - 'log_stream_stream:intercept_service', - ], record) + self.assertSequenceEqual( + [ + "log_stream_stream:intercept_service", + ], + record, + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 2f0f6de11007f..790cb27f746ab 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -24,85 +24,87 @@ from tests.unit.framework.common import test_constants from tests_aio.unit._test_base import AioTestBase -_SIMPLE_UNARY_UNARY = '/test/SimpleUnaryUnary' -_BLOCK_FOREVER = '/test/BlockForever' -_BLOCK_BRIEFLY = '/test/BlockBriefly' -_UNARY_STREAM_ASYNC_GEN = '/test/UnaryStreamAsyncGen' -_UNARY_STREAM_READER_WRITER = '/test/UnaryStreamReaderWriter' -_UNARY_STREAM_EVILLY_MIXED = '/test/UnaryStreamEvillyMixed' -_STREAM_UNARY_ASYNC_GEN = '/test/StreamUnaryAsyncGen' -_STREAM_UNARY_READER_WRITER = '/test/StreamUnaryReaderWriter' -_STREAM_UNARY_EVILLY_MIXED = '/test/StreamUnaryEvillyMixed' -_STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen' -_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter' -_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed' -_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod' -_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream' -_ERROR_IN_STREAM_UNARY = '/test/ErrorInStreamUnary' -_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary' -_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream' -_INVALID_TRAILING_METADATA = '/test/InvalidTrailingMetadata' - -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x01\x01\x01' +_SIMPLE_UNARY_UNARY = "/test/SimpleUnaryUnary" +_BLOCK_FOREVER = "/test/BlockForever" +_BLOCK_BRIEFLY = "/test/BlockBriefly" +_UNARY_STREAM_ASYNC_GEN = "/test/UnaryStreamAsyncGen" +_UNARY_STREAM_READER_WRITER = "/test/UnaryStreamReaderWriter" +_UNARY_STREAM_EVILLY_MIXED = "/test/UnaryStreamEvillyMixed" +_STREAM_UNARY_ASYNC_GEN = "/test/StreamUnaryAsyncGen" +_STREAM_UNARY_READER_WRITER = "/test/StreamUnaryReaderWriter" +_STREAM_UNARY_EVILLY_MIXED = "/test/StreamUnaryEvillyMixed" +_STREAM_STREAM_ASYNC_GEN = "/test/StreamStreamAsyncGen" +_STREAM_STREAM_READER_WRITER = "/test/StreamStreamReaderWriter" +_STREAM_STREAM_EVILLY_MIXED = "/test/StreamStreamEvillyMixed" +_UNIMPLEMENTED_METHOD = "/test/UnimplementedMethod" +_ERROR_IN_STREAM_STREAM = "/test/ErrorInStreamStream" +_ERROR_IN_STREAM_UNARY = "/test/ErrorInStreamUnary" +_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = "/test/ErrorWithoutRaiseInUnaryUnary" +_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = "/test/ErrorWithoutRaiseInStreamStream" +_INVALID_TRAILING_METADATA = "/test/InvalidTrailingMetadata" + +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x01\x01\x01" _NUM_STREAM_REQUESTS = 3 _NUM_STREAM_RESPONSES = 5 _MAXIMUM_CONCURRENT_RPCS = 5 class _GenericHandler(grpc.GenericRpcHandler): - def __init__(self): self._called = asyncio.get_event_loop().create_future() self._routing_table = { - _SIMPLE_UNARY_UNARY: - grpc.unary_unary_rpc_method_handler(self._unary_unary), - _BLOCK_FOREVER: - grpc.unary_unary_rpc_method_handler(self._block_forever), - _BLOCK_BRIEFLY: - grpc.unary_unary_rpc_method_handler(self._block_briefly), - _UNARY_STREAM_ASYNC_GEN: - grpc.unary_stream_rpc_method_handler( - self._unary_stream_async_gen), - _UNARY_STREAM_READER_WRITER: - grpc.unary_stream_rpc_method_handler( - self._unary_stream_reader_writer), - _UNARY_STREAM_EVILLY_MIXED: - grpc.unary_stream_rpc_method_handler( - self._unary_stream_evilly_mixed), - _STREAM_UNARY_ASYNC_GEN: - grpc.stream_unary_rpc_method_handler( - self._stream_unary_async_gen), - _STREAM_UNARY_READER_WRITER: - grpc.stream_unary_rpc_method_handler( - self._stream_unary_reader_writer), - _STREAM_UNARY_EVILLY_MIXED: - grpc.stream_unary_rpc_method_handler( - self._stream_unary_evilly_mixed), - _STREAM_STREAM_ASYNC_GEN: - grpc.stream_stream_rpc_method_handler( - self._stream_stream_async_gen), - _STREAM_STREAM_READER_WRITER: - grpc.stream_stream_rpc_method_handler( - self._stream_stream_reader_writer), - _STREAM_STREAM_EVILLY_MIXED: - grpc.stream_stream_rpc_method_handler( - self._stream_stream_evilly_mixed), - _ERROR_IN_STREAM_STREAM: - grpc.stream_stream_rpc_method_handler( - self._error_in_stream_stream), - _ERROR_IN_STREAM_UNARY: - grpc.stream_unary_rpc_method_handler( - self._value_error_in_stream_unary), - _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY: - grpc.unary_unary_rpc_method_handler( - self._error_without_raise_in_unary_unary), - _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM: - grpc.stream_stream_rpc_method_handler( - self._error_without_raise_in_stream_stream), - _INVALID_TRAILING_METADATA: - grpc.unary_unary_rpc_method_handler( - self._invalid_trailing_metadata), + _SIMPLE_UNARY_UNARY: grpc.unary_unary_rpc_method_handler( + self._unary_unary + ), + _BLOCK_FOREVER: grpc.unary_unary_rpc_method_handler( + self._block_forever + ), + _BLOCK_BRIEFLY: grpc.unary_unary_rpc_method_handler( + self._block_briefly + ), + _UNARY_STREAM_ASYNC_GEN: grpc.unary_stream_rpc_method_handler( + self._unary_stream_async_gen + ), + _UNARY_STREAM_READER_WRITER: grpc.unary_stream_rpc_method_handler( + self._unary_stream_reader_writer + ), + _UNARY_STREAM_EVILLY_MIXED: grpc.unary_stream_rpc_method_handler( + self._unary_stream_evilly_mixed + ), + _STREAM_UNARY_ASYNC_GEN: grpc.stream_unary_rpc_method_handler( + self._stream_unary_async_gen + ), + _STREAM_UNARY_READER_WRITER: grpc.stream_unary_rpc_method_handler( + self._stream_unary_reader_writer + ), + _STREAM_UNARY_EVILLY_MIXED: grpc.stream_unary_rpc_method_handler( + self._stream_unary_evilly_mixed + ), + _STREAM_STREAM_ASYNC_GEN: grpc.stream_stream_rpc_method_handler( + self._stream_stream_async_gen + ), + _STREAM_STREAM_READER_WRITER: grpc.stream_stream_rpc_method_handler( + self._stream_stream_reader_writer + ), + _STREAM_STREAM_EVILLY_MIXED: grpc.stream_stream_rpc_method_handler( + self._stream_stream_evilly_mixed + ), + _ERROR_IN_STREAM_STREAM: grpc.stream_stream_rpc_method_handler( + self._error_in_stream_stream + ), + _ERROR_IN_STREAM_UNARY: grpc.stream_unary_rpc_method_handler( + self._value_error_in_stream_unary + ), + _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY: grpc.unary_unary_rpc_method_handler( + self._error_without_raise_in_unary_unary + ), + _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM: grpc.stream_stream_rpc_method_handler( + self._error_without_raise_in_stream_stream + ), + _INVALID_TRAILING_METADATA: grpc.unary_unary_rpc_method_handler( + self._invalid_trailing_metadata + ), } @staticmethod @@ -182,7 +184,7 @@ async def _stream_stream_evilly_mixed(self, request_iterator, context): async def _error_in_stream_stream(self, request_iterator, unused_context): async for request in request_iterator: assert _REQUEST == request - raise RuntimeError('A testing RuntimeError!') + raise RuntimeError("A testing RuntimeError!") yield _RESPONSE async def _value_error_in_stream_unary(self, request_iterator, context): @@ -191,14 +193,15 @@ async def _value_error_in_stream_unary(self, request_iterator, context): assert _REQUEST == request request_count += 1 if request_count >= 1: - raise ValueError('A testing RuntimeError!') + raise ValueError("A testing RuntimeError!") async def _error_without_raise_in_unary_unary(self, request, context): assert _REQUEST == request context.set_code(grpc.StatusCode.INTERNAL) - async def _error_without_raise_in_stream_stream(self, request_iterator, - context): + async def _error_without_raise_in_stream_stream( + self, request_iterator, context + ): async for request in request_iterator: assert _REQUEST == request context.set_code(grpc.StatusCode.INTERNAL) @@ -206,11 +209,10 @@ async def _error_without_raise_in_stream_stream(self, request_iterator, async def _invalid_trailing_metadata(self, request, context): assert _REQUEST == request for invalid_metadata in [ - 42, {}, { - 'error': 'error' - }, [{ - 'error': "error" - }] + 42, + {}, + {"error": "error"}, + [{"error": "error"}], ]: try: context.set_trailing_metadata(invalid_metadata) @@ -218,14 +220,15 @@ async def _invalid_trailing_metadata(self, request, context): pass else: raise ValueError( - f'No TypeError raised for invalid metadata: {invalid_metadata}' + "No TypeError raised for invalid metadata:" + f" {invalid_metadata}" ) - await context.abort(grpc.StatusCode.DATA_LOSS, - details="invalid abort", - trailing_metadata=({ - 'error': ('error1', 'error2') - })) + await context.abort( + grpc.StatusCode.DATA_LOSS, + details="invalid abort", + trailing_metadata=({"error": ("error1", "error2")}), + ) def service(self, handler_details): if not self._called.done(): @@ -238,15 +241,14 @@ async def wait_for_call(self): async def _start_test_server(): server = aio.server() - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") generic_handler = _GenericHandler() server.add_generic_rpc_handlers((generic_handler,)) await server.start() - return 'localhost:%d' % port, server, generic_handler + return "localhost:%d" % port, server, generic_handler class TestServer(AioTestBase): - async def setUp(self): addr, self._server, self._generic_handler = await _start_test_server() self._channel = aio.insecure_channel(addr) @@ -274,7 +276,8 @@ async def test_unary_stream_async_generator(self): async def test_unary_stream_reader_writer(self): unary_stream_call = self._channel.unary_stream( - _UNARY_STREAM_READER_WRITER) + _UNARY_STREAM_READER_WRITER + ) call = unary_stream_call(_REQUEST) for _ in range(_NUM_STREAM_RESPONSES): @@ -285,7 +288,8 @@ async def test_unary_stream_reader_writer(self): async def test_unary_stream_evilly_mixed(self): unary_stream_call = self._channel.unary_stream( - _UNARY_STREAM_EVILLY_MIXED) + _UNARY_STREAM_EVILLY_MIXED + ) call = unary_stream_call(_REQUEST) # Uses reader API @@ -328,7 +332,8 @@ def request_gen(): async def test_stream_unary_reader_writer(self): stream_unary_call = self._channel.stream_unary( - _STREAM_UNARY_READER_WRITER) + _STREAM_UNARY_READER_WRITER + ) call = stream_unary_call() for _ in range(_NUM_STREAM_REQUESTS): @@ -341,7 +346,8 @@ async def test_stream_unary_reader_writer(self): async def test_stream_unary_evilly_mixed(self): stream_unary_call = self._channel.stream_unary( - _STREAM_UNARY_EVILLY_MIXED) + _STREAM_UNARY_EVILLY_MIXED + ) call = stream_unary_call() for _ in range(_NUM_STREAM_REQUESTS): @@ -354,7 +360,8 @@ async def test_stream_unary_evilly_mixed(self): async def test_stream_stream_async_generator(self): stream_stream_call = self._channel.stream_stream( - _STREAM_STREAM_ASYNC_GEN) + _STREAM_STREAM_ASYNC_GEN + ) call = stream_stream_call() for _ in range(_NUM_STREAM_REQUESTS): @@ -369,7 +376,8 @@ async def test_stream_stream_async_generator(self): async def test_stream_stream_reader_writer(self): stream_stream_call = self._channel.stream_stream( - _STREAM_STREAM_READER_WRITER) + _STREAM_STREAM_READER_WRITER + ) call = stream_stream_call() for _ in range(_NUM_STREAM_REQUESTS): @@ -384,7 +392,8 @@ async def test_stream_stream_reader_writer(self): async def test_stream_stream_evilly_mixed(self): stream_stream_call = self._channel.stream_stream( - _STREAM_STREAM_EVILLY_MIXED) + _STREAM_STREAM_EVILLY_MIXED + ) call = stream_stream_call() for _ in range(_NUM_STREAM_REQUESTS): @@ -413,8 +422,9 @@ async def test_graceful_shutdown_success(self): shutdown_start_time = time.time() await self._server.stop(test_constants.SHORT_TIMEOUT) grace_period_length = time.time() - shutdown_start_time - self.assertGreater(grace_period_length, - test_constants.SHORT_TIMEOUT / 3) + self.assertGreater( + grace_period_length, test_constants.SHORT_TIMEOUT / 3 + ) # Validates the states. self.assertEqual(_RESPONSE, await call) @@ -428,8 +438,9 @@ async def test_graceful_shutdown_failed(self): with self.assertRaises(aio.AioRpcError) as exception_context: await call - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() + ) async def test_concurrent_graceful_shutdown(self): call = self._channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST) @@ -443,8 +454,9 @@ async def test_concurrent_graceful_shutdown(self): self._server.stop(test_constants.LONG_TIMEOUT), ) grace_period_length = time.time() - shutdown_start_time - self.assertGreater(grace_period_length, - test_constants.SHORT_TIMEOUT / 3) + self.assertGreater( + grace_period_length, test_constants.SHORT_TIMEOUT / 3 + ) self.assertEqual(_RESPONSE, await call) self.assertTrue(call.done()) @@ -463,8 +475,9 @@ async def test_concurrent_graceful_shutdown_immediate(self): with self.assertRaises(aio.AioRpcError) as exception_context: await call - self.assertEqual(grpc.StatusCode.UNAVAILABLE, - exception_context.exception.code()) + self.assertEqual( + grpc.StatusCode.UNAVAILABLE, exception_context.exception.code() + ) async def test_shutdown_before_call(self): await self._server.stop(None) @@ -483,7 +496,8 @@ async def test_unimplemented(self): async def test_shutdown_during_stream_stream(self): stream_stream_call = self._channel.stream_stream( - _STREAM_STREAM_ASYNC_GEN) + _STREAM_STREAM_ASYNC_GEN + ) call = stream_stream_call() # Don't half close the RPC yet, keep it alive. @@ -495,7 +509,8 @@ async def test_shutdown_during_stream_stream(self): async def test_error_in_stream_stream(self): stream_stream_call = self._channel.stream_stream( - _ERROR_IN_STREAM_STREAM) + _ERROR_IN_STREAM_STREAM + ) call = stream_stream_call() # Don't half close the RPC yet, keep it alive. @@ -506,7 +521,8 @@ async def test_error_in_stream_stream(self): async def test_error_without_raise_in_unary_unary(self): call = self._channel.unary_unary(_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY)( - _REQUEST) + _REQUEST + ) with self.assertRaises(aio.AioRpcError) as exception_context: await call @@ -516,7 +532,8 @@ async def test_error_without_raise_in_unary_unary(self): async def test_error_without_raise_in_stream_stream(self): call = self._channel.stream_stream( - _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM)() + _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM + )() for _ in range(_NUM_STREAM_REQUESTS): await call.write(_REQUEST) @@ -539,23 +556,23 @@ async def request_gen(): self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code()) async def test_port_binding_exception(self): - server = aio.server(options=(('grpc.so_reuseport', 0),)) - port = server.add_insecure_port('localhost:0') + server = aio.server(options=(("grpc.so_reuseport", 0),)) + port = server.add_insecure_port("localhost:0") bind_address = "localhost:%d" % port with self.assertRaises(RuntimeError): server.add_insecure_port(bind_address) - server_credentials = grpc.ssl_server_credentials([ - (resources.private_key(), resources.certificate_chain()) - ]) + server_credentials = grpc.ssl_server_credentials( + [(resources.private_key(), resources.certificate_chain())] + ) with self.assertRaises(RuntimeError): server.add_secure_port(bind_address, server_credentials) async def test_maximum_concurrent_rpcs(self): # Build the server with concurrent rpc argument server = aio.server(maximum_concurrent_rpcs=_MAXIMUM_CONCURRENT_RPCS) - port = server.add_insecure_port('localhost:0') + port = server.add_insecure_port("localhost:0") bind_address = "localhost:%d" % port server.add_generic_rpc_handlers((_GenericHandler(),)) await server.start() @@ -566,7 +583,8 @@ async def test_maximum_concurrent_rpcs(self): for _ in range(3 * _MAXIMUM_CONCURRENT_RPCS): rpcs.append(channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST)) task = self.loop.create_task( - asyncio.wait(rpcs, return_when=asyncio.FIRST_EXCEPTION)) + asyncio.wait(rpcs, return_when=asyncio.FIRST_EXCEPTION) + ) # Each batch took test_constants.SHORT_TIMEOUT /2 start_time = time.time() await task @@ -584,9 +602,9 @@ async def test_invalid_trailing_metadata(self): rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code()) - self.assertIn('trailing', rpc_error.details()) + self.assertIn("trailing", rpc_error.details()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/server_time_remaining_test.py b/src/python/grpcio_tests/tests_aio/unit/server_time_remaining_test.py index 340e4cc350a5a..bdb6a6e44f9f2 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_time_remaining_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_time_remaining_test.py @@ -25,19 +25,18 @@ from tests_aio.unit._common import AdhocGenericHandler from tests_aio.unit._test_base import AioTestBase -_REQUEST = b'\x09\x05' +_REQUEST = b"\x09\x05" _REQUEST_TIMEOUT_S = datetime.timedelta(seconds=5).total_seconds() class TestServerTimeRemaining(AioTestBase): - async def setUp(self): # Create async server - self._server = aio.server(options=(('grpc.so_reuseport', 0),)) + self._server = aio.server(options=(("grpc.so_reuseport", 0),)) self._adhoc_handlers = AdhocGenericHandler() self._server.add_generic_rpc_handlers((self._adhoc_handlers,)) - port = self._server.add_insecure_port('[::]:0') - address = 'localhost:%d' % port + port = self._server.add_insecure_port("[::]:0") + address = "localhost:%d" % port await self._server.start() # Create async channel self._channel = aio.insecure_channel(address) @@ -50,15 +49,17 @@ async def test_servicer_context_time_remaining(self): seen_time_remaining = [] @grpc.unary_unary_rpc_method_handler - def log_time_remaining(request: bytes, - context: grpc.ServicerContext) -> bytes: + def log_time_remaining( + request: bytes, context: grpc.ServicerContext + ) -> bytes: seen_time_remaining.append(context.time_remaining()) return b"" # Check if the deadline propagates properly self._adhoc_handlers.set_adhoc_handler(log_time_remaining) await self._channel.unary_unary(ADHOC_METHOD)( - _REQUEST, timeout=_REQUEST_TIMEOUT_S) + _REQUEST, timeout=_REQUEST_TIMEOUT_S + ) self.assertGreater(seen_time_remaining[0], _REQUEST_TIMEOUT_S / 2) # Check if there is no timeout, the time_remaining will be None self._adhoc_handlers.set_adhoc_handler(log_time_remaining) @@ -66,6 +67,6 @@ def log_time_remaining(request: bytes, self.assertIsNone(seen_time_remaining[1]) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/timeout_test.py b/src/python/grpcio_tests/tests_aio/unit/timeout_test.py index dab0f5113f4c9..b6834ba4e815e 100644 --- a/src/python/grpcio_tests/tests_aio/unit/timeout_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/timeout_test.py @@ -28,13 +28,13 @@ _SLEEP_TIME_UNIT_S = datetime.timedelta(seconds=1).total_seconds() -_TEST_SLEEPY_UNARY_UNARY = '/test/Test/SleepyUnaryUnary' -_TEST_SLEEPY_UNARY_STREAM = '/test/Test/SleepyUnaryStream' -_TEST_SLEEPY_STREAM_UNARY = '/test/Test/SleepyStreamUnary' -_TEST_SLEEPY_STREAM_STREAM = '/test/Test/SleepyStreamStream' +_TEST_SLEEPY_UNARY_UNARY = "/test/Test/SleepyUnaryUnary" +_TEST_SLEEPY_UNARY_STREAM = "/test/Test/SleepyUnaryStream" +_TEST_SLEEPY_STREAM_UNARY = "/test/Test/SleepyStreamUnary" +_TEST_SLEEPY_STREAM_STREAM = "/test/Test/SleepyStreamStream" -_REQUEST = b'\x00\x00\x00' -_RESPONSE = b'\x01\x01\x01' +_REQUEST = b"\x00\x00\x00" +_RESPONSE = b"\x01\x01\x01" async def _test_sleepy_unary_unary(unused_request, unused_context): @@ -62,40 +62,44 @@ async def _test_sleepy_stream_stream(unused_request_iterator, context): _ROUTING_TABLE = { - _TEST_SLEEPY_UNARY_UNARY: - grpc.unary_unary_rpc_method_handler(_test_sleepy_unary_unary), - _TEST_SLEEPY_UNARY_STREAM: - grpc.unary_stream_rpc_method_handler(_test_sleepy_unary_stream), - _TEST_SLEEPY_STREAM_UNARY: - grpc.stream_unary_rpc_method_handler(_test_sleepy_stream_unary), - _TEST_SLEEPY_STREAM_STREAM: - grpc.stream_stream_rpc_method_handler(_test_sleepy_stream_stream) + _TEST_SLEEPY_UNARY_UNARY: grpc.unary_unary_rpc_method_handler( + _test_sleepy_unary_unary + ), + _TEST_SLEEPY_UNARY_STREAM: grpc.unary_stream_rpc_method_handler( + _test_sleepy_unary_stream + ), + _TEST_SLEEPY_STREAM_UNARY: grpc.stream_unary_rpc_method_handler( + _test_sleepy_stream_unary + ), + _TEST_SLEEPY_STREAM_STREAM: grpc.stream_stream_rpc_method_handler( + _test_sleepy_stream_stream + ), } class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): return _ROUTING_TABLE.get(handler_call_details.method) async def _start_test_server(): server = aio.server() - port = server.add_insecure_port('[::]:0') + port = server.add_insecure_port("[::]:0") server.add_generic_rpc_handlers((_GenericHandler(),)) await server.start() - return f'localhost:{port}', server + return f"localhost:{port}", server class TestTimeout(AioTestBase): - async def setUp(self): address, self._server = await _start_test_server() self._client = aio.insecure_channel(address) - self.assertEqual(grpc.ChannelConnectivity.IDLE, - self._client.get_state(True)) - await _common.block_until_certain_state(self._client, - grpc.ChannelConnectivity.READY) + self.assertEqual( + grpc.ChannelConnectivity.IDLE, self._client.get_state(True) + ) + await _common.block_until_certain_state( + self._client, grpc.ChannelConnectivity.READY + ) async def tearDown(self): await self._client.close() @@ -173,6 +177,6 @@ async def test_stream_stream_deadline_exceeded(self): self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.code()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py b/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py index a49a1241c4f77..2cc8dc7c28241 100644 --- a/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/wait_for_connection_test.py @@ -29,8 +29,8 @@ from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server -_REQUEST = b'\x01\x02\x03' -_TEST_METHOD = '/test/Test' +_REQUEST = b"\x01\x02\x03" +_TEST_METHOD = "/test/Test" _NUM_STREAM_RESPONSES = 5 _REQUEST_PAYLOAD_SIZE = 7 @@ -64,7 +64,8 @@ async def test_unary_stream_ok(self): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) call = self._stub.StreamingOutputCall(request) @@ -74,8 +75,9 @@ async def test_unary_stream_ok(self): response_cnt = 0 async for response in call: response_cnt += 1 - self.assertIs(type(response), - messages_pb2.StreamingOutputCallResponse) + self.assertIs( + type(response), messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(_NUM_STREAM_RESPONSES, response_cnt) @@ -87,7 +89,7 @@ async def test_stream_unary_ok(self): # No exception raised and no message swallowed. await call.wait_for_connection() - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) for _ in range(_NUM_STREAM_RESPONSES): @@ -96,8 +98,10 @@ async def test_stream_unary_ok(self): response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) - self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, - response.aggregated_payload_size) + self.assertEqual( + _NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size, + ) self.assertEqual(await call.code(), grpc.StatusCode.OK) @@ -109,13 +113,15 @@ async def test_stream_stream_ok(self): request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) response = await call.read() - self.assertIsInstance(response, - messages_pb2.StreamingOutputCallResponse) + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse + ) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) await call.done_writing() @@ -155,6 +161,6 @@ async def test_stream_stream_error(self): self.assertEqual(grpc.StatusCode.UNAVAILABLE, rpc_error.code()) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py b/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py index 303c138642aeb..334fd60cb194a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/wait_for_ready_test.py @@ -37,20 +37,25 @@ async def _perform_unary_unary(stub, wait_for_ready): - await stub.UnaryCall(messages_pb2.SimpleRequest(), - timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + await stub.UnaryCall( + messages_pb2.SimpleRequest(), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready, + ) async def _perform_unary_stream(stub, wait_for_ready): request = messages_pb2.StreamingOutputCallRequest() for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) - call = stub.StreamingOutputCall(request, - timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + call = stub.StreamingOutputCall( + request, + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready, + ) for _ in range(_NUM_STREAM_RESPONSES): await call.read() @@ -58,25 +63,29 @@ async def _perform_unary_stream(stub, wait_for_ready): async def _perform_stream_unary(stub, wait_for_ready): - payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + payload = messages_pb2.Payload(body=b"\0" * _REQUEST_PAYLOAD_SIZE) request = messages_pb2.StreamingInputCallRequest(payload=payload) async def gen(): for _ in range(_NUM_STREAM_RESPONSES): yield request - await stub.StreamingInputCall(gen(), - timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + await stub.StreamingInputCall( + gen(), + timeout=test_constants.LONG_TIMEOUT, + wait_for_ready=wait_for_ready, + ) async def _perform_stream_stream(stub, wait_for_ready): - call = stub.FullDuplexCall(timeout=test_constants.LONG_TIMEOUT, - wait_for_ready=wait_for_ready) + call = stub.FullDuplexCall( + timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready + ) request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ) for _ in range(_NUM_STREAM_RESPONSES): await call.write(request) @@ -96,7 +105,6 @@ async def _perform_stream_stream(stub, wait_for_ready): class TestWaitForReady(AioTestBase): - async def setUp(self): address, self._port, self._socket = get_socket(listen=False) self._channel = aio.insecure_channel(f"{address}:{self._port}") @@ -122,8 +130,10 @@ async def test_call_wait_for_ready_disabled(self): """RPC should fail immediately after connection failed.""" await self._connection_fails_fast(False) - @unittest.skipIf(platform.system() == 'Windows', - 'https://github.com/grpc/grpc/pull/26729') + @unittest.skipIf( + platform.system() == "Windows", + "https://github.com/grpc/grpc/pull/26729", + ) async def test_call_wait_for_ready_enabled(self): """RPC will wait until the connection is ready.""" for action in _RPC_ACTIONS: @@ -133,7 +143,8 @@ async def test_call_wait_for_ready_enabled(self): # Wait for TRANSIENT_FAILURE, and RPC is not aborting await _common.block_until_certain_state( - self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE) + self._channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE + ) try: # Start the server @@ -146,6 +157,6 @@ async def test_call_wait_for_ready_enabled(self): await server.stop(None) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_gevent/unit/_test_server.py b/src/python/grpcio_tests/tests_gevent/unit/_test_server.py index 82327aa7098da..68c36093151d0 100644 --- a/src/python/grpcio_tests/tests_gevent/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_gevent/unit/_test_server.py @@ -25,7 +25,6 @@ class TestServiceServicer(test_pb2_grpc.TestServiceServicer): - def UnaryCall(self, request, context): return messages_pb2.SimpleResponse() @@ -37,25 +36,26 @@ def UnaryCallWithSleep(self, unused_request, unused_context): def start_test_server(port: int = 0) -> Tuple[str, Any]: server = grpc.server(futures.ThreadPoolExecutor()) servicer = TestServiceServicer() - test_pb2_grpc.add_TestServiceServicer_to_server(TestServiceServicer(), - server) + test_pb2_grpc.add_TestServiceServicer_to_server( + TestServiceServicer(), server + ) server.add_generic_rpc_handlers((_create_extra_generic_handler(servicer),)) - port = server.add_insecure_port('[::]:%d' % port) + port = server.add_insecure_port("[::]:%d" % port) server.start() - return 'localhost:%d' % port, server + return "localhost:%d" % port, server def _create_extra_generic_handler(servicer: TestServiceServicer) -> Any: # Add programatically extra methods not provided by the proto file # that are used during the tests rpc_method_handlers = { - 'UnaryCallWithSleep': - grpc.unary_unary_rpc_method_handler( - servicer.UnaryCallWithSleep, - request_deserializer=messages_pb2.SimpleRequest.FromString, - response_serializer=messages_pb2.SimpleResponse. - SerializeToString) + "UnaryCallWithSleep": grpc.unary_unary_rpc_method_handler( + servicer.UnaryCallWithSleep, + request_deserializer=messages_pb2.SimpleRequest.FromString, + response_serializer=messages_pb2.SimpleResponse.SerializeToString, + ) } - return grpc.method_handlers_generic_handler('grpc.testing.TestService', - rpc_method_handlers) + return grpc.method_handlers_generic_handler( + "grpc.testing.TestService", rpc_method_handlers + ) diff --git a/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py b/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py index ca73fd685d59a..47fdb2c22e783 100644 --- a/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py @@ -23,11 +23,10 @@ from src.proto.grpc.testing import test_pb2_grpc from tests_gevent.unit._test_server import start_test_server -_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' +_UNARY_CALL_METHOD_WITH_SLEEP = "/grpc.testing.TestService/UnaryCallWithSleep" class CloseChannelTest(unittest.TestCase): - def setUp(self): self._server_target, self._server = start_test_server() self._channel = grpc.insecure_channel(self._server_target) @@ -101,5 +100,5 @@ def _global_exception_handler(self, exctype, value, tb): sys.__excepthook__(exctype, value, tb) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py index 582da6a1c0ffb..7ac32f108445d 100644 --- a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py +++ b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py @@ -34,7 +34,7 @@ logger = logging.getLogger() console_handler = logging.StreamHandler() -formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') +formatter = logging.Formatter(fmt="%(asctime)s: %(levelname)-8s %(message)s") console_handler.setFormatter(formatter) logger.addHandler(console_handler) @@ -76,7 +76,8 @@ def __init__(self, start: int, end: int): self._rpcs_needed = end - start self._rpcs_by_peer = collections.defaultdict(int) self._rpcs_by_method = collections.defaultdict( - lambda: collections.defaultdict(int)) + lambda: collections.defaultdict(int) + ) self._condition = threading.Condition() self._no_remote_peer = 0 @@ -93,11 +94,13 @@ def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None: self._condition.notify() def await_rpc_stats_response( - self, timeout_sec: int) -> messages_pb2.LoadBalancerStatsResponse: + self, timeout_sec: int + ) -> messages_pb2.LoadBalancerStatsResponse: """Blocks until a full response has been collected.""" with self._condition: - self._condition.wait_for(lambda: not self._rpcs_needed, - timeout=float(timeout_sec)) + self._condition.wait_for( + lambda: not self._rpcs_needed, timeout=float(timeout_sec) + ) response = messages_pb2.LoadBalancerStatsResponse() for peer, count in self._rpcs_by_peer.items(): response.rpcs_by_peer[peer] = count @@ -119,7 +122,8 @@ def await_rpc_stats_response( # Mapping[method, Mapping[status_code, count]] _global_rpc_statuses: Mapping[str, Mapping[int, int]] = collections.defaultdict( - lambda: collections.defaultdict(int)) + lambda: collections.defaultdict(int) +) def _handle_sigint(sig, frame) -> None: @@ -128,15 +132,16 @@ def _handle_sigint(sig, frame) -> None: _global_server.stop(None) -class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer - ): - +class _LoadBalancerStatsServicer( + test_pb2_grpc.LoadBalancerStatsServiceServicer +): def __init__(self): super(_LoadBalancerStatsServicer).__init__() def GetClientStats( - self, request: messages_pb2.LoadBalancerStatsRequest, - context: grpc.ServicerContext + self, + request: messages_pb2.LoadBalancerStatsRequest, + context: grpc.ServicerContext, ) -> messages_pb2.LoadBalancerStatsResponse: logger.info("Received stats request.") start = None @@ -154,8 +159,9 @@ def GetClientStats( return response def GetClientAccumulatedStats( - self, request: messages_pb2.LoadBalancerAccumulatedStatsRequest, - context: grpc.ServicerContext + self, + request: messages_pb2.LoadBalancerAccumulatedStatsRequest, + context: grpc.ServicerContext, ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse: logger.info("Received cumulative stats request.") response = messages_pb2.LoadBalancerAccumulatedStatsResponse() @@ -163,39 +169,48 @@ def GetClientAccumulatedStats( for method in _SUPPORTED_METHODS: caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method] response.num_rpcs_started_by_method[ - caps_method] = _global_rpcs_started[method] + caps_method + ] = _global_rpcs_started[method] response.num_rpcs_succeeded_by_method[ - caps_method] = _global_rpcs_succeeded[method] + caps_method + ] = _global_rpcs_succeeded[method] response.num_rpcs_failed_by_method[ - caps_method] = _global_rpcs_failed[method] + caps_method + ] = _global_rpcs_failed[method] response.stats_per_method[ - caps_method].rpcs_started = _global_rpcs_started[method] + caps_method + ].rpcs_started = _global_rpcs_started[method] for code, count in _global_rpc_statuses[method].items(): response.stats_per_method[caps_method].result[code] = count logger.info("Returning cumulative stats response.") return response -def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]], - request_id: int, stub: test_pb2_grpc.TestServiceStub, - timeout: float, futures: Mapping[int, Tuple[grpc.Future, - str]]) -> None: +def _start_rpc( + method: str, + metadata: Sequence[Tuple[str, str]], + request_id: int, + stub: test_pb2_grpc.TestServiceStub, + timeout: float, + futures: Mapping[int, Tuple[grpc.Future, str]], +) -> None: logger.debug(f"Sending {method} request to backend: {request_id}") if method == "UnaryCall": - future = stub.UnaryCall.future(messages_pb2.SimpleRequest(), - metadata=metadata, - timeout=timeout) + future = stub.UnaryCall.future( + messages_pb2.SimpleRequest(), metadata=metadata, timeout=timeout + ) elif method == "EmptyCall": - future = stub.EmptyCall.future(empty_pb2.Empty(), - metadata=metadata, - timeout=timeout) + future = stub.EmptyCall.future( + empty_pb2.Empty(), metadata=metadata, timeout=timeout + ) else: raise ValueError(f"Unrecognized method '{method}'.") futures[request_id] = (future, method) -def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str, - print_response: bool) -> None: +def _on_rpc_done( + rpc_id: int, future: grpc.Future, method: str, print_response: bool +) -> None: exception = future.exception() hostname = "" with _global_lock: @@ -232,8 +247,9 @@ def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str, watcher.on_rpc_complete(rpc_id, hostname, method) -def _remove_completed_rpcs(futures: Mapping[int, grpc.Future], - print_response: bool) -> None: +def _remove_completed_rpcs( + futures: Mapping[int, grpc.Future], print_response: bool +) -> None: logger.debug("Removing completed RPCs") done = [] for future_id, (future, method) in futures.items(): @@ -258,9 +274,16 @@ class _ChannelConfiguration: When accessing any of its members, the lock member should be held. """ - def __init__(self, method: str, metadata: Sequence[Tuple[str, str]], - qps: int, server: str, rpc_timeout_sec: int, - print_response: bool, secure_mode: bool): + def __init__( + self, + method: str, + metadata: Sequence[Tuple[str, str]], + qps: int, + server: str, + rpc_timeout_sec: int, + print_response: bool, + secure_mode: bool, + ): # condition is signalled when a change is made to the config. self.condition = threading.Condition() @@ -291,7 +314,8 @@ def _run_single_channel(config: _ChannelConfiguration) -> None: with config.condition: if config.qps == 0: config.condition.wait( - timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds()) + timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds() + ) continue else: duration_per_query = 1.0 / float(config.qps) @@ -302,8 +326,14 @@ def _run_single_channel(config: _ChannelConfiguration) -> None: _global_rpcs_started[config.method] += 1 start = time.time() end = start + duration_per_query - _start_rpc(config.method, config.metadata, request_id, stub, - float(config.rpc_timeout_sec), futures) + _start_rpc( + config.method, + config.metadata, + request_id, + stub, + float(config.rpc_timeout_sec), + futures, + ) print_response = config.print_response _remove_completed_rpcs(futures, config.print_response) logger.debug(f"Currently {len(futures)} in-flight RPCs") @@ -315,17 +345,19 @@ def _run_single_channel(config: _ChannelConfiguration) -> None: class _XdsUpdateClientConfigureServicer( - test_pb2_grpc.XdsUpdateClientConfigureServiceServicer): - - def __init__(self, per_method_configs: Mapping[str, _ChannelConfiguration], - qps: int): + test_pb2_grpc.XdsUpdateClientConfigureServiceServicer +): + def __init__( + self, per_method_configs: Mapping[str, _ChannelConfiguration], qps: int + ): super(_XdsUpdateClientConfigureServicer).__init__() self._per_method_configs = per_method_configs self._qps = qps def Configure( - self, request: messages_pb2.ClientConfigureRequest, - context: grpc.ServicerContext + self, + request: messages_pb2.ClientConfigureRequest, + context: grpc.ServicerContext, ) -> messages_pb2.ClientConfigureResponse: logger.info("Received Configure RPC: %s", request) method_strs = [_METHOD_ENUM_TO_STR[t] for t in request.types] @@ -334,9 +366,11 @@ def Configure( channel_config = self._per_method_configs[method] if method in method_strs: qps = self._qps - metadata = ((md.key, md.value) - for md in request.metadata - if md.type == method_enum) + metadata = ( + (md.key, md.value) + for md in request.metadata + if md.type == method_enum + ) # For backward compatibility, do not change timeout when we # receive a default value timeout. if request.timeout_sec == 0: @@ -361,13 +395,15 @@ class _MethodHandle: _channel_threads: List[threading.Thread] - def __init__(self, num_channels: int, - channel_config: _ChannelConfiguration): + def __init__( + self, num_channels: int, channel_config: _ChannelConfiguration + ): """Creates and starts a group of threads running the indicated method.""" self._channel_threads = [] for i in range(num_channels): - thread = threading.Thread(target=_run_single_channel, - args=(channel_config,)) + thread = threading.Thread( + target=_run_single_channel, args=(channel_config,) + ) thread.start() self._channel_threads.append(thread) @@ -377,8 +413,11 @@ def stop(self) -> None: channel_thread.join() -def _run(args: argparse.Namespace, methods: Sequence[str], - per_method_metadata: PerMethodMetadataType) -> None: +def _run( + args: argparse.Namespace, + methods: Sequence[str], + per_method_metadata: PerMethodMetadataType, +) -> None: logger.info("Starting python xDS Interop Client.") global _global_server # pylint: disable=global-statement method_handles = [] @@ -389,17 +428,25 @@ def _run(args: argparse.Namespace, methods: Sequence[str], else: qps = 0 channel_config = _ChannelConfiguration( - method, per_method_metadata.get(method, []), qps, args.server, - args.rpc_timeout_sec, args.print_response, args.secure_mode) + method, + per_method_metadata.get(method, []), + qps, + args.server, + args.rpc_timeout_sec, + args.print_response, + args.secure_mode, + ) channel_configs[method] = channel_config method_handles.append(_MethodHandle(args.num_channels, channel_config)) _global_server = grpc.server(futures.ThreadPoolExecutor()) _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server( - _LoadBalancerStatsServicer(), _global_server) + _LoadBalancerStatsServicer(), _global_server + ) test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server( _XdsUpdateClientConfigureServicer(channel_configs, args.qps), - _global_server) + _global_server, + ) channelz.add_channelz_servicer(_global_server) grpc_admin.add_admin_servicers(_global_server) _global_server.start() @@ -415,7 +462,8 @@ def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType: elems = metadatum.split(":") if len(elems) != 3: raise ValueError( - f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'") + f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'" + ) if elems[0] not in _SUPPORTED_METHODS: raise ValueError(f"Unrecognized method '{elems[0]}'") per_method_metadata[elems[0]].append((elems[1], elems[2])) @@ -425,8 +473,9 @@ def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType: def parse_rpc_arg(rpc_arg: str) -> Sequence[str]: methods = rpc_arg.split(",") if set(methods) - set(_SUPPORTED_METHODS): - raise ValueError("--rpc supported methods: {}".format( - ", ".join(_SUPPORTED_METHODS))) + raise ValueError( + "--rpc supported methods: {}".format(", ".join(_SUPPORTED_METHODS)) + ) return methods @@ -441,61 +490,72 @@ def bool_arg(arg: str) -> bool: if __name__ == "__main__": parser = argparse.ArgumentParser( - description='Run Python XDS interop client.') + description="Run Python XDS interop client." + ) parser.add_argument( "--num_channels", default=1, type=int, - help="The number of channels from which to send requests.") - parser.add_argument("--print_response", - default="False", - type=bool_arg, - help="Write RPC response to STDOUT.") + help="The number of channels from which to send requests.", + ) + parser.add_argument( + "--print_response", + default="False", + type=bool_arg, + help="Write RPC response to STDOUT.", + ) parser.add_argument( "--qps", default=1, type=int, - help="The number of queries to send from each channel per second.") - parser.add_argument("--rpc_timeout_sec", - default=30, - type=int, - help="The per-RPC timeout in seconds.") - parser.add_argument("--server", - default="localhost:50051", - help="The address of the server.") + help="The number of queries to send from each channel per second.", + ) + parser.add_argument( + "--rpc_timeout_sec", + default=30, + type=int, + help="The per-RPC timeout in seconds.", + ) + parser.add_argument( + "--server", default="localhost:50051", help="The address of the server." + ) parser.add_argument( "--stats_port", default=50052, type=int, - help="The port on which to expose the peer distribution stats service.") + help="The port on which to expose the peer distribution stats service.", + ) parser.add_argument( "--secure_mode", default="False", type=bool_arg, - help="If specified, uses xDS credentials to connect to the server.") - parser.add_argument('--verbose', - help='verbose log output', - default=False, - action='store_true') - parser.add_argument("--log_file", - default=None, - type=str, - help="A file to log to.") + help="If specified, uses xDS credentials to connect to the server.", + ) + parser.add_argument( + "--verbose", + help="verbose log output", + default=False, + action="store_true", + ) + parser.add_argument( + "--log_file", default=None, type=str, help="A file to log to." + ) rpc_help = "A comma-delimited list of RPC methods to run. Must be one of " rpc_help += ", ".join(_SUPPORTED_METHODS) rpc_help += "." parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help) metadata_help = ( - "A comma-delimited list of 3-tuples of the form " + - "METHOD:KEY:VALUE, e.g. " + - "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3") + "A comma-delimited list of 3-tuples of the form " + + "METHOD:KEY:VALUE, e.g. " + + "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3" + ) parser.add_argument("--metadata", default="", type=str, help=metadata_help) args = parser.parse_args() signal.signal(signal.SIGINT, _handle_sigint) if args.verbose: logger.setLevel(logging.DEBUG) if args.log_file: - file_handler = logging.FileHandler(args.log_file, mode='a') + file_handler = logging.FileHandler(args.log_file, mode="a") file_handler.setFormatter(formatter) logger.addHandler(file_handler) _run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata)) diff --git a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client_test.py b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client_test.py index 98f6e388b191c..17419fe2d77b8 100644 --- a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client_test.py +++ b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client_test.py @@ -59,41 +59,49 @@ def _set_union(a: Iterable, b: Iterable) -> Set: def _start_python_with_args( file: str, args: List[str] ) -> Tuple[subprocess.Popen, tempfile.TemporaryFile, tempfile.TemporaryFile]: - with tempfile.TemporaryFile(mode='r') as stdout: - with tempfile.TemporaryFile(mode='r') as stderr: - proc = subprocess.Popen((sys.executable, file) + tuple(args), - stdout=stdout, - stderr=stderr) + with tempfile.TemporaryFile(mode="r") as stdout: + with tempfile.TemporaryFile(mode="r") as stderr: + proc = subprocess.Popen( + (sys.executable, file) + tuple(args), + stdout=stdout, + stderr=stderr, + ) yield proc, stdout, stderr -def _dump_stream(process_name: str, stream_name: str, - stream: tempfile.TemporaryFile): +def _dump_stream( + process_name: str, stream_name: str, stream: tempfile.TemporaryFile +): sys.stderr.write(f"{process_name} {stream_name}:\n") stream.seek(0) sys.stderr.write(stream.read()) -def _dump_streams(process_name: str, stdout: tempfile.TemporaryFile, - stderr: tempfile.TemporaryFile): +def _dump_streams( + process_name: str, + stdout: tempfile.TemporaryFile, + stderr: tempfile.TemporaryFile, +): _dump_stream(process_name, "stdout", stdout) _dump_stream(process_name, "stderr", stderr) sys.stderr.write(f"End {process_name} output.\n") def _index_accumulated_stats( - response: messages_pb2.LoadBalancerAccumulatedStatsResponse + response: messages_pb2.LoadBalancerAccumulatedStatsResponse, ) -> Mapping[str, Mapping[int, int]]: indexed = collections.defaultdict(lambda: collections.defaultdict(int)) for _, method_str in _METHODS: for status in response.stats_per_method[method_str].result.keys(): indexed[method_str][status] = response.stats_per_method[ - method_str].result[status] + method_str + ].result[status] return indexed -def _subtract_indexed_stats(a: Mapping[str, Mapping[int, int]], - b: Mapping[str, Mapping[int, int]]): +def _subtract_indexed_stats( + a: Mapping[str, Mapping[int, int]], b: Mapping[str, Mapping[int, int]] +): c = collections.defaultdict(lambda: collections.defaultdict(int)) all_methods = _set_union(a.keys(), b.keys()) for method in all_methods: @@ -103,26 +111,29 @@ def _subtract_indexed_stats(a: Mapping[str, Mapping[int, int]], return c -def _collect_stats(stats_port: int, - duration: int) -> Mapping[str, Mapping[int, int]]: +def _collect_stats( + stats_port: int, duration: int +) -> Mapping[str, Mapping[int, int]]: settings = { "target": f"localhost:{stats_port}", "insecure": True, } response = test_pb2_grpc.LoadBalancerStatsService.GetClientAccumulatedStats( - messages_pb2.LoadBalancerAccumulatedStatsRequest(), **settings) + messages_pb2.LoadBalancerAccumulatedStatsRequest(), **settings + ) before = _index_accumulated_stats(response) time.sleep(duration) response = test_pb2_grpc.LoadBalancerStatsService.GetClientAccumulatedStats( - messages_pb2.LoadBalancerAccumulatedStatsRequest(), **settings) + messages_pb2.LoadBalancerAccumulatedStatsRequest(), **settings + ) after = _index_accumulated_stats(response) return _subtract_indexed_stats(after, before) class XdsInteropClientTest(unittest.TestCase): - - def _assert_client_consistent(self, server_port: int, stats_port: int, - qps: int, num_channels: int): + def _assert_client_consistent( + self, server_port: int, stats_port: int, qps: int, num_channels: int + ): settings = { "target": f"localhost:{stats_port}", "insecure": True, @@ -131,7 +142,8 @@ def _assert_client_consistent(self, server_port: int, stats_port: int, target_method, target_method_str = _METHODS[i % len(_METHODS)] test_pb2_grpc.XdsUpdateClientConfigureService.Configure( messages_pb2.ClientConfigureRequest(types=[target_method]), - **settings) + **settings, + ) delta = _collect_stats(stats_port, _ITERATION_DURATION_SECONDS) logging.info("Delta: %s", delta) for _, method_str in _METHODS: @@ -145,27 +157,34 @@ def test_configure_consistency(self): _, server_port, socket = framework_common.get_socket() with _start_python_with_args( - _SERVER_PATH, - [f"--port={server_port}", f"--maintenance_port={server_port}" - ]) as (server, server_stdout, server_stderr): + _SERVER_PATH, + [f"--port={server_port}", f"--maintenance_port={server_port}"], + ) as (server, server_stdout, server_stderr): # Send RPC to server to make sure it's running. logging.info("Sending RPC to server.") - test_pb2_grpc.TestService.EmptyCall(empty_pb2.Empty(), - f"localhost:{server_port}", - insecure=True, - wait_for_ready=True) + test_pb2_grpc.TestService.EmptyCall( + empty_pb2.Empty(), + f"localhost:{server_port}", + insecure=True, + wait_for_ready=True, + ) logging.info("Server successfully started.") socket.close() _, stats_port, stats_socket = framework_common.get_socket() - with _start_python_with_args(_CLIENT_PATH, [ + with _start_python_with_args( + _CLIENT_PATH, + [ f"--server=localhost:{server_port}", - f"--stats_port={stats_port}", f"--qps={_QPS}", - f"--num_channels={_NUM_CHANNELS}" - ]) as (client, client_stdout, client_stderr): + f"--stats_port={stats_port}", + f"--qps={_QPS}", + f"--num_channels={_NUM_CHANNELS}", + ], + ) as (client, client_stdout, client_stderr): stats_socket.close() try: - self._assert_client_consistent(server_port, stats_port, - _QPS, _NUM_CHANNELS) + self._assert_client_consistent( + server_port, stats_port, _QPS, _NUM_CHANNELS + ) except: _dump_streams("server", server_stdout, server_stderr) _dump_streams("client", client_stdout, client_stderr) @@ -177,6 +196,6 @@ def test_configure_consistency(self): client.wait(timeout=_SUBPROCESS_TIMEOUT_SECONDS) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_server.py b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_server.py index d5e1ce5c71876..2e38fc58533d0 100644 --- a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_server.py +++ b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_server.py @@ -46,40 +46,44 @@ logger = logging.getLogger() console_handler = logging.StreamHandler() -formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') +formatter = logging.Formatter(fmt="%(asctime)s: %(levelname)-8s %(message)s") console_handler.setFormatter(formatter) logger.addHandler(console_handler) class TestService(test_pb2_grpc.TestServiceServicer): - def __init__(self, server_id, hostname): self._server_id = server_id self._hostname = hostname - def EmptyCall(self, _: empty_pb2.Empty, - context: grpc.ServicerContext) -> empty_pb2.Empty: - context.send_initial_metadata((('hostname', self._hostname),)) + def EmptyCall( + self, _: empty_pb2.Empty, context: grpc.ServicerContext + ) -> empty_pb2.Empty: + context.send_initial_metadata((("hostname", self._hostname),)) return empty_pb2.Empty() - def UnaryCall(self, request: messages_pb2.SimpleRequest, - context: grpc.ServicerContext) -> messages_pb2.SimpleResponse: - context.send_initial_metadata((('hostname', self._hostname),)) + def UnaryCall( + self, request: messages_pb2.SimpleRequest, context: grpc.ServicerContext + ) -> messages_pb2.SimpleResponse: + context.send_initial_metadata((("hostname", self._hostname),)) response = messages_pb2.SimpleResponse() response.server_id = self._server_id response.hostname = self._hostname return response -def _configure_maintenance_server(server: grpc.Server, - maintenance_port: int) -> None: +def _configure_maintenance_server( + server: grpc.Server, maintenance_port: int +) -> None: channelz.add_channelz_servicer(server) listen_address = f"{_LISTEN_HOST}:{maintenance_port}" server.add_insecure_port(listen_address) health_servicer = grpc_health.HealthServicer( experimental_non_blocking=True, experimental_thread_pool=futures.ThreadPoolExecutor( - max_workers=_THREAD_POOL_SIZE)) + max_workers=_THREAD_POOL_SIZE + ), + ) health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) SERVICE_NAMES = ( @@ -93,10 +97,12 @@ def _configure_maintenance_server(server: grpc.Server, reflection.enable_server_reflection(SERVICE_NAMES, server) -def _configure_test_server(server: grpc.Server, port: int, secure_mode: bool, - server_id: str) -> None: +def _configure_test_server( + server: grpc.Server, port: int, secure_mode: bool, server_id: str +) -> None: test_pb2_grpc.add_TestServiceServicer_to_server( - TestService(server_id, socket.gethostname()), server) + TestService(server_id, socket.gethostname()), server + ) listen_address = f"{_LISTEN_HOST}:{port}" if not secure_mode: server.add_insecure_port(listen_address) @@ -107,11 +113,13 @@ def _configure_test_server(server: grpc.Server, port: int, secure_mode: bool, server.add_secure_port(listen_address, server_creds) -def _run(port: int, maintenance_port: int, secure_mode: bool, - server_id: str) -> None: +def _run( + port: int, maintenance_port: int, secure_mode: bool, server_id: str +) -> None: if port == maintenance_port: server = grpc.server( - futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE)) + futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE) + ) _configure_test_server(server, port, secure_mode, server_id) _configure_maintenance_server(server, maintenance_port) server.start() @@ -120,13 +128,15 @@ def _run(port: int, maintenance_port: int, secure_mode: bool, server.wait_for_termination() else: maintenance_server = grpc.server( - futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE)) + futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE) + ) _configure_maintenance_server(maintenance_server, maintenance_port) maintenance_server.start() logger.info("Maintenance server listening on port %d", maintenance_port) test_server = grpc.server( futures.ThreadPoolExecutor(max_workers=_THREAD_POOL_SIZE), - xds=secure_mode) + xds=secure_mode, + ) _configure_test_server(test_server, port, secure_mode, server_id) test_server.start() logger.info("Test server listening on port %d", port) @@ -145,28 +155,35 @@ def bool_arg(arg: str) -> bool: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Run Python xDS interop server.") - parser.add_argument("--port", - type=int, - default=8080, - help="Port for test server.") - parser.add_argument("--maintenance_port", - type=int, - default=8080, - help="Port for servers besides test server.") + description="Run Python xDS interop server." + ) + parser.add_argument( + "--port", type=int, default=8080, help="Port for test server." + ) + parser.add_argument( + "--maintenance_port", + type=int, + default=8080, + help="Port for servers besides test server.", + ) parser.add_argument( "--secure_mode", type=bool_arg, default="False", - help="If specified, uses xDS to retrieve server credentials.") - parser.add_argument("--server_id", - type=str, - default="python_server", - help="The server ID to return in responses..") - parser.add_argument('--verbose', - help='verbose log output', - default=False, - action='store_true') + help="If specified, uses xDS to retrieve server credentials.", + ) + parser.add_argument( + "--server_id", + type=str, + default="python_server", + help="The server ID to return in responses..", + ) + parser.add_argument( + "--verbose", + help="verbose log output", + default=False, + action="store_true", + ) args = parser.parse_args() if args.verbose: logger.setLevel(logging.DEBUG) @@ -174,6 +191,7 @@ def bool_arg(arg: str) -> bool: logger.setLevel(logging.INFO) if args.secure_mode and args.port == args.maintenance_port: raise ValueError( - "--port and --maintenance_port must not be the same when --secure_mode is set." + "--port and --maintenance_port must not be the same when" + " --secure_mode is set." ) _run(args.port, args.maintenance_port, args.secure_mode, args.server_id) diff --git a/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py b/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py index f8b8382e31c4b..c917bd1052141 100644 --- a/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py +++ b/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py @@ -27,8 +27,8 @@ import grpc -_TEST_METHOD = '/test/Test' -_REQUEST = b'\x23\x33' +_TEST_METHOD = "/test/Test" +_REQUEST = b"\x23\x33" _LARGE_NUM_OF_ITERATIONS = 5000 # If MAX_RSS inflated more than this size, the test is failed. @@ -51,19 +51,19 @@ def _pretty_print_bytes(x): class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _TEST_METHOD: return grpc.unary_unary_rpc_method_handler(lambda x, _: x) def _start_a_test_server(): - server = grpc.server(ThreadPoolExecutor(max_workers=1), - options=(('grpc.so_reuseport', 0),)) + server = grpc.server( + ThreadPoolExecutor(max_workers=1), options=(("grpc.so_reuseport", 0),) + ) server.add_generic_rpc_handlers((_GenericHandler(),)) - port = server.add_insecure_port('localhost:0') + port = server.add_insecure_port("localhost:0") server.start() - return 'localhost:%d' % port, server + return "localhost:%d" % port, server def _perform_an_rpc(address): @@ -74,7 +74,6 @@ def _perform_an_rpc(address): class TestLeak(unittest.TestCase): - def test_leak_with_single_shot_rpcs(self): address, server = _start_a_test_server() @@ -88,9 +87,12 @@ def test_leak_with_single_shot_rpcs(self): # Fails the test if memory leak detected. diff = _get_max_rss() - before if diff > _FAIL_THRESHOLD: - self.fail("Max RSS inflated {} > {}".format( - _pretty_print_bytes(diff), - _pretty_print_bytes(_FAIL_THRESHOLD))) + self.fail( + "Max RSS inflated {} > {}".format( + _pretty_print_bytes(diff), + _pretty_print_bytes(_FAIL_THRESHOLD), + ) + ) if __name__ == "__main__": diff --git a/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py b/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py index f3e9a71a44540..771097936f693 100644 --- a/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py +++ b/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py @@ -100,7 +100,6 @@ def _on_done(): class _GenericHandler(grpc.GenericRpcHandler): - def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) @@ -126,7 +125,7 @@ def _time_invocation(to_time: Callable[[], None]) -> datetime.timedelta: def _server(credentials: Optional[grpc.ServerCredentials]): try: server = test_common.test_server() - target = '[::]:0' + target = "[::]:0" if credentials is None: port = server.add_insecure_port(target) else: @@ -139,7 +138,6 @@ def _server(credentials: Optional[grpc.ServerCredentials]): class SimpleStubsTest(unittest.TestCase): - def assert_cached(self, to_check: Callable[[str], None]) -> None: """Asserts that a function caches intermediate data/state. @@ -161,17 +159,21 @@ def assert_cached(self, to_check: Callable[[str], None]) -> None: runs.append(_time_invocation(lambda: to_check(text))) initial_runs.append(runs[0]) cached_runs.extend(runs[1:]) - average_cold = sum((run for run in initial_runs), - datetime.timedelta()) / len(initial_runs) - average_warm = sum((run for run in cached_runs), - datetime.timedelta()) / len(cached_runs) + average_cold = sum( + (run for run in initial_runs), datetime.timedelta() + ) / len(initial_runs) + average_warm = sum( + (run for run in cached_runs), datetime.timedelta() + ) / len(cached_runs) self.assertLess(average_warm, average_cold) - def assert_eventually(self, - predicate: Callable[[], bool], - *, - timeout: Optional[datetime.timedelta] = None, - message: Optional[Callable[[], str]] = None) -> None: + def assert_eventually( + self, + predicate: Callable[[], bool], + *, + timeout: Optional[datetime.timedelta] = None, + message: Optional[Callable[[], str]] = None, + ) -> None: message = message or (lambda: "Proposition did not evaluate to true") timeout = timeout or datetime.timedelta(seconds=10) end = datetime.datetime.now() + timeout @@ -184,30 +186,31 @@ def assert_eventually(self, def test_unary_unary_insecure(self): with _server(None) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" response = grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, - channel_credentials=grpc.experimental. - insecure_channel_credentials(), - timeout=None) + channel_credentials=grpc.experimental.insecure_channel_credentials(), + timeout=None, + ) self.assertEqual(_REQUEST, response) def test_unary_unary_secure(self): with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" response = grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, channel_credentials=grpc.local_channel_credentials(), - timeout=None) + timeout=None, + ) self.assertEqual(_REQUEST, response) def test_channels_cached(self): with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" test_name = inspect.stack()[0][3] args = (_REQUEST, target, _UNARY_UNARY) kwargs = {"channel_credentials": grpc.local_channel_credentials()} @@ -221,22 +224,22 @@ def _invoke(seed: str): def test_channels_evicted(self): with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" response = grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, - channel_credentials=grpc.local_channel_credentials()) + channel_credentials=grpc.local_channel_credentials(), + ) self.assert_eventually( - lambda: grpc._simple_stubs.ChannelCache.get( - )._test_only_channel_count() == 0, - message=lambda: - f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain" + lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() + == 0, + message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain", ) def test_total_channels_enforced(self): with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" for i in range(_STRESS_EPOCHS): # Ensure we get a new channel each time. options = (("foo", str(i)),) @@ -246,100 +249,106 @@ def test_total_channels_enforced(self): target, _UNARY_UNARY, options=options, - channel_credentials=grpc.local_channel_credentials()) + channel_credentials=grpc.local_channel_credentials(), + ) self.assert_eventually( - lambda: grpc._simple_stubs.ChannelCache.get( - )._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1, - message=lambda: - f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain" + lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() + <= _MAXIMUM_CHANNELS + 1, + message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain", ) def test_unary_stream(self): with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" for response in grpc.experimental.unary_stream( - _REQUEST, - target, - _UNARY_STREAM, - channel_credentials=grpc.local_channel_credentials()): + _REQUEST, + target, + _UNARY_STREAM, + channel_credentials=grpc.local_channel_credentials(), + ): self.assertEqual(_REQUEST, response) def test_stream_unary(self): - def request_iter(): for _ in range(_CLIENT_REQUEST_COUNT): yield _REQUEST with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" response = grpc.experimental.stream_unary( request_iter(), target, _STREAM_UNARY, - channel_credentials=grpc.local_channel_credentials()) + channel_credentials=grpc.local_channel_credentials(), + ) self.assertEqual(_REQUEST, response) def test_stream_stream(self): - def request_iter(): for _ in range(_CLIENT_REQUEST_COUNT): yield _REQUEST with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" for response in grpc.experimental.stream_stream( - request_iter(), - target, - _STREAM_STREAM, - channel_credentials=grpc.local_channel_credentials()): + request_iter(), + target, + _STREAM_STREAM, + channel_credentials=grpc.local_channel_credentials(), + ): self.assertEqual(_REQUEST, response) def test_default_ssl(self): _private_key = resources.private_key() _certificate_chain = resources.certificate_chain() _server_certs = ((_private_key, _certificate_chain),) - _server_host_override = 'foo.test.google.fr' + _server_host_override = "foo.test.google.fr" _test_root_certificates = resources.test_root_certificates() - _property_options = (( - 'grpc.ssl_target_name_override', - _server_host_override, - ),) - cert_dir = os.path.join(os.path.dirname(resources.__file__), - "credentials") + _property_options = ( + ( + "grpc.ssl_target_name_override", + _server_host_override, + ), + ) + cert_dir = os.path.join( + os.path.dirname(resources.__file__), "credentials" + ) cert_file = os.path.join(cert_dir, "ca.pem") with _env("GRPC_DEFAULT_SSL_ROOTS_FILE_PATH", cert_file): server_creds = grpc.ssl_server_credentials(_server_certs) with _server(server_creds) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" response = grpc.experimental.unary_unary( - _REQUEST, target, _UNARY_UNARY, options=_property_options) + _REQUEST, target, _UNARY_UNARY, options=_property_options + ) def test_insecure_sugar(self): with _server(None) as port: - target = f'localhost:{port}' - response = grpc.experimental.unary_unary(_REQUEST, - target, - _UNARY_UNARY, - insecure=True) + target = f"localhost:{port}" + response = grpc.experimental.unary_unary( + _REQUEST, target, _UNARY_UNARY, insecure=True + ) self.assertEqual(_REQUEST, response) def test_insecure_sugar_mutually_exclusive(self): with _server(None) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" with self.assertRaises(ValueError): response = grpc.experimental.unary_unary( _REQUEST, target, _UNARY_UNARY, insecure=True, - channel_credentials=grpc.local_channel_credentials()) + channel_credentials=grpc.local_channel_credentials(), + ) def test_default_wait_for_ready(self): addr, port, sock = get_socket() sock.close() - target = f'{addr}:{port}' + target = f"{addr}:{port}" channel = grpc._simple_stubs.ChannelCache.get().get_channel( - target, (), None, True, None) + target, (), None, True, None + ) rpc_finished_event = threading.Event() rpc_failed_event = threading.Event() server = None @@ -354,8 +363,10 @@ def _on_connectivity_changed(connectivity): server.add_generic_rpc_handlers((_GenericHandler(),)) server.start() channel.unsubscribe(_on_connectivity_changed) - elif connectivity in (grpc.ChannelConnectivity.IDLE, - grpc.ChannelConnectivity.CONNECTING): + elif connectivity in ( + grpc.ChannelConnectivity.IDLE, + grpc.ChannelConnectivity.CONNECTING, + ): pass else: self.fail("Encountered unknown state.") @@ -364,11 +375,9 @@ def _on_connectivity_changed(connectivity): def _send_rpc(): try: - response = grpc.experimental.unary_unary(_REQUEST, - target, - _UNARY_UNARY, - timeout=None, - insecure=True) + response = grpc.experimental.unary_unary( + _REQUEST, target, _UNARY_UNARY, timeout=None, insecure=True + ) rpc_finished_event.set() except Exception as e: rpc_failed_event.set() @@ -383,15 +392,18 @@ def _send_rpc(): def assert_times_out(self, invocation_args): with _server(None) as port: - target = f'localhost:{port}' + target = f"localhost:{port}" with self.assertRaises(grpc.RpcError) as cm: - response = grpc.experimental.unary_unary(_REQUEST, - target, - _BLACK_HOLE, - insecure=True, - **invocation_args) - self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, - cm.exception.code()) + response = grpc.experimental.unary_unary( + _REQUEST, + target, + _BLACK_HOLE, + insecure=True, + **invocation_args, + ) + self.assertEqual( + grpc.StatusCode.DEADLINE_EXCEEDED, cm.exception.code() + ) def test_default_timeout(self): not_present = object() diff --git a/src/re2/gen_build_yaml.py b/src/re2/gen_build_yaml.py index cd66a13e46e8f..063bb4fa58e3c 100755 --- a/src/re2/gen_build_yaml.py +++ b/src/re2/gen_build_yaml.py @@ -19,30 +19,30 @@ import glob import yaml -os.chdir(os.path.dirname(sys.argv[0]) + '/../..') +os.chdir(os.path.dirname(sys.argv[0]) + "/../..") out = {} -out['libs'] = [{ - #TODO @donnadionne: extracting the list of source files from bazel build to reduce duplication - 'name': - 're2', - 'build': - 'private', - 'language': - 'c', - 'secure': - False, - 'src': - sorted( - glob.glob('third_party/re2/re2/*.cc') + [ - "third_party/re2/util/pcre.cc", "third_party/re2/util/rune.cc", - "third_party/re2/util/strutil.cc" - ]), - 'headers': - sorted( - glob.glob('third_party/re2/re2/*.h') + - glob.glob('third_party/re2/util/*.h')), -}] +out["libs"] = [ + { + # TODO @donnadionne: extracting the list of source files from bazel build to reduce duplication + "name": "re2", + "build": "private", + "language": "c", + "secure": False, + "src": sorted( + glob.glob("third_party/re2/re2/*.cc") + + [ + "third_party/re2/util/pcre.cc", + "third_party/re2/util/rune.cc", + "third_party/re2/util/strutil.cc", + ] + ), + "headers": sorted( + glob.glob("third_party/re2/re2/*.h") + + glob.glob("third_party/re2/util/*.h") + ), + } +] print(yaml.dump(out)) diff --git a/src/upb/gen_build_yaml.py b/src/upb/gen_build_yaml.py index 42a7128ae7074..b7824001bdb9e 100755 --- a/src/upb/gen_build_yaml.py +++ b/src/upb/gen_build_yaml.py @@ -24,173 +24,175 @@ out = {} try: - out['libs'] = [{ - 'name': 'upb', - 'build': 'all', - 'language': 'c', - 'src': [ - "third_party/utf8_range/naive.c", - "third_party/utf8_range/range2-neon.c", - "third_party/utf8_range/range2-sse.c", - "third_party/upb/upb/base/status.c", - "third_party/upb/upb/collections/array.c", - "third_party/upb/upb/collections/map_sorter.c", - "third_party/upb/upb/collections/map.c", - "third_party/upb/upb/hash/common.c", - "third_party/upb/upb/json/decode.c", - "third_party/upb/upb/json/encode.c", - "third_party/upb/upb/lex/atoi.c", - "third_party/upb/upb/lex/round_trip.c", - "third_party/upb/upb/lex/strtod.c", - "third_party/upb/upb/lex/unicode.c", - "third_party/upb/upb/mem/alloc.c", - "third_party/upb/upb/mem/arena.c", - "third_party/upb/upb/message/accessors.c", - "third_party/upb/upb/message/message.c", - "third_party/upb/upb/mini_table/common.c", - "third_party/upb/upb/mini_table/decode.c", - "third_party/upb/upb/mini_table/encode.c", - "third_party/upb/upb/mini_table/extension_registry.c", - "third_party/upb/upb/reflection/def_builder.c", - "third_party/upb/upb/reflection/def_pool.c", - "third_party/upb/upb/reflection/def_type.c", - "third_party/upb/upb/reflection/desc_state.c", - "third_party/upb/upb/reflection/enum_def.c", - "third_party/upb/upb/reflection/enum_reserved_range.c", - "third_party/upb/upb/reflection/enum_value_def.c", - "third_party/upb/upb/reflection/extension_range.c", - "third_party/upb/upb/reflection/field_def.c", - "third_party/upb/upb/reflection/file_def.c", - "third_party/upb/upb/reflection/message_def.c", - "third_party/upb/upb/reflection/message_reserved_range.c", - "third_party/upb/upb/reflection/message.c", - "third_party/upb/upb/reflection/method_def.c", - "third_party/upb/upb/reflection/oneof_def.c", - "third_party/upb/upb/reflection/service_def.c", - "third_party/upb/upb/text/encode.c", - "third_party/upb/upb/wire/decode_fast.c", - "third_party/upb/upb/wire/decode.c", - "third_party/upb/upb/wire/encode.c", - "third_party/upb/upb/wire/eps_copy_input_stream.c", - "third_party/upb/upb/wire/reader.c", - "src/core/ext/upb-generated/google/protobuf/descriptor.upb.c", - "src/core/ext/upbdefs-generated/google/protobuf/descriptor.upbdefs.c", - ], - 'headers': [ - "third_party/utf8_range/utf8_range.h", - "third_party/upb/upb/alloc.h", - "third_party/upb/upb/arena.h", - "third_party/upb/upb/array.h", - "third_party/upb/upb/base/descriptor_constants.h", - "third_party/upb/upb/base/log2.h", - "third_party/upb/upb/base/status.h", - "third_party/upb/upb/base/string_view.h", - "third_party/upb/upb/collections/array_internal.h", - "third_party/upb/upb/collections/array.h", - "third_party/upb/upb/collections/map_gencode_util.h", - "third_party/upb/upb/collections/map_internal.h", - "third_party/upb/upb/collections/map_sorter_internal.h", - "third_party/upb/upb/collections/map.h", - "third_party/upb/upb/collections/message_value.h", - "third_party/upb/upb/decode.h", - "third_party/upb/upb/def.h", - "third_party/upb/upb/def.hpp", - "third_party/upb/upb/encode.h", - "third_party/upb/upb/extension_registry.h", - "third_party/upb/upb/hash/common.h", - "third_party/upb/upb/hash/int_table.h", - "third_party/upb/upb/hash/str_table.h", - "third_party/upb/upb/json_decode.h", - "third_party/upb/upb/json_encode.h", - "third_party/upb/upb/json/decode.h", - "third_party/upb/upb/json/encode.h", - "third_party/upb/upb/lex/atoi.h", - "third_party/upb/upb/lex/round_trip.h", - "third_party/upb/upb/lex/strtod.h", - "third_party/upb/upb/lex/unicode.h", - "third_party/upb/upb/map.h", - "third_party/upb/upb/mem/alloc.h", - "third_party/upb/upb/mem/arena_internal.h", - "third_party/upb/upb/mem/arena.h", - "third_party/upb/upb/message/accessors_internal.h", - "third_party/upb/upb/message/accessors.h", - "third_party/upb/upb/message/extension_internal.h", - "third_party/upb/upb/message/internal.h", - "third_party/upb/upb/message/message.h", - "third_party/upb/upb/mini_table.h", - "third_party/upb/upb/mini_table/common_internal.h", - "third_party/upb/upb/mini_table/common.h", - "third_party/upb/upb/mini_table/decode.h", - "third_party/upb/upb/mini_table/encode_internal.h", - "third_party/upb/upb/mini_table/encode_internal.hpp", - "third_party/upb/upb/mini_table/enum_internal.h", - "third_party/upb/upb/mini_table/extension_internal.h", - "third_party/upb/upb/mini_table/extension_registry.h", - "third_party/upb/upb/mini_table/field_internal.h", - "third_party/upb/upb/mini_table/file_internal.h", - "third_party/upb/upb/mini_table/message_internal.h", - "third_party/upb/upb/mini_table/sub_internal.h", - "third_party/upb/upb/mini_table/types.h", - "third_party/upb/upb/msg.h", - "third_party/upb/upb/port/atomic.h", - "third_party/upb/upb/port/def.inc", - "third_party/upb/upb/port/undef.inc", - "third_party/upb/upb/port/vsnprintf_compat.h", - "third_party/upb/upb/reflection.h", - "third_party/upb/upb/reflection.hpp", - "third_party/upb/upb/reflection/common.h", - "third_party/upb/upb/reflection/def_builder_internal.h", - "third_party/upb/upb/reflection/def_pool_internal.h", - "third_party/upb/upb/reflection/def_pool.h", - "third_party/upb/upb/reflection/def_type.h", - "third_party/upb/upb/reflection/def.h", - "third_party/upb/upb/reflection/def.hpp", - "third_party/upb/upb/reflection/desc_state_internal.h", - "third_party/upb/upb/reflection/enum_def_internal.h", - "third_party/upb/upb/reflection/enum_def.h", - "third_party/upb/upb/reflection/enum_reserved_range_internal.h", - "third_party/upb/upb/reflection/enum_reserved_range.h", - "third_party/upb/upb/reflection/enum_value_def_internal.h", - "third_party/upb/upb/reflection/enum_value_def.h", - "third_party/upb/upb/reflection/extension_range_internal.h", - "third_party/upb/upb/reflection/extension_range.h", - "third_party/upb/upb/reflection/field_def_internal.h", - "third_party/upb/upb/reflection/field_def.h", - "third_party/upb/upb/reflection/file_def_internal.h", - "third_party/upb/upb/reflection/file_def.h", - "third_party/upb/upb/reflection/message_def_internal.h", - "third_party/upb/upb/reflection/message_def.h", - "third_party/upb/upb/reflection/message_reserved_range_internal.h", - "third_party/upb/upb/reflection/message_reserved_range.h", - "third_party/upb/upb/reflection/message.h", - "third_party/upb/upb/reflection/message.hpp", - "third_party/upb/upb/reflection/method_def_internal.h", - "third_party/upb/upb/reflection/method_def.h", - "third_party/upb/upb/reflection/oneof_def_internal.h", - "third_party/upb/upb/reflection/oneof_def.h", - "third_party/upb/upb/reflection/service_def_internal.h", - "third_party/upb/upb/reflection/service_def.h", - "third_party/upb/upb/status.h", - "third_party/upb/upb/string_view.h", - "third_party/upb/upb/text_encode.h", - "third_party/upb/upb/text/encode.h", - "third_party/upb/upb/upb.h", - "third_party/upb/upb/upb.hpp", - "third_party/upb/upb/wire/common_internal.h", - "third_party/upb/upb/wire/common.h", - "third_party/upb/upb/wire/decode_fast.h", - "third_party/upb/upb/wire/decode_internal.h", - "third_party/upb/upb/wire/decode.h", - "third_party/upb/upb/wire/encode.h", - "third_party/upb/upb/wire/eps_copy_input_stream.h", - "third_party/upb/upb/wire/reader.h", - "third_party/upb/upb/wire/swap_internal.h", - "third_party/upb/upb/wire/types.h", - "src/core/ext/upb-generated/google/protobuf/descriptor.upb.h", - "src/core/ext/upbdefs-generated/google/protobuf/descriptor.upbdefs.h", - ], - 'secure': False, - }] + out["libs"] = [ + { + "name": "upb", + "build": "all", + "language": "c", + "src": [ + "third_party/utf8_range/naive.c", + "third_party/utf8_range/range2-neon.c", + "third_party/utf8_range/range2-sse.c", + "third_party/upb/upb/base/status.c", + "third_party/upb/upb/collections/array.c", + "third_party/upb/upb/collections/map_sorter.c", + "third_party/upb/upb/collections/map.c", + "third_party/upb/upb/hash/common.c", + "third_party/upb/upb/json/decode.c", + "third_party/upb/upb/json/encode.c", + "third_party/upb/upb/lex/atoi.c", + "third_party/upb/upb/lex/round_trip.c", + "third_party/upb/upb/lex/strtod.c", + "third_party/upb/upb/lex/unicode.c", + "third_party/upb/upb/mem/alloc.c", + "third_party/upb/upb/mem/arena.c", + "third_party/upb/upb/message/accessors.c", + "third_party/upb/upb/message/message.c", + "third_party/upb/upb/mini_table/common.c", + "third_party/upb/upb/mini_table/decode.c", + "third_party/upb/upb/mini_table/encode.c", + "third_party/upb/upb/mini_table/extension_registry.c", + "third_party/upb/upb/reflection/def_builder.c", + "third_party/upb/upb/reflection/def_pool.c", + "third_party/upb/upb/reflection/def_type.c", + "third_party/upb/upb/reflection/desc_state.c", + "third_party/upb/upb/reflection/enum_def.c", + "third_party/upb/upb/reflection/enum_reserved_range.c", + "third_party/upb/upb/reflection/enum_value_def.c", + "third_party/upb/upb/reflection/extension_range.c", + "third_party/upb/upb/reflection/field_def.c", + "third_party/upb/upb/reflection/file_def.c", + "third_party/upb/upb/reflection/message_def.c", + "third_party/upb/upb/reflection/message_reserved_range.c", + "third_party/upb/upb/reflection/message.c", + "third_party/upb/upb/reflection/method_def.c", + "third_party/upb/upb/reflection/oneof_def.c", + "third_party/upb/upb/reflection/service_def.c", + "third_party/upb/upb/text/encode.c", + "third_party/upb/upb/wire/decode_fast.c", + "third_party/upb/upb/wire/decode.c", + "third_party/upb/upb/wire/encode.c", + "third_party/upb/upb/wire/eps_copy_input_stream.c", + "third_party/upb/upb/wire/reader.c", + "src/core/ext/upb-generated/google/protobuf/descriptor.upb.c", + "src/core/ext/upbdefs-generated/google/protobuf/descriptor.upbdefs.c", + ], + "headers": [ + "third_party/utf8_range/utf8_range.h", + "third_party/upb/upb/alloc.h", + "third_party/upb/upb/arena.h", + "third_party/upb/upb/array.h", + "third_party/upb/upb/base/descriptor_constants.h", + "third_party/upb/upb/base/log2.h", + "third_party/upb/upb/base/status.h", + "third_party/upb/upb/base/string_view.h", + "third_party/upb/upb/collections/array_internal.h", + "third_party/upb/upb/collections/array.h", + "third_party/upb/upb/collections/map_gencode_util.h", + "third_party/upb/upb/collections/map_internal.h", + "third_party/upb/upb/collections/map_sorter_internal.h", + "third_party/upb/upb/collections/map.h", + "third_party/upb/upb/collections/message_value.h", + "third_party/upb/upb/decode.h", + "third_party/upb/upb/def.h", + "third_party/upb/upb/def.hpp", + "third_party/upb/upb/encode.h", + "third_party/upb/upb/extension_registry.h", + "third_party/upb/upb/hash/common.h", + "third_party/upb/upb/hash/int_table.h", + "third_party/upb/upb/hash/str_table.h", + "third_party/upb/upb/json_decode.h", + "third_party/upb/upb/json_encode.h", + "third_party/upb/upb/json/decode.h", + "third_party/upb/upb/json/encode.h", + "third_party/upb/upb/lex/atoi.h", + "third_party/upb/upb/lex/round_trip.h", + "third_party/upb/upb/lex/strtod.h", + "third_party/upb/upb/lex/unicode.h", + "third_party/upb/upb/map.h", + "third_party/upb/upb/mem/alloc.h", + "third_party/upb/upb/mem/arena_internal.h", + "third_party/upb/upb/mem/arena.h", + "third_party/upb/upb/message/accessors_internal.h", + "third_party/upb/upb/message/accessors.h", + "third_party/upb/upb/message/extension_internal.h", + "third_party/upb/upb/message/internal.h", + "third_party/upb/upb/message/message.h", + "third_party/upb/upb/mini_table.h", + "third_party/upb/upb/mini_table/common_internal.h", + "third_party/upb/upb/mini_table/common.h", + "third_party/upb/upb/mini_table/decode.h", + "third_party/upb/upb/mini_table/encode_internal.h", + "third_party/upb/upb/mini_table/encode_internal.hpp", + "third_party/upb/upb/mini_table/enum_internal.h", + "third_party/upb/upb/mini_table/extension_internal.h", + "third_party/upb/upb/mini_table/extension_registry.h", + "third_party/upb/upb/mini_table/field_internal.h", + "third_party/upb/upb/mini_table/file_internal.h", + "third_party/upb/upb/mini_table/message_internal.h", + "third_party/upb/upb/mini_table/sub_internal.h", + "third_party/upb/upb/mini_table/types.h", + "third_party/upb/upb/msg.h", + "third_party/upb/upb/port/atomic.h", + "third_party/upb/upb/port/def.inc", + "third_party/upb/upb/port/undef.inc", + "third_party/upb/upb/port/vsnprintf_compat.h", + "third_party/upb/upb/reflection.h", + "third_party/upb/upb/reflection.hpp", + "third_party/upb/upb/reflection/common.h", + "third_party/upb/upb/reflection/def_builder_internal.h", + "third_party/upb/upb/reflection/def_pool_internal.h", + "third_party/upb/upb/reflection/def_pool.h", + "third_party/upb/upb/reflection/def_type.h", + "third_party/upb/upb/reflection/def.h", + "third_party/upb/upb/reflection/def.hpp", + "third_party/upb/upb/reflection/desc_state_internal.h", + "third_party/upb/upb/reflection/enum_def_internal.h", + "third_party/upb/upb/reflection/enum_def.h", + "third_party/upb/upb/reflection/enum_reserved_range_internal.h", + "third_party/upb/upb/reflection/enum_reserved_range.h", + "third_party/upb/upb/reflection/enum_value_def_internal.h", + "third_party/upb/upb/reflection/enum_value_def.h", + "third_party/upb/upb/reflection/extension_range_internal.h", + "third_party/upb/upb/reflection/extension_range.h", + "third_party/upb/upb/reflection/field_def_internal.h", + "third_party/upb/upb/reflection/field_def.h", + "third_party/upb/upb/reflection/file_def_internal.h", + "third_party/upb/upb/reflection/file_def.h", + "third_party/upb/upb/reflection/message_def_internal.h", + "third_party/upb/upb/reflection/message_def.h", + "third_party/upb/upb/reflection/message_reserved_range_internal.h", + "third_party/upb/upb/reflection/message_reserved_range.h", + "third_party/upb/upb/reflection/message.h", + "third_party/upb/upb/reflection/message.hpp", + "third_party/upb/upb/reflection/method_def_internal.h", + "third_party/upb/upb/reflection/method_def.h", + "third_party/upb/upb/reflection/oneof_def_internal.h", + "third_party/upb/upb/reflection/oneof_def.h", + "third_party/upb/upb/reflection/service_def_internal.h", + "third_party/upb/upb/reflection/service_def.h", + "third_party/upb/upb/status.h", + "third_party/upb/upb/string_view.h", + "third_party/upb/upb/text_encode.h", + "third_party/upb/upb/text/encode.h", + "third_party/upb/upb/upb.h", + "third_party/upb/upb/upb.hpp", + "third_party/upb/upb/wire/common_internal.h", + "third_party/upb/upb/wire/common.h", + "third_party/upb/upb/wire/decode_fast.h", + "third_party/upb/upb/wire/decode_internal.h", + "third_party/upb/upb/wire/decode.h", + "third_party/upb/upb/wire/encode.h", + "third_party/upb/upb/wire/eps_copy_input_stream.h", + "third_party/upb/upb/wire/reader.h", + "third_party/upb/upb/wire/swap_internal.h", + "third_party/upb/upb/wire/types.h", + "src/core/ext/upb-generated/google/protobuf/descriptor.upb.h", + "src/core/ext/upbdefs-generated/google/protobuf/descriptor.upbdefs.h", + ], + "secure": False, + } + ] except: pass diff --git a/src/zlib/gen_build_yaml.py b/src/zlib/gen_build_yaml.py index c58ff8ab3a978..fdd0ccee0af3a 100755 --- a/src/zlib/gen_build_yaml.py +++ b/src/zlib/gen_build_yaml.py @@ -19,42 +19,39 @@ import sys import yaml -os.chdir(os.path.dirname(sys.argv[0]) + '/../..') +os.chdir(os.path.dirname(sys.argv[0]) + "/../..") out = {} try: - with open('third_party/zlib/CMakeLists.txt') as f: + with open("third_party/zlib/CMakeLists.txt") as f: cmake = f.read() def cmpath(x): - return 'third_party/zlib/%s' % x.replace('${CMAKE_CURRENT_BINARY_DIR}/', - '') + return "third_party/zlib/%s" % x.replace( + "${CMAKE_CURRENT_BINARY_DIR}/", "" + ) def cmvar(name): - regex = r'set\(\s*' + regex = r"set\(\s*" regex += name - regex += r'([^)]*)\)' + regex += r"([^)]*)\)" return [cmpath(x) for x in re.search(regex, cmake).group(1).split()] - out['libs'] = [{ - 'name': - 'z', - 'zlib': - True, - 'defaults': - 'zlib', - 'build': - 'private', - 'language': - 'c', - 'secure': - False, - 'src': - sorted(cmvar('ZLIB_SRCS')), - 'headers': - sorted(cmvar('ZLIB_PUBLIC_HDRS') + cmvar('ZLIB_PRIVATE_HDRS')), - }] + out["libs"] = [ + { + "name": "z", + "zlib": True, + "defaults": "zlib", + "build": "private", + "language": "c", + "secure": False, + "src": sorted(cmvar("ZLIB_SRCS")), + "headers": sorted( + cmvar("ZLIB_PUBLIC_HDRS") + cmvar("ZLIB_PRIVATE_HDRS") + ), + } + ] except: pass diff --git a/test/core/end2end/fuzzers/generate_client_examples_of_bad_closing_streams.py b/test/core/end2end/fuzzers/generate_client_examples_of_bad_closing_streams.py index cb47cae5cc40c..5ee5403a2246f 100755 --- a/test/core/end2end/fuzzers/generate_client_examples_of_bad_closing_streams.py +++ b/test/core/end2end/fuzzers/generate_client_examples_of_bad_closing_streams.py @@ -19,16 +19,16 @@ os.chdir(os.path.dirname(sys.argv[0])) streams = { - 'server_hanging_response_1_header': - ([0, 0, 0, 4, 0, 0, 0, 0, 0] + # settings frame - [0, 0, 0, 1, 5, 0, 0, 0, 1] # trailers - ), - 'server_hanging_response_2_header2': - ([0, 0, 0, 4, 0, 0, 0, 0, 0] + # settings frame - [0, 0, 0, 1, 4, 0, 0, 0, 1] + # headers - [0, 0, 0, 1, 5, 0, 0, 0, 1] # trailers - ), + "server_hanging_response_1_header": ( + [0, 0, 0, 4, 0, 0, 0, 0, 0] + + [0, 0, 0, 1, 5, 0, 0, 0, 1] # settings frame # trailers + ), + "server_hanging_response_2_header2": ( + [0, 0, 0, 4, 0, 0, 0, 0, 0] + + [0, 0, 0, 1, 4, 0, 0, 0, 1] # settings frame + + [0, 0, 0, 1, 5, 0, 0, 0, 1] # headers # trailers + ), } for name, stream in streams.items(): - open('client_fuzzer_corpus/%s' % name, 'w').write(bytearray(stream)) + open("client_fuzzer_corpus/%s" % name, "w").write(bytearray(stream)) diff --git a/test/core/http/test_server.py b/test/core/http/test_server.py index abd26c53d3d1c..e8bb0770bab00 100755 --- a/test/core/http/test_server.py +++ b/test/core/http/test_server.py @@ -22,48 +22,56 @@ import sys _PEM = os.path.abspath( - os.path.join(os.path.dirname(sys.argv[0]), '../../..', - 'src/core/tsi/test_creds/server1.pem')) + os.path.join( + os.path.dirname(sys.argv[0]), + "../../..", + "src/core/tsi/test_creds/server1.pem", + ) +) _KEY = os.path.abspath( - os.path.join(os.path.dirname(sys.argv[0]), '../../..', - 'src/core/tsi/test_creds/server1.key')) + os.path.join( + os.path.dirname(sys.argv[0]), + "../../..", + "src/core/tsi/test_creds/server1.key", + ) +) print(_PEM) open(_PEM).close() -argp = argparse.ArgumentParser(description='Server for httpcli_test') -argp.add_argument('-p', '--port', default=10080, type=int) -argp.add_argument('-s', '--ssl', default=False, action='store_true') +argp = argparse.ArgumentParser(description="Server for httpcli_test") +argp.add_argument("-p", "--port", default=10080, type=int) +argp.add_argument("-s", "--ssl", default=False, action="store_true") args = argp.parse_args() -print('server running on port %d' % args.port) +print("server running on port %d" % args.port) class Handler(BaseHTTPRequestHandler): - def good(self): self.send_response(200) - self.send_header('Content-Type', 'text/html') + self.send_header("Content-Type", "text/html") self.end_headers() self.wfile.write( - 'Hello world!'.encode('ascii')) + "Hello world!".encode("ascii") + ) self.wfile.write( - '

This is a test

'.encode('ascii')) + "

This is a test

".encode("ascii") + ) def do_GET(self): - if self.path == '/get': + if self.path == "/get": self.good() def do_POST(self): - content_len = self.headers.get('content-length') - content = self.rfile.read(int(content_len)).decode('ascii') - if self.path == '/post' and content == 'hello': + content_len = self.headers.get("content-length") + content = self.rfile.read(int(content_len)).decode("ascii") + if self.path == "/post" and content == "hello": self.good() -httpd = HTTPServer(('localhost', args.port), Handler) +httpd = HTTPServer(("localhost", args.port), Handler) if args.ssl: - httpd.socket = ssl.wrap_socket(httpd.socket, - certfile=_PEM, - keyfile=_KEY, - server_side=True) + httpd.socket = ssl.wrap_socket( + httpd.socket, certfile=_PEM, keyfile=_KEY, server_side=True + ) httpd.serve_forever() diff --git a/test/cpp/naming/gen_build_yaml.py b/test/cpp/naming/gen_build_yaml.py index 24713aa5b8fc2..a65f2b3e9b2ae 100755 --- a/test/cpp/naming/gen_build_yaml.py +++ b/test/cpp/naming/gen_build_yaml.py @@ -20,64 +20,83 @@ import yaml -_LOCAL_DNS_SERVER_ADDRESS = '127.0.0.1:15353' +_LOCAL_DNS_SERVER_ADDRESS = "127.0.0.1:15353" def _append_zone_name(name, zone_name): - return '%s.%s' % (name, zone_name) + return "%s.%s" % (name, zone_name) def _build_expected_addrs_cmd_arg(expected_addrs): out = [] for addr in expected_addrs: - out.append('%s,%s' % (addr['address'], str(addr['is_balancer']))) - return ';'.join(out) + out.append("%s,%s" % (addr["address"], str(addr["is_balancer"]))) + return ";".join(out) def _resolver_test_cases(resolver_component_data): out = [] - for test_case in resolver_component_data['resolver_component_tests']: + for test_case in resolver_component_data["resolver_component_tests"]: target_name = _append_zone_name( - test_case['record_to_resolve'], - resolver_component_data['resolver_tests_common_zone_name']) - out.append({ - 'test_title': - target_name, - 'arg_names_and_values': [ - ('target_name', target_name), - ('do_ordered_address_comparison', - test_case['do_ordered_address_comparison']), - ('expected_addrs', - _build_expected_addrs_cmd_arg(test_case['expected_addrs'])), - ('expected_chosen_service_config', - (test_case['expected_chosen_service_config'] or '')), - ('expected_service_config_error', - (test_case['expected_service_config_error'] or '')), - ('expected_lb_policy', (test_case['expected_lb_policy'] or '')), - ('enable_srv_queries', test_case['enable_srv_queries']), - ('enable_txt_queries', test_case['enable_txt_queries']), - ('inject_broken_nameserver_list', - test_case['inject_broken_nameserver_list']), - ], - }) + test_case["record_to_resolve"], + resolver_component_data["resolver_tests_common_zone_name"], + ) + out.append( + { + "test_title": target_name, + "arg_names_and_values": [ + ("target_name", target_name), + ( + "do_ordered_address_comparison", + test_case["do_ordered_address_comparison"], + ), + ( + "expected_addrs", + _build_expected_addrs_cmd_arg( + test_case["expected_addrs"] + ), + ), + ( + "expected_chosen_service_config", + (test_case["expected_chosen_service_config"] or ""), + ), + ( + "expected_service_config_error", + (test_case["expected_service_config_error"] or ""), + ), + ( + "expected_lb_policy", + (test_case["expected_lb_policy"] or ""), + ), + ("enable_srv_queries", test_case["enable_srv_queries"]), + ("enable_txt_queries", test_case["enable_txt_queries"]), + ( + "inject_broken_nameserver_list", + test_case["inject_broken_nameserver_list"], + ), + ], + } + ) return out def main(): - resolver_component_data = '' - with open('test/cpp/naming/resolver_test_record_groups.yaml') as f: + resolver_component_data = "" + with open("test/cpp/naming/resolver_test_record_groups.yaml") as f: resolver_component_data = yaml.safe_load(f) json = { - 'resolver_tests_common_zone_name': - resolver_component_data['resolver_tests_common_zone_name'], + "resolver_tests_common_zone_name": resolver_component_data[ + "resolver_tests_common_zone_name" + ], # this data is required by the resolver_component_tests_runner.py.template - 'resolver_component_test_cases': - _resolver_test_cases(resolver_component_data), + "resolver_component_test_cases": _resolver_test_cases( + resolver_component_data + ), } print(yaml.safe_dump(json)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/test/cpp/naming/manual_run_resolver_component_test.py b/test/cpp/naming/manual_run_resolver_component_test.py index fc6812bf59126..6a4059dc98449 100644 --- a/test/cpp/naming/manual_run_resolver_component_test.py +++ b/test/cpp/naming/manual_run_resolver_component_test.py @@ -19,24 +19,26 @@ # The c-ares test suite doesn't get ran regularly on Windows, but # this script provides a way to run a lot of the tests manually. -_MSBUILD_CONFIG = os.environ['CONFIG'] -os.chdir(os.path.join('..', '..', os.getcwd())) +_MSBUILD_CONFIG = os.environ["CONFIG"] +os.chdir(os.path.join("..", "..", os.getcwd())) # This port is arbitrary, but it needs to be available. _DNS_SERVER_PORT = 15353 -subprocess.call([ - sys.executable, - 'test\\cpp\\naming\\resolver_component_tests_runner.py', - '--test_bin_path', - 'cmake\\build\\%s\\resolver_component_test.exe' % _MSBUILD_CONFIG, - '--dns_server_bin_path', - 'test\\cpp\\naming\\utils\\dns_server.py', - '--records_config_path', - 'test\\cpp\\naming\\resolver_test_record_groups.yaml', - '--dns_server_port', - str(_DNS_SERVER_PORT), - '--dns_resolver_bin_path', - 'test\\cpp\\naming\\utils\\dns_resolver.py', - '--tcp_connect_bin_path', - 'test\\cpp\\naming\\utils\\tcp_connect.py', -]) +subprocess.call( + [ + sys.executable, + "test\\cpp\\naming\\resolver_component_tests_runner.py", + "--test_bin_path", + "cmake\\build\\%s\\resolver_component_test.exe" % _MSBUILD_CONFIG, + "--dns_server_bin_path", + "test\\cpp\\naming\\utils\\dns_server.py", + "--records_config_path", + "test\\cpp\\naming\\resolver_test_record_groups.yaml", + "--dns_server_port", + str(_DNS_SERVER_PORT), + "--dns_resolver_bin_path", + "test\\cpp\\naming\\utils\\dns_resolver.py", + "--tcp_connect_bin_path", + "test\\cpp\\naming\\utils\\tcp_connect.py", + ] +) diff --git a/test/cpp/naming/utils/dns_resolver.py b/test/cpp/naming/utils/dns_resolver.py index 773f0c1e796fe..91b01e856d1c2 100755 --- a/test/cpp/naming/utils/dns_resolver.py +++ b/test/cpp/naming/utils/dns_resolver.py @@ -24,27 +24,35 @@ def main(): - argp = argparse.ArgumentParser(description='Make DNS queries for A records') - argp.add_argument('-s', - '--server_host', - default='127.0.0.1', - type=str, - help='Host for DNS server to listen on for TCP and UDP.') - argp.add_argument('-p', - '--server_port', - default=53, - type=int, - help='Port that the DNS server is listening on.') - argp.add_argument('-n', - '--qname', - default=None, - type=str, - help=('Name of the record to query for. ')) - argp.add_argument('-t', - '--timeout', - default=1, - type=int, - help=('Force process exit after this number of seconds.')) + argp = argparse.ArgumentParser(description="Make DNS queries for A records") + argp.add_argument( + "-s", + "--server_host", + default="127.0.0.1", + type=str, + help="Host for DNS server to listen on for TCP and UDP.", + ) + argp.add_argument( + "-p", + "--server_port", + default=53, + type=int, + help="Port that the DNS server is listening on.", + ) + argp.add_argument( + "-n", + "--qname", + default=None, + type=str, + help="Name of the record to query for. ", + ) + argp.add_argument( + "-t", + "--timeout", + default=1, + type=int, + help="Force process exit after this number of seconds.", + ) args = argp.parse_args() def OnResolverResultAvailable(result): @@ -62,5 +70,5 @@ def BeginQuery(reactor, qname): task.react(BeginQuery, [args.qname]) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/test/cpp/naming/utils/dns_server.py b/test/cpp/naming/utils/dns_server.py index 245c801d0beeb..02e1541875bef 100755 --- a/test/cpp/naming/utils/dns_server.py +++ b/test/cpp/naming/utils/dns_server.py @@ -39,12 +39,13 @@ import twisted.names.server import yaml -_SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp' # missing end '.' for twisted syntax -_SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123' +_SERVER_HEALTH_CHECK_RECORD_NAME = ( # missing end '.' for twisted syntax + "health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp" +) +_SERVER_HEALTH_CHECK_RECORD_DATA = "123.123.123.123" class NoFileAuthority(authority.FileAuthority): - def __init__(self, soa, records): # skip FileAuthority common.ResolverBase.__init__(self) @@ -56,82 +57,89 @@ def start_local_dns_server(args): all_records = {} def _push_record(name, r): - name = name.encode('ascii') - print('pushing record: |%s|' % name) + name = name.encode("ascii") + print("pushing record: |%s|" % name) if all_records.get(name) is not None: all_records[name].append(r) return all_records[name] = [r] def _maybe_split_up_txt_data(name, txt_data, r_ttl): - txt_data = txt_data.encode('ascii') + txt_data = txt_data.encode("ascii") start = 0 txt_data_list = [] while len(txt_data[start:]) > 0: next_read = len(txt_data[start:]) if next_read > 255: next_read = 255 - txt_data_list.append(txt_data[start:start + next_read]) + txt_data_list.append(txt_data[start : start + next_read]) start += next_read _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl)) with open(args.records_config_path) as config: test_records_config = yaml.safe_load(config) - common_zone_name = test_records_config['resolver_tests_common_zone_name'] - for group in test_records_config['resolver_component_tests']: - for name in group['records'].keys(): - for record in group['records'][name]: - r_type = record['type'] - r_data = record['data'] - r_ttl = int(record['TTL']) - record_full_name = '%s.%s' % (name, common_zone_name) - assert record_full_name[-1] == '.' + common_zone_name = test_records_config["resolver_tests_common_zone_name"] + for group in test_records_config["resolver_component_tests"]: + for name in group["records"].keys(): + for record in group["records"][name]: + r_type = record["type"] + r_data = record["data"] + r_ttl = int(record["TTL"]) + record_full_name = "%s.%s" % (name, common_zone_name) + assert record_full_name[-1] == "." record_full_name = record_full_name[:-1] - if r_type == 'A': - _push_record(record_full_name, - dns.Record_A(r_data, ttl=r_ttl)) - if r_type == 'AAAA': - _push_record(record_full_name, - dns.Record_AAAA(r_data, ttl=r_ttl)) - if r_type == 'SRV': - p, w, port, target = r_data.split(' ') + if r_type == "A": + _push_record( + record_full_name, dns.Record_A(r_data, ttl=r_ttl) + ) + if r_type == "AAAA": + _push_record( + record_full_name, dns.Record_AAAA(r_data, ttl=r_ttl) + ) + if r_type == "SRV": + p, w, port, target = r_data.split(" ") p = int(p) w = int(w) port = int(port) target_full_name = ( - '%s.%s' % (target, common_zone_name)).encode('ascii') + "%s.%s" % (target, common_zone_name) + ).encode("ascii") _push_record( record_full_name, - dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl)) - if r_type == 'TXT': + dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl), + ) + if r_type == "TXT": _maybe_split_up_txt_data(record_full_name, r_data, r_ttl) # Add an optional IPv4 record is specified if args.add_a_record: - extra_host, extra_host_ipv4 = args.add_a_record.split(':') + extra_host, extra_host_ipv4 = args.add_a_record.split(":") _push_record(extra_host, dns.Record_A(extra_host_ipv4, ttl=0)) # Server health check record - _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME, - dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0)) - soa_record = dns.Record_SOA(mname=common_zone_name.encode('ascii')) + _push_record( + _SERVER_HEALTH_CHECK_RECORD_NAME, + dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0), + ) + soa_record = dns.Record_SOA(mname=common_zone_name.encode("ascii")) test_domain_com = NoFileAuthority( - soa=(common_zone_name.encode('ascii'), soa_record), + soa=(common_zone_name.encode("ascii"), soa_record), records=all_records, ) server = twisted.names.server.DNSServerFactory( - authorities=[test_domain_com], verbose=2) + authorities=[test_domain_com], verbose=2 + ) server.noisy = 2 twisted.internet.reactor.listenTCP(args.port, server) dns_proto = twisted.names.dns.DNSDatagramProtocol(server) dns_proto.noisy = 2 twisted.internet.reactor.listenUDP(args.port, dns_proto) - print('starting local dns server on 127.0.0.1:%s' % args.port) - print('starting twisted.internet.reactor') + print("starting local dns server on 127.0.0.1:%s" % args.port) + print("starting twisted.internet.reactor") twisted.internet.reactor.suggestThreadPoolSize(1) twisted.internet.reactor.run() def _quit_on_signal(signum, _frame): - print('Received SIGNAL %d. Quitting with exit code 0' % signum) + print("Received SIGNAL %d. Quitting with exit code 0" % signum) twisted.internet.reactor.stop() sys.stdout.flush() sys.exit(0) @@ -146,35 +154,44 @@ def flush_stdout_loop(): sys.stdout.flush() time.sleep(sleep_time) num_timeouts_so_far += 1 - print('Process timeout reached, or cancelled. Exitting 0.') + print("Process timeout reached, or cancelled. Exitting 0.") os.kill(os.getpid(), signal.SIGTERM) def main(): argp = argparse.ArgumentParser( - description='Local DNS Server for resolver tests') - argp.add_argument('-p', - '--port', - default=None, - type=int, - help='Port for DNS server to listen on for TCP and UDP.') + description="Local DNS Server for resolver tests" + ) argp.add_argument( - '-r', - '--records_config_path', + "-p", + "--port", + default=None, + type=int, + help="Port for DNS server to listen on for TCP and UDP.", + ) + argp.add_argument( + "-r", + "--records_config_path", default=None, type=str, - help=('Directory of resolver_test_record_groups.yaml file. ' - 'Defaults to path needed when the test is invoked as part ' - 'of run_tests.py.')) + help=( + "Directory of resolver_test_record_groups.yaml file. " + "Defaults to path needed when the test is invoked as part " + "of run_tests.py." + ), + ) argp.add_argument( - '--add_a_record', + "--add_a_record", default=None, type=str, - help=('Add an A record via the command line. Useful for when we ' - 'need to serve a one-off A record that is under a ' - 'different domain then the rest the records configured in ' - '--records_config_path (which all need to be under the ' - 'same domain). Format: :')) + help=( + "Add an A record via the command line. Useful for when we " + "need to serve a one-off A record that is under a " + "different domain then the rest the records configured in " + "--records_config_path (which all need to be under the " + "same domain). Format: :" + ), + ) args = argp.parse_args() signal.signal(signal.SIGTERM, _quit_on_signal) signal.signal(signal.SIGINT, _quit_on_signal) @@ -184,5 +201,5 @@ def main(): start_local_dns_server(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/test/cpp/naming/utils/run_dns_server_for_lb_interop_tests.py b/test/cpp/naming/utils/run_dns_server_for_lb_interop_tests.py index 0c9cf9f238a1d..403b23dbde3ec 100755 --- a/test/cpp/naming/utils/run_dns_server_for_lb_interop_tests.py +++ b/test/cpp/naming/utils/run_dns_server_for_lb_interop_tests.py @@ -24,85 +24,102 @@ import yaml argp = argparse.ArgumentParser( - description='Runs a DNS server for LB interop tests') -argp.add_argument('-l', - '--grpclb_ips', - default=None, - type=str, - help='Comma-separated list of IP addresses of balancers') + description="Runs a DNS server for LB interop tests" +) argp.add_argument( - '-f', - '--fallback_ips', + "-l", + "--grpclb_ips", default=None, type=str, - help='Comma-separated list of IP addresses of fallback servers') + help="Comma-separated list of IP addresses of balancers", +) argp.add_argument( - '-c', - '--cause_no_error_no_data_for_balancer_a_record', + "-f", + "--fallback_ips", + default=None, + type=str, + help="Comma-separated list of IP addresses of fallback servers", +) +argp.add_argument( + "-c", + "--cause_no_error_no_data_for_balancer_a_record", default=False, - action='store_const', + action="store_const", const=True, - help=('Used for testing the case in which the grpclb ' - 'balancer A record lookup results in a DNS NOERROR response ' - 'but with no ANSWER section i.e. no addresses')) + help=( + "Used for testing the case in which the grpclb " + "balancer A record lookup results in a DNS NOERROR response " + "but with no ANSWER section i.e. no addresses" + ), +) args = argp.parse_args() balancer_records = [] -grpclb_ips = args.grpclb_ips.split(',') +grpclb_ips = args.grpclb_ips.split(",") if grpclb_ips[0]: for ip in grpclb_ips: - balancer_records.append({ - 'TTL': '2100', - 'data': ip, - 'type': 'A', - }) + balancer_records.append( + { + "TTL": "2100", + "data": ip, + "type": "A", + } + ) fallback_records = [] -fallback_ips = args.fallback_ips.split(',') +fallback_ips = args.fallback_ips.split(",") if fallback_ips[0]: for ip in fallback_ips: - fallback_records.append({ - 'TTL': '2100', - 'data': ip, - 'type': 'A', - }) + fallback_records.append( + { + "TTL": "2100", + "data": ip, + "type": "A", + } + ) records_config_yaml = { - 'resolver_tests_common_zone_name': - 'test.google.fr.', - 'resolver_component_tests': [{ - 'records': { - '_grpclb._tcp.server': [{ - 'TTL': '2100', - 'data': '0 0 12000 balancer', - 'type': 'SRV' - },], - 'balancer': balancer_records, - 'server': fallback_records, + "resolver_tests_common_zone_name": "test.google.fr.", + "resolver_component_tests": [ + { + "records": { + "_grpclb._tcp.server": [ + { + "TTL": "2100", + "data": "0 0 12000 balancer", + "type": "SRV", + }, + ], + "balancer": balancer_records, + "server": fallback_records, + } } - }] + ], } if args.cause_no_error_no_data_for_balancer_a_record: - balancer_records = records_config_yaml['resolver_component_tests'][0][ - 'records']['balancer'] + balancer_records = records_config_yaml["resolver_component_tests"][0][ + "records" + ]["balancer"] assert not balancer_records # Insert a TXT record at the balancer.test.google.fr. domain. # This TXT record won't actually be resolved or used by gRPC clients; # inserting this record is just a way get the balancer.test.google.fr. # A record queries to return NOERROR DNS responses that also have no # ANSWER section, in order to simulate this failure case. - balancer_records.append({ - 'TTL': '2100', - 'data': 'arbitrary string that wont actually be resolved', - 'type': 'TXT', - }) + balancer_records.append( + { + "TTL": "2100", + "data": "arbitrary string that wont actually be resolved", + "type": "TXT", + } + ) # Generate the actual DNS server records config file records_config_path = tempfile.mktemp() -with open(records_config_path, 'w') as records_config_generated: +with open(records_config_path, "w") as records_config_generated: records_config_generated.write(yaml.dump(records_config_yaml)) -with open(records_config_path, 'r') as records_config_generated: - sys.stderr.write('===== DNS server records config: =====\n') +with open(records_config_path, "r") as records_config_generated: + sys.stderr.write("===== DNS server records config: =====\n") sys.stderr.write(records_config_generated.read()) - sys.stderr.write('======================================\n') + sys.stderr.write("======================================\n") # Run the DNS server # Note that we need to add the extra @@ -110,10 +127,12 @@ # OAuth creds and ALTS creds to work. # TODO(apolcyn): should metadata.google.internal always resolve # to 169.254.169.254? -subprocess.check_output([ - '/var/local/git/grpc/test/cpp/naming/utils/dns_server.py', - '--port=53', - '--records_config_path', - records_config_path, - '--add_a_record=metadata.google.internal:169.254.169.254', -]) +subprocess.check_output( + [ + "/var/local/git/grpc/test/cpp/naming/utils/dns_server.py", + "--port=53", + "--records_config_path", + records_config_path, + "--add_a_record=metadata.google.internal:169.254.169.254", + ] +) diff --git a/test/cpp/naming/utils/tcp_connect.py b/test/cpp/naming/utils/tcp_connect.py index e41f870b74664..f06a5e4655e65 100755 --- a/test/cpp/naming/utils/tcp_connect.py +++ b/test/cpp/naming/utils/tcp_connect.py @@ -23,26 +23,34 @@ def main(): argp = argparse.ArgumentParser( - description='Open a TCP handshake to a server') - argp.add_argument('-s', - '--server_host', - default=None, - type=str, - help='Server host name or IP.') - argp.add_argument('-p', - '--server_port', - default=0, - type=int, - help='Port that the server is listening on.') - argp.add_argument('-t', - '--timeout', - default=1, - type=int, - help='Force process exit after this number of seconds.') + description="Open a TCP handshake to a server" + ) + argp.add_argument( + "-s", + "--server_host", + default=None, + type=str, + help="Server host name or IP.", + ) + argp.add_argument( + "-p", + "--server_port", + default=0, + type=int, + help="Port that the server is listening on.", + ) + argp.add_argument( + "-t", + "--timeout", + default=1, + type=int, + help="Force process exit after this number of seconds.", + ) args = argp.parse_args() - socket.create_connection([args.server_host, args.server_port], - timeout=args.timeout) + socket.create_connection( + [args.server_host, args.server_port], timeout=args.timeout + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/test/cpp/qps/json_run_localhost_scenario_gen.py b/test/cpp/qps/json_run_localhost_scenario_gen.py index c7c5028f5a0f1..72d29ab84789d 100755 --- a/test/cpp/qps/json_run_localhost_scenario_gen.py +++ b/test/cpp/qps/json_run_localhost_scenario_gen.py @@ -24,5 +24,6 @@ gen.generate_scenarios_bzl( gen.generate_json_run_localhost_scenarios(), - os.path.join(script_dir, 'json_run_localhost_scenarios.bzl'), - 'JSON_RUN_LOCALHOST_SCENARIOS') + os.path.join(script_dir, "json_run_localhost_scenarios.bzl"), + "JSON_RUN_LOCALHOST_SCENARIOS", +) diff --git a/test/cpp/qps/qps_json_driver_scenario_gen.py b/test/cpp/qps/qps_json_driver_scenario_gen.py index 76bff59f2c853..3f5f2e2471afc 100755 --- a/test/cpp/qps/qps_json_driver_scenario_gen.py +++ b/test/cpp/qps/qps_json_driver_scenario_gen.py @@ -24,5 +24,6 @@ gen.generate_scenarios_bzl( gen.generate_qps_json_driver_scenarios(), - os.path.join(script_dir, 'qps_json_driver_scenarios.bzl'), - 'QPS_JSON_DRIVER_SCENARIOS') + os.path.join(script_dir, "qps_json_driver_scenarios.bzl"), + "QPS_JSON_DRIVER_SCENARIOS", +) diff --git a/test/cpp/qps/scenario_generator_helper.py b/test/cpp/qps/scenario_generator_helper.py index 14666c7ad56b4..907925b5e521b 100755 --- a/test/cpp/qps/scenario_generator_helper.py +++ b/test/cpp/qps/scenario_generator_helper.py @@ -23,7 +23,8 @@ import yaml run_tests_root = os.path.abspath( - os.path.join(os.path.dirname(sys.argv[0]), '../../../tools/run_tests')) + os.path.join(os.path.dirname(sys.argv[0]), "../../../tools/run_tests") +) sys.path.append(run_tests_root) import performance.scenario_config as scenario_config @@ -48,28 +49,32 @@ def _mutate_scenario(scenario_json): """Modifies vanilla benchmark scenario config to make it more suitable for running as a unit test.""" # tweak parameters to get fast test times scenario_json = dict(scenario_json) - scenario_json['warmup_seconds'] = 0 - scenario_json['benchmark_seconds'] = 1 + scenario_json["warmup_seconds"] = 0 + scenario_json["benchmark_seconds"] = 1 outstanding_rpcs_divisor = 1 - if scenario_json['client_config'][ - 'client_type'] == 'SYNC_CLIENT' or scenario_json['server_config'][ - 'server_type'] == 'SYNC_SERVER': + if ( + scenario_json["client_config"]["client_type"] == "SYNC_CLIENT" + or scenario_json["server_config"]["server_type"] == "SYNC_SERVER" + ): # reduce the number of threads needed for scenarios that use synchronous API outstanding_rpcs_divisor = 10 - scenario_json['client_config']['outstanding_rpcs_per_channel'] = max( - 1, scenario_json['client_config']['outstanding_rpcs_per_channel'] // - outstanding_rpcs_divisor) + scenario_json["client_config"]["outstanding_rpcs_per_channel"] = max( + 1, + scenario_json["client_config"]["outstanding_rpcs_per_channel"] + // outstanding_rpcs_divisor, + ) # Some scenarios use high channel count since when actually # benchmarking, we want to saturate the machine that runs the benchmark. # For unit test, this is an overkill. max_client_channels = 16 - if scenario_json['client_config']['rpc_type'] == 'STREAMING_FROM_SERVER': + if scenario_json["client_config"]["rpc_type"] == "STREAMING_FROM_SERVER": # streaming from server scenarios tend to have trouble shutting down # quickly if there are too many channels. max_client_channels = 4 - scenario_json['client_config']['client_channels'] = min( - max_client_channels, scenario_json['client_config']['client_channels']) + scenario_json["client_config"]["client_channels"] = min( + max_client_channels, scenario_json["client_config"]["client_channels"] + ) return scenario_config.remove_nonproto_fields(scenario_json) @@ -78,7 +83,7 @@ def generate_json_run_localhost_scenarios(): return [ _mutate_scenario(scenario_json) for scenario_json in scenario_config.CXXLanguage().scenarios() - if 'scalable' in scenario_json.get('CATEGORIES', []) + if "scalable" in scenario_json.get("CATEGORIES", []) ] @@ -86,7 +91,7 @@ def generate_qps_json_driver_scenarios(): return [ _mutate_scenario(scenario_json) for scenario_json in scenario_config.CXXLanguage().scenarios() - if 'inproc' in scenario_json.get('CATEGORIES', []) + if "inproc" in scenario_json.get("CATEGORIES", []) ] @@ -94,19 +99,21 @@ def generate_scenarios_bzl(json_scenarios, bzl_filename, bzl_variablename): """Generate .bzl file that defines a variable with JSON scenario configs.""" all_scenarios = [] for scenario in json_scenarios: - scenario_name = scenario['name'] + scenario_name = scenario["name"] # argument will be passed as "--scenarios_json" to the test binary # the string needs to be quoted in \' to ensure it gets passed as a single argument in shell - scenarios_json_arg_str = '\\\'%s\\\'' % json.dumps( - {'scenarios': [scenario]}) + scenarios_json_arg_str = "\\'%s\\'" % json.dumps( + {"scenarios": [scenario]} + ) all_scenarios.append((scenario_name, scenarios_json_arg_str)) - with open(bzl_filename, 'w') as f: + with open(bzl_filename, "w") as f: f.write(_COPYRIGHT) f.write( - '"""AUTOGENERATED: configuration of benchmark scenarios to be run as bazel test"""\n\n' + '"""AUTOGENERATED: configuration of benchmark scenarios to be run' + ' as bazel test"""\n\n' ) - f.write('%s = {\n' % bzl_variablename) + f.write("%s = {\n" % bzl_variablename) for scenario in all_scenarios: f.write(" \"%s\": '%s',\n" % (scenario[0], scenario[1])) - f.write('}\n') + f.write("}\n") diff --git a/test/distrib/bazel/python/helloworld.py b/test/distrib/bazel/python/helloworld.py index 496f75e36883f..1cf33796d473d 100644 --- a/test/distrib/bazel/python/helloworld.py +++ b/test/distrib/bazel/python/helloworld.py @@ -25,19 +25,19 @@ import helloworld_pb2 import helloworld_pb2_grpc -_HOST = 'localhost' -_SERVER_ADDRESS = '{}:0'.format(_HOST) +_HOST = "localhost" +_SERVER_ADDRESS = "{}:0".format(_HOST) class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): - request_in_flight = datetime.datetime.now() - \ - request.request_initiation.ToDatetime() + request_in_flight = ( + datetime.datetime.now() - request.request_initiation.ToDatetime() + ) request_duration = duration_pb2.Duration() request_duration.FromTimedelta(request_in_flight) return helloworld_pb2.HelloReply( - message='Hello, %s!' % request.name, + message="Hello, %s!" % request.name, request_duration=request_duration, ) @@ -55,22 +55,23 @@ def _listening_server(): class ImportTest(unittest.TestCase): - def test_import(self): with _listening_server() as port: - with grpc.insecure_channel('{}:{}'.format(_HOST, port)) as channel: + with grpc.insecure_channel("{}:{}".format(_HOST, port)) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) request_timestamp = timestamp_pb2.Timestamp() request_timestamp.GetCurrentTime() - response = stub.SayHello(helloworld_pb2.HelloRequest( - name='you', - request_initiation=request_timestamp, - ), - wait_for_ready=True) + response = stub.SayHello( + helloworld_pb2.HelloRequest( + name="you", + request_initiation=request_timestamp, + ), + wait_for_ready=True, + ) self.assertEqual(response.message, "Hello, you!") self.assertGreater(response.request_duration.nanos, 0) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main() diff --git a/test/distrib/bazel/python/helloworld_moved.py b/test/distrib/bazel/python/helloworld_moved.py index c902735862a22..7a459ddf4a208 100644 --- a/test/distrib/bazel/python/helloworld_moved.py +++ b/test/distrib/bazel/python/helloworld_moved.py @@ -26,21 +26,22 @@ from google.protobuf import timestamp_pb2 from google.cloud import helloworld_pb2 from google.cloud import helloworld_pb2_grpc + # isort: on -_HOST = 'localhost' -_SERVER_ADDRESS = '{}:0'.format(_HOST) +_HOST = "localhost" +_SERVER_ADDRESS = "{}:0".format(_HOST) class Greeter(helloworld_pb2_grpc.GreeterServicer): - def SayHello(self, request, context): - request_in_flight = datetime.datetime.now() - \ - request.request_initiation.ToDatetime() + request_in_flight = ( + datetime.datetime.now() - request.request_initiation.ToDatetime() + ) request_duration = duration_pb2.Duration() request_duration.FromTimedelta(request_in_flight) return helloworld_pb2.HelloReply( - message='Hello, %s!' % request.name, + message="Hello, %s!" % request.name, request_duration=request_duration, ) @@ -58,22 +59,23 @@ def _listening_server(): class ImportTest(unittest.TestCase): - def test_import(self): with _listening_server() as port: - with grpc.insecure_channel('{}:{}'.format(_HOST, port)) as channel: + with grpc.insecure_channel("{}:{}".format(_HOST, port)) as channel: stub = helloworld_pb2_grpc.GreeterStub(channel) request_timestamp = timestamp_pb2.Timestamp() request_timestamp.GetCurrentTime() - response = stub.SayHello(helloworld_pb2.HelloRequest( - name='you', - request_initiation=request_timestamp, - ), - wait_for_ready=True) + response = stub.SayHello( + helloworld_pb2.HelloRequest( + name="you", + request_initiation=request_timestamp, + ), + wait_for_ready=True, + ) self.assertEqual(response.message, "Hello, you!") self.assertGreater(response.request_duration.nanos, 0) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main() diff --git a/test/distrib/bazel/python/namespaced/upper/example/import_no_strip_test.py b/test/distrib/bazel/python/namespaced/upper/example/import_no_strip_test.py index f0c44539c20f3..4ec57c8a37f46 100644 --- a/test/distrib/bazel/python/namespaced/upper/example/import_no_strip_test.py +++ b/test/distrib/bazel/python/namespaced/upper/example/import_no_strip_test.py @@ -17,23 +17,25 @@ class ImportTest(unittest.TestCase): - def test_import(self): - from foo.bar.namespaced.upper.example.namespaced_example_pb2 import \ - NamespacedExample + from foo.bar.namespaced.upper.example.namespaced_example_pb2 import ( + NamespacedExample, + ) + namespaced_example = NamespacedExample() namespaced_example.value = "hello" # Superfluous assert, important part is namespaced example was imported. self.assertEqual(namespaced_example.value, "hello") def test_grpc(self): - from foo.bar.namespaced.upper.example.namespaced_example_pb2_grpc import \ - NamespacedServiceStub + from foo.bar.namespaced.upper.example.namespaced_example_pb2_grpc import ( + NamespacedServiceStub, + ) # No error from import self.assertEqual(1, 1) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main() diff --git a/test/distrib/bazel/python/namespaced/upper/example/import_strip_test.py b/test/distrib/bazel/python/namespaced/upper/example/import_strip_test.py index 1d0d14f48a246..56b9975d95037 100644 --- a/test/distrib/bazel/python/namespaced/upper/example/import_strip_test.py +++ b/test/distrib/bazel/python/namespaced/upper/example/import_strip_test.py @@ -17,9 +17,9 @@ class ImportTest(unittest.TestCase): - def test_import(self): from foo.bar.namespaced_example_pb2 import NamespacedExample + namespaced_example = NamespacedExample() namespaced_example.value = "hello" # Superfluous assert, important part is namespaced example was imported. @@ -32,6 +32,6 @@ def test_grpc(self): self.assertEqual(1, 1) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main() diff --git a/test/distrib/bazel/python/namespaced/upper/example/no_import_no_strip_test.py b/test/distrib/bazel/python/namespaced/upper/example/no_import_no_strip_test.py index 6702e52ed1162..d745c4ea34c29 100644 --- a/test/distrib/bazel/python/namespaced/upper/example/no_import_no_strip_test.py +++ b/test/distrib/bazel/python/namespaced/upper/example/no_import_no_strip_test.py @@ -17,23 +17,25 @@ class ImportTest(unittest.TestCase): - def test_import(self): - from namespaced.upper.example.namespaced_example_pb2 import \ - NamespacedExample + from namespaced.upper.example.namespaced_example_pb2 import ( + NamespacedExample, + ) + namespaced_example = NamespacedExample() namespaced_example.value = "hello" # Superfluous assert, important part is namespaced example was imported. self.assertEqual(namespaced_example.value, "hello") def test_grpc(self): - from namespaced.upper.example.namespaced_example_pb2_grpc import \ - NamespacedServiceStub + from namespaced.upper.example.namespaced_example_pb2_grpc import ( + NamespacedServiceStub, + ) # No error from import self.assertEqual(1, 1) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main() diff --git a/test/distrib/bazel/python/namespaced/upper/example/no_import_strip_test.py b/test/distrib/bazel/python/namespaced/upper/example/no_import_strip_test.py index 8e756a04f9fe8..190241a1f8390 100644 --- a/test/distrib/bazel/python/namespaced/upper/example/no_import_strip_test.py +++ b/test/distrib/bazel/python/namespaced/upper/example/no_import_strip_test.py @@ -17,9 +17,9 @@ class ImportTest(unittest.TestCase): - def test_import(self): from namespaced_example_pb2 import NamespacedExample + namespaced_example = NamespacedExample() namespaced_example.value = "hello" # Superfluous assert, important part is namespaced example was imported. @@ -32,6 +32,6 @@ def test_grpc(self): self.assertEqual(1, 1) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() unittest.main() diff --git a/test/distrib/gcf/python/main.py b/test/distrib/gcf/python/main.py index 8f649932b9f28..41cd56979fa1e 100644 --- a/test/distrib/gcf/python/main.py +++ b/test/distrib/gcf/python/main.py @@ -17,14 +17,14 @@ ps_client = pubsub_v1.PublisherClient() _PROJECT_ID = "grpc-testing" -_PUBSUB_TOPIC = 'gcf-distribtest-topic' +_PUBSUB_TOPIC = "gcf-distribtest-topic" @functions_framework.http def test_publish(request): topic_path = ps_client.topic_path(_PROJECT_ID, _PUBSUB_TOPIC) message = '{"function": "TEST"}' - message_bytes = message.encode('utf-8') + message_bytes = message.encode("utf-8") for _ in range(100): future = ps_client.publish(topic_path, data=message_bytes) diff --git a/test/distrib/python/distribtest.py b/test/distrib/python/distribtest.py index 0fc2854c7783e..586ec2f3b73e4 100644 --- a/test/distrib/python/distribtest.py +++ b/test/distrib/python/distribtest.py @@ -16,6 +16,6 @@ # This code doesn't do much but makes sure the native extension is loaded # which is what we are testing here. -channel = grpc.insecure_channel('localhost:1000') +channel = grpc.insecure_channel("localhost:1000") del channel -print('Success!') +print("Success!") diff --git a/test/http2_test/http2_base_server.py b/test/http2_test/http2_base_server.py index dff2bbdb87b2f..a665f9fae0911 100644 --- a/test/http2_test/http2_base_server.py +++ b/test/http2_test/http2_base_server.py @@ -28,18 +28,17 @@ class H2ProtocolBaseServer(twisted.internet.protocol.Protocol): - def __init__(self): self._conn = h2.connection.H2Connection(client_side=False) self._recv_buffer = {} self._handlers = {} - self._handlers['ConnectionMade'] = self.on_connection_made_default - self._handlers['DataReceived'] = self.on_data_received_default - self._handlers['WindowUpdated'] = self.on_window_update_default - self._handlers['RequestReceived'] = self.on_request_received_default - self._handlers['SendDone'] = self.on_send_done_default - self._handlers['ConnectionLost'] = self.on_connection_lost - self._handlers['PingAcknowledged'] = self.on_ping_acknowledged_default + self._handlers["ConnectionMade"] = self.on_connection_made_default + self._handlers["DataReceived"] = self.on_data_received_default + self._handlers["WindowUpdated"] = self.on_window_update_default + self._handlers["RequestReceived"] = self.on_request_received_default + self._handlers["SendDone"] = self.on_send_done_default + self._handlers["ConnectionLost"] = self.on_connection_lost + self._handlers["PingAcknowledged"] = self.on_ping_acknowledged_default self._stream_status = {} self._send_remaining = {} self._outstanding_pings = 0 @@ -48,19 +47,19 @@ def set_handlers(self, handlers): self._handlers = handlers def connectionMade(self): - self._handlers['ConnectionMade']() + self._handlers["ConnectionMade"]() def connectionLost(self, reason): - self._handlers['ConnectionLost'](reason) + self._handlers["ConnectionLost"](reason) def on_connection_made_default(self): - logging.info('Connection Made') + logging.info("Connection Made") self._conn.initiate_connection() self.transport.setTcpNoDelay(True) self.transport.write(self._conn.data_to_send()) def on_connection_lost(self, reason): - logging.info('Disconnected %s' % reason) + logging.info("Disconnected %s" % reason) def dataReceived(self, data): try: @@ -72,29 +71,36 @@ def dataReceived(self, data): if self._conn.data_to_send: self.transport.write(self._conn.data_to_send()) for event in events: - if isinstance(event, h2.events.RequestReceived - ) and self._handlers.has_key('RequestReceived'): - logging.info('RequestReceived Event for stream: %d' % - event.stream_id) - self._handlers['RequestReceived'](event) - elif isinstance(event, h2.events.DataReceived - ) and self._handlers.has_key('DataReceived'): - logging.info('DataReceived Event for stream: %d' % - event.stream_id) - self._handlers['DataReceived'](event) - elif isinstance(event, h2.events.WindowUpdated - ) and self._handlers.has_key('WindowUpdated'): - logging.info('WindowUpdated Event for stream: %d' % - event.stream_id) - self._handlers['WindowUpdated'](event) - elif isinstance(event, h2.events.PingAcknowledged - ) and self._handlers.has_key('PingAcknowledged'): - logging.info('PingAcknowledged Event') - self._handlers['PingAcknowledged'](event) + if isinstance( + event, h2.events.RequestReceived + ) and self._handlers.has_key("RequestReceived"): + logging.info( + "RequestReceived Event for stream: %d" % event.stream_id + ) + self._handlers["RequestReceived"](event) + elif isinstance( + event, h2.events.DataReceived + ) and self._handlers.has_key("DataReceived"): + logging.info( + "DataReceived Event for stream: %d" % event.stream_id + ) + self._handlers["DataReceived"](event) + elif isinstance( + event, h2.events.WindowUpdated + ) and self._handlers.has_key("WindowUpdated"): + logging.info( + "WindowUpdated Event for stream: %d" % event.stream_id + ) + self._handlers["WindowUpdated"](event) + elif isinstance( + event, h2.events.PingAcknowledged + ) and self._handlers.has_key("PingAcknowledged"): + logging.info("PingAcknowledged Event") + self._handlers["PingAcknowledged"](event) self.transport.write(self._conn.data_to_send()) def on_ping_acknowledged_default(self, event): - logging.info('ping acknowledged') + logging.info("ping acknowledged") self._outstanding_pings -= 1 def on_data_received_default(self, event): @@ -102,51 +108,53 @@ def on_data_received_default(self, event): self._recv_buffer[event.stream_id] += event.data def on_request_received_default(self, event): - self._recv_buffer[event.stream_id] = '' + self._recv_buffer[event.stream_id] = "" self._stream_id = event.stream_id self._stream_status[event.stream_id] = True self._conn.send_headers( stream_id=event.stream_id, headers=[ - (':status', '200'), - ('content-type', 'application/grpc'), - ('grpc-encoding', 'identity'), - ('grpc-accept-encoding', 'identity,deflate,gzip'), + (":status", "200"), + ("content-type", "application/grpc"), + ("grpc-encoding", "identity"), + ("grpc-accept-encoding", "identity,deflate,gzip"), ], ) self.transport.write(self._conn.data_to_send()) - def on_window_update_default(self, - _, - pad_length=None, - read_chunk_size=_READ_CHUNK_SIZE): + def on_window_update_default( + self, _, pad_length=None, read_chunk_size=_READ_CHUNK_SIZE + ): # try to resume sending on all active streams (update might be for connection) for stream_id in self._send_remaining: - self.default_send(stream_id, - pad_length=pad_length, - read_chunk_size=read_chunk_size) + self.default_send( + stream_id, + pad_length=pad_length, + read_chunk_size=read_chunk_size, + ) def send_reset_stream(self): self._conn.reset_stream(self._stream_id) self.transport.write(self._conn.data_to_send()) - def setup_send(self, - data_to_send, - stream_id, - pad_length=None, - read_chunk_size=_READ_CHUNK_SIZE): - logging.info('Setting up data to send for stream_id: %d' % stream_id) + def setup_send( + self, + data_to_send, + stream_id, + pad_length=None, + read_chunk_size=_READ_CHUNK_SIZE, + ): + logging.info("Setting up data to send for stream_id: %d" % stream_id) self._send_remaining[stream_id] = len(data_to_send) self._send_offset = 0 self._data_to_send = data_to_send - self.default_send(stream_id, - pad_length=pad_length, - read_chunk_size=read_chunk_size) - - def default_send(self, - stream_id, - pad_length=None, - read_chunk_size=_READ_CHUNK_SIZE): + self.default_send( + stream_id, pad_length=pad_length, read_chunk_size=read_chunk_size + ) + + def default_send( + self, stream_id, pad_length=None, read_chunk_size=_READ_CHUNK_SIZE + ): if not self._send_remaining.has_key(stream_id): # not setup to send data yet return @@ -156,40 +164,49 @@ def default_send(self, padding_bytes = pad_length + 1 if pad_length is not None else 0 if lfcw - padding_bytes <= 0: logging.info( - 'Stream %d. lfcw: %d. padding bytes: %d. not enough quota yet' - % (stream_id, lfcw, padding_bytes)) + "Stream %d. lfcw: %d. padding bytes: %d. not enough" + " quota yet" % (stream_id, lfcw, padding_bytes) + ) break chunk_size = min(lfcw - padding_bytes, read_chunk_size) bytes_to_send = min(chunk_size, self._send_remaining[stream_id]) logging.info( - 'flow_control_window = %d. sending [%d:%d] stream_id %d. includes %d total padding bytes' - % (lfcw, self._send_offset, self._send_offset + bytes_to_send + - padding_bytes, stream_id, padding_bytes)) + "flow_control_window = %d. sending [%d:%d] stream_id %d." + " includes %d total padding bytes" + % ( + lfcw, + self._send_offset, + self._send_offset + bytes_to_send + padding_bytes, + stream_id, + padding_bytes, + ) + ) # The receiver might allow sending frames larger than the http2 minimum # max frame size (16384), but this test should never send more than 16384 # for simplicity (which is always legal). if bytes_to_send + padding_bytes > _MIN_SETTINGS_MAX_FRAME_SIZE: - raise ValueError("overload: sending %d" % - (bytes_to_send + padding_bytes)) - data = self._data_to_send[self._send_offset:self._send_offset + - bytes_to_send] + raise ValueError( + "overload: sending %d" % (bytes_to_send + padding_bytes) + ) + data = self._data_to_send[ + self._send_offset : self._send_offset + bytes_to_send + ] try: - self._conn.send_data(stream_id, - data, - end_stream=False, - pad_length=pad_length) + self._conn.send_data( + stream_id, data, end_stream=False, pad_length=pad_length + ) except h2.exceptions.ProtocolError: - logging.info('Stream %d is closed' % stream_id) + logging.info("Stream %d is closed" % stream_id) break self._send_remaining[stream_id] -= bytes_to_send self._send_offset += bytes_to_send if self._send_remaining[stream_id] == 0: - self._handlers['SendDone'](stream_id) + self._handlers["SendDone"](stream_id) def default_ping(self): - logging.info('sending ping') + logging.info("sending ping") self._outstanding_pings += 1 - self._conn.ping(b'\x00' * 8) + self._conn.ping(b"\x00" * 8) self.transport.write(self._conn.data_to_send()) def on_send_done_default(self, stream_id): @@ -197,33 +214,36 @@ def on_send_done_default(self, stream_id): self._stream_status[stream_id] = False self.default_send_trailer(stream_id) else: - logging.error('Stream %d is already closed' % stream_id) + logging.error("Stream %d is already closed" % stream_id) def default_send_trailer(self, stream_id): - logging.info('Sending trailer for stream id %d' % stream_id) - self._conn.send_headers(stream_id, - headers=[('grpc-status', '0')], - end_stream=True) + logging.info("Sending trailer for stream id %d" % stream_id) + self._conn.send_headers( + stream_id, headers=[("grpc-status", "0")], end_stream=True + ) self.transport.write(self._conn.data_to_send()) @staticmethod def default_response_data(response_size): sresp = messages_pb2.SimpleResponse() - sresp.payload.body = b'\x00' * response_size + sresp.payload.body = b"\x00" * response_size serialized_resp_proto = sresp.SerializeToString() - response_data = b'\x00' + struct.pack( - 'i', len(serialized_resp_proto))[::-1] + serialized_resp_proto + response_data = ( + b"\x00" + + struct.pack("i", len(serialized_resp_proto))[::-1] + + serialized_resp_proto + ) return response_data def parse_received_data(self, stream_id): - """ returns a grpc framed string of bytes containing response proto of the size - asked in request """ + """returns a grpc framed string of bytes containing response proto of the size + asked in request""" recv_buffer = self._recv_buffer[stream_id] - grpc_msg_size = struct.unpack('i', recv_buffer[1:5][::-1])[0] + grpc_msg_size = struct.unpack("i", recv_buffer[1:5][::-1])[0] if len(recv_buffer) != _GRPC_HEADER_SIZE + grpc_msg_size: return None - req_proto_str = recv_buffer[5:5 + grpc_msg_size] + req_proto_str = recv_buffer[5 : 5 + grpc_msg_size] sr = messages_pb2.SimpleRequest() sr.ParseFromString(req_proto_str) - logging.info('Parsed simple request for stream %d' % stream_id) + logging.info("Parsed simple request for stream %d" % stream_id) return sr diff --git a/test/http2_test/http2_server_health_check.py b/test/http2_test/http2_server_health_check.py index 11cf40793e392..c91e4aa395a69 100644 --- a/test/http2_test/http2_server_health_check.py +++ b/test/http2_test/http2_server_health_check.py @@ -19,17 +19,17 @@ # Utility to healthcheck the http2 server. Used when starting the server to # verify that the server is live before tests begin. -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--server_host', type=str, default='localhost') - parser.add_argument('--server_port', type=int, default=8080) + parser.add_argument("--server_host", type=str, default="localhost") + parser.add_argument("--server_port", type=int, default=8080) args = parser.parse_args() server_host = args.server_host server_port = args.server_port - conn = hyper.HTTP20Connection('%s:%d' % (server_host, server_port)) - conn.request('POST', '/grpc.testing.TestService/UnaryCall') + conn = hyper.HTTP20Connection("%s:%d" % (server_host, server_port)) + conn.request("POST", "/grpc.testing.TestService/UnaryCall") resp = conn.get_response() - if resp.headers.get('grpc-encoding') is None: + if resp.headers.get("grpc-encoding") is None: sys.exit(1) else: sys.exit(0) diff --git a/test/http2_test/http2_test_server.py b/test/http2_test/http2_test_server.py index ed402e09704e3..64fd04aa0a20c 100644 --- a/test/http2_test/http2_test_server.py +++ b/test/http2_test/http2_test_server.py @@ -31,40 +31,38 @@ import twisted.internet.reactor _TEST_CASE_MAPPING = { - 'rst_after_header': test_rst_after_header.TestcaseRstStreamAfterHeader, - 'rst_after_data': test_rst_after_data.TestcaseRstStreamAfterData, - 'rst_during_data': test_rst_during_data.TestcaseRstStreamDuringData, - 'goaway': test_goaway.TestcaseGoaway, - 'ping': test_ping.TestcasePing, - 'max_streams': test_max_streams.TestcaseSettingsMaxStreams, - + "rst_after_header": test_rst_after_header.TestcaseRstStreamAfterHeader, + "rst_after_data": test_rst_after_data.TestcaseRstStreamAfterData, + "rst_during_data": test_rst_during_data.TestcaseRstStreamDuringData, + "goaway": test_goaway.TestcaseGoaway, + "ping": test_ping.TestcasePing, + "max_streams": test_max_streams.TestcaseSettingsMaxStreams, # Positive tests below: - 'data_frame_padding': test_data_frame_padding.TestDataFramePadding, - 'no_df_padding_sanity_test': test_data_frame_padding.TestDataFramePadding, + "data_frame_padding": test_data_frame_padding.TestDataFramePadding, + "no_df_padding_sanity_test": test_data_frame_padding.TestDataFramePadding, } _exit_code = 0 class H2Factory(twisted.internet.protocol.Factory): - def __init__(self, testcase): - logging.info('Creating H2Factory for new connection (%s)', testcase) + logging.info("Creating H2Factory for new connection (%s)", testcase) self._num_streams = 0 self._testcase = testcase def buildProtocol(self, addr): self._num_streams += 1 - logging.info('New Connection: %d' % self._num_streams) + logging.info("New Connection: %d" % self._num_streams) if not _TEST_CASE_MAPPING.has_key(self._testcase): - logging.error('Unknown test case: %s' % self._testcase) - assert (0) + logging.error("Unknown test case: %s" % self._testcase) + assert 0 else: t = _TEST_CASE_MAPPING[self._testcase] - if self._testcase == 'goaway': + if self._testcase == "goaway": return t(self._num_streams).get_base_server() - elif self._testcase == 'no_df_padding_sanity_test': + elif self._testcase == "no_df_padding_sanity_test": return t(use_padding=False).get_base_server() else: return t().get_base_server() @@ -73,14 +71,18 @@ def buildProtocol(self, addr): def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument( - '--base_port', + "--base_port", type=int, default=8080, - help='base port to run the servers (default: 8080). One test server is ' - 'started on each incrementing port, beginning with base_port, in the ' - 'following order: data_frame_padding,goaway,max_streams,' - 'no_df_padding_sanity_test,ping,rst_after_data,rst_after_header,' - 'rst_during_data') + help=( + "base port to run the servers (default: 8080). One test server is " + "started on each incrementing port, beginning with base_port, in" + " the " + "following order: data_frame_padding,goaway,max_streams," + "no_df_padding_sanity_test,ping,rst_after_data,rst_after_header," + "rst_during_data" + ), + ) return parser.parse_args() @@ -92,32 +94,35 @@ def listen_error(reason): # with exit code 1. global _exit_code _exit_code = 1 - logging.error('Listening failed: %s' % reason.value) + logging.error("Listening failed: %s" % reason.value) twisted.internet.reactor.stop() deferred.addErrback(listen_error) def start_test_servers(base_port): - """ Start one server per test case on incrementing port numbers - beginning with base_port """ + """Start one server per test case on incrementing port numbers + beginning with base_port""" index = 0 for test_case in sorted(_TEST_CASE_MAPPING.keys()): portnum = base_port + index - logging.warning('serving on port %d : %s' % (portnum, test_case)) + logging.warning("serving on port %d : %s" % (portnum, test_case)) endpoint = twisted.internet.endpoints.TCP4ServerEndpoint( - twisted.internet.reactor, portnum, backlog=128) + twisted.internet.reactor, portnum, backlog=128 + ) # Wait until the reactor is running before calling endpoint.listen(). twisted.internet.reactor.callWhenRunning(listen, endpoint, test_case) index += 1 -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig( - format= - '%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s', - level=logging.INFO) + format=( + "%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s" + ), + level=logging.INFO, + ) args = parse_arguments() start_test_servers(args.base_port) twisted.internet.reactor.run() diff --git a/test/http2_test/test_data_frame_padding.py b/test/http2_test/test_data_frame_padding.py index cfd86956e878e..65d89d6a2fef9 100644 --- a/test/http2_test/test_data_frame_padding.py +++ b/test/http2_test/test_data_frame_padding.py @@ -28,14 +28,15 @@ class TestDataFramePadding(object): In response to an incoming request, this test sends headers, followed by data, followed by a reset stream frame. Client asserts that the RPC failed. Client needs to deliver the complete message to the application layer. - """ + """ def __init__(self, use_padding=True): self._base_server = http2_base_server.H2ProtocolBaseServer() - self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['WindowUpdated'] = self.on_window_update + self._base_server._handlers["DataReceived"] = self.on_data_received + self._base_server._handlers["WindowUpdated"] = self.on_window_update self._base_server._handlers[ - 'RequestReceived'] = self.on_request_received + "RequestReceived" + ] = self.on_request_received # _total_updates maps stream ids to total flow control updates received self._total_updates = {} @@ -52,43 +53,54 @@ def get_base_server(self): return self._base_server def on_data_received(self, event): - logging.info('on data received. Stream id: %d. Data length: %d' % - (event.stream_id, len(event.data))) + logging.info( + "on data received. Stream id: %d. Data length: %d" + % (event.stream_id, len(event.data)) + ) self._base_server.on_data_received_default(event) if len(event.data) == 0: return sr = self._base_server.parse_received_data(event.stream_id) - stream_bytes = '' + stream_bytes = "" # Check if full grpc msg has been read into the recv buffer yet if sr: response_data = self._base_server.default_response_data( - sr.response_size) - logging.info('Stream id: %d. total resp size: %d' % - (event.stream_id, len(response_data))) + sr.response_size + ) + logging.info( + "Stream id: %d. total resp size: %d" + % (event.stream_id, len(response_data)) + ) # Begin sending the response. Add ``self._pad_length`` padding to each # data frame and split the whole message into data frames each carrying # only self._read_chunk_size of data. # The purpose is to have the majority of the data frame response bytes # be padding bytes, since ``self._pad_length`` >> ``self._read_chunk_size``. - self._base_server.setup_send(response_data, - event.stream_id, - pad_length=self._pad_length, - read_chunk_size=self._read_chunk_size) + self._base_server.setup_send( + response_data, + event.stream_id, + pad_length=self._pad_length, + read_chunk_size=self._read_chunk_size, + ) def on_request_received(self, event): self._base_server.on_request_received_default(event) - logging.info('on request received. Stream id: %s.' % event.stream_id) + logging.info("on request received. Stream id: %s." % event.stream_id) self._total_updates[event.stream_id] = 0 # Log debug info and try to resume sending on all currently active streams. def on_window_update(self, event): - logging.info('on window update. Stream id: %s. Delta: %s' % - (event.stream_id, event.delta)) + logging.info( + "on window update. Stream id: %s. Delta: %s" + % (event.stream_id, event.delta) + ) self._total_updates[event.stream_id] += event.delta total = self._total_updates[event.stream_id] - logging.info('... - total updates for stream %d : %d' % - (event.stream_id, total)) + logging.info( + "... - total updates for stream %d : %d" % (event.stream_id, total) + ) self._base_server.on_window_update_default( event, pad_length=self._pad_length, - read_chunk_size=self._read_chunk_size) + read_chunk_size=self._read_chunk_size, + ) diff --git a/test/http2_test/test_goaway.py b/test/http2_test/test_goaway.py index 8e9664c95d566..3c6323e40e5ee 100644 --- a/test/http2_test/test_goaway.py +++ b/test/http2_test/test_goaway.py @@ -19,20 +19,21 @@ class TestcaseGoaway(object): - """ + """ This test does the following: Process incoming request normally, i.e. send headers, data and trailers. Then send a GOAWAY frame with the stream id of the processed request. It checks that the next request is made on a different TCP connection. - """ + """ def __init__(self, iteration): self._base_server = http2_base_server.H2ProtocolBaseServer() self._base_server._handlers[ - 'RequestReceived'] = self.on_request_received - self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['SendDone'] = self.on_send_done - self._base_server._handlers['ConnectionLost'] = self.on_connection_lost + "RequestReceived" + ] = self.on_request_received + self._base_server._handlers["DataReceived"] = self.on_data_received + self._base_server._handlers["SendDone"] = self.on_send_done + self._base_server._handlers["ConnectionLost"] = self.on_connection_lost self._ready_to_send = False self._iteration = iteration @@ -40,17 +41,17 @@ def get_base_server(self): return self._base_server def on_connection_lost(self, reason): - logging.info('Disconnect received. Count %d' % self._iteration) + logging.info("Disconnect received. Count %d" % self._iteration) # _iteration == 2 => Two different connections have been used. if self._iteration == 2: self._base_server.on_connection_lost(reason) def on_send_done(self, stream_id): self._base_server.on_send_done_default(stream_id) - logging.info('Sending GOAWAY for stream %d:' % stream_id) - self._base_server._conn.close_connection(error_code=0, - additional_data=None, - last_stream_id=stream_id) + logging.info("Sending GOAWAY for stream %d:" % stream_id) + self._base_server._conn.close_connection( + error_code=0, additional_data=None, last_stream_id=stream_id + ) self._base_server._stream_status[stream_id] = False def on_request_received(self, event): @@ -61,8 +62,9 @@ def on_data_received(self, event): self._base_server.on_data_received_default(event) sr = self._base_server.parse_received_data(event.stream_id) if sr: - logging.info('Creating response size = %s' % sr.response_size) + logging.info("Creating response size = %s" % sr.response_size) response_data = self._base_server.default_response_data( - sr.response_size) + sr.response_size + ) self._ready_to_send = True self._base_server.setup_send(response_data, event.stream_id) diff --git a/test/http2_test/test_max_streams.py b/test/http2_test/test_max_streams.py index 62538433ae1b1..c0a8cfd4b3db3 100644 --- a/test/http2_test/test_max_streams.py +++ b/test/http2_test/test_max_streams.py @@ -22,31 +22,34 @@ class TestcaseSettingsMaxStreams(object): """ This test sets MAX_CONCURRENT_STREAMS to 1 and asserts that at any point only 1 stream is active. - """ + """ def __init__(self): self._base_server = http2_base_server.H2ProtocolBaseServer() - self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['ConnectionMade'] = self.on_connection_made + self._base_server._handlers["DataReceived"] = self.on_data_received + self._base_server._handlers["ConnectionMade"] = self.on_connection_made def get_base_server(self): return self._base_server def on_connection_made(self): - logging.info('Connection Made') + logging.info("Connection Made") self._base_server._conn.initiate_connection() self._base_server._conn.update_settings( - {hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1}) + {hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1} + ) self._base_server.transport.setTcpNoDelay(True) self._base_server.transport.write( - self._base_server._conn.data_to_send()) + self._base_server._conn.data_to_send() + ) def on_data_received(self, event): self._base_server.on_data_received_default(event) sr = self._base_server.parse_received_data(event.stream_id) if sr: - logging.info('Creating response of size = %s' % sr.response_size) + logging.info("Creating response of size = %s" % sr.response_size) response_data = self._base_server.default_response_data( - sr.response_size) + sr.response_size + ) self._base_server.setup_send(response_data, event.stream_id) # TODO (makdharma): Add assertion to check number of live streams diff --git a/test/http2_test/test_ping.py b/test/http2_test/test_ping.py index 4bf1d6c4067d9..49db39332cdf1 100644 --- a/test/http2_test/test_ping.py +++ b/test/http2_test/test_ping.py @@ -22,14 +22,15 @@ class TestcasePing(object): This test injects PING frames before and after header and data. Keeps count of outstanding ping response and asserts when the count is non-zero at the end of the test. - """ + """ def __init__(self): self._base_server = http2_base_server.H2ProtocolBaseServer() self._base_server._handlers[ - 'RequestReceived'] = self.on_request_received - self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['ConnectionLost'] = self.on_connection_lost + "RequestReceived" + ] = self.on_request_received + self._base_server._handlers["DataReceived"] = self.on_data_received + self._base_server._handlers["ConnectionLost"] = self.on_connection_lost def get_base_server(self): return self._base_server @@ -43,15 +44,18 @@ def on_data_received(self, event): self._base_server.on_data_received_default(event) sr = self._base_server.parse_received_data(event.stream_id) if sr: - logging.info('Creating response size = %s' % sr.response_size) + logging.info("Creating response size = %s" % sr.response_size) response_data = self._base_server.default_response_data( - sr.response_size) + sr.response_size + ) self._base_server.default_ping() self._base_server.setup_send(response_data, event.stream_id) self._base_server.default_ping() def on_connection_lost(self, reason): - logging.info('Disconnect received. Ping Count %d' % - self._base_server._outstanding_pings) - assert (self._base_server._outstanding_pings == 0) + logging.info( + "Disconnect received. Ping Count %d" + % self._base_server._outstanding_pings + ) + assert self._base_server._outstanding_pings == 0 self._base_server.on_connection_lost(reason) diff --git a/test/http2_test/test_rst_after_data.py b/test/http2_test/test_rst_after_data.py index c867dc7a4406b..fb4d457a1f36d 100644 --- a/test/http2_test/test_rst_after_data.py +++ b/test/http2_test/test_rst_after_data.py @@ -20,12 +20,12 @@ class TestcaseRstStreamAfterData(object): In response to an incoming request, this test sends headers, followed by data, followed by a reset stream frame. Client asserts that the RPC failed. Client needs to deliver the complete message to the application layer. - """ + """ def __init__(self): self._base_server = http2_base_server.H2ProtocolBaseServer() - self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['SendDone'] = self.on_send_done + self._base_server._handlers["DataReceived"] = self.on_data_received + self._base_server._handlers["SendDone"] = self.on_send_done def get_base_server(self): return self._base_server @@ -35,7 +35,8 @@ def on_data_received(self, event): sr = self._base_server.parse_received_data(event.stream_id) if sr: response_data = self._base_server.default_response_data( - sr.response_size) + sr.response_size + ) self._ready_to_send = True self._base_server.setup_send(response_data, event.stream_id) # send reset stream diff --git a/test/http2_test/test_rst_after_header.py b/test/http2_test/test_rst_after_header.py index 1e2ddcbd4cf27..70c3d251b1775 100644 --- a/test/http2_test/test_rst_after_header.py +++ b/test/http2_test/test_rst_after_header.py @@ -19,12 +19,13 @@ class TestcaseRstStreamAfterHeader(object): """ In response to an incoming request, this test sends headers, followed by a reset stream frame. Client asserts that the RPC failed. - """ + """ def __init__(self): self._base_server = http2_base_server.H2ProtocolBaseServer() self._base_server._handlers[ - 'RequestReceived'] = self.on_request_received + "RequestReceived" + ] = self.on_request_received def get_base_server(self): return self._base_server diff --git a/test/http2_test/test_rst_during_data.py b/test/http2_test/test_rst_during_data.py index c34954e4f69cf..d3fe3737b6c92 100644 --- a/test/http2_test/test_rst_during_data.py +++ b/test/http2_test/test_rst_during_data.py @@ -20,12 +20,12 @@ class TestcaseRstStreamDuringData(object): In response to an incoming request, this test sends headers, followed by some data, followed by a reset stream frame. Client asserts that the RPC failed and does not deliver the message to the application. - """ + """ def __init__(self): self._base_server = http2_base_server.H2ProtocolBaseServer() - self._base_server._handlers['DataReceived'] = self.on_data_received - self._base_server._handlers['SendDone'] = self.on_send_done + self._base_server._handlers["DataReceived"] = self.on_data_received + self._base_server._handlers["SendDone"] = self.on_send_done def get_base_server(self): return self._base_server @@ -35,12 +35,14 @@ def on_data_received(self, event): sr = self._base_server.parse_received_data(event.stream_id) if sr: response_data = self._base_server.default_response_data( - sr.response_size) + sr.response_size + ) self._ready_to_send = True response_len = len(response_data) - truncated_response_data = response_data[0:response_len / 2] - self._base_server.setup_send(truncated_response_data, - event.stream_id) + truncated_response_data = response_data[0 : response_len / 2] + self._base_server.setup_send( + truncated_response_data, event.stream_id + ) def on_send_done(self, stream_id): self._base_server.send_reset_stream() diff --git a/tools/buildgen/_mako_renderer.py b/tools/buildgen/_mako_renderer.py index 8a46fbe49dbe5..13099603c3518 100755 --- a/tools/buildgen/_mako_renderer.py +++ b/tools/buildgen/_mako_renderer.py @@ -32,10 +32,11 @@ from mako.template import Template import yaml -PROJECT_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", - "..") +PROJECT_ROOT = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", ".." +) # TODO(lidiz) find a better way for plugins to reference each other -sys.path.append(os.path.join(PROJECT_ROOT, 'tools', 'buildgen', 'plugins')) +sys.path.append(os.path.join(PROJECT_ROOT, "tools", "buildgen", "plugins")) def out(msg: str) -> None: @@ -43,8 +44,10 @@ def out(msg: str) -> None: def showhelp() -> None: - out('mako-renderer.py [-o out] [-m cache] [-P preprocessed_input] [-d dict] [-d dict...]' - ' [-t template] [-w preprocessed_output]') + out( + "mako-renderer.py [-o out] [-m cache] [-P preprocessed_input] [-d dict]" + " [-d dict...] [-t template] [-w preprocessed_output]" + ) def render_template(template: Template, context: Context) -> None: @@ -72,40 +75,40 @@ def main(argv: List[str]) -> None: output_merged = None try: - opts, args = getopt.getopt(argv, 'hM:m:o:t:P:') + opts, args = getopt.getopt(argv, "hM:m:o:t:P:") except getopt.GetoptError: - out('Unknown option') + out("Unknown option") showhelp() sys.exit(2) for opt, arg in opts: - if opt == '-h': - out('Displaying showhelp') + if opt == "-h": + out("Displaying showhelp") showhelp() sys.exit() - elif opt == '-o': + elif opt == "-o": if got_output: - out('Got more than one output') + out("Got more than one output") showhelp() sys.exit(3) got_output = True output_name = arg - elif opt == '-m': + elif opt == "-m": if module_directory is not None: - out('Got more than one cache directory') + out("Got more than one cache directory") showhelp() sys.exit(4) module_directory = arg - elif opt == '-M': + elif opt == "-M": if output_merged is not None: - out('Got more than one output merged path') + out("Got more than one output merged path") showhelp() sys.exit(5) output_merged = arg - elif opt == '-P': + elif opt == "-P": assert not got_preprocessed_input assert json_dict == {} - with open(arg, 'rb') as dict_file: + with open(arg, "rb") as dict_file: dictionary = pickle.load(dict_file) got_preprocessed_input = True @@ -117,13 +120,16 @@ def main(argv: List[str]) -> None: for src in srcs: if isinstance(src, str): assert len(srcs) == 1 - template = Template(src, - filename=arg, - module_directory=module_directory, - lookup=TemplateLookup(directories=['.'])) - with open(output_name, 'w') as output_file: - render_template(template, Context(output_file, - **dictionary)) + template = Template( + src, + filename=arg, + module_directory=module_directory, + lookup=TemplateLookup(directories=["."]), + ) + with open(output_name, "w") as output_file: + render_template( + template, Context(output_file, **dictionary) + ) else: # we have optional control data: this template represents # a directory @@ -136,12 +142,12 @@ def main(argv: List[str]) -> None: shutil.rmtree(output_name, ignore_errors=True) cleared_dir = True items = [] - if 'foreach' in src: - for el in dictionary[src['foreach']]: - if 'cond' in src: + if "foreach" in src: + for el in dictionary[src["foreach"]]: + if "cond" in src: args = dict(dictionary) - args['selected'] = el - if not eval(src['cond'], {}, args): + args["selected"] = el + if not eval(src["cond"], {}, args): continue items.append(el) assert items @@ -149,24 +155,25 @@ def main(argv: List[str]) -> None: items = [None] for item in items: args = dict(dictionary) - args['selected'] = item + args["selected"] = item item_output_name = os.path.join( - output_name, - Template(src['output_name']).render(**args)) + output_name, Template(src["output_name"]).render(**args) + ) if not os.path.exists(os.path.dirname(item_output_name)): os.makedirs(os.path.dirname(item_output_name)) template = Template( - src['template'], + src["template"], filename=arg, module_directory=module_directory, - lookup=TemplateLookup(directories=['.'])) - with open(item_output_name, 'w') as output_file: + lookup=TemplateLookup(directories=["."]), + ) + with open(item_output_name, "w") as output_file: render_template(template, Context(output_file, **args)) if not got_input and not preprocessed_output: - out('Got nothing to do') + out("Got nothing to do") showhelp() -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/tools/buildgen/_utils.py b/tools/buildgen/_utils.py index c3dfb8d5c8a92..e6634e8a4175d 100644 --- a/tools/buildgen/_utils.py +++ b/tools/buildgen/_utils.py @@ -23,7 +23,7 @@ def import_python_module(path: str) -> types.ModuleType: """Imports the Python file at the given path, returns a module object.""" - module_name = os.path.basename(path).replace('.py', '') + module_name = os.path.basename(path).replace(".py", "") spec = importlib.util.spec_from_file_location(module_name, path) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module @@ -59,7 +59,7 @@ def merge_json(dst: Union[Mapping, List], add: Union[Mapping, List]) -> None: if isinstance(dst, dict) and isinstance(add, dict): for k, v in list(add.items()): if k in dst: - if k.startswith('#'): + if k.startswith("#"): continue merge_json(dst[k], v) else: @@ -68,5 +68,6 @@ def merge_json(dst: Union[Mapping, List], add: Union[Mapping, List]) -> None: dst.extend(add) else: raise TypeError( - 'Tried to merge incompatible objects %s %s\n\n%r\n\n%r' % - (type(dst).__name__, type(add).__name__, dst, add)) + "Tried to merge incompatible objects %s %s\n\n%r\n\n%r" + % (type(dst).__name__, type(add).__name__, dst, add) + ) diff --git a/tools/buildgen/build_cleaner.py b/tools/buildgen/build_cleaner.py index 72546631c47ff..6071ec5053537 100755 --- a/tools/buildgen/build_cleaner.py +++ b/tools/buildgen/build_cleaner.py @@ -21,20 +21,35 @@ import yaml -TEST = (os.environ.get('TEST', 'false') == 'true') +TEST = os.environ.get("TEST", "false") == "true" _TOP_LEVEL_KEYS = [ - 'settings', 'proto_deps', 'filegroups', 'libs', 'targets', 'vspackages' + "settings", + "proto_deps", + "filegroups", + "libs", + "targets", + "vspackages", ] _ELEM_KEYS = [ - 'name', 'gtest', 'cpu_cost', 'flaky', 'build', 'run', 'language', - 'public_headers', 'headers', 'src', 'deps' + "name", + "gtest", + "cpu_cost", + "flaky", + "build", + "run", + "language", + "public_headers", + "headers", + "src", + "deps", ] def repr_ordered_dict(dumper, odict): - return dumper.represent_mapping('tag:yaml.org,2002:map', - list(odict.items())) + return dumper.represent_mapping( + "tag:yaml.org,2002:map", list(odict.items()) + ) yaml.add_representer(collections.OrderedDict, repr_ordered_dict) @@ -43,7 +58,7 @@ def repr_ordered_dict(dumper, odict): def _rebuild_as_ordered_dict(indict, special_keys): outdict = collections.OrderedDict() for key in sorted(indict.keys()): - if '#' in key: + if "#" in key: outdict[key] = indict[key] for key in special_keys: if key in indict: @@ -51,18 +66,18 @@ def _rebuild_as_ordered_dict(indict, special_keys): for key in sorted(indict.keys()): if key in special_keys: continue - if '#' in key: + if "#" in key: continue outdict[key] = indict[key] return outdict def _clean_elem(indict): - for name in ['public_headers', 'headers', 'src']: + for name in ["public_headers", "headers", "src"]: if name not in indict: continue inlist = indict[name] - protos = list(x for x in inlist if os.path.splitext(x)[1] == '.proto') + protos = list(x for x in inlist if os.path.splitext(x)[1] == ".proto") others = set(x for x in inlist if x not in protos) indict[name] = protos + sorted(others) return _rebuild_as_ordered_dict(indict, _ELEM_KEYS) @@ -71,21 +86,23 @@ def _clean_elem(indict): def cleaned_build_yaml_dict_as_string(indict): """Takes dictionary which represents yaml file and returns the cleaned-up yaml string""" js = _rebuild_as_ordered_dict(indict, _TOP_LEVEL_KEYS) - for grp in ['filegroups', 'libs', 'targets']: + for grp in ["filegroups", "libs", "targets"]: if grp not in js: continue - js[grp] = sorted([_clean_elem(x) for x in js[grp]], - key=lambda x: (x.get('language', '_'), x['name'])) + js[grp] = sorted( + [_clean_elem(x) for x in js[grp]], + key=lambda x: (x.get("language", "_"), x["name"]), + ) output = yaml.dump(js, indent=2, width=80, default_flow_style=False) # massage out trailing whitespace lines = [] for line in output.splitlines(): - lines.append(line.rstrip() + '\n') - output = ''.join(lines) + lines.append(line.rstrip() + "\n") + output = "".join(lines) return output -if __name__ == '__main__': +if __name__ == "__main__": for filename in sys.argv[1:]: with open(filename) as f: js = yaml.safe_load(f) @@ -94,8 +111,9 @@ def cleaned_build_yaml_dict_as_string(indict): with open(filename) as f: if not f.read() == output: raise Exception( - 'Looks like build-cleaner.py has not been run for file "%s"?' - % filename) + "Looks like build-cleaner.py has not been run for file" + ' "%s"?' % filename + ) else: - with open(filename, 'w') as f: + with open(filename, "w") as f: f.write(output) diff --git a/tools/buildgen/extract_metadata_from_bazel_xml.py b/tools/buildgen/extract_metadata_from_bazel_xml.py index eb89e77040621..044eddf765162 100755 --- a/tools/buildgen/extract_metadata_from_bazel_xml.py +++ b/tools/buildgen/extract_metadata_from_bazel_xml.py @@ -63,12 +63,9 @@ class ExternalProtoLibrary: http_archive in Bazel. """ - def __init__(self, - destination, - proto_prefix, - urls=None, - hash="", - strip_prefix=""): + def __init__( + self, destination, proto_prefix, urls=None, hash="", strip_prefix="" + ): self.destination = destination self.proto_prefix = proto_prefix if urls is None: @@ -80,24 +77,27 @@ def __init__(self, EXTERNAL_PROTO_LIBRARIES = { - 'envoy_api': - ExternalProtoLibrary(destination='third_party/envoy-api', - proto_prefix='third_party/envoy-api/'), - 'com_google_googleapis': - ExternalProtoLibrary(destination='third_party/googleapis', - proto_prefix='third_party/googleapis/'), - 'com_github_cncf_udpa': - ExternalProtoLibrary(destination='third_party/xds', - proto_prefix='third_party/xds/'), - 'opencensus_proto': - ExternalProtoLibrary(destination='third_party/opencensus-proto/src', - proto_prefix='third_party/opencensus-proto/src/'), + "envoy_api": ExternalProtoLibrary( + destination="third_party/envoy-api", + proto_prefix="third_party/envoy-api/", + ), + "com_google_googleapis": ExternalProtoLibrary( + destination="third_party/googleapis", + proto_prefix="third_party/googleapis/", + ), + "com_github_cncf_udpa": ExternalProtoLibrary( + destination="third_party/xds", proto_prefix="third_party/xds/" + ), + "opencensus_proto": ExternalProtoLibrary( + destination="third_party/opencensus-proto/src", + proto_prefix="third_party/opencensus-proto/src/", + ), } def _maybe_get_internal_path(name: str) -> Optional[str]: for key in EXTERNAL_PROTO_LIBRARIES: - if name.startswith('@' + key): + if name.startswith("@" + key): return key return None @@ -105,51 +105,52 @@ def _maybe_get_internal_path(name: str) -> Optional[str]: def _bazel_query_xml_tree(query: str) -> ET.Element: """Get xml output of bazel query invocation, parsed as XML tree""" output = subprocess.check_output( - ['tools/bazel', 'query', '--noimplicit_deps', '--output', 'xml', query]) + ["tools/bazel", "query", "--noimplicit_deps", "--output", "xml", query] + ) return ET.fromstring(output) def _rule_dict_from_xml_node(rule_xml_node): """Converts XML node representing a rule (obtained from "bazel query --output xml") to a dictionary that contains all the metadata we will need.""" result = { - 'class': rule_xml_node.attrib.get('class'), - 'name': rule_xml_node.attrib.get('name'), - 'srcs': [], - 'hdrs': [], - 'deps': [], - 'data': [], - 'tags': [], - 'args': [], - 'generator_function': None, - 'size': None, - 'flaky': False, - 'actual': None, # the real target name for aliases + "class": rule_xml_node.attrib.get("class"), + "name": rule_xml_node.attrib.get("name"), + "srcs": [], + "hdrs": [], + "deps": [], + "data": [], + "tags": [], + "args": [], + "generator_function": None, + "size": None, + "flaky": False, + "actual": None, # the real target name for aliases } for child in rule_xml_node: # all the metadata we want is stored under "list" tags - if child.tag == 'list': - list_name = child.attrib['name'] - if list_name in ['srcs', 'hdrs', 'deps', 'data', 'tags', 'args']: - result[list_name] += [item.attrib['value'] for item in child] - if child.tag == 'string': - string_name = child.attrib['name'] - if string_name in ['generator_function', 'size']: - result[string_name] = child.attrib['value'] - if child.tag == 'boolean': - bool_name = child.attrib['name'] - if bool_name in ['flaky']: - result[bool_name] = child.attrib['value'] == 'true' - if child.tag == 'label': + if child.tag == "list": + list_name = child.attrib["name"] + if list_name in ["srcs", "hdrs", "deps", "data", "tags", "args"]: + result[list_name] += [item.attrib["value"] for item in child] + if child.tag == "string": + string_name = child.attrib["name"] + if string_name in ["generator_function", "size"]: + result[string_name] = child.attrib["value"] + if child.tag == "boolean": + bool_name = child.attrib["name"] + if bool_name in ["flaky"]: + result[bool_name] = child.attrib["value"] == "true" + if child.tag == "label": # extract actual name for alias rules - label_name = child.attrib['name'] - if label_name in ['actual']: - actual_name = child.attrib.get('value', None) + label_name = child.attrib["name"] + if label_name in ["actual"]: + actual_name = child.attrib.get("value", None) if actual_name: - result['actual'] = actual_name + result["actual"] = actual_name # HACK: since we do a lot of transitive dependency scanning, # make it seem that the actual name is a dependency of the alias rule # (aliases don't have dependencies themselves) - result['deps'].append(actual_name) + result["deps"].append(actual_name) return result @@ -157,53 +158,53 @@ def _extract_rules_from_bazel_xml(xml_tree): """Extract bazel rules from an XML tree node obtained from "bazel query --output xml" command.""" result = {} for child in xml_tree: - if child.tag == 'rule': + if child.tag == "rule": rule_dict = _rule_dict_from_xml_node(child) - rule_clazz = rule_dict['class'] - rule_name = rule_dict['name'] + rule_clazz = rule_dict["class"] + rule_name = rule_dict["name"] if rule_clazz in [ - 'cc_library', - 'cc_binary', - 'cc_test', - 'cc_proto_library', - 'cc_proto_gen_validate', - 'proto_library', - 'upb_proto_library', - 'upb_proto_reflection_library', - 'alias', + "cc_library", + "cc_binary", + "cc_test", + "cc_proto_library", + "cc_proto_gen_validate", + "proto_library", + "upb_proto_library", + "upb_proto_reflection_library", + "alias", ]: if rule_name in result: - raise Exception('Rule %s already present' % rule_name) + raise Exception("Rule %s already present" % rule_name) result[rule_name] = rule_dict return result def _get_bazel_label(target_name: str) -> str: - if target_name.startswith('@'): + if target_name.startswith("@"): return target_name - if ':' in target_name: - return '//%s' % target_name + if ":" in target_name: + return "//%s" % target_name else: - return '//:%s' % target_name + return "//:%s" % target_name def _extract_source_file_path(label: str) -> str: """Gets relative path to source file from bazel deps listing""" - if label.startswith('//'): - label = label[len('//'):] + if label.startswith("//"): + label = label[len("//") :] # labels in form //:src/core/lib/surface/call_test_only.h - if label.startswith(':'): - label = label[len(':'):] + if label.startswith(":"): + label = label[len(":") :] # labels in form //test/core/util:port.cc - label = label.replace(':', '/') + label = label.replace(":", "/") return label def _extract_public_headers(bazel_rule: BuildMetadata) -> List[str]: """Gets list of public headers from a bazel rule""" result = [] - for dep in bazel_rule['hdrs']: - if dep.startswith('//:include/') and dep.endswith('.h'): + for dep in bazel_rule["hdrs"]: + if dep.startswith("//:include/") and dep.endswith(".h"): result.append(_extract_source_file_path(dep)) return list(sorted(result)) @@ -211,9 +212,12 @@ def _extract_public_headers(bazel_rule: BuildMetadata) -> List[str]: def _extract_nonpublic_headers(bazel_rule: BuildMetadata) -> List[str]: """Gets list of non-public headers from a bazel rule""" result = [] - for dep in bazel_rule['hdrs']: - if dep.startswith('//') and not dep.startswith( - '//:include/') and dep.endswith('.h'): + for dep in bazel_rule["hdrs"]: + if ( + dep.startswith("//") + and not dep.startswith("//:include/") + and dep.endswith(".h") + ): result.append(_extract_source_file_path(dep)) return list(sorted(result)) @@ -221,9 +225,9 @@ def _extract_nonpublic_headers(bazel_rule: BuildMetadata) -> List[str]: def _extract_sources(bazel_rule: BuildMetadata) -> List[str]: """Gets list of source files from a bazel rule""" result = [] - for src in bazel_rule['srcs']: - if src.endswith('.cc') or src.endswith('.c') or src.endswith('.proto'): - if src.startswith('//'): + for src in bazel_rule["srcs"]: + if src.endswith(".cc") or src.endswith(".c") or src.endswith(".proto"): + if src.startswith("//"): # This source file is local to gRPC result.append(_extract_source_file_path(src)) else: @@ -235,20 +239,26 @@ def _extract_sources(bazel_rule: BuildMetadata) -> List[str]: if external_proto_library_name is not None: result.append( src.replace( - '@%s//' % external_proto_library_name, + "@%s//" % external_proto_library_name, EXTERNAL_PROTO_LIBRARIES[ - external_proto_library_name].proto_prefix). - replace(':', '/')) + external_proto_library_name + ].proto_prefix, + ).replace(":", "/") + ) return list(sorted(result)) -def _extract_deps(bazel_rule: BuildMetadata, - bazel_rules: BuildDict) -> List[str]: +def _extract_deps( + bazel_rule: BuildMetadata, bazel_rules: BuildDict +) -> List[str]: """Gets list of deps from from a bazel rule""" - deps = set(bazel_rule['deps']) - for src in bazel_rule['srcs']: - if not src.endswith('.cc') and not src.endswith( - '.c') and not src.endswith('.proto'): + deps = set(bazel_rule["deps"]) + for src in bazel_rule["srcs"]: + if ( + not src.endswith(".cc") + and not src.endswith(".c") + and not src.endswith(".proto") + ): if src in bazel_rules: # This label doesn't point to a source file, but another Bazel # target. This is required for :pkg_cc_proto_validate targets, @@ -257,8 +267,9 @@ def _extract_deps(bazel_rule: BuildMetadata, return list(sorted(list(deps))) -def _create_target_from_bazel_rule(target_name: str, - bazel_rules: BuildDict) -> BuildMetadata: +def _create_target_from_bazel_rule( + target_name: str, bazel_rules: BuildDict +) -> BuildMetadata: """Create build.yaml-like target definition from bazel metadata""" bazel_rule = bazel_rules[_get_bazel_label(target_name)] @@ -267,31 +278,31 @@ def _create_target_from_bazel_rule(target_name: str, # and only later we will populate the public fields (once we do some extra # postprocessing). result = { - 'name': target_name, - '_PUBLIC_HEADERS_BAZEL': _extract_public_headers(bazel_rule), - '_HEADERS_BAZEL': _extract_nonpublic_headers(bazel_rule), - '_SRC_BAZEL': _extract_sources(bazel_rule), - '_DEPS_BAZEL': _extract_deps(bazel_rule, bazel_rules), - 'public_headers': bazel_rule['_COLLAPSED_PUBLIC_HEADERS'], - 'headers': bazel_rule['_COLLAPSED_HEADERS'], - 'src': bazel_rule['_COLLAPSED_SRCS'], - 'deps': bazel_rule['_COLLAPSED_DEPS'], + "name": target_name, + "_PUBLIC_HEADERS_BAZEL": _extract_public_headers(bazel_rule), + "_HEADERS_BAZEL": _extract_nonpublic_headers(bazel_rule), + "_SRC_BAZEL": _extract_sources(bazel_rule), + "_DEPS_BAZEL": _extract_deps(bazel_rule, bazel_rules), + "public_headers": bazel_rule["_COLLAPSED_PUBLIC_HEADERS"], + "headers": bazel_rule["_COLLAPSED_HEADERS"], + "src": bazel_rule["_COLLAPSED_SRCS"], + "deps": bazel_rule["_COLLAPSED_DEPS"], } return result def _external_dep_name_from_bazel_dependency(bazel_dep: str) -> Optional[str]: """Returns name of dependency if external bazel dependency is provided or None""" - if bazel_dep.startswith('@com_google_absl//'): + if bazel_dep.startswith("@com_google_absl//"): # special case for add dependency on one of the absl libraries (there is not just one absl library) - prefixlen = len('@com_google_absl//') + prefixlen = len("@com_google_absl//") return bazel_dep[prefixlen:] - elif bazel_dep == '//external:upb_lib': - return 'upb' - elif bazel_dep == '//external:benchmark': - return 'benchmark' - elif bazel_dep == '//external:libssl': - return 'libssl' + elif bazel_dep == "//external:upb_lib": + return "upb" + elif bazel_dep == "//external:benchmark": + return "benchmark" + elif bazel_dep == "//external:libssl": + return "libssl" else: # all the other external deps such as protobuf, cares, zlib # don't need to be listed explicitly, they are handled automatically @@ -300,8 +311,8 @@ def _external_dep_name_from_bazel_dependency(bazel_dep: str) -> Optional[str]: def _compute_transitive_metadata( - rule_name: str, bazel_rules: Any, - bazel_label_to_dep_name: Dict[str, str]) -> None: + rule_name: str, bazel_rules: Any, bazel_label_to_dep_name: Dict[str, str] +) -> None: """Computes the final build metadata for Bazel target with rule_name. The dependencies that will appear on the deps list are: @@ -355,23 +366,27 @@ def _compute_transitive_metadata( if external_dep_name_maybe is None: if "_PROCESSING_DONE" not in bazel_rules[dep]: # This item is not processed before, compute now - _compute_transitive_metadata(dep, bazel_rules, - bazel_label_to_dep_name) - transitive_deps.update(bazel_rules[dep].get( - '_TRANSITIVE_DEPS', [])) + _compute_transitive_metadata( + dep, bazel_rules, bazel_label_to_dep_name + ) + transitive_deps.update( + bazel_rules[dep].get("_TRANSITIVE_DEPS", []) + ) collapsed_deps.update( - collapsed_deps, bazel_rules[dep].get('_COLLAPSED_DEPS', [])) - exclude_deps.update(bazel_rules[dep].get('_EXCLUDE_DEPS', [])) + collapsed_deps, bazel_rules[dep].get("_COLLAPSED_DEPS", []) + ) + exclude_deps.update(bazel_rules[dep].get("_EXCLUDE_DEPS", [])) # This dep is a public target, add it as a dependency if dep in bazel_label_to_dep_name: transitive_deps.update([bazel_label_to_dep_name[dep]]) - collapsed_deps.update(collapsed_deps, - [bazel_label_to_dep_name[dep]]) + collapsed_deps.update( + collapsed_deps, [bazel_label_to_dep_name[dep]] + ) # Add all the transitive deps of our every public dep to exclude # list since we want to avoid building sources that are already # built by our dependencies - exclude_deps.update(bazel_rules[dep]['_TRANSITIVE_DEPS']) + exclude_deps.update(bazel_rules[dep]["_TRANSITIVE_DEPS"]) continue # This dep is an external target, add it as a dependency @@ -385,7 +400,8 @@ def _compute_transitive_metadata( # Calculate transitive public deps (needed for collapsing sources) transitive_public_deps = set( - [x for x in transitive_deps if x in bazel_label_to_dep_name]) + [x for x in transitive_deps if x in bazel_label_to_dep_name] + ) # Remove intermediate targets that our public dependencies already depend # on. This is the step that further shorten the deps list. @@ -415,19 +431,22 @@ def _compute_transitive_metadata( if dep in bazel_rules: collapsed_srcs.update(_extract_sources(bazel_rules[dep])) collapsed_public_headers.update( - _extract_public_headers(bazel_rules[dep])) + _extract_public_headers(bazel_rules[dep]) + ) collapsed_headers.update( - _extract_nonpublic_headers(bazel_rules[dep])) + _extract_nonpublic_headers(bazel_rules[dep]) + ) # This item is a "visited" flag - bazel_rule['_PROCESSING_DONE'] = True + bazel_rule["_PROCESSING_DONE"] = True # Following items are described in the docstinrg. - bazel_rule['_TRANSITIVE_DEPS'] = list(sorted(transitive_deps)) - bazel_rule['_COLLAPSED_DEPS'] = list(sorted(collapsed_deps)) - bazel_rule['_COLLAPSED_SRCS'] = list(sorted(collapsed_srcs)) - bazel_rule['_COLLAPSED_PUBLIC_HEADERS'] = list( - sorted(collapsed_public_headers)) - bazel_rule['_COLLAPSED_HEADERS'] = list(sorted(collapsed_headers)) - bazel_rule['_EXCLUDE_DEPS'] = list(sorted(exclude_deps)) + bazel_rule["_TRANSITIVE_DEPS"] = list(sorted(transitive_deps)) + bazel_rule["_COLLAPSED_DEPS"] = list(sorted(collapsed_deps)) + bazel_rule["_COLLAPSED_SRCS"] = list(sorted(collapsed_srcs)) + bazel_rule["_COLLAPSED_PUBLIC_HEADERS"] = list( + sorted(collapsed_public_headers) + ) + bazel_rule["_COLLAPSED_HEADERS"] = list(sorted(collapsed_headers)) + bazel_rule["_EXCLUDE_DEPS"] = list(sorted(exclude_deps)) # TODO(jtattermusch): deduplicate with transitive_dependencies.py (which has a @@ -440,8 +459,9 @@ def _compute_transitive_metadata( # bazel builds it is customary to define larger number of smaller # "sublibraries". The need for elision (and expansion) of intermediate libraries # can be re-evaluated in the future. -def _populate_transitive_metadata(bazel_rules: Any, - public_dep_names: Iterable[str]) -> None: +def _populate_transitive_metadata( + bazel_rules: Any, public_dep_names: Iterable[str] +) -> None: """Add 'transitive_deps' field for each of the rules""" # Create the map between Bazel label and public dependency name bazel_label_to_dep_name = {} @@ -451,28 +471,30 @@ def _populate_transitive_metadata(bazel_rules: Any, # Make sure we reached all the Bazel rules # TODO(lidiz) potentially we could only update a subset of rules for rule_name in bazel_rules: - if '_PROCESSING_DONE' not in bazel_rules[rule_name]: - _compute_transitive_metadata(rule_name, bazel_rules, - bazel_label_to_dep_name) + if "_PROCESSING_DONE" not in bazel_rules[rule_name]: + _compute_transitive_metadata( + rule_name, bazel_rules, bazel_label_to_dep_name + ) def update_test_metadata_with_transitive_metadata( - all_extra_metadata: BuildDict, bazel_rules: BuildDict) -> None: + all_extra_metadata: BuildDict, bazel_rules: BuildDict +) -> None: """Patches test build metadata with transitive metadata.""" for lib_name, lib_dict in list(all_extra_metadata.items()): # Skip if it isn't not an test - if lib_dict.get('build') != 'test' or lib_dict.get('_TYPE') != 'target': + if lib_dict.get("build") != "test" or lib_dict.get("_TYPE") != "target": continue bazel_rule = bazel_rules[_get_bazel_label(lib_name)] - if '//external:benchmark' in bazel_rule['_TRANSITIVE_DEPS']: - lib_dict['benchmark'] = True - lib_dict['defaults'] = 'benchmark' + if "//external:benchmark" in bazel_rule["_TRANSITIVE_DEPS"]: + lib_dict["benchmark"] = True + lib_dict["defaults"] = "benchmark" - if '//external:gtest' in bazel_rule['_TRANSITIVE_DEPS']: - lib_dict['gtest'] = True - lib_dict['language'] = 'c++' + if "//external:gtest" in bazel_rule["_TRANSITIVE_DEPS"]: + lib_dict["gtest"] = True + lib_dict["language"] = "c++" def _get_transitive_protos(bazel_rules, t): @@ -485,12 +507,12 @@ def _get_transitive_protos(bazel_rules, t): name = que.pop(0) rule = bazel_rules.get(name, None) if rule: - for dep in rule['deps']: + for dep in rule["deps"]: if dep not in visited: visited.add(dep) que.append(dep) - for src in rule['srcs']: - if src.endswith('.proto'): + for src in rule["srcs"]: + if src.endswith(".proto"): ret.append(src) return list(set(ret)) @@ -498,36 +520,45 @@ def _get_transitive_protos(bazel_rules, t): def _expand_upb_proto_library_rules(bazel_rules): # Expand the .proto files from UPB proto library rules into the pre-generated # upb.h and upb.c files. - GEN_UPB_ROOT = '//:src/core/ext/upb-generated/' - GEN_UPBDEFS_ROOT = '//:src/core/ext/upbdefs-generated/' - EXTERNAL_LINKS = [('@com_google_protobuf//', 'src/'), - ('@com_google_googleapis//', ''), - ('@com_github_cncf_udpa//', ''), - ('@com_envoyproxy_protoc_gen_validate//', ''), - ('@envoy_api//', ''), ('@opencensus_proto//', '')] + GEN_UPB_ROOT = "//:src/core/ext/upb-generated/" + GEN_UPBDEFS_ROOT = "//:src/core/ext/upbdefs-generated/" + EXTERNAL_LINKS = [ + ("@com_google_protobuf//", "src/"), + ("@com_google_googleapis//", ""), + ("@com_github_cncf_udpa//", ""), + ("@com_envoyproxy_protoc_gen_validate//", ""), + ("@envoy_api//", ""), + ("@opencensus_proto//", ""), + ] for name, bazel_rule in bazel_rules.items(): - gen_func = bazel_rule.get('generator_function', None) - if gen_func in ('grpc_upb_proto_library', - 'grpc_upb_proto_reflection_library'): + gen_func = bazel_rule.get("generator_function", None) + if gen_func in ( + "grpc_upb_proto_library", + "grpc_upb_proto_reflection_library", + ): # get proto dependency - deps = bazel_rule['deps'] + deps = bazel_rule["deps"] if len(deps) != 1: raise Exception( - 'upb rule "{0}" should have 1 proto dependency but has "{1}"' - .format(name, deps)) + 'upb rule "{0}" should have 1 proto dependency but has' + ' "{1}"'.format(name, deps) + ) # deps is not properly fetched from bazel query for upb_proto_library target # so add the upb dependency manually - bazel_rule['deps'] = [ - '//external:upb_lib', '//external:upb_lib_descriptor', - '//external:upb_generated_code_support__only_for_generated_code_do_not_use__i_give_permission_to_break_me' + bazel_rule["deps"] = [ + "//external:upb_lib", + "//external:upb_lib_descriptor", + "//external:upb_generated_code_support__only_for_generated_code_do_not_use__i_give_permission_to_break_me", ] # populate the upb_proto_library rule with pre-generated upb headers # and sources using proto_rule protos = _get_transitive_protos(bazel_rules, deps[0]) if len(protos) == 0: raise Exception( - 'upb rule "{0}" should have at least one proto file.'. - format(name)) + 'upb rule "{0}" should have at least one proto file.'.format( + name + ) + ) srcs = [] hdrs = [] for proto_src in protos: @@ -536,23 +567,35 @@ def _expand_upb_proto_library_rules(bazel_rules): prefix_to_strip = external_link[0] + external_link[1] if not proto_src.startswith(prefix_to_strip): raise Exception( - 'Source file "{0}" in upb rule {1} does not have the expected prefix "{2}"' - .format(proto_src, name, prefix_to_strip)) - proto_src = proto_src[len(prefix_to_strip):] + 'Source file "{0}" in upb rule {1} does not' + ' have the expected prefix "{2}"'.format( + proto_src, name, prefix_to_strip + ) + ) + proto_src = proto_src[len(prefix_to_strip) :] break - if proto_src.startswith('@'): + if proto_src.startswith("@"): raise Exception('"{0}" is unknown workspace.'.format(name)) proto_src = _extract_source_file_path(proto_src) - ext = '.upb' if gen_func == 'grpc_upb_proto_library' else '.upbdefs' - root = GEN_UPB_ROOT if gen_func == 'grpc_upb_proto_library' else GEN_UPBDEFS_ROOT - srcs.append(root + proto_src.replace('.proto', ext + '.c')) - hdrs.append(root + proto_src.replace('.proto', ext + '.h')) - bazel_rule['srcs'] = srcs - bazel_rule['hdrs'] = hdrs - - -def _generate_build_metadata(build_extra_metadata: BuildDict, - bazel_rules: BuildDict) -> BuildDict: + ext = ( + ".upb" + if gen_func == "grpc_upb_proto_library" + else ".upbdefs" + ) + root = ( + GEN_UPB_ROOT + if gen_func == "grpc_upb_proto_library" + else GEN_UPBDEFS_ROOT + ) + srcs.append(root + proto_src.replace(".proto", ext + ".c")) + hdrs.append(root + proto_src.replace(".proto", ext + ".h")) + bazel_rule["srcs"] = srcs + bazel_rule["hdrs"] = hdrs + + +def _generate_build_metadata( + build_extra_metadata: BuildDict, bazel_rules: BuildDict +) -> BuildDict: """Generate build metadata in build.yaml-like format bazel build metadata and build.yaml-specific "extra metadata".""" lib_names = list(build_extra_metadata.keys()) result = {} @@ -572,38 +615,48 @@ def _generate_build_metadata(build_extra_metadata: BuildDict, # The rename step needs to be made after we're done with most of processing logic # otherwise the already-renamed libraries will have different names than expected for lib_name in lib_names: - to_name = build_extra_metadata.get(lib_name, {}).get('_RENAME', None) + to_name = build_extra_metadata.get(lib_name, {}).get("_RENAME", None) if to_name: # store lib under the new name and also change its 'name' property if to_name in result: - raise Exception('Cannot rename target ' + str(lib_name) + ', ' + - str(to_name) + ' already exists.') + raise Exception( + "Cannot rename target " + + str(lib_name) + + ", " + + str(to_name) + + " already exists." + ) lib_dict = result.pop(lib_name) - lib_dict['name'] = to_name + lib_dict["name"] = to_name result[to_name] = lib_dict # dep names need to be updated as well for lib_dict_to_update in list(result.values()): - lib_dict_to_update['deps'] = list([ - to_name if dep == lib_name else dep - for dep in lib_dict_to_update['deps'] - ]) + lib_dict_to_update["deps"] = list( + [ + to_name if dep == lib_name else dep + for dep in lib_dict_to_update["deps"] + ] + ) return result def _convert_to_build_yaml_like(lib_dict: BuildMetadata) -> BuildYaml: lib_names = [ - lib_name for lib_name in list(lib_dict.keys()) - if lib_dict[lib_name].get('_TYPE', 'library') == 'library' + lib_name + for lib_name in list(lib_dict.keys()) + if lib_dict[lib_name].get("_TYPE", "library") == "library" ] target_names = [ - lib_name for lib_name in list(lib_dict.keys()) - if lib_dict[lib_name].get('_TYPE', 'library') == 'target' + lib_name + for lib_name in list(lib_dict.keys()) + if lib_dict[lib_name].get("_TYPE", "library") == "target" ] test_names = [ - lib_name for lib_name in list(lib_dict.keys()) - if lib_dict[lib_name].get('_TYPE', 'library') == 'test' + lib_name + for lib_name in list(lib_dict.keys()) + if lib_dict[lib_name].get("_TYPE", "library") == "test" ] # list libraries and targets in predefined order @@ -614,29 +667,31 @@ def _convert_to_build_yaml_like(lib_dict: BuildMetadata) -> BuildYaml: # get rid of temporary private fields prefixed with "_" and some other useless fields for lib in lib_list: for field_to_remove in [ - k for k in list(lib.keys()) if k.startswith('_') + k for k in list(lib.keys()) if k.startswith("_") ]: lib.pop(field_to_remove, None) for target in target_list: for field_to_remove in [ - k for k in list(target.keys()) if k.startswith('_') + k for k in list(target.keys()) if k.startswith("_") ]: target.pop(field_to_remove, None) - target.pop('public_headers', - None) # public headers make no sense for targets + target.pop( + "public_headers", None + ) # public headers make no sense for targets for test in test_list: for field_to_remove in [ - k for k in list(test.keys()) if k.startswith('_') + k for k in list(test.keys()) if k.startswith("_") ]: test.pop(field_to_remove, None) - test.pop('public_headers', - None) # public headers make no sense for tests + test.pop( + "public_headers", None + ) # public headers make no sense for tests build_yaml_like = { - 'libs': lib_list, - 'filegroups': [], - 'targets': target_list, - 'tests': test_list, + "libs": lib_list, + "filegroups": [], + "targets": target_list, + "tests": test_list, } return build_yaml_like @@ -645,10 +700,10 @@ def _extract_cc_tests(bazel_rules: BuildDict) -> List[str]: """Gets list of cc_test tests from bazel rules""" result = [] for bazel_rule in list(bazel_rules.values()): - if bazel_rule['class'] == 'cc_test': - test_name = bazel_rule['name'] - if test_name.startswith('//'): - prefixlen = len('//') + if bazel_rule["class"] == "cc_test": + test_name = bazel_rule["name"] + if test_name.startswith("//"): + prefixlen = len("//") result.append(test_name[prefixlen:]) return list(sorted(result)) @@ -657,111 +712,129 @@ def _exclude_unwanted_cc_tests(tests: List[str]) -> List[str]: """Filters out bazel tests that we don't want to run with other build systems or we cannot build them reasonably""" # most qps tests are autogenerated, we are fine without them - tests = [test for test in tests if not test.startswith('test/cpp/qps:')] + tests = [test for test in tests if not test.startswith("test/cpp/qps:")] # microbenchmarks aren't needed for checking correctness tests = [ - test for test in tests - if not test.startswith('test/cpp/microbenchmarks:') + test + for test in tests + if not test.startswith("test/cpp/microbenchmarks:") ] tests = [ - test for test in tests - if not test.startswith('test/core/promise/benchmark:') + test + for test in tests + if not test.startswith("test/core/promise/benchmark:") ] # we have trouble with census dependency outside of bazel tests = [ - test for test in tests - if not test.startswith('test/cpp/ext/filters/census:') and - not test.startswith('test/core/xds:xds_channel_stack_modifier_test') and - not test.startswith('test/cpp/ext/gcp:') and - not test.startswith('test/cpp/ext/filters/logging:') and - not test.startswith('test/cpp/interop:observability_interop') + test + for test in tests + if not test.startswith("test/cpp/ext/filters/census:") + and not test.startswith("test/core/xds:xds_channel_stack_modifier_test") + and not test.startswith("test/cpp/ext/gcp:") + and not test.startswith("test/cpp/ext/filters/logging:") + and not test.startswith("test/cpp/interop:observability_interop") ] # missing opencensus/stats/stats.h tests = [ - test for test in tests if not test.startswith( - 'test/cpp/end2end:server_load_reporting_end2end_test') + test + for test in tests + if not test.startswith( + "test/cpp/end2end:server_load_reporting_end2end_test" + ) ] tests = [ - test for test in tests if not test.startswith( - 'test/cpp/server/load_reporter:lb_load_reporter_test') + test + for test in tests + if not test.startswith( + "test/cpp/server/load_reporter:lb_load_reporter_test" + ) ] # The test uses --running_under_bazel cmdline argument # To avoid the trouble needing to adjust it, we just skip the test tests = [ - test for test in tests if not test.startswith( - 'test/cpp/naming:resolver_component_tests_runner_invoker') + test + for test in tests + if not test.startswith( + "test/cpp/naming:resolver_component_tests_runner_invoker" + ) ] # the test requires 'client_crash_test_server' to be built tests = [ - test for test in tests - if not test.startswith('test/cpp/end2end:time_change_test') + test + for test in tests + if not test.startswith("test/cpp/end2end:time_change_test") ] # the test requires 'client_crash_test_server' to be built tests = [ - test for test in tests - if not test.startswith('test/cpp/end2end:client_crash_test') + test + for test in tests + if not test.startswith("test/cpp/end2end:client_crash_test") ] # the test requires 'server_crash_test_client' to be built tests = [ - test for test in tests - if not test.startswith('test/cpp/end2end:server_crash_test') + test + for test in tests + if not test.startswith("test/cpp/end2end:server_crash_test") ] # test never existed under build.yaml and it fails -> skip it tests = [ - test for test in tests - if not test.startswith('test/core/tsi:ssl_session_cache_test') + test + for test in tests + if not test.startswith("test/core/tsi:ssl_session_cache_test") ] # the binary of this test does not get built with cmake tests = [ - test for test in tests - if not test.startswith('test/cpp/util:channelz_sampler_test') + test + for test in tests + if not test.startswith("test/cpp/util:channelz_sampler_test") ] # we don't need to generate fuzzers outside of bazel - tests = [test for test in tests if not test.endswith('_fuzzer')] + tests = [test for test in tests if not test.endswith("_fuzzer")] return tests def _generate_build_extra_metadata_for_tests( - tests: List[str], bazel_rules: BuildDict) -> BuildDict: + tests: List[str], bazel_rules: BuildDict +) -> BuildDict: """For given tests, generate the "extra metadata" that we need for our "build.yaml"-like output. The extra metadata is generated from the bazel rule metadata by using a bunch of heuristics.""" test_metadata = {} for test in tests: - test_dict = {'build': 'test', '_TYPE': 'target'} + test_dict = {"build": "test", "_TYPE": "target"} bazel_rule = bazel_rules[_get_bazel_label(test)] - bazel_tags = bazel_rule['tags'] - if 'manual' in bazel_tags: + bazel_tags = bazel_rule["tags"] + if "manual" in bazel_tags: # don't run the tests marked as "manual" - test_dict['run'] = False + test_dict["run"] = False - if bazel_rule['flaky']: + if bazel_rule["flaky"]: # don't run tests that are marked as "flaky" under bazel # because that would only add noise for the run_tests.py tests # and seeing more failures for tests that we already know are flaky # doesn't really help anything - test_dict['run'] = False + test_dict["run"] = False - if 'no_uses_polling' in bazel_tags: - test_dict['uses_polling'] = False + if "no_uses_polling" in bazel_tags: + test_dict["uses_polling"] = False - if 'grpc_fuzzer' == bazel_rule['generator_function']: + if "grpc_fuzzer" == bazel_rule["generator_function"]: # currently we hand-list fuzzers instead of generating them automatically # because there's no way to obtain maxlen property from bazel BUILD file. - print(('skipping fuzzer ' + test)) + print(("skipping fuzzer " + test)) continue - if 'bazel_only' in bazel_tags: + if "bazel_only" in bazel_tags: continue # if any tags that restrict platform compatibility are present, @@ -771,42 +844,43 @@ def _generate_build_extra_metadata_for_tests( # is made (for tests where uses_polling=True). So for now, we just # assume all tests are compatible with linux and ignore the "no_linux" tag # completely. - known_platform_tags = set(['no_windows', 'no_mac']) + known_platform_tags = set(["no_windows", "no_mac"]) if set(bazel_tags).intersection(known_platform_tags): platforms = [] # assume all tests are compatible with linux and posix - platforms.append('linux') + platforms.append("linux") platforms.append( - 'posix') # there is no posix-specific tag in bazel BUILD - if 'no_mac' not in bazel_tags: - platforms.append('mac') - if 'no_windows' not in bazel_tags: - platforms.append('windows') - test_dict['platforms'] = platforms - - cmdline_args = bazel_rule['args'] + "posix" + ) # there is no posix-specific tag in bazel BUILD + if "no_mac" not in bazel_tags: + platforms.append("mac") + if "no_windows" not in bazel_tags: + platforms.append("windows") + test_dict["platforms"] = platforms + + cmdline_args = bazel_rule["args"] if cmdline_args: - test_dict['args'] = list(cmdline_args) + test_dict["args"] = list(cmdline_args) - if test.startswith('test/cpp'): - test_dict['language'] = 'c++' + if test.startswith("test/cpp"): + test_dict["language"] = "c++" - elif test.startswith('test/core'): - test_dict['language'] = 'c' + elif test.startswith("test/core"): + test_dict["language"] = "c" else: - raise Exception('wrong test' + test) + raise Exception("wrong test" + test) # short test name without the path. # There can be name collisions, but we will resolve them later simple_test_name = os.path.basename(_extract_source_file_path(test)) - test_dict['_RENAME'] = simple_test_name + test_dict["_RENAME"] = simple_test_name test_metadata[test] = test_dict # detect duplicate test names tests_by_simple_name = {} for test_name, test_dict in list(test_metadata.items()): - simple_test_name = test_dict['_RENAME'] + simple_test_name = test_dict["_RENAME"] if simple_test_name not in tests_by_simple_name: tests_by_simple_name[simple_test_name] = [] tests_by_simple_name[simple_test_name].append(test_name) @@ -815,37 +889,40 @@ def _generate_build_extra_metadata_for_tests( for collision_list in list(tests_by_simple_name.values()): if len(collision_list) > 1: for test_name in collision_list: - long_name = test_name.replace('/', '_').replace(':', '_') - print(( - 'short name of "%s" collides with another test, renaming to %s' - % (test_name, long_name))) - test_metadata[test_name]['_RENAME'] = long_name + long_name = test_name.replace("/", "_").replace(":", "_") + print( + 'short name of "%s" collides with another test, renaming' + " to %s" % (test_name, long_name) + ) + test_metadata[test_name]["_RENAME"] = long_name return test_metadata -def _parse_http_archives(xml_tree: ET.Element) -> 'List[ExternalProtoLibrary]': +def _parse_http_archives(xml_tree: ET.Element) -> "List[ExternalProtoLibrary]": """Parse Bazel http_archive rule into ExternalProtoLibrary objects.""" result = [] for xml_http_archive in xml_tree: - if xml_http_archive.tag != 'rule' or xml_http_archive.attrib[ - 'class'] != 'http_archive': + if ( + xml_http_archive.tag != "rule" + or xml_http_archive.attrib["class"] != "http_archive" + ): continue # A distilled Python representation of Bazel http_archive http_archive = dict() for xml_node in xml_http_archive: - if xml_node.attrib['name'] == 'name': - http_archive["name"] = xml_node.attrib['value'] - if xml_node.attrib['name'] == 'urls': + if xml_node.attrib["name"] == "name": + http_archive["name"] = xml_node.attrib["value"] + if xml_node.attrib["name"] == "urls": http_archive["urls"] = [] for url_node in xml_node: - http_archive["urls"].append(url_node.attrib['value']) - if xml_node.attrib['name'] == 'url': - http_archive["urls"] = [xml_node.attrib['value']] - if xml_node.attrib['name'] == 'sha256': - http_archive["hash"] = xml_node.attrib['value'] - if xml_node.attrib['name'] == 'strip_prefix': - http_archive["strip_prefix"] = xml_node.attrib['value'] + http_archive["urls"].append(url_node.attrib["value"]) + if xml_node.attrib["name"] == "url": + http_archive["urls"] = [xml_node.attrib["value"]] + if xml_node.attrib["name"] == "sha256": + http_archive["hash"] = xml_node.attrib["value"] + if xml_node.attrib["name"] == "strip_prefix": + http_archive["strip_prefix"] = xml_node.attrib["value"] if http_archive["name"] not in EXTERNAL_PROTO_LIBRARIES: # If this http archive is not one of the external proto libraries, # we don't want to include it as a CMake target @@ -860,7 +937,7 @@ def _parse_http_archives(xml_tree: ET.Element) -> 'List[ExternalProtoLibrary]': def _generate_external_proto_libraries() -> List[Dict[str, Any]]: """Generates the build metadata for external proto libraries""" - xml_tree = _bazel_query_xml_tree('kind(http_archive, //external:*)') + xml_tree = _bazel_query_xml_tree("kind(http_archive, //external:*)") libraries = _parse_http_archives(xml_tree) libraries.sort(key=lambda x: x.destination) return list(map(lambda x: x.__dict__, libraries)) @@ -868,12 +945,18 @@ def _generate_external_proto_libraries() -> List[Dict[str, Any]]: def _detect_and_print_issues(build_yaml_like: BuildYaml) -> None: """Try detecting some unusual situations and warn about them.""" - for tgt in build_yaml_like['targets']: - if tgt['build'] == 'test': - for src in tgt['src']: - if src.startswith('src/') and not src.endswith('.proto'): - print(('source file from under "src/" tree used in test ' + - tgt['name'] + ': ' + src)) + for tgt in build_yaml_like["targets"]: + if tgt["build"] == "test": + for src in tgt["src"]: + if src.startswith("src/") and not src.endswith(".proto"): + print( + ( + 'source file from under "src/" tree used in test ' + + tgt["name"] + + ": " + + src + ) + ) # extra metadata that will be used to construct build.yaml @@ -881,206 +964,184 @@ def _detect_and_print_issues(build_yaml_like: BuildYaml) -> None: # _TYPE: whether this is library, target or test # _RENAME: whether this target should be renamed to a different name (to match expectations of make and cmake builds) _BUILD_EXTRA_METADATA = { - 'third_party/address_sorting:address_sorting': { - 'language': 'c', - 'build': 'all', - '_RENAME': 'address_sorting' - }, - 'gpr': { - 'language': 'c', - 'build': 'all', - }, - 'grpc': { - 'language': 'c', - 'build': 'all', - 'baselib': True, - 'generate_plugin_registry': True - }, - 'grpc++': { - 'language': 'c++', - 'build': 'all', - 'baselib': True, - }, - 'grpc++_alts': { - 'language': 'c++', - 'build': 'all', - 'baselib': True + "third_party/address_sorting:address_sorting": { + "language": "c", + "build": "all", + "_RENAME": "address_sorting", }, - 'grpc++_error_details': { - 'language': 'c++', - 'build': 'all' + "gpr": { + "language": "c", + "build": "all", }, - 'grpc++_reflection': { - 'language': 'c++', - 'build': 'all' + "grpc": { + "language": "c", + "build": "all", + "baselib": True, + "generate_plugin_registry": True, }, - 'grpc_authorization_provider': { - 'language': 'c++', - 'build': 'all' + "grpc++": { + "language": "c++", + "build": "all", + "baselib": True, }, - 'grpc++_unsecure': { - 'language': 'c++', - 'build': 'all', - 'baselib': True, + "grpc++_alts": {"language": "c++", "build": "all", "baselib": True}, + "grpc++_error_details": {"language": "c++", "build": "all"}, + "grpc++_reflection": {"language": "c++", "build": "all"}, + "grpc_authorization_provider": {"language": "c++", "build": "all"}, + "grpc++_unsecure": { + "language": "c++", + "build": "all", + "baselib": True, }, - 'grpc_unsecure': { - 'language': 'c', - 'build': 'all', - 'baselib': True, - 'generate_plugin_registry': True + "grpc_unsecure": { + "language": "c", + "build": "all", + "baselib": True, + "generate_plugin_registry": True, }, - 'grpcpp_channelz': { - 'language': 'c++', - 'build': 'all' + "grpcpp_channelz": {"language": "c++", "build": "all"}, + "grpc++_test": { + "language": "c++", + "build": "private", }, - 'grpc++_test': { - 'language': 'c++', - 'build': 'private', + "src/compiler:grpc_plugin_support": { + "language": "c++", + "build": "protoc", + "_RENAME": "grpc_plugin_support", }, - 'src/compiler:grpc_plugin_support': { - 'language': 'c++', - 'build': 'protoc', - '_RENAME': 'grpc_plugin_support' + "src/compiler:grpc_cpp_plugin": { + "language": "c++", + "build": "protoc", + "_TYPE": "target", + "_RENAME": "grpc_cpp_plugin", }, - 'src/compiler:grpc_cpp_plugin': { - 'language': 'c++', - 'build': 'protoc', - '_TYPE': 'target', - '_RENAME': 'grpc_cpp_plugin' + "src/compiler:grpc_csharp_plugin": { + "language": "c++", + "build": "protoc", + "_TYPE": "target", + "_RENAME": "grpc_csharp_plugin", }, - 'src/compiler:grpc_csharp_plugin': { - 'language': 'c++', - 'build': 'protoc', - '_TYPE': 'target', - '_RENAME': 'grpc_csharp_plugin' + "src/compiler:grpc_node_plugin": { + "language": "c++", + "build": "protoc", + "_TYPE": "target", + "_RENAME": "grpc_node_plugin", }, - 'src/compiler:grpc_node_plugin': { - 'language': 'c++', - 'build': 'protoc', - '_TYPE': 'target', - '_RENAME': 'grpc_node_plugin' + "src/compiler:grpc_objective_c_plugin": { + "language": "c++", + "build": "protoc", + "_TYPE": "target", + "_RENAME": "grpc_objective_c_plugin", }, - 'src/compiler:grpc_objective_c_plugin': { - 'language': 'c++', - 'build': 'protoc', - '_TYPE': 'target', - '_RENAME': 'grpc_objective_c_plugin' + "src/compiler:grpc_php_plugin": { + "language": "c++", + "build": "protoc", + "_TYPE": "target", + "_RENAME": "grpc_php_plugin", }, - 'src/compiler:grpc_php_plugin': { - 'language': 'c++', - 'build': 'protoc', - '_TYPE': 'target', - '_RENAME': 'grpc_php_plugin' + "src/compiler:grpc_python_plugin": { + "language": "c++", + "build": "protoc", + "_TYPE": "target", + "_RENAME": "grpc_python_plugin", }, - 'src/compiler:grpc_python_plugin': { - 'language': 'c++', - 'build': 'protoc', - '_TYPE': 'target', - '_RENAME': 'grpc_python_plugin' + "src/compiler:grpc_ruby_plugin": { + "language": "c++", + "build": "protoc", + "_TYPE": "target", + "_RENAME": "grpc_ruby_plugin", }, - 'src/compiler:grpc_ruby_plugin': { - 'language': 'c++', - 'build': 'protoc', - '_TYPE': 'target', - '_RENAME': 'grpc_ruby_plugin' - }, - # TODO(jtattermusch): consider adding grpc++_core_stats - # test support libraries - 'test/core/util:grpc_test_util': { - 'language': 'c', - 'build': 'private', - '_RENAME': 'grpc_test_util' + "test/core/util:grpc_test_util": { + "language": "c", + "build": "private", + "_RENAME": "grpc_test_util", }, - 'test/core/util:grpc_test_util_unsecure': { - 'language': 'c', - 'build': 'private', - '_RENAME': 'grpc_test_util_unsecure' + "test/core/util:grpc_test_util_unsecure": { + "language": "c", + "build": "private", + "_RENAME": "grpc_test_util_unsecure", }, # TODO(jtattermusch): consider adding grpc++_test_util_unsecure - it doesn't seem to be used by bazel build (don't forget to set secure: False) - 'test/cpp/util:test_config': { - 'language': 'c++', - 'build': 'private', - '_RENAME': 'grpc++_test_config' + "test/cpp/util:test_config": { + "language": "c++", + "build": "private", + "_RENAME": "grpc++_test_config", }, - 'test/cpp/util:test_util': { - 'language': 'c++', - 'build': 'private', - '_RENAME': 'grpc++_test_util' + "test/cpp/util:test_util": { + "language": "c++", + "build": "private", + "_RENAME": "grpc++_test_util", }, - # benchmark support libraries - 'test/cpp/microbenchmarks:helpers': { - 'language': 'c++', - 'build': 'test', - 'defaults': 'benchmark', - '_RENAME': 'benchmark_helpers' + "test/cpp/microbenchmarks:helpers": { + "language": "c++", + "build": "test", + "defaults": "benchmark", + "_RENAME": "benchmark_helpers", }, - 'test/cpp/interop:interop_client': { - 'language': 'c++', - 'build': 'test', - 'run': False, - '_TYPE': 'target', - '_RENAME': 'interop_client' + "test/cpp/interop:interop_client": { + "language": "c++", + "build": "test", + "run": False, + "_TYPE": "target", + "_RENAME": "interop_client", }, - 'test/cpp/interop:interop_server': { - 'language': 'c++', - 'build': 'test', - 'run': False, - '_TYPE': 'target', - '_RENAME': 'interop_server' + "test/cpp/interop:interop_server": { + "language": "c++", + "build": "test", + "run": False, + "_TYPE": "target", + "_RENAME": "interop_server", }, - 'test/cpp/interop:xds_interop_client': { - 'language': 'c++', - 'build': 'test', - 'run': False, - '_TYPE': 'target', - '_RENAME': 'xds_interop_client' + "test/cpp/interop:xds_interop_client": { + "language": "c++", + "build": "test", + "run": False, + "_TYPE": "target", + "_RENAME": "xds_interop_client", }, - 'test/cpp/interop:xds_interop_server': { - 'language': 'c++', - 'build': 'test', - 'run': False, - '_TYPE': 'target', - '_RENAME': 'xds_interop_server' + "test/cpp/interop:xds_interop_server": { + "language": "c++", + "build": "test", + "run": False, + "_TYPE": "target", + "_RENAME": "xds_interop_server", }, - 'test/cpp/interop:http2_client': { - 'language': 'c++', - 'build': 'test', - 'run': False, - '_TYPE': 'target', - '_RENAME': 'http2_client' + "test/cpp/interop:http2_client": { + "language": "c++", + "build": "test", + "run": False, + "_TYPE": "target", + "_RENAME": "http2_client", }, - 'test/cpp/qps:qps_json_driver': { - 'language': 'c++', - 'build': 'test', - 'run': False, - '_TYPE': 'target', - '_RENAME': 'qps_json_driver' + "test/cpp/qps:qps_json_driver": { + "language": "c++", + "build": "test", + "run": False, + "_TYPE": "target", + "_RENAME": "qps_json_driver", }, - 'test/cpp/qps:qps_worker': { - 'language': 'c++', - 'build': 'test', - 'run': False, - '_TYPE': 'target', - '_RENAME': 'qps_worker' + "test/cpp/qps:qps_worker": { + "language": "c++", + "build": "test", + "run": False, + "_TYPE": "target", + "_RENAME": "qps_worker", }, - 'test/cpp/util:grpc_cli': { - 'language': 'c++', - 'build': 'test', - 'run': False, - '_TYPE': 'target', - '_RENAME': 'grpc_cli' + "test/cpp/util:grpc_cli": { + "language": "c++", + "build": "test", + "run": False, + "_TYPE": "target", + "_RENAME": "grpc_cli", }, - # TODO(jtattermusch): create_jwt and verify_jwt breaks distribtests because it depends on grpc_test_utils and thus requires tests to be built # For now it's ok to disable them as these binaries aren't very useful anyway. # 'test/core/security:create_jwt': { 'language': 'c', 'build': 'tool', '_TYPE': 'target', '_RENAME': 'grpc_create_jwt' }, # 'test/core/security:verify_jwt': { 'language': 'c', 'build': 'tool', '_TYPE': 'target', '_RENAME': 'grpc_verify_jwt' }, - # TODO(jtattermusch): add remaining tools such as grpc_print_google_default_creds_token (they are not used by bazel build) - # TODO(jtattermusch): these fuzzers had no build.yaml equivalent # test/core/compression:message_compress_fuzzer # test/core/compression:message_decompress_fuzzer @@ -1115,7 +1176,8 @@ def _detect_and_print_issues(build_yaml_like: BuildYaml) -> None: bazel_rules = {} for query in _BAZEL_DEPS_QUERIES: bazel_rules.update( - _extract_rules_from_bazel_xml(_bazel_query_xml_tree(query))) + _extract_rules_from_bazel_xml(_bazel_query_xml_tree(query)) + ) # Step 1.5: The sources for UPB protos are pre-generated, so we want # to expand the UPB proto library bazel rules into the generated @@ -1174,7 +1236,8 @@ def _detect_and_print_issues(build_yaml_like: BuildYaml) -> None: all_extra_metadata = {} all_extra_metadata.update(_BUILD_EXTRA_METADATA) all_extra_metadata.update( - _generate_build_extra_metadata_for_tests(tests, bazel_rules)) + _generate_build_extra_metadata_for_tests(tests, bazel_rules) +) # Step 4: Compute the build metadata that will be used in the final build.yaml. # The final build metadata includes transitive dependencies, and sources/headers @@ -1245,7 +1308,8 @@ def _detect_and_print_issues(build_yaml_like: BuildYaml) -> None: # will be a soft error that doesn't block existing target from successfully # built. build_yaml_like[ - 'external_proto_libraries'] = _generate_external_proto_libraries() + "external_proto_libraries" +] = _generate_external_proto_libraries() # detect and report some suspicious situations we've seen before _detect_and_print_issues(build_yaml_like) @@ -1257,6 +1321,7 @@ def _detect_and_print_issues(build_yaml_like: BuildYaml) -> None: # TODO(jtattermusch): The "cleanup" function is taken from the legacy # build system (which used build.yaml) and can be eventually removed. build_yaml_string = build_cleaner.cleaned_build_yaml_dict_as_string( - build_yaml_like) -with open('build_autogenerated.yaml', 'w') as file: + build_yaml_like +) +with open("build_autogenerated.yaml", "w") as file: file.write(build_yaml_string) diff --git a/tools/buildgen/generate_projects.py b/tools/buildgen/generate_projects.py index 7b8241991f518..06ab102720372 100644 --- a/tools/buildgen/generate_projects.py +++ b/tools/buildgen/generate_projects.py @@ -25,44 +25,51 @@ import _utils import yaml -PROJECT_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", - "..") +PROJECT_ROOT = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", ".." +) os.chdir(PROJECT_ROOT) # TODO(lidiz) find a better way for plugins to reference each other -sys.path.append(os.path.join(PROJECT_ROOT, 'tools', 'buildgen', 'plugins')) +sys.path.append(os.path.join(PROJECT_ROOT, "tools", "buildgen", "plugins")) # from tools.run_tests.python_utils import jobset jobset = _utils.import_python_module( - os.path.join(PROJECT_ROOT, 'tools', 'run_tests', 'python_utils', - 'jobset.py')) + os.path.join( + PROJECT_ROOT, "tools", "run_tests", "python_utils", "jobset.py" + ) +) -PREPROCESSED_BUILD = '.preprocessed_build' -test = {} if os.environ.get('TEST', 'false') == 'true' else None +PREPROCESSED_BUILD = ".preprocessed_build" +test = {} if os.environ.get("TEST", "false") == "true" else None -assert sys.argv[1:], 'run generate_projects.sh instead of this directly' +assert sys.argv[1:], "run generate_projects.sh instead of this directly" parser = argparse.ArgumentParser() -parser.add_argument('build_files', - nargs='+', - default=[], - help="build files describing build specs") -parser.add_argument('--templates', - nargs='+', - default=[], - help="mako template files to render") -parser.add_argument('--output_merged', - '-m', - default='', - type=str, - help="merge intermediate results to a file") -parser.add_argument('--jobs', - '-j', - default=multiprocessing.cpu_count(), - type=int, - help="maximum parallel jobs") -parser.add_argument('--base', - default='.', - type=str, - help="base path for generated files") +parser.add_argument( + "build_files", + nargs="+", + default=[], + help="build files describing build specs", +) +parser.add_argument( + "--templates", nargs="+", default=[], help="mako template files to render" +) +parser.add_argument( + "--output_merged", + "-m", + default="", + type=str, + help="merge intermediate results to a file", +) +parser.add_argument( + "--jobs", + "-j", + default=multiprocessing.cpu_count(), + type=int, + help="maximum parallel jobs", +) +parser.add_argument( + "--base", default=".", type=str, help="base path for generated files" +) args = parser.parse_args() @@ -70,14 +77,14 @@ def preprocess_build_files() -> _utils.Bunch: """Merges build yaml into a one dictionary then pass it to plugins.""" build_spec = dict() for build_file in args.build_files: - with open(build_file, 'r') as f: + with open(build_file, "r") as f: _utils.merge_json(build_spec, yaml.safe_load(f.read())) # Executes plugins. Plugins update the build spec in-place. - for py_file in sorted(glob.glob('tools/buildgen/plugins/*.py')): + for py_file in sorted(glob.glob("tools/buildgen/plugins/*.py")): plugin = _utils.import_python_module(py_file) plugin.mako_plugin(build_spec) if args.output_merged: - with open(args.output_merged, 'w') as f: + with open(args.output_merged, "w") as f: f.write(yaml.dump(build_spec)) # Makes build_spec sort of immutable and dot-accessible return _utils.to_bunch(build_spec) @@ -86,18 +93,18 @@ def preprocess_build_files() -> _utils.Bunch: def generate_template_render_jobs(templates: List[str]) -> List[jobset.JobSpec]: """Generate JobSpecs for each one of the template rendering work.""" jobs = [] - base_cmd = [sys.executable, 'tools/buildgen/_mako_renderer.py'] + base_cmd = [sys.executable, "tools/buildgen/_mako_renderer.py"] for template in sorted(templates, reverse=True): root, f = os.path.split(template) - if os.path.splitext(f)[1] == '.template': - out_dir = args.base + root[len('templates'):] + if os.path.splitext(f)[1] == ".template": + out_dir = args.base + root[len("templates") :] out = os.path.join(out_dir, os.path.splitext(f)[0]) if not os.path.exists(out_dir): os.makedirs(out_dir) cmd = base_cmd[:] - cmd.append('-P') + cmd.append("-P") cmd.append(PREPROCESSED_BUILD) - cmd.append('-o') + cmd.append("-o") if test is None: cmd.append(out) else: @@ -105,37 +112,41 @@ def generate_template_render_jobs(templates: List[str]) -> List[jobset.JobSpec]: test[out] = tf[1] os.close(tf[0]) cmd.append(test[out]) - cmd.append(args.base + '/' + root + '/' + f) - jobs.append(jobset.JobSpec(cmd, shortname=out, - timeout_seconds=None)) + cmd.append(args.base + "/" + root + "/" + f) + jobs.append( + jobset.JobSpec(cmd, shortname=out, timeout_seconds=None) + ) return jobs def main() -> None: templates = args.templates if not templates: - for root, _, files in os.walk('templates'): + for root, _, files in os.walk("templates"): for f in files: templates.append(os.path.join(root, f)) build_spec = preprocess_build_files() - with open(PREPROCESSED_BUILD, 'wb') as f: + with open(PREPROCESSED_BUILD, "wb") as f: pickle.dump(build_spec, f) - err_cnt, _ = jobset.run(generate_template_render_jobs(templates), - maxjobs=args.jobs) + err_cnt, _ = jobset.run( + generate_template_render_jobs(templates), maxjobs=args.jobs + ) if err_cnt != 0: - print('ERROR: %s error(s) found while generating projects.' % err_cnt, - file=sys.stderr) + print( + "ERROR: %s error(s) found while generating projects." % err_cnt, + file=sys.stderr, + ) sys.exit(1) if test is not None: for s, g in test.items(): if os.path.isfile(g): - assert 0 == os.system('diff %s %s' % (s, g)), s + assert 0 == os.system("diff %s %s" % (s, g)), s os.unlink(g) else: - assert 0 == os.system('diff -r %s %s' % (s, g)), s + assert 0 == os.system("diff -r %s %s" % (s, g)), s shutil.rmtree(g, ignore_errors=True) diff --git a/tools/buildgen/plugins/check_attrs.py b/tools/buildgen/plugins/check_attrs.py index 14a5940149368..8f14277f43330 100644 --- a/tools/buildgen/plugins/check_attrs.py +++ b/tools/buildgen/plugins/check_attrs.py @@ -19,91 +19,95 @@ def anything(): def one_of(values): - return lambda v: ('{0} is not in [{1}]'.format(v, values) - if v not in values else None) + return lambda v: ( + "{0} is not in [{1}]".format(v, values) if v not in values else None + ) def subset_of(values): - return lambda v: ('{0} is not subset of [{1}]'.format(v, values) - if not all(e in values for e in v) else None) + return lambda v: ( + "{0} is not subset of [{1}]".format(v, values) + if not all(e in values for e in v) + else None + ) VALID_ATTRIBUTE_KEYS_MAP = { - 'filegroup': { - 'deps': anything(), - 'headers': anything(), - 'plugin': anything(), - 'public_headers': anything(), - 'src': anything(), - 'uses': anything(), + "filegroup": { + "deps": anything(), + "headers": anything(), + "plugin": anything(), + "public_headers": anything(), + "src": anything(), + "uses": anything(), }, - 'lib': { - 'asm_src': anything(), - 'baselib': anything(), - 'boringssl': one_of((True,)), - 'build_system': anything(), - 'build': anything(), - 'cmake_target': anything(), - 'defaults': anything(), - 'deps_linkage': one_of(('static',)), - 'deps': anything(), - 'dll': one_of((True, 'only')), - 'filegroups': anything(), - 'generate_plugin_registry': anything(), - 'headers': anything(), - 'language': one_of(('c', 'c++', 'csharp')), - 'LDFLAGS': anything(), - 'platforms': subset_of(('linux', 'mac', 'posix', 'windows')), - 'public_headers': anything(), - 'secure': one_of(('check', True, False)), - 'src': anything(), - 'vs_proj_dir': anything(), - 'zlib': one_of((True,)), + "lib": { + "asm_src": anything(), + "baselib": anything(), + "boringssl": one_of((True,)), + "build_system": anything(), + "build": anything(), + "cmake_target": anything(), + "defaults": anything(), + "deps_linkage": one_of(("static",)), + "deps": anything(), + "dll": one_of((True, "only")), + "filegroups": anything(), + "generate_plugin_registry": anything(), + "headers": anything(), + "language": one_of(("c", "c++", "csharp")), + "LDFLAGS": anything(), + "platforms": subset_of(("linux", "mac", "posix", "windows")), + "public_headers": anything(), + "secure": one_of(("check", True, False)), + "src": anything(), + "vs_proj_dir": anything(), + "zlib": one_of((True,)), }, - 'target': { - 'args': anything(), - 'benchmark': anything(), - 'boringssl': one_of((True,)), - 'build': anything(), - 'ci_platforms': anything(), - 'corpus_dirs': anything(), - 'cpu_cost': anything(), - 'defaults': anything(), - 'deps': anything(), - 'dict': anything(), - 'exclude_configs': anything(), - 'exclude_iomgrs': anything(), - 'excluded_poll_engines': anything(), - 'filegroups': anything(), - 'flaky': one_of((True, False)), - 'gtest': one_of((True, False)), - 'headers': anything(), - 'language': one_of(('c', 'c89', 'c++', 'csharp')), - 'maxlen': anything(), - 'platforms': subset_of(('linux', 'mac', 'posix', 'windows')), - 'run': one_of((True, False)), - 'secure': one_of(('check', True, False)), - 'src': anything(), - 'timeout_seconds': anything(), - 'uses_polling': anything(), - 'vs_proj_dir': anything(), - 'zlib': one_of((True,)), + "target": { + "args": anything(), + "benchmark": anything(), + "boringssl": one_of((True,)), + "build": anything(), + "ci_platforms": anything(), + "corpus_dirs": anything(), + "cpu_cost": anything(), + "defaults": anything(), + "deps": anything(), + "dict": anything(), + "exclude_configs": anything(), + "exclude_iomgrs": anything(), + "excluded_poll_engines": anything(), + "filegroups": anything(), + "flaky": one_of((True, False)), + "gtest": one_of((True, False)), + "headers": anything(), + "language": one_of(("c", "c89", "c++", "csharp")), + "maxlen": anything(), + "platforms": subset_of(("linux", "mac", "posix", "windows")), + "run": one_of((True, False)), + "secure": one_of(("check", True, False)), + "src": anything(), + "timeout_seconds": anything(), + "uses_polling": anything(), + "vs_proj_dir": anything(), + "zlib": one_of((True,)), + }, + "external_proto_library": { + "destination": anything(), + "proto_prefix": anything(), + "urls": anything(), + "hash": anything(), + "strip_prefix": anything(), }, - 'external_proto_library': { - 'destination': anything(), - 'proto_prefix': anything(), - 'urls': anything(), - 'hash': anything(), - 'strip_prefix': anything(), - } } def check_attributes(entity, kind, errors): attributes = VALID_ATTRIBUTE_KEYS_MAP[kind] - name = entity.get('name', anything()) + name = entity.get("name", anything()) for key, value in list(entity.items()): - if key == 'name': + if key == "name": continue validator = attributes.get(key) if validator: @@ -111,10 +115,15 @@ def check_attributes(entity, kind, errors): if error: errors.append( "{0}({1}) has an invalid value for '{2}': {3}".format( - name, kind, key, error)) + name, kind, key, error + ) + ) else: - errors.append("{0}({1}) has an invalid attribute '{2}'".format( - name, kind, key)) + errors.append( + "{0}({1}) has an invalid attribute '{2}'".format( + name, kind, key + ) + ) def mako_plugin(dictionary): @@ -126,13 +135,13 @@ def mako_plugin(dictionary): """ errors = [] - for filegroup in dictionary.get('filegroups', {}): - check_attributes(filegroup, 'filegroup', errors) - for lib in dictionary.get('libs', {}): - check_attributes(lib, 'lib', errors) - for target in dictionary.get('targets', {}): - check_attributes(target, 'target', errors) - for target in dictionary.get('external_proto_libraries', {}): - check_attributes(target, 'external_proto_library', errors) + for filegroup in dictionary.get("filegroups", {}): + check_attributes(filegroup, "filegroup", errors) + for lib in dictionary.get("libs", {}): + check_attributes(lib, "lib", errors) + for target in dictionary.get("targets", {}): + check_attributes(target, "target", errors) + for target in dictionary.get("external_proto_libraries", {}): + check_attributes(target, "external_proto_library", errors) if errors: - raise Exception('\n'.join(errors)) + raise Exception("\n".join(errors)) diff --git a/tools/buildgen/plugins/expand_bin_attrs.py b/tools/buildgen/plugins/expand_bin_attrs.py index d5acd8d06c876..e8de2b04ba6ff 100755 --- a/tools/buildgen/plugins/expand_bin_attrs.py +++ b/tools/buildgen/plugins/expand_bin_attrs.py @@ -21,26 +21,26 @@ def mako_plugin(dictionary): """The exported plugin code for expand_filegroups. - The list of libs in the build.yaml file can contain "filegroups" tags. - These refer to the filegroups in the root object. We will expand and - merge filegroups on the src, headers and public_headers properties. + The list of libs in the build.yaml file can contain "filegroups" tags. + These refer to the filegroups in the root object. We will expand and + merge filegroups on the src, headers and public_headers properties. - """ + """ - targets = dictionary.get('targets') - default_platforms = ['windows', 'posix', 'linux', 'mac'] + targets = dictionary.get("targets") + default_platforms = ["windows", "posix", "linux", "mac"] for tgt in targets: - tgt['flaky'] = tgt.get('flaky', False) - tgt['platforms'] = sorted(tgt.get('platforms', default_platforms)) - tgt['ci_platforms'] = sorted(tgt.get('ci_platforms', tgt['platforms'])) - tgt['boringssl'] = tgt.get('boringssl', False) - tgt['zlib'] = tgt.get('zlib', False) - tgt['ares'] = tgt.get('ares', False) - tgt['gtest'] = tgt.get('gtest', False) - - libs = dictionary.get('libs') + tgt["flaky"] = tgt.get("flaky", False) + tgt["platforms"] = sorted(tgt.get("platforms", default_platforms)) + tgt["ci_platforms"] = sorted(tgt.get("ci_platforms", tgt["platforms"])) + tgt["boringssl"] = tgt.get("boringssl", False) + tgt["zlib"] = tgt.get("zlib", False) + tgt["ares"] = tgt.get("ares", False) + tgt["gtest"] = tgt.get("gtest", False) + + libs = dictionary.get("libs") for lib in libs: - lib['boringssl'] = lib.get('boringssl', False) - lib['zlib'] = lib.get('zlib', False) - lib['ares'] = lib.get('ares', False) + lib["boringssl"] = lib.get("boringssl", False) + lib["zlib"] = lib.get("zlib", False) + lib["ares"] = lib.get("ares", False) diff --git a/tools/buildgen/plugins/expand_version.py b/tools/buildgen/plugins/expand_version.py index fc70e9323be6f..ddf53eabdc289 100755 --- a/tools/buildgen/plugins/expand_version.py +++ b/tools/buildgen/plugins/expand_version.py @@ -21,111 +21,119 @@ import re LANGUAGES = [ - 'core', - 'cpp', - 'csharp', - 'node', - 'objc', - 'php', - 'python', - 'ruby', + "core", + "cpp", + "csharp", + "node", + "objc", + "php", + "python", + "ruby", ] class Version: - def __init__(self, version_str, override_major=None): self.tag = None - if '-' in version_str: - version_str, self.tag = version_str.split('-') + if "-" in version_str: + version_str, self.tag = version_str.split("-") self.major, self.minor, self.patch = [ - int(x) for x in version_str.split('.') + int(x) for x in version_str.split(".") ] if override_major: self.major = override_major def __str__(self): """Version string in a somewhat idiomatic style for most languages""" - version_str = '%d.%d.%d' % (self.major, self.minor, self.patch) + version_str = "%d.%d.%d" % (self.major, self.minor, self.patch) if self.tag: - version_str += '-%s' % self.tag + version_str += "-%s" % self.tag return version_str def pep440(self): """Version string in Python PEP440 style""" - s = '%d.%d.%d' % (self.major, self.minor, self.patch) + s = "%d.%d.%d" % (self.major, self.minor, self.patch) if self.tag: # we need to translate from grpc version tags to pep440 version # tags; this code is likely to be a little ad-hoc - if self.tag == 'dev': - s += '.dev0' - elif len(self.tag) >= 3 and self.tag[0:3] == 'pre': - s += 'rc%d' % int(self.tag[3:]) + if self.tag == "dev": + s += ".dev0" + elif len(self.tag) >= 3 and self.tag[0:3] == "pre": + s += "rc%d" % int(self.tag[3:]) else: raise Exception( - 'Don\'t know how to translate version tag "%s" to pep440' % - self.tag) + 'Don\'t know how to translate version tag "%s" to pep440' + % self.tag + ) return s def ruby(self): """Version string in Ruby style""" if self.tag: - return '%d.%d.%d.%s' % (self.major, self.minor, self.patch, - self.tag) + return "%d.%d.%d.%s" % ( + self.major, + self.minor, + self.patch, + self.tag, + ) else: - return '%d.%d.%d' % (self.major, self.minor, self.patch) + return "%d.%d.%d" % (self.major, self.minor, self.patch) def php(self): """Version string for PHP PECL package""" - s = '%d.%d.%d' % (self.major, self.minor, self.patch) + s = "%d.%d.%d" % (self.major, self.minor, self.patch) if self.tag: - if self.tag == 'dev': - s += 'dev' - elif len(self.tag) >= 3 and self.tag[0:3] == 'pre': - s += 'RC%d' % int(self.tag[3:]) + if self.tag == "dev": + s += "dev" + elif len(self.tag) >= 3 and self.tag[0:3] == "pre": + s += "RC%d" % int(self.tag[3:]) else: raise Exception( - 'Don\'t know how to translate version tag "%s" to PECL version' - % self.tag) + 'Don\'t know how to translate version tag "%s" to PECL' + " version" % self.tag + ) return s def php_stability(self): """stability string for PHP PECL package.xml file""" if self.tag: - return 'beta' + return "beta" else: - return 'stable' + return "stable" def php_composer(self): """Version string for PHP Composer package""" - return '%d.%d.%d' % (self.major, self.minor, self.patch) + return "%d.%d.%d" % (self.major, self.minor, self.patch) def php_current_version(self): - return '7.4' + return "7.4" def php_debian_version(self): - return 'buster' + return "buster" def mako_plugin(dictionary): """Expand version numbers: - - for each language, ensure there's a language_version tag in - settings (defaulting to the master version tag) - - expand version strings to major, minor, patch, and tag - """ + - for each language, ensure there's a language_version tag in + settings (defaulting to the master version tag) + - expand version strings to major, minor, patch, and tag + """ - settings = dictionary['settings'] - version_str = settings['version'] + settings = dictionary["settings"] + version_str = settings["version"] master_version = Version(version_str) - settings['version'] = master_version + settings["version"] = master_version for language in LANGUAGES: - version_tag = '%s_version' % language - override_major = settings.get('%s_major_version' % language, None) + version_tag = "%s_version" % language + override_major = settings.get("%s_major_version" % language, None) if version_tag in settings: - settings[version_tag] = Version(settings[version_tag], - override_major=override_major) + settings[version_tag] = Version( + settings[version_tag], override_major=override_major + ) else: - settings[version_tag] = Version(version_str, - override_major=override_major) - settings['protobuf_major_minor_version'] = ('.'.join( - settings['protobuf_version'].split('.')[:2])) + settings[version_tag] = Version( + version_str, override_major=override_major + ) + settings["protobuf_major_minor_version"] = ".".join( + settings["protobuf_version"].split(".")[:2] + ) diff --git a/tools/buildgen/plugins/list_api.py b/tools/buildgen/plugins/list_api.py index 63a635a909ea8..c5f2a85fea9d1 100755 --- a/tools/buildgen/plugins/list_api.py +++ b/tools/buildgen/plugins/list_api.py @@ -22,33 +22,33 @@ import yaml -_RE_API = r'(?:GPRAPI|GRPCAPI|CENSUSAPI)([^#;]*);' +_RE_API = r"(?:GPRAPI|GRPCAPI|CENSUSAPI)([^#;]*);" def list_c_apis(filenames): for filename in filenames: - with open(filename, 'r') as f: + with open(filename, "r") as f: text = f.read() for m in re.finditer(_RE_API, text): - api_declaration = re.sub('[ \r\n\t]+', ' ', m.group(1)) - type_and_name, args_and_close = api_declaration.split('(', 1) - args = args_and_close[:args_and_close.rfind(')')].strip() - last_space = type_and_name.rfind(' ') - last_star = type_and_name.rfind('*') + api_declaration = re.sub("[ \r\n\t]+", " ", m.group(1)) + type_and_name, args_and_close = api_declaration.split("(", 1) + args = args_and_close[: args_and_close.rfind(")")].strip() + last_space = type_and_name.rfind(" ") + last_star = type_and_name.rfind("*") type_end = max(last_space, last_star) - return_type = type_and_name[0:type_end + 1].strip() - name = type_and_name[type_end + 1:].strip() + return_type = type_and_name[0 : type_end + 1].strip() + name = type_and_name[type_end + 1 :].strip() yield { - 'return_type': return_type, - 'name': name, - 'arguments': args, - 'header': filename + "return_type": return_type, + "name": name, + "arguments": args, + "header": filename, } def headers_under(directory): for root, dirnames, filenames in os.walk(directory): - for filename in fnmatch.filter(filenames, '*.h'): + for filename in fnmatch.filter(filenames, "*.h"): yield os.path.join(root, filename) @@ -56,15 +56,15 @@ def mako_plugin(dictionary): apis = [] headers = [] - for lib in dictionary['libs']: - if lib['name'] in ['grpc', 'gpr']: - headers.extend(lib['public_headers']) + for lib in dictionary["libs"]: + if lib["name"] in ["grpc", "gpr"]: + headers.extend(lib["public_headers"]) apis.extend(list_c_apis(sorted(set(headers)))) - dictionary['c_apis'] = apis + dictionary["c_apis"] = apis -if __name__ == '__main__': +if __name__ == "__main__": print( - (yaml.dump([api for api in list_c_apis(headers_under('include/grpc')) - ]))) + (yaml.dump([api for api in list_c_apis(headers_under("include/grpc"))])) + ) diff --git a/tools/buildgen/plugins/list_protos.py b/tools/buildgen/plugins/list_protos.py index 0aa5fe55d29fe..2863de7e27934 100755 --- a/tools/buildgen/plugins/list_protos.py +++ b/tools/buildgen/plugins/list_protos.py @@ -24,30 +24,30 @@ def mako_plugin(dictionary): """The exported plugin code for list_protos. - Some projects generators may want to get the full list of unique .proto files - that are being included in a project. This code extracts all files referenced - in any library or target that ends in .proto, and builds and exports that as - a list called "protos". + Some projects generators may want to get the full list of unique .proto files + that are being included in a project. This code extracts all files referenced + in any library or target that ends in .proto, and builds and exports that as + a list called "protos". - """ + """ - libs = dictionary.get('libs', []) - targets = dictionary.get('targets', []) + libs = dictionary.get("libs", []) + targets = dictionary.get("targets", []) - proto_re = re.compile('(.*)\\.proto') + proto_re = re.compile("(.*)\\.proto") protos = set() for lib in libs: - for src in lib.get('src', []): + for src in lib.get("src", []): m = proto_re.match(src) if m: protos.add(m.group(1)) for tgt in targets: - for src in tgt.get('src', []): + for src in tgt.get("src", []): m = proto_re.match(src) if m: protos.add(m.group(1)) protos = sorted(protos) - dictionary['protos'] = protos + dictionary["protos"] = protos diff --git a/tools/buildgen/plugins/transitive_dependencies.py b/tools/buildgen/plugins/transitive_dependencies.py index 97c726b74952d..1cb8a1342176c 100644 --- a/tools/buildgen/plugins/transitive_dependencies.py +++ b/tools/buildgen/plugins/transitive_dependencies.py @@ -51,17 +51,18 @@ def mako_plugin(dictionary): transitive_deps property to each with the transitive closure of those dependency lists. The result list is sorted in a topological ordering. """ - lib_map = {lib['name']: lib for lib in dictionary.get('libs')} + lib_map = {lib["name"]: lib for lib in dictionary.get("libs")} for target_name, target_list in list(dictionary.items()): for target in target_list: if isinstance(target, dict): - if 'deps' in target or target_name == 'libs': - if not 'deps' in target: + if "deps" in target or target_name == "libs": + if not "deps" in target: # make sure all the libs have the "deps" field populated - target['deps'] = [] - target['transitive_deps'] = transitive_deps(lib_map, target) + target["deps"] = [] + target["transitive_deps"] = transitive_deps(lib_map, target) - python_dependencies = dictionary.get('python_dependencies') - python_dependencies['transitive_deps'] = transitive_deps( - lib_map, python_dependencies) + python_dependencies = dictionary.get("python_dependencies") + python_dependencies["transitive_deps"] = transitive_deps( + lib_map, python_dependencies + ) diff --git a/tools/buildgen/plugins/verify_duplicate_sources.py b/tools/buildgen/plugins/verify_duplicate_sources.py index 06b2b1cb426ea..5b84278c29ef2 100644 --- a/tools/buildgen/plugins/verify_duplicate_sources.py +++ b/tools/buildgen/plugins/verify_duplicate_sources.py @@ -23,20 +23,22 @@ def mako_plugin(dictionary): """ errors = [] target_groups = ( - ('gpr', 'grpc', 'grpc++'), - ('gpr', 'grpc_unsecure', 'grpc++_unsecure'), + ("gpr", "grpc", "grpc++"), + ("gpr", "grpc_unsecure", "grpc++_unsecure"), ) - lib_map = {lib['name']: lib for lib in dictionary.get('libs')} + lib_map = {lib["name"]: lib for lib in dictionary.get("libs")} for target_group in target_groups: src_map = {} for target in target_group: - for src in lib_map[target]['src']: - if src.endswith('.cc'): + for src in lib_map[target]["src"]: + if src.endswith(".cc"): if src in src_map: errors.append( - 'Source {0} is used in both {1} and {2}'.format( - src, src_map[src], target)) + "Source {0} is used in both {1} and {2}".format( + src, src_map[src], target + ) + ) else: src_map[src] = target if errors: - raise Exception('\n'.join(errors)) + raise Exception("\n".join(errors)) diff --git a/tools/codegen/core/experiments_compiler.py b/tools/codegen/core/experiments_compiler.py index 4377eb0c6a603..3b1188b8e7886 100644 --- a/tools/codegen/core/experiments_compiler.py +++ b/tools/codegen/core/experiments_compiler.py @@ -64,21 +64,21 @@ """ -def ToCStr(s, encoding='ascii'): +def ToCStr(s, encoding="ascii"): if isinstance(s, str): s = s.encode(encoding) - result = '' + result = "" for c in s: c = chr(c) if isinstance(c, int) else c - if not (32 <= ord(c) < 127) or c in ('\\', '"'): - result += '\\%03o' % ord(c) + if not (32 <= ord(c) < 127) or c in ("\\", '"'): + result += "\\%03o" % ord(c) else: result += c return '"' + result + '"' def SnakeToPascal(s): - return ''.join(x.capitalize() for x in s.split('_')) + return "".join(x.capitalize() for x in s.split("_")) def PutBanner(files, banner, prefix): @@ -88,7 +88,7 @@ def PutBanner(files, banner, prefix): if not line: print(prefix, file=f) else: - print('%s %s' % (prefix, line), file=f) + print("%s %s" % (prefix, line), file=f) print(file=f) @@ -97,97 +97,111 @@ def PutCopyright(file, prefix): with open(__file__) as my_source: copyright = [] for line in my_source: - if line[0] != '#': + if line[0] != "#": break for line in my_source: - if line[0] == '#': + if line[0] == "#": copyright.append(line) break for line in my_source: - if line[0] != '#': + if line[0] != "#": break copyright.append(line) PutBanner([file], [line[2:].rstrip() for line in copyright], prefix) class ExperimentDefinition(object): - def __init__(self, attributes): self._error = False - if 'name' not in attributes: + if "name" not in attributes: print("ERROR: experiment with no name: %r" % attributes) self._error = True - if 'description' not in attributes: - print("ERROR: no description for experiment %s" % - attributes['name']) + if "description" not in attributes: + print( + "ERROR: no description for experiment %s" % attributes["name"] + ) self._error = True - if 'owner' not in attributes: - print("ERROR: no owner for experiment %s" % attributes['name']) + if "owner" not in attributes: + print("ERROR: no owner for experiment %s" % attributes["name"]) self._error = True - if 'expiry' not in attributes: - print("ERROR: no expiry for experiment %s" % attributes['name']) + if "expiry" not in attributes: + print("ERROR: no expiry for experiment %s" % attributes["name"]) self._error = True - if attributes['name'] == 'monitoring_experiment': - if attributes['expiry'] != 'never-ever': + if attributes["name"] == "monitoring_experiment": + if attributes["expiry"] != "never-ever": print("ERROR: monitoring_experiment should never expire") self._error = True if self._error: print("Failed to create experiment definition") return self._allow_in_fuzzing_config = True - self._name = attributes['name'] - self._description = attributes['description'] - self._expiry = attributes['expiry'] + self._name = attributes["name"] + self._description = attributes["description"] + self._expiry = attributes["expiry"] self._default = None self._additional_constraints = {} self._test_tags = [] - if 'allow_in_fuzzing_config' in attributes: + if "allow_in_fuzzing_config" in attributes: self._allow_in_fuzzing_config = attributes[ - 'allow_in_fuzzing_config'] + "allow_in_fuzzing_config" + ] - if 'test_tags' in attributes: - self._test_tags = attributes['test_tags'] + if "test_tags" in attributes: + self._test_tags = attributes["test_tags"] def IsValid(self, check_expiry=False): if self._error: return False if not check_expiry: return True - if self._name == 'monitoring_experiment' and self._expiry == 'never-ever': + if ( + self._name == "monitoring_experiment" + and self._expiry == "never-ever" + ): return True today = datetime.date.today() two_quarters_from_now = today + datetime.timedelta(days=180) - expiry = datetime.datetime.strptime(self._expiry, '%Y/%m/%d').date() + expiry = datetime.datetime.strptime(self._expiry, "%Y/%m/%d").date() if expiry < today: - print("WARNING: experiment %s expired on %s" % - (self._name, self._expiry)) + print( + "WARNING: experiment %s expired on %s" + % (self._name, self._expiry) + ) if expiry > two_quarters_from_now: - print("WARNING: experiment %s expires far in the future on %s" % - (self._name, self._expiry)) + print( + "WARNING: experiment %s expires far in the future on %s" + % (self._name, self._expiry) + ) print("expiry should be no more than two quarters from now") return not self._error def AddRolloutSpecification(self, allowed_defaults, rollout_attributes): if self._error or self._default is not None: return False - if rollout_attributes['name'] != self._name: + if rollout_attributes["name"] != self._name: print( - "ERROR: Rollout specification does not apply to this experiment: %s" - % self._name) + "ERROR: Rollout specification does not apply to this" + " experiment: %s" % self._name + ) return False - if 'default' not in rollout_attributes: - print("ERROR: no default for experiment %s" % - rollout_attributes['name']) + if "default" not in rollout_attributes: + print( + "ERROR: no default for experiment %s" + % rollout_attributes["name"] + ) self._error = True - if rollout_attributes['default'] not in allowed_defaults: - print("ERROR: invalid default for experiment %s: %r" % - (rollout_attributes['name'], rollout_attributes['default'])) + if rollout_attributes["default"] not in allowed_defaults: + print( + "ERROR: invalid default for experiment %s: %r" + % (rollout_attributes["name"], rollout_attributes["default"]) + ) self._error = True - if 'additional_constraints' in rollout_attributes: + if "additional_constraints" in rollout_attributes: self._additional_constraints = rollout_attributes[ - 'additional_constraints'] - self._default = rollout_attributes['default'] + "additional_constraints" + ] + self._default = rollout_attributes["default"] return True @property @@ -216,12 +230,9 @@ def additional_constraints(self): class ExperimentsCompiler(object): - - def __init__(self, - defaults, - final_return, - final_define, - bzl_list_for_defaults=None): + def __init__( + self, defaults, final_return, final_define, bzl_list_for_defaults=None + ): self._defaults = defaults self._final_return = final_return self._final_define = final_define @@ -231,32 +242,41 @@ def __init__(self, def AddExperimentDefinition(self, experiment_definition): if experiment_definition.name in self._experiment_definitions: - print("ERROR: Duplicate experiment definition: %s" % - experiment_definition.name) + print( + "ERROR: Duplicate experiment definition: %s" + % experiment_definition.name + ) return False self._experiment_definitions[ - experiment_definition.name] = experiment_definition + experiment_definition.name + ] = experiment_definition return True def AddRolloutSpecification(self, rollout_attributes): - if 'name' not in rollout_attributes: - print("ERROR: experiment with no name: %r in rollout_attribute" % - rollout_attributes) + if "name" not in rollout_attributes: + print( + "ERROR: experiment with no name: %r in rollout_attribute" + % rollout_attributes + ) return False - if rollout_attributes['name'] not in self._experiment_definitions: - print("WARNING: rollout for an undefined experiment: %s ignored" % - rollout_attributes['name']) - return (self._experiment_definitions[ - rollout_attributes['name']].AddRolloutSpecification( - self._defaults, rollout_attributes)) + if rollout_attributes["name"] not in self._experiment_definitions: + print( + "WARNING: rollout for an undefined experiment: %s ignored" + % rollout_attributes["name"] + ) + return self._experiment_definitions[ + rollout_attributes["name"] + ].AddRolloutSpecification(self._defaults, rollout_attributes) def GenerateExperimentsHdr(self, output_file): - with open(output_file, 'w') as H: + with open(output_file, "w") as H: PutCopyright(H, "//") PutBanner( [H], - ["Auto generated by tools/codegen/core/gen_experiments.py"] + - _CODEGEN_PLACEHOLDER_TEXT.splitlines(), "//") + ["Auto generated by tools/codegen/core/gen_experiments.py"] + + _CODEGEN_PLACEHOLDER_TEXT.splitlines(), + "//", + ) print("#ifndef GRPC_SRC_CORE_LIB_EXPERIMENTS_EXPERIMENTS_H", file=H) print("#define GRPC_SRC_CORE_LIB_EXPERIMENTS_EXPERIMENTS_H", file=H) @@ -264,7 +284,7 @@ def GenerateExperimentsHdr(self, output_file): print("#include ", file=H) print(file=H) print("#include ", file=H) - print("#include \"src/core/lib/experiments/config.h\"", file=H) + print('#include "src/core/lib/experiments/config.h"', file=H) print(file=H) print("namespace grpc_core {", file=H) print(file=H) @@ -272,80 +292,109 @@ def GenerateExperimentsHdr(self, output_file): for _, exp in self._experiment_definitions.items(): define_fmt = self._final_define[exp.default] if define_fmt: - print(define_fmt % - ("GRPC_EXPERIMENT_IS_INCLUDED_%s" % exp.name.upper()), - file=H) + print( + define_fmt + % ("GRPC_EXPERIMENT_IS_INCLUDED_%s" % exp.name.upper()), + file=H, + ) print( - "inline bool Is%sEnabled() { %s }" % - (SnakeToPascal(exp.name), self._final_return[exp.default]), - file=H) + "inline bool Is%sEnabled() { %s }" + % ( + SnakeToPascal(exp.name), + self._final_return[exp.default], + ), + file=H, + ) print("#else", file=H) for i, (_, exp) in enumerate(self._experiment_definitions.items()): - print("#define GRPC_EXPERIMENT_IS_INCLUDED_%s" % - exp.name.upper(), - file=H) print( - "inline bool Is%sEnabled() { return IsExperimentEnabled(%d); }" + "#define GRPC_EXPERIMENT_IS_INCLUDED_%s" % exp.name.upper(), + file=H, + ) + print( + "inline bool Is%sEnabled() { return" + " IsExperimentEnabled(%d); }" % (SnakeToPascal(exp.name), i), - file=H) + file=H, + ) print(file=H) - print("constexpr const size_t kNumExperiments = %d;" % - len(self._experiment_definitions.keys()), - file=H) print( - "extern const ExperimentMetadata g_experiment_metadata[kNumExperiments];", - file=H) + "constexpr const size_t kNumExperiments = %d;" + % len(self._experiment_definitions.keys()), + file=H, + ) + print( + ( + "extern const ExperimentMetadata" + " g_experiment_metadata[kNumExperiments];" + ), + file=H, + ) print(file=H) print("#endif", file=H) print("} // namespace grpc_core", file=H) print(file=H) - print("#endif // GRPC_SRC_CORE_LIB_EXPERIMENTS_EXPERIMENTS_H", - file=H) + print( + "#endif // GRPC_SRC_CORE_LIB_EXPERIMENTS_EXPERIMENTS_H", file=H + ) def GenerateExperimentsSrc(self, output_file): - with open(output_file, 'w') as C: + with open(output_file, "w") as C: PutCopyright(C, "//") PutBanner( [C], ["Auto generated by tools/codegen/core/gen_experiments.py"], - "//") + "//", + ) print("#include ", file=C) - print("#include \"src/core/lib/experiments/experiments.h\"", file=C) + print('#include "src/core/lib/experiments/experiments.h"', file=C) print(file=C) print("#ifndef GRPC_EXPERIMENTS_ARE_FINAL", file=C) print("namespace {", file=C) have_defaults = set() for _, exp in self._experiment_definitions.items(): - print("const char* const description_%s = %s;" % - (exp.name, ToCStr(exp.description)), - file=C) print( - "const char* const additional_constraints_%s = %s;" % - (exp.name, ToCStr(json.dumps(exp.additional_constraints))), - file=C) + "const char* const description_%s = %s;" + % (exp.name, ToCStr(exp.description)), + file=C, + ) + print( + "const char* const additional_constraints_%s = %s;" + % ( + exp.name, + ToCStr(json.dumps(exp.additional_constraints)), + ), + file=C, + ) have_defaults.add(self._defaults[exp.default]) - if 'kDefaultForDebugOnly' in have_defaults: + if "kDefaultForDebugOnly" in have_defaults: print("#ifdef NDEBUG", file=C) - if 'kDefaultForDebugOnly' in have_defaults: + if "kDefaultForDebugOnly" in have_defaults: print("const bool kDefaultForDebugOnly = false;", file=C) print("#else", file=C) - if 'kDefaultForDebugOnly' in have_defaults: + if "kDefaultForDebugOnly" in have_defaults: print("const bool kDefaultForDebugOnly = true;", file=C) print("#endif", file=C) print("}", file=C) print(file=C) print("namespace grpc_core {", file=C) print(file=C) - print("const ExperimentMetadata g_experiment_metadata[] = {", - file=C) + print( + "const ExperimentMetadata g_experiment_metadata[] = {", file=C + ) for _, exp in self._experiment_definitions.items(): print( " {%s, description_%s, additional_constraints_%s, %s, %s}," - % (ToCStr(exp.name), exp.name, exp.name, - self._defaults[exp.default], - 'true' if exp.allow_in_fuzzing_config else 'false'), - file=C) + % ( + ToCStr(exp.name), + exp.name, + exp.name, + self._defaults[exp.default], + "true" if exp.allow_in_fuzzing_config else "false", + ), + file=C, + ) print("};", file=C) print(file=C) print("} // namespace grpc_core", file=C) @@ -358,37 +407,43 @@ def GenExperimentsBzl(self, output_file): bzl_to_tags_to_experiments = dict( (key, collections.defaultdict(list)) for key in self._bzl_list_for_defaults.keys() - if key is not None) + if key is not None + ) for _, exp in self._experiment_definitions.items(): for tag in exp.test_tags: bzl_to_tags_to_experiments[exp.default][tag].append(exp.name) - with open(output_file, 'w') as B: + with open(output_file, "w") as B: PutCopyright(B, "#") PutBanner( [B], ["Auto generated by tools/codegen/core/gen_experiments.py"], - "#") + "#", + ) print( - "\"\"\"Dictionary of tags to experiments so we know when to test different experiments.\"\"\"", - file=B) + ( + '"""Dictionary of tags to experiments so we know when to' + ' test different experiments."""' + ), + file=B, + ) bzl_to_tags_to_experiments = sorted( (self._bzl_list_for_defaults[default], tags_to_experiments) - for default, tags_to_experiments in - bzl_to_tags_to_experiments.items() - if self._bzl_list_for_defaults[default] is not None) + for default, tags_to_experiments in bzl_to_tags_to_experiments.items() + if self._bzl_list_for_defaults[default] is not None + ) print(file=B) print("EXPERIMENTS = {", file=B) for key, tags_to_experiments in bzl_to_tags_to_experiments: - print(" \"%s\": {" % key, file=B) + print(' "%s": {' % key, file=B) for tag, experiments in sorted(tags_to_experiments.items()): - print(" \"%s\": [" % tag, file=B) + print(' "%s": [' % tag, file=B) for experiment in sorted(experiments): - print(" \"%s\"," % experiment, file=B) + print(' "%s",' % experiment, file=B) print(" ],", file=B) print(" },", file=B) print("}", file=B) diff --git a/tools/codegen/core/gen_config_vars.py b/tools/codegen/core/gen_config_vars.py index 1f5f02ed1f213..62fa5e3ec1f6a 100755 --- a/tools/codegen/core/gen_config_vars.py +++ b/tools/codegen/core/gen_config_vars.py @@ -28,55 +28,55 @@ import yaml -with open('src/core/lib/config/config_vars.yaml') as f: +with open("src/core/lib/config/config_vars.yaml") as f: attrs = yaml.safe_load(f.read(), Loader=yaml.FullLoader) error = False today = datetime.date.today() two_quarters_from_now = today + datetime.timedelta(days=180) for attr in attrs: - if 'name' not in attr: + if "name" not in attr: print("config has no name: %r" % attr) error = True continue - if 'experiment' in attr['name'] and attr['name'] != 'experiments': - print('use experiment system for experiments') + if "experiment" in attr["name"] and attr["name"] != "experiments": + print("use experiment system for experiments") error = True - if 'description' not in attr: - print("no description for %s" % attr['name']) + if "description" not in attr: + print("no description for %s" % attr["name"]) error = True - if 'default' not in attr: - print("no default for %s" % attr['name']) + if "default" not in attr: + print("no default for %s" % attr["name"]) error = True if error: sys.exit(1) -def c_str(s, encoding='ascii'): +def c_str(s, encoding="ascii"): if s is None: return '""' if isinstance(s, str): s = s.encode(encoding) - result = '' + result = "" for c in s: c = chr(c) if isinstance(c, int) else c - if not (32 <= ord(c) < 127) or c in ('\\', '"'): - result += '\\%03o' % ord(c) + if not (32 <= ord(c) < 127) or c in ("\\", '"'): + result += "\\%03o" % ord(c) else: result += c return '"' + result + '"' def snake_to_pascal(s): - return ''.join(x.capitalize() for x in s.split('_')) + return "".join(x.capitalize() for x in s.split("_")) # utility: print a big comment block into a set of files def put_banner(files, banner): for f in files: for line in banner: - print('// %s' % line, file=f) + print("// %s" % line, file=f) print(file=f) @@ -85,14 +85,14 @@ def put_copyright(file): with open(sys.argv[0]) as my_source: copyright = [] for line in my_source: - if line[0] != '#': + if line[0] != "#": break for line in my_source: - if line[0] == '#': + if line[0] == "#": copyright.append(line) break for line in my_source: - if line[0] != '#': + if line[0] != "#": break copyright.append(line) put_banner([file], [line[2:].rstrip() for line in copyright]) @@ -130,7 +130,7 @@ def put_copyright(file): "int": 0, "bool": 1, "string": 2, - "comma_separated_string": 3 + "comma_separated_string": 3, } @@ -152,7 +152,7 @@ def int_default_value(x, name): def string_default_value(x, name): if x is None: - return "\"\"" + return '""' if x.startswith("$"): return x[1:] return c_str(x) @@ -167,23 +167,28 @@ def string_default_value(x, name): TO_STRING = { "int": "$", - "bool": "$?\"true\":\"false\"", - "string": "\"\\\"\", absl::CEscape($), \"\\\"\"", - "comma_separated_string": "\"\\\"\", absl::CEscape($), \"\\\"\"", + "bool": '$?"true":"false"', + "string": '"\\"", absl::CEscape($), "\\""', + "comma_separated_string": '"\\"", absl::CEscape($), "\\""', } -attrs_in_packing_order = sorted(attrs, - key=lambda a: SORT_ORDER_FOR_PACKING[a['type']]) +attrs_in_packing_order = sorted( + attrs, key=lambda a: SORT_ORDER_FOR_PACKING[a["type"]] +) -with open('test/core/util/fuzz_config_vars.proto', 'w') as P: +with open("test/core/util/fuzz_config_vars.proto", "w") as P: put_copyright(P) - put_banner([P], [ - "", "Automatically generated by tools/codegen/core/gen_config_vars.py", - "" - ]) + put_banner( + [P], + [ + "", + "Automatically generated by tools/codegen/core/gen_config_vars.py", + "", + ], + ) - print("syntax = \"proto3\";", file=P) + print('syntax = "proto3";', file=P) print(file=P) print("package grpc.testing;", file=P) print(file=P) @@ -191,90 +196,121 @@ def string_default_value(x, name): for attr in attrs_in_packing_order: if attr.get("fuzz", False) == False: continue - print(" optional %s %s = %d;" % - (PROTO_TYPE[attr['type']], attr['name'], - binascii.crc32(attr['name'].encode('ascii')) & 0x1FFFFFFF), - file=P) + print( + " optional %s %s = %d;" + % ( + PROTO_TYPE[attr["type"]], + attr["name"], + binascii.crc32(attr["name"].encode("ascii")) & 0x1FFFFFFF, + ), + file=P, + ) print("};", file=P) -with open('test/core/util/fuzz_config_vars.h', 'w') as H: +with open("test/core/util/fuzz_config_vars.h", "w") as H: put_copyright(H) - put_banner([H], [ - "", "Automatically generated by tools/codegen/core/gen_config_vars.py", - "" - ]) + put_banner( + [H], + [ + "", + "Automatically generated by tools/codegen/core/gen_config_vars.py", + "", + ], + ) print("#ifndef GRPC_TEST_CORE_UTIL_FUZZ_CONFIG_VARS_H", file=H) print("#define GRPC_TEST_CORE_UTIL_FUZZ_CONFIG_VARS_H", file=H) print(file=H) print("#include ", file=H) print(file=H) - print("#include \"test/core/util/fuzz_config_vars.pb.h\"", file=H) - print("#include \"src/core/lib/config/config_vars.h\"", file=H) + print('#include "test/core/util/fuzz_config_vars.pb.h"', file=H) + print('#include "src/core/lib/config/config_vars.h"', file=H) print(file=H) print("namespace grpc_core {", file=H) print(file=H) print( - "ConfigVars::Overrides OverridesFromFuzzConfigVars(const grpc::testing::FuzzConfigVars& vars);", - file=H) + ( + "ConfigVars::Overrides OverridesFromFuzzConfigVars(const" + " grpc::testing::FuzzConfigVars& vars);" + ), + file=H, + ) print( "void ApplyFuzzConfigVars(const grpc::testing::FuzzConfigVars& vars);", - file=H) + file=H, + ) print(file=H) print("} // namespace grpc_core", file=H) print(file=H) print("#endif // GRPC_TEST_CORE_UTIL_FUZZ_CONFIG_VARS_H", file=H) -with open('test/core/util/fuzz_config_vars.cc', 'w') as C: +with open("test/core/util/fuzz_config_vars.cc", "w") as C: put_copyright(C) - put_banner([C], [ - "", "Automatically generated by tools/codegen/core/gen_config_vars.py", - "" - ]) - - print("#include \"test/core/util/fuzz_config_vars.h\"", file=C) - print("#include \"test/core/util/fuzz_config_vars_helpers.h\"", file=C) + put_banner( + [C], + [ + "", + "Automatically generated by tools/codegen/core/gen_config_vars.py", + "", + ], + ) + + print('#include "test/core/util/fuzz_config_vars.h"', file=C) + print('#include "test/core/util/fuzz_config_vars_helpers.h"', file=C) print(file=C) print("namespace grpc_core {", file=C) print(file=C) print( - "ConfigVars::Overrides OverridesFromFuzzConfigVars(const grpc::testing::FuzzConfigVars& vars) {", - file=C) + ( + "ConfigVars::Overrides OverridesFromFuzzConfigVars(const" + " grpc::testing::FuzzConfigVars& vars) {" + ), + file=C, + ) print(" ConfigVars::Overrides overrides;", file=C) for attr in attrs_in_packing_order: fuzz = attr.get("fuzz", False) if not fuzz: continue - print(" if (vars.has_%s()) {" % attr['name'], file=C) + print(" if (vars.has_%s()) {" % attr["name"], file=C) if isinstance(fuzz, str): - print(" overrides.%s = %s(vars.%s());" % - (attr['name'], fuzz, attr['name']), - file=C) + print( + " overrides.%s = %s(vars.%s());" + % (attr["name"], fuzz, attr["name"]), + file=C, + ) else: - print(" overrides.%s = vars.%s();" % - (attr['name'], attr['name']), - file=C) + print( + " overrides.%s = vars.%s();" % (attr["name"], attr["name"]), + file=C, + ) print(" }", file=C) print(" return overrides;", file=C) print("}", file=C) print( "void ApplyFuzzConfigVars(const grpc::testing::FuzzConfigVars& vars) {", - file=C) - print(" ConfigVars::SetOverrides(OverridesFromFuzzConfigVars(vars));", - file=C) + file=C, + ) + print( + " ConfigVars::SetOverrides(OverridesFromFuzzConfigVars(vars));", file=C + ) print("}", file=C) print(file=C) print("} // namespace grpc_core", file=C) -with open('src/core/lib/config/config_vars.h', 'w') as H: +with open("src/core/lib/config/config_vars.h", "w") as H: put_copyright(H) - put_banner([H], [ - "", "Automatically generated by tools/codegen/core/gen_config_vars.py", - "" - ]) + put_banner( + [H], + [ + "", + "Automatically generated by tools/codegen/core/gen_config_vars.py", + "", + ], + ) print("#ifndef GRPC_SRC_CORE_LIB_CONFIG_CONFIG_VARS_H", file=H) print("#define GRPC_SRC_CORE_LIB_CONFIG_CONFIG_VARS_H", file=H) @@ -284,8 +320,8 @@ def string_default_value(x, name): print("#include ", file=H) print("#include ", file=H) print("#include ", file=H) - print("#include \"absl/strings/string_view.h\"", file=H) - print("#include \"absl/types/optional.h\"", file=H) + print('#include "absl/strings/string_view.h"', file=H) + print('#include "absl/types/optional.h"', file=H) print(file=H) print("namespace grpc_core {", file=H) print(file=H) @@ -293,122 +329,160 @@ def string_default_value(x, name): print(" public:", file=H) print(" struct Overrides {", file=H) for attr in attrs_in_packing_order: - print(" absl::optional<%s> %s;" % - (MEMBER_TYPE[attr['type']], attr['name']), - file=H) + print( + " absl::optional<%s> %s;" + % (MEMBER_TYPE[attr["type"]], attr["name"]), + file=H, + ) print(" };", file=H) print(" ConfigVars(const ConfigVars&) = delete;", file=H) print(" ConfigVars& operator=(const ConfigVars&) = delete;", file=H) - print(" // Get the core configuration; if it does not exist, create it.", - file=H) + print( + " // Get the core configuration; if it does not exist, create it.", + file=H, + ) print(" static const ConfigVars& Get() {", file=H) print(" auto* p = config_vars_.load(std::memory_order_acquire);", file=H) print(" if (p != nullptr) return *p;", file=H) print(" return Load();", file=H) print(" }", file=H) print(" static void SetOverrides(const Overrides& overrides);", file=H) - print(" // Drop the config vars. Users must ensure no other threads are", - file=H) + print( + " // Drop the config vars. Users must ensure no other threads are", + file=H, + ) print(" // accessing the configuration.", file=H) print(" static void Reset();", file=H) print(" std::string ToString() const;", file=H) for attr in attrs: - for line in attr['description'].splitlines(): + for line in attr["description"].splitlines(): print(" // %s" % line, file=H) - if attr.get('force-load-on-access', False): - print(" %s %s() const;" % - (MEMBER_TYPE[attr['type']], snake_to_pascal(attr['name'])), - file=H) + if attr.get("force-load-on-access", False): + print( + " %s %s() const;" + % (MEMBER_TYPE[attr["type"]], snake_to_pascal(attr["name"])), + file=H, + ) else: - print(" %s %s() const { return %s_; }" % - (RETURN_TYPE[attr['type']], snake_to_pascal( - attr['name']), attr['name']), - file=H) + print( + " %s %s() const { return %s_; }" + % ( + RETURN_TYPE[attr["type"]], + snake_to_pascal(attr["name"]), + attr["name"], + ), + file=H, + ) print(" private:", file=H) print(" explicit ConfigVars(const Overrides& overrides);", file=H) print(" static const ConfigVars& Load();", file=H) print(" static std::atomic config_vars_;", file=H) for attr in attrs_in_packing_order: - if attr.get('force-load-on-access', False): + if attr.get("force-load-on-access", False): continue - print(" %s %s_;" % (MEMBER_TYPE[attr['type']], attr['name']), file=H) + print(" %s %s_;" % (MEMBER_TYPE[attr["type"]], attr["name"]), file=H) for attr in attrs_in_packing_order: - if attr.get('force-load-on-access', False) == False: + if attr.get("force-load-on-access", False) == False: continue - print(" absl::optional<%s> override_%s_;" % - (MEMBER_TYPE[attr['type']], attr['name']), - file=H) + print( + " absl::optional<%s> override_%s_;" + % (MEMBER_TYPE[attr["type"]], attr["name"]), + file=H, + ) print("};", file=H) print(file=H) print("} // namespace grpc_core", file=H) print(file=H) print("#endif // GRPC_SRC_CORE_LIB_CONFIG_CONFIG_VARS_H", file=H) -with open('src/core/lib/config/config_vars.cc', 'w') as C: +with open("src/core/lib/config/config_vars.cc", "w") as C: put_copyright(C) - put_banner([C], [ - "", "Automatically generated by tools/codegen/core/gen_config_vars.py", - "" - ]) + put_banner( + [C], + [ + "", + "Automatically generated by tools/codegen/core/gen_config_vars.py", + "", + ], + ) print("#include ", file=C) - print("#include \"src/core/lib/config/config_vars.h\"", file=C) - print("#include \"src/core/lib/config/load_config.h\"", file=C) - print("#include \"absl/strings/escaping.h\"", file=C) - print("#include \"absl/flags/flag.h\"", file=C) + print('#include "src/core/lib/config/config_vars.h"', file=C) + print('#include "src/core/lib/config/load_config.h"', file=C) + print('#include "absl/strings/escaping.h"', file=C) + print('#include "absl/flags/flag.h"', file=C) print(file=C) for attr in attrs: - if 'prelude' in attr: - print(attr['prelude'], file=C) + if "prelude" in attr: + print(attr["prelude"], file=C) for attr in attrs: - print("ABSL_FLAG(%s, %s, {}, %s);" % - (FLAG_TYPE[attr["type"]], 'grpc_' + attr['name'], - c_str(attr['description'])), - file=C) + print( + "ABSL_FLAG(%s, %s, {}, %s);" + % ( + FLAG_TYPE[attr["type"]], + "grpc_" + attr["name"], + c_str(attr["description"]), + ), + file=C, + ) print(file=C) print("namespace grpc_core {", file=C) print(file=C) print("ConfigVars::ConfigVars(const Overrides& overrides) :", file=C) initializers = [ - "%s_(LoadConfig(FLAGS_grpc_%s, \"GRPC_%s\", overrides.%s, %s))" % - (attr['name'], attr['name'], attr['name'].upper(), attr['name'], - DEFAULT_VALUE[attr['type']](attr['default'], attr['name'])) + '%s_(LoadConfig(FLAGS_grpc_%s, "GRPC_%s", overrides.%s, %s))' + % ( + attr["name"], + attr["name"], + attr["name"].upper(), + attr["name"], + DEFAULT_VALUE[attr["type"]](attr["default"], attr["name"]), + ) for attr in attrs_in_packing_order - if attr.get('force-load-on-access', False) == False + if attr.get("force-load-on-access", False) == False ] initializers += [ - "override_%s_(overrides.%s)" % (attr['name'], attr['name']) + "override_%s_(overrides.%s)" % (attr["name"], attr["name"]) for attr in attrs_in_packing_order - if attr.get('force-load-on-access', False) + if attr.get("force-load-on-access", False) ] print(",".join(initializers), file=C) print("{}", file=C) print(file=C) for attr in attrs: - if attr.get('force-load-on-access', False): + if attr.get("force-load-on-access", False): print( - "%s ConfigVars::%s() const { return LoadConfig(FLAGS_grpc_%s, \"GRPC_%s\", override_%s_, %s); }" - % (MEMBER_TYPE[attr['type']], snake_to_pascal(attr['name']), - attr['name'], attr['name'].upper(), attr['name'], - DEFAULT_VALUE[attr['type']](attr['default'], attr['name'])), - file=C) + "%s ConfigVars::%s() const { return LoadConfig(FLAGS_grpc_%s," + ' "GRPC_%s", override_%s_, %s); }' + % ( + MEMBER_TYPE[attr["type"]], + snake_to_pascal(attr["name"]), + attr["name"], + attr["name"].upper(), + attr["name"], + DEFAULT_VALUE[attr["type"]](attr["default"], attr["name"]), + ), + file=C, + ) print(file=C) print("std::string ConfigVars::ToString() const {", file=C) print(" return absl::StrCat(", file=C) for i, attr in enumerate(attrs): if i: print(",", file=C) - print(c_str(", " + attr['name'] + ": "), file=C) + print(c_str(", " + attr["name"] + ": "), file=C) else: - print(c_str(attr['name'] + ": "), file=C) - print(",", - TO_STRING[attr['type']].replace( - "$", - snake_to_pascal(attr['name']) + "()"), - file=C) + print(c_str(attr["name"] + ": "), file=C) + print( + ",", + TO_STRING[attr["type"]].replace( + "$", snake_to_pascal(attr["name"]) + "()" + ), + file=C, + ) print(");}", file=C) print(file=C) print("}", file=C) diff --git a/tools/codegen/core/gen_experiments.py b/tools/codegen/core/gen_experiments.py index 31d88f99bdbae..7984ee1361544 100755 --- a/tools/codegen/core/gen_experiments.py +++ b/tools/codegen/core/gen_experiments.py @@ -29,31 +29,31 @@ import yaml DEFAULTS = { - 'broken': 'false', - False: 'false', - True: 'true', - 'debug': 'kDefaultForDebugOnly', + "broken": "false", + False: "false", + True: "true", + "debug": "kDefaultForDebugOnly", } FINAL_RETURN = { - 'broken': 'return false;', - False: 'return false;', - True: 'return true;', - 'debug': '\n#ifdef NDEBUG\nreturn false;\n#else\nreturn true;\n#endif\n', + "broken": "return false;", + False: "return false;", + True: "return true;", + "debug": "\n#ifdef NDEBUG\nreturn false;\n#else\nreturn true;\n#endif\n", } FINAL_DEFINE = { - 'broken': None, + "broken": None, False: None, - True: '#define %s', - 'debug': '#ifndef NDEBUG\n#define %s\n#endif', + True: "#define %s", + "debug": "#ifndef NDEBUG\n#define %s\n#endif", } BZL_LIST_FOR_DEFAULTS = { - 'broken': None, - False: 'off', - True: 'on', - 'debug': 'dbg', + "broken": None, + False: "off", + True: "on", + "debug": "dbg", } @@ -70,45 +70,46 @@ def ParseCommandLineArguments(args): # intentionally, We want more formatting than this class can provide. flag_parser = argparse.ArgumentParser() flag_parser.add_argument( - '--check', - action='store_false', - help='If specified, disables checking experiment expiry dates', + "--check", + action="store_false", + help="If specified, disables checking experiment expiry dates", ) flag_parser.add_argument( - '--disable_gen_hdrs', - action='store_true', - help='If specified, disables generation of experiments hdr files', + "--disable_gen_hdrs", + action="store_true", + help="If specified, disables generation of experiments hdr files", ) flag_parser.add_argument( - '--disable_gen_srcs', - action='store_true', - help='If specified, disables generation of experiments source files', + "--disable_gen_srcs", + action="store_true", + help="If specified, disables generation of experiments source files", ) flag_parser.add_argument( - '--disable_gen_bzl', - action='store_true', - help='If specified, disables generation of experiments.bzl file', + "--disable_gen_bzl", + action="store_true", + help="If specified, disables generation of experiments.bzl file", ) return flag_parser.parse_args(args) args = ParseCommandLineArguments(sys.argv[1:]) -with open('src/core/lib/experiments/experiments.yaml') as f: +with open("src/core/lib/experiments/experiments.yaml") as f: attrs = yaml.safe_load(f.read()) -with open('src/core/lib/experiments/rollouts.yaml') as f: +with open("src/core/lib/experiments/rollouts.yaml") as f: rollouts = yaml.safe_load(f.read()) -compiler = exp.ExperimentsCompiler(DEFAULTS, FINAL_RETURN, FINAL_DEFINE, - BZL_LIST_FOR_DEFAULTS) +compiler = exp.ExperimentsCompiler( + DEFAULTS, FINAL_RETURN, FINAL_DEFINE, BZL_LIST_FOR_DEFAULTS +) experiment_annotation = "gRPC Experiments: " for attr in attrs: exp_definition = exp.ExperimentDefinition(attr) if not exp_definition.IsValid(args.check): sys.exit(1) - experiment_annotation += exp_definition.name + ':0,' + experiment_annotation += exp_definition.name + ":0," if not compiler.AddExperimentDefinition(exp_definition): print("Experiment = %s ERROR adding" % exp_definition.name) sys.exit(1) @@ -124,12 +125,12 @@ def ParseCommandLineArguments(args): if not args.disable_gen_hdrs: print("Generating experiments headers") - compiler.GenerateExperimentsHdr('src/core/lib/experiments/experiments.h') + compiler.GenerateExperimentsHdr("src/core/lib/experiments/experiments.h") if not args.disable_gen_srcs: print("Generating experiments srcs") - compiler.GenerateExperimentsSrc('src/core/lib/experiments/experiments.cc') + compiler.GenerateExperimentsSrc("src/core/lib/experiments/experiments.cc") if not args.disable_gen_bzl: print("Generating experiments.bzl") - compiler.GenExperimentsBzl('bazel/experiments.bzl') + compiler.GenExperimentsBzl("bazel/experiments.bzl") diff --git a/tools/codegen/core/gen_grpc_tls_credentials_options.py b/tools/codegen/core/gen_grpc_tls_credentials_options.py index ceec9d1960594..7aae7774353e2 100755 --- a/tools/codegen/core/gen_grpc_tls_credentials_options.py +++ b/tools/codegen/core/gen_grpc_tls_credentials_options.py @@ -35,138 +35,198 @@ class DataMember: test_name: str # The name to use for the associated test test_value_1: str # Test-specific value to use for comparison test_value_2: str # Test-specific value (different from test_value_1) - default_initializer: str = '' # If non-empty, this will be used as the default initialization of this field - getter_comment: str = '' # Comment to add before the getter for this field - special_getter_return_type: str = '' # Override for the return type of getter (eg. const std::string&) - override_getter: str = '' # Override for the entire getter method. Relevant for certificate_verifier and certificate_provider - setter_comment: str = '' # Commend to add before the setter for this field + default_initializer: str = ( # If non-empty, this will be used as the default initialization of this field + "" + ) + getter_comment: str = "" # Comment to add before the getter for this field + special_getter_return_type: str = ( # Override for the return type of getter (eg. const std::string&) + "" + ) + override_getter: str = ( # Override for the entire getter method. Relevant for certificate_verifier and certificate_provider + "" + ) + setter_comment: str = "" # Commend to add before the setter for this field setter_move_semantics: bool = False # Should the setter use move-semantics - special_comparator: str = '' # If non-empty, this will be used in `operator==` + special_comparator: str = ( # If non-empty, this will be used in `operator==` + "" + ) _DATA_MEMBERS = [ - DataMember(name='cert_request_type', - type='grpc_ssl_client_certificate_request_type', - default_initializer='GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE', - test_name="DifferentCertRequestType", - test_value_1="GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE", - test_value_2="GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY"), - DataMember(name='verify_server_cert', - type='bool', - default_initializer='true', - test_name="DifferentVerifyServerCert", - test_value_1="false", - test_value_2="true"), - DataMember(name='min_tls_version', - type='grpc_tls_version', - default_initializer='grpc_tls_version::TLS1_2', - test_name="DifferentMinTlsVersion", - test_value_1="grpc_tls_version::TLS1_2", - test_value_2="grpc_tls_version::TLS1_3"), - DataMember(name='max_tls_version', - type='grpc_tls_version', - default_initializer='grpc_tls_version::TLS1_3', - test_name="DifferentMaxTlsVersion", - test_value_1="grpc_tls_version::TLS1_2", - test_value_2="grpc_tls_version::TLS1_3"), DataMember( - name='certificate_verifier', - type='grpc_core::RefCountedPtr', + name="cert_request_type", + type="grpc_ssl_client_certificate_request_type", + default_initializer="GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE", + test_name="DifferentCertRequestType", + test_value_1="GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE", + test_value_2="GRPC_SSL_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY", + ), + DataMember( + name="verify_server_cert", + type="bool", + default_initializer="true", + test_name="DifferentVerifyServerCert", + test_value_1="false", + test_value_2="true", + ), + DataMember( + name="min_tls_version", + type="grpc_tls_version", + default_initializer="grpc_tls_version::TLS1_2", + test_name="DifferentMinTlsVersion", + test_value_1="grpc_tls_version::TLS1_2", + test_value_2="grpc_tls_version::TLS1_3", + ), + DataMember( + name="max_tls_version", + type="grpc_tls_version", + default_initializer="grpc_tls_version::TLS1_3", + test_name="DifferentMaxTlsVersion", + test_value_1="grpc_tls_version::TLS1_2", + test_value_2="grpc_tls_version::TLS1_3", + ), + DataMember( + name="certificate_verifier", + type="grpc_core::RefCountedPtr", override_getter="""grpc_tls_certificate_verifier* certificate_verifier() { return certificate_verifier_.get(); }""", setter_move_semantics=True, - special_comparator= - '(certificate_verifier_ == other.certificate_verifier_ || (certificate_verifier_ != nullptr && other.certificate_verifier_ != nullptr && certificate_verifier_->Compare(other.certificate_verifier_.get()) == 0))', + special_comparator=( + "(certificate_verifier_ == other.certificate_verifier_ ||" + " (certificate_verifier_ != nullptr && other.certificate_verifier_" + " != nullptr &&" + " certificate_verifier_->Compare(other.certificate_verifier_.get())" + " == 0))" + ), test_name="DifferentCertificateVerifier", test_value_1="MakeRefCounted()", - test_value_2="MakeRefCounted(nullptr, \"\")"), - DataMember(name='check_call_host', - type='bool', - default_initializer='true', - test_name="DifferentCheckCallHost", - test_value_1="false", - test_value_2="true"), + test_value_2='MakeRefCounted(nullptr, "")', + ), DataMember( - name='certificate_provider', - type='grpc_core::RefCountedPtr', - getter_comment= - 'Returns the distributor from certificate_provider_ if it is set, nullptr otherwise.', - override_getter= - """grpc_tls_certificate_distributor* certificate_distributor() { + name="check_call_host", + type="bool", + default_initializer="true", + test_name="DifferentCheckCallHost", + test_value_1="false", + test_value_2="true", + ), + DataMember( + name="certificate_provider", + type="grpc_core::RefCountedPtr", + getter_comment=( + "Returns the distributor from certificate_provider_ if it is set," + " nullptr otherwise." + ), + override_getter="""grpc_tls_certificate_distributor* certificate_distributor() { if (certificate_provider_ != nullptr) { return certificate_provider_->distributor().get(); } return nullptr; }""", setter_move_semantics=True, - special_comparator= - '(certificate_provider_ == other.certificate_provider_ || (certificate_provider_ != nullptr && other.certificate_provider_ != nullptr && certificate_provider_->Compare(other.certificate_provider_.get()) == 0))', + special_comparator=( + "(certificate_provider_ == other.certificate_provider_ ||" + " (certificate_provider_ != nullptr && other.certificate_provider_" + " != nullptr &&" + " certificate_provider_->Compare(other.certificate_provider_.get())" + " == 0))" + ), test_name="DifferentCertificateProvider", - test_value_1= - "MakeRefCounted(\"root_cert_1\", PemKeyCertPairList())", - test_value_2= - "MakeRefCounted(\"root_cert_2\", PemKeyCertPairList())" + test_value_1=( + 'MakeRefCounted("root_cert_1",' + " PemKeyCertPairList())" + ), + test_value_2=( + 'MakeRefCounted("root_cert_2",' + " PemKeyCertPairList())" + ), ), DataMember( - name='watch_root_cert', - type='bool', - default_initializer='false', - setter_comment= - 'If need to watch the updates of root certificates with name |root_cert_name|. The default value is false. If used in tls_credentials, it should always be set to true unless the root certificates are not needed.', + name="watch_root_cert", + type="bool", + default_initializer="false", + setter_comment=( + "If need to watch the updates of root certificates with name" + " |root_cert_name|. The default value is false. If used in" + " tls_credentials, it should always be set to true unless the root" + " certificates are not needed." + ), test_name="DifferentWatchRootCert", test_value_1="false", - test_value_2="true"), + test_value_2="true", + ), DataMember( - name='root_cert_name', - type='std::string', - special_getter_return_type='const std::string&', - setter_comment= - 'Sets the name of root certificates being watched, if |set_watch_root_cert| is called. If not set, an empty string will be used as the name.', + name="root_cert_name", + type="std::string", + special_getter_return_type="const std::string&", + setter_comment=( + "Sets the name of root certificates being watched, if" + " |set_watch_root_cert| is called. If not set, an empty string will" + " be used as the name." + ), setter_move_semantics=True, test_name="DifferentRootCertName", - test_value_1="\"root_cert_name_1\"", - test_value_2="\"root_cert_name_2\""), + test_value_1='"root_cert_name_1"', + test_value_2='"root_cert_name_2"', + ), DataMember( - name='watch_identity_pair', - type='bool', - default_initializer='false', - setter_comment= - 'If need to watch the updates of identity certificates with name |identity_cert_name|. The default value is false. If used in tls_credentials, it should always be set to true unless the identity key-cert pairs are not needed.', + name="watch_identity_pair", + type="bool", + default_initializer="false", + setter_comment=( + "If need to watch the updates of identity certificates with name" + " |identity_cert_name|. The default value is false. If used in" + " tls_credentials, it should always be set to true unless the" + " identity key-cert pairs are not needed." + ), test_name="DifferentWatchIdentityPair", test_value_1="false", - test_value_2="true"), + test_value_2="true", + ), DataMember( - name='identity_cert_name', - type='std::string', - special_getter_return_type='const std::string&', - setter_comment= - 'Sets the name of identity key-cert pairs being watched, if |set_watch_identity_pair| is called. If not set, an empty string will be used as the name.', + name="identity_cert_name", + type="std::string", + special_getter_return_type="const std::string&", + setter_comment=( + "Sets the name of identity key-cert pairs being watched, if" + " |set_watch_identity_pair| is called. If not set, an empty string" + " will be used as the name." + ), setter_move_semantics=True, test_name="DifferentIdentityCertName", - test_value_1="\"identity_cert_name_1\"", - test_value_2="\"identity_cert_name_2\""), - DataMember(name='tls_session_key_log_file_path', - type='std::string', - special_getter_return_type='const std::string&', - setter_move_semantics=True, - test_name="DifferentTlsSessionKeyLogFilePath", - test_value_1="\"file_path_1\"", - test_value_2="\"file_path_2\""), + test_value_1='"identity_cert_name_1"', + test_value_2='"identity_cert_name_2"', + ), + DataMember( + name="tls_session_key_log_file_path", + type="std::string", + special_getter_return_type="const std::string&", + setter_move_semantics=True, + test_name="DifferentTlsSessionKeyLogFilePath", + test_value_1='"file_path_1"', + test_value_2='"file_path_2"', + ), DataMember( - name='crl_directory', - type='std::string', - special_getter_return_type='const std::string&', - setter_comment= - ' gRPC will enforce CRLs on all handshakes from all hashed CRL files inside of the crl_directory. If not set, an empty string will be used, which will not enable CRL checking. Only supported for OpenSSL version > 1.1.', + name="crl_directory", + type="std::string", + special_getter_return_type="const std::string&", + setter_comment=( + " gRPC will enforce CRLs on all handshakes from all hashed CRL" + " files inside of the crl_directory. If not set, an empty string" + " will be used, which will not enable CRL checking. Only supported" + " for OpenSSL version > 1.1." + ), setter_move_semantics=True, test_name="DifferentCrlDirectory", - test_value_1="\"crl_directory_1\"", - test_value_2="\"crl_directory_2\"") + test_value_1='"crl_directory_1"', + test_value_2='"crl_directory_2"', + ), ] # print copyright notice from this file def put_copyright(f, year): - print("""// + print( + """// // // Copyright %s gRPC authors. // @@ -183,8 +243,10 @@ def put_copyright(f, year): // limitations under the License. // // -""" % (year), - file=f) +""" + % (year), + file=f, + ) # Prints differences between two files @@ -193,10 +255,9 @@ def get_file_differences(file1, file2): file1_text = f1.readlines() with open(file2) as f2: file2_text = f2.readlines() - return difflib.unified_diff(file1_text, - file2_text, - fromfile=file1, - tofile=file2) + return difflib.unified_diff( + file1_text, file2_text, fromfile=file1, tofile=file2 + ) # Is this script executed in test mode? @@ -204,17 +265,20 @@ def get_file_differences(file1, file2): if len(sys.argv) > 1 and sys.argv[1] == "--test": test_mode = True -HEADER_FILE_NAME = 'src/core/lib/security/credentials/tls/grpc_tls_credentials_options.h' +HEADER_FILE_NAME = ( + "src/core/lib/security/credentials/tls/grpc_tls_credentials_options.h" +) # Generate src/core/lib/security/credentials/tls/grpc_tls_credentials_options.h header_file_name = HEADER_FILE_NAME -if (test_mode): +if test_mode: header_file_name = tempfile.NamedTemporaryFile(delete=False).name -H = open(header_file_name, 'w') +H = open(header_file_name, "w") -put_copyright(H, '2018') +put_copyright(H, "2018") print( - '// Generated by tools/codegen/core/gen_grpc_tls_credentials_options.py\n', - file=H) + "// Generated by tools/codegen/core/gen_grpc_tls_credentials_options.py\n", + file=H, +) print( """#ifndef GRPC_SRC_CORE_LIB_SECURITY_CREDENTIALS_TLS_GRPC_TLS_CREDENTIALS_OPTIONS_H #define GRPC_SRC_CORE_LIB_SECURITY_CREDENTIALS_TLS_GRPC_TLS_CREDENTIALS_OPTIONS_H @@ -239,88 +303,128 @@ def get_file_differences(file1, file2): public: ~grpc_tls_credentials_options() override = default; """, - file=H) + file=H, +) # Print out getters for all data members print(" // Getters for member fields.", file=H) for data_member in _DATA_MEMBERS: - if data_member.getter_comment != '': + if data_member.getter_comment != "": print(" // " + data_member.getter_comment, file=H) if data_member.override_getter: print(" " + data_member.override_getter, file=H) else: print( - " %s %s() const { return %s; }" % - (data_member.special_getter_return_type if - data_member.special_getter_return_type != '' else data_member.type, - data_member.name, data_member.name + '_'), - file=H) + " %s %s() const { return %s; }" + % ( + data_member.special_getter_return_type + if data_member.special_getter_return_type != "" + else data_member.type, + data_member.name, + data_member.name + "_", + ), + file=H, + ) # Print out setters for all data members print("", file=H) print(" // Setters for member fields.", file=H) for data_member in _DATA_MEMBERS: - if data_member.setter_comment != '': + if data_member.setter_comment != "": print(" // " + data_member.setter_comment, file=H) - if (data_member.setter_move_semantics): - print(" void set_%s(%s %s) { %s_ = std::move(%s); }" % - (data_member.name, data_member.type, data_member.name, - data_member.name, data_member.name), - file=H) + if data_member.setter_move_semantics: + print( + " void set_%s(%s %s) { %s_ = std::move(%s); }" + % ( + data_member.name, + data_member.type, + data_member.name, + data_member.name, + data_member.name, + ), + file=H, + ) else: - print(" void set_%s(%s %s) { %s_ = %s; }" % - (data_member.name, data_member.type, data_member.name, - data_member.name, data_member.name), - file=H) + print( + " void set_%s(%s %s) { %s_ = %s; }" + % ( + data_member.name, + data_member.type, + data_member.name, + data_member.name, + data_member.name, + ), + file=H, + ) # Write out operator== -print("\n bool operator==(const grpc_tls_credentials_options& other) const {", - file=H) +print( + "\n bool operator==(const grpc_tls_credentials_options& other) const {", + file=H, +) operator_equal_content = " return " for i in range(len(_DATA_MEMBERS)): - if (i != 0): + if i != 0: operator_equal_content += " " - if (_DATA_MEMBERS[i].special_comparator != ''): + if _DATA_MEMBERS[i].special_comparator != "": operator_equal_content += _DATA_MEMBERS[i].special_comparator else: - operator_equal_content += _DATA_MEMBERS[ - i].name + "_ == other." + _DATA_MEMBERS[i].name + "_" - if (i != len(_DATA_MEMBERS) - 1): - operator_equal_content += ' &&\n' + operator_equal_content += ( + _DATA_MEMBERS[i].name + "_ == other." + _DATA_MEMBERS[i].name + "_" + ) + if i != len(_DATA_MEMBERS) - 1: + operator_equal_content += " &&\n" print(operator_equal_content + ";\n }", file=H) -#Print out data member declarations +# Print out data member declarations print("\n private:", file=H) for data_member in _DATA_MEMBERS: - if data_member.default_initializer == '': - print(" %s %s_;" % ( - data_member.type, - data_member.name, - ), file=H) + if data_member.default_initializer == "": + print( + " %s %s_;" + % ( + data_member.type, + data_member.name, + ), + file=H, + ) else: - print(" %s %s_ = %s;" % (data_member.type, data_member.name, - data_member.default_initializer), - file=H) + print( + " %s %s_ = %s;" + % ( + data_member.type, + data_member.name, + data_member.default_initializer, + ), + file=H, + ) # Print out file ending -print("""}; +print( + """}; #endif // GRPC_SRC_CORE_LIB_SECURITY_CREDENTIALS_TLS_GRPC_TLS_CREDENTIALS_OPTIONS_H""", - file=H) + file=H, +) H.close() # Generate test/core/security/grpc_tls_credentials_options_comparator_test.cc -TEST_FILE_NAME = 'test/core/security/grpc_tls_credentials_options_comparator_test.cc' +TEST_FILE_NAME = ( + "test/core/security/grpc_tls_credentials_options_comparator_test.cc" +) test_file_name = TEST_FILE_NAME -if (test_mode): +if test_mode: test_file_name = tempfile.NamedTemporaryFile(delete=False).name -T = open(test_file_name, 'w') +T = open(test_file_name, "w") -put_copyright(T, '2022') -print('// Generated by tools/codegen/core/gen_grpc_tls_credentials_options.py', - file=T) -print(""" +put_copyright(T, "2022") +print( + "// Generated by tools/codegen/core/gen_grpc_tls_credentials_options.py", + file=T, +) +print( + """ #include #include @@ -334,11 +438,13 @@ def get_file_differences(file1, file2): namespace grpc_core { namespace { """, - file=T) + file=T, +) # Generate negative test for each negative member for data_member in _DATA_MEMBERS: - print("""TEST(TlsCredentialsOptionsComparatorTest, %s) { + print( + """TEST(TlsCredentialsOptionsComparatorTest, %s) { auto* options_1 = grpc_tls_credentials_options_create(); auto* options_2 = grpc_tls_credentials_options_create(); options_1->set_%s(%s); @@ -347,12 +453,20 @@ def get_file_differences(file1, file2): EXPECT_FALSE(*options_2 == *options_1); delete options_1; delete options_2; -}""" % (data_member.test_name, data_member.name, data_member.test_value_1, - data_member.name, data_member.test_value_2), - file=T) +}""" + % ( + data_member.test_name, + data_member.name, + data_member.test_value_1, + data_member.name, + data_member.test_value_2, + ), + file=T, + ) # Print out file ending -print(""" +print( + """ } // namespace } // namespace grpc_core @@ -364,10 +478,11 @@ def get_file_differences(file1, file2): grpc_shutdown(); return result; }""", - file=T) + file=T, +) T.close() -if (test_mode): +if test_mode: header_diff = get_file_differences(header_file_name, HEADER_FILE_NAME) test_diff = get_file_differences(test_file_name, TEST_FILE_NAME) os.unlink(header_file_name) @@ -378,8 +493,9 @@ def get_file_differences(file1, file2): header_error = True if header_error: print( - HEADER_FILE_NAME + - ' should not be manually modified. Please make changes to tools/distrib/gen_grpc_tls_credentials_options.py instead.' + HEADER_FILE_NAME + + " should not be manually modified. Please make changes to" + " tools/distrib/gen_grpc_tls_credentials_options.py instead." ) test_error = False for line in test_diff: @@ -387,8 +503,9 @@ def get_file_differences(file1, file2): test_error = True if test_error: print( - TEST_FILE_NAME + - ' should not be manually modified. Please make changes to tools/distrib/gen_grpc_tls_credentials_options.py instead.' + TEST_FILE_NAME + + " should not be manually modified. Please make changes to" + " tools/distrib/gen_grpc_tls_credentials_options.py instead." ) - if (header_error or test_error): + if header_error or test_error: sys.exit(1) diff --git a/tools/codegen/core/gen_header_frame.py b/tools/codegen/core/gen_header_frame.py index a501330ec9e05..0d7d4c9511966 100755 --- a/tools/codegen/core/gen_header_frame.py +++ b/tools/codegen/core/gen_header_frame.py @@ -26,20 +26,20 @@ def append_never_indexed(payload_line, n, count, key, value, value_is_huff): payload_line.append(0x10) - assert (len(key) <= 126) + assert len(key) <= 126 payload_line.append(len(key)) payload_line.extend(ord(c) for c in key) - assert (len(value) <= 126) + assert len(value) <= 126 payload_line.append(len(value) + (0x80 if value_is_huff else 0)) payload_line.extend(value) def append_inc_indexed(payload_line, n, count, key, value, value_is_huff): payload_line.append(0x40) - assert (len(key) <= 126) + assert len(key) <= 126 payload_line.append(len(key)) payload_line.extend(ord(c) for c in key) - assert (len(value) <= 126) + assert len(value) <= 126 payload_line.append(len(value) + (0x80 if value_is_huff else 0)) payload_line.extend(value) @@ -50,21 +50,21 @@ def append_pre_indexed(payload_line, n, count, key, value, value_is_huff): def esc_c(line): - out = "\"" + out = '"' last_was_hex = False for c in line: if 32 <= c < 127: if c in hex_bytes and last_was_hex: - out += "\"\"" + out += '""' if c != ord('"'): out += chr(c) else: - out += "\\\"" + out += '\\"' last_was_hex = False else: out += "\\x%02x" % c last_was_hex = True - return out + "\"" + return out + '"' def output_c(payload_bytes): @@ -76,61 +76,61 @@ def output_hex(payload_bytes): all_bytes = [] for line in payload_bytes: all_bytes.extend(line) - print(('{%s}' % ', '.join('0x%02x' % c for c in all_bytes))) + print(("{%s}" % ", ".join("0x%02x" % c for c in all_bytes))) def output_hexstr(payload_bytes): all_bytes = [] for line in payload_bytes: all_bytes.extend(line) - print(('%s' % ''.join('%02x' % c for c in all_bytes))) + print(("%s" % "".join("%02x" % c for c in all_bytes))) _COMPRESSORS = { - 'never': append_never_indexed, - 'inc': append_inc_indexed, - 'pre': append_pre_indexed, + "never": append_never_indexed, + "inc": append_inc_indexed, + "pre": append_pre_indexed, } _OUTPUTS = { - 'c': output_c, - 'hex': output_hex, - 'hexstr': output_hexstr, + "c": output_c, + "hex": output_hex, + "hexstr": output_hexstr, } -argp = argparse.ArgumentParser('Generate header frames') -argp.add_argument('--set_end_stream', - default=False, - action='store_const', - const=True) -argp.add_argument('--no_framing', - default=False, - action='store_const', - const=True) -argp.add_argument('--compression', - choices=sorted(_COMPRESSORS.keys()), - default='never') -argp.add_argument('--huff', default=False, action='store_const', const=True) -argp.add_argument('--output', default='c', choices=sorted(_OUTPUTS.keys())) +argp = argparse.ArgumentParser("Generate header frames") +argp.add_argument( + "--set_end_stream", default=False, action="store_const", const=True +) +argp.add_argument( + "--no_framing", default=False, action="store_const", const=True +) +argp.add_argument( + "--compression", choices=sorted(_COMPRESSORS.keys()), default="never" +) +argp.add_argument("--huff", default=False, action="store_const", const=True) +argp.add_argument("--output", default="c", choices=sorted(_OUTPUTS.keys())) args = argp.parse_args() # parse input, fill in vals vals = [] for line in sys.stdin: line = line.strip() - if line == '': + if line == "": continue - if line[0] == '#': + if line[0] == "#": continue - key_tail, value = line[1:].split(':') + key_tail, value = line[1:].split(":") key = (line[0] + key_tail).strip() - value = value.strip().encode('ascii') + value = value.strip().encode("ascii") if args.huff: from hpack.huffman import HuffmanEncoder from hpack.huffman_constants import REQUEST_CODES from hpack.huffman_constants import REQUEST_CODES_LENGTH - value = HuffmanEncoder(REQUEST_CODES, - REQUEST_CODES_LENGTH).encode(value) + + value = HuffmanEncoder(REQUEST_CODES, REQUEST_CODES_LENGTH).encode( + value + ) vals.append((key, value)) # generate frame payload binary data @@ -141,8 +141,9 @@ def output_hexstr(payload_bytes): n = 0 for key, value in vals: payload_line = [] - _COMPRESSORS[args.compression](payload_line, n, len(vals), key, value, - args.huff) + _COMPRESSORS[args.compression]( + payload_line, n, len(vals), key, value, args.huff + ) n += 1 payload_len += len(payload_line) payload_bytes.append(payload_line) @@ -152,20 +153,22 @@ def output_hexstr(payload_bytes): flags = 0x04 # END_HEADERS if args.set_end_stream: flags |= 0x01 # END_STREAM - payload_bytes[0].extend([ - (payload_len >> 16) & 0xff, - (payload_len >> 8) & 0xff, - (payload_len) & 0xff, - # header frame - 0x01, - # flags - flags, - # stream id - 0x00, - 0x00, - 0x00, - 0x01 - ]) + payload_bytes[0].extend( + [ + (payload_len >> 16) & 0xFF, + (payload_len >> 8) & 0xFF, + (payload_len) & 0xFF, + # header frame + 0x01, + # flags + flags, + # stream id + 0x00, + 0x00, + 0x00, + 0x01, + ] + ) hex_bytes = [ord(c) for c in "abcdefABCDEF0123456789"] diff --git a/tools/codegen/core/gen_if_list.py b/tools/codegen/core/gen_if_list.py index d2dad87393472..0411ef3ef8860 100755 --- a/tools/codegen/core/gen_if_list.py +++ b/tools/codegen/core/gen_if_list.py @@ -21,23 +21,23 @@ def put_banner(files, banner): for f in files: for line in banner: - print('// %s' % line, file=f) - print('', file=f) + print("// %s" % line, file=f) + print("", file=f) -with open('src/core/lib/gprpp/if_list.h', 'w') as H: +with open("src/core/lib/gprpp/if_list.h", "w") as H: # copy-paste copyright notice from this file with open(sys.argv[0]) as my_source: copyright = [] for line in my_source: - if line[0] != '#': + if line[0] != "#": break for line in my_source: - if line[0] == '#': + if line[0] == "#": copyright.append(line) break for line in my_source: - if line[0] != '#': + if line[0] != "#": break copyright.append(line) put_banner([H], [line[2:].rstrip() for line in copyright]) @@ -46,31 +46,36 @@ def put_banner(files, banner): print("#ifndef GRPC_CORE_LIB_GPRPP_IF_LIST_H", file=H) print("#define GRPC_CORE_LIB_GPRPP_IF_LIST_H", file=H) - print('', file=H) - print('#include ', file=H) - print('', file=H) + print("", file=H) + print("#include ", file=H) + print("", file=H) print("#include ", file=H) - print('', file=H) + print("", file=H) print("namespace grpc_core {", file=H) for n in range(1, 64): - print('', file=H) + print("", file=H) print( - "template auto IfList(CheckArg input, ActionArg action_arg, ActionFail action_fail, %s, %s) {" + "template auto IfList(CheckArg input, ActionArg" + " action_arg, ActionFail action_fail, %s, %s) {" % ( ", ".join("typename Check%d" % (i) for i in range(0, n)), ", ".join("typename Action%d" % (i) for i in range(0, n)), ", ".join("Check%d check%d" % (i, i) for i in range(0, n)), ", ".join("Action%d action%d" % (i, i) for i in range(0, n)), ), - file=H) + file=H, + ) for i in range(0, n): - print(" if (check%d(input)) return action%d(action_arg);" % (i, i), - file=H) + print( + " if (check%d(input)) return action%d(action_arg);" % (i, i), + file=H, + ) print(" return action_fail(action_arg);", file=H) print("}", file=H) - print('', file=H) + print("", file=H) print("}", file=H) - print('', file=H) + print("", file=H) print("#endif // GRPC_CORE_LIB_GPRPP_IF_LIST_H", file=H) diff --git a/tools/codegen/core/gen_server_registered_method_bad_client_test_body.py b/tools/codegen/core/gen_server_registered_method_bad_client_test_body.py index f4f25a69a9363..496894c1514f8 100755 --- a/tools/codegen/core/gen_server_registered_method_bad_client_test_body.py +++ b/tools/codegen/core/gen_server_registered_method_bad_client_test_body.py @@ -16,21 +16,21 @@ def esc_c(line): - out = "\"" + out = '"' last_was_hex = False for c in line: if 32 <= c < 127: if c in hex_bytes and last_was_hex: - out += "\"\"" + out += '""' if c != ord('"'): out += chr(c) else: - out += "\\\"" + out += '\\"' last_was_hex = False else: out += "\\x%02x" % c last_was_hex = True - return out + "\"" + return out + '"' done = set() @@ -38,19 +38,36 @@ def esc_c(line): for message_length in range(0, 3): for send_message_length in range(0, message_length + 1): payload = [ - 0, (message_length >> 24) & 0xff, (message_length >> 16) & 0xff, - (message_length >> 8) & 0xff, (message_length) & 0xff + 0, + (message_length >> 24) & 0xFF, + (message_length >> 16) & 0xFF, + (message_length >> 8) & 0xFF, + (message_length) & 0xFF, ] + send_message_length * [0] for frame_length in range(0, len(payload) + 1): - is_end = frame_length == len( - payload) and send_message_length == message_length - frame = [(frame_length >> 16) & 0xff, (frame_length >> 8) & 0xff, - (frame_length) & 0xff, 0, 1 if is_end else 0, 0, 0, 0, 1 - ] + payload[0:frame_length] + is_end = ( + frame_length == len(payload) + and send_message_length == message_length + ) + frame = [ + (frame_length >> 16) & 0xFF, + (frame_length >> 8) & 0xFF, + (frame_length) & 0xFF, + 0, + 1 if is_end else 0, + 0, + 0, + 0, + 1, + ] + payload[0:frame_length] text = esc_c(frame) if text not in done: print( - ('GRPC_RUN_BAD_CLIENT_TEST(verifier_%s, PFX_STR %s, %s);' % - ('succeeds' if is_end else 'fails', text, - '0' if is_end else 'GRPC_BAD_CLIENT_DISCONNECT'))) + "GRPC_RUN_BAD_CLIENT_TEST(verifier_%s, PFX_STR %s, %s);" + % ( + "succeeds" if is_end else "fails", + text, + "0" if is_end else "GRPC_BAD_CLIENT_DISCONNECT", + ) + ) done.add(text) diff --git a/tools/codegen/core/gen_settings_ids.py b/tools/codegen/core/gen_settings_ids.py index ff532a4f091e0..1ac8740748663 100755 --- a/tools/codegen/core/gen_settings_ids.py +++ b/tools/codegen/core/gen_settings_ids.py @@ -23,48 +23,60 @@ _MAX_HEADER_LIST_SIZE = 16 * 1024 * 1024 -Setting = collections.namedtuple('Setting', 'id default min max on_error') -OnError = collections.namedtuple('OnError', 'behavior code') -clamp_invalid_value = OnError('CLAMP_INVALID_VALUE', 'PROTOCOL_ERROR') -disconnect_on_invalid_value = lambda e: OnError('DISCONNECT_ON_INVALID_VALUE', e - ) -DecoratedSetting = collections.namedtuple('DecoratedSetting', - 'enum name setting') +Setting = collections.namedtuple("Setting", "id default min max on_error") +OnError = collections.namedtuple("OnError", "behavior code") +clamp_invalid_value = OnError("CLAMP_INVALID_VALUE", "PROTOCOL_ERROR") +disconnect_on_invalid_value = lambda e: OnError( + "DISCONNECT_ON_INVALID_VALUE", e +) +DecoratedSetting = collections.namedtuple( + "DecoratedSetting", "enum name setting" +) _SETTINGS = { - 'HEADER_TABLE_SIZE': - Setting(1, 4096, 0, 0xffffffff, clamp_invalid_value), - 'ENABLE_PUSH': - Setting(2, 1, 0, 1, disconnect_on_invalid_value('PROTOCOL_ERROR')), - 'MAX_CONCURRENT_STREAMS': - Setting(3, 0xffffffff, 0, 0xffffffff, - disconnect_on_invalid_value('PROTOCOL_ERROR')), - 'INITIAL_WINDOW_SIZE': - Setting(4, 65535, 0, 0x7fffffff, - disconnect_on_invalid_value('FLOW_CONTROL_ERROR')), - 'MAX_FRAME_SIZE': - Setting(5, 16384, 16384, 16777215, - disconnect_on_invalid_value('PROTOCOL_ERROR')), - 'MAX_HEADER_LIST_SIZE': - Setting(6, _MAX_HEADER_LIST_SIZE, 0, _MAX_HEADER_LIST_SIZE, - clamp_invalid_value), - 'GRPC_ALLOW_TRUE_BINARY_METADATA': - Setting(0xfe03, 0, 0, 1, clamp_invalid_value), - 'GRPC_PREFERRED_RECEIVE_CRYPTO_FRAME_SIZE': - Setting(0xfe04, 0, 16384, 0x7fffffff, clamp_invalid_value), + "HEADER_TABLE_SIZE": Setting(1, 4096, 0, 0xFFFFFFFF, clamp_invalid_value), + "ENABLE_PUSH": Setting( + 2, 1, 0, 1, disconnect_on_invalid_value("PROTOCOL_ERROR") + ), + "MAX_CONCURRENT_STREAMS": Setting( + 3, + 0xFFFFFFFF, + 0, + 0xFFFFFFFF, + disconnect_on_invalid_value("PROTOCOL_ERROR"), + ), + "INITIAL_WINDOW_SIZE": Setting( + 4, + 65535, + 0, + 0x7FFFFFFF, + disconnect_on_invalid_value("FLOW_CONTROL_ERROR"), + ), + "MAX_FRAME_SIZE": Setting( + 5, 16384, 16384, 16777215, disconnect_on_invalid_value("PROTOCOL_ERROR") + ), + "MAX_HEADER_LIST_SIZE": Setting( + 6, _MAX_HEADER_LIST_SIZE, 0, _MAX_HEADER_LIST_SIZE, clamp_invalid_value + ), + "GRPC_ALLOW_TRUE_BINARY_METADATA": Setting( + 0xFE03, 0, 0, 1, clamp_invalid_value + ), + "GRPC_PREFERRED_RECEIVE_CRYPTO_FRAME_SIZE": Setting( + 0xFE04, 0, 16384, 0x7FFFFFFF, clamp_invalid_value + ), } -H = open('src/core/ext/transport/chttp2/transport/http2_settings.h', 'w') -C = open('src/core/ext/transport/chttp2/transport/http2_settings.cc', 'w') +H = open("src/core/ext/transport/chttp2/transport/http2_settings.h", "w") +C = open("src/core/ext/transport/chttp2/transport/http2_settings.cc", "w") # utility: print a big comment block into a set of files def put_banner(files, banner): for f in files: - print('/*', file=f) + print("/*", file=f) for line in banner: - print(' * %s' % line, file=f) - print(' */', file=f) + print(" * %s" % line, file=f) + print(" */", file=f) print(file=f) @@ -72,37 +84,42 @@ def put_banner(files, banner): with open(sys.argv[0]) as my_source: copyright = [] for line in my_source: - if line[0] != '#': + if line[0] != "#": break for line in my_source: - if line[0] == '#': + if line[0] == "#": copyright.append(line) break for line in my_source: - if line[0] != '#': + if line[0] != "#": break copyright.append(line) put_banner([H, C], [line[2:].rstrip() for line in copyright]) put_banner( [H, C], - ["Automatically generated by tools/codegen/core/gen_settings_ids.py"]) + ["Automatically generated by tools/codegen/core/gen_settings_ids.py"], +) -print("#ifndef GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_HTTP2_SETTINGS_H", - file=H) -print("#define GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_HTTP2_SETTINGS_H", - file=H) +print( + "#ifndef GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_HTTP2_SETTINGS_H", file=H +) +print( + "#define GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_HTTP2_SETTINGS_H", file=H +) print(file=H) print("#include ", file=H) print("#include ", file=H) print(file=H) print("#include ", file=C) -print("#include \"src/core/ext/transport/chttp2/transport/http2_settings.h\"", - file=C) +print( + '#include "src/core/ext/transport/chttp2/transport/http2_settings.h"', + file=C, +) print(file=C) -print("#include \"src/core/lib/gpr/useful.h\"", file=C) -print("#include \"src/core/lib/transport/http2_errors.h\"", file=C) +print('#include "src/core/lib/gpr/useful.h"', file=C) +print('#include "src/core/lib/transport/http2_errors.h"', file=C) print(file=C) p = perfection.hash_parameters(sorted(x.id for x in list(_SETTINGS.values()))) @@ -121,57 +138,77 @@ def hash(i): for name, setting in _SETTINGS.items() ] -print('typedef enum {', file=H) +print("typedef enum {", file=H) for decorated_setting in sorted(decorated_settings): - print(' GRPC_CHTTP2_SETTINGS_%s = %d, /* wire id %d */' % - (decorated_setting.name, decorated_setting.enum, - decorated_setting.setting.id), - file=H) -print('} grpc_chttp2_setting_id;', file=H) + print( + " GRPC_CHTTP2_SETTINGS_%s = %d, /* wire id %d */" + % ( + decorated_setting.name, + decorated_setting.enum, + decorated_setting.setting.id, + ), + file=H, + ) +print("} grpc_chttp2_setting_id;", file=H) print(file=H) -print('#define GRPC_CHTTP2_NUM_SETTINGS %d' % - (max(x.enum for x in decorated_settings) + 1), - file=H) - -print('extern const uint16_t grpc_setting_id_to_wire_id[];', file=H) -print('const uint16_t grpc_setting_id_to_wire_id[] = {%s};' % - ','.join('%d' % s for s in p.slots), - file=C) +print( + "#define GRPC_CHTTP2_NUM_SETTINGS %d" + % (max(x.enum for x in decorated_settings) + 1), + file=H, +) + +print("extern const uint16_t grpc_setting_id_to_wire_id[];", file=H) +print( + "const uint16_t grpc_setting_id_to_wire_id[] = {%s};" + % ",".join("%d" % s for s in p.slots), + file=C, +) print(file=H) print( - "bool grpc_wire_id_to_setting_id(uint32_t wire_id, grpc_chttp2_setting_id *out);", - file=H) + ( + "bool grpc_wire_id_to_setting_id(uint32_t wire_id," + " grpc_chttp2_setting_id *out);" + ), + file=H, +) cgargs = { - 'r': ','.join('%d' % (r if r is not None else 0) for r in p.r), - 't': p.t, - 'offset': abs(p.offset), - 'offset_sign': '+' if p.offset > 0 else '-' + "r": ",".join("%d" % (r if r is not None else 0) for r in p.r), + "t": p.t, + "offset": abs(p.offset), + "offset_sign": "+" if p.offset > 0 else "-", } -print(""" +print( + """ bool grpc_wire_id_to_setting_id(uint32_t wire_id, grpc_chttp2_setting_id *out) { uint32_t i = wire_id %(offset_sign)s %(offset)d; uint32_t x = i %% %(t)d; uint32_t y = i / %(t)d; uint32_t h = x; switch (y) { -""" % cgargs, - file=C) +""" + % cgargs, + file=C, +) for i, r in enumerate(p.r): if not r: continue if r < 0: - print('case %d: h -= %d; break;' % (i, -r), file=C) + print("case %d: h -= %d; break;" % (i, -r), file=C) else: - print('case %d: h += %d; break;' % (i, r), file=C) -print(""" + print("case %d: h += %d; break;" % (i, r), file=C) +print( + """ } *out = static_cast(h); return h < GPR_ARRAY_SIZE(grpc_setting_id_to_wire_id) && grpc_setting_id_to_wire_id[h] == wire_id; } -""" % cgargs, - file=C) +""" + % cgargs, + file=C, +) -print(""" +print( + """ typedef enum { GRPC_CHTTP2_CLAMP_INVALID_VALUE, GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE @@ -188,32 +225,46 @@ def hash(i): extern const grpc_chttp2_setting_parameters grpc_chttp2_settings_parameters[GRPC_CHTTP2_NUM_SETTINGS]; """, - file=H) + file=H, +) print( - "const grpc_chttp2_setting_parameters grpc_chttp2_settings_parameters[GRPC_CHTTP2_NUM_SETTINGS] = {", - file=C) + ( + "const grpc_chttp2_setting_parameters" + " grpc_chttp2_settings_parameters[GRPC_CHTTP2_NUM_SETTINGS] = {" + ), + file=C, +) i = 0 for decorated_setting in sorted(decorated_settings): while i < decorated_setting.enum: print( - "{NULL, 0, 0, 0, GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE, GRPC_HTTP2_PROTOCOL_ERROR},", - file=C) + ( + "{NULL, 0, 0, 0, GRPC_CHTTP2_DISCONNECT_ON_INVALID_VALUE," + " GRPC_HTTP2_PROTOCOL_ERROR}," + ), + file=C, + ) i += 1 - print("{\"%s\", %du, %du, %du, GRPC_CHTTP2_%s, GRPC_HTTP2_%s}," % ( - decorated_setting.name, - decorated_setting.setting.default, - decorated_setting.setting.min, - decorated_setting.setting.max, - decorated_setting.setting.on_error.behavior, - decorated_setting.setting.on_error.code, - ), - file=C) + print( + '{"%s", %du, %du, %du, GRPC_CHTTP2_%s, GRPC_HTTP2_%s},' + % ( + decorated_setting.name, + decorated_setting.setting.default, + decorated_setting.setting.min, + decorated_setting.setting.max, + decorated_setting.setting.on_error.behavior, + decorated_setting.setting.on_error.code, + ), + file=C, + ) i += 1 print("};", file=C) print(file=H) -print("#endif /* GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_HTTP2_SETTINGS_H */", - file=H) +print( + "#endif /* GRPC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_HTTP2_SETTINGS_H */", + file=H, +) H.close() C.close() diff --git a/tools/codegen/core/gen_stats_data.py b/tools/codegen/core/gen_stats_data.py index f4710ecf230b9..cd2e379606b38 100755 --- a/tools/codegen/core/gen_stats_data.py +++ b/tools/codegen/core/gen_stats_data.py @@ -24,35 +24,39 @@ import yaml -with open('src/core/lib/debug/stats_data.yaml') as f: +with open("src/core/lib/debug/stats_data.yaml") as f: attrs = yaml.safe_load(f.read(), Loader=yaml.Loader) -REQUIRED_FIELDS = ['name', 'doc'] +REQUIRED_FIELDS = ["name", "doc"] def make_type(name, fields): - return (collections.namedtuple( - name, ' '.join(list(set(REQUIRED_FIELDS + fields)))), []) + return ( + collections.namedtuple( + name, " ".join(list(set(REQUIRED_FIELDS + fields))) + ), + [], + ) -def c_str(s, encoding='ascii'): +def c_str(s, encoding="ascii"): if isinstance(s, str): s = s.encode(encoding) - result = '' + result = "" for c in s: c = chr(c) if isinstance(c, int) else c - if not (32 <= ord(c) < 127) or c in ('\\', '"'): - result += '\\%03o' % ord(c) + if not (32 <= ord(c) < 127) or c in ("\\", '"'): + result += "\\%03o" % ord(c) else: result += c return '"' + result + '"' types = ( - make_type('Counter', []), - make_type('Histogram', ['max', 'buckets']), + make_type("Counter", []), + make_type("Histogram", ["max", "buckets"]), ) -Shape = collections.namedtuple('Shape', 'max buckets') +Shape = collections.namedtuple("Shape", "max buckets") inst_map = dict((t[0].__name__, t[1]) for t in types) @@ -104,7 +108,7 @@ def find_ideal_shift(mapped_bounds, max_size): def gen_map_table(mapped_bounds, shift_data): - #print("gen_map_table(%s, %s)" % (mapped_bounds, shift_data)) + # print("gen_map_table(%s, %s)" % (mapped_bounds, shift_data)) tbl = [] cur = 0 mapped_bounds = [x >> shift_data[0] for x in mapped_bounds] @@ -132,13 +136,13 @@ def decl_static_table(values, type): def type_for_uint_table(table): mv = max(table) if mv < 2**8: - return 'uint8_t' + return "uint8_t" elif mv < 2**16: - return 'uint16_t' + return "uint16_t" elif mv < 2**32: - return 'uint32_t' + return "uint32_t" else: - return 'uint64_t' + return "uint64_t" def merge_cases(cases): @@ -148,8 +152,11 @@ def merge_cases(cases): left_len = l // 2 left = cases[0:left_len] right = cases[left_len:] - return 'if (value < %d) {\n%s\n} else {\n%s\n}' % ( - left[-1][0], merge_cases(left), merge_cases(right)) + return "if (value < %d) {\n%s\n} else {\n%s\n}" % ( + left[-1][0], + merge_cases(left), + merge_cases(right), + ) def gen_bucket_code(shape): @@ -164,7 +171,8 @@ def gen_bucket_code(shape): else: mul = math.pow( float(shape.max) / bounds[-1], - 1.0 / (shape.buckets + 1 - len(bounds))) + 1.0 / (shape.buckets + 1 - len(bounds)), + ) nextb = int(math.ceil(bounds[-1] * mul)) if nextb <= bounds[-1] + 1: nextb = bounds[-1] + 1 @@ -172,18 +180,20 @@ def gen_bucket_code(shape): done_trivial = True first_nontrivial = len(bounds) bounds.append(nextb) - bounds_idx = decl_static_table(bounds, 'int') - #print first_nontrivial, shift_data, bounds - #if shift_data is not None: print [hex(x >> shift_data[0]) for x in code_bounds[first_nontrivial:]] + bounds_idx = decl_static_table(bounds, "int") + # print first_nontrivial, shift_data, bounds + # if shift_data is not None: print [hex(x >> shift_data[0]) for x in code_bounds[first_nontrivial:]] if first_nontrivial is None: - return ('return grpc_core::Clamp(value, 0, %d);\n' % shape.max, - bounds_idx) - cases = [(0, 'return 0;'), (first_nontrivial, 'return value;')] + return ( + "return grpc_core::Clamp(value, 0, %d);\n" % shape.max, + bounds_idx, + ) + cases = [(0, "return 0;"), (first_nontrivial, "return value;")] if done_trivial: first_nontrivial_code = dbl2u64(first_nontrivial) last_code = first_nontrivial_code while True: - code = '' + code = "" first_nontrivial = u642dbl(first_nontrivial_code) code_bounds_index = None for i, b in enumerate(bounds): @@ -191,35 +201,44 @@ def gen_bucket_code(shape): code_bounds_index = i break code_bounds = [dbl2u64(x) - first_nontrivial_code for x in bounds] - shift_data = find_ideal_shift(code_bounds[code_bounds_index:], - 65536) + shift_data = find_ideal_shift( + code_bounds[code_bounds_index:], 65536 + ) if not shift_data: break - map_table = gen_map_table(code_bounds[code_bounds_index:], - shift_data) + map_table = gen_map_table( + code_bounds[code_bounds_index:], shift_data + ) if not map_table: break if map_table[-1] < 5: break map_table_idx = decl_static_table( [x + code_bounds_index for x in map_table], - type_for_uint_table(map_table)) + type_for_uint_table(map_table), + ) last_code = ( - (len(map_table) - 1) << shift_data[0]) + first_nontrivial_code - code += 'DblUint val;\n' - code += 'val.dbl = value;\n' - code += 'const int bucket = ' - code += 'kStatsTable%d[((val.uint - %dull) >> %d)];\n' % ( - map_table_idx, first_nontrivial_code, shift_data[0]) - code += 'return bucket - (value < kStatsTable%d[bucket]);' % bounds_idx + (len(map_table) - 1) << shift_data[0] + ) + first_nontrivial_code + code += "DblUint val;\n" + code += "val.dbl = value;\n" + code += "const int bucket = " + code += "kStatsTable%d[((val.uint - %dull) >> %d)];\n" % ( + map_table_idx, + first_nontrivial_code, + shift_data[0], + ) + code += ( + "return bucket - (value < kStatsTable%d[bucket]);" % bounds_idx + ) cases.append((int(u642dbl(last_code)) + 1, code)) first_nontrivial_code = last_code last = u642dbl(last_code) + 1 for i, b in enumerate(bounds[:-2]): if bounds[i + 1] < last: continue - cases.append((bounds[i + 1], 'return %d;' % i)) - cases.append((None, 'return %d;' % (len(bounds) - 2))) + cases.append((bounds[i + 1], "return %d;" % i)) + cases.append((None, "return %d;" % (len(bounds) - 2))) return (merge_cases(cases), bounds_idx) @@ -227,39 +246,39 @@ def gen_bucket_code(shape): def put_banner(files, banner): for f in files: for line in banner: - print('// %s' % line, file=f) + print("// %s" % line, file=f) print(file=f) shapes = set() -for histogram in inst_map['Histogram']: +for histogram in inst_map["Histogram"]: shapes.add(Shape(max=histogram.max, buckets=histogram.buckets)) def snake_to_pascal(name): - return ''.join([x.capitalize() for x in name.split('_')]) + return "".join([x.capitalize() for x in name.split("_")]) -with open('src/core/lib/debug/stats_data.h', 'w') as H: +with open("src/core/lib/debug/stats_data.h", "w") as H: # copy-paste copyright notice from this file with open(sys.argv[0]) as my_source: copyright = [] for line in my_source: - if line[0] != '#': + if line[0] != "#": break for line in my_source: - if line[0] == '#': + if line[0] == "#": copyright.append(line) break for line in my_source: - if line[0] != '#': + if line[0] != "#": break copyright.append(line) put_banner([H], [line[2:].rstrip() for line in copyright]) put_banner( - [H], - ["Automatically generated by tools/codegen/core/gen_stats_data.py"]) + [H], ["Automatically generated by tools/codegen/core/gen_stats_data.py"] + ) print("#ifndef GRPC_SRC_CORE_LIB_DEBUG_STATS_DATA_H", file=H) print("#define GRPC_SRC_CORE_LIB_DEBUG_STATS_DATA_H", file=H) @@ -268,140 +287,186 @@ def snake_to_pascal(name): print("#include ", file=H) print("#include ", file=H) print("#include ", file=H) - print("#include \"src/core/lib/debug/histogram_view.h\"", file=H) - print("#include \"absl/strings/string_view.h\"", file=H) - print("#include \"src/core/lib/gprpp/per_cpu.h\"", file=H) + print('#include "src/core/lib/debug/histogram_view.h"', file=H) + print('#include "absl/strings/string_view.h"', file=H) + print('#include "src/core/lib/gprpp/per_cpu.h"', file=H) print(file=H) print("namespace grpc_core {", file=H) for shape in shapes: - print("class HistogramCollector_%d_%d;" % (shape.max, shape.buckets), - file=H) + print( + "class HistogramCollector_%d_%d;" % (shape.max, shape.buckets), + file=H, + ) print("class Histogram_%d_%d {" % (shape.max, shape.buckets), file=H) print(" public:", file=H) print(" static int BucketFor(int value);", file=H) print(" const uint64_t* buckets() const { return buckets_; }", file=H) print( - " friend Histogram_%d_%d operator-(const Histogram_%d_%d& left, const Histogram_%d_%d& right);" - % (shape.max, shape.buckets, shape.max, shape.buckets, shape.max, - shape.buckets), - file=H) + " friend Histogram_%d_%d operator-(const Histogram_%d_%d& left," + " const Histogram_%d_%d& right);" + % ( + shape.max, + shape.buckets, + shape.max, + shape.buckets, + shape.max, + shape.buckets, + ), + file=H, + ) print(" private:", file=H) - print(" friend class HistogramCollector_%d_%d;" % - (shape.max, shape.buckets), - file=H) + print( + " friend class HistogramCollector_%d_%d;" + % (shape.max, shape.buckets), + file=H, + ) print(" uint64_t buckets_[%d]{};" % shape.buckets, file=H) print("};", file=H) - print("class HistogramCollector_%d_%d {" % (shape.max, shape.buckets), - file=H) + print( + "class HistogramCollector_%d_%d {" % (shape.max, shape.buckets), + file=H, + ) print(" public:", file=H) print(" void Increment(int value) {", file=H) - print(" buckets_[Histogram_%d_%d::BucketFor(value)]" % - (shape.max, shape.buckets), - file=H) + print( + " buckets_[Histogram_%d_%d::BucketFor(value)]" + % (shape.max, shape.buckets), + file=H, + ) print(" .fetch_add(1, std::memory_order_relaxed);", file=H) print(" }", file=H) - print(" void Collect(Histogram_%d_%d* result) const;" % - (shape.max, shape.buckets), - file=H) + print( + " void Collect(Histogram_%d_%d* result) const;" + % (shape.max, shape.buckets), + file=H, + ) print(" private:", file=H) print(" std::atomic buckets_[%d]{};" % shape.buckets, file=H) print("};", file=H) print("struct GlobalStats {", file=H) print(" enum class Counter {", file=H) - for ctr in inst_map['Counter']: + for ctr in inst_map["Counter"]: print(" k%s," % snake_to_pascal(ctr.name), file=H) print(" COUNT", file=H) print(" };", file=H) print(" enum class Histogram {", file=H) - for ctr in inst_map['Histogram']: + for ctr in inst_map["Histogram"]: print(" k%s," % snake_to_pascal(ctr.name), file=H) print(" COUNT", file=H) print(" };", file=H) print(" GlobalStats();", file=H) print( - " static const absl::string_view counter_name[static_cast(Counter::COUNT)];", - file=H) + ( + " static const absl::string_view" + " counter_name[static_cast(Counter::COUNT)];" + ), + file=H, + ) print( - " static const absl::string_view histogram_name[static_cast(Histogram::COUNT)];", - file=H) + ( + " static const absl::string_view" + " histogram_name[static_cast(Histogram::COUNT)];" + ), + file=H, + ) print( - " static const absl::string_view counter_doc[static_cast(Counter::COUNT)];", - file=H) + ( + " static const absl::string_view" + " counter_doc[static_cast(Counter::COUNT)];" + ), + file=H, + ) print( - " static const absl::string_view histogram_doc[static_cast(Histogram::COUNT)];", - file=H) + ( + " static const absl::string_view" + " histogram_doc[static_cast(Histogram::COUNT)];" + ), + file=H, + ) print(" union {", file=H) print(" struct {", file=H) - for ctr in inst_map['Counter']: + for ctr in inst_map["Counter"]: print(" uint64_t %s;" % ctr.name, file=H) print(" };", file=H) print(" uint64_t counters[static_cast(Counter::COUNT)];", file=H) print(" };", file=H) - for ctr in inst_map['Histogram']: - print(" Histogram_%d_%d %s;" % (ctr.max, ctr.buckets, ctr.name), - file=H) + for ctr in inst_map["Histogram"]: + print( + " Histogram_%d_%d %s;" % (ctr.max, ctr.buckets, ctr.name), file=H + ) print(" HistogramView histogram(Histogram which) const;", file=H) print( " std::unique_ptr Diff(const GlobalStats& other) const;", - file=H) + file=H, + ) print("};", file=H) print("class GlobalStatsCollector {", file=H) print(" public:", file=H) print(" std::unique_ptr Collect() const;", file=H) - for ctr in inst_map['Counter']: + for ctr in inst_map["Counter"]: print( - " void Increment%s() { data_.this_cpu().%s.fetch_add(1, std::memory_order_relaxed); }" + " void Increment%s() { data_.this_cpu().%s.fetch_add(1," + " std::memory_order_relaxed); }" % (snake_to_pascal(ctr.name), ctr.name), - file=H) - for ctr in inst_map['Histogram']: + file=H, + ) + for ctr in inst_map["Histogram"]: print( - " void Increment%s(int value) { data_.this_cpu().%s.Increment(value); }" + " void Increment%s(int value) {" + " data_.this_cpu().%s.Increment(value); }" % (snake_to_pascal(ctr.name), ctr.name), - file=H) + file=H, + ) print(" private:", file=H) print(" struct Data {", file=H) - for ctr in inst_map['Counter']: + for ctr in inst_map["Counter"]: print(" std::atomic %s{0};" % ctr.name, file=H) - for ctr in inst_map['Histogram']: - print(" HistogramCollector_%d_%d %s;" % - (ctr.max, ctr.buckets, ctr.name), - file=H) + for ctr in inst_map["Histogram"]: + print( + " HistogramCollector_%d_%d %s;" + % (ctr.max, ctr.buckets, ctr.name), + file=H, + ) print(" };", file=H) print( - " PerCpu data_{PerCpuOptions().SetCpusPerShard(4).SetMaxShards(32)};", - file=H) + ( + " PerCpu" + " data_{PerCpuOptions().SetCpusPerShard(4).SetMaxShards(32)};" + ), + file=H, + ) print("};", file=H) print("}", file=H) print(file=H) print("#endif // GRPC_SRC_CORE_LIB_DEBUG_STATS_DATA_H", file=H) -with open('src/core/lib/debug/stats_data.cc', 'w') as C: +with open("src/core/lib/debug/stats_data.cc", "w") as C: # copy-paste copyright notice from this file with open(sys.argv[0]) as my_source: copyright = [] for line in my_source: - if line[0] != '#': + if line[0] != "#": break for line in my_source: - if line[0] == '#': + if line[0] == "#": copyright.append(line) break for line in my_source: - if line[0] != '#': + if line[0] != "#": break copyright.append(line) put_banner([C], [line[2:].rstrip() for line in copyright]) put_banner( - [C], - ["Automatically generated by tools/codegen/core/gen_stats_data.py"]) + [C], ["Automatically generated by tools/codegen/core/gen_stats_data.py"] + ) print("#include ", file=C) print(file=C) - print("#include \"src/core/lib/debug/stats_data.h\"", file=C) + print('#include "src/core/lib/debug/stats_data.h"', file=C) print("#include ", file=C) print(file=C) @@ -417,99 +482,142 @@ def snake_to_pascal(name): for shape in shapes: print( - "void HistogramCollector_%d_%d::Collect(Histogram_%d_%d* result) const {" - % (shape.max, shape.buckets, shape.max, shape.buckets), - file=C) + "void HistogramCollector_%d_%d::Collect(Histogram_%d_%d* result)" + " const {" % (shape.max, shape.buckets, shape.max, shape.buckets), + file=C, + ) print(" for (int i=0; i<%d; i++) {" % shape.buckets, file=C) print( - " result->buckets_[i] += buckets_[i].load(std::memory_order_relaxed);", - file=C) + ( + " result->buckets_[i] +=" + " buckets_[i].load(std::memory_order_relaxed);" + ), + file=C, + ) print(" }", file=C) print("}", file=C) print( - "Histogram_%d_%d operator-(const Histogram_%d_%d& left, const Histogram_%d_%d& right) {" - % (shape.max, shape.buckets, shape.max, shape.buckets, shape.max, - shape.buckets), - file=C) + "Histogram_%d_%d operator-(const Histogram_%d_%d& left, const" + " Histogram_%d_%d& right) {" + % ( + shape.max, + shape.buckets, + shape.max, + shape.buckets, + shape.max, + shape.buckets, + ), + file=C, + ) print(" Histogram_%d_%d result;" % (shape.max, shape.buckets), file=C) print(" for (int i=0; i<%d; i++) {" % shape.buckets, file=C) - print(" result.buckets_[i] = left.buckets_[i] - right.buckets_[i];", - file=C) + print( + " result.buckets_[i] = left.buckets_[i] - right.buckets_[i];", + file=C, + ) print(" }", file=C) print(" return result;", file=C) print("}", file=C) for typename, instances in sorted(inst_map.items()): print( - "const absl::string_view GlobalStats::%s_name[static_cast(%s::COUNT)] = {" + "const absl::string_view" + " GlobalStats::%s_name[static_cast(%s::COUNT)] = {" % (typename.lower(), typename), - file=C) + file=C, + ) for inst in instances: print(" %s," % c_str(inst.name), file=C) print("};", file=C) print( - "const absl::string_view GlobalStats::%s_doc[static_cast(%s::COUNT)] = {" + "const absl::string_view" + " GlobalStats::%s_doc[static_cast(%s::COUNT)] = {" % (typename.lower(), typename), - file=C) + file=C, + ) for inst in instances: print(" %s," % c_str(inst.doc), file=C) print("};", file=C) print("namespace {", file=C) for i, tbl in enumerate(static_tables): - print("const %s kStatsTable%d[%d] = {%s};" % - (tbl[0], i, len(tbl[1]), ','.join('%s' % x for x in tbl[1])), - file=C) + print( + "const %s kStatsTable%d[%d] = {%s};" + % (tbl[0], i, len(tbl[1]), ",".join("%s" % x for x in tbl[1])), + file=C, + ) print("} // namespace", file=C) for shape, code in zip(shapes, histo_code): - print(("int Histogram_%d_%d::BucketFor(int value) {%s}") % - (shape.max, shape.buckets, code), - file=C) + print( + "int Histogram_%d_%d::BucketFor(int value) {%s}" + % (shape.max, shape.buckets, code), + file=C, + ) - print("GlobalStats::GlobalStats() : %s {}" % - ",".join("%s{0}" % ctr.name for ctr in inst_map['Counter']), - file=C) + print( + "GlobalStats::GlobalStats() : %s {}" + % ",".join("%s{0}" % ctr.name for ctr in inst_map["Counter"]), + file=C, + ) - print("HistogramView GlobalStats::histogram(Histogram which) const {", - file=C) + print( + "HistogramView GlobalStats::histogram(Histogram which) const {", file=C + ) print(" switch (which) {", file=C) print(" default: GPR_UNREACHABLE_CODE(return HistogramView());", file=C) - for inst in inst_map['Histogram']: + for inst in inst_map["Histogram"]: print(" case Histogram::k%s:" % snake_to_pascal(inst.name), file=C) print( - " return HistogramView{&Histogram_%d_%d::BucketFor, kStatsTable%d, %d, %s.buckets()};" - % (inst.max, inst.buckets, histo_bucket_boundaries[Shape( - inst.max, inst.buckets)], inst.buckets, inst.name), - file=C) + " return HistogramView{&Histogram_%d_%d::BucketFor," + " kStatsTable%d, %d, %s.buckets()};" + % ( + inst.max, + inst.buckets, + histo_bucket_boundaries[Shape(inst.max, inst.buckets)], + inst.buckets, + inst.name, + ), + file=C, + ) print(" }", file=C) print("}", file=C) print( "std::unique_ptr GlobalStatsCollector::Collect() const {", - file=C) + file=C, + ) print(" auto result = std::make_unique();", file=C) print(" for (const auto& data : data_) {", file=C) - for ctr in inst_map['Counter']: - print(" result->%s += data.%s.load(std::memory_order_relaxed);" % - (ctr.name, ctr.name), - file=C) - for h in inst_map['Histogram']: + for ctr in inst_map["Counter"]: + print( + " result->%s += data.%s.load(std::memory_order_relaxed);" + % (ctr.name, ctr.name), + file=C, + ) + for h in inst_map["Histogram"]: print(" data.%s.Collect(&result->%s);" % (h.name, h.name), file=C) print(" }", file=C) print(" return result;", file=C) print("}", file=C) print( - "std::unique_ptr GlobalStats::Diff(const GlobalStats& other) const {", - file=C) + ( + "std::unique_ptr GlobalStats::Diff(const GlobalStats&" + " other) const {" + ), + file=C, + ) print(" auto result = std::make_unique();", file=C) - for ctr in inst_map['Counter']: - print(" result->%s = %s - other.%s;" % (ctr.name, ctr.name, ctr.name), - file=C) - for h in inst_map['Histogram']: - print(" result->%s = %s - other.%s;" % (h.name, h.name, h.name), - file=C) + for ctr in inst_map["Counter"]: + print( + " result->%s = %s - other.%s;" % (ctr.name, ctr.name, ctr.name), + file=C, + ) + for h in inst_map["Histogram"]: + print( + " result->%s = %s - other.%s;" % (h.name, h.name, h.name), file=C + ) print(" return result;", file=C) print("}", file=C) diff --git a/tools/codegen/core/gen_switch.py b/tools/codegen/core/gen_switch.py index 02c79700e7766..80e53c6416d8c 100755 --- a/tools/codegen/core/gen_switch.py +++ b/tools/codegen/core/gen_switch.py @@ -20,26 +20,26 @@ # utility: print a big comment block into a set of files def put_banner(files, banner): for f in files: - print('/*', file=f) + print("/*", file=f) for line in banner: - print(' * %s' % line, file=f) - print(' */', file=f) - print('', file=f) + print(" * %s" % line, file=f) + print(" */", file=f) + print("", file=f) -with open('src/core/lib/promise/detail/switch.h', 'w') as H: +with open("src/core/lib/promise/detail/switch.h", "w") as H: # copy-paste copyright notice from this file with open(sys.argv[0]) as my_source: copyright = [] for line in my_source: - if line[0] != '#': + if line[0] != "#": break for line in my_source: - if line[0] == '#': + if line[0] == "#": copyright.append(line) break for line in my_source: - if line[0] != '#': + if line[0] != "#": break copyright.append(line) put_banner([H], [line[2:].rstrip() for line in copyright]) @@ -48,20 +48,23 @@ def put_banner(files, banner): print("#ifndef GRPC_CORE_LIB_PROMISE_DETAIL_SWITCH_H", file=H) print("#define GRPC_CORE_LIB_PROMISE_DETAIL_SWITCH_H", file=H) - print('', file=H) - print('#include ', file=H) - print('', file=H) + print("", file=H) + print("#include ", file=H) + print("", file=H) print("#include ", file=H) - print('', file=H) + print("", file=H) print("namespace grpc_core {", file=H) for n in range(1, 33): - print('', file=H) - print("template R Switch(char idx, %s) {" % ( - ", ".join("typename F%d" % i for i in range(0, n)), - ", ".join("F%d f%d" % (i, i) for i in range(0, n)), - ), - file=H) + print("", file=H) + print( + "template R Switch(char idx, %s) {" + % ( + ", ".join("typename F%d" % i for i in range(0, n)), + ", ".join("F%d f%d" % (i, i) for i in range(0, n)), + ), + file=H, + ) print(" switch (idx) {", file=H) for i in range(0, n): print(" case %d: return f%d();" % (i, i), file=H) @@ -69,7 +72,7 @@ def put_banner(files, banner): print(" abort();", file=H) print("}", file=H) - print('', file=H) + print("", file=H) print("}", file=H) - print('', file=H) + print("", file=H) print("#endif // GRPC_CORE_LIB_PROMISE_DETAIL_SWITCH_H", file=H) diff --git a/tools/codegen/core/gen_upb_api_from_bazel_xml.py b/tools/codegen/core/gen_upb_api_from_bazel_xml.py index 58b9be7900564..663154fe03196 100755 --- a/tools/codegen/core/gen_upb_api_from_bazel_xml.py +++ b/tools/codegen/core/gen_upb_api_from_bazel_xml.py @@ -38,35 +38,35 @@ import xml.etree.ElementTree # Rule object representing the UPB rule of Bazel BUILD. -Rule = collections.namedtuple('Rule', 'name type srcs deps proto_files') +Rule = collections.namedtuple("Rule", "name type srcs deps proto_files") -BAZEL_BIN = 'tools/bazel' +BAZEL_BIN = "tools/bazel" def parse_bazel_rule(elem): - '''Returns a rule from bazel XML rule.''' + """Returns a rule from bazel XML rule.""" srcs = [] deps = [] for child in elem: - if child.tag == 'list' and child.attrib['name'] == 'srcs': + if child.tag == "list" and child.attrib["name"] == "srcs": for tag in child: - if tag.tag == 'label': - srcs.append(tag.attrib['value']) - if child.tag == 'list' and child.attrib['name'] == 'deps': + if tag.tag == "label": + srcs.append(tag.attrib["value"]) + if child.tag == "list" and child.attrib["name"] == "deps": for tag in child: - if tag.tag == 'label': - deps.append(tag.attrib['value']) - if child.tag == 'label': + if tag.tag == "label": + deps.append(tag.attrib["value"]) + if child.tag == "label": # extract actual name for alias rules - label_name = child.attrib['name'] - if label_name in ['actual']: - actual_name = child.attrib.get('value', None) + label_name = child.attrib["name"] + if label_name in ["actual"]: + actual_name = child.attrib.get("value", None) if actual_name: # HACK: since we do a lot of transitive dependency scanning, # make it seem that the actual name is a dependency of the alias rule # (aliases don't have dependencies themselves) deps.append(actual_name) - return Rule(elem.attrib['name'], elem.attrib['class'], srcs, deps, []) + return Rule(elem.attrib["name"], elem.attrib["class"], srcs, deps, []) def get_transitive_protos(rules, t): @@ -84,62 +84,77 @@ def get_transitive_protos(rules, t): visited.add(dep) que.append(dep) for src in rule.srcs: - if src.endswith('.proto'): + if src.endswith(".proto"): ret.append(src) return list(set(ret)) def read_upb_bazel_rules(): - '''Runs bazel query on given package file and returns all upb rules.''' + """Runs bazel query on given package file and returns all upb rules.""" # Use a wrapper version of bazel in gRPC not to use system-wide bazel # to avoid bazel conflict when running on Kokoro. result = subprocess.check_output( - [BAZEL_BIN, 'query', '--output', 'xml', '--noimplicit_deps', '//:all']) + [BAZEL_BIN, "query", "--output", "xml", "--noimplicit_deps", "//:all"] + ) root = xml.etree.ElementTree.fromstring(result) rules = [ parse_bazel_rule(elem) for elem in root - if elem.tag == 'rule' and elem.attrib['class'] in [ - 'upb_proto_library', - 'upb_proto_reflection_library', + if elem.tag == "rule" + and elem.attrib["class"] + in [ + "upb_proto_library", + "upb_proto_reflection_library", ] ] # query all dependencies of upb rules to get a list of proto files all_deps = [dep for rule in rules for dep in rule.deps] - result = subprocess.check_output([ - BAZEL_BIN, 'query', '--output', 'xml', '--noimplicit_deps', - ' union '.join('deps({0})'.format(d) for d in all_deps) - ]) + result = subprocess.check_output( + [ + BAZEL_BIN, + "query", + "--output", + "xml", + "--noimplicit_deps", + " union ".join("deps({0})".format(d) for d in all_deps), + ] + ) root = xml.etree.ElementTree.fromstring(result) dep_rules = {} for dep_rule in ( - parse_bazel_rule(elem) for elem in root if elem.tag == 'rule'): + parse_bazel_rule(elem) for elem in root if elem.tag == "rule" + ): dep_rules[dep_rule.name] = dep_rule # add proto files to upb rules transitively for rule in rules: - if not rule.type.startswith('upb_proto_'): + if not rule.type.startswith("upb_proto_"): continue if len(rule.deps) == 1: rule.proto_files.extend( - get_transitive_protos(dep_rules, rule.deps[0])) + get_transitive_protos(dep_rules, rule.deps[0]) + ) return rules def build_upb_bazel_rules(rules): - result = subprocess.check_output([BAZEL_BIN, 'build'] + - [rule.name for rule in rules]) + result = subprocess.check_output( + [BAZEL_BIN, "build"] + [rule.name for rule in rules] + ) def get_upb_path(proto_path, ext): - return proto_path.replace(':', '/').replace('.proto', ext) + return proto_path.replace(":", "/").replace(".proto", ext) def get_bazel_bin_root_path(elink): - BAZEL_BIN_ROOT = 'bazel-bin/' - if elink[0].startswith('@'): + BAZEL_BIN_ROOT = "bazel-bin/" + if elink[0].startswith("@"): # external - result = os.path.join(BAZEL_BIN_ROOT, 'external', - elink[0].replace('@', '').replace('//', '')) + result = os.path.join( + BAZEL_BIN_ROOT, + "external", + elink[0].replace("@", "").replace("//", ""), + ) if elink[1]: result = os.path.join(result, elink[1]) return result @@ -149,70 +164,81 @@ def get_bazel_bin_root_path(elink): def get_external_link(file): - EXTERNAL_LINKS = [('@com_google_protobuf//', 'src/'), - ('@com_google_googleapis//', ''), - ('@com_github_cncf_udpa//', ''), - ('@com_envoyproxy_protoc_gen_validate//', ''), - ('@envoy_api//', ''), ('@opencensus_proto//', '')] + EXTERNAL_LINKS = [ + ("@com_google_protobuf//", "src/"), + ("@com_google_googleapis//", ""), + ("@com_github_cncf_udpa//", ""), + ("@com_envoyproxy_protoc_gen_validate//", ""), + ("@envoy_api//", ""), + ("@opencensus_proto//", ""), + ] for external_link in EXTERNAL_LINKS: if file.startswith(external_link[0]): return external_link - return ('//', '') + return ("//", "") def copy_upb_generated_files(rules, args): files = {} for rule in rules: - if rule.type == 'upb_proto_library': - frag = '.upb' + if rule.type == "upb_proto_library": + frag = ".upb" output_dir = args.upb_out else: - frag = '.upbdefs' + frag = ".upbdefs" output_dir = args.upbdefs_out for proto_file in rule.proto_files: elink = get_external_link(proto_file) prefix_to_strip = elink[0] + elink[1] if not proto_file.startswith(prefix_to_strip): raise Exception( - 'Source file "{0}" in does not have the expected prefix "{1}"' - .format(proto_file, prefix_to_strip)) - proto_file = proto_file[len(prefix_to_strip):] - for ext in ('.h', '.c'): + 'Source file "{0}" in does not have the expected prefix' + ' "{1}"'.format(proto_file, prefix_to_strip) + ) + proto_file = proto_file[len(prefix_to_strip) :] + for ext in (".h", ".c"): file = get_upb_path(proto_file, frag + ext) src = os.path.join(get_bazel_bin_root_path(elink), file) dst = os.path.join(output_dir, file) files[src] = dst for src, dst in files.items(): if args.verbose: - print('Copy:') - print(' {0}'.format(src)) - print(' -> {0}'.format(dst)) + print("Copy:") + print(" {0}".format(src)) + print(" -> {0}".format(dst)) os.makedirs(os.path.split(dst)[0], exist_ok=True) shutil.copyfile(src, dst) -parser = argparse.ArgumentParser(description='UPB code-gen from bazel') -parser.add_argument('--verbose', default=False, action='store_true') -parser.add_argument('--upb_out', - default='src/core/ext/upb-generated', - help='Output directory for upb targets') -parser.add_argument('--upbdefs_out', - default='src/core/ext/upbdefs-generated', - help='Output directory for upbdefs targets') +parser = argparse.ArgumentParser(description="UPB code-gen from bazel") +parser.add_argument("--verbose", default=False, action="store_true") +parser.add_argument( + "--upb_out", + default="src/core/ext/upb-generated", + help="Output directory for upb targets", +) +parser.add_argument( + "--upbdefs_out", + default="src/core/ext/upbdefs-generated", + help="Output directory for upbdefs targets", +) def main(): args = parser.parse_args() rules = read_upb_bazel_rules() if args.verbose: - print('Rules:') + print("Rules:") for rule in rules: - print(' name={0} type={1} proto_files={2}'.format( - rule.name, rule.type, rule.proto_files)) + print( + " name={0} type={1} proto_files={2}".format( + rule.name, rule.type, rule.proto_files + ) + ) if rules: build_upb_bazel_rules(rules) copy_upb_generated_files(rules, args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/codegen/core/optimize_arena_pool_sizes.py b/tools/codegen/core/optimize_arena_pool_sizes.py index bfae1c6217a6f..e5f094dfd7b1c 100755 --- a/tools/codegen/core/optimize_arena_pool_sizes.py +++ b/tools/codegen/core/optimize_arena_pool_sizes.py @@ -26,8 +26,8 @@ import sys # A single allocation, negative size => free -Allocation = collections.namedtuple('Allocation', 'size ptr') -Active = collections.namedtuple('Active', 'id size') +Allocation = collections.namedtuple("Allocation", "size ptr") +Active = collections.namedtuple("Active", "id size") # Read through all the captures, and build up scrubbed traces arenas = [] @@ -38,8 +38,9 @@ sizes = set() for filename in sys.argv[1:]: for line in open(filename): - m = re.search(r'ARENA 0x([0-9a-f]+) ALLOC ([0-9]+) @ 0x([0-9a-f]+)', - line) + m = re.search( + r"ARENA 0x([0-9a-f]+) ALLOC ([0-9]+) @ 0x([0-9a-f]+)", line + ) if m: size = int(m.group(2)) if size > biggest: @@ -49,13 +50,13 @@ active[m.group(3)] = Active(m.group(1), size) building[m.group(1)].append(size) sizes.add(size) - m = re.search(r'FREE 0x([0-9a-f]+)', line) + m = re.search(r"FREE 0x([0-9a-f]+)", line) if m: # We may have spurious frees, so make sure there's an outstanding allocation last = active.pop(m.group(1), None) if last is not None: building[last.id].append(-last.size) - m = re.search(r'DESTRUCT_ARENA 0x([0-9a-f]+)', line) + m = re.search(r"DESTRUCT_ARENA 0x([0-9a-f]+)", line) if m: trace = building.pop(m.group(1), None) if trace: @@ -89,8 +90,9 @@ def outstanding_bytes(pool_sizes, trace): def measure(pool_sizes): max_outstanding = 0 for trace in arenas: - max_outstanding = max(max_outstanding, - outstanding_bytes(pool_sizes, trace)) + max_outstanding = max( + max_outstanding, outstanding_bytes(pool_sizes, trace) + ) return max_outstanding @@ -120,8 +122,10 @@ def add(l): m = measure(top) step += 1 if step % 1000 == 0: - print("iter %d; pending=%d; top=%r/%d" % - (step, len(testq), top, measure(top))) + print( + "iter %d; pending=%d; top=%r/%d" + % (step, len(testq), top, measure(top)) + ) for i in sizes: if i >= top[-1]: continue diff --git a/tools/debug/core/chttp2_ref_leak.py b/tools/debug/core/chttp2_ref_leak.py index 9403853a91b2a..a4660f9c1c6b3 100755 --- a/tools/debug/core/chttp2_ref_leak.py +++ b/tools/debug/core/chttp2_ref_leak.py @@ -22,7 +22,7 @@ def new_obj(): - return ['destroy'] + return ["destroy"] outstanding = collections.defaultdict(new_obj) @@ -32,13 +32,14 @@ def new_obj(): for line in sys.stdin: m = re.search( - r'chttp2:( ref|unref):0x([a-fA-F0-9]+) [^ ]+ ([^[]+) \[(.*)\]', line) + r"chttp2:( ref|unref):0x([a-fA-F0-9]+) [^ ]+ ([^[]+) \[(.*)\]", line + ) if m: - if m.group(1) == ' ref': + if m.group(1) == " ref": outstanding[m.group(2)].append(m.group(3)) else: outstanding[m.group(2)].remove(m.group(3)) for obj, remaining in list(outstanding.items()): if remaining: - print(('LEAKED: %s %r' % (obj, remaining))) + print(("LEAKED: %s %r" % (obj, remaining))) diff --git a/tools/debug/core/error_ref_leak.py b/tools/debug/core/error_ref_leak.py index d65ecabfab1b9..1ff04e6455f9a 100644 --- a/tools/debug/core/error_ref_leak.py +++ b/tools/debug/core/error_ref_leak.py @@ -27,21 +27,21 @@ errs = [] for line in data: # if we care about the line - if re.search(r'error.cc', line): + if re.search(r"error.cc", line): # str manip to cut off left part of log line - line = line.partition('error.cc:')[-1] - line = re.sub(r'\d+] ', r'', line) + line = line.partition("error.cc:")[-1] + line = re.sub(r"\d+] ", r"", line) line = line.strip().split() err = line[0].strip(":") if line[1] == "create": - assert (err not in errs) + assert err not in errs errs.append(err) elif line[0] == "realloc": errs.remove(line[1]) errs.append(line[3]) # explicitly look for the last dereference elif line[1] == "1" and line[3] == "0": - assert (err in errs) + assert err in errs errs.remove(err) print(("leaked:", errs)) diff --git a/tools/distrib/add-iwyu.py b/tools/distrib/add-iwyu.py index 5d180c4256d52..496f329bdcab3 100755 --- a/tools/distrib/add-iwyu.py +++ b/tools/distrib/add-iwyu.py @@ -21,7 +21,7 @@ def to_inc(filename): """Given filename, synthesize what should go in an include statement to get that file""" if filename.startswith("include/"): - return '<%s>' % filename[len("include/"):] + return "<%s>" % filename[len("include/") :] return '"%s"' % filename @@ -30,40 +30,40 @@ def set_pragmas(filename, pragmas): lines = [] saw_first_define = False for line in open(filename).read().splitlines(): - if line.startswith('// IWYU pragma: '): + if line.startswith("// IWYU pragma: "): continue lines.append(line) - if not saw_first_define and line.startswith('#define '): + if not saw_first_define and line.startswith("#define "): saw_first_define = True - lines.append('') + lines.append("") for pragma in pragmas: - lines.append('// IWYU pragma: %s' % pragma) - lines.append('') - open(filename, 'w').write('\n'.join(lines) + '\n') + lines.append("// IWYU pragma: %s" % pragma) + lines.append("") + open(filename, "w").write("\n".join(lines) + "\n") def set_exports(pub, cg): """In file pub, mark the include for cg with IWYU pragma: export""" lines = [] for line in open(pub).read().splitlines(): - if line.startswith('#include %s' % to_inc(cg)): - lines.append('#include %s // IWYU pragma: export' % to_inc(cg)) + if line.startswith("#include %s" % to_inc(cg)): + lines.append("#include %s // IWYU pragma: export" % to_inc(cg)) else: lines.append(line) - open(pub, 'w').write('\n'.join(lines) + '\n') + open(pub, "w").write("\n".join(lines) + "\n") CG_ROOTS_GRPC = ( - (r'sync', 'grpc/support/sync.h', False), - (r'atm', 'grpc/support/atm.h', False), - (r'grpc_types', 'grpc/grpc.h', True), - (r'gpr_types', 'grpc/grpc.h', True), - (r'compression_types', 'grpc/compression.h', True), - (r'connectivity_state', 'grpc/grpc.h', True), + (r"sync", "grpc/support/sync.h", False), + (r"atm", "grpc/support/atm.h", False), + (r"grpc_types", "grpc/grpc.h", True), + (r"gpr_types", "grpc/grpc.h", True), + (r"compression_types", "grpc/compression.h", True), + (r"connectivity_state", "grpc/grpc.h", True), ) CG_ROOTS_GRPCPP = [ - (r'status_code_enum', 'grpcpp/support/status.h', False), + (r"status_code_enum", "grpcpp/support/status.h", False), ] @@ -74,14 +74,14 @@ def fix_tree(tree, cg_roots): # The same, but for things with '/impl/codegen' in their names cg_reverse_map = collections.defaultdict(list) for root, dirs, files in os.walk(tree): - root_map = cg_reverse_map if '/impl/codegen' in root else reverse_map + root_map = cg_reverse_map if "/impl/codegen" in root else reverse_map for filename in files: root_map[filename].append(root) # For each thing in '/impl/codegen' figure out what exports it for filename, paths in cg_reverse_map.items(): print("****", filename) # Exclude non-headers - if not filename.endswith('.h'): + if not filename.endswith(".h"): continue pragmas = [] # Check for our 'special' headers: if we see one of these, we just @@ -89,16 +89,17 @@ def fix_tree(tree, cg_roots): for root, target, friend in cg_roots: print(root, target, friend) if filename.startswith(root): - pragmas = ['private, include <%s>' % target] + pragmas = ["private, include <%s>" % target] if friend: pragmas.append('friend "src/.*"') if len(paths) == 1: path = paths[0] - if filename.startswith(root + '.'): - set_exports('include/' + target, path + '/' + filename) - if filename.startswith(root + '_'): - set_exports(path + '/' + root + '.h', - path + '/' + filename) + if filename.startswith(root + "."): + set_exports("include/" + target, path + "/" + filename) + if filename.startswith(root + "_"): + set_exports( + path + "/" + root + ".h", path + "/" + filename + ) # If the path for a file in /impl/codegen is ambiguous, just don't bother if not pragmas and len(paths) == 1: path = paths[0] @@ -108,22 +109,22 @@ def fix_tree(tree, cg_roots): # And that it too is unambiguous if len(proper) == 1: # Build the two relevant pathnames - cg = path + '/' + filename - pub = proper[0] + '/' + filename + cg = path + "/" + filename + pub = proper[0] + "/" + filename # And see if the public file actually includes the /impl/codegen file - if ('#include %s' % to_inc(cg)) in open(pub).read(): + if ("#include %s" % to_inc(cg)) in open(pub).read(): # Finally, if it does, we'll set that pragma - pragmas = ['private, include %s' % to_inc(pub)] + pragmas = ["private, include %s" % to_inc(pub)] # And mark the export set_exports(pub, cg) # If we can't find a good alternative include to point people to, # mark things private anyway... we don't want to recommend people include # from impl/codegen if not pragmas: - pragmas = ['private'] + pragmas = ["private"] for path in paths: - set_pragmas(path + '/' + filename, pragmas) + set_pragmas(path + "/" + filename, pragmas) -fix_tree('include/grpc', CG_ROOTS_GRPC) -fix_tree('include/grpcpp', CG_ROOTS_GRPCPP) +fix_tree("include/grpc", CG_ROOTS_GRPC) +fix_tree("include/grpcpp", CG_ROOTS_GRPCPP) diff --git a/tools/distrib/yapf_code.sh b/tools/distrib/black_code.sh similarity index 75% rename from tools/distrib/yapf_code.sh rename to tools/distrib/black_code.sh index c1b60d1269fbd..6908409496722 100755 --- a/tools/distrib/yapf_code.sh +++ b/tools/distrib/black_code.sh @@ -15,8 +15,8 @@ set -ex -ACTION=${1:---in-place} -[[ $ACTION == '--in-place' ]] || [[ $ACTION == '--diff' ]] +ACTION="${1:-}" +[[ $ACTION == '' ]] || [[ $ACTION == '--diff' ]] || [[ $ACTION == '--check' ]] # change to root directory cd "$(dirname "${0}")/../.." @@ -29,10 +29,10 @@ DIRS=( 'setup.py' ) -VIRTUALENV=yapf_virtual_environment +VIRTUALENV=black_virtual_environment python3 -m virtualenv $VIRTUALENV -p $(which python3) PYTHON=${VIRTUALENV}/bin/python -"$PYTHON" -m pip install yapf==0.30.0 +"$PYTHON" -m pip install black==23.3.0 -$PYTHON -m yapf $ACTION --parallel --recursive --style=setup.cfg "${DIRS[@]}" -e "**/site-packages/**/*" +$PYTHON -m black --config=black.toml $ACTION "${DIRS[@]}" diff --git a/tools/distrib/c-ish/check_documentation.py b/tools/distrib/c-ish/check_documentation.py index eb4bbc1d7c1c4..59c596adff33b 100755 --- a/tools/distrib/c-ish/check_documentation.py +++ b/tools/distrib/c-ish/check_documentation.py @@ -22,15 +22,19 @@ # where do we run _TARGET_DIRS = [ - 'include/grpc', 'include/grpc++', 'src/core', 'src/cpp', 'test/core', - 'test/cpp' + "include/grpc", + "include/grpc++", + "src/core", + "src/cpp", + "test/core", + "test/cpp", ] # which file extensions do we care about -_INTERESTING_EXTENSIONS = ['.c', '.h', '.cc'] +_INTERESTING_EXTENSIONS = [".c", ".h", ".cc"] # find our home -_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) os.chdir(_ROOT) errors = 0 @@ -39,10 +43,10 @@ printed_banner = False for target_dir in _TARGET_DIRS: for root, dirs, filenames in os.walk(target_dir): - if 'README.md' not in filenames: + if "README.md" not in filenames: if not printed_banner: - print('Missing README.md') - print('=================') + print("Missing README.md") + print("=================") printed_banner = True print(root) errors += 1 @@ -57,12 +61,12 @@ path = os.path.join(root, filename) with open(path) as f: contents = f.read() - if '\\file' not in contents: + if "\\file" not in contents: if not printed_banner: - print('Missing \\file comment') - print('======================') + print("Missing \\file comment") + print("======================") printed_banner = True print(path) errors += 1 -assert errors == 0, 'error count = %d' % errors +assert errors == 0, "error count = %d" % errors diff --git a/tools/distrib/check_copyright.py b/tools/distrib/check_copyright.py index 0e6d45c5d3977..0a61e0853503a 100755 --- a/tools/distrib/check_copyright.py +++ b/tools/distrib/check_copyright.py @@ -22,23 +22,22 @@ import sys # find our home -ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(ROOT) # parse command line -argp = argparse.ArgumentParser(description='copyright checker') -argp.add_argument('-o', - '--output', - default='details', - choices=['list', 'details']) -argp.add_argument('-s', '--skips', default=0, action='store_const', const=1) -argp.add_argument('-a', '--ancient', default=0, action='store_const', const=1) -argp.add_argument('--precommit', action='store_true') -argp.add_argument('--fix', action='store_true') +argp = argparse.ArgumentParser(description="copyright checker") +argp.add_argument( + "-o", "--output", default="details", choices=["list", "details"] +) +argp.add_argument("-s", "--skips", default=0, action="store_const", const=1) +argp.add_argument("-a", "--ancient", default=0, action="store_const", const=1) +argp.add_argument("--precommit", action="store_true") +argp.add_argument("--fix", action="store_true") args = argp.parse_args() # open the license text -with open('NOTICE.txt') as f: +with open("NOTICE.txt") as f: LICENSE_NOTICE = f.read().splitlines() # license format by file extension @@ -46,28 +45,28 @@ # that given a line of license text, returns what should # be in the file LICENSE_PREFIX_RE = { - '.bat': r'@rem\s*', - '.c': r'\s*(?://|\*)\s*', - '.cc': r'\s*(?://|\*)\s*', - '.h': r'\s*(?://|\*)\s*', - '.m': r'\s*\*\s*', - '.mm': r'\s*\*\s*', - '.php': r'\s*\*\s*', - '.js': r'\s*\*\s*', - '.py': r'#\s*', - '.pyx': r'#\s*', - '.pxd': r'#\s*', - '.pxi': r'#\s*', - '.rb': r'#\s*', - '.sh': r'#\s*', - '.proto': r'//\s*', - '.cs': r'//\s*', - '.mak': r'#\s*', - '.bazel': r'#\s*', - '.bzl': r'#\s*', - 'Makefile': r'#\s*', - 'Dockerfile': r'#\s*', - 'BUILD': r'#\s*', + ".bat": r"@rem\s*", + ".c": r"\s*(?://|\*)\s*", + ".cc": r"\s*(?://|\*)\s*", + ".h": r"\s*(?://|\*)\s*", + ".m": r"\s*\*\s*", + ".mm": r"\s*\*\s*", + ".php": r"\s*\*\s*", + ".js": r"\s*\*\s*", + ".py": r"#\s*", + ".pyx": r"#\s*", + ".pxd": r"#\s*", + ".pxi": r"#\s*", + ".rb": r"#\s*", + ".sh": r"#\s*", + ".proto": r"//\s*", + ".cs": r"//\s*", + ".mak": r"#\s*", + ".bazel": r"#\s*", + ".bzl": r"#\s*", + "Makefile": r"#\s*", + "Dockerfile": r"#\s*", + "BUILD": r"#\s*", } # The key is the file extension, while the value is a tuple of fields @@ -77,126 +76,155 @@ # If header and footer are irrelevant for a specific file extension, they are # set to None. LICENSE_PREFIX_TEXT = { - '.bat': (None, '@rem', None), - '.c': (None, '//', None), - '.cc': (None, '//', None), - '.h': (None, '//', None), - '.m': ('/**', ' *', ' */'), - '.mm': ('/**', ' *', ' */'), - '.php': ('/**', ' *', ' */'), - '.js': ('/**', ' *', ' */'), - '.py': (None, '#', None), - '.pyx': (None, '#', None), - '.pxd': (None, '#', None), - '.pxi': (None, '#', None), - '.rb': (None, '#', None), - '.sh': (None, '#', None), - '.proto': (None, '//', None), - '.cs': (None, '//', None), - '.mak': (None, '#', None), - '.bazel': (None, '#', None), - '.bzl': (None, '#', None), - 'Makefile': (None, '#', None), - 'Dockerfile': (None, '#', None), - 'BUILD': (None, '#', None), + ".bat": (None, "@rem", None), + ".c": (None, "//", None), + ".cc": (None, "//", None), + ".h": (None, "//", None), + ".m": ("/**", " *", " */"), + ".mm": ("/**", " *", " */"), + ".php": ("/**", " *", " */"), + ".js": ("/**", " *", " */"), + ".py": (None, "#", None), + ".pyx": (None, "#", None), + ".pxd": (None, "#", None), + ".pxi": (None, "#", None), + ".rb": (None, "#", None), + ".sh": (None, "#", None), + ".proto": (None, "//", None), + ".cs": (None, "//", None), + ".mak": (None, "#", None), + ".bazel": (None, "#", None), + ".bzl": (None, "#", None), + "Makefile": (None, "#", None), + "Dockerfile": (None, "#", None), + "BUILD": (None, "#", None), } -_EXEMPT = frozenset(( - # Generated protocol compiler output. - 'examples/python/helloworld/helloworld_pb2.py', - 'examples/python/helloworld/helloworld_pb2_grpc.py', - 'examples/python/multiplex/helloworld_pb2.py', - 'examples/python/multiplex/helloworld_pb2_grpc.py', - 'examples/python/multiplex/route_guide_pb2.py', - 'examples/python/multiplex/route_guide_pb2_grpc.py', - 'examples/python/route_guide/route_guide_pb2.py', - 'examples/python/route_guide/route_guide_pb2_grpc.py', - - # Generated doxygen config file - 'tools/doxygen/Doxyfile.php', - - # An older file originally from outside gRPC. - 'src/php/tests/bootstrap.php', - # census.proto copied from github - 'tools/grpcz/census.proto', - # status.proto copied from googleapis - 'src/proto/grpc/status/status.proto', - - # Gradle wrappers used to build for Android - 'examples/android/helloworld/gradlew.bat', - 'src/android/test/interop/gradlew.bat', - - # Designer-generated source - 'examples/csharp/HelloworldXamarin/Droid/Resources/Resource.designer.cs', - 'examples/csharp/HelloworldXamarin/iOS/ViewController.designer.cs', - - # BoringSSL generated header. It has commit version information at the head - # of the file so we cannot check the license info. - 'src/boringssl/boringssl_prefix_symbols.h', -)) - -_ENFORCE_CPP_STYLE_COMMENT_PATH_PREFIX = tuple([ - 'include/grpc++/', - 'include/grpcpp/', - 'src/core/', - 'src/cpp/', - 'test/core/', - 'test/cpp/', - 'fuzztest/', -]) - -RE_YEAR = r'Copyright (?P[0-9]+\-)?(?P[0-9]+) ([Tt]he )?gRPC [Aa]uthors(\.|)' +_EXEMPT = frozenset( + ( + # Generated protocol compiler output. + "examples/python/helloworld/helloworld_pb2.py", + "examples/python/helloworld/helloworld_pb2_grpc.py", + "examples/python/multiplex/helloworld_pb2.py", + "examples/python/multiplex/helloworld_pb2_grpc.py", + "examples/python/multiplex/route_guide_pb2.py", + "examples/python/multiplex/route_guide_pb2_grpc.py", + "examples/python/route_guide/route_guide_pb2.py", + "examples/python/route_guide/route_guide_pb2_grpc.py", + # Generated doxygen config file + "tools/doxygen/Doxyfile.php", + # An older file originally from outside gRPC. + "src/php/tests/bootstrap.php", + # census.proto copied from github + "tools/grpcz/census.proto", + # status.proto copied from googleapis + "src/proto/grpc/status/status.proto", + # Gradle wrappers used to build for Android + "examples/android/helloworld/gradlew.bat", + "src/android/test/interop/gradlew.bat", + # Designer-generated source + "examples/csharp/HelloworldXamarin/Droid/Resources/Resource.designer.cs", + "examples/csharp/HelloworldXamarin/iOS/ViewController.designer.cs", + # BoringSSL generated header. It has commit version information at the head + # of the file so we cannot check the license info. + "src/boringssl/boringssl_prefix_symbols.h", + ) +) + +_ENFORCE_CPP_STYLE_COMMENT_PATH_PREFIX = tuple( + [ + "include/grpc++/", + "include/grpcpp/", + "src/core/", + "src/cpp/", + "test/core/", + "test/cpp/", + "fuzztest/", + ] +) + +RE_YEAR = ( + r"Copyright (?P[0-9]+\-)?(?P[0-9]+) ([Tt]he )?gRPC" + r" [Aa]uthors(\.|)" +) RE_LICENSE = dict( - (k, r'\n'.join(LICENSE_PREFIX_RE[k] + - (RE_YEAR if re.search(RE_YEAR, line) else re.escape(line)) - for line in LICENSE_NOTICE)) - for k, v in list(LICENSE_PREFIX_RE.items())) - -RE_C_STYLE_COMMENT_START = r'^/\*\s*\n' -RE_C_STYLE_COMMENT_OPTIONAL_LINE = r'(?:\s*\*\s*\n)*' -RE_C_STYLE_COMMENT_END = r'\s*\*/' -RE_C_STYLE_COMMENT_LICENSE = RE_C_STYLE_COMMENT_START + RE_C_STYLE_COMMENT_OPTIONAL_LINE + r'\n'.join( - r'\s*(?:\*)\s*' + (RE_YEAR if re.search(RE_YEAR, line) else re.escape(line)) + ( + k, + r"\n".join( + LICENSE_PREFIX_RE[k] + + (RE_YEAR if re.search(RE_YEAR, line) else re.escape(line)) + for line in LICENSE_NOTICE + ), + ) + for k, v in list(LICENSE_PREFIX_RE.items()) +) + +RE_C_STYLE_COMMENT_START = r"^/\*\s*\n" +RE_C_STYLE_COMMENT_OPTIONAL_LINE = r"(?:\s*\*\s*\n)*" +RE_C_STYLE_COMMENT_END = r"\s*\*/" +RE_C_STYLE_COMMENT_LICENSE = ( + RE_C_STYLE_COMMENT_START + + RE_C_STYLE_COMMENT_OPTIONAL_LINE + + r"\n".join( + r"\s*(?:\*)\s*" + + (RE_YEAR if re.search(RE_YEAR, line) else re.escape(line)) + for line in LICENSE_NOTICE + ) + + r"\n" + + RE_C_STYLE_COMMENT_OPTIONAL_LINE + + RE_C_STYLE_COMMENT_END +) +RE_CPP_STYLE_COMMENT_LICENSE = r"\n".join( + r"\s*(?://)\s*" + (RE_YEAR if re.search(RE_YEAR, line) else re.escape(line)) for line in LICENSE_NOTICE -) + r'\n' + RE_C_STYLE_COMMENT_OPTIONAL_LINE + RE_C_STYLE_COMMENT_END -RE_CPP_STYLE_COMMENT_LICENSE = r'\n'.join( - r'\s*(?://)\s*' + (RE_YEAR if re.search(RE_YEAR, line) else re.escape(line)) - for line in LICENSE_NOTICE) +) YEAR = datetime.datetime.now().year -LICENSE_YEAR = f'Copyright {YEAR} gRPC authors.' +LICENSE_YEAR = f"Copyright {YEAR} gRPC authors." def join_license_text(header, prefix, footer, notice): - text = (header + '\n') if header else "" + text = (header + "\n") if header else "" def add_prefix(prefix, line): # Don't put whitespace between prefix and empty line to avoid having # trailing whitespaces. - return prefix + ('' if len(line) == 0 else ' ') + line + return prefix + ("" if len(line) == 0 else " ") + line - text += '\n'.join( + text += "\n".join( add_prefix(prefix, (LICENSE_YEAR if re.search(RE_YEAR, line) else line)) - for line in LICENSE_NOTICE) - text += '\n' + for line in LICENSE_NOTICE + ) + text += "\n" if footer: - text += footer + '\n' + text += footer + "\n" return text LICENSE_TEXT = dict( - (k, - join_license_text(LICENSE_PREFIX_TEXT[k][0], LICENSE_PREFIX_TEXT[k][1], - LICENSE_PREFIX_TEXT[k][2], LICENSE_NOTICE)) - for k, v in list(LICENSE_PREFIX_TEXT.items())) + ( + k, + join_license_text( + LICENSE_PREFIX_TEXT[k][0], + LICENSE_PREFIX_TEXT[k][1], + LICENSE_PREFIX_TEXT[k][2], + LICENSE_NOTICE, + ), + ) + for k, v in list(LICENSE_PREFIX_TEXT.items()) +) if args.precommit: - FILE_LIST_COMMAND = 'git status -z | grep -Poz \'(?<=^[MARC][MARCD ] )[^\s]+\'' + FILE_LIST_COMMAND = ( + "git status -z | grep -Poz '(?<=^[MARC][MARCD ] )[^\s]+'" + ) else: - FILE_LIST_COMMAND = 'git ls-tree -r --name-only -r HEAD | ' \ - 'grep -v ^third_party/ |' \ - 'grep -v "\(ares_config.h\|ares_build.h\)"' + FILE_LIST_COMMAND = ( + "git ls-tree -r --name-only -r HEAD | " + "grep -v ^third_party/ |" + 'grep -v "\(ares_config.h\|ares_build.h\)"' + ) def load(name): @@ -205,18 +233,18 @@ def load(name): def save(name, text): - with open(name, 'w') as f: + with open(name, "w") as f: f.write(text) -assert (re.search(RE_LICENSE['Makefile'], load('Makefile'))) +assert re.search(RE_LICENSE["Makefile"], load("Makefile")) def log(cond, why, filename): if not cond: return - if args.output == 'details': - print(('%s: %s' % (why, filename))) + if args.output == "details": + print(("%s: %s" % (why, filename))) else: print(filename) @@ -226,18 +254,18 @@ def write_copyright(license_text, file_text, filename): lines = file_text.split("\n") if lines and lines[0].startswith("#!"): shebang = lines[0] + "\n" - file_text = file_text[len(shebang):] + file_text = file_text[len(shebang) :] rewritten_text = shebang + license_text + "\n" + file_text - with open(filename, 'w') as f: + with open(filename, "w") as f: f.write(rewritten_text) def replace_copyright(license_text, file_text, filename): m = re.search(RE_C_STYLE_COMMENT_LICENSE, text) if m: - rewritten_text = license_text + file_text[m.end():] - with open(filename, 'w') as f: + rewritten_text = license_text + file_text[m.end() :] + with open(filename, "w") as f: f.write(rewritten_text) return True return False @@ -247,8 +275,11 @@ def replace_copyright(license_text, file_text, filename): ok = True filename_list = [] try: - filename_list = subprocess.check_output(FILE_LIST_COMMAND, - shell=True).decode().splitlines() + filename_list = ( + subprocess.check_output(FILE_LIST_COMMAND, shell=True) + .decode() + .splitlines() + ) except subprocess.CalledProcessError: sys.exit(0) @@ -257,13 +288,18 @@ def replace_copyright(license_text, file_text, filename): if filename in _EXEMPT: continue # Skip check for upb generated code. - if (filename.endswith('.upb.h') or filename.endswith('.upb.c') or - filename.endswith('.upbdefs.h') or filename.endswith('.upbdefs.c')): + if ( + filename.endswith(".upb.h") + or filename.endswith(".upb.c") + or filename.endswith(".upbdefs.h") + or filename.endswith(".upbdefs.c") + ): continue ext = os.path.splitext(filename)[1] base = os.path.basename(filename) if filename.startswith(_ENFORCE_CPP_STYLE_COMMENT_PATH_PREFIX) and ext in [ - '.cc', '.h' + ".cc", + ".h", ]: enforce_cpp_style_comment = True re_license = RE_CPP_STYLE_COMMENT_LICENSE @@ -275,7 +311,7 @@ def replace_copyright(license_text, file_text, filename): re_license = RE_LICENSE[base] license_text = LICENSE_TEXT[base] else: - log(args.skips, 'skip', filename) + log(args.skips, "skip", filename) continue try: text = load(filename) @@ -285,8 +321,11 @@ def replace_copyright(license_text, file_text, filename): if m: pass elif enforce_cpp_style_comment: - log(1, 'copyright missing or does not use cpp-style copyright header', - filename) + log( + 1, + "copyright missing or does not use cpp-style copyright header", + filename, + ) if args.fix: # Attempt fix: search for c-style copyright header and replace it # with cpp-style copyright header. If that doesn't work @@ -294,17 +333,18 @@ def replace_copyright(license_text, file_text, filename): if not replace_copyright(license_text, text, filename): write_copyright(license_text, text, filename) ok = False - elif 'DO NOT EDIT' not in text: + elif "DO NOT EDIT" not in text: if args.fix: write_copyright(license_text, text, filename) - log(1, 'copyright missing (fixed)', filename) + log(1, "copyright missing (fixed)", filename) else: - log(1, 'copyright missing', filename) + log(1, "copyright missing", filename) ok = False if not ok and not args.fix: print( - 'You may use following command to automatically fix copyright headers:') - print(' tools/distrib/check_copyright.py --fix') + "You may use following command to automatically fix copyright headers:" + ) + print(" tools/distrib/check_copyright.py --fix") sys.exit(0 if ok else 1) diff --git a/tools/distrib/check_include_guards.py b/tools/distrib/check_include_guards.py index a8efd358a5873..b42b53b065fc7 100755 --- a/tools/distrib/check_include_guards.py +++ b/tools/distrib/check_include_guards.py @@ -23,66 +23,84 @@ def build_valid_guard(fpath): - guard_components = fpath.replace('++', 'XX').replace('.', - '_').upper().split('/') - if fpath.startswith('include/'): - return '_'.join(guard_components[1:]) + guard_components = ( + fpath.replace("++", "XX").replace(".", "_").upper().split("/") + ) + if fpath.startswith("include/"): + return "_".join(guard_components[1:]) else: - return 'GRPC_' + '_'.join(guard_components) + return "GRPC_" + "_".join(guard_components) def load(fpath): - with open(fpath, 'r') as f: + with open(fpath, "r") as f: return f.read() def save(fpath, contents): - with open(fpath, 'w') as f: + with open(fpath, "w") as f: f.write(contents) class GuardValidator(object): - def __init__(self): - self.ifndef_re = re.compile(r'#ifndef ([A-Z][A-Z_0-9]*)') - self.define_re = re.compile(r'#define ([A-Z][A-Z_0-9]*)') + self.ifndef_re = re.compile(r"#ifndef ([A-Z][A-Z_0-9]*)") + self.define_re = re.compile(r"#define ([A-Z][A-Z_0-9]*)") self.endif_c_core_re = re.compile( - r'#endif /\* (?: *\\\n *)?([A-Z][A-Z_0-9]*) (?:\\\n *)?\*/$') - self.endif_re = re.compile(r'#endif // ([A-Z][A-Z_0-9]*)') + r"#endif /\* (?: *\\\n *)?([A-Z][A-Z_0-9]*) (?:\\\n *)?\*/$" + ) + self.endif_re = re.compile(r"#endif // ([A-Z][A-Z_0-9]*)") self.comments_then_includes_re = re.compile( - r'^((//.*?$|/\*.*?\*/|[ \r\n\t])*)(([ \r\n\t]|#include .*)*)(#ifndef [^\n]*\n#define [^\n]*\n)', - re.DOTALL | re.MULTILINE) + ( + r"^((//.*?$|/\*.*?\*/|[ \r\n\t])*)(([ \r\n\t]|#include" + r" .*)*)(#ifndef [^\n]*\n#define [^\n]*\n)" + ), + re.DOTALL | re.MULTILINE, + ) self.failed = False def _is_c_core_header(self, fpath): - return 'include' in fpath and not ( - 'grpc++' in fpath or 'grpcpp' in fpath or 'event_engine' in fpath or - fpath.endswith('/grpc_audit_logging.h') or - fpath.endswith('/json.h')) + return "include" in fpath and not ( + "grpc++" in fpath + or "grpcpp" in fpath + or "event_engine" in fpath + or fpath.endswith("/grpc_audit_logging.h") + or fpath.endswith("/json.h") + ) def fail(self, fpath, regexp, fcontents, match_txt, correct, fix): c_core_header = self._is_c_core_header(fpath) self.failed = True invalid_guards_msg_template = ( - '{0}: Missing preprocessor guards (RE {1}). ' - 'Please wrap your code around the following guards:\n' - '#ifndef {2}\n' - '#define {2}\n' - '...\n' - '... epic code ...\n' - '...\n') + ('#endif /* {2} */' - if c_core_header else '#endif // {2}') + "{0}: Missing preprocessor guards (RE {1}). " + "Please wrap your code around the following guards:\n" + "#ifndef {2}\n" + "#define {2}\n" + "...\n" + "... epic code ...\n" + "...\n" + + ("#endif /* {2} */" if c_core_header else "#endif // {2}") + ) if not match_txt: print( - (invalid_guards_msg_template.format(fpath, regexp.pattern, - build_valid_guard(fpath)))) + ( + invalid_guards_msg_template.format( + fpath, regexp.pattern, build_valid_guard(fpath) + ) + ) + ) return fcontents - print((('{}: Wrong preprocessor guards (RE {}):' - '\n\tFound {}, expected {}').format(fpath, regexp.pattern, - match_txt, correct))) + print( + ( + ( + "{}: Wrong preprocessor guards (RE {}):" + "\n\tFound {}, expected {}" + ).format(fpath, regexp.pattern, match_txt, correct) + ) + ) if fix: - print(('Fixing {}...\n'.format(fpath))) + print("Fixing {}...\n".format(fpath)) fixed_fcontents = re.sub(match_txt, correct, fcontents) if fixed_fcontents: self.failed = False @@ -99,25 +117,27 @@ def check(self, fpath, fix): match = self.ifndef_re.search(fcontents) if not match: - print(('something drastically wrong with: %s' % fpath)) + print(("something drastically wrong with: %s" % fpath)) return False # failed if match.lastindex is None: # No ifndef. Request manual addition with hints - self.fail(fpath, match.re, match.string, '', '', False) + self.fail(fpath, match.re, match.string, "", "", False) return False # failed # Does the guard end with a '_H'? running_guard = match.group(1) - if not running_guard.endswith('_H'): - fcontents = self.fail(fpath, match.re, match.string, match.group(1), - valid_guard, fix) + if not running_guard.endswith("_H"): + fcontents = self.fail( + fpath, match.re, match.string, match.group(1), valid_guard, fix + ) if fix: save(fpath, fcontents) # Is it the expected one based on the file path? if running_guard != valid_guard: - fcontents = self.fail(fpath, match.re, match.string, match.group(1), - valid_guard, fix) + fcontents = self.fail( + fpath, match.re, match.string, match.group(1), valid_guard, fix + ) if fix: save(fpath, fcontents) @@ -125,13 +145,14 @@ def check(self, fpath, fix): match = self.define_re.search(fcontents) if match.lastindex is None: # No define. Request manual addition with hints - self.fail(fpath, match.re, match.string, '', '', False) + self.fail(fpath, match.re, match.string, "", "", False) return False # failed # Is the #define guard the same as the #ifndef guard? if match.group(1) != running_guard: - fcontents = self.fail(fpath, match.re, match.string, match.group(1), - valid_guard, fix) + fcontents = self.fail( + fpath, match.re, match.string, match.group(1), valid_guard, fix + ) if fix: save(fpath, fcontents) @@ -139,72 +160,88 @@ def check(self, fpath, fix): flines = fcontents.rstrip().splitlines() # Use findall and use the last result if there are multiple matches, # i.e. nested include guards. - match = self.endif_c_core_re.findall('\n'.join(flines[-3:])) + match = self.endif_c_core_re.findall("\n".join(flines[-3:])) if not match and not c_core_header: - match = self.endif_re.findall('\n'.join(flines[-3:])) + match = self.endif_re.findall("\n".join(flines[-3:])) if not match: # No endif. Check if we have the last line as just '#endif' and if so # replace it with a properly commented one. - if flines[-1] == '#endif': - flines[-1] = ( - '#endif' + - (' /* {} */\n'.format(valid_guard) - if c_core_header else ' // {}\n'.format(valid_guard))) + if flines[-1] == "#endif": + flines[-1] = "#endif" + ( + " /* {} */\n".format(valid_guard) + if c_core_header + else " // {}\n".format(valid_guard) + ) if fix: - fcontents = '\n'.join(flines) + fcontents = "\n".join(flines) save(fpath, fcontents) else: # something else is wrong, bail out self.fail( fpath, self.endif_c_core_re if c_core_header else self.endif_re, - flines[-1], '', '', False) + flines[-1], + "", + "", + False, + ) elif match[-1] != running_guard: # Is the #endif guard the same as the #ifndef and #define guards? - fcontents = self.fail(fpath, self.endif_re, fcontents, match[-1], - valid_guard, fix) + fcontents = self.fail( + fpath, self.endif_re, fcontents, match[-1], valid_guard, fix + ) if fix: save(fpath, fcontents) match = self.comments_then_includes_re.search(fcontents) - assert (match) + assert match bad_includes = match.group(3) if bad_includes: print( "includes after initial comments but before include guards in", - fpath) + fpath, + ) if fix: - fcontents = fcontents[:match.start(3)] + match.group( - 5) + match.group(3) + fcontents[match.end(5):] + fcontents = ( + fcontents[: match.start(3)] + + match.group(5) + + match.group(3) + + fcontents[match.end(5) :] + ) save(fpath, fcontents) return not self.failed # Did the check succeed? (ie, not failed) # find our home -ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(ROOT) # parse command line -argp = argparse.ArgumentParser(description='include guard checker') -argp.add_argument('-f', '--fix', default=False, action='store_true') -argp.add_argument('--precommit', default=False, action='store_true') +argp = argparse.ArgumentParser(description="include guard checker") +argp.add_argument("-f", "--fix", default=False, action="store_true") +argp.add_argument("--precommit", default=False, action="store_true") args = argp.parse_args() -grep_filter = r"grep -E '^(include|src/core|src/cpp|test/core|test/cpp|fuzztest/)/.*\.h$'" +grep_filter = ( + r"grep -E '^(include|src/core|src/cpp|test/core|test/cpp|fuzztest/)/.*\.h$'" +) if args.precommit: - git_command = 'git diff --name-only HEAD' + git_command = "git diff --name-only HEAD" else: - git_command = 'git ls-tree -r --name-only -r HEAD' + git_command = "git ls-tree -r --name-only -r HEAD" -FILE_LIST_COMMAND = ' | '.join((git_command, grep_filter)) +FILE_LIST_COMMAND = " | ".join((git_command, grep_filter)) # scan files ok = True filename_list = [] try: - filename_list = subprocess.check_output(FILE_LIST_COMMAND, - shell=True).decode().splitlines() + filename_list = ( + subprocess.check_output(FILE_LIST_COMMAND, shell=True) + .decode() + .splitlines() + ) # Filter out non-existent files (ie, file removed or renamed) filename_list = (f for f in filename_list if os.path.isfile(f)) except subprocess.CalledProcessError: @@ -214,8 +251,12 @@ def check(self, fpath, fix): for filename in filename_list: # Skip check for upb generated code. - if (filename.endswith('.upb.h') or filename.endswith('.upb.c') or - filename.endswith('.upbdefs.h') or filename.endswith('.upbdefs.c')): + if ( + filename.endswith(".upb.h") + or filename.endswith(".upb.c") + or filename.endswith(".upbdefs.h") + or filename.endswith(".upbdefs.c") + ): continue ok = ok and validator.check(filename, args.fix) diff --git a/tools/distrib/check_naked_includes.py b/tools/distrib/check_naked_includes.py index 961e376e5dc0a..b3893aa31e69e 100755 --- a/tools/distrib/check_naked_includes.py +++ b/tools/distrib/check_naked_includes.py @@ -22,38 +22,38 @@ import sys # find our home -ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(ROOT) # parse command line -argp = argparse.ArgumentParser(description='include guard checker') -argp.add_argument('-f', '--fix', default=False, action='store_true') +argp = argparse.ArgumentParser(description="include guard checker") +argp.add_argument("-f", "--fix", default=False, action="store_true") args = argp.parse_args() # error count errors = 0 CHECK_SUBDIRS = [ - 'src/core', - 'src/cpp', - 'test/core', - 'test/cpp', - 'fuzztest', + "src/core", + "src/cpp", + "test/core", + "test/cpp", + "fuzztest", ] for subdir in CHECK_SUBDIRS: for root, dirs, files in os.walk(subdir): for f in files: - if f.endswith('.h') or f.endswith('.cc'): + if f.endswith(".h") or f.endswith(".cc"): fpath = os.path.join(root, f) - output = open(fpath, 'r').readlines() + output = open(fpath, "r").readlines() changed = False - for (i, line) in enumerate(output): + for i, line in enumerate(output): m = re.match(r'^#include "([^"]*)"(.*)', line) if not m: continue include = m.group(1) - if '/' in include: + if "/" in include: continue expect_path = os.path.join(root, include) trailing = m.group(2) @@ -62,12 +62,16 @@ changed = True errors += 1 output[i] = '#include "{0}"{1}\n'.format( - expect_path, trailing) - print("Found naked include '{0}' in {1}".format( - include, fpath)) + expect_path, trailing + ) + print( + "Found naked include '{0}' in {1}".format( + include, fpath + ) + ) if changed and args.fix: - open(fpath, 'w').writelines(output) + open(fpath, "w").writelines(output) if errors > 0: - print('{} errors found.'.format(errors)) + print("{} errors found.".format(errors)) sys.exit(1) diff --git a/tools/distrib/check_namespace_qualification.py b/tools/distrib/check_namespace_qualification.py index 4ec45770dd8c6..bb409f04422ef 100755 --- a/tools/distrib/check_namespace_qualification.py +++ b/tools/distrib/check_namespace_qualification.py @@ -25,27 +25,27 @@ def load(fpath): - with open(fpath, 'r') as f: + with open(fpath, "r") as f: return f.readlines() def save(fpath, contents): - with open(fpath, 'w') as f: + with open(fpath, "w") as f: f.write(contents) class QualificationValidator(object): - def __init__(self): - self.fully_qualified_re = re.compile(r'([ (<])::(grpc[A-Za-z_:])') + self.fully_qualified_re = re.compile(r"([ (<])::(grpc[A-Za-z_:])") self.using_re = re.compile( - r'(using +|using +[A-Za-z_]+ *= *|namespace [A-Za-z_]+ *= *)::') - self.define_re = re.compile(r'^#define') + r"(using +|using +[A-Za-z_]+ *= *|namespace [A-Za-z_]+ *= *)::" + ) + self.define_re = re.compile(r"^#define") def check(self, fpath, fix): fcontents = load(fpath) failed = False - for (i, line) in enumerate(fcontents): + for i, line in enumerate(fcontents): if not self.fully_qualified_re.search(line): continue # skip `using` statements @@ -56,12 +56,12 @@ def check(self, fpath, fix): continue # fully-qualified namespace found, which may be unnecessary if fix: - fcontents[i] = self.fully_qualified_re.sub(r'\1\2', line) + fcontents[i] = self.fully_qualified_re.sub(r"\1\2", line) else: print("Found in %s:%d - %s" % (fpath, i, line.strip())) failed = True if fix: - save(fpath, ''.join(fcontents)) + save(fpath, "".join(fcontents)) return not failed @@ -76,34 +76,38 @@ def check(self, fpath, fix): # multi-line #define statements are not handled "src/core/lib/gprpp/global_config_env.h", "src/core/lib/profiling/timers.h", - "src/core/lib/gprpp/crash.h" + "src/core/lib/gprpp/crash.h", ] # find our home -ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(ROOT) # parse command line argp = argparse.ArgumentParser( - description='c++ namespace full qualification checker') -argp.add_argument('-f', '--fix', default=False, action='store_true') -argp.add_argument('--precommit', default=False, action='store_true') + description="c++ namespace full qualification checker" +) +argp.add_argument("-f", "--fix", default=False, action="store_true") +argp.add_argument("--precommit", default=False, action="store_true") args = argp.parse_args() grep_filter = r"grep -E '^(include|src|test).*\.(h|cc)$'" if args.precommit: - git_command = 'git diff --name-only HEAD' + git_command = "git diff --name-only HEAD" else: - git_command = 'git ls-tree -r --name-only -r HEAD' + git_command = "git ls-tree -r --name-only -r HEAD" -FILE_LIST_COMMAND = ' | '.join((git_command, grep_filter)) +FILE_LIST_COMMAND = " | ".join((git_command, grep_filter)) # scan files ok = True filename_list = [] try: - filename_list = subprocess.check_output(FILE_LIST_COMMAND, - shell=True).decode().splitlines() + filename_list = ( + subprocess.check_output(FILE_LIST_COMMAND, shell=True) + .decode() + .splitlines() + ) # Filter out non-existent files (ie, file removed or renamed) filename_list = (f for f in filename_list if os.path.isfile(f)) except subprocess.CalledProcessError: @@ -113,9 +117,13 @@ def check(self, fpath, fix): for filename in filename_list: # Skip check for upb generated code and ignored files. - if (filename.endswith('.upb.h') or filename.endswith('.upb.c') or - filename.endswith('.upbdefs.h') or - filename.endswith('.upbdefs.c') or filename in IGNORED_FILES): + if ( + filename.endswith(".upb.h") + or filename.endswith(".upb.c") + or filename.endswith(".upbdefs.h") + or filename.endswith(".upbdefs.c") + or filename in IGNORED_FILES + ): continue ok = validator.check(filename, args.fix) and ok diff --git a/tools/distrib/check_redundant_namespace_qualifiers.py b/tools/distrib/check_redundant_namespace_qualifiers.py index 0241e81d684e3..0322332209b8b 100755 --- a/tools/distrib/check_redundant_namespace_qualifiers.py +++ b/tools/distrib/check_redundant_namespace_qualifiers.py @@ -31,7 +31,7 @@ def find_closing_mustache(contents, initial_depth): if contents[0] == '"': contents = contents[1:] while contents[0] != '"': - if contents.startswith('\\\\'): + if contents.startswith("\\\\"): contents = contents[2:] elif contents.startswith('\\"'): contents = contents[2:] @@ -39,19 +39,22 @@ def find_closing_mustache(contents, initial_depth): contents = contents[1:] contents = contents[1:] # And characters that might confuse us. - elif contents.startswith("'{'") or contents.startswith( - "'\"'") or contents.startswith("'}'"): + elif ( + contents.startswith("'{'") + or contents.startswith("'\"'") + or contents.startswith("'}'") + ): contents = contents[3:] # Skip over comments. elif contents.startswith("//"): - contents = contents[contents.find('\n'):] + contents = contents[contents.find("\n") :] elif contents.startswith("/*"): - contents = contents[contents.find('*/') + 2:] + contents = contents[contents.find("*/") + 2 :] # Count up or down if we see a mustache. - elif contents[0] == '{': + elif contents[0] == "{": contents = contents[1:] depth += 1 - elif contents[0] == '}': + elif contents[0] == "}": contents = contents[1:] depth -= 1 if depth == 0: @@ -65,26 +68,32 @@ def find_closing_mustache(contents, initial_depth): def is_a_define_statement(match, body): """See if the matching line begins with #define""" # This does not yet help with multi-line defines - m = re.search(r"^#define.*{}$".format(match.group(0)), body[:match.end()], - re.MULTILINE) + m = re.search( + r"^#define.*{}$".format(match.group(0)), + body[: match.end()], + re.MULTILINE, + ) return m is not None def update_file(contents, namespaces): """Scan the contents of a file, and for top-level namespaces in namespaces remove redundant usages.""" - output = '' + output = "" while contents: - m = re.search(r'namespace ([a-zA-Z0-9_]*) {', contents) + m = re.search(r"namespace ([a-zA-Z0-9_]*) {", contents) if not m: output += contents break - output += contents[:m.end()] - contents = contents[m.end():] + output += contents[: m.end()] + contents = contents[m.end() :] end = find_closing_mustache(contents, 1) if end is None: - print('Failed to find closing mustache for namespace {}'.format( - m.group(1))) - print('Remaining text:') + print( + "Failed to find closing mustache for namespace {}".format( + m.group(1) + ) + ) + print("Remaining text:") print(contents) sys.exit(1) body = contents[:end] @@ -92,18 +101,18 @@ def update_file(contents, namespaces): if namespace in namespaces: while body: # Find instances of 'namespace::' - m = re.search(r'\b' + namespace + r'::\b', body) + m = re.search(r"\b" + namespace + r"::\b", body) if not m: break # Ignore instances of '::namespace::' -- these are usually meant to be there. - if m.start() >= 2 and body[m.start() - 2:].startswith('::'): - output += body[:m.end()] + if m.start() >= 2 and body[m.start() - 2 :].startswith("::"): + output += body[: m.end()] # Ignore #defines, since they may be used anywhere elif is_a_define_statement(m, body): - output += body[:m.end()] + output += body[: m.end()] else: - output += body[:m.start()] - body = body[m.end():] + output += body[: m.start()] + body = body[m.end() :] output += body contents = contents[end:] return output @@ -132,18 +141,22 @@ def update_file(contents, namespaces): ::foo::a; } """ -output = update_file(_TEST, ['foo']) +output = update_file(_TEST, ["foo"]) if output != _TEST_EXPECTED: import difflib - print('FAILED: self check') - print('\n'.join( - difflib.ndiff(_TEST_EXPECTED.splitlines(1), output.splitlines(1)))) + + print("FAILED: self check") + print( + "\n".join( + difflib.ndiff(_TEST_EXPECTED.splitlines(1), output.splitlines(1)) + ) + ) sys.exit(1) # Main loop. -Config = collections.namedtuple('Config', ['dirs', 'namespaces']) +Config = collections.namedtuple("Config", ["dirs", "namespaces"]) -_CONFIGURATION = (Config(['src/core', 'test/core'], ['grpc_core']),) +_CONFIGURATION = (Config(["src/core", "test/core"], ["grpc_core"]),) changed = [] @@ -151,7 +164,7 @@ def update_file(contents, namespaces): for dir in config.dirs: for root, dirs, files in os.walk(dir): for file in files: - if file.endswith('.cc') or file.endswith('.h'): + if file.endswith(".cc") or file.endswith(".h"): path = os.path.join(root, file) try: with open(path) as f: @@ -161,11 +174,11 @@ def update_file(contents, namespaces): updated = update_file(contents, config.namespaces) if updated != contents: changed.append(path) - with open(os.path.join(root, file), 'w') as f: + with open(os.path.join(root, file), "w") as f: f.write(updated) if changed: - print('The following files were changed:') + print("The following files were changed:") for path in changed: - print(' ' + path) + print(" " + path) sys.exit(1) diff --git a/tools/distrib/fix_build_deps.py b/tools/distrib/fix_build_deps.py index 612e4694d76f1..6ad3684fdddad 100755 --- a/tools/distrib/fix_build_deps.py +++ b/tools/distrib/fix_build_deps.py @@ -25,7 +25,7 @@ import run_buildozer # find our home -ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(ROOT) vendors = collections.defaultdict(list) @@ -41,256 +41,159 @@ # TODO(ctiller): ideally we wouldn't hardcode a bunch of paths here. # We can likely parse out BUILD files from dependencies to generate this index. EXTERNAL_DEPS = { - 'absl/algorithm/container.h': - 'absl/algorithm:container', - 'absl/base/attributes.h': - 'absl/base:core_headers', - 'absl/base/call_once.h': - 'absl/base', + "absl/algorithm/container.h": "absl/algorithm:container", + "absl/base/attributes.h": "absl/base:core_headers", + "absl/base/call_once.h": "absl/base", # TODO(ctiller) remove this - 'absl/base/internal/endian.h': - 'absl/base', - 'absl/base/thread_annotations.h': - 'absl/base:core_headers', - 'absl/container/flat_hash_map.h': - 'absl/container:flat_hash_map', - 'absl/container/flat_hash_set.h': - 'absl/container:flat_hash_set', - 'absl/container/inlined_vector.h': - 'absl/container:inlined_vector', - 'absl/cleanup/cleanup.h': - 'absl/cleanup', - 'absl/debugging/failure_signal_handler.h': - 'absl/debugging:failure_signal_handler', - 'absl/debugging/stacktrace.h': - 'absl/debugging:stacktrace', - 'absl/debugging/symbolize.h': - 'absl/debugging:symbolize', - 'absl/flags/flag.h': - 'absl/flags:flag', - 'absl/flags/marshalling.h': - 'absl/flags:marshalling', - 'absl/flags/parse.h': - 'absl/flags:parse', - 'absl/functional/any_invocable.h': - 'absl/functional:any_invocable', - 'absl/functional/bind_front.h': - 'absl/functional:bind_front', - 'absl/functional/function_ref.h': - 'absl/functional:function_ref', - 'absl/hash/hash.h': - 'absl/hash', - 'absl/memory/memory.h': - 'absl/memory', - 'absl/meta/type_traits.h': - 'absl/meta:type_traits', - 'absl/numeric/int128.h': - 'absl/numeric:int128', - 'absl/random/random.h': - 'absl/random', - 'absl/random/distributions.h': - 'absl/random:distributions', - 'absl/random/uniform_int_distribution.h': - 'absl/random:distributions', - 'absl/status/status.h': - 'absl/status', - 'absl/status/statusor.h': - 'absl/status:statusor', - 'absl/strings/ascii.h': - 'absl/strings', - 'absl/strings/cord.h': - 'absl/strings:cord', - 'absl/strings/escaping.h': - 'absl/strings', - 'absl/strings/match.h': - 'absl/strings', - 'absl/strings/numbers.h': - 'absl/strings', - 'absl/strings/str_cat.h': - 'absl/strings', - 'absl/strings/str_format.h': - 'absl/strings:str_format', - 'absl/strings/str_join.h': - 'absl/strings', - 'absl/strings/str_replace.h': - 'absl/strings', - 'absl/strings/str_split.h': - 'absl/strings', - 'absl/strings/string_view.h': - 'absl/strings', - 'absl/strings/strip.h': - 'absl/strings', - 'absl/strings/substitute.h': - 'absl/strings', - 'absl/synchronization/mutex.h': - 'absl/synchronization', - 'absl/synchronization/notification.h': - 'absl/synchronization', - 'absl/time/clock.h': - 'absl/time', - 'absl/time/time.h': - 'absl/time', - 'absl/types/optional.h': - 'absl/types:optional', - 'absl/types/span.h': - 'absl/types:span', - 'absl/types/variant.h': - 'absl/types:variant', - 'absl/utility/utility.h': - 'absl/utility', - 'address_sorting/address_sorting.h': - 'address_sorting', - 'ares.h': - 'cares', - 'fuzztest/fuzztest.h': ['fuzztest', 'fuzztest_main'], - 'google/api/monitored_resource.pb.h': - 'google/api:monitored_resource_cc_proto', - 'google/devtools/cloudtrace/v2/tracing.grpc.pb.h': - 'googleapis_trace_grpc_service', - 'google/logging/v2/logging.grpc.pb.h': - 'googleapis_logging_grpc_service', - 'google/logging/v2/logging.pb.h': - 'googleapis_logging_cc_proto', - 'google/logging/v2/log_entry.pb.h': - 'googleapis_logging_cc_proto', - 'google/monitoring/v3/metric_service.grpc.pb.h': - 'googleapis_monitoring_grpc_service', - 'gmock/gmock.h': - 'gtest', - 'gtest/gtest.h': - 'gtest', - 'opencensus/exporters/stats/stackdriver/stackdriver_exporter.h': - 'opencensus-stats-stackdriver_exporter', - 'opencensus/exporters/trace/stackdriver/stackdriver_exporter.h': - 'opencensus-trace-stackdriver_exporter', - 'opencensus/trace/context_util.h': - 'opencensus-trace-context_util', - 'opencensus/trace/propagation/grpc_trace_bin.h': - 'opencensus-trace-propagation', - 'opencensus/tags/context_util.h': - 'opencensus-tags-context_util', - 'opencensus/trace/span_context.h': - 'opencensus-trace-span_context', - 'openssl/base.h': - 'libssl', - 'openssl/bio.h': - 'libssl', - 'openssl/bn.h': - 'libcrypto', - 'openssl/buffer.h': - 'libcrypto', - 'openssl/crypto.h': - 'libcrypto', - 'openssl/digest.h': - 'libssl', - 'openssl/engine.h': - 'libcrypto', - 'openssl/err.h': - 'libcrypto', - 'openssl/evp.h': - 'libcrypto', - 'openssl/hmac.h': - 'libcrypto', - 'openssl/pem.h': - 'libcrypto', - 'openssl/rsa.h': - 'libcrypto', - 'openssl/sha.h': - 'libcrypto', - 'openssl/ssl.h': - 'libssl', - 'openssl/tls1.h': - 'libssl', - 'openssl/x509.h': - 'libcrypto', - 'openssl/x509v3.h': - 'libcrypto', - 're2/re2.h': - 're2', - 'upb/arena.h': - 'upb_lib', - 'upb/base/string_view.h': - 'upb_lib', - 'upb/collections/map.h': - 'upb_collections_lib', - 'upb/def.h': - 'upb_lib', - 'upb/json_encode.h': - 'upb_json_lib', - 'upb/mem/arena.h': - 'upb_lib', - 'upb/text_encode.h': - 'upb_textformat_lib', - 'upb/def.hpp': - 'upb_reflection', - 'upb/upb.h': - 'upb_lib', - 'upb/upb.hpp': - 'upb_lib', - 'xxhash.h': - 'xxhash', - 'zlib.h': - 'madler_zlib', + "absl/base/internal/endian.h": "absl/base", + "absl/base/thread_annotations.h": "absl/base:core_headers", + "absl/container/flat_hash_map.h": "absl/container:flat_hash_map", + "absl/container/flat_hash_set.h": "absl/container:flat_hash_set", + "absl/container/inlined_vector.h": "absl/container:inlined_vector", + "absl/cleanup/cleanup.h": "absl/cleanup", + "absl/debugging/failure_signal_handler.h": ( + "absl/debugging:failure_signal_handler" + ), + "absl/debugging/stacktrace.h": "absl/debugging:stacktrace", + "absl/debugging/symbolize.h": "absl/debugging:symbolize", + "absl/flags/flag.h": "absl/flags:flag", + "absl/flags/marshalling.h": "absl/flags:marshalling", + "absl/flags/parse.h": "absl/flags:parse", + "absl/functional/any_invocable.h": "absl/functional:any_invocable", + "absl/functional/bind_front.h": "absl/functional:bind_front", + "absl/functional/function_ref.h": "absl/functional:function_ref", + "absl/hash/hash.h": "absl/hash", + "absl/memory/memory.h": "absl/memory", + "absl/meta/type_traits.h": "absl/meta:type_traits", + "absl/numeric/int128.h": "absl/numeric:int128", + "absl/random/random.h": "absl/random", + "absl/random/distributions.h": "absl/random:distributions", + "absl/random/uniform_int_distribution.h": "absl/random:distributions", + "absl/status/status.h": "absl/status", + "absl/status/statusor.h": "absl/status:statusor", + "absl/strings/ascii.h": "absl/strings", + "absl/strings/cord.h": "absl/strings:cord", + "absl/strings/escaping.h": "absl/strings", + "absl/strings/match.h": "absl/strings", + "absl/strings/numbers.h": "absl/strings", + "absl/strings/str_cat.h": "absl/strings", + "absl/strings/str_format.h": "absl/strings:str_format", + "absl/strings/str_join.h": "absl/strings", + "absl/strings/str_replace.h": "absl/strings", + "absl/strings/str_split.h": "absl/strings", + "absl/strings/string_view.h": "absl/strings", + "absl/strings/strip.h": "absl/strings", + "absl/strings/substitute.h": "absl/strings", + "absl/synchronization/mutex.h": "absl/synchronization", + "absl/synchronization/notification.h": "absl/synchronization", + "absl/time/clock.h": "absl/time", + "absl/time/time.h": "absl/time", + "absl/types/optional.h": "absl/types:optional", + "absl/types/span.h": "absl/types:span", + "absl/types/variant.h": "absl/types:variant", + "absl/utility/utility.h": "absl/utility", + "address_sorting/address_sorting.h": "address_sorting", + "ares.h": "cares", + "fuzztest/fuzztest.h": ["fuzztest", "fuzztest_main"], + "google/api/monitored_resource.pb.h": ( + "google/api:monitored_resource_cc_proto" + ), + "google/devtools/cloudtrace/v2/tracing.grpc.pb.h": ( + "googleapis_trace_grpc_service" + ), + "google/logging/v2/logging.grpc.pb.h": "googleapis_logging_grpc_service", + "google/logging/v2/logging.pb.h": "googleapis_logging_cc_proto", + "google/logging/v2/log_entry.pb.h": "googleapis_logging_cc_proto", + "google/monitoring/v3/metric_service.grpc.pb.h": ( + "googleapis_monitoring_grpc_service" + ), + "gmock/gmock.h": "gtest", + "gtest/gtest.h": "gtest", + "opencensus/exporters/stats/stackdriver/stackdriver_exporter.h": ( + "opencensus-stats-stackdriver_exporter" + ), + "opencensus/exporters/trace/stackdriver/stackdriver_exporter.h": ( + "opencensus-trace-stackdriver_exporter" + ), + "opencensus/trace/context_util.h": "opencensus-trace-context_util", + "opencensus/trace/propagation/grpc_trace_bin.h": ( + "opencensus-trace-propagation" + ), + "opencensus/tags/context_util.h": "opencensus-tags-context_util", + "opencensus/trace/span_context.h": "opencensus-trace-span_context", + "openssl/base.h": "libssl", + "openssl/bio.h": "libssl", + "openssl/bn.h": "libcrypto", + "openssl/buffer.h": "libcrypto", + "openssl/crypto.h": "libcrypto", + "openssl/digest.h": "libssl", + "openssl/engine.h": "libcrypto", + "openssl/err.h": "libcrypto", + "openssl/evp.h": "libcrypto", + "openssl/hmac.h": "libcrypto", + "openssl/pem.h": "libcrypto", + "openssl/rsa.h": "libcrypto", + "openssl/sha.h": "libcrypto", + "openssl/ssl.h": "libssl", + "openssl/tls1.h": "libssl", + "openssl/x509.h": "libcrypto", + "openssl/x509v3.h": "libcrypto", + "re2/re2.h": "re2", + "upb/arena.h": "upb_lib", + "upb/base/string_view.h": "upb_lib", + "upb/collections/map.h": "upb_collections_lib", + "upb/def.h": "upb_lib", + "upb/json_encode.h": "upb_json_lib", + "upb/mem/arena.h": "upb_lib", + "upb/text_encode.h": "upb_textformat_lib", + "upb/def.hpp": "upb_reflection", + "upb/upb.h": "upb_lib", + "upb/upb.hpp": "upb_lib", + "xxhash.h": "xxhash", + "zlib.h": "madler_zlib", } INTERNAL_DEPS = { - "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h": - "//test/core/event_engine/fuzzing_event_engine", - "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h": - "//test/core/event_engine/fuzzing_event_engine:fuzzing_event_engine_proto", - 'google/api/expr/v1alpha1/syntax.upb.h': - 'google_type_expr_upb', - 'google/rpc/status.upb.h': - 'google_rpc_status_upb', - 'google/protobuf/any.upb.h': - 'protobuf_any_upb', - 'google/protobuf/duration.upb.h': - 'protobuf_duration_upb', - 'google/protobuf/struct.upb.h': - 'protobuf_struct_upb', - 'google/protobuf/timestamp.upb.h': - 'protobuf_timestamp_upb', - 'google/protobuf/wrappers.upb.h': - 'protobuf_wrappers_upb', - 'grpc/status.h': - 'grpc_public_hdrs', - 'src/proto/grpc/channelz/channelz.grpc.pb.h': - '//src/proto/grpc/channelz:channelz_proto', - 'src/proto/grpc/core/stats.pb.h': - '//src/proto/grpc/core:stats_proto', - 'src/proto/grpc/health/v1/health.upb.h': - 'grpc_health_upb', - 'src/proto/grpc/lb/v1/load_reporter.grpc.pb.h': - '//src/proto/grpc/lb/v1:load_reporter_proto', - 'src/proto/grpc/lb/v1/load_balancer.upb.h': - 'grpc_lb_upb', - 'src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h': - '//src/proto/grpc/reflection/v1alpha:reflection_proto', - 'src/proto/grpc/gcp/transport_security_common.upb.h': - 'alts_upb', - 'src/proto/grpc/gcp/handshaker.upb.h': - 'alts_upb', - 'src/proto/grpc/gcp/altscontext.upb.h': - 'alts_upb', - 'src/proto/grpc/lookup/v1/rls.upb.h': - 'rls_upb', - 'src/proto/grpc/lookup/v1/rls_config.upb.h': - 'rls_config_upb', - 'src/proto/grpc/lookup/v1/rls_config.upbdefs.h': - 'rls_config_upbdefs', - 'src/proto/grpc/testing/xds/v3/csds.grpc.pb.h': - '//src/proto/grpc/testing/xds/v3:csds_proto', - 'xds/data/orca/v3/orca_load_report.upb.h': - 'xds_orca_upb', - 'xds/service/orca/v3/orca.upb.h': - 'xds_orca_service_upb', - 'xds/type/v3/typed_struct.upb.h': - 'xds_type_upb', + "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h": ( + "//test/core/event_engine/fuzzing_event_engine" + ), + "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h": "//test/core/event_engine/fuzzing_event_engine:fuzzing_event_engine_proto", + "google/api/expr/v1alpha1/syntax.upb.h": "google_type_expr_upb", + "google/rpc/status.upb.h": "google_rpc_status_upb", + "google/protobuf/any.upb.h": "protobuf_any_upb", + "google/protobuf/duration.upb.h": "protobuf_duration_upb", + "google/protobuf/struct.upb.h": "protobuf_struct_upb", + "google/protobuf/timestamp.upb.h": "protobuf_timestamp_upb", + "google/protobuf/wrappers.upb.h": "protobuf_wrappers_upb", + "grpc/status.h": "grpc_public_hdrs", + "src/proto/grpc/channelz/channelz.grpc.pb.h": ( + "//src/proto/grpc/channelz:channelz_proto" + ), + "src/proto/grpc/core/stats.pb.h": "//src/proto/grpc/core:stats_proto", + "src/proto/grpc/health/v1/health.upb.h": "grpc_health_upb", + "src/proto/grpc/lb/v1/load_reporter.grpc.pb.h": ( + "//src/proto/grpc/lb/v1:load_reporter_proto" + ), + "src/proto/grpc/lb/v1/load_balancer.upb.h": "grpc_lb_upb", + "src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h": ( + "//src/proto/grpc/reflection/v1alpha:reflection_proto" + ), + "src/proto/grpc/gcp/transport_security_common.upb.h": "alts_upb", + "src/proto/grpc/gcp/handshaker.upb.h": "alts_upb", + "src/proto/grpc/gcp/altscontext.upb.h": "alts_upb", + "src/proto/grpc/lookup/v1/rls.upb.h": "rls_upb", + "src/proto/grpc/lookup/v1/rls_config.upb.h": "rls_config_upb", + "src/proto/grpc/lookup/v1/rls_config.upbdefs.h": "rls_config_upbdefs", + "src/proto/grpc/testing/xds/v3/csds.grpc.pb.h": ( + "//src/proto/grpc/testing/xds/v3:csds_proto" + ), + "xds/data/orca/v3/orca_load_report.upb.h": "xds_orca_upb", + "xds/service/orca/v3/orca.upb.h": "xds_orca_service_upb", + "xds/type/v3/typed_struct.upb.h": "xds_type_upb", } class FakeSelects: - def config_setting_group(self, **kwargs): pass @@ -302,35 +205,43 @@ def config_setting_group(self, **kwargs): # Convert the source or header target to a relative path. def _get_filename(name, parsing_path): - filename = '%s%s' % ( - (parsing_path + '/' if - (parsing_path and not name.startswith('//')) else ''), name) - filename = filename.replace('//:', '') - filename = filename.replace('//src/core:', 'src/core/') - filename = filename.replace('//src/cpp/ext/filters/census:', - 'src/cpp/ext/filters/census/') + filename = "%s%s" % ( + ( + parsing_path + "/" + if (parsing_path and not name.startswith("//")) + else "" + ), + name, + ) + filename = filename.replace("//:", "") + filename = filename.replace("//src/core:", "src/core/") + filename = filename.replace( + "//src/cpp/ext/filters/census:", "src/cpp/ext/filters/census/" + ) return filename -def grpc_cc_library(name, - hdrs=[], - public_hdrs=[], - srcs=[], - select_deps=None, - tags=[], - deps=[], - external_deps=[], - proto=None, - **kwargs): +def grpc_cc_library( + name, + hdrs=[], + public_hdrs=[], + srcs=[], + select_deps=None, + tags=[], + deps=[], + external_deps=[], + proto=None, + **kwargs, +): global args global num_cc_libraries global num_opted_out_cc_libraries global parsing_path - assert (parsing_path is not None) - name = '//%s:%s' % (parsing_path, name) + assert parsing_path is not None + name = "//%s:%s" % (parsing_path, name) num_cc_libraries += 1 - if select_deps or 'nofixdeps' in tags: - if args.whats_left and not select_deps and 'nofixdeps' not in tags: + if select_deps or "nofixdeps" in tags: + if args.whats_left and not select_deps and "nofixdeps" not in tags: num_opted_out_cc_libraries += 1 print("Not opted in: {}".format(name)) no_update.add(name) @@ -338,11 +249,13 @@ def grpc_cc_library(name, # avoid_dep is the internal way of saying prefer something else # we add grpc_avoid_dep to allow internal grpc-only stuff to avoid each # other, whilst not biasing dependent projects - if 'avoid_dep' in tags or 'grpc_avoid_dep' in tags: + if "avoid_dep" in tags or "grpc_avoid_dep" in tags: avoidness[name] += 10 if proto: - proto_hdr = '%s%s' % ((parsing_path + '/' if parsing_path else ''), - proto.replace('.proto', '.pb.h')) + proto_hdr = "%s%s" % ( + (parsing_path + "/" if parsing_path else ""), + proto.replace(".proto", ".pb.h"), + ) skip_headers[name].add(proto_hdr) for hdr in hdrs + public_hdrs: @@ -352,7 +265,7 @@ def grpc_cc_library(name, original_external_deps[name] = frozenset(external_deps) for src in hdrs + public_hdrs + srcs: for line in open(_get_filename(src, parsing_path)): - m = re.search(r'^#include <(.*)>', line) + m = re.search(r"^#include <(.*)>", line) if m: inc.add(m.group(1)) m = re.search(r'^#include "(.*)"', line) @@ -363,27 +276,28 @@ def grpc_cc_library(name, def grpc_proto_library(name, srcs, **kwargs): global parsing_path - assert (parsing_path is not None) - name = '//%s:%s' % (parsing_path, name) + assert parsing_path is not None + name = "//%s:%s" % (parsing_path, name) for src in srcs: - proto_hdr = src.replace('.proto', '.pb.h') + proto_hdr = src.replace(".proto", ".pb.h") vendors[_get_filename(proto_hdr, parsing_path)].append(name) def buildozer(cmd, target): - buildozer_commands.append('%s|%s' % (cmd, target)) + buildozer_commands.append("%s|%s" % (cmd, target)) def buildozer_set_list(name, values, target, via=""): if not values: - buildozer('remove %s' % name, target) + buildozer("remove %s" % name, target) return adjust = via if via else name - buildozer('set %s %s' % (adjust, ' '.join('"%s"' % s for s in values)), - target) + buildozer( + "set %s %s" % (adjust, " ".join('"%s"' % s for s in values)), target + ) if via: - buildozer('remove %s' % name, target) - buildozer('rename %s %s' % (via, name), target) + buildozer("remove %s" % name, target) + buildozer("rename %s %s" % (via, name), target) def score_edit_distance(proposed, existing): @@ -417,107 +331,120 @@ def score_best(proposed, existing): SCORERS = { - 'edit_distance': score_edit_distance, - 'list_size': score_list_size, - 'best': score_best, + "edit_distance": score_edit_distance, + "list_size": score_list_size, + "best": score_best, } -parser = argparse.ArgumentParser(description='Fix build dependencies') -parser.add_argument('targets', - nargs='*', - default=[], - help='targets to fix (empty => all)') -parser.add_argument('--score', - type=str, - default='edit_distance', - help='scoring function to use: one of ' + - ', '.join(SCORERS.keys())) -parser.add_argument('--whats_left', - action='store_true', - default=False, - help='show what is left to opt in') -parser.add_argument('--explain', - action='store_true', - default=False, - help='try to explain some decisions') +parser = argparse.ArgumentParser(description="Fix build dependencies") +parser.add_argument( + "targets", nargs="*", default=[], help="targets to fix (empty => all)" +) parser.add_argument( - '--why', + "--score", + type=str, + default="edit_distance", + help="scoring function to use: one of " + ", ".join(SCORERS.keys()), +) +parser.add_argument( + "--whats_left", + action="store_true", + default=False, + help="show what is left to opt in", +) +parser.add_argument( + "--explain", + action="store_true", + default=False, + help="try to explain some decisions", +) +parser.add_argument( + "--why", type=str, default=None, - help='with --explain, target why a given dependency is needed') + help="with --explain, target why a given dependency is needed", +) args = parser.parse_args() for dirname in [ - "", - "src/core", - "src/cpp/ext/gcp", - "test/core/backoff", - "test/core/uri", - "test/core/util", - "test/core/end2end", - "test/core/event_engine", - "test/core/filters", - "test/core/promise", - "test/core/resource_quota", - "test/core/transport/chaotic_good", - "fuzztest", - "fuzztest/core/channel", + "", + "src/core", + "src/cpp/ext/gcp", + "test/core/backoff", + "test/core/uri", + "test/core/util", + "test/core/end2end", + "test/core/event_engine", + "test/core/filters", + "test/core/promise", + "test/core/resource_quota", + "test/core/transport/chaotic_good", + "fuzztest", + "fuzztest/core/channel", ]: parsing_path = dirname exec( - open('%sBUILD' % (dirname + '/' if dirname else ''), 'r').read(), { - 'load': lambda filename, *args: None, - 'licenses': lambda licenses: None, - 'package': lambda **kwargs: None, - 'exports_files': lambda files, visibility=None: None, - 'bool_flag': lambda **kwargs: None, - 'config_setting': lambda **kwargs: None, - 'selects': FakeSelects(), - 'python_config_settings': lambda **kwargs: None, - 'grpc_cc_binary': grpc_cc_library, - 'grpc_cc_library': grpc_cc_library, - 'grpc_cc_test': grpc_cc_library, - 'grpc_core_end2end_test': lambda **kwargs: None, - 'grpc_fuzzer': grpc_cc_library, - 'grpc_fuzz_test': grpc_cc_library, - 'grpc_proto_fuzzer': grpc_cc_library, - 'grpc_proto_library': grpc_proto_library, - 'select': lambda d: d["//conditions:default"], - 'glob': lambda files: None, - 'grpc_end2end_tests': lambda: None, - 'grpc_upb_proto_library': lambda name, **kwargs: None, - 'grpc_upb_proto_reflection_library': lambda name, **kwargs: None, - 'grpc_generate_one_off_targets': lambda: None, - 'grpc_generate_one_off_internal_targets': lambda: None, - 'grpc_package': lambda **kwargs: None, - 'filegroup': lambda name, **kwargs: None, - 'sh_library': lambda name, **kwargs: None, - }, {}) + open("%sBUILD" % (dirname + "/" if dirname else ""), "r").read(), + { + "load": lambda filename, *args: None, + "licenses": lambda licenses: None, + "package": lambda **kwargs: None, + "exports_files": lambda files, visibility=None: None, + "bool_flag": lambda **kwargs: None, + "config_setting": lambda **kwargs: None, + "selects": FakeSelects(), + "python_config_settings": lambda **kwargs: None, + "grpc_cc_binary": grpc_cc_library, + "grpc_cc_library": grpc_cc_library, + "grpc_cc_test": grpc_cc_library, + "grpc_core_end2end_test": lambda **kwargs: None, + "grpc_fuzzer": grpc_cc_library, + "grpc_fuzz_test": grpc_cc_library, + "grpc_proto_fuzzer": grpc_cc_library, + "grpc_proto_library": grpc_proto_library, + "select": lambda d: d["//conditions:default"], + "glob": lambda files: None, + "grpc_end2end_tests": lambda: None, + "grpc_upb_proto_library": lambda name, **kwargs: None, + "grpc_upb_proto_reflection_library": lambda name, **kwargs: None, + "grpc_generate_one_off_targets": lambda: None, + "grpc_generate_one_off_internal_targets": lambda: None, + "grpc_package": lambda **kwargs: None, + "filegroup": lambda name, **kwargs: None, + "sh_library": lambda name, **kwargs: None, + }, + {}, + ) parsing_path = None if args.whats_left: - print("{}/{} libraries are opted in".format( - num_cc_libraries - num_opted_out_cc_libraries, num_cc_libraries)) + print( + "{}/{} libraries are opted in".format( + num_cc_libraries - num_opted_out_cc_libraries, num_cc_libraries + ) + ) def make_relative_path(dep, lib): if lib is None: return dep - lib_path = lib[:lib.rfind(':') + 1] + lib_path = lib[: lib.rfind(":") + 1] if dep.startswith(lib_path): - return dep[len(lib_path):] + return dep[len(lib_path) :] return dep if args.whats_left: - print("{}/{} libraries are opted in".format( - num_cc_libraries - num_opted_out_cc_libraries, num_cc_libraries)) + print( + "{}/{} libraries are opted in".format( + num_cc_libraries - num_opted_out_cc_libraries, num_cc_libraries + ) + ) # Keeps track of all possible sets of dependencies that could satify the # problem. (models the list monad in Haskell!) class Choices: - def __init__(self, library, substitutions): self.library = library self.to_add = [] @@ -527,14 +454,20 @@ def __init__(self, library, substitutions): def add_one_of(self, choices, trigger): if not choices: return - choices = sum([self.apply_substitutions(choice) for choice in choices], - []) + choices = sum( + [self.apply_substitutions(choice) for choice in choices], [] + ) if args.explain and (args.why is None or args.why in choices): - print("{}: Adding one of {} for {}".format(self.library, choices, - trigger)) + print( + "{}: Adding one of {} for {}".format( + self.library, choices, trigger + ) + ) self.to_add.append( tuple( - make_relative_path(choice, self.library) for choice in choices)) + make_relative_path(choice, self.library) for choice in choices + ) + ) def add(self, choice, trigger): self.add_one_of([choice], trigger) @@ -581,25 +514,29 @@ def make_library(library): # we need a little trickery here since grpc_base has channel.cc, which calls grpc_init # which is in grpc, which is illegal but hard to change # once EventEngine lands we can clean this up - deps = Choices(library, {'//:grpc_base': ['//:grpc', '//:grpc_unsecure']} - if library.startswith('//test/') else {}) + deps = Choices( + library, + {"//:grpc_base": ["//:grpc", "//:grpc_unsecure"]} + if library.startswith("//test/") + else {}, + ) external_deps = Choices(None, {}) for hdr in hdrs: if hdr in skip_headers[library]: continue - if hdr == 'systemd/sd-daemon.h': + if hdr == "systemd/sd-daemon.h": continue - if hdr == 'src/core/lib/profiling/stap_probes.h': + if hdr == "src/core/lib/profiling/stap_probes.h": continue - if hdr.startswith('src/libfuzzer/'): + if hdr.startswith("src/libfuzzer/"): continue - if hdr == 'grpc/grpc.h' and library.startswith('//test:'): + if hdr == "grpc/grpc.h" and library.startswith("//test:"): # not the root build including grpc.h ==> //:grpc - deps.add_one_of(['//:grpc', '//:grpc_unsecure'], hdr) + deps.add_one_of(["//:grpc", "//:grpc_unsecure"], hdr) continue if hdr in INTERNAL_DEPS: @@ -608,8 +545,8 @@ def make_library(library): for d in dep: deps.add(d, hdr) else: - if not ('//' in dep): - dep = '//:' + dep + if not ("//" in dep): + dep = "//:" + dep deps.add(dep, hdr) continue @@ -617,11 +554,11 @@ def make_library(library): deps.add_one_of(vendors[hdr], hdr) continue - if 'include/' + hdr in vendors: - deps.add_one_of(vendors['include/' + hdr], hdr) + if "include/" + hdr in vendors: + deps.add_one_of(vendors["include/" + hdr], hdr) continue - if '.' not in hdr: + if "." not in hdr: # assume a c++ system include continue @@ -633,58 +570,62 @@ def make_library(library): external_deps.add(EXTERNAL_DEPS[hdr], hdr) continue - if hdr.startswith('opencensus/'): - trail = hdr[len('opencensus/'):] - trail = trail[:trail.find('/')] - external_deps.add('opencensus-' + trail, hdr) + if hdr.startswith("opencensus/"): + trail = hdr[len("opencensus/") :] + trail = trail[: trail.find("/")] + external_deps.add("opencensus-" + trail, hdr) continue - if hdr.startswith('envoy/'): + if hdr.startswith("envoy/"): path, file = os.path.split(hdr) - file = file.split('.') - path = path.split('/') - dep = '_'.join(path[:-1] + [file[1]]) + file = file.split(".") + path = path.split("/") + dep = "_".join(path[:-1] + [file[1]]) deps.add(dep, hdr) continue - if hdr.startswith('google/protobuf/') and not hdr.endswith('.upb.h'): - external_deps.add('protobuf_headers', hdr) + if hdr.startswith("google/protobuf/") and not hdr.endswith(".upb.h"): + external_deps.add("protobuf_headers", hdr) continue - if '/' not in hdr: + if "/" not in hdr: # assume a system include continue is_sys_include = False for sys_path in [ - 'sys', - 'arpa', - 'gperftools', - 'netinet', - 'linux', - 'android', - 'mach', - 'net', - 'CoreFoundation', + "sys", + "arpa", + "gperftools", + "netinet", + "linux", + "android", + "mach", + "net", + "CoreFoundation", ]: - if hdr.startswith(sys_path + '/'): + if hdr.startswith(sys_path + "/"): is_sys_include = True break if is_sys_include: # assume a system include continue - print("# ERROR: can't categorize header: %s used by %s" % - (hdr, library)) + print( + "# ERROR: can't categorize header: %s used by %s" % (hdr, library) + ) error = True deps.remove(library) deps = sorted( - deps.best(lambda x: SCORERS[args.score](x, original_deps[library]))) + deps.best(lambda x: SCORERS[args.score](x, original_deps[library])) + ) external_deps = sorted( - external_deps.best(lambda x: SCORERS[args.score] - (x, original_external_deps[library]))) + external_deps.best( + lambda x: SCORERS[args.score](x, original_external_deps[library]) + ) + ) return (library, error, deps, external_deps) @@ -705,8 +646,8 @@ def main() -> None: if lib_error: error = True continue - buildozer_set_list('external_deps', external_deps, library, via='deps') - buildozer_set_list('deps', deps, library) + buildozer_set_list("external_deps", external_deps, library, via="deps") + buildozer_set_list("deps", deps, library) run_buildozer.run_buildozer(buildozer_commands) diff --git a/tools/distrib/gen_compilation_database.py b/tools/distrib/gen_compilation_database.py index f2c93ade71c5a..15e5e3a84cfb6 100755 --- a/tools/distrib/gen_compilation_database.py +++ b/tools/distrib/gen_compilation_database.py @@ -39,21 +39,33 @@ def generateCompilationDatabase(args): "--remote_download_outputs=all", ] - subprocess.check_call(["bazel", "build"] + bazel_options + [ - "--aspects=@bazel_compdb//:aspects.bzl%compilation_database_aspect", - "--output_groups=compdb_files,header_files" - ] + args.bazel_targets) - - execroot = subprocess.check_output(["bazel", "info", "execution_root"] + - bazel_options).decode().strip() + subprocess.check_call( + ["bazel", "build"] + + bazel_options + + [ + "--aspects=@bazel_compdb//:aspects.bzl%compilation_database_aspect", + "--output_groups=compdb_files,header_files", + ] + + args.bazel_targets + ) + + execroot = ( + subprocess.check_output( + ["bazel", "info", "execution_root"] + bazel_options + ) + .decode() + .strip() + ) compdb = [] for compdb_file in Path(execroot).glob("**/*.compile_commands.json"): compdb.extend( json.loads( - "[" + - compdb_file.read_text().replace("__EXEC_ROOT__", execroot) + - "]")) + "[" + + compdb_file.read_text().replace("__EXEC_ROOT__", execroot) + + "]" + ) + ) if args.dedup_targets: compdb_map = {target["file"]: target for target in compdb} @@ -126,13 +138,14 @@ def fixCompilationDatabase(args, db): if __name__ == "__main__": parser = argparse.ArgumentParser( - description='Generate JSON compilation database') - parser.add_argument('--include_external', action='store_true') - parser.add_argument('--include_genfiles', action='store_true') - parser.add_argument('--include_headers', action='store_true') - parser.add_argument('--vscode', action='store_true') - parser.add_argument('--ignore_system_headers', action='store_true') - parser.add_argument('--dedup_targets', action='store_true') - parser.add_argument('bazel_targets', nargs='*', default=["//..."]) + description="Generate JSON compilation database" + ) + parser.add_argument("--include_external", action="store_true") + parser.add_argument("--include_genfiles", action="store_true") + parser.add_argument("--include_headers", action="store_true") + parser.add_argument("--vscode", action="store_true") + parser.add_argument("--ignore_system_headers", action="store_true") + parser.add_argument("--dedup_targets", action="store_true") + parser.add_argument("bazel_targets", nargs="*", default=["//..."]) args = parser.parse_args() fixCompilationDatabase(args, generateCompilationDatabase(args)) diff --git a/tools/distrib/isort_code.sh b/tools/distrib/isort_code.sh index 97c01651330bf..34955e8460a3d 100755 --- a/tools/distrib/isort_code.sh +++ b/tools/distrib/isort_code.sh @@ -31,7 +31,6 @@ DIRS=( 'test' 'tools' 'setup.py' - 'tools/run_tests/xds_k8s_test_driver' ) VIRTUALENV=isort_virtual_environment @@ -40,21 +39,4 @@ python3 -m virtualenv $VIRTUALENV PYTHON=${VIRTUALENV}/bin/python "$PYTHON" -m pip install isort==5.9.2 -$PYTHON -m isort $ACTION \ - --force-sort-within-sections \ - --force-single-line-imports --single-line-exclusions=typing \ - --src "examples/python/data_transmission" \ - --src "examples/python/async_streaming" \ - --src "tools/run_tests/xds_k8s_test_driver" \ - --src "src/python/grpcio_tests" \ - --src "tools/run_tests" \ - --project "examples" \ - --project "src" \ - --thirdparty "grpc" \ - --skip-glob "third_party/*" \ - --skip-glob "*/env/*" \ - --skip-glob "*pb2*.py" \ - --skip-glob "*pb2*.pyi" \ - --skip-glob "**/site-packages/**/*" \ - --dont-follow-links \ - "${DIRS[@]}" +$PYTHON -m isort $ACTION --settings-path=black.toml --dont-follow-links "${DIRS[@]}" diff --git a/tools/distrib/pylint_code.sh b/tools/distrib/pylint_code.sh index b68bf2ad8ea1a..76a25c7dd851f 100755 --- a/tools/distrib/pylint_code.sh +++ b/tools/distrib/pylint_code.sh @@ -47,7 +47,10 @@ PYTHON=$VIRTUALENV/bin/python $PYTHON -m pip install --upgrade pip==19.3.1 # TODO(https://github.com/grpc/grpc/issues/23394): Update Pylint. -$PYTHON -m pip install --upgrade astroid==2.3.3 pylint==2.2.2 "isort>=4.3.0,<5.0.0" +$PYTHON -m pip install --upgrade astroid==2.3.3 \ + pylint==2.2.2 \ + toml==0.10.2 \ + "isort>=4.3.0,<5.0.0" EXIT=0 for dir in "${DIRS[@]}"; do diff --git a/tools/distrib/python/check_grpcio_tools.py b/tools/distrib/python/check_grpcio_tools.py index 376ea55d226f3..495b2f718a64c 100755 --- a/tools/distrib/python/check_grpcio_tools.py +++ b/tools/distrib/python/check_grpcio_tools.py @@ -22,11 +22,12 @@ submodule_commit_hash = _make.protobuf_submodule_commit_hash() -with open(_make.GRPC_PYTHON_PROTOC_LIB_DEPS, 'r') as _protoc_lib_deps_file: +with open(_make.GRPC_PYTHON_PROTOC_LIB_DEPS, "r") as _protoc_lib_deps_file: content = _protoc_lib_deps_file.read().splitlines() -testString = (_make.COMMIT_HASH_PREFIX + submodule_commit_hash + - _make.COMMIT_HASH_SUFFIX) +testString = ( + _make.COMMIT_HASH_PREFIX + submodule_commit_hash + _make.COMMIT_HASH_SUFFIX +) if testString not in content: print(OUT_OF_DATE_MESSAGE.format(_make.GRPC_PYTHON_PROTOC_LIB_DEPS)) diff --git a/tools/distrib/python/docgen.py b/tools/distrib/python/docgen.py index 9cbd6bbab6e84..00d00940e73bc 100755 --- a/tools/distrib/python/docgen.py +++ b/tools/distrib/python/docgen.py @@ -26,90 +26,104 @@ import grpc_version parser = argparse.ArgumentParser() -parser.add_argument('--repo-owner', - type=str, - help=('Owner of the GitHub repository to be pushed')) -parser.add_argument('--doc-branch', - type=str, - default='python-doc-%s' % grpc_version.VERSION) +parser.add_argument( + "--repo-owner", type=str, help="Owner of the GitHub repository to be pushed" +) +parser.add_argument( + "--doc-branch", type=str, default="python-doc-%s" % grpc_version.VERSION +) args = parser.parse_args() SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) -PROJECT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, '..', '..', '..')) +PROJECT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, "..", "..", "..")) -SETUP_PATH = os.path.join(PROJECT_ROOT, 'setup.py') -REQUIREMENTS_PATH = os.path.join(PROJECT_ROOT, 'requirements.bazel.txt') -DOC_PATH = os.path.join(PROJECT_ROOT, 'doc/build') +SETUP_PATH = os.path.join(PROJECT_ROOT, "setup.py") +REQUIREMENTS_PATH = os.path.join(PROJECT_ROOT, "requirements.bazel.txt") +DOC_PATH = os.path.join(PROJECT_ROOT, "doc/build") if "VIRTUAL_ENV" in os.environ: - VIRTUALENV_DIR = os.environ['VIRTUAL_ENV'] - PYTHON_PATH = os.path.join(VIRTUALENV_DIR, 'bin', 'python') + VIRTUALENV_DIR = os.environ["VIRTUAL_ENV"] + PYTHON_PATH = os.path.join(VIRTUALENV_DIR, "bin", "python") subprocess_arguments_list = [] else: - VIRTUALENV_DIR = os.path.join(SCRIPT_DIR, 'distrib_virtualenv') - PYTHON_PATH = os.path.join(VIRTUALENV_DIR, 'bin', 'python') + VIRTUALENV_DIR = os.path.join(SCRIPT_DIR, "distrib_virtualenv") + PYTHON_PATH = os.path.join(VIRTUALENV_DIR, "bin", "python") subprocess_arguments_list = [ - ['python3', '-m', 'virtualenv', VIRTUALENV_DIR], + ["python3", "-m", "virtualenv", VIRTUALENV_DIR], ] subprocess_arguments_list += [ - [PYTHON_PATH, '-m', 'pip', 'install', '--upgrade', 'pip==19.3.1'], - [PYTHON_PATH, '-m', 'pip', 'install', '-r', REQUIREMENTS_PATH], - [PYTHON_PATH, '-m', 'pip', 'install', '--upgrade', 'Sphinx'], - [PYTHON_PATH, SETUP_PATH, 'doc'], + [PYTHON_PATH, "-m", "pip", "install", "--upgrade", "pip==19.3.1"], + [PYTHON_PATH, "-m", "pip", "install", "-r", REQUIREMENTS_PATH], + [PYTHON_PATH, "-m", "pip", "install", "--upgrade", "Sphinx"], + [PYTHON_PATH, SETUP_PATH, "doc"], ] for subprocess_arguments in subprocess_arguments_list: - print('Running command: {}'.format(subprocess_arguments)) + print("Running command: {}".format(subprocess_arguments)) subprocess.check_call(args=subprocess_arguments) if not args.repo_owner or not args.doc_branch: - tty_width = int(os.popen('stty size', 'r').read().split()[1]) - print('-' * tty_width) - print('Please check generated Python doc inside doc/build') + tty_width = int(os.popen("stty size", "r").read().split()[1]) + print("-" * tty_width) + print("Please check generated Python doc inside doc/build") print( - 'To push to a GitHub repo, please provide repo owner and doc branch name' + "To push to a GitHub repo, please provide repo owner and doc branch" + " name" ) else: # Create a temporary directory out of tree, checkout gh-pages from the # specified repository, edit it, and push it. It's up to the user to then go # onto GitHub and make a PR against grpc/grpc:gh-pages. repo_parent_dir = tempfile.mkdtemp() - print('Documentation parent directory: {}'.format(repo_parent_dir)) - repo_dir = os.path.join(repo_parent_dir, 'grpc') - python_doc_dir = os.path.join(repo_dir, 'python') + print("Documentation parent directory: {}".format(repo_parent_dir)) + repo_dir = os.path.join(repo_parent_dir, "grpc") + python_doc_dir = os.path.join(repo_dir, "python") doc_branch = args.doc_branch - print('Cloning your repository...') - subprocess.check_call([ - 'git', - 'clone', - '--branch', - 'gh-pages', - 'https://github.com/grpc/grpc', - ], - cwd=repo_parent_dir) - subprocess.check_call(['git', 'checkout', '-b', doc_branch], cwd=repo_dir) - subprocess.check_call([ - 'git', 'remote', 'add', 'ssh-origin', - 'git@github.com:%s/grpc.git' % args.repo_owner - ], - cwd=repo_dir) - print('Updating documentation...') + print("Cloning your repository...") + subprocess.check_call( + [ + "git", + "clone", + "--branch", + "gh-pages", + "https://github.com/grpc/grpc", + ], + cwd=repo_parent_dir, + ) + subprocess.check_call(["git", "checkout", "-b", doc_branch], cwd=repo_dir) + subprocess.check_call( + [ + "git", + "remote", + "add", + "ssh-origin", + "git@github.com:%s/grpc.git" % args.repo_owner, + ], + cwd=repo_dir, + ) + print("Updating documentation...") shutil.rmtree(python_doc_dir, ignore_errors=True) shutil.copytree(DOC_PATH, python_doc_dir) - print('Attempting to push documentation to %s/%s...' % - (args.repo_owner, doc_branch)) + print( + "Attempting to push documentation to %s/%s..." + % (args.repo_owner, doc_branch) + ) try: - subprocess.check_call(['git', 'add', '--all'], cwd=repo_dir) + subprocess.check_call(["git", "add", "--all"], cwd=repo_dir) subprocess.check_call( - ['git', 'commit', '-m', 'Auto-update Python documentation'], - cwd=repo_dir) + ["git", "commit", "-m", "Auto-update Python documentation"], + cwd=repo_dir, + ) subprocess.check_call( - ['git', 'push', '--set-upstream', 'ssh-origin', doc_branch], - cwd=repo_dir) + ["git", "push", "--set-upstream", "ssh-origin", doc_branch], + cwd=repo_dir, + ) except subprocess.CalledProcessError: - print('Failed to push documentation. Examine this directory and push ' - 'manually: {}'.format(repo_parent_dir)) + print( + "Failed to push documentation. Examine this directory and push " + "manually: {}".format(repo_parent_dir) + ) sys.exit(1) shutil.rmtree(repo_parent_dir) diff --git a/tools/distrib/python/grpc_prefixed/generate.py b/tools/distrib/python/grpc_prefixed/generate.py index c685a1e113162..c8ac8cfde3b4e 100644 --- a/tools/distrib/python/grpc_prefixed/generate.py +++ b/tools/distrib/python/grpc_prefixed/generate.py @@ -30,12 +30,13 @@ import jinja2 WORK_PATH = os.path.realpath(os.path.dirname(__file__)) -LICENSE = os.path.join(WORK_PATH, '../../../../LICENSE') -BUILD_PATH = os.path.join(WORK_PATH, 'build') -DIST_PATH = os.path.join(WORK_PATH, 'dist') +LICENSE = os.path.join(WORK_PATH, "../../../../LICENSE") +BUILD_PATH = os.path.join(WORK_PATH, "build") +DIST_PATH = os.path.join(WORK_PATH, "dist") env = jinja2.Environment( - loader=jinja2.FileSystemLoader(os.path.join(WORK_PATH, 'templates'))) + loader=jinja2.FileSystemLoader(os.path.join(WORK_PATH, "templates")) +) LOGGER = logging.getLogger(__name__) POPEN_TIMEOUT_S = datetime.timedelta(minutes=1).total_seconds() @@ -44,10 +45,11 @@ @dataclasses.dataclass class PackageMeta: """Meta-info of a PyPI package.""" + name: str name_long: str destination_package: str - version: str = '1.0.0' + version: str = "1.0.0" def clean() -> None: @@ -68,81 +70,112 @@ def generate_package(meta: PackageMeta) -> None: os.makedirs(package_path, exist_ok=True) # Copy license - shutil.copyfile(LICENSE, os.path.join(package_path, 'LICENSE')) + shutil.copyfile(LICENSE, os.path.join(package_path, "LICENSE")) # Generates source code for template_name in env.list_templates(): template = env.get_template(template_name) with open( - os.path.join(package_path, - template_name.replace('.template', '')), 'w') as f: + os.path.join(package_path, template_name.replace(".template", "")), + "w", + ) as f: f.write(template.render(dataclasses.asdict(meta))) # Creates wheel - job = subprocess.Popen([ - sys.executable, - os.path.join(package_path, 'setup.py'), 'sdist', '--dist-dir', DIST_PATH - ], - cwd=package_path, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) + job = subprocess.Popen( + [ + sys.executable, + os.path.join(package_path, "setup.py"), + "sdist", + "--dist-dir", + DIST_PATH, + ], + cwd=package_path, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) outs, _ = job.communicate(timeout=POPEN_TIMEOUT_S) # Logs result if job.returncode != 0: - LOGGER.error('Wheel creation failed with %d', job.returncode) + LOGGER.error("Wheel creation failed with %d", job.returncode) LOGGER.error(outs) else: - LOGGER.info('Package <%s> generated', meta.name) + LOGGER.info("Package <%s> generated", meta.name) def main(): clean() generate_package( - PackageMeta(name='grpc', - name_long='gRPC Python', - destination_package='grpcio')) + PackageMeta( + name="grpc", name_long="gRPC Python", destination_package="grpcio" + ) + ) generate_package( - PackageMeta(name='grpc-status', - name_long='gRPC Rich Error Status', - destination_package='grpcio-status')) + PackageMeta( + name="grpc-status", + name_long="gRPC Rich Error Status", + destination_package="grpcio-status", + ) + ) generate_package( - PackageMeta(name='grpc-channelz', - name_long='gRPC Channel Tracing', - destination_package='grpcio-channelz')) + PackageMeta( + name="grpc-channelz", + name_long="gRPC Channel Tracing", + destination_package="grpcio-channelz", + ) + ) generate_package( - PackageMeta(name='grpc-tools', - name_long='ProtoBuf Code Generator', - destination_package='grpcio-tools')) + PackageMeta( + name="grpc-tools", + name_long="ProtoBuf Code Generator", + destination_package="grpcio-tools", + ) + ) generate_package( - PackageMeta(name='grpc-reflection', - name_long='gRPC Reflection', - destination_package='grpcio-reflection')) + PackageMeta( + name="grpc-reflection", + name_long="gRPC Reflection", + destination_package="grpcio-reflection", + ) + ) generate_package( - PackageMeta(name='grpc-testing', - name_long='gRPC Testing Utility', - destination_package='grpcio-testing')) + PackageMeta( + name="grpc-testing", + name_long="gRPC Testing Utility", + destination_package="grpcio-testing", + ) + ) generate_package( - PackageMeta(name='grpc-health-checking', - name_long='gRPC Health Checking', - destination_package='grpcio-health-checking')) + PackageMeta( + name="grpc-health-checking", + name_long="gRPC Health Checking", + destination_package="grpcio-health-checking", + ) + ) generate_package( - PackageMeta(name='grpc-csds', - name_long='gRPC Client Status Discovery Service', - destination_package='grpcio-csds')) + PackageMeta( + name="grpc-csds", + name_long="gRPC Client Status Discovery Service", + destination_package="grpcio-csds", + ) + ) generate_package( - PackageMeta(name='grpc-admin', - name_long='gRPC Admin Interface', - destination_package='grpcio-admin')) + PackageMeta( + name="grpc-admin", + name_long="gRPC Admin Interface", + destination_package="grpcio-admin", + ) + ) if __name__ == "__main__": diff --git a/tools/distrib/python/grpcio_tools/_parallel_compile_patch.py b/tools/distrib/python/grpcio_tools/_parallel_compile_patch.py index b259061e9170f..d5ac317ae5b38 100644 --- a/tools/distrib/python/grpcio_tools/_parallel_compile_patch.py +++ b/tools/distrib/python/grpcio_tools/_parallel_compile_patch.py @@ -22,26 +22,31 @@ try: BUILD_EXT_COMPILER_JOBS = int( - os.environ['GRPC_PYTHON_BUILD_EXT_COMPILER_JOBS']) + os.environ["GRPC_PYTHON_BUILD_EXT_COMPILER_JOBS"] + ) except KeyError: import multiprocessing + BUILD_EXT_COMPILER_JOBS = multiprocessing.cpu_count() # monkey-patch for parallel compilation -def _parallel_compile(self, - sources, - output_dir=None, - macros=None, - include_dirs=None, - debug=0, - extra_preargs=None, - extra_postargs=None, - depends=None): +def _parallel_compile( + self, + sources, + output_dir=None, + macros=None, + include_dirs=None, + debug=0, + extra_preargs=None, + extra_postargs=None, + depends=None, +): # setup the same way as distutils.ccompiler.CCompiler # https://github.com/python/cpython/blob/31368a4f0e531c19affe2a1becd25fc316bc7501/Lib/distutils/ccompiler.py#L564 macros, objects, extra_postargs, pp_opts, build = self._setup_compile( - output_dir, macros, include_dirs, sources, depends, extra_postargs) + output_dir, macros, include_dirs, sources, depends, extra_postargs + ) cc_args = self._get_cc_args(pp_opts, debug, extra_preargs) def _compile_single_file(obj): @@ -53,8 +58,10 @@ def _compile_single_file(obj): # run compilation of individual files in parallel import multiprocessing.pool + multiprocessing.pool.ThreadPool(BUILD_EXT_COMPILER_JOBS).map( - _compile_single_file, objects) + _compile_single_file, objects + ) return objects diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/command.py b/tools/distrib/python/grpcio_tools/grpc_tools/command.py index 3fd16687f4b13..5b0b7eb6a79fa 100644 --- a/tools/distrib/python/grpcio_tools/grpc_tools/command.py +++ b/tools/distrib/python/grpcio_tools/grpc_tools/command.py @@ -25,35 +25,42 @@ def build_package_protos(package_root, strict_mode=False): inclusion_root = os.path.abspath(package_root) for root, _, files in os.walk(inclusion_root): for filename in files: - if filename.endswith('.proto'): - proto_files.append(os.path.abspath(os.path.join(root, - filename))) + if filename.endswith(".proto"): + proto_files.append( + os.path.abspath(os.path.join(root, filename)) + ) well_known_protos_include = pkg_resources.resource_filename( - 'grpc_tools', '_proto') + "grpc_tools", "_proto" + ) for proto_file in proto_files: command = [ - 'grpc_tools.protoc', - '--proto_path={}'.format(inclusion_root), - '--proto_path={}'.format(well_known_protos_include), - '--python_out={}'.format(inclusion_root), - '--pyi_out={}'.format(inclusion_root), - '--grpc_python_out={}'.format(inclusion_root), + "grpc_tools.protoc", + "--proto_path={}".format(inclusion_root), + "--proto_path={}".format(well_known_protos_include), + "--python_out={}".format(inclusion_root), + "--pyi_out={}".format(inclusion_root), + "--grpc_python_out={}".format(inclusion_root), ] + [proto_file] if protoc.main(command) != 0: if strict_mode: - raise Exception('error: {} failed'.format(command)) + raise Exception("error: {} failed".format(command)) else: - sys.stderr.write('warning: {} failed'.format(command)) + sys.stderr.write("warning: {} failed".format(command)) class BuildPackageProtos(setuptools.Command): """Command to generate project *_pb2.py modules from proto files.""" - description = 'build grpc protobuf modules' - user_options = [('strict-mode', 's', - 'exit with non-zero value if the proto compiling fails.')] + description = "build grpc protobuf modules" + user_options = [ + ( + "strict-mode", + "s", + "exit with non-zero value if the proto compiling fails.", + ) + ] def initialize_options(self): self.strict_mode = False @@ -66,5 +73,6 @@ def run(self): # directory is provided as an 'include' directory. We assume it's the '' key # to `self.distribution.package_dir` (and get a key error if it's not # there). - build_package_protos(self.distribution.package_dir[''], - self.strict_mode) + build_package_protos( + self.distribution.package_dir[""], self.strict_mode + ) diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/protoc.py b/tools/distrib/python/grpcio_tools/grpc_tools/protoc.py index ce7b82487f768..10419ed1ca25d 100644 --- a/tools/distrib/python/grpcio_tools/grpc_tools/protoc.py +++ b/tools/distrib/python/grpcio_tools/grpc_tools/protoc.py @@ -29,10 +29,10 @@ def main(command_arguments): """Run the protocol buffer compiler with the given command-line arguments. - Args: - command_arguments: a list of strings representing command line arguments to - `protoc`. - """ + Args: + command_arguments: a list of strings representing command line arguments to + `protoc`. + """ command_arguments = [argument.encode() for argument in command_arguments] return _protoc_compiler.run_main(command_arguments) @@ -52,19 +52,25 @@ def _maybe_install_proto_finders(): global _FINDERS_INSTALLED with _FINDERS_INSTALLED_LOCK: if not _FINDERS_INSTALLED: - sys.meta_path.extend([ - ProtoFinder(_PROTO_MODULE_SUFFIX, - _protoc_compiler.get_protos), - ProtoFinder(_SERVICE_MODULE_SUFFIX, - _protoc_compiler.get_services) - ]) + sys.meta_path.extend( + [ + ProtoFinder( + _PROTO_MODULE_SUFFIX, _protoc_compiler.get_protos + ), + ProtoFinder( + _SERVICE_MODULE_SUFFIX, + _protoc_compiler.get_services, + ), + ] + ) sys.path.append( - pkg_resources.resource_filename('grpc_tools', '_proto')) + pkg_resources.resource_filename("grpc_tools", "_proto") + ) _FINDERS_INSTALLED = True def _module_name_to_proto_file(suffix, module_name): components = module_name.split(".") - proto_name = components[-1][:-1 * len(suffix)] + proto_name = components[-1][: -1 * len(suffix)] # NOTE(rbellevi): The Protobuf library expects this path to use # forward slashes on every platform. return "/".join(components[:-1] + [proto_name + ".proto"]) @@ -77,8 +83,9 @@ def _proto_file_to_module_name(suffix, proto_file): def _protos(protobuf_path): """Returns a gRPC module generated from the indicated proto file.""" _maybe_install_proto_finders() - module_name = _proto_file_to_module_name(_PROTO_MODULE_SUFFIX, - protobuf_path) + module_name = _proto_file_to_module_name( + _PROTO_MODULE_SUFFIX, protobuf_path + ) module = importlib.import_module(module_name) return module @@ -86,8 +93,9 @@ def _services(protobuf_path): """Returns a module generated from the indicated proto file.""" _maybe_install_proto_finders() _protos(protobuf_path) - module_name = _proto_file_to_module_name(_SERVICE_MODULE_SUFFIX, - protobuf_path) + module_name = _proto_file_to_module_name( + _SERVICE_MODULE_SUFFIX, protobuf_path + ) module = importlib.import_module(module_name) return module @@ -99,9 +107,9 @@ def _protos_and_services(protobuf_path): _proto_code_cache_lock = threading.RLock() class ProtoLoader(importlib.abc.Loader): - - def __init__(self, suffix, codegen_fn, module_name, protobuf_path, - proto_root): + def __init__( + self, suffix, codegen_fn, module_name, protobuf_path, proto_root + ): self._suffix = suffix self._codegen_fn = codegen_fn self._module_name = module_name @@ -113,8 +121,9 @@ def create_module(self, spec): def _generated_file_to_module_name(self, filepath): components = filepath.split(os.path.sep) - return ".".join(components[:-1] + - [os.path.splitext(components[-1])[0]]) + return ".".join( + components[:-1] + [os.path.splitext(components[-1])[0]] + ) def exec_module(self, module): assert module.__name__ == self._module_name @@ -125,8 +134,9 @@ def exec_module(self, module): exec(code, module.__dict__) else: files = self._codegen_fn( - self._protobuf_path.encode('ascii'), - [path.encode('ascii') for path in sys.path]) + self._protobuf_path.encode("ascii"), + [path.encode("ascii") for path in sys.path], + ) # NOTE: The files are returned in topological order of dependencies. Each # entry is guaranteed to depend only on the modules preceding it in the # list and the last entry is guaranteed to be our requested module. We @@ -134,7 +144,8 @@ def exec_module(self, module): # don't have to regenerate code that has already been generated by protoc. for f in files[:-1]: module_name = self._generated_file_to_module_name( - f[0].decode('ascii')) + f[0].decode("ascii") + ) if module_name not in sys.modules: if module_name not in _proto_code_cache: _proto_code_cache[module_name] = f[1] @@ -142,7 +153,6 @@ def exec_module(self, module): exec(files[-1][1], module.__dict__) class ProtoFinder(importlib.abc.MetaPathFinder): - def __init__(self, suffix, codegen_fn): self._suffix = suffix self._codegen_fn = codegen_fn @@ -160,14 +170,20 @@ def find_spec(self, fullname, path, target=None): else: return importlib.machinery.ModuleSpec( fullname, - ProtoLoader(self._suffix, self._codegen_fn, fullname, - filepath, search_path)) + ProtoLoader( + self._suffix, + self._codegen_fn, + fullname, + filepath, + search_path, + ), + ) # NOTE(rbellevi): We provide an environment variable that enables users to completely # disable this behavior if it is not desired, e.g. for performance reasons. if not os.getenv(_DISABLE_DYNAMIC_STUBS): _maybe_install_proto_finders() -if __name__ == '__main__': - proto_include = pkg_resources.resource_filename('grpc_tools', '_proto') - sys.exit(main(sys.argv + ['-I{}'.format(proto_include)])) +if __name__ == "__main__": + proto_include = pkg_resources.resource_filename("grpc_tools", "_proto") + sys.exit(main(sys.argv + ["-I{}".format(proto_include)])) diff --git a/tools/distrib/python/grpcio_tools/grpc_tools/test/protoc_test.py b/tools/distrib/python/grpcio_tools/grpc_tools/test/protoc_test.py index 1b72568f15e82..de1a7a56d6fbf 100644 --- a/tools/distrib/python/grpcio_tools/grpc_tools/test/protoc_test.py +++ b/tools/distrib/python/grpcio_tools/grpc_tools/test/protoc_test.py @@ -27,7 +27,6 @@ # TODO(https://github.com/grpc/grpc/issues/23847): Deduplicate this mechanism with # the grpcio_tests module. def _wrap_in_subprocess(error_queue, fn): - @functools.wraps(fn) def _wrapped(): try: @@ -48,7 +47,8 @@ def _run_in_subprocess(test_case): if not error_queue.empty(): raise error_queue.get() assert proc.exitcode == 0, "Process exited with code {}".format( - proc.exitcode) + proc.exitcode + ) @contextlib.contextmanager @@ -64,16 +64,20 @@ def _augmented_syspath(new_paths): def _test_import_protos(): from grpc_tools import protoc + with _augmented_syspath( - ("tools/distrib/python/grpcio_tools/grpc_tools/test/",)): + ("tools/distrib/python/grpcio_tools/grpc_tools/test/",) + ): protos = protoc._protos("simple.proto") assert protos.SimpleMessage is not None def _test_import_services(): from grpc_tools import protoc + with _augmented_syspath( - ("tools/distrib/python/grpcio_tools/grpc_tools/test/",)): + ("tools/distrib/python/grpcio_tools/grpc_tools/test/",) + ): protos = protoc._protos("simple.proto") services = protoc._services("simple.proto") assert services.SimpleMessageServiceStub is not None @@ -81,39 +85,50 @@ def _test_import_services(): def _test_import_services_without_protos(): from grpc_tools import protoc + with _augmented_syspath( - ("tools/distrib/python/grpcio_tools/grpc_tools/test/",)): + ("tools/distrib/python/grpcio_tools/grpc_tools/test/",) + ): services = protoc._services("simple.proto") assert services.SimpleMessageServiceStub is not None def _test_proto_module_imported_once(): from grpc_tools import protoc + with _augmented_syspath( - ("tools/distrib/python/grpcio_tools/grpc_tools/test/",)): + ("tools/distrib/python/grpcio_tools/grpc_tools/test/",) + ): protos = protoc._protos("simple.proto") services = protoc._services("simple.proto") complicated_protos = protoc._protos("complicated.proto") simple_message = protos.SimpleMessage() complicated_message = complicated_protos.ComplicatedMessage() - assert (simple_message.simpler_message.simplest_message.__class__ is - complicated_message.simplest_message.__class__) + assert ( + simple_message.simpler_message.simplest_message.__class__ + is complicated_message.simplest_message.__class__ + ) def _test_static_dynamic_combo(): with _augmented_syspath( - ("tools/distrib/python/grpcio_tools/grpc_tools/test/",)): + ("tools/distrib/python/grpcio_tools/grpc_tools/test/",) + ): from grpc_tools import protoc # isort:skip import complicated_pb2 + protos = protoc._protos("simple.proto") static_message = complicated_pb2.ComplicatedMessage() dynamic_message = protos.SimpleMessage() - assert (dynamic_message.simpler_message.simplest_message.__class__ is - static_message.simplest_message.__class__) + assert ( + dynamic_message.simpler_message.simplest_message.__class__ + is static_message.simplest_message.__class__ + ) def _test_combined_import(): from grpc_tools import protoc + protos, services = protoc._protos_and_services("simple.proto") assert protos.SimpleMessage is not None assert services.SimpleMessageServiceStub is not None @@ -121,6 +136,7 @@ def _test_combined_import(): def _test_syntax_errors(): from grpc_tools import protoc + try: protos = protoc._protos("flawed.proto") except Exception as e: @@ -133,7 +149,6 @@ def _test_syntax_errors(): class ProtocTest(unittest.TestCase): - def test_import_protos(self): _run_in_subprocess(_test_import_protos) @@ -156,5 +171,5 @@ def test_syntax_errors(self): _run_in_subprocess(_test_syntax_errors) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tools/distrib/python/grpcio_tools/setup.py b/tools/distrib/python/grpcio_tools/setup.py index 3c4e7ce99c110..9b5f3d828a0e3 100644 --- a/tools/distrib/python/grpcio_tools/setup.py +++ b/tools/distrib/python/grpcio_tools/setup.py @@ -34,10 +34,10 @@ # TODO(atash) add flag to disable Cython use _PACKAGE_PATH = os.path.realpath(os.path.dirname(__file__)) -_README_PATH = os.path.join(_PACKAGE_PATH, 'README.rst') +_README_PATH = os.path.join(_PACKAGE_PATH, "README.rst") os.chdir(os.path.dirname(os.path.abspath(__file__))) -sys.path.insert(0, os.path.abspath('.')) +sys.path.insert(0, os.path.abspath(".")) import _parallel_compile_patch import protoc_lib_deps @@ -53,10 +53,10 @@ _parallel_compile_patch.monkeypatch_compile_maybe() CLASSIFIERS = [ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", ] PY3 = sys.version_info.major == 3 @@ -64,13 +64,13 @@ def _env_bool_value(env_name, default): """Parses a bool option from an environment variable""" - return os.environ.get(env_name, default).upper() not in ['FALSE', '0', ''] + return os.environ.get(env_name, default).upper() not in ["FALSE", "0", ""] # Environment variable to determine whether or not the Cython extension should # *use* Cython or use the generated C files. Note that this requires the C files # to have been generated by building first *with* Cython support. -BUILD_WITH_CYTHON = _env_bool_value('GRPC_PYTHON_BUILD_WITH_CYTHON', 'False') +BUILD_WITH_CYTHON = _env_bool_value("GRPC_PYTHON_BUILD_WITH_CYTHON", "False") # Export this variable to force building the python extension with a statically linked libstdc++. # At least on linux, this is normally not needed as we can build manylinux-compatible wheels on linux just fine @@ -80,28 +80,34 @@ def _env_bool_value(env_name, default): # of GCC (we require >=5.1) but still uses old-enough libstdc++ symbols. # TODO(jtattermusch): remove this workaround once issues with crosscompiler version are resolved. BUILD_WITH_STATIC_LIBSTDCXX = _env_bool_value( - 'GRPC_PYTHON_BUILD_WITH_STATIC_LIBSTDCXX', 'False') + "GRPC_PYTHON_BUILD_WITH_STATIC_LIBSTDCXX", "False" +) def check_linker_need_libatomic(): """Test if linker on system needs libatomic.""" - code_test = (b'#include \n' + - b'int main() { return std::atomic{}; }') - cxx = os.environ.get('CXX', 'c++') - cpp_test = subprocess.Popen([cxx, '-x', 'c++', '-std=c++14', '-'], - stdin=PIPE, - stdout=PIPE, - stderr=PIPE) + code_test = ( + b"#include \n" + + b"int main() { return std::atomic{}; }" + ) + cxx = os.environ.get("CXX", "c++") + cpp_test = subprocess.Popen( + [cxx, "-x", "c++", "-std=c++14", "-"], + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + ) cpp_test.communicate(input=code_test) if cpp_test.returncode == 0: return False # Double-check to see if -latomic actually can solve the problem. # https://github.com/grpc/grpc/issues/22491 cpp_test = subprocess.Popen( - [cxx, '-x', 'c++', '-std=c++14', '-', '-latomic'], + [cxx, "-x", "c++", "-std=c++14", "-", "-latomic"], stdin=PIPE, stdout=PIPE, - stderr=PIPE) + stderr=PIPE, + ) cpp_test.communicate(input=code_test) return cpp_test.returncode == 0 @@ -117,10 +123,10 @@ def get_ext_filename(self, ext_name): # so that the resulting file name matches the target architecture and we end up with a well-formed # wheel. filename = build_ext.build_ext.get_ext_filename(self, ext_name) - orig_ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') - new_ext_suffix = os.getenv('GRPC_PYTHON_OVERRIDE_EXT_SUFFIX') + orig_ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") + new_ext_suffix = os.getenv("GRPC_PYTHON_OVERRIDE_EXT_SUFFIX") if new_ext_suffix and filename.endswith(orig_ext_suffix): - filename = filename[:-len(orig_ext_suffix)] + new_ext_suffix + filename = filename[: -len(orig_ext_suffix)] + new_ext_suffix return filename @@ -131,28 +137,33 @@ def get_ext_filename(self, ext_name): # We can also use these variables as a way to inject environment-specific # compiler/linker flags. We assume GCC-like compilers and/or MinGW as a # reasonable default. -EXTRA_ENV_COMPILE_ARGS = os.environ.get('GRPC_PYTHON_CFLAGS', None) -EXTRA_ENV_LINK_ARGS = os.environ.get('GRPC_PYTHON_LDFLAGS', None) +EXTRA_ENV_COMPILE_ARGS = os.environ.get("GRPC_PYTHON_CFLAGS", None) +EXTRA_ENV_LINK_ARGS = os.environ.get("GRPC_PYTHON_LDFLAGS", None) if EXTRA_ENV_COMPILE_ARGS is None: - EXTRA_ENV_COMPILE_ARGS = '-std=c++14' - if 'win32' in sys.platform: + EXTRA_ENV_COMPILE_ARGS = "-std=c++14" + if "win32" in sys.platform: if sys.version_info < (3, 5): # We use define flags here and don't directly add to DEFINE_MACROS below to # ensure that the expert user/builder has a way of turning it off (via the # envvars) without adding yet more GRPC-specific envvars. # See https://sourceforge.net/p/mingw-w64/bugs/363/ - if '32' in platform.architecture()[0]: - EXTRA_ENV_COMPILE_ARGS += ' -D_ftime=_ftime32 -D_timeb=__timeb32 -D_ftime_s=_ftime32_s -D_hypot=hypot' + if "32" in platform.architecture()[0]: + EXTRA_ENV_COMPILE_ARGS += ( + " -D_ftime=_ftime32 -D_timeb=__timeb32" + " -D_ftime_s=_ftime32_s -D_hypot=hypot" + ) else: - EXTRA_ENV_COMPILE_ARGS += ' -D_ftime=_ftime64 -D_timeb=__timeb64 -D_hypot=hypot' + EXTRA_ENV_COMPILE_ARGS += ( + " -D_ftime=_ftime64 -D_timeb=__timeb64 -D_hypot=hypot" + ) else: # We need to statically link the C++ Runtime, only the C runtime is # available dynamically - EXTRA_ENV_COMPILE_ARGS += ' /MT' + EXTRA_ENV_COMPILE_ARGS += " /MT" elif "linux" in sys.platform or "darwin" in sys.platform: - EXTRA_ENV_COMPILE_ARGS += ' -fno-wrapv -frtti' + EXTRA_ENV_COMPILE_ARGS += " -fno-wrapv -frtti" if EXTRA_ENV_LINK_ARGS is None: - EXTRA_ENV_LINK_ARGS = '' + EXTRA_ENV_LINK_ARGS = "" # NOTE(rbellevi): Clang on Mac OS will make all static symbols (both # variables and objects) global weak symbols. When a process loads the # protobuf wheel's shared object library before loading *this* C extension, @@ -173,23 +184,25 @@ def get_ext_filename(self, ext_name): # more modern ABIs (ELF et al.), Mach-O prepends an underscore to the names # of C functions. if "darwin" in sys.platform: - EXTRA_ENV_LINK_ARGS += ' -Wl,-exported_symbol,_{}'.format( - _EXT_INIT_SYMBOL) + EXTRA_ENV_LINK_ARGS += " -Wl,-exported_symbol,_{}".format( + _EXT_INIT_SYMBOL + ) if "linux" in sys.platform or "darwin" in sys.platform: - EXTRA_ENV_LINK_ARGS += ' -lpthread' + EXTRA_ENV_LINK_ARGS += " -lpthread" if check_linker_need_libatomic(): - EXTRA_ENV_LINK_ARGS += ' -latomic' + EXTRA_ENV_LINK_ARGS += " -latomic" elif "win32" in sys.platform and sys.version_info < (3, 5): msvcr = cygwinccompiler.get_msvcr()[0] EXTRA_ENV_LINK_ARGS += ( - ' -static-libgcc -static-libstdc++ -mcrtdll={msvcr}' - ' -static -lshlwapi'.format(msvcr=msvcr)) + " -static-libgcc -static-libstdc++ -mcrtdll={msvcr}" + " -static -lshlwapi".format(msvcr=msvcr) + ) EXTRA_COMPILE_ARGS = shlex.split(EXTRA_ENV_COMPILE_ARGS) EXTRA_LINK_ARGS = shlex.split(EXTRA_ENV_LINK_ARGS) if BUILD_WITH_STATIC_LIBSTDCXX: - EXTRA_LINK_ARGS.append('-static-libstdc++') + EXTRA_LINK_ARGS.append("-static-libstdc++") CC_FILES = [os.path.normpath(cc_file) for cc_file in protoc_lib_deps.CC_FILES] PROTO_FILES = [ @@ -200,44 +213,50 @@ def get_ext_filename(self, ext_name): ] PROTO_INCLUDE = os.path.normpath(protoc_lib_deps.PROTO_INCLUDE) -GRPC_PYTHON_TOOLS_PACKAGE = 'grpc_tools' -GRPC_PYTHON_PROTO_RESOURCES_NAME = '_proto' +GRPC_PYTHON_TOOLS_PACKAGE = "grpc_tools" +GRPC_PYTHON_PROTO_RESOURCES_NAME = "_proto" DEFINE_MACROS = () if "win32" in sys.platform: DEFINE_MACROS += ( - ('WIN32_LEAN_AND_MEAN', 1), + ("WIN32_LEAN_AND_MEAN", 1), # avoid https://github.com/abseil/abseil-cpp/issues/1425 - ('NOMINMAX', 1), + ("NOMINMAX", 1), ) - if '64bit' in platform.architecture()[0]: - DEFINE_MACROS += (('MS_WIN64', 1),) + if "64bit" in platform.architecture()[0]: + DEFINE_MACROS += (("MS_WIN64", 1),) elif "linux" in sys.platform or "darwin" in sys.platform: - DEFINE_MACROS += (('HAVE_PTHREAD', 1),) + DEFINE_MACROS += (("HAVE_PTHREAD", 1),) # By default, Python3 distutils enforces compatibility of # c plugins (.so files) with the OSX version Python was built with. # We need OSX 10.10, the oldest which supports C++ thread_local. -if 'darwin' in sys.platform: - mac_target = sysconfig.get_config_var('MACOSX_DEPLOYMENT_TARGET') - if mac_target and (pkg_resources.parse_version(mac_target) < - pkg_resources.parse_version('10.10.0')): - os.environ['MACOSX_DEPLOYMENT_TARGET'] = '10.10' - os.environ['_PYTHON_HOST_PLATFORM'] = re.sub( - r'macosx-[0-9]+\.[0-9]+-(.+)', r'macosx-10.10-\1', - util.get_platform()) +if "darwin" in sys.platform: + mac_target = sysconfig.get_config_var("MACOSX_DEPLOYMENT_TARGET") + if mac_target and ( + pkg_resources.parse_version(mac_target) + < pkg_resources.parse_version("10.10.0") + ): + os.environ["MACOSX_DEPLOYMENT_TARGET"] = "10.10" + os.environ["_PYTHON_HOST_PLATFORM"] = re.sub( + r"macosx-[0-9]+\.[0-9]+-(.+)", + r"macosx-10.10-\1", + util.get_platform(), + ) def package_data(): - tools_path = GRPC_PYTHON_TOOLS_PACKAGE.replace('.', os.path.sep) - proto_resources_path = os.path.join(tools_path, - GRPC_PYTHON_PROTO_RESOURCES_NAME) + tools_path = GRPC_PYTHON_TOOLS_PACKAGE.replace(".", os.path.sep) + proto_resources_path = os.path.join( + tools_path, GRPC_PYTHON_PROTO_RESOURCES_NAME + ) proto_files = [] for proto_file in PROTO_FILES: source = os.path.join(PROTO_INCLUDE, proto_file) target = os.path.join(proto_resources_path, proto_file) - relative_target = os.path.join(GRPC_PYTHON_PROTO_RESOURCES_NAME, - proto_file) + relative_target = os.path.join( + GRPC_PYTHON_PROTO_RESOURCES_NAME, proto_file + ) try: os.makedirs(os.path.dirname(target)) except OSError as error: @@ -252,25 +271,26 @@ def package_data(): def extension_modules(): if BUILD_WITH_CYTHON: - plugin_sources = [os.path.join('grpc_tools', '_protoc_compiler.pyx')] + plugin_sources = [os.path.join("grpc_tools", "_protoc_compiler.pyx")] else: - plugin_sources = [os.path.join('grpc_tools', '_protoc_compiler.cpp')] + plugin_sources = [os.path.join("grpc_tools", "_protoc_compiler.cpp")] plugin_sources += [ - os.path.join('grpc_tools', 'main.cc'), - os.path.join('grpc_root', 'src', 'compiler', 'python_generator.cc'), - os.path.join('grpc_root', 'src', 'compiler', 'proto_parser_helper.cc') + os.path.join("grpc_tools", "main.cc"), + os.path.join("grpc_root", "src", "compiler", "python_generator.cc"), + os.path.join("grpc_root", "src", "compiler", "proto_parser_helper.cc"), ] + CC_FILES plugin_ext = extension.Extension( - name='grpc_tools._protoc_compiler', + name="grpc_tools._protoc_compiler", sources=plugin_sources, include_dirs=[ - '.', - 'grpc_root', - os.path.join('grpc_root', 'include'), - ] + CC_INCLUDES, - language='c++', + ".", + "grpc_root", + os.path.join("grpc_root", "include"), + ] + + CC_INCLUDES, + language="c++", define_macros=list(DEFINE_MACROS), extra_compile_args=list(EXTRA_COMPILE_ARGS), extra_link_args=list(EXTRA_LINK_ARGS), @@ -278,37 +298,37 @@ def extension_modules(): extensions = [plugin_ext] if BUILD_WITH_CYTHON: from Cython import Build + return Build.cythonize(extensions) else: return extensions setuptools.setup( - name='grpcio-tools', + name="grpcio-tools", version=grpc_version.VERSION, - description='Protobuf code generator for gRPC', - long_description_content_type='text/x-rst', - long_description=open(_README_PATH, 'r').read(), - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - url='https://grpc.io', + description="Protobuf code generator for gRPC", + long_description_content_type="text/x-rst", + long_description=open(_README_PATH, "r").read(), + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + url="https://grpc.io", project_urls={ - "Source Code": - "https://github.com/grpc/grpc/tree/master/tools/distrib/python/grpcio_tools", - "Bug Tracker": - "https://github.com/grpc/grpc/issues", + "Source Code": "https://github.com/grpc/grpc/tree/master/tools/distrib/python/grpcio_tools", + "Bug Tracker": "https://github.com/grpc/grpc/issues", }, - license='Apache License 2.0', + license="Apache License 2.0", classifiers=CLASSIFIERS, ext_modules=extension_modules(), - packages=setuptools.find_packages('.'), - python_requires='>=3.7', + packages=setuptools.find_packages("."), + python_requires=">=3.7", install_requires=[ - 'protobuf>=4.21.6,<5.0dev', - 'grpcio>={version}'.format(version=grpc_version.VERSION), - 'setuptools', + "protobuf>=4.21.6,<5.0dev", + "grpcio>={version}".format(version=grpc_version.VERSION), + "setuptools", ], package_data=package_data(), cmdclass={ - 'build_ext': BuildExt, - }) + "build_ext": BuildExt, + }, +) diff --git a/tools/distrib/python/make_grpcio_tools.py b/tools/distrib/python/make_grpcio_tools.py index 34132723d789f..234a178d82abd 100755 --- a/tools/distrib/python/make_grpcio_tools.py +++ b/tools/distrib/python/make_grpcio_tools.py @@ -60,69 +60,77 @@ COMMIT_HASH_SUFFIX = '"' EXTERNAL_LINKS = [ - ('@com_google_absl//', 'third_party/abseil-cpp/'), - ('@com_google_protobuf//', 'third_party/protobuf/'), - ('@utf8_range//:', 'third_party/utf8_range/'), + ("@com_google_absl//", "third_party/abseil-cpp/"), + ("@com_google_protobuf//", "third_party/protobuf/"), + ("@utf8_range//:", "third_party/utf8_range/"), ] -PROTOBUF_PROTO_PREFIX = '@com_google_protobuf//src/' +PROTOBUF_PROTO_PREFIX = "@com_google_protobuf//src/" # will be added to include path when building grpcio_tools CC_INCLUDES = [ - os.path.join('third_party', 'abseil-cpp'), - os.path.join('third_party', 'protobuf', 'src'), - os.path.join('third_party', 'utf8_range'), + os.path.join("third_party", "abseil-cpp"), + os.path.join("third_party", "protobuf", "src"), + os.path.join("third_party", "utf8_range"), ] # include path for .proto files -PROTO_INCLUDE = os.path.join('third_party', 'protobuf', 'src') +PROTO_INCLUDE = os.path.join("third_party", "protobuf", "src") # the target directory is relative to the grpcio_tools package root. -GRPCIO_TOOLS_ROOT_PREFIX = 'tools/distrib/python/grpcio_tools/' +GRPCIO_TOOLS_ROOT_PREFIX = "tools/distrib/python/grpcio_tools/" # Pairs of (source, target) directories to copy # from the grpc repo root to the grpcio_tools build root. COPY_FILES_SOURCE_TARGET_PAIRS = [ - ('include', 'grpc_root/include'), - ('src/compiler', 'grpc_root/src/compiler'), - ('third_party/abseil-cpp/absl', 'third_party/abseil-cpp/absl'), - ('third_party/protobuf/src', 'third_party/protobuf/src'), - ('third_party/utf8_range', 'third_party/utf8_range') + ("include", "grpc_root/include"), + ("src/compiler", "grpc_root/src/compiler"), + ("third_party/abseil-cpp/absl", "third_party/abseil-cpp/absl"), + ("third_party/protobuf/src", "third_party/protobuf/src"), + ("third_party/utf8_range", "third_party/utf8_range"), ] # grpc repo root GRPC_ROOT = os.path.abspath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..')) + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..") +) # the directory under which to probe for the current protobuf commit SHA -GRPC_PROTOBUF_SUBMODULE_ROOT = os.path.join(GRPC_ROOT, 'third_party', - 'protobuf') +GRPC_PROTOBUF_SUBMODULE_ROOT = os.path.join( + GRPC_ROOT, "third_party", "protobuf" +) # the file to generate -GRPC_PYTHON_PROTOC_LIB_DEPS = os.path.join(GRPC_ROOT, 'tools', 'distrib', - 'python', 'grpcio_tools', - 'protoc_lib_deps.py') +GRPC_PYTHON_PROTOC_LIB_DEPS = os.path.join( + GRPC_ROOT, + "tools", + "distrib", + "python", + "grpcio_tools", + "protoc_lib_deps.py", +) # the script to run for getting dependencies -BAZEL_DEPS = os.path.join(GRPC_ROOT, 'tools', 'distrib', 'python', - 'bazel_deps.sh') +BAZEL_DEPS = os.path.join( + GRPC_ROOT, "tools", "distrib", "python", "bazel_deps.sh" +) # the bazel target to scrape to get list of sources for the build -BAZEL_DEPS_PROTOC_LIB_QUERY = '@com_google_protobuf//:protoc_lib' +BAZEL_DEPS_PROTOC_LIB_QUERY = "@com_google_protobuf//:protoc_lib" BAZEL_DEPS_COMMON_PROTOS_QUERIES = [ - '@com_google_protobuf//:well_known_type_protos', + "@com_google_protobuf//:well_known_type_protos", # has both plugin.proto and descriptor.proto - '@com_google_protobuf//:compiler_plugin_proto', + "@com_google_protobuf//:compiler_plugin_proto", ] def protobuf_submodule_commit_hash(): """Gets the commit hash for the HEAD of the protobuf submodule currently - checked out.""" + checked out.""" cwd = os.getcwd() os.chdir(GRPC_PROTOBUF_SUBMODULE_ROOT) - output = subprocess.check_output(['git', 'rev-parse', 'HEAD']) + output = subprocess.check_output(["git", "rev-parse", "HEAD"]) os.chdir(cwd) return output.decode("ascii").splitlines()[0].strip() @@ -138,11 +146,11 @@ def _pretty_print_list(items): """Pretty print python list""" formatted = pprint.pformat(items, indent=4) # add newline after opening bracket (and fix indent of the next line) - if formatted.startswith('['): - formatted = formatted[0] + '\n ' + formatted[1:] + if formatted.startswith("["): + formatted = formatted[0] + "\n " + formatted[1:] # add newline before closing bracket - if formatted.endswith(']'): - formatted = formatted[:-1] + '\n' + formatted[-1] + if formatted.endswith("]"): + formatted = formatted[:-1] + "\n" + formatted[-1] return formatted @@ -150,13 +158,13 @@ def _bazel_name_to_file_path(name): """Transform bazel reference to source file name.""" for link in EXTERNAL_LINKS: if name.startswith(link[0]): - filepath = link[1] + name[len(link[0]):].replace(':', '/') + filepath = link[1] + name[len(link[0]) :].replace(":", "/") # For some reason, the WKT sources (such as wrappers.pb.cc) # end up being reported by bazel as having an extra 'wkt/google/protobuf' # in path. Removing it makes the compilation pass. # TODO(jtattermusch) Get dir of this hack. - return filepath.replace('wkt/google/protobuf/', '') + return filepath.replace("wkt/google/protobuf/", "") return None @@ -167,7 +175,7 @@ def _generate_deps_file_content(): # Collect .cc files (that will be later included in the native extension build) cc_files = [] for name in cc_files_output: - if name.endswith('.cc'): + if name.endswith(".cc"): filepath = _bazel_name_to_file_path(name) if filepath: cc_files.append(filepath) @@ -177,9 +185,9 @@ def _generate_deps_file_content(): for target in BAZEL_DEPS_COMMON_PROTOS_QUERIES: raw_proto_files += _bazel_query(target) proto_files = [ - name[len(PROTOBUF_PROTO_PREFIX):].replace(':', '/') + name[len(PROTOBUF_PROTO_PREFIX) :].replace(":", "/") for name in raw_proto_files - if name.endswith('.proto') and name.startswith(PROTOBUF_PROTO_PREFIX) + if name.endswith(".proto") and name.startswith(PROTOBUF_PROTO_PREFIX) ] commit_hash = protobuf_submodule_commit_hash() @@ -190,19 +198,21 @@ def _generate_deps_file_content(): proto_files=_pretty_print_list(sorted(set(proto_files))), cc_includes=_pretty_print_list(CC_INCLUDES), proto_include=repr(PROTO_INCLUDE), - commit_hash_expr=commit_hash_expr) + commit_hash_expr=commit_hash_expr, + ) return deps_file_content def _copy_source_tree(source, target): """Copies source directory to a given target directory.""" - print('Copying contents of %s to %s' % (source, target)) + print("Copying contents of %s to %s" % (source, target)) # TODO(jtattermusch): It is unclear why this legacy code needs to copy # the source directory to the target via the following boilerplate. # Should this code be simplified? for source_dir, _, files in os.walk(source): target_dir = os.path.abspath( - os.path.join(target, os.path.relpath(source_dir, source))) + os.path.join(target, os.path.relpath(source_dir, source)) + ) try: os.makedirs(target_dir) except OSError as error: @@ -210,9 +220,11 @@ def _copy_source_tree(source, target): raise for relative_file in files: source_file = os.path.abspath( - os.path.join(source_dir, relative_file)) + os.path.join(source_dir, relative_file) + ) target_file = os.path.abspath( - os.path.join(target_dir, relative_file)) + os.path.join(target_dir, relative_file) + ) shutil.copyfile(source_file, target_file) @@ -226,14 +238,15 @@ def main(): for source, target in COPY_FILES_SOURCE_TARGET_PAIRS: # convert the slashes in the relative path to platform-specific path dividers. # All paths are relative to GRPC_ROOT - source_abs = os.path.join(GRPC_ROOT, os.path.join(*source.split('/'))) + source_abs = os.path.join(GRPC_ROOT, os.path.join(*source.split("/"))) # for targets, add grpcio_tools root prefix target = GRPCIO_TOOLS_ROOT_PREFIX + target - target_abs = os.path.join(GRPC_ROOT, os.path.join(*target.split('/'))) + target_abs = os.path.join(GRPC_ROOT, os.path.join(*target.split("/"))) _copy_source_tree(source_abs, target_abs) print( - 'The necessary source files were copied under the grpcio_tools package root.' + "The necessary source files were copied under the grpcio_tools package" + " root." ) print() @@ -253,11 +266,11 @@ def main(): traceback.print_exc(file=sys.stderr) return # If we successfully got the dependencies, truncate and rewrite the deps file. - with open(GRPC_PYTHON_PROTOC_LIB_DEPS, 'w') as deps_file: + with open(GRPC_PYTHON_PROTOC_LIB_DEPS, "w") as deps_file: deps_file.write(protoc_lib_deps_content) print('File "%s" updated.' % GRPC_PYTHON_PROTOC_LIB_DEPS) - print('Done.') + print("Done.") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/distrib/python/xds_protos/build.py b/tools/distrib/python/xds_protos/build.py index 6434272bc7774..6cbfc0fcc1526 100644 --- a/tools/distrib/python/xds_protos/build.py +++ b/tools/distrib/python/xds_protos/build.py @@ -22,89 +22,106 @@ # We might not want to compile all the protos EXCLUDE_PROTO_PACKAGES_LIST = [ # Requires extra dependency to Prometheus protos - 'envoy/service/metrics/v2', - 'envoy/service/metrics/v3', - 'envoy/service/metrics/v4alpha', + "envoy/service/metrics/v2", + "envoy/service/metrics/v3", + "envoy/service/metrics/v4alpha", ] # Compute the pathes WORK_DIR = os.path.dirname(os.path.abspath(__file__)) -GRPC_ROOT = os.path.abspath(os.path.join(WORK_DIR, '..', '..', '..', '..')) -XDS_PROTO_ROOT = os.path.join(GRPC_ROOT, 'third_party', 'envoy-api') -UDPA_PROTO_ROOT = os.path.join(GRPC_ROOT, 'third_party', 'udpa') -GOOGLEAPIS_ROOT = os.path.join(GRPC_ROOT, 'third_party', 'googleapis') -VALIDATE_ROOT = os.path.join(GRPC_ROOT, 'third_party', 'protoc-gen-validate') -OPENCENSUS_PROTO_ROOT = os.path.join(GRPC_ROOT, 'third_party', - 'opencensus-proto', 'src') -OPENTELEMETRY_PROTO_ROOT = os.path.join(GRPC_ROOT, 'third_party', - 'opentelemetry') +GRPC_ROOT = os.path.abspath(os.path.join(WORK_DIR, "..", "..", "..", "..")) +XDS_PROTO_ROOT = os.path.join(GRPC_ROOT, "third_party", "envoy-api") +UDPA_PROTO_ROOT = os.path.join(GRPC_ROOT, "third_party", "udpa") +GOOGLEAPIS_ROOT = os.path.join(GRPC_ROOT, "third_party", "googleapis") +VALIDATE_ROOT = os.path.join(GRPC_ROOT, "third_party", "protoc-gen-validate") +OPENCENSUS_PROTO_ROOT = os.path.join( + GRPC_ROOT, "third_party", "opencensus-proto", "src" +) +OPENTELEMETRY_PROTO_ROOT = os.path.join( + GRPC_ROOT, "third_party", "opentelemetry" +) WELL_KNOWN_PROTOS_INCLUDE = pkg_resources.resource_filename( - 'grpc_tools', '_proto') + "grpc_tools", "_proto" +) OUTPUT_PATH = WORK_DIR # Prepare the test file generation -TEST_FILE_NAME = 'generated_file_import_test.py' +TEST_FILE_NAME = "generated_file_import_test.py" TEST_IMPORTS = [] # The pkgutil-style namespace packaging __init__.py -PKGUTIL_STYLE_INIT = "__path__ = __import__('pkgutil').extend_path(__path__, __name__)\n" +PKGUTIL_STYLE_INIT = ( + "__path__ = __import__('pkgutil').extend_path(__path__, __name__)\n" +) NAMESPACE_PACKAGES = ["google"] -def add_test_import(proto_package_path: str, - file_name: str, - service: bool = False): - TEST_IMPORTS.append("from %s import %s\n" % (proto_package_path.replace( - '/', '.'), file_name.replace('.proto', '_pb2'))) +def add_test_import( + proto_package_path: str, file_name: str, service: bool = False +): + TEST_IMPORTS.append( + "from %s import %s\n" + % ( + proto_package_path.replace("/", "."), + file_name.replace(".proto", "_pb2"), + ) + ) if service: - TEST_IMPORTS.append("from %s import %s\n" % (proto_package_path.replace( - '/', '.'), file_name.replace('.proto', '_pb2_grpc'))) + TEST_IMPORTS.append( + "from %s import %s\n" + % ( + proto_package_path.replace("/", "."), + file_name.replace(".proto", "_pb2_grpc"), + ) + ) # Prepare Protoc command COMPILE_PROTO_ONLY = [ - 'grpc_tools.protoc', - '--proto_path={}'.format(XDS_PROTO_ROOT), - '--proto_path={}'.format(UDPA_PROTO_ROOT), - '--proto_path={}'.format(GOOGLEAPIS_ROOT), - '--proto_path={}'.format(VALIDATE_ROOT), - '--proto_path={}'.format(WELL_KNOWN_PROTOS_INCLUDE), - '--proto_path={}'.format(OPENCENSUS_PROTO_ROOT), - '--proto_path={}'.format(OPENTELEMETRY_PROTO_ROOT), - '--python_out={}'.format(OUTPUT_PATH), + "grpc_tools.protoc", + "--proto_path={}".format(XDS_PROTO_ROOT), + "--proto_path={}".format(UDPA_PROTO_ROOT), + "--proto_path={}".format(GOOGLEAPIS_ROOT), + "--proto_path={}".format(VALIDATE_ROOT), + "--proto_path={}".format(WELL_KNOWN_PROTOS_INCLUDE), + "--proto_path={}".format(OPENCENSUS_PROTO_ROOT), + "--proto_path={}".format(OPENTELEMETRY_PROTO_ROOT), + "--python_out={}".format(OUTPUT_PATH), ] -COMPILE_BOTH = COMPILE_PROTO_ONLY + ['--grpc_python_out={}'.format(OUTPUT_PATH)] +COMPILE_BOTH = COMPILE_PROTO_ONLY + ["--grpc_python_out={}".format(OUTPUT_PATH)] def has_grpc_service(proto_package_path: str) -> bool: - return proto_package_path.startswith('envoy/service') + return proto_package_path.startswith("envoy/service") -def compile_protos(proto_root: str, sub_dir: str = '.') -> None: +def compile_protos(proto_root: str, sub_dir: str = ".") -> None: for root, _, files in os.walk(os.path.join(proto_root, sub_dir)): proto_package_path = os.path.relpath(root, proto_root) if proto_package_path in EXCLUDE_PROTO_PACKAGES_LIST: - print(f'Skipping package {proto_package_path}') + print(f"Skipping package {proto_package_path}") continue for file_name in files: - if file_name.endswith('.proto'): + if file_name.endswith(".proto"): # Compile proto if has_grpc_service(proto_package_path): - return_code = protoc.main(COMPILE_BOTH + - [os.path.join(root, file_name)]) + return_code = protoc.main( + COMPILE_BOTH + [os.path.join(root, file_name)] + ) add_test_import(proto_package_path, file_name, service=True) else: - return_code = protoc.main(COMPILE_PROTO_ONLY + - [os.path.join(root, file_name)]) - add_test_import(proto_package_path, - file_name, - service=False) + return_code = protoc.main( + COMPILE_PROTO_ONLY + [os.path.join(root, file_name)] + ) + add_test_import( + proto_package_path, file_name, service=False + ) if return_code != 0: - raise Exception('error: {} failed'.format(COMPILE_BOTH)) + raise Exception("error: {} failed".format(COMPILE_BOTH)) def create_init_file(path: str, package_path: str = "") -> None: - with open(os.path.join(path, "__init__.py"), 'w') as f: + with open(os.path.join(path, "__init__.py"), "w") as f: # Apply the pkgutil-style namespace packaging, which is compatible for 2 # and 3. Here is the full table of namespace compatibility: # https://github.com/pypa/sample-namespace-packages/blob/master/table.md @@ -117,27 +134,32 @@ def main(): compile_protos(XDS_PROTO_ROOT) compile_protos(UDPA_PROTO_ROOT) # We don't want to compile the entire GCP surface API, just the essential ones - compile_protos(GOOGLEAPIS_ROOT, os.path.join('google', 'api')) - compile_protos(GOOGLEAPIS_ROOT, os.path.join('google', 'rpc')) - compile_protos(GOOGLEAPIS_ROOT, os.path.join('google', 'longrunning')) - compile_protos(GOOGLEAPIS_ROOT, os.path.join('google', 'logging')) - compile_protos(GOOGLEAPIS_ROOT, os.path.join('google', 'type')) - compile_protos(VALIDATE_ROOT, 'validate') + compile_protos(GOOGLEAPIS_ROOT, os.path.join("google", "api")) + compile_protos(GOOGLEAPIS_ROOT, os.path.join("google", "rpc")) + compile_protos(GOOGLEAPIS_ROOT, os.path.join("google", "longrunning")) + compile_protos(GOOGLEAPIS_ROOT, os.path.join("google", "logging")) + compile_protos(GOOGLEAPIS_ROOT, os.path.join("google", "type")) + compile_protos(VALIDATE_ROOT, "validate") compile_protos(OPENCENSUS_PROTO_ROOT) compile_protos(OPENTELEMETRY_PROTO_ROOT) # Generate __init__.py files for all modules create_init_file(WORK_DIR) for proto_root_module in [ - 'envoy', 'google', 'opencensus', 'udpa', 'validate', 'xds', - 'opentelemetry' + "envoy", + "google", + "opencensus", + "udpa", + "validate", + "xds", + "opentelemetry", ]: for root, _, _ in os.walk(os.path.join(WORK_DIR, proto_root_module)): package_path = os.path.relpath(root, WORK_DIR) create_init_file(root, package_path) # Generate test file - with open(os.path.join(WORK_DIR, TEST_FILE_NAME), 'w') as f: + with open(os.path.join(WORK_DIR, TEST_FILE_NAME), "w") as f: f.writelines(TEST_IMPORTS) diff --git a/tools/distrib/python/xds_protos/setup.py b/tools/distrib/python/xds_protos/setup.py index 3459104096325..c4542556bd5b5 100644 --- a/tools/distrib/python/xds_protos/setup.py +++ b/tools/distrib/python/xds_protos/setup.py @@ -19,35 +19,36 @@ import setuptools WORK_DIR = os.path.dirname(os.path.abspath(__file__)) -EXCLUDE_PYTHON_FILES = ['generated_file_import_test.py', 'build.py'] +EXCLUDE_PYTHON_FILES = ["generated_file_import_test.py", "build.py"] # Use setuptools to build Python package -with open(os.path.join(WORK_DIR, 'README.rst'), 'r') as f: +with open(os.path.join(WORK_DIR, "README.rst"), "r") as f: LONG_DESCRIPTION = f.read() PACKAGES = setuptools.find_packages(where=".", exclude=EXCLUDE_PYTHON_FILES) CLASSIFIERS = [ - 'Development Status :: 3 - Alpha', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', + "Development Status :: 3 - Alpha", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", ] INSTALL_REQUIRES = [ - 'grpcio>=1.49.0', - 'protobuf>=4.21.6,<5.0dev', + "grpcio>=1.49.0", + "protobuf>=4.21.6,<5.0dev", ] -SETUP_REQUIRES = INSTALL_REQUIRES + ['grpcio-tools'] +SETUP_REQUIRES = INSTALL_REQUIRES + ["grpcio-tools"] setuptools.setup( - name='xds-protos', - version='0.0.12', + name="xds-protos", + version="0.0.12", packages=PACKAGES, - description='Generated Python code from envoyproxy/data-plane-api', - long_description_content_type='text/x-rst', + description="Generated Python code from envoyproxy/data-plane-api", + long_description_content_type="text/x-rst", long_description=LONG_DESCRIPTION, - author='The gRPC Authors', - author_email='grpc-io@googlegroups.com', - url='https://grpc.io', - license='Apache License 2.0', - python_requires='>=3.7', + author="The gRPC Authors", + author_email="grpc-io@googlegroups.com", + url="https://grpc.io", + license="Apache License 2.0", + python_requires=">=3.7", install_requires=INSTALL_REQUIRES, setup_requires=SETUP_REQUIRES, - classifiers=CLASSIFIERS) + classifiers=CLASSIFIERS, +) diff --git a/tools/distrib/run_buildozer.py b/tools/distrib/run_buildozer.py index 5f6e01c6d305b..baaa16626b92e 100644 --- a/tools/distrib/run_buildozer.py +++ b/tools/distrib/run_buildozer.py @@ -22,9 +22,9 @@ def run_buildozer(buildozer_commands): return ok_statuses = (0, 3) temp = tempfile.NamedTemporaryFile() - open(temp.name, 'w').write('\n'.join(buildozer_commands)) - c = ['tools/distrib/buildozer.sh', '-f', temp.name] + open(temp.name, "w").write("\n".join(buildozer_commands)) + c = ["tools/distrib/buildozer.sh", "-f", temp.name] r = subprocess.call(c) if r not in ok_statuses: - print('{} failed with status {}'.format(c, r)) + print("{} failed with status {}".format(c, r)) sys.exit(1) diff --git a/tools/distrib/run_clang_tidy.py b/tools/distrib/run_clang_tidy.py index 4b0edcfd41fe9..d7c22b373e330 100755 --- a/tools/distrib/run_clang_tidy.py +++ b/tools/distrib/run_clang_tidy.py @@ -20,21 +20,25 @@ import sys sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', 'run_tests', - 'python_utils')) + os.path.join( + os.path.dirname(sys.argv[0]), "..", "run_tests", "python_utils" + ) +) import jobset -clang_tidy = os.environ.get('CLANG_TIDY', 'clang-tidy') +clang_tidy = os.environ.get("CLANG_TIDY", "clang-tidy") -argp = argparse.ArgumentParser(description='Run clang-tidy against core') -argp.add_argument('files', nargs='+', help='Files to tidy') -argp.add_argument('--fix', dest='fix', action='store_true') -argp.add_argument('-j', - '--jobs', - type=int, - default=multiprocessing.cpu_count(), - help='Number of CPUs to use') -argp.add_argument('--only-changed', dest='only_changed', action='store_true') +argp = argparse.ArgumentParser(description="Run clang-tidy against core") +argp.add_argument("files", nargs="+", help="Files to tidy") +argp.add_argument("--fix", dest="fix", action="store_true") +argp.add_argument( + "-j", + "--jobs", + type=int, + default=multiprocessing.cpu_count(), + help="Number of CPUs to use", +) +argp.add_argument("--only-changed", dest="only_changed", action="store_true") argp.set_defaults(fix=False, only_changed=False) args = argp.parse_args() @@ -46,18 +50,19 @@ config = f.read() cmdline = [ clang_tidy, - '--config=' + config, + "--config=" + config, ] if args.fix: - cmdline.append('--fix-errors') + cmdline.append("--fix-errors") if args.only_changed: orig_files = set(args.files) actual_files = [] output = subprocess.check_output( - ['git', 'diff', 'upstream/master', 'HEAD', '--name-only']) - for line in output.decode('ascii').splitlines(False): + ["git", "diff", "upstream/master", "HEAD", "--name-only"] + ) + for line in output.decode("ascii").splitlines(False): if line in orig_files: print(("check: %s" % line)) actual_files.append(line) @@ -72,7 +77,8 @@ cmdline + [filename], shortname=filename, timeout_seconds=15 * 60, - )) + ) + ) num_fails, res_set = jobset.run(jobs, maxjobs=args.jobs, quiet_success=True) sys.exit(num_fails) diff --git a/tools/distrib/sanitize.sh b/tools/distrib/sanitize.sh index 7cfc6559a866c..a31f1b8662847 100755 --- a/tools/distrib/sanitize.sh +++ b/tools/distrib/sanitize.sh @@ -27,7 +27,7 @@ tools/distrib/check_trailing_newlines.sh --fix tools/run_tests/sanity/check_port_platform.py --fix tools/run_tests/sanity/check_include_style.py --fix || true tools/distrib/check_namespace_qualification.py --fix -tools/distrib/yapf_code.sh +tools/distrib/black_code.sh tools/distrib/isort_code.sh tools/distrib/check_redundant_namespace_qualifiers.py || true tools/codegen/core/gen_grpc_tls_credentials_options.py diff --git a/tools/distrib/update_flakes.py b/tools/distrib/update_flakes.py index 16e8304e8368e..f730039beaaa3 100755 --- a/tools/distrib/update_flakes.py +++ b/tools/distrib/update_flakes.py @@ -27,28 +27,29 @@ def include_test(test): - if '@' in test: + if "@" in test: return False if test.startswith("//test/cpp/qps:"): return False return True -TEST_DIRS = ['test/core', 'test/cpp'] +TEST_DIRS = ["test/core", "test/cpp"] tests = {} already_flaky = set() for test_dir in TEST_DIRS: for line in subprocess.check_output( - ['bazel', 'query', 'tests({}/...)'.format(test_dir)]).splitlines(): - test = line.strip().decode('utf-8') + ["bazel", "query", "tests({}/...)".format(test_dir)] + ).splitlines(): + test = line.strip().decode("utf-8") if not include_test(test): continue tests[test] = False for test_dir in TEST_DIRS: for line in subprocess.check_output( - ['bazel', 'query', - 'attr(flaky, 1, tests({}/...))'.format(test_dir)]).splitlines(): - test = line.strip().decode('utf-8') + ["bazel", "query", "attr(flaky, 1, tests({}/...))".format(test_dir)] + ).splitlines(): + test = line.strip().decode("utf-8") if not include_test(test): continue already_flaky.add(test) @@ -57,15 +58,16 @@ def include_test(test): client = bigquery.Client() for row in client.query( - update_flakes_query.QUERY.format( - lookback_hours=lookback_hours)).result(): + update_flakes_query.QUERY.format(lookback_hours=lookback_hours) +).result(): if "/macos/" in row.job_name: continue # we know mac stuff is flaky if row.test_binary not in tests: - m = re.match(r'^//test/core/end2end:([^@]*)@([^@]*)(.*)', - row.test_binary) + m = re.match( + r"^//test/core/end2end:([^@]*)@([^@]*)(.*)", row.test_binary + ) if m: - flaky_e2e.add('{}@{}{}'.format(m.group(1), m.group(2), m.group(3))) + flaky_e2e.add("{}@{}{}".format(m.group(1), m.group(2), m.group(3))) print("will mark end2end test {} as flaky".format(row.test_binary)) else: print("skip obsolete test {}".format(row.test_binary)) @@ -76,29 +78,33 @@ def include_test(test): buildozer_commands = [] for test, flaky in sorted(tests.items()): if flaky: - buildozer_commands.append('set flaky True|{}'.format(test)) + buildozer_commands.append("set flaky True|{}".format(test)) elif test in already_flaky: - buildozer_commands.append('remove flaky|{}'.format(test)) + buildozer_commands.append("remove flaky|{}".format(test)) -with open('test/core/end2end/flaky.bzl', 'w') as f: +with open("test/core/end2end/flaky.bzl", "w") as f: with open(sys.argv[0]) as my_source: for line in my_source: - if line[0] != '#': + if line[0] != "#": break for line in my_source: - if line[0] == '#': + if line[0] == "#": print(line.strip(), file=f) break for line in my_source: - if line[0] != '#': + if line[0] != "#": break print(line.strip(), file=f) print( - "\"\"\"A list of flaky tests, consumed by generate_tests.bzl to set flaky attrs.\"\"\"", - file=f) + ( + '"""A list of flaky tests, consumed by generate_tests.bzl to set' + ' flaky attrs."""' + ), + file=f, + ) print("FLAKY_TESTS = [", file=f) for line in sorted(list(flaky_e2e)): - print(" \"{}\",".format(line), file=f) + print(' "{}",'.format(line), file=f) print("]", file=f) run_buildozer.run_buildozer(buildozer_commands) diff --git a/tools/gcp/utils/big_query_utils.py b/tools/gcp/utils/big_query_utils.py index 1853e62666e0e..630164d52a37d 100755 --- a/tools/gcp/utils/big_query_utils.py +++ b/tools/gcp/utils/big_query_utils.py @@ -29,110 +29,127 @@ def create_big_query(): - """Authenticates with cloud platform and gets a BiqQuery service object - """ + """Authenticates with cloud platform and gets a BiqQuery service object""" creds = GoogleCredentials.get_application_default() - return discovery.build('bigquery', - 'v2', - credentials=creds, - cache_discovery=False) + return discovery.build( + "bigquery", "v2", credentials=creds, cache_discovery=False + ) def create_dataset(biq_query, project_id, dataset_id): is_success = True body = { - 'datasetReference': { - 'projectId': project_id, - 'datasetId': dataset_id - } + "datasetReference": {"projectId": project_id, "datasetId": dataset_id} } try: - dataset_req = biq_query.datasets().insert(projectId=project_id, - body=body) + dataset_req = biq_query.datasets().insert( + projectId=project_id, body=body + ) dataset_req.execute(num_retries=NUM_RETRIES) except HttpError as http_error: if http_error.resp.status == 409: - print('Warning: The dataset %s already exists' % dataset_id) + print("Warning: The dataset %s already exists" % dataset_id) else: # Note: For more debugging info, print "http_error.content" - print('Error in creating dataset: %s. Err: %s' % - (dataset_id, http_error)) + print( + "Error in creating dataset: %s. Err: %s" + % (dataset_id, http_error) + ) is_success = False return is_success -def create_table(big_query, project_id, dataset_id, table_id, table_schema, - description): - fields = [{ - 'name': field_name, - 'type': field_type, - 'description': field_description - } for (field_name, field_type, field_description) in table_schema] - return create_table2(big_query, project_id, dataset_id, table_id, fields, - description) - - -def create_partitioned_table(big_query, - project_id, - dataset_id, - table_id, - table_schema, - description, - partition_type='DAY', - expiration_ms=_EXPIRATION_MS): +def create_table( + big_query, project_id, dataset_id, table_id, table_schema, description +): + fields = [ + { + "name": field_name, + "type": field_type, + "description": field_description, + } + for (field_name, field_type, field_description) in table_schema + ] + return create_table2( + big_query, project_id, dataset_id, table_id, fields, description + ) + + +def create_partitioned_table( + big_query, + project_id, + dataset_id, + table_id, + table_schema, + description, + partition_type="DAY", + expiration_ms=_EXPIRATION_MS, +): """Creates a partitioned table. By default, a date-paritioned table is created with - each partition lasting 30 days after it was last modified. - """ - fields = [{ - 'name': field_name, - 'type': field_type, - 'description': field_description - } for (field_name, field_type, field_description) in table_schema] - return create_table2(big_query, project_id, dataset_id, table_id, fields, - description, partition_type, expiration_ms) - - -def create_table2(big_query, - project_id, - dataset_id, - table_id, - fields_schema, - description, - partition_type=None, - expiration_ms=None): + each partition lasting 30 days after it was last modified. + """ + fields = [ + { + "name": field_name, + "type": field_type, + "description": field_description, + } + for (field_name, field_type, field_description) in table_schema + ] + return create_table2( + big_query, + project_id, + dataset_id, + table_id, + fields, + description, + partition_type, + expiration_ms, + ) + + +def create_table2( + big_query, + project_id, + dataset_id, + table_id, + fields_schema, + description, + partition_type=None, + expiration_ms=None, +): is_success = True body = { - 'description': description, - 'schema': { - 'fields': fields_schema + "description": description, + "schema": {"fields": fields_schema}, + "tableReference": { + "datasetId": dataset_id, + "projectId": project_id, + "tableId": table_id, }, - 'tableReference': { - 'datasetId': dataset_id, - 'projectId': project_id, - 'tableId': table_id - } } if partition_type and expiration_ms: body["timePartitioning"] = { "type": partition_type, - "expirationMs": expiration_ms + "expirationMs": expiration_ms, } try: - table_req = big_query.tables().insert(projectId=project_id, - datasetId=dataset_id, - body=body) + table_req = big_query.tables().insert( + projectId=project_id, datasetId=dataset_id, body=body + ) res = table_req.execute(num_retries=NUM_RETRIES) - print('Successfully created %s "%s"' % (res['kind'], res['id'])) + print('Successfully created %s "%s"' % (res["kind"], res["id"])) except HttpError as http_error: if http_error.resp.status == 409: - print('Warning: Table %s already exists' % table_id) + print("Warning: Table %s already exists" % table_id) else: - print('Error in creating table: %s. Err: %s' % - (table_id, http_error)) + print( + "Error in creating table: %s. Err: %s" % (table_id, http_error) + ) is_success = False return is_success @@ -141,64 +158,68 @@ def patch_table(big_query, project_id, dataset_id, table_id, fields_schema): is_success = True body = { - 'schema': { - 'fields': fields_schema + "schema": {"fields": fields_schema}, + "tableReference": { + "datasetId": dataset_id, + "projectId": project_id, + "tableId": table_id, }, - 'tableReference': { - 'datasetId': dataset_id, - 'projectId': project_id, - 'tableId': table_id - } } try: - table_req = big_query.tables().patch(projectId=project_id, - datasetId=dataset_id, - tableId=table_id, - body=body) + table_req = big_query.tables().patch( + projectId=project_id, + datasetId=dataset_id, + tableId=table_id, + body=body, + ) res = table_req.execute(num_retries=NUM_RETRIES) - print('Successfully patched %s "%s"' % (res['kind'], res['id'])) + print('Successfully patched %s "%s"' % (res["kind"], res["id"])) except HttpError as http_error: - print('Error in creating table: %s. Err: %s' % (table_id, http_error)) + print("Error in creating table: %s. Err: %s" % (table_id, http_error)) is_success = False return is_success def insert_rows(big_query, project_id, dataset_id, table_id, rows_list): is_success = True - body = {'rows': rows_list} + body = {"rows": rows_list} try: - insert_req = big_query.tabledata().insertAll(projectId=project_id, - datasetId=dataset_id, - tableId=table_id, - body=body) + insert_req = big_query.tabledata().insertAll( + projectId=project_id, + datasetId=dataset_id, + tableId=table_id, + body=body, + ) res = insert_req.execute(num_retries=NUM_RETRIES) - if res.get('insertErrors', None): - print('Error inserting rows! Response: %s' % res) + if res.get("insertErrors", None): + print("Error inserting rows! Response: %s" % res) is_success = False except HttpError as http_error: - print('Error inserting rows to the table %s' % table_id) - print('Error message: %s' % http_error) + print("Error inserting rows to the table %s" % table_id) + print("Error message: %s" % http_error) is_success = False return is_success def sync_query_job(big_query, project_id, query, timeout=5000): - query_data = {'query': query, 'timeoutMs': timeout} + query_data = {"query": query, "timeoutMs": timeout} query_job = None try: - query_job = big_query.jobs().query( - projectId=project_id, - body=query_data).execute(num_retries=NUM_RETRIES) + query_job = ( + big_query.jobs() + .query(projectId=project_id, body=query_data) + .execute(num_retries=NUM_RETRIES) + ) except HttpError as http_error: - print('Query execute job failed with error: %s' % http_error) + print("Query execute job failed with error: %s" % http_error) print(http_error.content) return query_job - # List of (column name, column type, description) tuples + + def make_row(unique_row_id, row_values_dict): - """row_values_dict is a dictionary of column name and column value. - """ - return {'insertId': unique_row_id, 'json': row_values_dict} + """row_values_dict is a dictionary of column name and column value.""" + return {"insertId": unique_row_id, "json": row_values_dict} diff --git a/tools/interop_matrix/client_matrix.py b/tools/interop_matrix/client_matrix.py index b442451513816..54cbaa38b688d 100644 --- a/tools/interop_matrix/client_matrix.py +++ b/tools/interop_matrix/client_matrix.py @@ -19,12 +19,12 @@ def get_github_repo(lang): return { - 'dart': 'https://github.com/grpc/grpc-dart.git', - 'go': 'https://github.com/grpc/grpc-go.git', - 'java': 'https://github.com/grpc/grpc-java.git', - 'node': 'https://github.com/grpc/grpc-node.git', + "dart": "https://github.com/grpc/grpc-dart.git", + "go": "https://github.com/grpc/grpc-go.git", + "java": "https://github.com/grpc/grpc-java.git", + "node": "https://github.com/grpc/grpc-node.git", # all other languages use the grpc.git repo. - }.get(lang, 'https://github.com/grpc/grpc.git') + }.get(lang, "https://github.com/grpc/grpc.git") def get_release_tags(lang): @@ -45,21 +45,21 @@ def should_build_docker_interop_image_from_release_tag(lang): # All dockerfile definitions live in grpc/grpc repository. # For language that have a separate repo, we need to use # dockerfile definitions from head of grpc/grpc. - if lang in ['go', 'java', 'node']: + if lang in ["go", "java", "node"]: return False return True # Dictionary of default runtimes per language LANG_RUNTIME_MATRIX = { - 'cxx': ['cxx'], # This is actually debian8. - 'go': ['go1.8', 'go1.11', 'go1.16', 'go1.19'], - 'java': ['java'], - 'python': ['python', 'pythonasyncio'], - 'node': ['node'], - 'ruby': ['ruby'], - 'php': ['php7'], - 'csharp': ['csharp', 'csharpcoreclr'], + "cxx": ["cxx"], # This is actually debian8. + "go": ["go1.8", "go1.11", "go1.16", "go1.19"], + "java": ["java"], + "python": ["python", "pythonasyncio"], + "node": ["node"], + "ruby": ["ruby"], + "php": ["php7"], + "csharp": ["csharp", "csharpcoreclr"], } @@ -74,575 +74,855 @@ def __init__(self, patch=[], runtimes=[], testcases_file=None): # Dictionary of known releases for given language. LANG_RELEASE_MATRIX = { - 'cxx': - OrderedDict([ - ('v1.0.1', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.1.4', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.2.5', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.3.9', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.4.2', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.6.6', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.7.2', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.8.0', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.9.1', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.10.1', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.11.1', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.12.0', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.13.0', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.14.1', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.15.0', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.16.0', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.17.1', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.18.0', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.19.0', ReleaseInfo(testcases_file='cxx__v1.0.1')), - ('v1.20.0', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.21.4', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.22.0', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.22.1', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.23.0', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.24.0', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.25.0', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.26.0', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.27.3', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.30.0', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.31.1', ReleaseInfo(testcases_file='cxx__v1.31.1')), - ('v1.32.0', ReleaseInfo()), - ('v1.33.2', ReleaseInfo()), - ('v1.34.0', ReleaseInfo()), - ('v1.35.0', ReleaseInfo()), - ('v1.36.3', ReleaseInfo()), - ('v1.37.0', ReleaseInfo()), - ('v1.38.0', ReleaseInfo()), - ('v1.39.0', ReleaseInfo()), - ('v1.41.1', ReleaseInfo()), - ('v1.42.0', ReleaseInfo()), - ('v1.43.0', ReleaseInfo()), - ('v1.44.0', ReleaseInfo()), - ('v1.46.2', ReleaseInfo()), - ('v1.47.1', ReleaseInfo()), - ('v1.48.3', ReleaseInfo()), - ('v1.49.1', ReleaseInfo()), - ('v1.52.0', ReleaseInfo()), - ('v1.53.0', ReleaseInfo()), - ('v1.54.0', ReleaseInfo()), - ('v1.55.0', ReleaseInfo()), - ]), - 'go': - OrderedDict([ - ('v1.0.5', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.2.1', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.3.0', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.4.2', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.5.2', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.6.0', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.7.4', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.8.2', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.9.2', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.10.1', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.11.3', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.12.2', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.13.0', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.14.0', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.15.0', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.16.0', - ReleaseInfo(runtimes=['go1.8'], testcases_file='go__v1.0.5')), - ('v1.17.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.0.5')), - ('v1.18.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.0.5')), - ('v1.19.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.0.5')), - ('v1.20.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.21.3', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.22.3', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.23.1', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.24.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.25.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.26.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.27.1', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.28.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.29.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.30.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.31.1', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.32.0', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.33.1', - ReleaseInfo(runtimes=['go1.11'], testcases_file='go__v1.20.0')), - ('v1.34.0', ReleaseInfo(runtimes=['go1.11'])), - ('v1.35.0', ReleaseInfo(runtimes=['go1.11'])), - ('v1.36.0', ReleaseInfo(runtimes=['go1.11'])), - ('v1.37.0', ReleaseInfo(runtimes=['go1.11'])), + "cxx": OrderedDict( + [ + ("v1.0.1", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.1.4", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.2.5", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.3.9", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.4.2", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.6.6", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.7.2", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.8.0", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.9.1", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.10.1", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.11.1", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.12.0", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.13.0", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.14.1", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.15.0", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.16.0", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.17.1", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.18.0", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.19.0", ReleaseInfo(testcases_file="cxx__v1.0.1")), + ("v1.20.0", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.21.4", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.22.0", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.22.1", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.23.0", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.24.0", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.25.0", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.26.0", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.27.3", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.30.0", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.31.1", ReleaseInfo(testcases_file="cxx__v1.31.1")), + ("v1.32.0", ReleaseInfo()), + ("v1.33.2", ReleaseInfo()), + ("v1.34.0", ReleaseInfo()), + ("v1.35.0", ReleaseInfo()), + ("v1.36.3", ReleaseInfo()), + ("v1.37.0", ReleaseInfo()), + ("v1.38.0", ReleaseInfo()), + ("v1.39.0", ReleaseInfo()), + ("v1.41.1", ReleaseInfo()), + ("v1.42.0", ReleaseInfo()), + ("v1.43.0", ReleaseInfo()), + ("v1.44.0", ReleaseInfo()), + ("v1.46.2", ReleaseInfo()), + ("v1.47.1", ReleaseInfo()), + ("v1.48.3", ReleaseInfo()), + ("v1.49.1", ReleaseInfo()), + ("v1.52.0", ReleaseInfo()), + ("v1.53.0", ReleaseInfo()), + ("v1.54.0", ReleaseInfo()), + ("v1.55.0", ReleaseInfo()), + ] + ), + "go": OrderedDict( + [ + ( + "v1.0.5", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.2.1", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.3.0", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.4.2", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.5.2", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.6.0", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.7.4", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.8.2", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.9.2", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.10.1", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.11.3", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.12.2", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.13.0", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.14.0", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.15.0", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.16.0", + ReleaseInfo(runtimes=["go1.8"], testcases_file="go__v1.0.5"), + ), + ( + "v1.17.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.0.5"), + ), + ( + "v1.18.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.0.5"), + ), + ( + "v1.19.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.0.5"), + ), + ( + "v1.20.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.21.3", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.22.3", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.23.1", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.24.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.25.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.26.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.27.1", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.28.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.29.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.30.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.31.1", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.32.0", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ( + "v1.33.1", + ReleaseInfo(runtimes=["go1.11"], testcases_file="go__v1.20.0"), + ), + ("v1.34.0", ReleaseInfo(runtimes=["go1.11"])), + ("v1.35.0", ReleaseInfo(runtimes=["go1.11"])), + ("v1.36.0", ReleaseInfo(runtimes=["go1.11"])), + ("v1.37.0", ReleaseInfo(runtimes=["go1.11"])), # NOTE: starting from release v1.38.0, use runtimes=['go1.16'] - ('v1.38.1', ReleaseInfo(runtimes=['go1.16'])), - ('v1.39.1', ReleaseInfo(runtimes=['go1.16'])), - ('v1.40.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.41.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.42.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.43.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.44.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.45.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.46.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.47.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.48.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.49.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.50.1', ReleaseInfo(runtimes=['go1.16'])), - ('v1.51.0', ReleaseInfo(runtimes=['go1.16'])), - ('v1.52.3', ReleaseInfo(runtimes=['go1.19'])), - ('v1.53.0', ReleaseInfo(runtimes=['go1.19'])), - ('v1.54.1', ReleaseInfo(runtimes=['go1.19'])), - ('v1.55.0', ReleaseInfo(runtimes=['go1.19'])), - ]), - 'java': - OrderedDict([ - ('v1.0.3', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.1.2', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.2.0', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.3.1', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.4.0', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.5.0', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.6.1', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.7.1', ReleaseInfo(testcases_file='java__v1.0.3')), - ('v1.8.0', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.9.1', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.10.1', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.11.0', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.12.1', ReleaseInfo(testcases_file='java__v1.0.3')), - ('v1.13.2', ReleaseInfo(testcases_file='java__v1.0.3')), - ('v1.14.0', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.15.1', ReleaseInfo(testcases_file='java__v1.0.3')), - ('v1.16.1', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.17.2', ReleaseInfo(testcases_file='java__v1.0.3')), - ('v1.18.0', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.19.0', - ReleaseInfo(runtimes=['java_oracle8'], - testcases_file='java__v1.0.3')), - ('v1.20.0', ReleaseInfo(runtimes=['java_oracle8'])), - ('v1.21.1', ReleaseInfo()), - ('v1.22.2', ReleaseInfo()), - ('v1.23.0', ReleaseInfo()), - ('v1.24.0', ReleaseInfo()), - ('v1.25.0', ReleaseInfo()), - ('v1.26.1', ReleaseInfo()), - ('v1.27.2', ReleaseInfo()), - ('v1.28.1', ReleaseInfo()), - ('v1.29.0', ReleaseInfo()), - ('v1.30.2', ReleaseInfo()), - ('v1.31.2', ReleaseInfo()), - ('v1.32.3', ReleaseInfo()), - ('v1.33.1', ReleaseInfo()), - ('v1.34.1', ReleaseInfo()), - ('v1.35.1', ReleaseInfo()), - ('v1.36.3', ReleaseInfo()), - ('v1.37.1', ReleaseInfo()), - ('v1.38.1', ReleaseInfo()), - ('v1.39.0', ReleaseInfo()), - ('v1.40.2', ReleaseInfo()), - ('v1.41.3', ReleaseInfo()), - ('v1.42.3', ReleaseInfo()), - ('v1.43.3', ReleaseInfo()), - ('v1.44.2', ReleaseInfo()), - ('v1.45.3', ReleaseInfo()), - ('v1.46.1', ReleaseInfo()), - ('v1.47.1', ReleaseInfo()), - ('v1.48.2', ReleaseInfo()), - ('v1.49.2', ReleaseInfo()), - ('v1.50.2', ReleaseInfo()), - ('v1.51.1', ReleaseInfo()), - ('v1.52.0', ReleaseInfo()), - ('v1.54.0', ReleaseInfo()), - ]), - 'python': - OrderedDict( - [ - ('v1.0.x', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.1.4', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.2.5', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.3.9', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.4.2', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.6.6', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.7.2', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.8.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.9.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.10.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.0.x')), - ('v1.11.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.11.1')), - ('v1.12.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.11.1')), - ('v1.13.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.11.1')), - ('v1.14.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.11.1')), - ('v1.15.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.11.1')), - ('v1.16.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.11.1')), - ('v1.17.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.11.1')), - ('v1.18.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.19.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.20.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.21.4', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.22.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.22.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.23.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.24.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.25.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.26.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.27.3', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.30.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.31.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.32.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.33.2', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.34.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.35.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.36.3', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.37.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.38.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.39.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.18.0')), - ('v1.41.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.41.1')), - ('v1.42.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.41.1')), - ('v1.43.2', - ReleaseInfo(runtimes=['python'], - testcases_file='python__v1.41.1')), - ('v1.44.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ('v1.46.2', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ('v1.47.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ('v1.48.3', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ('v1.49.1', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ('v1.52.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ('v1.53.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ('v1.54.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ('v1.55.0', - ReleaseInfo(runtimes=['python'], - testcases_file='python__master')), - ]), - 'node': - OrderedDict([ - ('v1.0.1', ReleaseInfo(testcases_file='node__v1.0.1')), - ('v1.1.4', ReleaseInfo(testcases_file='node__v1.1.4')), - ('v1.2.5', ReleaseInfo(testcases_file='node__v1.1.4')), - ('v1.3.9', ReleaseInfo(testcases_file='node__v1.1.4')), - ('v1.4.2', ReleaseInfo(testcases_file='node__v1.1.4')), - ('v1.6.6', ReleaseInfo(testcases_file='node__v1.1.4')), + ("v1.38.1", ReleaseInfo(runtimes=["go1.16"])), + ("v1.39.1", ReleaseInfo(runtimes=["go1.16"])), + ("v1.40.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.41.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.42.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.43.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.44.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.45.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.46.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.47.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.48.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.49.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.50.1", ReleaseInfo(runtimes=["go1.16"])), + ("v1.51.0", ReleaseInfo(runtimes=["go1.16"])), + ("v1.52.3", ReleaseInfo(runtimes=["go1.19"])), + ("v1.53.0", ReleaseInfo(runtimes=["go1.19"])), + ("v1.54.1", ReleaseInfo(runtimes=["go1.19"])), + ("v1.55.0", ReleaseInfo(runtimes=["go1.19"])), + ] + ), + "java": OrderedDict( + [ + ( + "v1.0.3", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.1.2", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.2.0", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.3.1", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.4.0", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.5.0", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.6.1", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ("v1.7.1", ReleaseInfo(testcases_file="java__v1.0.3")), + ( + "v1.8.0", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.9.1", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.10.1", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.11.0", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ("v1.12.1", ReleaseInfo(testcases_file="java__v1.0.3")), + ("v1.13.2", ReleaseInfo(testcases_file="java__v1.0.3")), + ( + "v1.14.0", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ("v1.15.1", ReleaseInfo(testcases_file="java__v1.0.3")), + ( + "v1.16.1", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ("v1.17.2", ReleaseInfo(testcases_file="java__v1.0.3")), + ( + "v1.18.0", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ( + "v1.19.0", + ReleaseInfo( + runtimes=["java_oracle8"], testcases_file="java__v1.0.3" + ), + ), + ("v1.20.0", ReleaseInfo(runtimes=["java_oracle8"])), + ("v1.21.1", ReleaseInfo()), + ("v1.22.2", ReleaseInfo()), + ("v1.23.0", ReleaseInfo()), + ("v1.24.0", ReleaseInfo()), + ("v1.25.0", ReleaseInfo()), + ("v1.26.1", ReleaseInfo()), + ("v1.27.2", ReleaseInfo()), + ("v1.28.1", ReleaseInfo()), + ("v1.29.0", ReleaseInfo()), + ("v1.30.2", ReleaseInfo()), + ("v1.31.2", ReleaseInfo()), + ("v1.32.3", ReleaseInfo()), + ("v1.33.1", ReleaseInfo()), + ("v1.34.1", ReleaseInfo()), + ("v1.35.1", ReleaseInfo()), + ("v1.36.3", ReleaseInfo()), + ("v1.37.1", ReleaseInfo()), + ("v1.38.1", ReleaseInfo()), + ("v1.39.0", ReleaseInfo()), + ("v1.40.2", ReleaseInfo()), + ("v1.41.3", ReleaseInfo()), + ("v1.42.3", ReleaseInfo()), + ("v1.43.3", ReleaseInfo()), + ("v1.44.2", ReleaseInfo()), + ("v1.45.3", ReleaseInfo()), + ("v1.46.1", ReleaseInfo()), + ("v1.47.1", ReleaseInfo()), + ("v1.48.2", ReleaseInfo()), + ("v1.49.2", ReleaseInfo()), + ("v1.50.2", ReleaseInfo()), + ("v1.51.1", ReleaseInfo()), + ("v1.52.0", ReleaseInfo()), + ("v1.54.0", ReleaseInfo()), + ] + ), + "python": OrderedDict( + [ + ( + "v1.0.x", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.1.4", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.2.5", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.3.9", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.4.2", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.6.6", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.7.2", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.8.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.9.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.10.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.0.x" + ), + ), + ( + "v1.11.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.11.1" + ), + ), + ( + "v1.12.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.11.1" + ), + ), + ( + "v1.13.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.11.1" + ), + ), + ( + "v1.14.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.11.1" + ), + ), + ( + "v1.15.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.11.1" + ), + ), + ( + "v1.16.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.11.1" + ), + ), + ( + "v1.17.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.11.1" + ), + ), + ( + "v1.18.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.19.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.20.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.21.4", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.22.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.22.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.23.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.24.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.25.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.26.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.27.3", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.30.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.31.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.32.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.33.2", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.34.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.35.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.36.3", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.37.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.38.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.39.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.18.0" + ), + ), + ( + "v1.41.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.41.1" + ), + ), + ( + "v1.42.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.41.1" + ), + ), + ( + "v1.43.2", + ReleaseInfo( + runtimes=["python"], testcases_file="python__v1.41.1" + ), + ), + ( + "v1.44.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ( + "v1.46.2", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ( + "v1.47.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ( + "v1.48.3", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ( + "v1.49.1", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ( + "v1.52.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ( + "v1.53.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ( + "v1.54.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ( + "v1.55.0", + ReleaseInfo( + runtimes=["python"], testcases_file="python__master" + ), + ), + ] + ), + "node": OrderedDict( + [ + ("v1.0.1", ReleaseInfo(testcases_file="node__v1.0.1")), + ("v1.1.4", ReleaseInfo(testcases_file="node__v1.1.4")), + ("v1.2.5", ReleaseInfo(testcases_file="node__v1.1.4")), + ("v1.3.9", ReleaseInfo(testcases_file="node__v1.1.4")), + ("v1.4.2", ReleaseInfo(testcases_file="node__v1.1.4")), + ("v1.6.6", ReleaseInfo(testcases_file="node__v1.1.4")), # TODO: https://github.com/grpc/grpc-node/issues/235. # ('v1.7.2', ReleaseInfo()), - ('v1.8.4', ReleaseInfo()), - ('v1.9.1', ReleaseInfo()), - ('v1.10.0', ReleaseInfo()), - ('v1.11.3', ReleaseInfo()), - ('v1.12.4', ReleaseInfo()), - ]), - 'ruby': - OrderedDict([ - ('v1.0.1', - ReleaseInfo(patch=[ - 'tools/dockerfile/interoptest/grpc_interop_ruby/Dockerfile', - 'tools/dockerfile/interoptest/grpc_interop_ruby/build_interop.sh', - ], - testcases_file='ruby__v1.0.1')), - ('v1.1.4', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.2.5', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.3.9', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.4.2', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.6.6', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.7.2', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.8.0', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.9.1', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.10.1', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.11.1', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.12.0', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.13.0', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.14.1', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.15.0', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.16.0', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.17.1', ReleaseInfo(testcases_file='ruby__v1.1.4')), - ('v1.18.0', - ReleaseInfo(patch=[ - 'tools/dockerfile/interoptest/grpc_interop_ruby/build_interop.sh', - ])), - ('v1.19.0', ReleaseInfo()), - ('v1.20.0', ReleaseInfo()), - ('v1.21.4', ReleaseInfo()), - ('v1.22.0', ReleaseInfo()), - ('v1.22.1', ReleaseInfo()), - ('v1.23.0', ReleaseInfo()), - ('v1.24.0', ReleaseInfo()), - ('v1.25.0', ReleaseInfo()), + ("v1.8.4", ReleaseInfo()), + ("v1.9.1", ReleaseInfo()), + ("v1.10.0", ReleaseInfo()), + ("v1.11.3", ReleaseInfo()), + ("v1.12.4", ReleaseInfo()), + ] + ), + "ruby": OrderedDict( + [ + ( + "v1.0.1", + ReleaseInfo( + patch=[ + "tools/dockerfile/interoptest/grpc_interop_ruby/Dockerfile", + "tools/dockerfile/interoptest/grpc_interop_ruby/build_interop.sh", + ], + testcases_file="ruby__v1.0.1", + ), + ), + ("v1.1.4", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.2.5", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.3.9", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.4.2", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.6.6", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.7.2", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.8.0", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.9.1", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.10.1", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.11.1", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.12.0", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.13.0", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.14.1", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.15.0", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.16.0", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ("v1.17.1", ReleaseInfo(testcases_file="ruby__v1.1.4")), + ( + "v1.18.0", + ReleaseInfo( + patch=[ + "tools/dockerfile/interoptest/grpc_interop_ruby/build_interop.sh", + ] + ), + ), + ("v1.19.0", ReleaseInfo()), + ("v1.20.0", ReleaseInfo()), + ("v1.21.4", ReleaseInfo()), + ("v1.22.0", ReleaseInfo()), + ("v1.22.1", ReleaseInfo()), + ("v1.23.0", ReleaseInfo()), + ("v1.24.0", ReleaseInfo()), + ("v1.25.0", ReleaseInfo()), # TODO: https://github.com/grpc/grpc/issues/18262. # If you are not encountering the error in above issue # go ahead and upload the docker image for new releases. - ('v1.26.0', ReleaseInfo()), - ('v1.27.3', ReleaseInfo()), - ('v1.30.0', ReleaseInfo()), - ('v1.31.1', ReleaseInfo()), - ('v1.32.0', ReleaseInfo()), - ('v1.33.2', ReleaseInfo()), - ('v1.34.0', ReleaseInfo()), - ('v1.35.0', ReleaseInfo()), - ('v1.36.3', ReleaseInfo()), - ('v1.37.0', ReleaseInfo()), - ('v1.38.0', ReleaseInfo()), - ('v1.39.0', ReleaseInfo()), - ('v1.41.1', ReleaseInfo()), - ('v1.42.0', ReleaseInfo()), - ('v1.43.0', ReleaseInfo()), - ('v1.44.0', ReleaseInfo()), - ('v1.46.2', ReleaseInfo()), - ('v1.47.1', ReleaseInfo()), - ('v1.48.3', ReleaseInfo()), - ('v1.49.1', ReleaseInfo()), - ('v1.52.0', ReleaseInfo()), - ('v1.53.0', ReleaseInfo()), - ('v1.54.0', ReleaseInfo()), - ('v1.55.0', ReleaseInfo()), - ]), - 'php': - OrderedDict([ - ('v1.0.1', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.1.4', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.2.5', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.3.9', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.4.2', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.6.6', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.7.2', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.8.0', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.9.1', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.10.1', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.11.1', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.12.0', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.13.0', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.14.1', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.15.0', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.16.0', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.17.1', ReleaseInfo(testcases_file='php__v1.0.1')), - ('v1.18.0', ReleaseInfo()), + ("v1.26.0", ReleaseInfo()), + ("v1.27.3", ReleaseInfo()), + ("v1.30.0", ReleaseInfo()), + ("v1.31.1", ReleaseInfo()), + ("v1.32.0", ReleaseInfo()), + ("v1.33.2", ReleaseInfo()), + ("v1.34.0", ReleaseInfo()), + ("v1.35.0", ReleaseInfo()), + ("v1.36.3", ReleaseInfo()), + ("v1.37.0", ReleaseInfo()), + ("v1.38.0", ReleaseInfo()), + ("v1.39.0", ReleaseInfo()), + ("v1.41.1", ReleaseInfo()), + ("v1.42.0", ReleaseInfo()), + ("v1.43.0", ReleaseInfo()), + ("v1.44.0", ReleaseInfo()), + ("v1.46.2", ReleaseInfo()), + ("v1.47.1", ReleaseInfo()), + ("v1.48.3", ReleaseInfo()), + ("v1.49.1", ReleaseInfo()), + ("v1.52.0", ReleaseInfo()), + ("v1.53.0", ReleaseInfo()), + ("v1.54.0", ReleaseInfo()), + ("v1.55.0", ReleaseInfo()), + ] + ), + "php": OrderedDict( + [ + ("v1.0.1", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.1.4", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.2.5", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.3.9", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.4.2", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.6.6", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.7.2", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.8.0", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.9.1", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.10.1", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.11.1", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.12.0", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.13.0", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.14.1", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.15.0", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.16.0", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.17.1", ReleaseInfo(testcases_file="php__v1.0.1")), + ("v1.18.0", ReleaseInfo()), # v1.19 and v1.20 were deliberately omitted here because of an issue. # See https://github.com/grpc/grpc/issues/18264 - ('v1.21.4', ReleaseInfo()), - ('v1.22.0', ReleaseInfo()), - ('v1.22.1', ReleaseInfo()), - ('v1.23.0', ReleaseInfo()), - ('v1.24.0', ReleaseInfo()), - ('v1.25.0', ReleaseInfo()), - ('v1.26.0', ReleaseInfo()), - ('v1.27.3', ReleaseInfo()), - ('v1.30.0', ReleaseInfo()), - ('v1.31.1', ReleaseInfo()), - ('v1.32.0', ReleaseInfo()), - ('v1.33.2', ReleaseInfo()), - ('v1.34.0', ReleaseInfo()), - ('v1.35.0', ReleaseInfo()), - ('v1.36.3', ReleaseInfo()), - ('v1.37.0', ReleaseInfo()), - ('v1.38.0', ReleaseInfo()), - ('v1.39.0', ReleaseInfo()), - ('v1.41.1', ReleaseInfo()), - ('v1.42.0', ReleaseInfo()), - ('v1.43.0', ReleaseInfo()), - ('v1.44.0', ReleaseInfo()), - ('v1.46.2', ReleaseInfo()), - ('v1.47.1', ReleaseInfo()), - ('v1.48.3', ReleaseInfo()), - ('v1.49.1', ReleaseInfo()), - ('v1.52.0', ReleaseInfo()), - ('v1.53.0', ReleaseInfo()), - ('v1.54.0', ReleaseInfo()), - ('v1.55.0', ReleaseInfo()), - ]), - 'csharp': - OrderedDict([ - ('v1.0.1', - ReleaseInfo(patch=[ - 'tools/dockerfile/interoptest/grpc_interop_csharp/Dockerfile', - 'tools/dockerfile/interoptest/grpc_interop_csharpcoreclr/Dockerfile', - ], - testcases_file='csharp__v1.1.4')), - ('v1.1.4', ReleaseInfo(testcases_file='csharp__v1.1.4')), - ('v1.2.5', ReleaseInfo(testcases_file='csharp__v1.1.4')), - ('v1.3.9', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.4.2', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.6.6', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.7.2', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.8.0', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.9.1', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.10.1', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.11.1', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.12.0', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.13.0', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.14.1', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.15.0', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.16.0', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.17.1', ReleaseInfo(testcases_file='csharp__v1.3.9')), - ('v1.18.0', ReleaseInfo(testcases_file='csharp__v1.18.0')), - ('v1.19.0', ReleaseInfo(testcases_file='csharp__v1.18.0')), - ('v1.20.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.20.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.21.4', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.22.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.22.1', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.23.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.24.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.25.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.26.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.27.3', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.30.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.31.1', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.32.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.33.2', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.34.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.35.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.36.3', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.37.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.38.1', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.39.1', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.41.1', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.42.0', ReleaseInfo(testcases_file='csharp__v1.20.0')), - ('v1.43.0', ReleaseInfo()), - ('v1.44.0', ReleaseInfo()), - ('v1.46.2', ReleaseInfo()), - ]), + ("v1.21.4", ReleaseInfo()), + ("v1.22.0", ReleaseInfo()), + ("v1.22.1", ReleaseInfo()), + ("v1.23.0", ReleaseInfo()), + ("v1.24.0", ReleaseInfo()), + ("v1.25.0", ReleaseInfo()), + ("v1.26.0", ReleaseInfo()), + ("v1.27.3", ReleaseInfo()), + ("v1.30.0", ReleaseInfo()), + ("v1.31.1", ReleaseInfo()), + ("v1.32.0", ReleaseInfo()), + ("v1.33.2", ReleaseInfo()), + ("v1.34.0", ReleaseInfo()), + ("v1.35.0", ReleaseInfo()), + ("v1.36.3", ReleaseInfo()), + ("v1.37.0", ReleaseInfo()), + ("v1.38.0", ReleaseInfo()), + ("v1.39.0", ReleaseInfo()), + ("v1.41.1", ReleaseInfo()), + ("v1.42.0", ReleaseInfo()), + ("v1.43.0", ReleaseInfo()), + ("v1.44.0", ReleaseInfo()), + ("v1.46.2", ReleaseInfo()), + ("v1.47.1", ReleaseInfo()), + ("v1.48.3", ReleaseInfo()), + ("v1.49.1", ReleaseInfo()), + ("v1.52.0", ReleaseInfo()), + ("v1.53.0", ReleaseInfo()), + ("v1.54.0", ReleaseInfo()), + ("v1.55.0", ReleaseInfo()), + ] + ), + "csharp": OrderedDict( + [ + ( + "v1.0.1", + ReleaseInfo( + patch=[ + "tools/dockerfile/interoptest/grpc_interop_csharp/Dockerfile", + "tools/dockerfile/interoptest/grpc_interop_csharpcoreclr/Dockerfile", + ], + testcases_file="csharp__v1.1.4", + ), + ), + ("v1.1.4", ReleaseInfo(testcases_file="csharp__v1.1.4")), + ("v1.2.5", ReleaseInfo(testcases_file="csharp__v1.1.4")), + ("v1.3.9", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.4.2", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.6.6", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.7.2", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.8.0", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.9.1", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.10.1", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.11.1", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.12.0", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.13.0", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.14.1", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.15.0", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.16.0", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.17.1", ReleaseInfo(testcases_file="csharp__v1.3.9")), + ("v1.18.0", ReleaseInfo(testcases_file="csharp__v1.18.0")), + ("v1.19.0", ReleaseInfo(testcases_file="csharp__v1.18.0")), + ("v1.20.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.20.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.21.4", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.22.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.22.1", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.23.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.24.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.25.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.26.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.27.3", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.30.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.31.1", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.32.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.33.2", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.34.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.35.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.36.3", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.37.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.38.1", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.39.1", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.41.1", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.42.0", ReleaseInfo(testcases_file="csharp__v1.20.0")), + ("v1.43.0", ReleaseInfo()), + ("v1.44.0", ReleaseInfo()), + ("v1.46.2", ReleaseInfo()), + ] + ), } diff --git a/tools/interop_matrix/create_matrix_images.py b/tools/interop_matrix/create_matrix_images.py index 5ad639c3d4cc3..6d86540567f3e 100755 --- a/tools/interop_matrix/create_matrix_images.py +++ b/tools/interop_matrix/create_matrix_images.py @@ -29,70 +29,99 @@ import client_matrix python_util_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), '../run_tests/python_utils')) + os.path.join(os.path.dirname(__file__), "../run_tests/python_utils") +) sys.path.append(python_util_dir) import dockerjob import jobset -_IMAGE_BUILDER = 'tools/run_tests/dockerize/build_interop_image.sh' +_IMAGE_BUILDER = "tools/run_tests/dockerize/build_interop_image.sh" _LANGUAGES = list(client_matrix.LANG_RUNTIME_MATRIX.keys()) # All gRPC release tags, flattened, deduped and sorted. _RELEASES = sorted( list( - set(release + set( + release for release_dict in list(client_matrix.LANG_RELEASE_MATRIX.values()) - for release in list(release_dict.keys())))) + for release in list(release_dict.keys()) + ) + ) +) # Destination directory inside docker image to keep extra info from build time. -_BUILD_INFO = '/var/local/build_info' - -argp = argparse.ArgumentParser(description='Run interop tests.') -argp.add_argument('--gcr_path', - default='gcr.io/grpc-testing', - help='Path of docker images in Google Container Registry') - -argp.add_argument('--release', - default='master', - choices=['all', 'master'] + _RELEASES, - help='github commit tag to checkout. When building all ' - 'releases defined in client_matrix.py, use "all". Valid only ' - 'with --git_checkout.') - -argp.add_argument('-l', - '--language', - choices=['all'] + sorted(_LANGUAGES), - nargs='+', - default=['all'], - help='Test languages to build docker images for.') - -argp.add_argument('--git_checkout', - action='store_true', - help='Use a separate git clone tree for building grpc stack. ' - 'Required when using --release flag. By default, current' - 'tree and the sibling will be used for building grpc stack.') - -argp.add_argument('--git_checkout_root', - default='/export/hda3/tmp/grpc_matrix', - help='Directory under which grpc-go/java/main repo will be ' - 'cloned. Valid only with --git_checkout.') - -argp.add_argument('--keep', - action='store_true', - help='keep the created local images after uploading to GCR') - -argp.add_argument('--reuse_git_root', - default=False, - action='store_const', - const=True, - help='reuse the repo dir. If False, the existing git root ' - 'directory will removed before a clean checkout, because ' - 'reusing the repo can cause git checkout error if you switch ' - 'between releases.') +_BUILD_INFO = "/var/local/build_info" + +argp = argparse.ArgumentParser(description="Run interop tests.") +argp.add_argument( + "--gcr_path", + default="gcr.io/grpc-testing", + help="Path of docker images in Google Container Registry", +) + +argp.add_argument( + "--release", + default="master", + choices=["all", "master"] + _RELEASES, + help=( + "github commit tag to checkout. When building all " + 'releases defined in client_matrix.py, use "all". Valid only ' + "with --git_checkout." + ), +) + +argp.add_argument( + "-l", + "--language", + choices=["all"] + sorted(_LANGUAGES), + nargs="+", + default=["all"], + help="Test languages to build docker images for.", +) + +argp.add_argument( + "--git_checkout", + action="store_true", + help=( + "Use a separate git clone tree for building grpc stack. " + "Required when using --release flag. By default, current" + "tree and the sibling will be used for building grpc stack." + ), +) + +argp.add_argument( + "--git_checkout_root", + default="/export/hda3/tmp/grpc_matrix", + help=( + "Directory under which grpc-go/java/main repo will be " + "cloned. Valid only with --git_checkout." + ), +) + +argp.add_argument( + "--keep", + action="store_true", + help="keep the created local images after uploading to GCR", +) + +argp.add_argument( + "--reuse_git_root", + default=False, + action="store_const", + const=True, + help=( + "reuse the repo dir. If False, the existing git root " + "directory will removed before a clean checkout, because " + "reusing the repo can cause git checkout error if you switch " + "between releases." + ), +) argp.add_argument( - '--upload_images', - action='store_true', - help='If set, images will be uploaded to container registry after building.' + "--upload_images", + action="store_true", + help=( + "If set, images will be uploaded to container registry after building." + ), ) args = argp.parse_args() @@ -101,38 +130,38 @@ def add_files_to_image(image, with_files, label=None): """Add files to a docker image. - image: docker image name, i.e. grpc_interop_java:26328ad8 - with_files: additional files to include in the docker image. - label: label string to attach to the image. - """ - tag_idx = image.find(':') + image: docker image name, i.e. grpc_interop_java:26328ad8 + with_files: additional files to include in the docker image. + label: label string to attach to the image. + """ + tag_idx = image.find(":") if tag_idx == -1: - jobset.message('FAILED', - 'invalid docker image %s' % image, - do_newline=True) + jobset.message( + "FAILED", "invalid docker image %s" % image, do_newline=True + ) sys.exit(1) - orig_tag = '%s_' % image - subprocess.check_output(['docker', 'tag', image, orig_tag]) + orig_tag = "%s_" % image + subprocess.check_output(["docker", "tag", image, orig_tag]) - lines = ['FROM ' + orig_tag] + lines = ["FROM " + orig_tag] if label: - lines.append('LABEL %s' % label) + lines.append("LABEL %s" % label) temp_dir = tempfile.mkdtemp() - atexit.register(lambda: subprocess.call(['rm', '-rf', temp_dir])) + atexit.register(lambda: subprocess.call(["rm", "-rf", temp_dir])) # Copy with_files inside the tmp directory, which will be the docker build # context. for f in with_files: shutil.copy(f, temp_dir) - lines.append('COPY %s %s/' % (os.path.basename(f), _BUILD_INFO)) + lines.append("COPY %s %s/" % (os.path.basename(f), _BUILD_INFO)) # Create a Dockerfile. - with open(os.path.join(temp_dir, 'Dockerfile'), 'w') as f: - f.write('\n'.join(lines)) + with open(os.path.join(temp_dir, "Dockerfile"), "w") as f: + f.write("\n".join(lines)) - jobset.message('START', 'Repackaging %s' % image, do_newline=True) - build_cmd = ['docker', 'build', '--rm', '--tag', image, temp_dir] + jobset.message("START", "Repackaging %s" % image, do_newline=True) + build_cmd = ["docker", "build", "--rm", "--tag", image, temp_dir] subprocess.check_output(build_cmd) dockerjob.remove_image(orig_tag, skip_nonexistent=True) @@ -140,22 +169,24 @@ def add_files_to_image(image, with_files, label=None): def build_image_jobspec(runtime, env, gcr_tag, stack_base): """Build interop docker image for a language with runtime. - runtime: a string, for example go1.8. - env: dictionary of env to passed to the build script. - gcr_tag: the tag for the docker image (i.e. v1.3.0). - stack_base: the local gRPC repo path. - """ - basename = 'grpc_interop_%s' % runtime - tag = '%s/%s:%s' % (args.gcr_path, basename, gcr_tag) - build_env = {'INTEROP_IMAGE': tag, 'BASE_NAME': basename} + runtime: a string, for example go1.8. + env: dictionary of env to passed to the build script. + gcr_tag: the tag for the docker image (i.e. v1.3.0). + stack_base: the local gRPC repo path. + """ + basename = "grpc_interop_%s" % runtime + tag = "%s/%s:%s" % (args.gcr_path, basename, gcr_tag) + build_env = {"INTEROP_IMAGE": tag, "BASE_NAME": basename} build_env.update(env) image_builder_path = _IMAGE_BUILDER if client_matrix.should_build_docker_interop_image_from_release_tag(lang): image_builder_path = os.path.join(stack_base, _IMAGE_BUILDER) - build_job = jobset.JobSpec(cmdline=[image_builder_path], - environ=build_env, - shortname='build_docker_%s' % runtime, - timeout_seconds=30 * 60) + build_job = jobset.JobSpec( + cmdline=[image_builder_path], + environ=build_env, + shortname="build_docker_%s" % runtime, + timeout_seconds=30 * 60, + ) build_job.tag = tag return build_job @@ -163,31 +194,36 @@ def build_image_jobspec(runtime, env, gcr_tag, stack_base): def build_all_images_for_lang(lang): """Build all docker images for a language across releases and runtimes.""" if not args.git_checkout: - if args.release != 'master': + if args.release != "master": print( - 'Cannot use --release without also enabling --git_checkout.\n') + "Cannot use --release without also enabling --git_checkout.\n" + ) sys.exit(1) releases = [args.release] else: - if args.release == 'all': + if args.release == "all": releases = client_matrix.get_release_tags(lang) else: # Build a particular release. - if args.release not in ['master' - ] + client_matrix.get_release_tags(lang): - jobset.message('SKIPPED', - '%s for %s is not defined' % - (args.release, lang), - do_newline=True) + if args.release not in ["master"] + client_matrix.get_release_tags( + lang + ): + jobset.message( + "SKIPPED", + "%s for %s is not defined" % (args.release, lang), + do_newline=True, + ) return [] releases = [args.release] images = [] for release in releases: images += build_all_images_for_release(lang, release) - jobset.message('SUCCESS', - 'All docker images built for %s at %s.' % (lang, releases), - do_newline=True) + jobset.message( + "SUCCESS", + "All docker images built for %s at %s." % (lang, releases), + do_newline=True, + ) return images @@ -198,14 +234,14 @@ def build_all_images_for_release(lang, release): env = {} # If we not using current tree or the sibling for grpc stack, do checkout. - stack_base = '' + stack_base = "" if args.git_checkout: stack_base = checkout_grpc_stack(lang, release) var = { - 'go': 'GRPC_GO_ROOT', - 'java': 'GRPC_JAVA_ROOT', - 'node': 'GRPC_NODE_ROOT' - }.get(lang, 'GRPC_ROOT') + "go": "GRPC_GO_ROOT", + "java": "GRPC_JAVA_ROOT", + "node": "GRPC_NODE_ROOT", + }.get(lang, "GRPC_ROOT") env[var] = stack_base for runtime in client_matrix.get_runtimes_for_lang_release(lang, release): @@ -213,28 +249,30 @@ def build_all_images_for_release(lang, release): docker_images.append(job.tag) build_jobs.append(job) - jobset.message('START', 'Building interop docker images.', do_newline=True) - print('Jobs to run: \n%s\n' % '\n'.join(str(j) for j in build_jobs)) + jobset.message("START", "Building interop docker images.", do_newline=True) + print("Jobs to run: \n%s\n" % "\n".join(str(j) for j in build_jobs)) - num_failures, _ = jobset.run(build_jobs, - newline_on_success=True, - maxjobs=multiprocessing.cpu_count()) + num_failures, _ = jobset.run( + build_jobs, newline_on_success=True, maxjobs=multiprocessing.cpu_count() + ) if num_failures: - jobset.message('FAILED', - 'Failed to build interop docker images.', - do_newline=True) + jobset.message( + "FAILED", "Failed to build interop docker images.", do_newline=True + ) docker_images_cleanup.extend(docker_images) sys.exit(1) - jobset.message('SUCCESS', - 'All docker images built for %s at %s.' % (lang, release), - do_newline=True) + jobset.message( + "SUCCESS", + "All docker images built for %s at %s." % (lang, release), + do_newline=True, + ) - if release != 'master': - commit_log = os.path.join(stack_base, 'commit_log') + if release != "master": + commit_log = os.path.join(stack_base, "commit_log") if os.path.exists(commit_log): for image in docker_images: - add_files_to_image(image, [commit_log], 'release=%s' % release) + add_files_to_image(image, [commit_log], "release=%s" % release) return docker_images @@ -256,16 +294,18 @@ def maybe_apply_patches_on_git_tag(stack_base, lang, release): files_to_patch = release_info.patch if not files_to_patch: return - patch_file_relative_path = 'patches/%s_%s/git_repo.patch' % (lang, release) + patch_file_relative_path = "patches/%s_%s/git_repo.patch" % (lang, release) patch_file = os.path.abspath( - os.path.join(os.path.dirname(__file__), patch_file_relative_path)) + os.path.join(os.path.dirname(__file__), patch_file_relative_path) + ) if not os.path.exists(patch_file): - jobset.message('FAILED', - 'expected patch file |%s| to exist' % patch_file) + jobset.message( + "FAILED", "expected patch file |%s| to exist" % patch_file + ) sys.exit(1) - subprocess.check_output(['git', 'apply', patch_file], - cwd=stack_base, - stderr=subprocess.STDOUT) + subprocess.check_output( + ["git", "apply", patch_file], cwd=stack_base, stderr=subprocess.STDOUT + ) # TODO(jtattermusch): this really would need simplification and refactoring # - "git add" and "git commit" can easily be done in a single command @@ -274,17 +314,23 @@ def maybe_apply_patches_on_git_tag(stack_base, lang, release): # - we only allow a single patch with name "git_repo.patch". A better design # would be to allow multiple patches that can have more descriptive names. for repo_relative_path in files_to_patch: - subprocess.check_output(['git', 'add', repo_relative_path], - cwd=stack_base, - stderr=subprocess.STDOUT) - subprocess.check_output([ - 'git', 'commit', '-m', - ('Hack performed on top of %s git ' - 'tag in order to build and run the %s ' - 'interop tests on that tag.' % (lang, release)) - ], - cwd=stack_base, - stderr=subprocess.STDOUT) + subprocess.check_output( + ["git", "add", repo_relative_path], + cwd=stack_base, + stderr=subprocess.STDOUT, + ) + subprocess.check_output( + [ + "git", + "commit", + "-m", + "Hack performed on top of %s git " + "tag in order to build and run the %s " + "interop tests on that tag." % (lang, release), + ], + cwd=stack_base, + stderr=subprocess.STDOUT, + ) def checkout_grpc_stack(lang, release): @@ -302,59 +348,71 @@ def checkout_grpc_stack(lang, release): # Clean up leftover repo dir if necessary. if not args.reuse_git_root and os.path.exists(stack_base): - jobset.message('START', 'Removing git checkout root.', do_newline=True) + jobset.message("START", "Removing git checkout root.", do_newline=True) shutil.rmtree(stack_base) if not os.path.exists(stack_base): - subprocess.check_call(['git', 'clone', '--recursive', repo], - cwd=os.path.dirname(stack_base)) + subprocess.check_call( + ["git", "clone", "--recursive", repo], + cwd=os.path.dirname(stack_base), + ) # git checkout. - jobset.message('START', - 'git checkout %s from %s' % (release, stack_base), - do_newline=True) + jobset.message( + "START", + "git checkout %s from %s" % (release, stack_base), + do_newline=True, + ) # We should NEVER do checkout on current tree !!! assert not os.path.dirname(__file__).startswith(stack_base) - output = subprocess.check_output(['git', 'checkout', release], - cwd=stack_base, - stderr=subprocess.STDOUT) + output = subprocess.check_output( + ["git", "checkout", release], cwd=stack_base, stderr=subprocess.STDOUT + ) maybe_apply_patches_on_git_tag(stack_base, lang, release) - commit_log = subprocess.check_output(['git', 'log', '-1'], cwd=stack_base) - jobset.message('SUCCESS', - 'git checkout', - '%s: %s' % (str(output), commit_log), - do_newline=True) + commit_log = subprocess.check_output(["git", "log", "-1"], cwd=stack_base) + jobset.message( + "SUCCESS", + "git checkout", + "%s: %s" % (str(output), commit_log), + do_newline=True, + ) # git submodule update - jobset.message('START', - 'git submodule update --init at %s from %s' % - (release, stack_base), - do_newline=True) - subprocess.check_call(['git', 'submodule', 'update', '--init'], - cwd=stack_base, - stderr=subprocess.STDOUT) - jobset.message('SUCCESS', - 'git submodule update --init', - '%s: %s' % (str(output), commit_log), - do_newline=True) + jobset.message( + "START", + "git submodule update --init at %s from %s" % (release, stack_base), + do_newline=True, + ) + subprocess.check_call( + ["git", "submodule", "update", "--init"], + cwd=stack_base, + stderr=subprocess.STDOUT, + ) + jobset.message( + "SUCCESS", + "git submodule update --init", + "%s: %s" % (str(output), commit_log), + do_newline=True, + ) # Write git log to commit_log so it can be packaged with the docker image. - with open(os.path.join(stack_base, 'commit_log'), 'wb') as f: + with open(os.path.join(stack_base, "commit_log"), "wb") as f: f.write(commit_log) return stack_base -languages = args.language if args.language != ['all'] else _LANGUAGES +languages = args.language if args.language != ["all"] else _LANGUAGES for lang in languages: docker_images = build_all_images_for_lang(lang) for image in docker_images: if args.upload_images: - jobset.message('START', 'Uploading %s' % image, do_newline=True) + jobset.message("START", "Uploading %s" % image, do_newline=True) # docker image name must be in the format /: - assert image.startswith(args.gcr_path) and image.find(':') != -1 - subprocess.call(['gcloud', 'docker', '--', 'push', image]) + assert image.startswith(args.gcr_path) and image.find(":") != -1 + subprocess.call(["gcloud", "docker", "--", "push", image]) else: # Uploading (and overwriting images) by default can easily break things. print( - 'Not uploading image %s, run with --upload_images to upload.' % - image) + "Not uploading image %s, run with --upload_images to upload." + % image + ) diff --git a/tools/interop_matrix/run_interop_matrix_tests.py b/tools/interop_matrix/run_interop_matrix_tests.py index 42f77d38e2f5c..abeae00fc22ed 100755 --- a/tools/interop_matrix/run_interop_matrix_tests.py +++ b/tools/interop_matrix/run_interop_matrix_tests.py @@ -30,7 +30,8 @@ import client_matrix python_util_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), '../run_tests/python_utils')) + os.path.join(os.path.dirname(__file__), "../run_tests/python_utils") +) sys.path.append(python_util_dir) import dockerjob import jobset @@ -44,68 +45,90 @@ # All gRPC release tags, flattened, deduped and sorted. _RELEASES = sorted( list( - set(release + set( + release for release_dict in list(client_matrix.LANG_RELEASE_MATRIX.values()) - for release in list(release_dict.keys())))) - -argp = argparse.ArgumentParser(description='Run interop tests.') -argp.add_argument('-j', '--jobs', default=multiprocessing.cpu_count(), type=int) -argp.add_argument('--gcr_path', - default='gcr.io/grpc-testing', - help='Path of docker images in Google Container Registry') -argp.add_argument('--release', - default='all', - choices=['all'] + _RELEASES, - help='Release tags to test. When testing all ' - 'releases defined in client_matrix.py, use "all".') -argp.add_argument('-l', - '--language', - choices=['all'] + sorted(_LANGUAGES), - nargs='+', - default=['all'], - help='Languages to test') + for release in list(release_dict.keys()) + ) + ) +) + +argp = argparse.ArgumentParser(description="Run interop tests.") +argp.add_argument("-j", "--jobs", default=multiprocessing.cpu_count(), type=int) +argp.add_argument( + "--gcr_path", + default="gcr.io/grpc-testing", + help="Path of docker images in Google Container Registry", +) +argp.add_argument( + "--release", + default="all", + choices=["all"] + _RELEASES, + help=( + "Release tags to test. When testing all " + 'releases defined in client_matrix.py, use "all".' + ), +) +argp.add_argument( + "-l", + "--language", + choices=["all"] + sorted(_LANGUAGES), + nargs="+", + default=["all"], + help="Languages to test", +) +argp.add_argument( + "--keep", + action="store_true", + help="keep the created local images after finishing the tests.", +) argp.add_argument( - '--keep', - action='store_true', - help='keep the created local images after finishing the tests.') -argp.add_argument('--report_file', - default='report.xml', - help='The result file to create.') -argp.add_argument('--allow_flakes', - default=False, - action='store_const', - const=True, - help=('Allow flaky tests to show as passing (re-runs failed ' - 'tests up to five times)')) -argp.add_argument('--bq_result_table', - default='', - type=str, - nargs='?', - help='Upload test results to a specified BQ table.') + "--report_file", default="report.xml", help="The result file to create." +) +argp.add_argument( + "--allow_flakes", + default=False, + action="store_const", + const=True, + help=( + "Allow flaky tests to show as passing (re-runs failed " + "tests up to five times)" + ), +) +argp.add_argument( + "--bq_result_table", + default="", + type=str, + nargs="?", + help="Upload test results to a specified BQ table.", +) # Requests will be routed through specified VIP by default. # See go/grpc-interop-tests (internal-only) for details. -argp.add_argument('--server_host', - default='74.125.206.210', - type=str, - nargs='?', - help='The gateway to backend services.') +argp.add_argument( + "--server_host", + default="74.125.206.210", + type=str, + nargs="?", + help="The gateway to backend services.", +) def _get_test_images_for_lang(lang, release_arg, image_path_prefix): """Find docker images for a language across releases and runtimes. - Returns dictionary of list of (, ) keyed by runtime. - """ - if release_arg == 'all': + Returns dictionary of list of (, ) keyed by runtime. + """ + if release_arg == "all": # Use all defined releases for given language releases = client_matrix.get_release_tags(lang) else: # Look for a particular release. if release_arg not in client_matrix.get_release_tags(lang): - jobset.message('SKIPPED', - 'release %s for %s is not defined' % - (release_arg, lang), - do_newline=True) + jobset.message( + "SKIPPED", + "release %s for %s is not defined" % (release_arg, lang), + do_newline=True, + ) return {} releases = [release_arg] @@ -113,8 +136,11 @@ def _get_test_images_for_lang(lang, release_arg, image_path_prefix): images = {} for tag in releases: for runtime in client_matrix.get_runtimes_for_lang_release(lang, tag): - image_name = '%s/grpc_interop_%s:%s' % (image_path_prefix, runtime, - tag) + image_name = "%s/grpc_interop_%s:%s" % ( + image_path_prefix, + runtime, + tag, + ) image_tuple = (tag, image_name) if runtime not in images: @@ -131,28 +157,29 @@ def _read_test_cases_file(lang, runtime, release): testcases_file = release_info.testcases_file if not testcases_file: # TODO(jtattermusch): remove the double-underscore, it is pointless - testcases_file = '%s__master' % lang + testcases_file = "%s__master" % lang # For csharp, the testcases file used depends on the runtime # TODO(jtattermusch): remove this odd specialcase - if lang == 'csharp' and runtime == 'csharpcoreclr': - testcases_file = testcases_file.replace('csharp_', 'csharpcoreclr_') + if lang == "csharp" and runtime == "csharpcoreclr": + testcases_file = testcases_file.replace("csharp_", "csharpcoreclr_") - testcases_filepath = os.path.join(os.path.dirname(__file__), 'testcases', - testcases_file) + testcases_filepath = os.path.join( + os.path.dirname(__file__), "testcases", testcases_file + ) lines = [] with open(testcases_filepath) as f: for line in f.readlines(): - line = re.sub('\\#.*$', '', line) # remove hash comments + line = re.sub("\\#.*$", "", line) # remove hash comments line = line.strip() - if line and not line.startswith('echo'): + if line and not line.startswith("echo"): # Each non-empty line is a treated as a test case command lines.append(line) return lines def _cleanup_docker_image(image): - jobset.message('START', 'Cleanup docker image %s' % image, do_newline=True) + jobset.message("START", "Cleanup docker image %s" % image, do_newline=True) dockerjob.remove_image(image, skip_nonexistent=True) @@ -171,35 +198,42 @@ def _generate_test_case_jobspecs(lang, runtime, release, suite_name): # what it currently being done seems fragile. # Extract test case name from the command line - m = re.search(r'--test_case=(\w+)', line) - testcase_name = m.group(1) if m else 'unknown_test' + m = re.search(r"--test_case=(\w+)", line) + testcase_name = m.group(1) if m else "unknown_test" # Extract the server name from the command line - if '--server_host_override=' in line: + if "--server_host_override=" in line: m = re.search( - r'--server_host_override=((.*).sandbox.googleapis.com)', line) + r"--server_host_override=((.*).sandbox.googleapis.com)", line + ) else: - m = re.search(r'--server_host=((.*).sandbox.googleapis.com)', line) - server = m.group(1) if m else 'unknown_server' - server_short = m.group(2) if m else 'unknown_server' + m = re.search(r"--server_host=((.*).sandbox.googleapis.com)", line) + server = m.group(1) if m else "unknown_server" + server_short = m.group(2) if m else "unknown_server" # replace original server_host argument - assert '--server_host=' in line - line = re.sub(r'--server_host=[^ ]*', - r'--server_host=%s' % args.server_host, line) + assert "--server_host=" in line + line = re.sub( + r"--server_host=[^ ]*", r"--server_host=%s" % args.server_host, line + ) # some interop tests don't set server_host_override (see #17407), # but we need to use it if different host is set via cmdline args. - if args.server_host != server and not '--server_host_override=' in line: - line = re.sub(r'(--server_host=[^ ]*)', - r'\1 --server_host_override=%s' % server, line) - - spec = jobset.JobSpec(cmdline=line, - shortname='%s:%s:%s:%s' % - (suite_name, lang, server_short, testcase_name), - timeout_seconds=_TEST_TIMEOUT_SECONDS, - shell=True, - flake_retries=5 if args.allow_flakes else 0) + if args.server_host != server and not "--server_host_override=" in line: + line = re.sub( + r"(--server_host=[^ ]*)", + r"\1 --server_host_override=%s" % server, + line, + ) + + spec = jobset.JobSpec( + cmdline=line, + shortname="%s:%s:%s:%s" + % (suite_name, lang, server_short, testcase_name), + timeout_seconds=_TEST_TIMEOUT_SECONDS, + shell=True, + flake_retries=5 if args.allow_flakes else 0, + ) job_spec_list.append(spec) return job_spec_list @@ -207,51 +241,61 @@ def _generate_test_case_jobspecs(lang, runtime, release, suite_name): def _pull_image_for_lang(lang, image, release): """Pull an image for a given language form the image registry.""" cmdline = [ - 'time gcloud docker -- pull %s && time docker run --rm=true %s /bin/true' - % (image, image) + "time gcloud docker -- pull %s && time docker run --rm=true %s" + " /bin/true" % (image, image) ] - return jobset.JobSpec(cmdline=cmdline, - shortname='pull_image_{}'.format(image), - timeout_seconds=_PULL_IMAGE_TIMEOUT_SECONDS, - shell=True, - flake_retries=2) + return jobset.JobSpec( + cmdline=cmdline, + shortname="pull_image_{}".format(image), + timeout_seconds=_PULL_IMAGE_TIMEOUT_SECONDS, + shell=True, + flake_retries=2, + ) def _test_release(lang, runtime, release, image, xml_report_tree, skip_tests): total_num_failures = 0 - suite_name = '%s__%s_%s' % (lang, runtime, release) - job_spec_list = _generate_test_case_jobspecs(lang, runtime, release, - suite_name) + suite_name = "%s__%s_%s" % (lang, runtime, release) + job_spec_list = _generate_test_case_jobspecs( + lang, runtime, release, suite_name + ) if not job_spec_list: - jobset.message('FAILED', 'No test cases were found.', do_newline=True) + jobset.message("FAILED", "No test cases were found.", do_newline=True) total_num_failures += 1 else: - num_failures, resultset = jobset.run(job_spec_list, - newline_on_success=True, - add_env={'docker_image': image}, - maxjobs=args.jobs, - skip_jobs=skip_tests) + num_failures, resultset = jobset.run( + job_spec_list, + newline_on_success=True, + add_env={"docker_image": image}, + maxjobs=args.jobs, + skip_jobs=skip_tests, + ) if args.bq_result_table and resultset: upload_test_results.upload_interop_results_to_bq( - resultset, args.bq_result_table) + resultset, args.bq_result_table + ) if skip_tests: - jobset.message('FAILED', 'Tests were skipped', do_newline=True) + jobset.message("FAILED", "Tests were skipped", do_newline=True) total_num_failures += 1 if num_failures: total_num_failures += num_failures - report_utils.append_junit_xml_results(xml_report_tree, resultset, - 'grpc_interop_matrix', suite_name, - str(uuid.uuid4())) + report_utils.append_junit_xml_results( + xml_report_tree, + resultset, + "grpc_interop_matrix", + suite_name, + str(uuid.uuid4()), + ) return total_num_failures def _run_tests_for_lang(lang, runtime, images, xml_report_tree): """Find and run all test cases for a language. - images is a list of (, ) tuple. - """ + images is a list of (, ) tuple. + """ skip_tests = False total_num_failures = 0 @@ -270,43 +314,45 @@ def _run_tests_for_lang(lang, runtime, images, xml_report_tree): # NOTE(rbellevi): We batch docker pull operations to maximize # parallelism, without letting the disk usage grow unbounded. - pull_failures, _ = jobset.run(pull_specs, - newline_on_success=True, - maxjobs=max_pull_jobs) + pull_failures, _ = jobset.run( + pull_specs, newline_on_success=True, maxjobs=max_pull_jobs + ) if pull_failures: jobset.message( - 'FAILED', - 'Image download failed. Skipping tests for language "%s"' % - lang, - do_newline=True) + "FAILED", + 'Image download failed. Skipping tests for language "%s"' + % lang, + do_newline=True, + ) skip_tests = True for release, image in images[chunk_start:chunk_end]: - total_num_failures += _test_release(lang, runtime, release, image, - xml_report_tree, skip_tests) + total_num_failures += _test_release( + lang, runtime, release, image, xml_report_tree, skip_tests + ) if not args.keep: for _, image in images[chunk_start:chunk_end]: _cleanup_docker_image(image) if not total_num_failures: - jobset.message('SUCCESS', - 'All {} tests passed'.format(lang), - do_newline=True) + jobset.message( + "SUCCESS", "All {} tests passed".format(lang), do_newline=True + ) else: - jobset.message('FAILED', - 'Some {} tests failed'.format(lang), - do_newline=True) + jobset.message( + "FAILED", "Some {} tests failed".format(lang), do_newline=True + ) return total_num_failures -languages = args.language if args.language != ['all'] else _LANGUAGES +languages = args.language if args.language != ["all"] else _LANGUAGES total_num_failures = 0 _xml_report_tree = report_utils.new_junit_xml_tree() for lang in languages: docker_images = _get_test_images_for_lang(lang, args.release, args.gcr_path) for runtime in sorted(docker_images.keys()): - total_num_failures += _run_tests_for_lang(lang, runtime, - docker_images[runtime], - _xml_report_tree) + total_num_failures += _run_tests_for_lang( + lang, runtime, docker_images[runtime], _xml_report_tree + ) report_utils.create_xml_report_file(_xml_report_tree, args.report_file) diff --git a/tools/mkowners/mkowners.py b/tools/mkowners/mkowners.py index ef09a999e3cb4..041609b4018e6 100755 --- a/tools/mkowners/mkowners.py +++ b/tools/mkowners/mkowners.py @@ -24,21 +24,26 @@ # Find the root of the git tree # -git_root = (subprocess.check_output(['git', 'rev-parse', '--show-toplevel' - ]).decode('utf-8').strip()) +git_root = ( + subprocess.check_output(["git", "rev-parse", "--show-toplevel"]) + .decode("utf-8") + .strip() +) # # Parse command line arguments # -default_out = os.path.join(git_root, '.github', 'CODEOWNERS') +default_out = os.path.join(git_root, ".github", "CODEOWNERS") -argp = argparse.ArgumentParser('Generate .github/CODEOWNERS file') -argp.add_argument('--out', - '-o', - type=str, - default=default_out, - help='Output file (default %s)' % default_out) +argp = argparse.ArgumentParser("Generate .github/CODEOWNERS file") +argp.add_argument( + "--out", + "-o", + type=str, + default=default_out, + help="Output file (default %s)" % default_out, +) args = argp.parse_args() # @@ -46,17 +51,17 @@ # owners_files = [ - os.path.join(root, 'OWNERS') + os.path.join(root, "OWNERS") for root, dirs, files in os.walk(git_root) - if 'OWNERS' in files + if "OWNERS" in files ] # # Parse owners files # -Owners = collections.namedtuple('Owners', 'parent directives dir') -Directive = collections.namedtuple('Directive', 'who globs') +Owners = collections.namedtuple("Owners", "parent directives dir") +Directive = collections.namedtuple("Directive", "who globs") def parse_owners(filename): @@ -69,29 +74,33 @@ def parse_owners(filename): # line := directive | comment if not line: continue - if line[0] == '#': + if line[0] == "#": continue # it's a directive directive = None - if line == 'set noparent': + if line == "set noparent": parent = False - elif line == '*': - directive = Directive(who='*', globs=[]) - elif ' ' in line: - (who, globs) = line.split(' ', 1) - globs_list = [glob for glob in globs.split(' ') if glob] + elif line == "*": + directive = Directive(who="*", globs=[]) + elif " " in line: + (who, globs) = line.split(" ", 1) + globs_list = [glob for glob in globs.split(" ") if glob] directive = Directive(who=who, globs=globs_list) else: directive = Directive(who=line, globs=[]) if directive: directives.append(directive) - return Owners(parent=parent, - directives=directives, - dir=os.path.relpath(os.path.dirname(filename), git_root)) + return Owners( + parent=parent, + directives=directives, + dir=os.path.relpath(os.path.dirname(filename), git_root), + ) -owners_data = sorted([parse_owners(filename) for filename in owners_files], - key=operator.attrgetter('dir')) +owners_data = sorted( + [parse_owners(filename) for filename in owners_files], + key=operator.attrgetter("dir"), +) # # Modify owners so that parented OWNERS files point to the actual @@ -109,7 +118,7 @@ def parse_owners(filename): rel = os.path.relpath(owners.dir, possible_parent.dir) # '..' ==> we had to walk up from possible_parent to get to owners # ==> not a parent - if '..' in rel: + if ".." in rel: continue depth = len(rel.split(os.sep)) if not best_parent or depth < best_parent_score: @@ -129,7 +138,7 @@ def parse_owners(filename): def full_dir(rules_dir, sub_path): - return os.path.join(rules_dir, sub_path) if rules_dir != '.' else sub_path + return os.path.join(rules_dir, sub_path) if rules_dir != "." else sub_path # glob using git @@ -141,9 +150,13 @@ def git_glob(glob): if glob in gg_cache: return gg_cache[glob] r = set( - subprocess.check_output([ - 'git', 'ls-files', os.path.join(git_root, glob) - ]).decode('utf-8').strip().splitlines()) + subprocess.check_output( + ["git", "ls-files", os.path.join(git_root, glob)] + ) + .decode("utf-8") + .strip() + .splitlines() + ) gg_cache[glob] = r return r @@ -152,15 +165,17 @@ def expand_directives(root, directives): globs = collections.OrderedDict() # build a table of glob --> owners for directive in directives: - for glob in directive.globs or ['**']: + for glob in directive.globs or ["**"]: if glob not in globs: globs[glob] = [] if directive.who not in globs[glob]: globs[glob].append(directive.who) # expand owners for intersecting globs - sorted_globs = sorted(list(globs.keys()), - key=lambda g: len(git_glob(full_dir(root, g))), - reverse=True) + sorted_globs = sorted( + list(globs.keys()), + key=lambda g: len(git_glob(full_dir(root, g))), + reverse=True, + ) out_globs = collections.OrderedDict() for glob_add in sorted_globs: who_add = globs[glob_add] @@ -193,8 +208,9 @@ def add_parent_to_globs(parent, globs, globs_dir): intersect = files_parent.intersection(files_child) gglob_who_orig = gglob_who.copy() if intersect: - for f in sorted(files_child - ): # sorted to ensure merge stability + for f in sorted( + files_child + ): # sorted to ensure merge stability if f not in intersect: who = gglob_who_orig.copy() globs[os.path.relpath(f, start=globs_dir)] = who @@ -203,15 +219,15 @@ def add_parent_to_globs(parent, globs, globs_dir): gglob_who.append(who) add_parent_to_globs(owners.parent, globs, globs_dir) return - assert (False) + assert False todo = owners_data.copy() done = set() -with open(args.out, 'w') as out: - out.write('# Auto-generated by the tools/mkowners/mkowners.py tool\n') - out.write('# Uses OWNERS files in different modules throughout the\n') - out.write('# repository as the source of truth for module ownership.\n') +with open(args.out, "w") as out: + out.write("# Auto-generated by the tools/mkowners/mkowners.py tool\n") + out.write("# Uses OWNERS files in different modules throughout the\n") + out.write("# repository as the source of truth for module ownership.\n") written_globs = [] while todo: head, *todo = todo @@ -235,7 +251,8 @@ def add_parent_to_globs(parent, globs, globs_dir): # affected differently by this rule and CODEOWNERS is order dependent break if not skip: - out.write('/%s %s\n' % - (full_dir(head.dir, glob), ' '.join(owners))) + out.write( + "/%s %s\n" % (full_dir(head.dir, glob), " ".join(owners)) + ) written_globs.append((glob, owners, head.dir)) done.add(head.dir) diff --git a/tools/profiling/bloat/bloat_diff.py b/tools/profiling/bloat/bloat_diff.py index dc4b96650e8a0..5ab47e40e8e26 100755 --- a/tools/profiling/bloat/bloat_diff.py +++ b/tools/profiling/bloat/bloat_diff.py @@ -26,39 +26,49 @@ import sys sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', '..', 'run_tests', - 'python_utils')) + os.path.join( + os.path.dirname(sys.argv[0]), "..", "..", "run_tests", "python_utils" + ) +) import check_on_pr -argp = argparse.ArgumentParser(description='Perform diff on microbenchmarks') +argp = argparse.ArgumentParser(description="Perform diff on microbenchmarks") -argp.add_argument('-d', - '--diff_base', - type=str, - help='Commit or branch to compare the current one to') +argp.add_argument( + "-d", + "--diff_base", + type=str, + help="Commit or branch to compare the current one to", +) -argp.add_argument('-j', '--jobs', type=int, default=multiprocessing.cpu_count()) +argp.add_argument("-j", "--jobs", type=int, default=multiprocessing.cpu_count()) args = argp.parse_args() # the libraries for which check bloat difference is calculated LIBS = [ - 'libgrpc.so', - 'libgrpc++.so', + "libgrpc.so", + "libgrpc++.so", ] def _build(output_dir): """Perform the cmake build under the output_dir.""" shutil.rmtree(output_dir, ignore_errors=True) - subprocess.check_call('mkdir -p %s' % output_dir, shell=True, cwd='.') - subprocess.check_call([ - 'cmake', '-DgRPC_BUILD_TESTS=OFF', '-DBUILD_SHARED_LIBS=ON', - '-DCMAKE_BUILD_TYPE=RelWithDebInfo', '-DCMAKE_C_FLAGS="-gsplit-dwarf"', - '-DCMAKE_CXX_FLAGS="-gsplit-dwarf"', '..' - ], - cwd=output_dir) - subprocess.check_call('make -j%d' % args.jobs, shell=True, cwd=output_dir) + subprocess.check_call("mkdir -p %s" % output_dir, shell=True, cwd=".") + subprocess.check_call( + [ + "cmake", + "-DgRPC_BUILD_TESTS=OFF", + "-DBUILD_SHARED_LIBS=ON", + "-DCMAKE_BUILD_TYPE=RelWithDebInfo", + '-DCMAKE_C_FLAGS="-gsplit-dwarf"', + '-DCMAKE_CXX_FLAGS="-gsplit-dwarf"', + "..", + ], + cwd=output_dir, + ) + subprocess.check_call("make -j%d" % args.jobs, shell=True, cwd=output_dir) def _rank_diff_bytes(diff_bytes): @@ -76,54 +86,72 @@ def _rank_diff_bytes(diff_bytes): return 3 * mul -_build('bloat_diff_new') +_build("bloat_diff_new") if args.diff_base: - where_am_i = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode().strip() + where_am_i = ( + subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + .decode() + .strip() + ) # checkout the diff base (="old") - subprocess.check_call(['git', 'checkout', args.diff_base]) - subprocess.check_call(['git', 'submodule', 'update']) + subprocess.check_call(["git", "checkout", args.diff_base]) + subprocess.check_call(["git", "submodule", "update"]) try: - _build('bloat_diff_old') + _build("bloat_diff_old") finally: # restore the original revision (="new") - subprocess.check_call(['git', 'checkout', where_am_i]) - subprocess.check_call(['git', 'submodule', 'update']) + subprocess.check_call(["git", "checkout", where_am_i]) + subprocess.check_call(["git", "submodule", "update"]) -pathlib.Path('bloaty-build').mkdir(exist_ok=True) +pathlib.Path("bloaty-build").mkdir(exist_ok=True) subprocess.check_call( - ['cmake', '-G', 'Unix Makefiles', '../third_party/bloaty'], - cwd='bloaty-build') -subprocess.check_call('make -j%d' % args.jobs, shell=True, cwd='bloaty-build') + ["cmake", "-G", "Unix Makefiles", "../third_party/bloaty"], + cwd="bloaty-build", +) +subprocess.check_call("make -j%d" % args.jobs, shell=True, cwd="bloaty-build") -text = '' +text = "" diff_size = 0 for lib in LIBS: - text += '****************************************************************\n\n' - text += lib + '\n\n' - old_version = glob.glob('bloat_diff_old/%s' % lib) - new_version = glob.glob('bloat_diff_new/%s' % lib) + text += ( + "****************************************************************\n\n" + ) + text += lib + "\n\n" + old_version = glob.glob("bloat_diff_old/%s" % lib) + new_version = glob.glob("bloat_diff_new/%s" % lib) for filename in [old_version, new_version]: if filename: - subprocess.check_call('strip %s -o %s.stripped' % - (filename[0], filename[0]), - shell=True) + subprocess.check_call( + "strip %s -o %s.stripped" % (filename[0], filename[0]), + shell=True, + ) assert len(new_version) == 1 - cmd = 'bloaty-build/bloaty -d compileunits,symbols' + cmd = "bloaty-build/bloaty -d compileunits,symbols" if old_version: assert len(old_version) == 1 text += subprocess.check_output( - '%s -n 0 --debug-file=%s --debug-file=%s %s.stripped -- %s.stripped' - % (cmd, new_version[0], old_version[0], new_version[0], - old_version[0]), - shell=True).decode() + "%s -n 0 --debug-file=%s --debug-file=%s %s.stripped -- %s.stripped" + % ( + cmd, + new_version[0], + old_version[0], + new_version[0], + old_version[0], + ), + shell=True, + ).decode() sections = [ - x for x in csv.reader( + x + for x in csv.reader( subprocess.check_output( - 'bloaty-build/bloaty -n 0 --csv %s -- %s' % - (new_version[0], old_version[0]), - shell=True).decode().splitlines()) + "bloaty-build/bloaty -n 0 --csv %s -- %s" + % (new_version[0], old_version[0]), + shell=True, + ) + .decode() + .splitlines() + ) ] print(sections) for section in sections[1:]: @@ -135,14 +163,16 @@ def _rank_diff_bytes(diff_bytes): continue diff_size += int(section[2]) else: - text += subprocess.check_output('%s %s.stripped -n 0 --debug-file=%s' % - (cmd, new_version[0], new_version[0]), - shell=True).decode() - text += '\n\n' + text += subprocess.check_output( + "%s %s.stripped -n 0 --debug-file=%s" + % (cmd, new_version[0], new_version[0]), + shell=True, + ).decode() + text += "\n\n" severity = _rank_diff_bytes(diff_size) print("SEVERITY: %d" % severity) print(text) -check_on_pr.check_on_pr('Bloat Difference', '```\n%s\n```' % text) -check_on_pr.label_significance_on_pr('bloat', severity) +check_on_pr.check_on_pr("Bloat Difference", "```\n%s\n```" % text) +check_on_pr.label_significance_on_pr("bloat", severity) diff --git a/tools/profiling/ios_bin/binary_size.py b/tools/profiling/ios_bin/binary_size.py index 5082fbe89647f..ee849fab8150a 100755 --- a/tools/profiling/ios_bin/binary_size.py +++ b/tools/profiling/ios_bin/binary_size.py @@ -25,22 +25,27 @@ from parse_link_map import parse_link_map sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', '..', 'run_tests', - 'python_utils')) + os.path.join( + os.path.dirname(sys.argv[0]), "..", "..", "run_tests", "python_utils" + ) +) import check_on_pr # Only show diff 1KB or greater _DIFF_THRESHOLD = 1000 -_SIZE_LABELS = ('Core', 'ObjC', 'BoringSSL', 'Protobuf', 'Total') +_SIZE_LABELS = ("Core", "ObjC", "BoringSSL", "Protobuf", "Total") argp = argparse.ArgumentParser( - description='Binary size diff of gRPC Objective-C sample') + description="Binary size diff of gRPC Objective-C sample" +) -argp.add_argument('-d', - '--diff_base', - type=str, - help='Commit or branch to compare the current one to') +argp.add_argument( + "-d", + "--diff_base", + type=str, + help="Commit or branch to compare the current one to", +) args = argp.parse_args() @@ -55,103 +60,117 @@ def dir_size(dir): def get_size(where): - build_dir = 'src/objective-c/examples/Sample/Build/Build-%s/' % where - link_map_filename = 'Build/Intermediates.noindex/Sample.build/Release-iphoneos/Sample.build/Sample-LinkMap-normal-arm64.txt' + build_dir = "src/objective-c/examples/Sample/Build/Build-%s/" % where + link_map_filename = "Build/Intermediates.noindex/Sample.build/Release-iphoneos/Sample.build/Sample-LinkMap-normal-arm64.txt" # IMPORTANT: order needs to match labels in _SIZE_LABELS return parse_link_map(build_dir + link_map_filename) def build(where): - subprocess.check_call(['make', 'clean']) - shutil.rmtree('src/objective-c/examples/Sample/Build/Build-%s' % where, - ignore_errors=True) + subprocess.check_call(["make", "clean"]) + shutil.rmtree( + "src/objective-c/examples/Sample/Build/Build-%s" % where, + ignore_errors=True, + ) subprocess.check_call( - 'CONFIG=opt EXAMPLE_PATH=src/objective-c/examples/Sample SCHEME=Sample ./build_one_example.sh', + ( + "CONFIG=opt EXAMPLE_PATH=src/objective-c/examples/Sample" + " SCHEME=Sample ./build_one_example.sh" + ), shell=True, - cwd='src/objective-c/tests') - os.rename('src/objective-c/examples/Sample/Build/Build', - 'src/objective-c/examples/Sample/Build/Build-%s' % where) + cwd="src/objective-c/tests", + ) + os.rename( + "src/objective-c/examples/Sample/Build/Build", + "src/objective-c/examples/Sample/Build/Build-%s" % where, + ) def _render_row(new, label, old): """Render row in 3-column output format.""" try: - formatted_new = '{:,}'.format(int(new)) + formatted_new = "{:,}".format(int(new)) except: formatted_new = new try: - formatted_old = '{:,}'.format(int(old)) + formatted_old = "{:,}".format(int(old)) except: formatted_old = old - return '{:>15}{:>15}{:>15}\n'.format(formatted_new, label, formatted_old) + return "{:>15}{:>15}{:>15}\n".format(formatted_new, label, formatted_old) def _diff_sign(new, old, diff_threshold=None): """Generate diff sign based on values""" - diff_sign = ' ' - if diff_threshold is not None and abs(new_size[i] - - old_size[i]) >= diff_threshold: - diff_sign += '!' + diff_sign = " " + if ( + diff_threshold is not None + and abs(new_size[i] - old_size[i]) >= diff_threshold + ): + diff_sign += "!" if new > old: - diff_sign += '(>)' + diff_sign += "(>)" elif new < old: - diff_sign += '(<)' + diff_sign += "(<)" else: - diff_sign += '(=)' + diff_sign += "(=)" return diff_sign -text = 'Objective-C binary sizes\n' +text = "Objective-C binary sizes\n" -build('new') -new_size = get_size('new') +build("new") +new_size = get_size("new") old_size = None if args.diff_base: - old = 'old' - where_am_i = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode().strip() - subprocess.check_call(['git', 'checkout', '--', '.']) - subprocess.check_call(['git', 'checkout', args.diff_base]) - subprocess.check_call(['git', 'submodule', 'update', '--force']) + old = "old" + where_am_i = ( + subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + .decode() + .strip() + ) + subprocess.check_call(["git", "checkout", "--", "."]) + subprocess.check_call(["git", "checkout", args.diff_base]) + subprocess.check_call(["git", "submodule", "update", "--force"]) try: - build('old') - old_size = get_size('old') + build("old") + old_size = get_size("old") finally: - subprocess.check_call(['git', 'checkout', '--', '.']) - subprocess.check_call(['git', 'checkout', where_am_i]) - subprocess.check_call(['git', 'submodule', 'update', '--force']) + subprocess.check_call(["git", "checkout", "--", "."]) + subprocess.check_call(["git", "checkout", where_am_i]) + subprocess.check_call(["git", "submodule", "update", "--force"]) -text += '**********************STATIC******************\n' -text += _render_row('New size', '', 'Old size') +text += "**********************STATIC******************\n" +text += _render_row("New size", "", "Old size") if old_size == None: for i in range(0, len(_SIZE_LABELS)): if i == len(_SIZE_LABELS) - 1: # skip line before rendering "Total" - text += '\n' - text += _render_row(new_size[i], _SIZE_LABELS[i], '') + text += "\n" + text += _render_row(new_size[i], _SIZE_LABELS[i], "") else: has_diff = False # go through all labels but "Total" for i in range(0, len(_SIZE_LABELS) - 1): if abs(new_size[i] - old_size[i]) >= _DIFF_THRESHOLD: has_diff = True - diff_sign = _diff_sign(new_size[i], - old_size[i], - diff_threshold=_DIFF_THRESHOLD) - text += _render_row(new_size[i], _SIZE_LABELS[i] + diff_sign, - old_size[i]) + diff_sign = _diff_sign( + new_size[i], old_size[i], diff_threshold=_DIFF_THRESHOLD + ) + text += _render_row( + new_size[i], _SIZE_LABELS[i] + diff_sign, old_size[i] + ) # render the "Total" i = len(_SIZE_LABELS) - 1 diff_sign = _diff_sign(new_size[i], old_size[i]) # skip line before rendering "Total" - text += '\n' + text += "\n" text += _render_row(new_size[i], _SIZE_LABELS[i] + diff_sign, old_size[i]) if not has_diff: - text += '\n No significant differences in binary sizes\n' -text += '\n' + text += "\n No significant differences in binary sizes\n" +text += "\n" print(text) -check_on_pr.check_on_pr('ObjC Binary Size', '```\n%s\n```' % text) +check_on_pr.check_on_pr("ObjC Binary Size", "```\n%s\n```" % text) diff --git a/tools/profiling/ios_bin/parse_link_map.py b/tools/profiling/ios_bin/parse_link_map.py index 1b0917a0693c2..36f095fac9bba 100755 --- a/tools/profiling/ios_bin/parse_link_map.py +++ b/tools/profiling/ios_bin/parse_link_map.py @@ -36,7 +36,7 @@ def parse_link_map(filename): objc_size = 0 protobuf_size = 0 - lines = open(filename, encoding='utf-8', errors='ignore').readlines() + lines = open(filename, encoding="utf-8", errors="ignore").readlines() for line in lines: line_stripped = line[:-1] if "# Object files:" == line_stripped: @@ -53,53 +53,66 @@ def parse_link_map(filename): continue if state == "object": - segs = re.search('(\[ *[0-9]*\]) (.*)', line_stripped) + segs = re.search("(\[ *[0-9]*\]) (.*)", line_stripped) table_tag[segs.group(1)] = segs.group(2) if state == "section": - if len(line_stripped) == 0 or line_stripped[0] == '#': + if len(line_stripped) == 0 or line_stripped[0] == "#": continue - segs = re.search('^(.+?)\s+(.+?)\s+.*', line_stripped) + segs = re.search("^(.+?)\s+(.+?)\s+.*", line_stripped) section_total_size += int(segs.group(2), 16) if state == "symbol": - if len(line_stripped) == 0 or line_stripped[0] == '#': + if len(line_stripped) == 0 or line_stripped[0] == "#": continue - segs = re.search('^.+?\s+(.+?)\s+(\[.+?\]).*', line_stripped) + segs = re.search("^.+?\s+(.+?)\s+(\[.+?\]).*", line_stripped) if not segs: continue target = table_tag[segs.group(2)] - target_stripped = re.search('^(.*?)(\(.+?\))?$', target).group(1) + target_stripped = re.search("^(.*?)(\(.+?\))?$", target).group(1) size = int(segs.group(1), 16) if not target_stripped in table_stats_symbol: table_stats_symbol[target_stripped] = 0 table_stats_symbol[target_stripped] += size - if 'BoringSSL' in target_stripped: + if "BoringSSL" in target_stripped: boringssl_size += size - elif 'libgRPC-Core' in target_stripped: + elif "libgRPC-Core" in target_stripped: core_size += size - elif 'libgRPC-RxLibrary' in target_stripped or \ - 'libgRPC' in target_stripped or \ - 'libgRPC-ProtoLibrary' in target_stripped: + elif ( + "libgRPC-RxLibrary" in target_stripped + or "libgRPC" in target_stripped + or "libgRPC-ProtoLibrary" in target_stripped + ): objc_size += size - elif 'libProtobuf' in target_stripped: + elif "libProtobuf" in target_stripped: protobuf_size += size for target in table_stats_symbol: symbol_total_size += table_stats_symbol[target] - return core_size, objc_size, boringssl_size, protobuf_size, symbol_total_size + return ( + core_size, + objc_size, + boringssl_size, + protobuf_size, + symbol_total_size, + ) def main(): filename = sys.argv[1] - core_size, objc_size, boringssl_size, protobuf_size, total_size = parse_link_map( - filename) - print(('Core size:{:,}'.format(core_size))) - print(('ObjC size:{:,}'.format(objc_size))) - print(('BoringSSL size:{:,}'.format(boringssl_size))) - print(('Protobuf size:{:,}\n'.format(protobuf_size))) - print(('Total size:{:,}'.format(total_size))) + ( + core_size, + objc_size, + boringssl_size, + protobuf_size, + total_size, + ) = parse_link_map(filename) + print("Core size:{:,}".format(core_size)) + print("ObjC size:{:,}".format(objc_size)) + print("BoringSSL size:{:,}".format(boringssl_size)) + print("Protobuf size:{:,}\n".format(protobuf_size)) + print("Total size:{:,}".format(total_size)) if __name__ == "__main__": diff --git a/tools/profiling/memory/memory_diff.py b/tools/profiling/memory/memory_diff.py index 2f0473f24d731..eb2af958e33bc 100755 --- a/tools/profiling/memory/memory_diff.py +++ b/tools/profiling/memory/memory_diff.py @@ -27,67 +27,88 @@ import sys sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', '..', 'run_tests', - 'python_utils')) + os.path.join( + os.path.dirname(sys.argv[0]), "..", "..", "run_tests", "python_utils" + ) +) import check_on_pr -argp = argparse.ArgumentParser(description='Perform diff on memory benchmarks') +argp = argparse.ArgumentParser(description="Perform diff on memory benchmarks") -argp.add_argument('-d', - '--diff_base', - type=str, - help='Commit or branch to compare the current one to') +argp.add_argument( + "-d", + "--diff_base", + type=str, + help="Commit or branch to compare the current one to", +) -argp.add_argument('-j', '--jobs', type=int, default=multiprocessing.cpu_count()) +argp.add_argument("-j", "--jobs", type=int, default=multiprocessing.cpu_count()) args = argp.parse_args() _INTERESTING = { - 'call/client': - (rb'client call memory usage: ([0-9\.]+) bytes per call', float), - 'call/server': - (rb'server call memory usage: ([0-9\.]+) bytes per call', float), - 'channel/client': - (rb'client channel memory usage: ([0-9\.]+) bytes per channel', float), - 'channel/server': - (rb'server channel memory usage: ([0-9\.]+) bytes per channel', float), + "call/client": ( + rb"client call memory usage: ([0-9\.]+) bytes per call", + float, + ), + "call/server": ( + rb"server call memory usage: ([0-9\.]+) bytes per call", + float, + ), + "channel/client": ( + rb"client channel memory usage: ([0-9\.]+) bytes per channel", + float, + ), + "channel/server": ( + rb"server channel memory usage: ([0-9\.]+) bytes per channel", + float, + ), } _SCENARIOS = { - 'default': [], - 'minstack': ['--scenario_config=minstack'], + "default": [], + "minstack": ["--scenario_config=minstack"], } _BENCHMARKS = { - 'call': ['--benchmark_names=call', '--size=50000'], - 'channel': ['--benchmark_names=channel', '--size=10000'], + "call": ["--benchmark_names=call", "--size=50000"], + "channel": ["--benchmark_names=channel", "--size=10000"], } def _run(): """Build with Bazel, then run, and extract interesting lines from the output.""" - subprocess.check_call([ - 'tools/bazel', 'build', '-c', 'opt', - 'test/core/memory_usage/memory_usage_test' - ]) + subprocess.check_call( + [ + "tools/bazel", + "build", + "-c", + "opt", + "test/core/memory_usage/memory_usage_test", + ] + ) ret = {} for name, benchmark_args in _BENCHMARKS.items(): for scenario, extra_args in _SCENARIOS.items(): - #TODO(chenancy) Remove when minstack is implemented for channel - if name == 'channel' and scenario == 'minstack': + # TODO(chenancy) Remove when minstack is implemented for channel + if name == "channel" and scenario == "minstack": continue try: - output = subprocess.check_output([ - 'bazel-bin/test/core/memory_usage/memory_usage_test', - ] + benchmark_args + extra_args) + output = subprocess.check_output( + [ + "bazel-bin/test/core/memory_usage/memory_usage_test", + ] + + benchmark_args + + extra_args + ) except subprocess.CalledProcessError as e: - print('Error running benchmark:', e) + print("Error running benchmark:", e) continue for line in output.splitlines(): for key, (pattern, conversion) in _INTERESTING.items(): m = re.match(pattern, line) if m: - ret[scenario + ': ' + key] = conversion(m.group(1)) + ret[scenario + ": " + key] = conversion(m.group(1)) return ret @@ -95,45 +116,50 @@ def _run(): old = None if args.diff_base: - where_am_i = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode().strip() + where_am_i = ( + subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + .decode() + .strip() + ) # checkout the diff base (="old") - subprocess.check_call(['git', 'checkout', args.diff_base]) + subprocess.check_call(["git", "checkout", args.diff_base]) try: old = _run() finally: # restore the original revision (="cur") - subprocess.check_call(['git', 'checkout', where_am_i]) + subprocess.check_call(["git", "checkout", where_am_i]) -text = '' +text = "" if old is None: print(cur) for key, value in sorted(cur.items()): - text += '{}: {}\n'.format(key, value) + text += "{}: {}\n".format(key, value) else: print(cur, old) call_diff_size = 0 channel_diff_size = 0 for scenario in _SCENARIOS.keys(): for key, value in sorted(_INTERESTING.items()): - key = scenario + ': ' + key + key = scenario + ": " + key if key in cur: if key not in old: - text += '{}: {}\n'.format(key, cur[key]) + text += "{}: {}\n".format(key, cur[key]) else: - text += '{}: {} -> {}\n'.format(key, old[key], cur[key]) - if 'call' in key: + text += "{}: {} -> {}\n".format(key, old[key], cur[key]) + if "call" in key: call_diff_size += cur[key] - old[key] else: channel_diff_size += cur[key] - old[key] print("CALL_DIFF_SIZE: %f" % call_diff_size) print("CHANNEL_DIFF_SIZE: %f" % channel_diff_size) - check_on_pr.label_increase_decrease_on_pr('per-call-memory', call_diff_size, - 64) - check_on_pr.label_increase_decrease_on_pr('per-channel-memory', - channel_diff_size, 1000) - #TODO(chennancy)Change significant value when minstack also runs for channel + check_on_pr.label_increase_decrease_on_pr( + "per-call-memory", call_diff_size, 64 + ) + check_on_pr.label_increase_decrease_on_pr( + "per-channel-memory", channel_diff_size, 1000 + ) + # TODO(chennancy)Change significant value when minstack also runs for channel print(text) -check_on_pr.check_on_pr('Memory Difference', '```\n%s\n```' % text) +check_on_pr.check_on_pr("Memory Difference", "```\n%s\n```" % text) diff --git a/tools/profiling/microbenchmarks/bm2bq.py b/tools/profiling/microbenchmarks/bm2bq.py index 705f1babfaeb6..dfc11aac82783 100755 --- a/tools/profiling/microbenchmarks/bm2bq.py +++ b/tools/profiling/microbenchmarks/bm2bq.py @@ -27,24 +27,25 @@ columns = [] for row in json.loads( - # TODO(jtattermusch): make sure the dataset name is not hardcoded - subprocess.check_output( - ['bq', '--format=json', 'show', - 'microbenchmarks.microbenchmarks']))['schema']['fields']: - columns.append((row['name'], row['type'].lower())) + # TODO(jtattermusch): make sure the dataset name is not hardcoded + subprocess.check_output( + ["bq", "--format=json", "show", "microbenchmarks.microbenchmarks"] + ) +)["schema"]["fields"]: + columns.append((row["name"], row["type"].lower())) SANITIZE = { - 'integer': int, - 'float': float, - 'boolean': bool, - 'string': str, - 'timestamp': str, + "integer": int, + "float": float, + "boolean": bool, + "string": str, + "timestamp": str, } # TODO(jtattermusch): add proper argparse argument, rather than trying # to emulate with manual argv inspection. -if sys.argv[1] == '--schema': - print(',\n'.join('%s:%s' % (k, t.upper()) for k, t in columns)) +if sys.argv[1] == "--schema": + print(",\n".join("%s:%s" % (k, t.upper()) for k, t in columns)) sys.exit(0) with open(sys.argv[1]) as f: @@ -63,7 +64,7 @@ sane_row = {} for name, sql_type in columns: if name in row: - if row[name] == '': + if row[name] == "": continue sane_row[name] = SANITIZE[sql_type](row[name]) writer.writerow(sane_row) diff --git a/tools/profiling/microbenchmarks/bm_diff/bm_build.py b/tools/profiling/microbenchmarks/bm_diff/bm_build.py index 33c2a57146ccc..2cc2a8479fbe6 100755 --- a/tools/profiling/microbenchmarks/bm_diff/bm_build.py +++ b/tools/profiling/microbenchmarks/bm_diff/bm_build.py @@ -25,26 +25,33 @@ def _args(): - argp = argparse.ArgumentParser(description='Builds microbenchmarks') - argp.add_argument('-b', - '--benchmarks', - nargs='+', - choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, - default=bm_constants._AVAILABLE_BENCHMARK_TESTS, - help='Which benchmarks to build') + argp = argparse.ArgumentParser(description="Builds microbenchmarks") argp.add_argument( - '-j', - '--jobs', + "-b", + "--benchmarks", + nargs="+", + choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, + default=bm_constants._AVAILABLE_BENCHMARK_TESTS, + help="Which benchmarks to build", + ) + argp.add_argument( + "-j", + "--jobs", type=int, default=multiprocessing.cpu_count(), - help= - 'Deprecated. Bazel chooses number of CPUs to build with automatically.') + help=( + "Deprecated. Bazel chooses number of CPUs to build with" + " automatically." + ), + ) argp.add_argument( - '-n', - '--name', + "-n", + "--name", type=str, - help= - 'Unique name of this build. To be used as a handle to pass to the other bm* scripts' + help=( + "Unique name of this build. To be used as a handle to pass to the" + " other bm* scripts" + ), ) args = argp.parse_args() assert args.name @@ -53,31 +60,39 @@ def _args(): def _build_cmd(cfg, benchmarks): bazel_targets = [ - '//test/cpp/microbenchmarks:%s' % benchmark for benchmark in benchmarks + "//test/cpp/microbenchmarks:%s" % benchmark for benchmark in benchmarks ] # --dynamic_mode=off makes sure that we get a monolithic binary that can be safely # moved outside of the bazel-bin directory - return ['tools/bazel', 'build', - '--config=%s' % cfg, '--dynamic_mode=off'] + bazel_targets + return [ + "tools/bazel", + "build", + "--config=%s" % cfg, + "--dynamic_mode=off", + ] + bazel_targets def _build_config_and_copy(cfg, benchmarks, dest_dir): """Build given config and copy resulting binaries to dest_dir/CONFIG""" subprocess.check_call(_build_cmd(cfg, benchmarks)) - cfg_dir = dest_dir + '/%s' % cfg + cfg_dir = dest_dir + "/%s" % cfg os.makedirs(cfg_dir) - subprocess.check_call(['cp'] + [ - 'bazel-bin/test/cpp/microbenchmarks/%s' % benchmark - for benchmark in benchmarks - ] + [cfg_dir]) + subprocess.check_call( + ["cp"] + + [ + "bazel-bin/test/cpp/microbenchmarks/%s" % benchmark + for benchmark in benchmarks + ] + + [cfg_dir] + ) def build(name, benchmarks, jobs): - dest_dir = 'bm_diff_%s' % name + dest_dir = "bm_diff_%s" % name shutil.rmtree(dest_dir, ignore_errors=True) - _build_config_and_copy('opt', benchmarks, dest_dir) + _build_config_and_copy("opt", benchmarks, dest_dir) -if __name__ == '__main__': +if __name__ == "__main__": args = _args() build(args.name, args.benchmarks, args.jobs) diff --git a/tools/profiling/microbenchmarks/bm_diff/bm_constants.py b/tools/profiling/microbenchmarks/bm_diff/bm_constants.py index 99b4b2413bc8b..a8a1246bc6057 100644 --- a/tools/profiling/microbenchmarks/bm_diff/bm_constants.py +++ b/tools/profiling/microbenchmarks/bm_diff/bm_constants.py @@ -16,22 +16,29 @@ """ Configurable constants for the bm_*.py family """ _AVAILABLE_BENCHMARK_TESTS = [ - 'bm_fullstack_unary_ping_pong', - 'bm_fullstack_streaming_ping_pong', - 'bm_fullstack_streaming_pump', - 'bm_closure', - 'bm_cq', - 'bm_call_create', - 'bm_chttp2_hpack', - 'bm_chttp2_transport', - 'bm_pollset', + "bm_fullstack_unary_ping_pong", + "bm_fullstack_streaming_ping_pong", + "bm_fullstack_streaming_pump", + "bm_closure", + "bm_cq", + "bm_call_create", + "bm_chttp2_hpack", + "bm_chttp2_transport", + "bm_pollset", ] -_INTERESTING = ('cpu_time', 'real_time', 'locks_per_iteration', - 'allocs_per_iteration', 'writes_per_iteration', - 'atm_cas_per_iteration', 'atm_add_per_iteration', - 'nows_per_iteration', 'cli_transport_stalls_per_iteration', - 'cli_stream_stalls_per_iteration', - 'svr_transport_stalls_per_iteration', - 'svr_stream_stalls_per_iteration', - 'http2_pings_sent_per_iteration') +_INTERESTING = ( + "cpu_time", + "real_time", + "locks_per_iteration", + "allocs_per_iteration", + "writes_per_iteration", + "atm_cas_per_iteration", + "atm_add_per_iteration", + "nows_per_iteration", + "cli_transport_stalls_per_iteration", + "cli_stream_stalls_per_iteration", + "svr_transport_stalls_per_iteration", + "svr_stream_stalls_per_iteration", + "http2_pings_sent_per_iteration", +) diff --git a/tools/profiling/microbenchmarks/bm_diff/bm_diff.py b/tools/profiling/microbenchmarks/bm_diff/bm_diff.py index 647547a0c6d10..c34b915491efe 100755 --- a/tools/profiling/microbenchmarks/bm_diff/bm_diff.py +++ b/tools/profiling/microbenchmarks/bm_diff/bm_diff.py @@ -22,7 +22,7 @@ import subprocess import sys -sys.path.append(os.path.join(os.path.dirname(sys.argv[0]), '..')) +sys.path.append(os.path.join(os.path.dirname(sys.argv[0]), "..")) import bm_constants import bm_json @@ -33,7 +33,7 @@ def _median(ary): - assert (len(ary)) + assert len(ary) ary = sorted(ary) n = len(ary) if n % 2 == 0: @@ -44,38 +44,46 @@ def _median(ary): def _args(): argp = argparse.ArgumentParser( - description='Perform diff on microbenchmarks') - argp.add_argument('-t', - '--track', - choices=sorted(bm_constants._INTERESTING), - nargs='+', - default=sorted(bm_constants._INTERESTING), - help='Which metrics to track') - argp.add_argument('-b', - '--benchmarks', - nargs='+', - choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, - default=bm_constants._AVAILABLE_BENCHMARK_TESTS, - help='Which benchmarks to run') + description="Perform diff on microbenchmarks" + ) + argp.add_argument( + "-t", + "--track", + choices=sorted(bm_constants._INTERESTING), + nargs="+", + default=sorted(bm_constants._INTERESTING), + help="Which metrics to track", + ) argp.add_argument( - '-l', - '--loops', + "-b", + "--benchmarks", + nargs="+", + choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, + default=bm_constants._AVAILABLE_BENCHMARK_TESTS, + help="Which benchmarks to run", + ) + argp.add_argument( + "-l", + "--loops", type=int, default=20, - help= - 'Number of times to loops the benchmarks. Must match what was passed to bm_run.py' + help=( + "Number of times to loops the benchmarks. Must match what was" + " passed to bm_run.py" + ), + ) + argp.add_argument( + "-r", + "--regex", + type=str, + default="", + help="Regex to filter benchmarks run", + ) + argp.add_argument("-n", "--new", type=str, help="New benchmark name") + argp.add_argument("-o", "--old", type=str, help="Old benchmark name") + argp.add_argument( + "-v", "--verbose", type=bool, help="Print details of before/after" ) - argp.add_argument('-r', - '--regex', - type=str, - default="", - help='Regex to filter benchmarks run') - argp.add_argument('-n', '--new', type=str, help='New benchmark name') - argp.add_argument('-o', '--old', type=str, help='Old benchmark name') - argp.add_argument('-v', - '--verbose', - type=bool, - help='Print details of before/after') args = argp.parse_args() global verbose if args.verbose: @@ -91,11 +99,10 @@ def _maybe_print(str): class Benchmark: - def __init__(self): self.samples = { True: collections.defaultdict(list), - False: collections.defaultdict(list) + False: collections.defaultdict(list), } self.final = {} self.speedup = {} @@ -112,20 +119,22 @@ def process(self, track, new_name, old_name): if not new or not old: continue mdn_diff = abs(_median(new) - _median(old)) - _maybe_print('%s: %s=%r %s=%r mdn_diff=%r' % - (f, new_name, new, old_name, old, mdn_diff)) + _maybe_print( + "%s: %s=%r %s=%r mdn_diff=%r" + % (f, new_name, new, old_name, old, mdn_diff) + ) s = bm_speedup.speedup(new, old, 1e-5) self.speedup[f] = s if abs(s) > 3: if mdn_diff > 0.5: - self.final[f] = '%+d%%' % s + self.final[f] = "%+d%%" % s return self.final.keys() def skip(self): return not self.final def row(self, flds): - return [self.final[f] if f in self.final else '' for f in flds] + return [self.final[f] if f in self.final else "" for f in flds] def speedup(self, name): if name in self.speedup: @@ -155,7 +164,7 @@ def _read_json(filename, badjson_files, nonexistant_files): def fmt_dict(d): - return ''.join([" " + k + ": " + str(d[k]) + "\n" for k in d]) + return "".join([" " + k + ": " + str(d[k]) + "\n" for k in d]) def diff(bms, loops, regex, track, old, new): @@ -165,29 +174,41 @@ def diff(bms, loops, regex, track, old, new): nonexistant_files = {} for bm in bms: for loop in range(0, loops): - for line in subprocess.check_output([ - 'bm_diff_%s/opt/%s' % (old, bm), '--benchmark_list_tests', - '--benchmark_filter=%s' % regex - ]).splitlines(): - line = line.decode('UTF-8') - stripped_line = line.strip().replace("/", "_").replace( - "<", "_").replace(">", "_").replace(", ", "_") + for line in subprocess.check_output( + [ + "bm_diff_%s/opt/%s" % (old, bm), + "--benchmark_list_tests", + "--benchmark_filter=%s" % regex, + ] + ).splitlines(): + line = line.decode("UTF-8") + stripped_line = ( + line.strip() + .replace("/", "_") + .replace("<", "_") + .replace(">", "_") + .replace(", ", "_") + ) js_new_opt = _read_json( - '%s.%s.opt.%s.%d.json' % (bm, stripped_line, new, loop), - badjson_files, nonexistant_files) + "%s.%s.opt.%s.%d.json" % (bm, stripped_line, new, loop), + badjson_files, + nonexistant_files, + ) js_old_opt = _read_json( - '%s.%s.opt.%s.%d.json' % (bm, stripped_line, old, loop), - badjson_files, nonexistant_files) + "%s.%s.opt.%s.%d.json" % (bm, stripped_line, old, loop), + badjson_files, + nonexistant_files, + ) if js_new_opt: for row in bm_json.expand_json(js_new_opt): - name = row['cpp_name'] - if name.endswith('_mean') or name.endswith('_stddev'): + name = row["cpp_name"] + if name.endswith("_mean") or name.endswith("_stddev"): continue benchmarks[name].add_sample(track, row, True) if js_old_opt: for row in bm_json.expand_json(js_old_opt): - name = row['cpp_name'] - if name.endswith('_mean') or name.endswith('_stddev'): + name = row["cpp_name"] + if name.endswith("_mean") or name.endswith("_stddev"): continue benchmarks[name].add_sample(track, row, False) @@ -203,11 +224,12 @@ def diff(bms, loops, regex, track, old, new): _NOISY = ["BM_WellFlushed"] for name, bm in benchmarks.items(): if name in _NOISY: - print("skipping noisy benchmark '%s' for labelling evaluation" % - name) + print( + "skipping noisy benchmark '%s' for labelling evaluation" % name + ) if bm.skip(): continue - d = bm.speedup['cpu_time'] + d = bm.speedup["cpu_time"] if d is None: continue histogram.append(d) @@ -231,7 +253,7 @@ def diff(bms, loops, regex, track, old, new): significance = 3 significance *= mul - headers = ['Benchmark'] + fields + headers = ["Benchmark"] + fields rows = [] for name in sorted(benchmarks.keys()): if benchmarks[name].skip(): @@ -239,24 +261,40 @@ def diff(bms, loops, regex, track, old, new): rows.append([name] + benchmarks[name].row(fields)) note = None if len(badjson_files): - note = 'Corrupt JSON data (indicates timeout or crash): \n%s' % fmt_dict( - badjson_files) + note = ( + "Corrupt JSON data (indicates timeout or crash): \n%s" + % fmt_dict(badjson_files) + ) if len(nonexistant_files): if note: - note += '\n\nMissing files (indicates new benchmark): \n%s' % fmt_dict( - nonexistant_files) + note += ( + "\n\nMissing files (indicates new benchmark): \n%s" + % fmt_dict(nonexistant_files) + ) else: - note = '\n\nMissing files (indicates new benchmark): \n%s' % fmt_dict( - nonexistant_files) + note = ( + "\n\nMissing files (indicates new benchmark): \n%s" + % fmt_dict(nonexistant_files) + ) if rows: - return tabulate.tabulate(rows, headers=headers, - floatfmt='+.2f'), note, significance + return ( + tabulate.tabulate(rows, headers=headers, floatfmt="+.2f"), + note, + significance, + ) else: return None, note, 0 -if __name__ == '__main__': +if __name__ == "__main__": args = _args() - diff, note = diff(args.benchmarks, args.loops, args.regex, args.track, - args.old, args.new, args.counters) - print('%s\n%s' % (note, diff if diff else "No performance differences")) + diff, note = diff( + args.benchmarks, + args.loops, + args.regex, + args.track, + args.old, + args.new, + args.counters, + ) + print("%s\n%s" % (note, diff if diff else "No performance differences")) diff --git a/tools/profiling/microbenchmarks/bm_diff/bm_main.py b/tools/profiling/microbenchmarks/bm_diff/bm_main.py index 26a04faf659e6..429eca97b5de1 100755 --- a/tools/profiling/microbenchmarks/bm_diff/bm_main.py +++ b/tools/profiling/microbenchmarks/bm_diff/bm_main.py @@ -23,12 +23,21 @@ import sys sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', '..', 'run_tests', - 'python_utils')) + os.path.join( + os.path.dirname(sys.argv[0]), "..", "..", "run_tests", "python_utils" + ) +) sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', '..', '..', 'run_tests', - 'python_utils')) + os.path.join( + os.path.dirname(sys.argv[0]), + "..", + "..", + "..", + "run_tests", + "python_utils", + ) +) import bm_build import bm_constants @@ -40,51 +49,67 @@ def _args(): argp = argparse.ArgumentParser( - description='Perform diff on microbenchmarks') - argp.add_argument('-t', - '--track', - choices=sorted(bm_constants._INTERESTING), - nargs='+', - default=sorted(bm_constants._INTERESTING), - help='Which metrics to track') - argp.add_argument('-b', - '--benchmarks', - nargs='+', - choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, - default=bm_constants._AVAILABLE_BENCHMARK_TESTS, - help='Which benchmarks to run') - argp.add_argument('-d', - '--diff_base', - type=str, - help='Commit or branch to compare the current one to') + description="Perform diff on microbenchmarks" + ) + argp.add_argument( + "-t", + "--track", + choices=sorted(bm_constants._INTERESTING), + nargs="+", + default=sorted(bm_constants._INTERESTING), + help="Which metrics to track", + ) + argp.add_argument( + "-b", + "--benchmarks", + nargs="+", + choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, + default=bm_constants._AVAILABLE_BENCHMARK_TESTS, + help="Which benchmarks to run", + ) + argp.add_argument( + "-d", + "--diff_base", + type=str, + help="Commit or branch to compare the current one to", + ) + argp.add_argument( + "-o", + "--old", + default="old", + type=str, + help='Name of baseline run to compare to. Usually just called "old"', + ) argp.add_argument( - '-o', - '--old', - default='old', + "-r", + "--regex", type=str, - help='Name of baseline run to compare to. Usually just called "old"') - argp.add_argument('-r', - '--regex', - type=str, - default="", - help='Regex to filter benchmarks run') + default="", + help="Regex to filter benchmarks run", + ) argp.add_argument( - '-l', - '--loops', + "-l", + "--loops", type=int, default=10, - help= - 'Number of times to loops the benchmarks. More loops cuts down on noise' + help=( + "Number of times to loops the benchmarks. More loops cuts down on" + " noise" + ), + ) + argp.add_argument( + "-j", + "--jobs", + type=int, + default=multiprocessing.cpu_count(), + help="Number of CPUs to use", + ) + argp.add_argument( + "--pr_comment_name", + type=str, + default="microbenchmarks", + help="Name that Jenkins will use to comment on the PR", ) - argp.add_argument('-j', - '--jobs', - type=int, - default=multiprocessing.cpu_count(), - help='Number of CPUs to use') - argp.add_argument('--pr_comment_name', - type=str, - default="microbenchmarks", - help='Name that Jenkins will use to comment on the PR') args = argp.parse_args() assert args.diff_base or args.old, "One of diff_base or old must be set!" if args.loops < 3: @@ -107,44 +132,51 @@ def inner(*args): def main(args): - - bm_build.build('new', args.benchmarks, args.jobs) + bm_build.build("new", args.benchmarks, args.jobs) old = args.old if args.diff_base: - old = 'old' + old = "old" where_am_i = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip() - subprocess.check_call(['git', 'checkout', args.diff_base]) + ["git", "rev-parse", "--abbrev-ref", "HEAD"] + ).strip() + subprocess.check_call(["git", "checkout", args.diff_base]) try: bm_build.build(old, args.benchmarks, args.jobs) finally: - subprocess.check_call(['git', 'checkout', where_am_i]) - subprocess.check_call(['git', 'submodule', 'update']) + subprocess.check_call(["git", "checkout", where_am_i]) + subprocess.check_call(["git", "submodule", "update"]) jobs_list = [] - jobs_list += bm_run.create_jobs('new', args.benchmarks, args.loops, - args.regex) - jobs_list += bm_run.create_jobs(old, args.benchmarks, args.loops, - args.regex) + jobs_list += bm_run.create_jobs( + "new", args.benchmarks, args.loops, args.regex + ) + jobs_list += bm_run.create_jobs( + old, args.benchmarks, args.loops, args.regex + ) # shuffle all jobs to eliminate noise from GCE CPU drift random.shuffle(jobs_list, random.SystemRandom().random) jobset.run(jobs_list, maxjobs=args.jobs) - diff, note, significance = bm_diff.diff(args.benchmarks, args.loops, - args.regex, args.track, old, 'new') + diff, note, significance = bm_diff.diff( + args.benchmarks, args.loops, args.regex, args.track, old, "new" + ) if diff: - text = '[%s] Performance differences noted:\n%s' % ( - args.pr_comment_name, diff) + text = "[%s] Performance differences noted:\n%s" % ( + args.pr_comment_name, + diff, + ) else: - text = '[%s] No significant performance differences' % args.pr_comment_name + text = ( + "[%s] No significant performance differences" % args.pr_comment_name + ) if note: - text = note + '\n\n' + text - print('%s' % text) - check_on_pr.check_on_pr('Benchmark', '```\n%s\n```' % text) + text = note + "\n\n" + text + print("%s" % text) + check_on_pr.check_on_pr("Benchmark", "```\n%s\n```" % text) -if __name__ == '__main__': +if __name__ == "__main__": args = _args() main(args) diff --git a/tools/profiling/microbenchmarks/bm_diff/bm_run.py b/tools/profiling/microbenchmarks/bm_diff/bm_run.py index 85a8f5e5dd489..25b2772b95ae0 100755 --- a/tools/profiling/microbenchmarks/bm_diff/bm_run.py +++ b/tools/profiling/microbenchmarks/bm_diff/bm_run.py @@ -27,78 +27,107 @@ import jobset sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', '..', '..', 'run_tests', - 'python_utils')) + os.path.join( + os.path.dirname(sys.argv[0]), + "..", + "..", + "..", + "run_tests", + "python_utils", + ) +) def _args(): - argp = argparse.ArgumentParser(description='Runs microbenchmarks') - argp.add_argument('-b', - '--benchmarks', - nargs='+', - choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, - default=bm_constants._AVAILABLE_BENCHMARK_TESTS, - help='Benchmarks to run') - argp.add_argument('-j', - '--jobs', - type=int, - default=multiprocessing.cpu_count(), - help='Number of CPUs to use') + argp = argparse.ArgumentParser(description="Runs microbenchmarks") + argp.add_argument( + "-b", + "--benchmarks", + nargs="+", + choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, + default=bm_constants._AVAILABLE_BENCHMARK_TESTS, + help="Benchmarks to run", + ) + argp.add_argument( + "-j", + "--jobs", + type=int, + default=multiprocessing.cpu_count(), + help="Number of CPUs to use", + ) argp.add_argument( - '-n', - '--name', + "-n", + "--name", type=str, - help= - 'Unique name of the build to run. Needs to match the handle passed to bm_build.py' + help=( + "Unique name of the build to run. Needs to match the handle passed" + " to bm_build.py" + ), ) - argp.add_argument('-r', - '--regex', - type=str, - default="", - help='Regex to filter benchmarks run') argp.add_argument( - '-l', - '--loops', + "-r", + "--regex", + type=str, + default="", + help="Regex to filter benchmarks run", + ) + argp.add_argument( + "-l", + "--loops", type=int, default=20, - help= - 'Number of times to loops the benchmarks. More loops cuts down on noise' + help=( + "Number of times to loops the benchmarks. More loops cuts down on" + " noise" + ), ) - argp.add_argument('--counters', dest='counters', action='store_true') - argp.add_argument('--no-counters', dest='counters', action='store_false') + argp.add_argument("--counters", dest="counters", action="store_true") + argp.add_argument("--no-counters", dest="counters", action="store_false") argp.set_defaults(counters=True) args = argp.parse_args() assert args.name if args.loops < 3: - print("WARNING: This run will likely be noisy. Increase loops to at " - "least 3.") + print( + "WARNING: This run will likely be noisy. Increase loops to at " + "least 3." + ) return args def _collect_bm_data(bm, cfg, name, regex, idx, loops): jobs_list = [] - for line in subprocess.check_output([ - 'bm_diff_%s/%s/%s' % (name, cfg, bm), '--benchmark_list_tests', - '--benchmark_filter=%s' % regex - ]).splitlines(): - line = line.decode('UTF-8') - stripped_line = line.strip().replace("/", - "_").replace("<", "_").replace( - ">", "_").replace(", ", "_") + for line in subprocess.check_output( + [ + "bm_diff_%s/%s/%s" % (name, cfg, bm), + "--benchmark_list_tests", + "--benchmark_filter=%s" % regex, + ] + ).splitlines(): + line = line.decode("UTF-8") + stripped_line = ( + line.strip() + .replace("/", "_") + .replace("<", "_") + .replace(">", "_") + .replace(", ", "_") + ) cmd = [ - 'bm_diff_%s/%s/%s' % (name, cfg, bm), - '--benchmark_filter=^%s$' % line, - '--benchmark_out=%s.%s.%s.%s.%d.json' % - (bm, stripped_line, cfg, name, idx), - '--benchmark_out_format=json', + "bm_diff_%s/%s/%s" % (name, cfg, bm), + "--benchmark_filter=^%s$" % line, + "--benchmark_out=%s.%s.%s.%s.%d.json" + % (bm, stripped_line, cfg, name, idx), + "--benchmark_out_format=json", ] jobs_list.append( - jobset.JobSpec(cmd, - shortname='%s %s %s %s %d/%d' % - (bm, line, cfg, name, idx + 1, loops), - verbose_success=True, - cpu_cost=2, - timeout_seconds=60 * 60)) # one hour + jobset.JobSpec( + cmd, + shortname="%s %s %s %s %d/%d" + % (bm, line, cfg, name, idx + 1, loops), + verbose_success=True, + cpu_cost=2, + timeout_seconds=60 * 60, + ) + ) # one hour return jobs_list @@ -106,13 +135,14 @@ def create_jobs(name, benchmarks, loops, regex): jobs_list = [] for loop in range(0, loops): for bm in benchmarks: - jobs_list += _collect_bm_data(bm, 'opt', name, regex, loop, loops) + jobs_list += _collect_bm_data(bm, "opt", name, regex, loop, loops) random.shuffle(jobs_list, random.SystemRandom().random) return jobs_list -if __name__ == '__main__': +if __name__ == "__main__": args = _args() - jobs_list = create_jobs(args.name, args.benchmarks, args.loops, args.regex, - args.counters) + jobs_list = create_jobs( + args.name, args.benchmarks, args.loops, args.regex, args.counters + ) jobset.run(jobs_list, maxjobs=args.jobs) diff --git a/tools/profiling/microbenchmarks/bm_json.py b/tools/profiling/microbenchmarks/bm_json.py index cb466f231380e..50bd47591a8a1 100644 --- a/tools/profiling/microbenchmarks/bm_json.py +++ b/tools/profiling/microbenchmarks/bm_json.py @@ -20,179 +20,182 @@ # template arguments and dynamic arguments of individual benchmark types # Example benchmark name: "BM_UnaryPingPong/0/0" _BM_SPECS = { - 'BM_UnaryPingPong': { - 'tpl': ['fixture', 'client_mutator', 'server_mutator'], - 'dyn': ['request_size', 'response_size'], - }, - 'BM_PumpStreamClientToServer': { - 'tpl': ['fixture'], - 'dyn': ['request_size'], - }, - 'BM_PumpStreamServerToClient': { - 'tpl': ['fixture'], - 'dyn': ['request_size'], - }, - 'BM_StreamingPingPong': { - 'tpl': ['fixture', 'client_mutator', 'server_mutator'], - 'dyn': ['request_size', 'request_count'], - }, - 'BM_StreamingPingPongMsgs': { - 'tpl': ['fixture', 'client_mutator', 'server_mutator'], - 'dyn': ['request_size'], - }, - 'BM_PumpStreamServerToClient_Trickle': { - 'tpl': [], - 'dyn': ['request_size', 'bandwidth_kilobits'], - }, - 'BM_PumpUnbalancedUnary_Trickle': { - 'tpl': [], - 'dyn': ['cli_req_size', 'svr_req_size', 'bandwidth_kilobits'], - }, - 'BM_ErrorStringOnNewError': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_ErrorStringRepeatedly': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_ErrorGetStatus': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_ErrorGetStatusCode': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_ErrorHttpError': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_HasClearGrpcStatus': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_IsolatedFilter': { - 'tpl': ['fixture', 'client_mutator'], - 'dyn': [], - }, - 'BM_HpackEncoderEncodeHeader': { - 'tpl': ['fixture'], - 'dyn': ['end_of_stream', 'request_size'], - }, - 'BM_HpackParserParseHeader': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_CallCreateDestroy': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_Zalloc': { - 'tpl': [], - 'dyn': ['request_size'], - }, - 'BM_PollEmptyPollset_SpeedOfLight': { - 'tpl': [], - 'dyn': ['request_size', 'request_count'], - }, - 'BM_StreamCreateSendInitialMetadataDestroy': { - 'tpl': ['fixture'], - 'dyn': [], - }, - 'BM_TransportStreamSend': { - 'tpl': [], - 'dyn': ['request_size'], - }, - 'BM_TransportStreamRecv': { - 'tpl': [], - 'dyn': ['request_size'], - }, - 'BM_StreamingPingPongWithCoalescingApi': { - 'tpl': ['fixture', 'client_mutator', 'server_mutator'], - 'dyn': ['request_size', 'request_count', 'end_of_stream'], - }, - 'BM_Base16SomeStuff': { - 'tpl': [], - 'dyn': ['request_size'], - } + "BM_UnaryPingPong": { + "tpl": ["fixture", "client_mutator", "server_mutator"], + "dyn": ["request_size", "response_size"], + }, + "BM_PumpStreamClientToServer": { + "tpl": ["fixture"], + "dyn": ["request_size"], + }, + "BM_PumpStreamServerToClient": { + "tpl": ["fixture"], + "dyn": ["request_size"], + }, + "BM_StreamingPingPong": { + "tpl": ["fixture", "client_mutator", "server_mutator"], + "dyn": ["request_size", "request_count"], + }, + "BM_StreamingPingPongMsgs": { + "tpl": ["fixture", "client_mutator", "server_mutator"], + "dyn": ["request_size"], + }, + "BM_PumpStreamServerToClient_Trickle": { + "tpl": [], + "dyn": ["request_size", "bandwidth_kilobits"], + }, + "BM_PumpUnbalancedUnary_Trickle": { + "tpl": [], + "dyn": ["cli_req_size", "svr_req_size", "bandwidth_kilobits"], + }, + "BM_ErrorStringOnNewError": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_ErrorStringRepeatedly": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_ErrorGetStatus": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_ErrorGetStatusCode": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_ErrorHttpError": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_HasClearGrpcStatus": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_IsolatedFilter": { + "tpl": ["fixture", "client_mutator"], + "dyn": [], + }, + "BM_HpackEncoderEncodeHeader": { + "tpl": ["fixture"], + "dyn": ["end_of_stream", "request_size"], + }, + "BM_HpackParserParseHeader": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_CallCreateDestroy": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_Zalloc": { + "tpl": [], + "dyn": ["request_size"], + }, + "BM_PollEmptyPollset_SpeedOfLight": { + "tpl": [], + "dyn": ["request_size", "request_count"], + }, + "BM_StreamCreateSendInitialMetadataDestroy": { + "tpl": ["fixture"], + "dyn": [], + }, + "BM_TransportStreamSend": { + "tpl": [], + "dyn": ["request_size"], + }, + "BM_TransportStreamRecv": { + "tpl": [], + "dyn": ["request_size"], + }, + "BM_StreamingPingPongWithCoalescingApi": { + "tpl": ["fixture", "client_mutator", "server_mutator"], + "dyn": ["request_size", "request_count", "end_of_stream"], + }, + "BM_Base16SomeStuff": { + "tpl": [], + "dyn": ["request_size"], + }, } def numericalize(s): """Convert abbreviations like '100M' or '10k' to a number.""" if not s: - return '' - if s[-1] == 'k': + return "" + if s[-1] == "k": return float(s[:-1]) * 1024 - if s[-1] == 'M': + if s[-1] == "M": return float(s[:-1]) * 1024 * 1024 - if 0 <= (ord(s[-1]) - ord('0')) <= 9: + if 0 <= (ord(s[-1]) - ord("0")) <= 9: return float(s) - assert 'not a number: %s' % s + assert "not a number: %s" % s def parse_name(name): cpp_name = name - if '<' not in name and '/' not in name and name not in _BM_SPECS: - return {'name': name, 'cpp_name': name} + if "<" not in name and "/" not in name and name not in _BM_SPECS: + return {"name": name, "cpp_name": name} rest = name out = {} tpl_args = [] dyn_args = [] - if '<' in rest: - tpl_bit = rest[rest.find('<') + 1:rest.rfind('>')] - arg = '' + if "<" in rest: + tpl_bit = rest[rest.find("<") + 1 : rest.rfind(">")] + arg = "" nesting = 0 for c in tpl_bit: - if c == '<': + if c == "<": nesting += 1 arg += c - elif c == '>': + elif c == ">": nesting -= 1 arg += c - elif c == ',': + elif c == ",": if nesting == 0: tpl_args.append(arg.strip()) - arg = '' + arg = "" else: arg += c else: arg += c tpl_args.append(arg.strip()) - rest = rest[:rest.find('<')] + rest[rest.rfind('>') + 1:] - if '/' in rest: - s = rest.split('/') + rest = rest[: rest.find("<")] + rest[rest.rfind(">") + 1 :] + if "/" in rest: + s = rest.split("/") rest = s[0] dyn_args = s[1:] name = rest - assert name in _BM_SPECS, '_BM_SPECS needs to be expanded for %s' % name - assert len(dyn_args) == len(_BM_SPECS[name]['dyn']) - assert len(tpl_args) == len(_BM_SPECS[name]['tpl']) - out['name'] = name - out['cpp_name'] = cpp_name + assert name in _BM_SPECS, "_BM_SPECS needs to be expanded for %s" % name + assert len(dyn_args) == len(_BM_SPECS[name]["dyn"]) + assert len(tpl_args) == len(_BM_SPECS[name]["tpl"]) + out["name"] = name + out["cpp_name"] = cpp_name out.update( - dict((k, numericalize(v)) - for k, v in zip(_BM_SPECS[name]['dyn'], dyn_args))) - out.update(dict(zip(_BM_SPECS[name]['tpl'], tpl_args))) + dict( + (k, numericalize(v)) + for k, v in zip(_BM_SPECS[name]["dyn"], dyn_args) + ) + ) + out.update(dict(zip(_BM_SPECS[name]["tpl"], tpl_args))) return out def expand_json(js): if not js: raise StopIteration() - for bm in js['benchmarks']: - if bm['name'].endswith('_stddev') or bm['name'].endswith('_mean'): + for bm in js["benchmarks"]: + if bm["name"].endswith("_stddev") or bm["name"].endswith("_mean"): continue - context = js['context'] - if 'label' in bm: + context = js["context"] + if "label" in bm: labels_list = [ - s.split(':') - for s in bm['label'].strip().split(' ') - if len(s) and s[0] != '#' + s.split(":") + for s in bm["label"].strip().split(" ") + if len(s) and s[0] != "#" ] for el in labels_list: - el[0] = el[0].replace('/iter', '_per_iteration') + el[0] = el[0].replace("/iter", "_per_iteration") labels = dict(labels_list) else: labels = {} @@ -201,11 +204,11 @@ def expand_json(js): # Link the data to a kokoro job run by adding # well known kokoro env variables as metadata for each row row = { - 'jenkins_build': os.environ.get('KOKORO_BUILD_NUMBER', ''), - 'jenkins_job': os.environ.get('KOKORO_JOB_NAME', ''), + "jenkins_build": os.environ.get("KOKORO_BUILD_NUMBER", ""), + "jenkins_job": os.environ.get("KOKORO_JOB_NAME", ""), } row.update(context) row.update(bm) - row.update(parse_name(row['name'])) + row.update(parse_name(row["name"])) row.update(labels) yield row diff --git a/tools/profiling/qps/qps_diff.py b/tools/profiling/qps/qps_diff.py index 68f3310b1bdf9..77ee16b4a3a79 100755 --- a/tools/profiling/qps/qps_diff.py +++ b/tools/profiling/qps/qps_diff.py @@ -27,58 +27,71 @@ import tabulate sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', 'microbenchmarks', - 'bm_diff')) + os.path.join( + os.path.dirname(sys.argv[0]), "..", "microbenchmarks", "bm_diff" + ) +) import bm_speedup sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', '..', 'run_tests', - 'python_utils')) + os.path.join( + os.path.dirname(sys.argv[0]), "..", "..", "run_tests", "python_utils" + ) +) import check_on_pr def _args(): - argp = argparse.ArgumentParser(description='Perform diff on QPS Driver') - argp.add_argument('-d', - '--diff_base', - type=str, - help='Commit or branch to compare the current one to') + argp = argparse.ArgumentParser(description="Perform diff on QPS Driver") + argp.add_argument( + "-d", + "--diff_base", + type=str, + help="Commit or branch to compare the current one to", + ) argp.add_argument( - '-l', - '--loops', + "-l", + "--loops", type=int, default=4, - help='Number of loops for each benchmark. More loops cuts down on noise' + help=( + "Number of loops for each benchmark. More loops cuts down on noise" + ), + ) + argp.add_argument( + "-j", + "--jobs", + type=int, + default=multiprocessing.cpu_count(), + help="Number of CPUs to use", ) - argp.add_argument('-j', - '--jobs', - type=int, - default=multiprocessing.cpu_count(), - help='Number of CPUs to use') args = argp.parse_args() assert args.diff_base, "diff_base must be set" return args def _make_cmd(jobs): - return ['make', '-j', '%d' % jobs, 'qps_json_driver', 'qps_worker'] + return ["make", "-j", "%d" % jobs, "qps_json_driver", "qps_worker"] def build(name, jobs): - shutil.rmtree('qps_diff_%s' % name, ignore_errors=True) - subprocess.check_call(['git', 'submodule', 'update']) + shutil.rmtree("qps_diff_%s" % name, ignore_errors=True) + subprocess.check_call(["git", "submodule", "update"]) try: subprocess.check_call(_make_cmd(jobs)) except subprocess.CalledProcessError as e: - subprocess.check_call(['make', 'clean']) + subprocess.check_call(["make", "clean"]) subprocess.check_call(_make_cmd(jobs)) - os.rename('bins', 'qps_diff_%s' % name) + os.rename("bins", "qps_diff_%s" % name) def _run_cmd(name, scenario, fname): return [ - 'qps_diff_%s/opt/qps_json_driver' % name, '--scenarios_json', scenario, - '--json_file_out', fname + "qps_diff_%s/opt/qps_json_driver" % name, + "--scenarios_json", + scenario, + "--json_file_out", + fname, ] @@ -92,7 +105,7 @@ def run(name, scenarios, loops): def _load_qps(fname): try: with open(fname) as f: - return json.loads(f.read())['qps'] + return json.loads(f.read())["qps"] except IOError as e: print(("IOError occurred reading file: %s" % fname)) return None @@ -102,7 +115,7 @@ def _load_qps(fname): def _median(ary): - assert (len(ary)) + assert len(ary) ary = sorted(ary) n = len(ary) if n % 2 == 0: @@ -124,48 +137,55 @@ def diff(scenarios, loops, old, new): new_data[sn].append(_load_qps("%s.%s.%d.json" % (sn, new, i))) # crunch data - headers = ['Benchmark', 'qps'] + headers = ["Benchmark", "qps"] rows = [] for sn in scenarios: mdn_diff = abs(_median(new_data[sn]) - _median(old_data[sn])) - print(('%s: %s=%r %s=%r mdn_diff=%r' % - (sn, new, new_data[sn], old, old_data[sn], mdn_diff))) + print( + "%s: %s=%r %s=%r mdn_diff=%r" + % (sn, new, new_data[sn], old, old_data[sn], mdn_diff) + ) s = bm_speedup.speedup(new_data[sn], old_data[sn], 10e-5) if abs(s) > 3 and mdn_diff > 0.5: - rows.append([sn, '%+d%%' % s]) + rows.append([sn, "%+d%%" % s]) if rows: - return tabulate.tabulate(rows, headers=headers, floatfmt='+.2f') + return tabulate.tabulate(rows, headers=headers, floatfmt="+.2f") else: return None def main(args): - build('new', args.jobs) + build("new", args.jobs) if args.diff_base: - where_am_i = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD']).decode().strip() - subprocess.check_call(['git', 'checkout', args.diff_base]) + where_am_i = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"] + ) + .decode() + .strip() + ) + subprocess.check_call(["git", "checkout", args.diff_base]) try: - build('old', args.jobs) + build("old", args.jobs) finally: - subprocess.check_call(['git', 'checkout', where_am_i]) - subprocess.check_call(['git', 'submodule', 'update']) + subprocess.check_call(["git", "checkout", where_am_i]) + subprocess.check_call(["git", "submodule", "update"]) - run('new', qps_scenarios._SCENARIOS, args.loops) - run('old', qps_scenarios._SCENARIOS, args.loops) + run("new", qps_scenarios._SCENARIOS, args.loops) + run("old", qps_scenarios._SCENARIOS, args.loops) - diff_output = diff(qps_scenarios._SCENARIOS, args.loops, 'old', 'new') + diff_output = diff(qps_scenarios._SCENARIOS, args.loops, "old", "new") if diff_output: - text = '[qps] Performance differences noted:\n%s' % diff_output + text = "[qps] Performance differences noted:\n%s" % diff_output else: - text = '[qps] No significant performance differences' - print(('%s' % text)) - check_on_pr.check_on_pr('QPS', '```\n%s\n```' % text) + text = "[qps] No significant performance differences" + print(("%s" % text)) + check_on_pr.check_on_pr("QPS", "```\n%s\n```" % text) -if __name__ == '__main__': +if __name__ == "__main__": args = _args() main(args) diff --git a/tools/profiling/qps/qps_scenarios.py b/tools/profiling/qps/qps_scenarios.py index ae2552eda0632..57031d4741bea 100644 --- a/tools/profiling/qps/qps_scenarios.py +++ b/tools/profiling/qps/qps_scenarios.py @@ -14,8 +14,30 @@ """ QPS Scenarios to run """ _SCENARIOS = { - 'large-message-throughput': - '{"scenarios":[{"name":"large-message-throughput", "spawn_local_worker_count": -2, "warmup_seconds": 30, "benchmark_seconds": 270, "num_servers": 1, "server_config": {"async_server_threads": 1, "security_params": null, "server_type": "ASYNC_SERVER"}, "num_clients": 1, "client_config": {"client_type": "ASYNC_CLIENT", "security_params": null, "payload_config": {"simple_params": {"resp_size": 1048576, "req_size": 1048576}}, "client_channels": 1, "async_client_threads": 1, "outstanding_rpcs_per_channel": 1, "rpc_type": "UNARY", "load_params": {"closed_loop": {}}, "histogram_params": {"max_possible": 60000000000.0, "resolution": 0.01}}}]}', - 'multi-channel-64-KiB': - '{"scenarios":[{"name":"multi-channel-64-KiB", "spawn_local_worker_count": -3, "warmup_seconds": 30, "benchmark_seconds": 270, "num_servers": 1, "server_config": {"async_server_threads": 31, "security_params": null, "server_type": "ASYNC_SERVER"}, "num_clients": 2, "client_config": {"client_type": "ASYNC_CLIENT", "security_params": null, "payload_config": {"simple_params": {"resp_size": 65536, "req_size": 65536}}, "client_channels": 32, "async_client_threads": 31, "outstanding_rpcs_per_channel": 100, "rpc_type": "UNARY", "load_params": {"closed_loop": {}}, "histogram_params": {"max_possible": 60000000000.0, "resolution": 0.01}}}]}' + "large-message-throughput": ( + '{"scenarios":[{"name":"large-message-throughput",' + ' "spawn_local_worker_count": -2, "warmup_seconds": 30,' + ' "benchmark_seconds": 270, "num_servers": 1, "server_config":' + ' {"async_server_threads": 1, "security_params": null, "server_type":' + ' "ASYNC_SERVER"}, "num_clients": 1, "client_config": {"client_type":' + ' "ASYNC_CLIENT", "security_params": null, "payload_config":' + ' {"simple_params": {"resp_size": 1048576, "req_size": 1048576}},' + ' "client_channels": 1, "async_client_threads": 1,' + ' "outstanding_rpcs_per_channel": 1, "rpc_type": "UNARY",' + ' "load_params": {"closed_loop": {}}, "histogram_params":' + ' {"max_possible": 60000000000.0, "resolution": 0.01}}}]}' + ), + "multi-channel-64-KiB": ( + '{"scenarios":[{"name":"multi-channel-64-KiB",' + ' "spawn_local_worker_count": -3, "warmup_seconds": 30,' + ' "benchmark_seconds": 270, "num_servers": 1, "server_config":' + ' {"async_server_threads": 31, "security_params": null, "server_type":' + ' "ASYNC_SERVER"}, "num_clients": 2, "client_config": {"client_type":' + ' "ASYNC_CLIENT", "security_params": null, "payload_config":' + ' {"simple_params": {"resp_size": 65536, "req_size": 65536}},' + ' "client_channels": 32, "async_client_threads": 31,' + ' "outstanding_rpcs_per_channel": 100, "rpc_type": "UNARY",' + ' "load_params": {"closed_loop": {}}, "histogram_params":' + ' {"max_possible": 60000000000.0, "resolution": 0.01}}}]}' + ), } diff --git a/tools/release/release_notes.py b/tools/release/release_notes.py index d16a541685ec5..e931fee0d7b90 100644 --- a/tools/release/release_notes.py +++ b/tools/release/release_notes.py @@ -1,4 +1,4 @@ -#Copyright 2019 gRPC authors. +# Copyright 2019 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -83,38 +83,44 @@ """ HTML_URL = "https://github.com/grpc/grpc/pull/" -API_URL = 'https://api.github.com/repos/grpc/grpc/pulls/' +API_URL = "https://api.github.com/repos/grpc/grpc/pulls/" def get_commit_log(prevRelLabel, relBranch): - """Return the output of 'git log prevRelLabel..relBranch' """ + """Return the output of 'git log prevRelLabel..relBranch'""" import subprocess + glg_command = [ - "git", "log", "--pretty=oneline", "--committer=GitHub", - "%s..%s" % (prevRelLabel, relBranch) + "git", + "log", + "--pretty=oneline", + "--committer=GitHub", + "%s..%s" % (prevRelLabel, relBranch), ] print(("Running ", " ".join(glg_command))) - return subprocess.check_output(glg_command).decode('utf-8', 'ignore') + return subprocess.check_output(glg_command).decode("utf-8", "ignore") def get_pr_data(pr_num): """Get the PR data from github. Return 'error' on exception""" - http = urllib3.PoolManager(retries=urllib3.Retry(total=7, backoff_factor=1), - timeout=4.0) + http = urllib3.PoolManager( + retries=urllib3.Retry(total=7, backoff_factor=1), timeout=4.0 + ) url = API_URL + pr_num try: - response = http.request('GET', - url, - headers={'Authorization': 'token %s' % TOKEN}) + response = http.request( + "GET", url, headers={"Authorization": "token %s" % TOKEN} + ) except urllib3.exceptions.HTTPError as e: - print('Request error:', e.reason) - return 'error' - return json.loads(response.data.decode('utf-8')) + print("Request error:", e.reason) + return "error" + return json.loads(response.data.decode("utf-8")) def get_pr_titles(gitLogs): import re + error_count = 0 # PRs with merge commits match_merge_pr = "Merge pull request #(\d+)" @@ -136,21 +142,22 @@ def get_pr_titles(gitLogs): pr = get_pr_data(pr_num) if pr == "error": print( - ("\n***ERROR*** Error in getting data for PR " + pr_num + "\n")) + ("\n***ERROR*** Error in getting data for PR " + pr_num + "\n") + ) error_count += 1 continue rl_no_found = False rl_yes_found = False lang_found = False - for label in pr['labels']: - if label['name'] == 'release notes: yes': + for label in pr["labels"]: + if label["name"] == "release notes: yes": rl_yes_found = True - elif label['name'] == 'release notes: no': + elif label["name"] == "release notes: no": rl_no_found = True - elif label['name'].startswith('lang/'): + elif label["name"].startswith("lang/"): lang_found = True - lang = label['name'].split('/')[1].lower() - #lang = lang[0].upper() + lang[1:] + lang = label["name"].split("/")[1].lower() + # lang = lang[0].upper() + lang[1:] body = pr["title"] if not body.endswith("."): body = body + "." @@ -159,10 +166,12 @@ def get_pr_titles(gitLogs): error_count += 1 continue - prline = "- " + body + " ([#" + pr_num + "](" + HTML_URL + pr_num + "))" + prline = ( + "- " + body + " ([#" + pr_num + "](" + HTML_URL + pr_num + "))" + ) detail = "- " + pr["merged_by"]["login"] + "@ " + prline print(detail) - #if no RL label + # if no RL label if not rl_no_found and not rl_yes_found: print(("Release notes label missing for " + pr_num)) langs_pr["nolabel"].append(detail) @@ -173,8 +182,14 @@ def get_pr_titles(gitLogs): print(("'Release notes:no' found for " + pr_num)) langs_pr["notinrel"].append(detail) elif rl_yes_found: - print(("'Release notes:yes' found for " + pr_num + " with lang " + - lang)) + print( + ( + "'Release notes:yes' found for " + + pr_num + + " with lang " + + lang + ) + ) langs_pr["inrel"].append(detail) langs_pr[lang].append(prline) @@ -204,7 +219,8 @@ def write_draft(langs_pr, file, version, date): file.write("\n") file.write("\n") file.write( - "PRs going into release notes - please check title and fix in Github. Do not edit here.\n" + "PRs going into release notes - please check title and fix in Github." + " Do not edit here.\n" ) file.write("---\n") file.write("\n") @@ -288,53 +304,69 @@ def write_rel_notes(langs_pr, file, version, name): def build_args_parser(): import argparse + parser = argparse.ArgumentParser() - parser.add_argument('release_version', - type=str, - help='New release version e.g. 1.14.0') - parser.add_argument('release_name', - type=str, - help='New release name e.g. gladiolus') - parser.add_argument('release_date', - type=str, - help='Release date e.g. 7/30/18') - parser.add_argument('previous_release_label', - type=str, - help='Previous release branch/tag e.g. v1.13.x') - parser.add_argument('release_branch', - type=str, - help='Current release branch e.g. origin/v1.14.x') - parser.add_argument('draft_filename', - type=str, - help='Name of the draft file e.g. draft.md') - parser.add_argument('release_notes_filename', - type=str, - help='Name of the release notes file e.g. relnotes.md') - parser.add_argument('--token', - type=str, - default='', - help='GitHub API token to avoid being rate limited') + parser.add_argument( + "release_version", type=str, help="New release version e.g. 1.14.0" + ) + parser.add_argument( + "release_name", type=str, help="New release name e.g. gladiolus" + ) + parser.add_argument( + "release_date", type=str, help="Release date e.g. 7/30/18" + ) + parser.add_argument( + "previous_release_label", + type=str, + help="Previous release branch/tag e.g. v1.13.x", + ) + parser.add_argument( + "release_branch", + type=str, + help="Current release branch e.g. origin/v1.14.x", + ) + parser.add_argument( + "draft_filename", type=str, help="Name of the draft file e.g. draft.md" + ) + parser.add_argument( + "release_notes_filename", + type=str, + help="Name of the release notes file e.g. relnotes.md", + ) + parser.add_argument( + "--token", + type=str, + default="", + help="GitHub API token to avoid being rate limited", + ) return parser def main(): import os + global TOKEN parser = build_args_parser() args = parser.parse_args() - version, name, date = args.release_version, args.release_name, args.release_date + version, name, date = ( + args.release_version, + args.release_name, + args.release_date, + ) start, end = args.previous_release_label, args.release_branch TOKEN = args.token - if TOKEN == '': + if TOKEN == "": try: TOKEN = os.environ["GITHUB_TOKEN"] except: pass - if TOKEN == '': + if TOKEN == "": print( - "Error: Github API token required. Either include param --token= or set environment variable GITHUB_TOKEN to your github token" + "Error: Github API token required. Either include param" + " --token= or set environment variable" + " GITHUB_TOKEN to your github token" ) return @@ -343,9 +375,9 @@ def main(): draft_file, rel_file = args.draft_filename, args.release_notes_filename filename = os.path.abspath(draft_file) if os.path.exists(filename): - file = open(filename, 'r+') + file = open(filename, "r+") else: - file = open(filename, 'w') + file = open(filename, "w") file.seek(0) write_draft(langs_pr, file, version, date) @@ -355,9 +387,9 @@ def main(): filename = os.path.abspath(rel_file) if os.path.exists(filename): - file = open(filename, 'r+') + file = open(filename, "r+") else: - file = open(filename, 'w') + file = open(filename, "w") file.seek(0) write_rel_notes(langs_pr, file, version, name) diff --git a/tools/release/verify_python_release.py b/tools/release/verify_python_release.py index 186e65bc2172a..b03512aae3944 100644 --- a/tools/release/verify_python_release.py +++ b/tools/release/verify_python_release.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -#Copyright 2019 gRPC authors. +# Copyright 2019 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,7 +52,7 @@ def _get_md5_checksum(filename): """Calculate the md5sum for a file.""" hash_md5 = hashlib.md5() - with open(filename, 'rb') as f: + with open(filename, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() @@ -61,7 +61,8 @@ def _get_md5_checksum(filename): def _get_local_artifacts(): """Get a set of artifacts representing all files in the cwd.""" return set( - Artifact(f, _get_md5_checksum(f)) for f in os.listdir(os.getcwd())) + Artifact(f, _get_md5_checksum(f)) for f in os.listdir(os.getcwd()) + ) def _get_remote_artifacts_for_package(package, version): @@ -71,13 +72,15 @@ def _get_remote_artifacts_for_package(package, version): experience, it has taken a minute on average to be fresh. """ artifacts = set() - payload_resp = requests.get("https://pypi.org/pypi/{}/{}/json".format( - package, version)) + payload_resp = requests.get( + "https://pypi.org/pypi/{}/{}/json".format(package, version) + ) payload_resp.raise_for_status() payload = payload_resp.json() - for download_info in payload['urls']: + for download_info in payload["urls"]: artifacts.add( - Artifact(download_info['filename'], download_info['md5_digest'])) + Artifact(download_info["filename"], download_info["md5_digest"]) + ) return artifacts @@ -111,11 +114,11 @@ def _verify_release(version, packages): parser = argparse.ArgumentParser( "Verify a release. Run this from a directory containing only the" "artifacts to be uploaded. Note that PyPI may take several minutes" - "after the upload to reflect the proper metadata.") + "after the upload to reflect the proper metadata." + ) parser.add_argument("version") - parser.add_argument("packages", - nargs='*', - type=str, - default=_DEFAULT_PACKAGES) + parser.add_argument( + "packages", nargs="*", type=str, default=_DEFAULT_PACKAGES + ) args = parser.parse_args() _verify_release(args.version, args.packages) diff --git a/tools/run_tests/artifacts/artifact_targets.py b/tools/run_tests/artifacts/artifact_targets.py index 15db34a2d5d65..9ec024622fedc 100644 --- a/tools/run_tests/artifacts/artifact_targets.py +++ b/tools/run_tests/artifacts/artifact_targets.py @@ -19,107 +19,116 @@ import string import sys -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) import python_utils.jobset as jobset _LATEST_MANYLINUX = "manylinux2014" -def create_docker_jobspec(name, - dockerfile_dir, - shell_command, - environ={}, - flake_retries=0, - timeout_retries=0, - timeout_seconds=30 * 60, - extra_docker_args=None, - verbose_success=False): +def create_docker_jobspec( + name, + dockerfile_dir, + shell_command, + environ={}, + flake_retries=0, + timeout_retries=0, + timeout_seconds=30 * 60, + extra_docker_args=None, + verbose_success=False, +): """Creates jobspec for a task running under docker.""" environ = environ.copy() - environ['ARTIFACTS_OUT'] = 'artifacts/%s' % name + environ["ARTIFACTS_OUT"] = "artifacts/%s" % name docker_args = [] for k, v in list(environ.items()): - docker_args += ['-e', '%s=%s' % (k, v)] + docker_args += ["-e", "%s=%s" % (k, v)] docker_env = { - 'DOCKERFILE_DIR': dockerfile_dir, - 'DOCKER_RUN_SCRIPT': 'tools/run_tests/dockerize/docker_run.sh', - 'DOCKER_RUN_SCRIPT_COMMAND': shell_command, - 'OUTPUT_DIR': 'artifacts' + "DOCKERFILE_DIR": dockerfile_dir, + "DOCKER_RUN_SCRIPT": "tools/run_tests/dockerize/docker_run.sh", + "DOCKER_RUN_SCRIPT_COMMAND": shell_command, + "OUTPUT_DIR": "artifacts", } if extra_docker_args is not None: - docker_env['EXTRA_DOCKER_ARGS'] = extra_docker_args + docker_env["EXTRA_DOCKER_ARGS"] = extra_docker_args jobspec = jobset.JobSpec( - cmdline=['tools/run_tests/dockerize/build_and_run_docker.sh'] + - docker_args, + cmdline=["tools/run_tests/dockerize/build_and_run_docker.sh"] + + docker_args, environ=docker_env, - shortname='build_artifact.%s' % (name), + shortname="build_artifact.%s" % (name), timeout_seconds=timeout_seconds, flake_retries=flake_retries, timeout_retries=timeout_retries, - verbose_success=verbose_success) + verbose_success=verbose_success, + ) return jobspec -def create_jobspec(name, - cmdline, - environ={}, - shell=False, - flake_retries=0, - timeout_retries=0, - timeout_seconds=30 * 60, - use_workspace=False, - cpu_cost=1.0, - verbose_success=False): +def create_jobspec( + name, + cmdline, + environ={}, + shell=False, + flake_retries=0, + timeout_retries=0, + timeout_seconds=30 * 60, + use_workspace=False, + cpu_cost=1.0, + verbose_success=False, +): """Creates jobspec.""" environ = environ.copy() if use_workspace: - environ['WORKSPACE_NAME'] = 'workspace_%s' % name - environ['ARTIFACTS_OUT'] = os.path.join('..', 'artifacts', name) - cmdline = ['bash', 'tools/run_tests/artifacts/run_in_workspace.sh' - ] + cmdline + environ["WORKSPACE_NAME"] = "workspace_%s" % name + environ["ARTIFACTS_OUT"] = os.path.join("..", "artifacts", name) + cmdline = [ + "bash", + "tools/run_tests/artifacts/run_in_workspace.sh", + ] + cmdline else: - environ['ARTIFACTS_OUT'] = os.path.join('artifacts', name) - - jobspec = jobset.JobSpec(cmdline=cmdline, - environ=environ, - shortname='build_artifact.%s' % (name), - timeout_seconds=timeout_seconds, - flake_retries=flake_retries, - timeout_retries=timeout_retries, - shell=shell, - cpu_cost=cpu_cost, - verbose_success=verbose_success) + environ["ARTIFACTS_OUT"] = os.path.join("artifacts", name) + + jobspec = jobset.JobSpec( + cmdline=cmdline, + environ=environ, + shortname="build_artifact.%s" % (name), + timeout_seconds=timeout_seconds, + flake_retries=flake_retries, + timeout_retries=timeout_retries, + shell=shell, + cpu_cost=cpu_cost, + verbose_success=verbose_success, + ) return jobspec -_MACOS_COMPAT_FLAG = '-mmacosx-version-min=10.10' +_MACOS_COMPAT_FLAG = "-mmacosx-version-min=10.10" -_ARCH_FLAG_MAP = {'x86': '-m32', 'x64': '-m64'} +_ARCH_FLAG_MAP = {"x86": "-m32", "x64": "-m64"} class PythonArtifact: """Builds Python artifacts.""" def __init__(self, platform, arch, py_version, presubmit=False): - self.name = 'python_%s_%s_%s' % (platform, arch, py_version) + self.name = "python_%s_%s_%s" % (platform, arch, py_version) self.platform = platform self.arch = arch - self.labels = ['artifact', 'python', platform, arch, py_version] + self.labels = ["artifact", "python", platform, arch, py_version] if presubmit: - self.labels.append('presubmit') + self.labels.append("presubmit") self.py_version = py_version if platform == _LATEST_MANYLINUX: - self.labels.append('latest-manylinux') - if 'manylinux' in platform: - self.labels.append('linux') - if 'linux_extra' in platform: + self.labels.append("latest-manylinux") + if "manylinux" in platform: + self.labels.append("linux") + if "linux_extra" in platform: # linux_extra wheels used to be built by a separate kokoro job. # Their build is now much faster, so they can be included # in the regular artifact build. - self.labels.append('linux') - if 'musllinux' in platform: - self.labels.append('linux') + self.labels.append("linux") + if "musllinux" in platform: + self.labels.append("linux") def pre_build_jobspecs(self): return [] @@ -129,92 +138,105 @@ def build_jobspec(self, inner_jobs=None): if inner_jobs is not None: # set number of parallel jobs when building native extension # building the native extension is the most time-consuming part of the build - environ['GRPC_PYTHON_BUILD_EXT_COMPILER_JOBS'] = str(inner_jobs) + environ["GRPC_PYTHON_BUILD_EXT_COMPILER_JOBS"] = str(inner_jobs) if self.platform == "macos": - environ['ARCHFLAGS'] = "-arch arm64 -arch x86_64" + environ["ARCHFLAGS"] = "-arch arm64 -arch x86_64" environ["GRPC_UNIVERSAL2_REPAIR"] = "true" - environ['GRPC_BUILD_WITH_BORING_SSL_ASM'] = "false" + environ["GRPC_BUILD_WITH_BORING_SSL_ASM"] = "false" - if self.platform == 'linux_extra': + if self.platform == "linux_extra": # Crosscompilation build for armv7 (e.g. Raspberry Pi) - environ['PYTHON'] = '/opt/python/{}/bin/python3'.format( - self.py_version) - environ['PIP'] = '/opt/python/{}/bin/pip3'.format(self.py_version) - environ['GRPC_SKIP_PIP_CYTHON_UPGRADE'] = 'TRUE' - environ['GRPC_SKIP_TWINE_CHECK'] = 'TRUE' + environ["PYTHON"] = "/opt/python/{}/bin/python3".format( + self.py_version + ) + environ["PIP"] = "/opt/python/{}/bin/pip3".format(self.py_version) + environ["GRPC_SKIP_PIP_CYTHON_UPGRADE"] = "TRUE" + environ["GRPC_SKIP_TWINE_CHECK"] = "TRUE" return create_docker_jobspec( self.name, - 'tools/dockerfile/grpc_artifact_python_linux_{}'.format( - self.arch), - 'tools/run_tests/artifacts/build_artifact_python.sh', + "tools/dockerfile/grpc_artifact_python_linux_{}".format( + self.arch + ), + "tools/run_tests/artifacts/build_artifact_python.sh", environ=environ, - timeout_seconds=60 * 60) - elif 'manylinux' in self.platform: - if self.arch == 'x86': - environ['SETARCH_CMD'] = 'linux32' + timeout_seconds=60 * 60, + ) + elif "manylinux" in self.platform: + if self.arch == "x86": + environ["SETARCH_CMD"] = "linux32" # Inside the manylinux container, the python installations are located in # special places... - environ['PYTHON'] = '/opt/python/{}/bin/python'.format( - self.py_version) - environ['PIP'] = '/opt/python/{}/bin/pip'.format(self.py_version) - environ['GRPC_SKIP_PIP_CYTHON_UPGRADE'] = 'TRUE' - if self.arch == 'aarch64': - environ['GRPC_SKIP_TWINE_CHECK'] = 'TRUE' + environ["PYTHON"] = "/opt/python/{}/bin/python".format( + self.py_version + ) + environ["PIP"] = "/opt/python/{}/bin/pip".format(self.py_version) + environ["GRPC_SKIP_PIP_CYTHON_UPGRADE"] = "TRUE" + if self.arch == "aarch64": + environ["GRPC_SKIP_TWINE_CHECK"] = "TRUE" # As we won't strip the binary with auditwheel (see below), strip # it at link time. - environ['LDFLAGS'] = '-s' + environ["LDFLAGS"] = "-s" else: # only run auditwheel if we're not crosscompiling - environ['GRPC_RUN_AUDITWHEEL_REPAIR'] = 'TRUE' + environ["GRPC_RUN_AUDITWHEEL_REPAIR"] = "TRUE" # only build the packages that depend on grpcio-tools # if we're not crosscompiling. # - they require protoc to run on current architecture # - they only have sdist packages anyway, so it's useless to build them again - environ['GRPC_BUILD_GRPCIO_TOOLS_DEPENDENTS'] = 'TRUE' + environ["GRPC_BUILD_GRPCIO_TOOLS_DEPENDENTS"] = "TRUE" return create_docker_jobspec( self.name, - 'tools/dockerfile/grpc_artifact_python_%s_%s' % - (self.platform, self.arch), - 'tools/run_tests/artifacts/build_artifact_python.sh', + "tools/dockerfile/grpc_artifact_python_%s_%s" + % (self.platform, self.arch), + "tools/run_tests/artifacts/build_artifact_python.sh", environ=environ, - timeout_seconds=60 * 60 * 2) - elif 'musllinux' in self.platform: - environ['PYTHON'] = '/opt/python/{}/bin/python'.format( - self.py_version) - environ['PIP'] = '/opt/python/{}/bin/pip'.format(self.py_version) - environ['GRPC_SKIP_PIP_CYTHON_UPGRADE'] = 'TRUE' - environ['GRPC_RUN_AUDITWHEEL_REPAIR'] = 'TRUE' - environ['GRPC_PYTHON_BUILD_WITH_STATIC_LIBSTDCXX'] = 'TRUE' + timeout_seconds=60 * 60 * 2, + ) + elif "musllinux" in self.platform: + environ["PYTHON"] = "/opt/python/{}/bin/python".format( + self.py_version + ) + environ["PIP"] = "/opt/python/{}/bin/pip".format(self.py_version) + environ["GRPC_SKIP_PIP_CYTHON_UPGRADE"] = "TRUE" + environ["GRPC_RUN_AUDITWHEEL_REPAIR"] = "TRUE" + environ["GRPC_PYTHON_BUILD_WITH_STATIC_LIBSTDCXX"] = "TRUE" return create_docker_jobspec( self.name, - 'tools/dockerfile/grpc_artifact_python_%s_%s' % - (self.platform, self.arch), - 'tools/run_tests/artifacts/build_artifact_python.sh', + "tools/dockerfile/grpc_artifact_python_%s_%s" + % (self.platform, self.arch), + "tools/run_tests/artifacts/build_artifact_python.sh", environ=environ, - timeout_seconds=60 * 60 * 2) - elif self.platform == 'windows': - environ['EXT_COMPILER'] = 'msvc' + timeout_seconds=60 * 60 * 2, + ) + elif self.platform == "windows": + environ["EXT_COMPILER"] = "msvc" # For some reason, the batch script %random% always runs with the same # seed. We create a random temp-dir here - dir = ''.join( - random.choice(string.ascii_uppercase) for _ in range(10)) - return create_jobspec(self.name, [ - 'tools\\run_tests\\artifacts\\build_artifact_python.bat', - self.py_version, '32' if self.arch == 'x86' else '64' - ], - environ=environ, - timeout_seconds=45 * 60, - use_workspace=True) + dir = "".join( + random.choice(string.ascii_uppercase) for _ in range(10) + ) + return create_jobspec( + self.name, + [ + "tools\\run_tests\\artifacts\\build_artifact_python.bat", + self.py_version, + "32" if self.arch == "x86" else "64", + ], + environ=environ, + timeout_seconds=45 * 60, + use_workspace=True, + ) else: - environ['PYTHON'] = self.py_version - environ['SKIP_PIP_INSTALL'] = 'TRUE' + environ["PYTHON"] = self.py_version + environ["SKIP_PIP_INSTALL"] = "TRUE" return create_jobspec( self.name, - ['tools/run_tests/artifacts/build_artifact_python.sh'], + ["tools/run_tests/artifacts/build_artifact_python.sh"], environ=environ, timeout_seconds=60 * 60 * 2, - use_workspace=True) + use_workspace=True, + ) def __str__(self): return self.name @@ -224,12 +246,12 @@ class RubyArtifact: """Builds ruby native gem.""" def __init__(self, platform, gem_platform, presubmit=False): - self.name = 'ruby_native_gem_%s_%s' % (platform, gem_platform) + self.name = "ruby_native_gem_%s_%s" % (platform, gem_platform) self.platform = platform self.gem_platform = gem_platform - self.labels = ['artifact', 'ruby', platform, gem_platform] + self.labels = ["artifact", "ruby", platform, gem_platform] if presubmit: - self.labels.append('presubmit') + self.labels.append("presubmit") def pre_build_jobspecs(self): return [] @@ -238,55 +260,61 @@ def build_jobspec(self, inner_jobs=None): environ = {} if inner_jobs is not None: # set number of parallel jobs when building native extension - environ['GRPC_RUBY_BUILD_PROCS'] = str(inner_jobs) + environ["GRPC_RUBY_BUILD_PROCS"] = str(inner_jobs) # Ruby build uses docker internally and docker cannot be nested. # We are using a custom workspace instead. - return create_jobspec(self.name, [ - 'tools/run_tests/artifacts/build_artifact_ruby.sh', - self.gem_platform - ], - use_workspace=True, - timeout_seconds=90 * 60, - environ=environ) + return create_jobspec( + self.name, + [ + "tools/run_tests/artifacts/build_artifact_ruby.sh", + self.gem_platform, + ], + use_workspace=True, + timeout_seconds=90 * 60, + environ=environ, + ) class PHPArtifact: """Builds PHP PECL package""" def __init__(self, platform, arch, presubmit=False): - self.name = 'php_pecl_package_{0}_{1}'.format(platform, arch) + self.name = "php_pecl_package_{0}_{1}".format(platform, arch) self.platform = platform self.arch = arch - self.labels = ['artifact', 'php', platform, arch] + self.labels = ["artifact", "php", platform, arch] if presubmit: - self.labels.append('presubmit') + self.labels.append("presubmit") def pre_build_jobspecs(self): return [] def build_jobspec(self, inner_jobs=None): del inner_jobs # arg unused as PHP artifact build is basically just packing an archive - if self.platform == 'linux': + if self.platform == "linux": return create_docker_jobspec( self.name, - 'tools/dockerfile/test/php73_zts_debian11_{}'.format(self.arch), - 'tools/run_tests/artifacts/build_artifact_php.sh') + "tools/dockerfile/test/php73_zts_debian11_{}".format(self.arch), + "tools/run_tests/artifacts/build_artifact_php.sh", + ) else: return create_jobspec( - self.name, ['tools/run_tests/artifacts/build_artifact_php.sh'], - use_workspace=True) + self.name, + ["tools/run_tests/artifacts/build_artifact_php.sh"], + use_workspace=True, + ) class ProtocArtifact: """Builds protoc and protoc-plugin artifacts""" def __init__(self, platform, arch, presubmit=False): - self.name = 'protoc_%s_%s' % (platform, arch) + self.name = "protoc_%s_%s" % (platform, arch) self.platform = platform self.arch = arch - self.labels = ['artifact', 'protoc', platform, arch] + self.labels = ["artifact", "protoc", platform, arch] if presubmit: - self.labels.append('presubmit') + self.labels.append("presubmit") def pre_build_jobspecs(self): return [] @@ -295,41 +323,52 @@ def build_jobspec(self, inner_jobs=None): environ = {} if inner_jobs is not None: # set number of parallel jobs when building protoc - environ['GRPC_PROTOC_BUILD_COMPILER_JOBS'] = str(inner_jobs) - - if self.platform != 'windows': - environ['CXXFLAGS'] = '' - environ['LDFLAGS'] = '' - if self.platform == 'linux': - dockerfile_dir = 'tools/dockerfile/grpc_artifact_centos6_{}'.format( - self.arch) - if self.arch == 'aarch64': + environ["GRPC_PROTOC_BUILD_COMPILER_JOBS"] = str(inner_jobs) + + if self.platform != "windows": + environ["CXXFLAGS"] = "" + environ["LDFLAGS"] = "" + if self.platform == "linux": + dockerfile_dir = ( + "tools/dockerfile/grpc_artifact_centos6_{}".format( + self.arch + ) + ) + if self.arch == "aarch64": # for aarch64, use a dockcross manylinux image that will # give us both ready to use crosscompiler and sufficient backward compatibility - dockerfile_dir = 'tools/dockerfile/grpc_artifact_protoc_aarch64' - environ['LDFLAGS'] += ' -static-libgcc -static-libstdc++ -s' + dockerfile_dir = ( + "tools/dockerfile/grpc_artifact_protoc_aarch64" + ) + environ["LDFLAGS"] += " -static-libgcc -static-libstdc++ -s" return create_docker_jobspec( self.name, dockerfile_dir, - 'tools/run_tests/artifacts/build_artifact_protoc.sh', - environ=environ) + "tools/run_tests/artifacts/build_artifact_protoc.sh", + environ=environ, + ) else: - environ[ - 'CXXFLAGS'] += ' -std=c++14 -stdlib=libc++ %s' % _MACOS_COMPAT_FLAG + environ["CXXFLAGS"] += ( + " -std=c++14 -stdlib=libc++ %s" % _MACOS_COMPAT_FLAG + ) return create_jobspec( self.name, - ['tools/run_tests/artifacts/build_artifact_protoc.sh'], + ["tools/run_tests/artifacts/build_artifact_protoc.sh"], environ=environ, timeout_seconds=60 * 60, - use_workspace=True) + use_workspace=True, + ) else: - vs_tools_architecture = self.arch # architecture selector passed to vcvarsall.bat - environ['ARCHITECTURE'] = vs_tools_architecture + vs_tools_architecture = ( + self.arch + ) # architecture selector passed to vcvarsall.bat + environ["ARCHITECTURE"] = vs_tools_architecture return create_jobspec( self.name, - ['tools\\run_tests\\artifacts\\build_artifact_protoc.bat'], + ["tools\\run_tests\\artifacts\\build_artifact_protoc.bat"], environ=environ, - use_workspace=True) + use_workspace=True, + ) def __str__(self): return self.name @@ -341,72 +380,100 @@ def _reorder_targets_for_build_speed(targets): # we start building ruby artifacts first, so that they don't end up # being a long tail once everything else finishes. return list( - sorted(targets, - key=lambda target: 0 if target.name.startswith('ruby_') else 1)) + sorted( + targets, + key=lambda target: 0 if target.name.startswith("ruby_") else 1, + ) + ) def targets(): """Gets list of supported targets""" - return _reorder_targets_for_build_speed([ - ProtocArtifact('linux', 'x64', presubmit=True), - ProtocArtifact('linux', 'x86', presubmit=True), - ProtocArtifact('linux', 'aarch64', presubmit=True), - ProtocArtifact('macos', 'x64', presubmit=True), - ProtocArtifact('windows', 'x64', presubmit=True), - ProtocArtifact('windows', 'x86', presubmit=True), - PythonArtifact('manylinux2014', 'x64', 'cp37-cp37m', presubmit=True), - PythonArtifact('manylinux2014', 'x64', 'cp38-cp38', presubmit=True), - PythonArtifact('manylinux2014', 'x64', 'cp39-cp39'), - PythonArtifact('manylinux2014', 'x64', 'cp310-cp310'), - PythonArtifact('manylinux2014', 'x64', 'cp311-cp311', presubmit=True), - PythonArtifact('manylinux2014', 'x86', 'cp37-cp37m', presubmit=True), - PythonArtifact('manylinux2014', 'x86', 'cp38-cp38', presubmit=True), - PythonArtifact('manylinux2014', 'x86', 'cp39-cp39'), - PythonArtifact('manylinux2014', 'x86', 'cp310-cp310'), - PythonArtifact('manylinux2014', 'x86', 'cp311-cp311', presubmit=True), - PythonArtifact('manylinux2014', 'aarch64', 'cp37-cp37m', - presubmit=True), - PythonArtifact('manylinux2014', 'aarch64', 'cp38-cp38', presubmit=True), - PythonArtifact('manylinux2014', 'aarch64', 'cp39-cp39'), - PythonArtifact('manylinux2014', 'aarch64', 'cp310-cp310'), - PythonArtifact('manylinux2014', 'aarch64', 'cp311-cp311'), - PythonArtifact('linux_extra', 'armv7', 'cp37-cp37m', presubmit=True), - PythonArtifact('linux_extra', 'armv7', 'cp38-cp38'), - PythonArtifact('linux_extra', 'armv7', 'cp39-cp39'), - PythonArtifact('linux_extra', 'armv7', 'cp310-cp310'), - PythonArtifact('linux_extra', 'armv7', 'cp311-cp311', presubmit=True), - PythonArtifact('musllinux_1_1', 'x64', 'cp310-cp310'), - PythonArtifact('musllinux_1_1', 'x64', 'cp311-cp311', presubmit=True), - PythonArtifact('musllinux_1_1', 'x64', 'cp37-cp37m', presubmit=True), - PythonArtifact('musllinux_1_1', 'x64', 'cp38-cp38'), - PythonArtifact('musllinux_1_1', 'x64', 'cp39-cp39'), - PythonArtifact('musllinux_1_1', 'x86', 'cp310-cp310'), - PythonArtifact('musllinux_1_1', 'x86', 'cp311-cp311', presubmit=True), - PythonArtifact('musllinux_1_1', 'x86', 'cp37-cp37m', presubmit=True), - PythonArtifact('musllinux_1_1', 'x86', 'cp38-cp38'), - PythonArtifact('musllinux_1_1', 'x86', 'cp39-cp39'), - PythonArtifact('macos', 'x64', 'python3.7', presubmit=True), - PythonArtifact('macos', 'x64', 'python3.8'), - PythonArtifact('macos', 'x64', 'python3.9'), - PythonArtifact('macos', 'x64', 'python3.10', presubmit=True), - PythonArtifact('macos', 'x64', 'python3.11', presubmit=True), - PythonArtifact('windows', 'x86', 'Python37_32bit', presubmit=True), - PythonArtifact('windows', 'x86', 'Python38_32bit'), - PythonArtifact('windows', 'x86', 'Python39_32bit'), - PythonArtifact('windows', 'x86', 'Python310_32bit'), - PythonArtifact('windows', 'x86', 'Python311_32bit', presubmit=True), - PythonArtifact('windows', 'x64', 'Python37', presubmit=True), - PythonArtifact('windows', 'x64', 'Python38'), - PythonArtifact('windows', 'x64', 'Python39'), - PythonArtifact('windows', 'x64', 'Python310'), - PythonArtifact('windows', 'x64', 'Python311', presubmit=True), - RubyArtifact('linux', 'x86-mingw32', presubmit=True), - RubyArtifact('linux', 'x64-mingw32', presubmit=True), - RubyArtifact('linux', 'x64-mingw-ucrt', presubmit=True), - RubyArtifact('linux', 'x86_64-linux', presubmit=True), - RubyArtifact('linux', 'x86-linux', presubmit=True), - RubyArtifact('linux', 'x86_64-darwin', presubmit=True), - RubyArtifact('linux', 'arm64-darwin', presubmit=True), - PHPArtifact('linux', 'x64', presubmit=True), - PHPArtifact('macos', 'x64', presubmit=True), - ]) + return _reorder_targets_for_build_speed( + [ + ProtocArtifact("linux", "x64", presubmit=True), + ProtocArtifact("linux", "x86", presubmit=True), + ProtocArtifact("linux", "aarch64", presubmit=True), + ProtocArtifact("macos", "x64", presubmit=True), + ProtocArtifact("windows", "x64", presubmit=True), + ProtocArtifact("windows", "x86", presubmit=True), + PythonArtifact( + "manylinux2014", "x64", "cp37-cp37m", presubmit=True + ), + PythonArtifact("manylinux2014", "x64", "cp38-cp38", presubmit=True), + PythonArtifact("manylinux2014", "x64", "cp39-cp39"), + PythonArtifact("manylinux2014", "x64", "cp310-cp310"), + PythonArtifact( + "manylinux2014", "x64", "cp311-cp311", presubmit=True + ), + PythonArtifact( + "manylinux2014", "x86", "cp37-cp37m", presubmit=True + ), + PythonArtifact("manylinux2014", "x86", "cp38-cp38", presubmit=True), + PythonArtifact("manylinux2014", "x86", "cp39-cp39"), + PythonArtifact("manylinux2014", "x86", "cp310-cp310"), + PythonArtifact( + "manylinux2014", "x86", "cp311-cp311", presubmit=True + ), + PythonArtifact( + "manylinux2014", "aarch64", "cp37-cp37m", presubmit=True + ), + PythonArtifact( + "manylinux2014", "aarch64", "cp38-cp38", presubmit=True + ), + PythonArtifact("manylinux2014", "aarch64", "cp39-cp39"), + PythonArtifact("manylinux2014", "aarch64", "cp310-cp310"), + PythonArtifact("manylinux2014", "aarch64", "cp311-cp311"), + PythonArtifact( + "linux_extra", "armv7", "cp37-cp37m", presubmit=True + ), + PythonArtifact("linux_extra", "armv7", "cp38-cp38"), + PythonArtifact("linux_extra", "armv7", "cp39-cp39"), + PythonArtifact("linux_extra", "armv7", "cp310-cp310"), + PythonArtifact( + "linux_extra", "armv7", "cp311-cp311", presubmit=True + ), + PythonArtifact("musllinux_1_1", "x64", "cp310-cp310"), + PythonArtifact( + "musllinux_1_1", "x64", "cp311-cp311", presubmit=True + ), + PythonArtifact( + "musllinux_1_1", "x64", "cp37-cp37m", presubmit=True + ), + PythonArtifact("musllinux_1_1", "x64", "cp38-cp38"), + PythonArtifact("musllinux_1_1", "x64", "cp39-cp39"), + PythonArtifact("musllinux_1_1", "x86", "cp310-cp310"), + PythonArtifact( + "musllinux_1_1", "x86", "cp311-cp311", presubmit=True + ), + PythonArtifact( + "musllinux_1_1", "x86", "cp37-cp37m", presubmit=True + ), + PythonArtifact("musllinux_1_1", "x86", "cp38-cp38"), + PythonArtifact("musllinux_1_1", "x86", "cp39-cp39"), + PythonArtifact("macos", "x64", "python3.7", presubmit=True), + PythonArtifact("macos", "x64", "python3.8"), + PythonArtifact("macos", "x64", "python3.9"), + PythonArtifact("macos", "x64", "python3.10", presubmit=True), + PythonArtifact("macos", "x64", "python3.11", presubmit=True), + PythonArtifact("windows", "x86", "Python37_32bit", presubmit=True), + PythonArtifact("windows", "x86", "Python38_32bit"), + PythonArtifact("windows", "x86", "Python39_32bit"), + PythonArtifact("windows", "x86", "Python310_32bit"), + PythonArtifact("windows", "x86", "Python311_32bit", presubmit=True), + PythonArtifact("windows", "x64", "Python37", presubmit=True), + PythonArtifact("windows", "x64", "Python38"), + PythonArtifact("windows", "x64", "Python39"), + PythonArtifact("windows", "x64", "Python310"), + PythonArtifact("windows", "x64", "Python311", presubmit=True), + RubyArtifact("linux", "x86-mingw32", presubmit=True), + RubyArtifact("linux", "x64-mingw32", presubmit=True), + RubyArtifact("linux", "x64-mingw-ucrt", presubmit=True), + RubyArtifact("linux", "x86_64-linux", presubmit=True), + RubyArtifact("linux", "x86-linux", presubmit=True), + RubyArtifact("linux", "x86_64-darwin", presubmit=True), + RubyArtifact("linux", "arm64-darwin", presubmit=True), + PHPArtifact("linux", "x64", presubmit=True), + PHPArtifact("macos", "x64", presubmit=True), + ] + ) diff --git a/tools/run_tests/artifacts/distribtest_targets.py b/tools/run_tests/artifacts/distribtest_targets.py index ed5d14b913612..76091144a560d 100644 --- a/tools/run_tests/artifacts/distribtest_targets.py +++ b/tools/run_tests/artifacts/distribtest_targets.py @@ -17,124 +17,141 @@ import os.path import sys -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) import python_utils.jobset as jobset -def create_docker_jobspec(name, - dockerfile_dir, - shell_command, - environ={}, - flake_retries=0, - timeout_retries=0, - copy_rel_path=None, - timeout_seconds=30 * 60): +def create_docker_jobspec( + name, + dockerfile_dir, + shell_command, + environ={}, + flake_retries=0, + timeout_retries=0, + copy_rel_path=None, + timeout_seconds=30 * 60, +): """Creates jobspec for a task running under docker.""" environ = environ.copy() # the entire repo will be cloned if copy_rel_path is not set. if copy_rel_path: - environ['RELATIVE_COPY_PATH'] = copy_rel_path + environ["RELATIVE_COPY_PATH"] = copy_rel_path docker_args = [] for k, v in list(environ.items()): - docker_args += ['-e', '%s=%s' % (k, v)] + docker_args += ["-e", "%s=%s" % (k, v)] docker_env = { - 'DOCKERFILE_DIR': dockerfile_dir, - 'DOCKER_RUN_SCRIPT': 'tools/run_tests/dockerize/docker_run.sh', - 'DOCKER_RUN_SCRIPT_COMMAND': shell_command, + "DOCKERFILE_DIR": dockerfile_dir, + "DOCKER_RUN_SCRIPT": "tools/run_tests/dockerize/docker_run.sh", + "DOCKER_RUN_SCRIPT_COMMAND": shell_command, } jobspec = jobset.JobSpec( - cmdline=['tools/run_tests/dockerize/build_and_run_docker.sh'] + - docker_args, + cmdline=["tools/run_tests/dockerize/build_and_run_docker.sh"] + + docker_args, environ=docker_env, - shortname='distribtest.%s' % (name), + shortname="distribtest.%s" % (name), timeout_seconds=timeout_seconds, flake_retries=flake_retries, - timeout_retries=timeout_retries) + timeout_retries=timeout_retries, + ) return jobspec -def create_jobspec(name, - cmdline, - environ=None, - shell=False, - flake_retries=0, - timeout_retries=0, - use_workspace=False, - timeout_seconds=10 * 60): +def create_jobspec( + name, + cmdline, + environ=None, + shell=False, + flake_retries=0, + timeout_retries=0, + use_workspace=False, + timeout_seconds=10 * 60, +): """Creates jobspec.""" environ = environ.copy() if use_workspace: - environ['WORKSPACE_NAME'] = 'workspace_%s' % name - cmdline = ['bash', 'tools/run_tests/artifacts/run_in_workspace.sh' - ] + cmdline - jobspec = jobset.JobSpec(cmdline=cmdline, - environ=environ, - shortname='distribtest.%s' % (name), - timeout_seconds=timeout_seconds, - flake_retries=flake_retries, - timeout_retries=timeout_retries, - shell=shell) + environ["WORKSPACE_NAME"] = "workspace_%s" % name + cmdline = [ + "bash", + "tools/run_tests/artifacts/run_in_workspace.sh", + ] + cmdline + jobspec = jobset.JobSpec( + cmdline=cmdline, + environ=environ, + shortname="distribtest.%s" % (name), + timeout_seconds=timeout_seconds, + flake_retries=flake_retries, + timeout_retries=timeout_retries, + shell=shell, + ) return jobspec class CSharpDistribTest(object): """Tests C# NuGet package""" - def __init__(self, - platform, - arch, - docker_suffix=None, - use_dotnet_cli=False, - presubmit=False): - self.name = 'csharp_%s_%s' % (platform, arch) + def __init__( + self, + platform, + arch, + docker_suffix=None, + use_dotnet_cli=False, + presubmit=False, + ): + self.name = "csharp_%s_%s" % (platform, arch) self.platform = platform self.arch = arch self.docker_suffix = docker_suffix - self.labels = ['distribtest', 'csharp', platform, arch] + self.labels = ["distribtest", "csharp", platform, arch] if presubmit: - self.labels.append('presubmit') - self.script_suffix = '' + self.labels.append("presubmit") + self.script_suffix = "" if docker_suffix: - self.name += '_%s' % docker_suffix + self.name += "_%s" % docker_suffix self.labels.append(docker_suffix) if use_dotnet_cli: - self.name += '_dotnetcli' - self.script_suffix = '_dotnetcli' - self.labels.append('dotnetcli') + self.name += "_dotnetcli" + self.script_suffix = "_dotnetcli" + self.labels.append("dotnetcli") else: - self.labels.append('olddotnet') + self.labels.append("olddotnet") def pre_build_jobspecs(self): return [] def build_jobspec(self, inner_jobs=None): del inner_jobs # arg unused as there is little opportunity for parallelizing whats inside the distribtests - if self.platform == 'linux': + if self.platform == "linux": return create_docker_jobspec( self.name, - 'tools/dockerfile/distribtest/csharp_%s_%s' % - (self.docker_suffix, self.arch), - 'test/distrib/csharp/run_distrib_test%s.sh' % - self.script_suffix, - copy_rel_path='test/distrib') - elif self.platform == 'macos': - return create_jobspec(self.name, [ - 'test/distrib/csharp/run_distrib_test%s.sh' % self.script_suffix - ], - environ={ - 'EXTERNAL_GIT_ROOT': '../../../..', - 'SKIP_NETCOREAPP21_DISTRIBTEST': '1', - 'SKIP_NET50_DISTRIBTEST': '1', - }, - use_workspace=True) - elif self.platform == 'windows': + "tools/dockerfile/distribtest/csharp_%s_%s" + % (self.docker_suffix, self.arch), + "test/distrib/csharp/run_distrib_test%s.sh" + % self.script_suffix, + copy_rel_path="test/distrib", + ) + elif self.platform == "macos": + return create_jobspec( + self.name, + [ + "test/distrib/csharp/run_distrib_test%s.sh" + % self.script_suffix + ], + environ={ + "EXTERNAL_GIT_ROOT": "../../../..", + "SKIP_NETCOREAPP21_DISTRIBTEST": "1", + "SKIP_NET50_DISTRIBTEST": "1", + }, + use_workspace=True, + ) + elif self.platform == "windows": # TODO(jtattermusch): re-enable windows distribtest return create_jobspec( self.name, - ['bash', 'tools/run_tests/artifacts/run_distribtest_csharp.sh'], + ["bash", "tools/run_tests/artifacts/run_distribtest_csharp.sh"], environ={}, - use_workspace=True) + use_workspace=True, + ) else: raise Exception("Not supported yet.") @@ -145,23 +162,20 @@ def __str__(self): class PythonDistribTest(object): """Tests Python package""" - def __init__(self, - platform, - arch, - docker_suffix, - source=False, - presubmit=False): + def __init__( + self, platform, arch, docker_suffix, source=False, presubmit=False + ): self.source = source if source: - self.name = 'python_dev_%s_%s_%s' % (platform, arch, docker_suffix) + self.name = "python_dev_%s_%s_%s" % (platform, arch, docker_suffix) else: - self.name = 'python_%s_%s_%s' % (platform, arch, docker_suffix) + self.name = "python_%s_%s_%s" % (platform, arch, docker_suffix) self.platform = platform self.arch = arch self.docker_suffix = docker_suffix - self.labels = ['distribtest', 'python', platform, arch, docker_suffix] + self.labels = ["distribtest", "python", platform, arch, docker_suffix] if presubmit: - self.labels.append('presubmit') + self.labels.append("presubmit") def pre_build_jobspecs(self): return [] @@ -169,23 +183,25 @@ def pre_build_jobspecs(self): def build_jobspec(self, inner_jobs=None): # TODO(jtattermusch): honor inner_jobs arg for this task. del inner_jobs - if not self.platform == 'linux': + if not self.platform == "linux": raise Exception("Not supported yet.") if self.source: return create_docker_jobspec( self.name, - 'tools/dockerfile/distribtest/python_dev_%s_%s' % - (self.docker_suffix, self.arch), - 'test/distrib/python/run_source_distrib_test.sh', - copy_rel_path='test/distrib') + "tools/dockerfile/distribtest/python_dev_%s_%s" + % (self.docker_suffix, self.arch), + "test/distrib/python/run_source_distrib_test.sh", + copy_rel_path="test/distrib", + ) else: return create_docker_jobspec( self.name, - 'tools/dockerfile/distribtest/python_%s_%s' % - (self.docker_suffix, self.arch), - 'test/distrib/python/run_binary_distrib_test.sh', - copy_rel_path='test/distrib') + "tools/dockerfile/distribtest/python_%s_%s" + % (self.docker_suffix, self.arch), + "test/distrib/python/run_binary_distrib_test.sh", + copy_rel_path="test/distrib", + ) def __str__(self): return self.name @@ -194,26 +210,32 @@ def __str__(self): class RubyDistribTest(object): """Tests Ruby package""" - def __init__(self, - platform, - arch, - docker_suffix, - ruby_version=None, - source=False, - presubmit=False): - self.package_type = 'binary' + def __init__( + self, + platform, + arch, + docker_suffix, + ruby_version=None, + source=False, + presubmit=False, + ): + self.package_type = "binary" if source: - self.package_type = 'source' - self.name = 'ruby_%s_%s_%s_version_%s_package_type_%s' % ( - platform, arch, docker_suffix, ruby_version or - 'unspecified', self.package_type) + self.package_type = "source" + self.name = "ruby_%s_%s_%s_version_%s_package_type_%s" % ( + platform, + arch, + docker_suffix, + ruby_version or "unspecified", + self.package_type, + ) self.platform = platform self.arch = arch self.docker_suffix = docker_suffix self.ruby_version = ruby_version - self.labels = ['distribtest', 'ruby', platform, arch, docker_suffix] + self.labels = ["distribtest", "ruby", platform, arch, docker_suffix] if presubmit: - self.labels.append('presubmit') + self.labels.append("presubmit") def pre_build_jobspecs(self): return [] @@ -222,22 +244,25 @@ def build_jobspec(self, inner_jobs=None): # TODO(jtattermusch): honor inner_jobs arg for this task. del inner_jobs arch_to_gem_arch = { - 'x64': 'x86_64', - 'x86': 'x86', + "x64": "x86_64", + "x86": "x86", } - if not self.platform == 'linux': + if not self.platform == "linux": raise Exception("Not supported yet.") - dockerfile_name = 'tools/dockerfile/distribtest/ruby_%s_%s' % ( - self.docker_suffix, self.arch) + dockerfile_name = "tools/dockerfile/distribtest/ruby_%s_%s" % ( + self.docker_suffix, + self.arch, + ) if self.ruby_version is not None: - dockerfile_name += '_%s' % self.ruby_version + dockerfile_name += "_%s" % self.ruby_version return create_docker_jobspec( self.name, dockerfile_name, - 'test/distrib/ruby/run_distrib_test.sh %s %s %s' % - (arch_to_gem_arch[self.arch], self.platform, self.package_type), - copy_rel_path='test/distrib') + "test/distrib/ruby/run_distrib_test.sh %s %s %s" + % (arch_to_gem_arch[self.arch], self.platform, self.package_type), + copy_rel_path="test/distrib", + ) def __str__(self): return self.name @@ -247,13 +272,13 @@ class PHP7DistribTest(object): """Tests PHP7 package""" def __init__(self, platform, arch, docker_suffix=None, presubmit=False): - self.name = 'php7_%s_%s_%s' % (platform, arch, docker_suffix) + self.name = "php7_%s_%s_%s" % (platform, arch, docker_suffix) self.platform = platform self.arch = arch self.docker_suffix = docker_suffix - self.labels = ['distribtest', 'php', 'php7', platform, arch] + self.labels = ["distribtest", "php", "php7", platform, arch] if presubmit: - self.labels.append('presubmit') + self.labels.append("presubmit") if docker_suffix: self.labels.append(docker_suffix) @@ -263,19 +288,22 @@ def pre_build_jobspecs(self): def build_jobspec(self, inner_jobs=None): # TODO(jtattermusch): honor inner_jobs arg for this task. del inner_jobs - if self.platform == 'linux': + if self.platform == "linux": return create_docker_jobspec( self.name, - 'tools/dockerfile/distribtest/php7_%s_%s' % - (self.docker_suffix, self.arch), - 'test/distrib/php/run_distrib_test.sh', - copy_rel_path='test/distrib') - elif self.platform == 'macos': + "tools/dockerfile/distribtest/php7_%s_%s" + % (self.docker_suffix, self.arch), + "test/distrib/php/run_distrib_test.sh", + copy_rel_path="test/distrib", + ) + elif self.platform == "macos": return create_jobspec( - self.name, ['test/distrib/php/run_distrib_test_macos.sh'], - environ={'EXTERNAL_GIT_ROOT': '../../../..'}, + self.name, + ["test/distrib/php/run_distrib_test_macos.sh"], + environ={"EXTERNAL_GIT_ROOT": "../../../.."}, timeout_seconds=20 * 60, - use_workspace=True) + use_workspace=True, + ) else: raise Exception("Not supported yet.") @@ -286,30 +314,31 @@ def __str__(self): class CppDistribTest(object): """Tests Cpp make install by building examples.""" - def __init__(self, - platform, - arch, - docker_suffix=None, - testcase=None, - presubmit=False): - if platform == 'linux': - self.name = 'cpp_%s_%s_%s_%s' % (platform, arch, docker_suffix, - testcase) + def __init__( + self, platform, arch, docker_suffix=None, testcase=None, presubmit=False + ): + if platform == "linux": + self.name = "cpp_%s_%s_%s_%s" % ( + platform, + arch, + docker_suffix, + testcase, + ) else: - self.name = 'cpp_%s_%s_%s' % (platform, arch, testcase) + self.name = "cpp_%s_%s_%s" % (platform, arch, testcase) self.platform = platform self.arch = arch self.docker_suffix = docker_suffix self.testcase = testcase self.labels = [ - 'distribtest', - 'cpp', + "distribtest", + "cpp", platform, arch, testcase, ] if presubmit: - self.labels.append('presubmit') + self.labels.append("presubmit") if docker_suffix: self.labels.append(docker_suffix) @@ -320,23 +349,26 @@ def build_jobspec(self, inner_jobs=None): environ = {} if inner_jobs is not None: # set number of parallel jobs for the C++ build - environ['GRPC_CPP_DISTRIBTEST_BUILD_COMPILER_JOBS'] = str( - inner_jobs) + environ["GRPC_CPP_DISTRIBTEST_BUILD_COMPILER_JOBS"] = str( + inner_jobs + ) - if self.platform == 'linux': + if self.platform == "linux": return create_docker_jobspec( self.name, - 'tools/dockerfile/distribtest/cpp_%s_%s' % - (self.docker_suffix, self.arch), - 'test/distrib/cpp/run_distrib_test_%s.sh' % self.testcase, - timeout_seconds=45 * 60) - elif self.platform == 'windows': + "tools/dockerfile/distribtest/cpp_%s_%s" + % (self.docker_suffix, self.arch), + "test/distrib/cpp/run_distrib_test_%s.sh" % self.testcase, + timeout_seconds=45 * 60, + ) + elif self.platform == "windows": return create_jobspec( self.name, - ['test\\distrib\\cpp\\run_distrib_test_%s.bat' % self.testcase], + ["test\\distrib\\cpp\\run_distrib_test_%s.bat" % self.testcase], environ={}, timeout_seconds=45 * 60, - use_workspace=True) + use_workspace=True, + ) else: raise Exception("Not supported yet.") @@ -348,104 +380,95 @@ def targets(): """Gets list of supported targets""" return [ # C++ - CppDistribTest('linux', 'x64', 'debian10', 'cmake', presubmit=True), - CppDistribTest('linux', - 'x64', - 'debian10', - 'cmake_as_submodule', - presubmit=True), - CppDistribTest('linux', - 'x64', - 'debian10', - 'cmake_as_externalproject', - presubmit=True), - CppDistribTest('linux', - 'x64', - 'debian10', - 'cmake_fetchcontent', - presubmit=True), - CppDistribTest('linux', - 'x64', - 'debian10', - 'cmake_module_install', - presubmit=True), - CppDistribTest('linux', - 'x64', - 'debian10', - 'cmake_pkgconfig', - presubmit=True), - CppDistribTest('linux', - 'x64', - 'debian10_aarch64_cross', - 'cmake_aarch64_cross', - presubmit=True), - CppDistribTest('windows', 'x86', testcase='cmake', presubmit=True), - CppDistribTest('windows', - 'x86', - testcase='cmake_as_externalproject', - presubmit=True), + CppDistribTest("linux", "x64", "debian10", "cmake", presubmit=True), + CppDistribTest( + "linux", "x64", "debian10", "cmake_as_submodule", presubmit=True + ), + CppDistribTest( + "linux", + "x64", + "debian10", + "cmake_as_externalproject", + presubmit=True, + ), + CppDistribTest( + "linux", "x64", "debian10", "cmake_fetchcontent", presubmit=True + ), + CppDistribTest( + "linux", "x64", "debian10", "cmake_module_install", presubmit=True + ), + CppDistribTest( + "linux", "x64", "debian10", "cmake_pkgconfig", presubmit=True + ), + CppDistribTest( + "linux", + "x64", + "debian10_aarch64_cross", + "cmake_aarch64_cross", + presubmit=True, + ), + CppDistribTest("windows", "x86", testcase="cmake", presubmit=True), + CppDistribTest( + "windows", + "x86", + testcase="cmake_as_externalproject", + presubmit=True, + ), # C# - CSharpDistribTest('linux', - 'x64', - 'debian10', - use_dotnet_cli=True, - presubmit=True), - CSharpDistribTest('linux', 'x64', 'ubuntu1604', use_dotnet_cli=True), - CSharpDistribTest('linux', - 'x64', - 'alpine', - use_dotnet_cli=True, - presubmit=True), - CSharpDistribTest('linux', - 'x64', - 'dotnet31', - use_dotnet_cli=True, - presubmit=True), - CSharpDistribTest('linux', - 'x64', - 'dotnet5', - use_dotnet_cli=True, - presubmit=True), - CSharpDistribTest('macos', 'x64', use_dotnet_cli=True, presubmit=True), - CSharpDistribTest('windows', 'x86', presubmit=True), - CSharpDistribTest('windows', 'x64', presubmit=True), + CSharpDistribTest( + "linux", "x64", "debian10", use_dotnet_cli=True, presubmit=True + ), + CSharpDistribTest("linux", "x64", "ubuntu1604", use_dotnet_cli=True), + CSharpDistribTest( + "linux", "x64", "alpine", use_dotnet_cli=True, presubmit=True + ), + CSharpDistribTest( + "linux", "x64", "dotnet31", use_dotnet_cli=True, presubmit=True + ), + CSharpDistribTest( + "linux", "x64", "dotnet5", use_dotnet_cli=True, presubmit=True + ), + CSharpDistribTest("macos", "x64", use_dotnet_cli=True, presubmit=True), + CSharpDistribTest("windows", "x86", presubmit=True), + CSharpDistribTest("windows", "x64", presubmit=True), # Python - PythonDistribTest('linux', 'x64', 'buster', presubmit=True), - PythonDistribTest('linux', 'x86', 'buster', presubmit=True), - PythonDistribTest('linux', 'x64', 'fedora34'), - PythonDistribTest('linux', 'x64', 'arch'), - PythonDistribTest('linux', 'x64', 'alpine'), - PythonDistribTest('linux', 'x64', 'ubuntu2004'), - PythonDistribTest('linux', 'aarch64', 'python38_buster', - presubmit=True), - PythonDistribTest('linux', - 'x64', - 'alpine3.7', - source=True, - presubmit=True), - PythonDistribTest('linux', 'x64', 'buster', source=True, - presubmit=True), - PythonDistribTest('linux', 'x86', 'buster', source=True, - presubmit=True), - PythonDistribTest('linux', 'x64', 'fedora34', source=True), - PythonDistribTest('linux', 'x64', 'arch', source=True), - PythonDistribTest('linux', 'x64', 'ubuntu2004', source=True), + PythonDistribTest("linux", "x64", "buster", presubmit=True), + PythonDistribTest("linux", "x86", "buster", presubmit=True), + PythonDistribTest("linux", "x64", "fedora34"), + PythonDistribTest("linux", "x64", "arch"), + PythonDistribTest("linux", "x64", "alpine"), + PythonDistribTest("linux", "x64", "ubuntu2004"), + PythonDistribTest( + "linux", "aarch64", "python38_buster", presubmit=True + ), + PythonDistribTest( + "linux", "x64", "alpine3.7", source=True, presubmit=True + ), + PythonDistribTest( + "linux", "x64", "buster", source=True, presubmit=True + ), + PythonDistribTest( + "linux", "x86", "buster", source=True, presubmit=True + ), + PythonDistribTest("linux", "x64", "fedora34", source=True), + PythonDistribTest("linux", "x64", "arch", source=True), + PythonDistribTest("linux", "x64", "ubuntu2004", source=True), # Ruby - RubyDistribTest('linux', - 'x64', - 'debian10', - ruby_version='ruby_2_6', - source=True, - presubmit=True), - RubyDistribTest('linux', - 'x64', - 'debian10', - ruby_version='ruby_2_7', - presubmit=True), - RubyDistribTest('linux', 'x64', 'centos7'), - RubyDistribTest('linux', 'x64', 'ubuntu1604'), - RubyDistribTest('linux', 'x64', 'ubuntu1804', presubmit=True), + RubyDistribTest( + "linux", + "x64", + "debian10", + ruby_version="ruby_2_6", + source=True, + presubmit=True, + ), + RubyDistribTest( + "linux", "x64", "debian10", ruby_version="ruby_2_7", presubmit=True + ), + RubyDistribTest("linux", "x64", "centos7"), + RubyDistribTest("linux", "x64", "ubuntu1604"), + RubyDistribTest("linux", "x64", "ubuntu1804", presubmit=True), # PHP7 - PHP7DistribTest('linux', 'x64', 'debian10', presubmit=True), - PHP7DistribTest('macos', 'x64', presubmit=True), + PHP7DistribTest("linux", "x64", "debian10", presubmit=True), + PHP7DistribTest("macos", "x64", presubmit=True), ] diff --git a/tools/run_tests/artifacts/package_targets.py b/tools/run_tests/artifacts/package_targets.py index be7c1f544ba83..dabfb452880d9 100644 --- a/tools/run_tests/artifacts/package_targets.py +++ b/tools/run_tests/artifacts/package_targets.py @@ -17,57 +17,64 @@ import os.path import sys -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) import python_utils.jobset as jobset -def create_docker_jobspec(name, - dockerfile_dir, - shell_command, - environ={}, - flake_retries=0, - timeout_retries=0): +def create_docker_jobspec( + name, + dockerfile_dir, + shell_command, + environ={}, + flake_retries=0, + timeout_retries=0, +): """Creates jobspec for a task running under docker.""" environ = environ.copy() docker_args = [] for k, v in list(environ.items()): - docker_args += ['-e', '%s=%s' % (k, v)] + docker_args += ["-e", "%s=%s" % (k, v)] docker_env = { - 'DOCKERFILE_DIR': dockerfile_dir, - 'DOCKER_RUN_SCRIPT': 'tools/run_tests/dockerize/docker_run.sh', - 'DOCKER_RUN_SCRIPT_COMMAND': shell_command, - 'OUTPUT_DIR': 'artifacts' + "DOCKERFILE_DIR": dockerfile_dir, + "DOCKER_RUN_SCRIPT": "tools/run_tests/dockerize/docker_run.sh", + "DOCKER_RUN_SCRIPT_COMMAND": shell_command, + "OUTPUT_DIR": "artifacts", } jobspec = jobset.JobSpec( - cmdline=['tools/run_tests/dockerize/build_and_run_docker.sh'] + - docker_args, + cmdline=["tools/run_tests/dockerize/build_and_run_docker.sh"] + + docker_args, environ=docker_env, - shortname='build_package.%s' % (name), + shortname="build_package.%s" % (name), timeout_seconds=30 * 60, flake_retries=flake_retries, - timeout_retries=timeout_retries) + timeout_retries=timeout_retries, + ) return jobspec -def create_jobspec(name, - cmdline, - environ=None, - cwd=None, - shell=False, - flake_retries=0, - timeout_retries=0, - cpu_cost=1.0): +def create_jobspec( + name, + cmdline, + environ=None, + cwd=None, + shell=False, + flake_retries=0, + timeout_retries=0, + cpu_cost=1.0, +): """Creates jobspec.""" - jobspec = jobset.JobSpec(cmdline=cmdline, - environ=environ, - cwd=cwd, - shortname='build_package.%s' % (name), - timeout_seconds=10 * 60, - flake_retries=flake_retries, - timeout_retries=timeout_retries, - cpu_cost=cpu_cost, - shell=shell) + jobspec = jobset.JobSpec( + cmdline=cmdline, + environ=environ, + cwd=cwd, + shortname="build_package.%s" % (name), + timeout_seconds=10 * 60, + flake_retries=flake_retries, + timeout_retries=timeout_retries, + cpu_cost=cpu_cost, + shell=shell, + ) return jobspec @@ -76,9 +83,9 @@ class CSharpPackage: def __init__(self, platform): self.platform = platform - self.labels = ['package', 'csharp', self.platform] - self.name = 'csharp_package_nuget_%s' % self.platform - self.labels += ['nuget'] + self.labels = ["package", "csharp", self.platform] + self.name = "csharp_package_nuget_%s" % self.platform + self.labels += ["nuget"] def pre_build_jobspecs(self): return [] @@ -86,24 +93,28 @@ def pre_build_jobspecs(self): def build_jobspec(self, inner_jobs=None): del inner_jobs # arg unused as there is little opportunity for parallelizing environ = { - 'GRPC_CSHARP_BUILD_SINGLE_PLATFORM_NUGET': - os.getenv('GRPC_CSHARP_BUILD_SINGLE_PLATFORM_NUGET', '') + "GRPC_CSHARP_BUILD_SINGLE_PLATFORM_NUGET": os.getenv( + "GRPC_CSHARP_BUILD_SINGLE_PLATFORM_NUGET", "" + ) } - build_script = 'src/csharp/build_nuget.sh' + build_script = "src/csharp/build_nuget.sh" - if self.platform == 'linux': + if self.platform == "linux": return create_docker_jobspec( self.name, - 'tools/dockerfile/test/csharp_debian11_x64', + "tools/dockerfile/test/csharp_debian11_x64", build_script, - environ=environ) + environ=environ, + ) else: - repo_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), - '..', '..', '..') - environ['EXTERNAL_GIT_ROOT'] = repo_root - return create_jobspec(self.name, ['bash', build_script], - environ=environ) + repo_root = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "..", ".." + ) + environ["EXTERNAL_GIT_ROOT"] = repo_root + return create_jobspec( + self.name, ["bash", build_script], environ=environ + ) def __str__(self): return self.name @@ -113,8 +124,8 @@ class RubyPackage: """Collects ruby gems created in the artifact phase""" def __init__(self): - self.name = 'ruby_package' - self.labels = ['package', 'ruby', 'linux'] + self.name = "ruby_package" + self.labels = ["package", "ruby", "linux"] def pre_build_jobspecs(self): return [] @@ -122,16 +133,18 @@ def pre_build_jobspecs(self): def build_jobspec(self, inner_jobs=None): del inner_jobs # arg unused as this step simply collects preexisting artifacts return create_docker_jobspec( - self.name, 'tools/dockerfile/grpc_artifact_centos6_x64', - 'tools/run_tests/artifacts/build_package_ruby.sh') + self.name, + "tools/dockerfile/grpc_artifact_centos6_x64", + "tools/run_tests/artifacts/build_package_ruby.sh", + ) class PythonPackage: """Collects python eggs and wheels created in the artifact phase""" def __init__(self): - self.name = 'python_package' - self.labels = ['package', 'python', 'linux'] + self.name = "python_package" + self.labels = ["package", "python", "linux"] def pre_build_jobspecs(self): return [] @@ -143,17 +156,18 @@ def build_jobspec(self, inner_jobs=None): # for artifact building seems natural. return create_docker_jobspec( self.name, - 'tools/dockerfile/grpc_artifact_python_manylinux2014_x64', - 'tools/run_tests/artifacts/build_package_python.sh', - environ={'PYTHON': '/opt/python/cp39-cp39/bin/python'}) + "tools/dockerfile/grpc_artifact_python_manylinux2014_x64", + "tools/run_tests/artifacts/build_package_python.sh", + environ={"PYTHON": "/opt/python/cp39-cp39/bin/python"}, + ) class PHPPackage: """Copy PHP PECL package artifact""" def __init__(self): - self.name = 'php_package' - self.labels = ['package', 'php', 'linux'] + self.name = "php_package" + self.labels = ["package", "php", "linux"] def pre_build_jobspecs(self): return [] @@ -161,17 +175,19 @@ def pre_build_jobspecs(self): def build_jobspec(self, inner_jobs=None): del inner_jobs # arg unused as this step simply collects preexisting artifacts return create_docker_jobspec( - self.name, 'tools/dockerfile/grpc_artifact_centos6_x64', - 'tools/run_tests/artifacts/build_package_php.sh') + self.name, + "tools/dockerfile/grpc_artifact_centos6_x64", + "tools/run_tests/artifacts/build_package_php.sh", + ) def targets(): """Gets list of supported targets""" return [ - CSharpPackage('linux'), - CSharpPackage('macos'), - CSharpPackage('windows'), + CSharpPackage("linux"), + CSharpPackage("macos"), + CSharpPackage("windows"), RubyPackage(), PythonPackage(), - PHPPackage() + PHPPackage(), ] diff --git a/tools/run_tests/lb_interop_tests/gen_build_yaml.py b/tools/run_tests/lb_interop_tests/gen_build_yaml.py index c80835d74abcc..f4e92fa51f251 100755 --- a/tools/run_tests/lb_interop_tests/gen_build_yaml.py +++ b/tools/run_tests/lb_interop_tests/gen_build_yaml.py @@ -32,31 +32,33 @@ def server_sec(transport_sec): - if transport_sec == 'google_default_credentials': - return 'alts', 'alts', 'tls' + if transport_sec == "google_default_credentials": + return "alts", "alts", "tls" return transport_sec, transport_sec, transport_sec def generate_no_balancer_because_lb_a_record_returns_nx_domain(): all_configs = [] for transport_sec in [ - 'insecure', 'alts', 'tls', 'google_default_credentials' + "insecure", + "alts", + "tls", + "google_default_credentials", ]: balancer_sec, backend_sec, fallback_sec = server_sec(transport_sec) config = { - 'name': - 'no_balancer_because_lb_a_record_returns_nx_domain_%s' % - transport_sec, - 'skip_langs': [], - 'transport_sec': - transport_sec, - 'balancer_configs': [], - 'backend_configs': [], - 'fallback_configs': [{ - 'transport_sec': fallback_sec, - }], - 'cause_no_error_no_data_for_balancer_a_record': - False, + "name": "no_balancer_because_lb_a_record_returns_nx_domain_%s" + % transport_sec, + "skip_langs": [], + "transport_sec": transport_sec, + "balancer_configs": [], + "backend_configs": [], + "fallback_configs": [ + { + "transport_sec": fallback_sec, + } + ], + "cause_no_error_no_data_for_balancer_a_record": False, } all_configs.append(config) return all_configs @@ -68,23 +70,25 @@ def generate_no_balancer_because_lb_a_record_returns_nx_domain(): def generate_no_balancer_because_lb_a_record_returns_no_data(): all_configs = [] for transport_sec in [ - 'insecure', 'alts', 'tls', 'google_default_credentials' + "insecure", + "alts", + "tls", + "google_default_credentials", ]: balancer_sec, backend_sec, fallback_sec = server_sec(transport_sec) config = { - 'name': - 'no_balancer_because_lb_a_record_returns_no_data_%s' % - transport_sec, - 'skip_langs': [], - 'transport_sec': - transport_sec, - 'balancer_configs': [], - 'backend_configs': [], - 'fallback_configs': [{ - 'transport_sec': fallback_sec, - }], - 'cause_no_error_no_data_for_balancer_a_record': - True, + "name": "no_balancer_because_lb_a_record_returns_no_data_%s" + % transport_sec, + "skip_langs": [], + "transport_sec": transport_sec, + "balancer_configs": [], + "backend_configs": [], + "fallback_configs": [ + { + "transport_sec": fallback_sec, + } + ], + "cause_no_error_no_data_for_balancer_a_record": True, } all_configs.append(config) return all_configs @@ -97,32 +101,35 @@ def generate_client_referred_to_backend(): all_configs = [] for balancer_short_stream in [True, False]: for transport_sec in [ - 'insecure', 'alts', 'tls', 'google_default_credentials' + "insecure", + "alts", + "tls", + "google_default_credentials", ]: balancer_sec, backend_sec, fallback_sec = server_sec(transport_sec) skip_langs = [] - if transport_sec == 'tls': - skip_langs += ['java'] + if transport_sec == "tls": + skip_langs += ["java"] if balancer_short_stream: - skip_langs += ['java'] + skip_langs += ["java"] config = { - 'name': - 'client_referred_to_backend_%s_short_stream_%s' % - (transport_sec, balancer_short_stream), - 'skip_langs': - skip_langs, - 'transport_sec': - transport_sec, - 'balancer_configs': [{ - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, - }], - 'backend_configs': [{ - 'transport_sec': backend_sec, - }], - 'fallback_configs': [], - 'cause_no_error_no_data_for_balancer_a_record': - False, + "name": "client_referred_to_backend_%s_short_stream_%s" + % (transport_sec, balancer_short_stream), + "skip_langs": skip_langs, + "transport_sec": transport_sec, + "balancer_configs": [ + { + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, + } + ], + "backend_configs": [ + { + "transport_sec": backend_sec, + } + ], + "fallback_configs": [], + "cause_no_error_no_data_for_balancer_a_record": False, } all_configs.append(config) return all_configs @@ -134,33 +141,35 @@ def generate_client_referred_to_backend(): def generate_client_referred_to_backend_fallback_broken(): all_configs = [] for balancer_short_stream in [True, False]: - for transport_sec in ['alts', 'tls', 'google_default_credentials']: + for transport_sec in ["alts", "tls", "google_default_credentials"]: balancer_sec, backend_sec, fallback_sec = server_sec(transport_sec) skip_langs = [] - if transport_sec == 'tls': - skip_langs += ['java'] + if transport_sec == "tls": + skip_langs += ["java"] if balancer_short_stream: - skip_langs += ['java'] + skip_langs += ["java"] config = { - 'name': - 'client_referred_to_backend_fallback_broken_%s_short_stream_%s' - % (transport_sec, balancer_short_stream), - 'skip_langs': - skip_langs, - 'transport_sec': - transport_sec, - 'balancer_configs': [{ - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, - }], - 'backend_configs': [{ - 'transport_sec': backend_sec, - }], - 'fallback_configs': [{ - 'transport_sec': 'insecure', - }], - 'cause_no_error_no_data_for_balancer_a_record': - False, + "name": "client_referred_to_backend_fallback_broken_%s_short_stream_%s" + % (transport_sec, balancer_short_stream), + "skip_langs": skip_langs, + "transport_sec": transport_sec, + "balancer_configs": [ + { + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, + } + ], + "backend_configs": [ + { + "transport_sec": backend_sec, + } + ], + "fallback_configs": [ + { + "transport_sec": "insecure", + } + ], + "cause_no_error_no_data_for_balancer_a_record": False, } all_configs.append(config) return all_configs @@ -173,40 +182,47 @@ def generate_client_referred_to_backend_multiple_backends(): all_configs = [] for balancer_short_stream in [True, False]: for transport_sec in [ - 'insecure', 'alts', 'tls', 'google_default_credentials' + "insecure", + "alts", + "tls", + "google_default_credentials", ]: balancer_sec, backend_sec, fallback_sec = server_sec(transport_sec) skip_langs = [] - if transport_sec == 'tls': - skip_langs += ['java'] + if transport_sec == "tls": + skip_langs += ["java"] if balancer_short_stream: - skip_langs += ['java'] + skip_langs += ["java"] config = { - 'name': - 'client_referred_to_backend_multiple_backends_%s_short_stream_%s' - % (transport_sec, balancer_short_stream), - 'skip_langs': - skip_langs, - 'transport_sec': - transport_sec, - 'balancer_configs': [{ - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, - }], - 'backend_configs': [{ - 'transport_sec': backend_sec, - }, { - 'transport_sec': backend_sec, - }, { - 'transport_sec': backend_sec, - }, { - 'transport_sec': backend_sec, - }, { - 'transport_sec': backend_sec, - }], - 'fallback_configs': [], - 'cause_no_error_no_data_for_balancer_a_record': - False, + "name": "client_referred_to_backend_multiple_backends_%s_short_stream_%s" + % (transport_sec, balancer_short_stream), + "skip_langs": skip_langs, + "transport_sec": transport_sec, + "balancer_configs": [ + { + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, + } + ], + "backend_configs": [ + { + "transport_sec": backend_sec, + }, + { + "transport_sec": backend_sec, + }, + { + "transport_sec": backend_sec, + }, + { + "transport_sec": backend_sec, + }, + { + "transport_sec": backend_sec, + }, + ], + "fallback_configs": [], + "cause_no_error_no_data_for_balancer_a_record": False, } all_configs.append(config) return all_configs @@ -219,32 +235,37 @@ def generate_client_falls_back_because_no_backends(): all_configs = [] for balancer_short_stream in [True, False]: for transport_sec in [ - 'insecure', 'alts', 'tls', 'google_default_credentials' + "insecure", + "alts", + "tls", + "google_default_credentials", ]: balancer_sec, backend_sec, fallback_sec = server_sec(transport_sec) - skip_langs = ['go', 'java'] - if transport_sec == 'tls': - skip_langs += ['java'] + skip_langs = ["go", "java"] + if transport_sec == "tls": + skip_langs += ["java"] if balancer_short_stream: - skip_langs += ['java'] + skip_langs += ["java"] config = { - 'name': - 'client_falls_back_because_no_backends_%s_short_stream_%s' % - (transport_sec, balancer_short_stream), - 'skip_langs': - skip_langs, - 'transport_sec': - transport_sec, - 'balancer_configs': [{ - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, - }], - 'backend_configs': [], - 'fallback_configs': [{ - 'transport_sec': fallback_sec, - }], - 'cause_no_error_no_data_for_balancer_a_record': - False, + "name": ( + "client_falls_back_because_no_backends_%s_short_stream_%s" + ) + % (transport_sec, balancer_short_stream), + "skip_langs": skip_langs, + "transport_sec": transport_sec, + "balancer_configs": [ + { + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, + } + ], + "backend_configs": [], + "fallback_configs": [ + { + "transport_sec": fallback_sec, + } + ], + "cause_no_error_no_data_for_balancer_a_record": False, } all_configs.append(config) return all_configs @@ -255,29 +276,29 @@ def generate_client_falls_back_because_no_backends(): def generate_client_falls_back_because_balancer_connection_broken(): all_configs = [] - for transport_sec in ['alts', 'tls', 'google_default_credentials']: + for transport_sec in ["alts", "tls", "google_default_credentials"]: balancer_sec, backend_sec, fallback_sec = server_sec(transport_sec) skip_langs = [] - if transport_sec == 'tls': - skip_langs = ['java'] + if transport_sec == "tls": + skip_langs = ["java"] config = { - 'name': - 'client_falls_back_because_balancer_connection_broken_%s' % - transport_sec, - 'skip_langs': - skip_langs, - 'transport_sec': - transport_sec, - 'balancer_configs': [{ - 'transport_sec': 'insecure', - 'short_stream': False, - }], - 'backend_configs': [], - 'fallback_configs': [{ - 'transport_sec': fallback_sec, - }], - 'cause_no_error_no_data_for_balancer_a_record': - False, + "name": "client_falls_back_because_balancer_connection_broken_%s" + % transport_sec, + "skip_langs": skip_langs, + "transport_sec": transport_sec, + "balancer_configs": [ + { + "transport_sec": "insecure", + "short_stream": False, + } + ], + "backend_configs": [], + "fallback_configs": [ + { + "transport_sec": fallback_sec, + } + ], + "cause_no_error_no_data_for_balancer_a_record": False, } all_configs.append(config) return all_configs @@ -290,50 +311,51 @@ def generate_client_referred_to_backend_multiple_balancers(): all_configs = [] for balancer_short_stream in [True, False]: for transport_sec in [ - 'insecure', 'alts', 'tls', 'google_default_credentials' + "insecure", + "alts", + "tls", + "google_default_credentials", ]: balancer_sec, backend_sec, fallback_sec = server_sec(transport_sec) skip_langs = [] - if transport_sec == 'tls': - skip_langs += ['java'] + if transport_sec == "tls": + skip_langs += ["java"] if balancer_short_stream: - skip_langs += ['java'] + skip_langs += ["java"] config = { - 'name': - 'client_referred_to_backend_multiple_balancers_%s_short_stream_%s' - % (transport_sec, balancer_short_stream), - 'skip_langs': - skip_langs, - 'transport_sec': - transport_sec, - 'balancer_configs': [ + "name": "client_referred_to_backend_multiple_balancers_%s_short_stream_%s" + % (transport_sec, balancer_short_stream), + "skip_langs": skip_langs, + "transport_sec": transport_sec, + "balancer_configs": [ { - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, }, { - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, }, { - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, }, { - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, }, { - 'transport_sec': balancer_sec, - 'short_stream': balancer_short_stream, + "transport_sec": balancer_sec, + "short_stream": balancer_short_stream, }, ], - 'backend_configs': [{ - 'transport_sec': backend_sec, - },], - 'fallback_configs': [], - 'cause_no_error_no_data_for_balancer_a_record': - False, + "backend_configs": [ + { + "transport_sec": backend_sec, + }, + ], + "fallback_configs": [], + "cause_no_error_no_data_for_balancer_a_record": False, } all_configs.append(config) return all_configs @@ -341,6 +363,12 @@ def generate_client_referred_to_backend_multiple_balancers(): all_scenarios += generate_client_referred_to_backend_multiple_balancers() -print((yaml.dump({ - 'lb_interop_test_scenarios': all_scenarios, -}))) +print( + ( + yaml.dump( + { + "lb_interop_test_scenarios": all_scenarios, + } + ) + ) +) diff --git a/tools/run_tests/performance/bq_upload_result.py b/tools/run_tests/performance/bq_upload_result.py index 727f4d5924d49..cb734efb7281b 100755 --- a/tools/run_tests/performance/bq_upload_result.py +++ b/tools/run_tests/performance/bq_upload_result.py @@ -26,53 +26,64 @@ import uuid gcp_utils_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), '../../gcp/utils')) + os.path.join(os.path.dirname(__file__), "../../gcp/utils") +) sys.path.append(gcp_utils_dir) import big_query_utils -_PROJECT_ID = 'grpc-testing' +_PROJECT_ID = "grpc-testing" def _upload_netperf_latency_csv_to_bigquery(dataset_id, table_id, result_file): - with open(result_file, 'r') as f: - (col1, col2, col3) = f.read().split(',') + with open(result_file, "r") as f: + (col1, col2, col3) = f.read().split(",") latency50 = float(col1.strip()) * 1000 latency90 = float(col2.strip()) * 1000 latency99 = float(col3.strip()) * 1000 scenario_result = { - 'scenario': { - 'name': 'netperf_tcp_rr' + "scenario": {"name": "netperf_tcp_rr"}, + "summary": { + "latency50": latency50, + "latency90": latency90, + "latency99": latency99, }, - 'summary': { - 'latency50': latency50, - 'latency90': latency90, - 'latency99': latency99 - } } bq = big_query_utils.create_big_query() _create_results_table(bq, dataset_id, table_id) if not _insert_result( - bq, dataset_id, table_id, scenario_result, flatten=False): - print('Error uploading result to bigquery.') + bq, dataset_id, table_id, scenario_result, flatten=False + ): + print("Error uploading result to bigquery.") sys.exit(1) -def _upload_scenario_result_to_bigquery(dataset_id, table_id, result_file, - metadata_file, node_info_file, - prometheus_query_results_file): - with open(result_file, 'r') as f: +def _upload_scenario_result_to_bigquery( + dataset_id, + table_id, + result_file, + metadata_file, + node_info_file, + prometheus_query_results_file, +): + with open(result_file, "r") as f: scenario_result = json.loads(f.read()) bq = big_query_utils.create_big_query() _create_results_table(bq, dataset_id, table_id) - if not _insert_scenario_result(bq, dataset_id, table_id, scenario_result, - metadata_file, node_info_file, - prometheus_query_results_file): - print('Error uploading result to bigquery.') + if not _insert_scenario_result( + bq, + dataset_id, + table_id, + scenario_result, + metadata_file, + node_info_file, + prometheus_query_results_file, + ): + print("Error uploading result to bigquery.") sys.exit(1) @@ -81,70 +92,85 @@ def _insert_result(bq, dataset_id, table_id, scenario_result, flatten=True): _flatten_result_inplace(scenario_result) _populate_metadata_inplace(scenario_result) row = big_query_utils.make_row(str(uuid.uuid4()), scenario_result) - return big_query_utils.insert_rows(bq, _PROJECT_ID, dataset_id, table_id, - [row]) - - -def _insert_scenario_result(bq, - dataset_id, - table_id, - scenario_result, - test_metadata_file, - node_info_file, - prometheus_query_results_file, - flatten=True): + return big_query_utils.insert_rows( + bq, _PROJECT_ID, dataset_id, table_id, [row] + ) + + +def _insert_scenario_result( + bq, + dataset_id, + table_id, + scenario_result, + test_metadata_file, + node_info_file, + prometheus_query_results_file, + flatten=True, +): if flatten: _flatten_result_inplace(scenario_result) _populate_metadata_from_file(scenario_result, test_metadata_file) _populate_node_metadata_from_file(scenario_result, node_info_file) - _populate_prometheus_query_results_from_file(scenario_result, - prometheus_query_results_file) + _populate_prometheus_query_results_from_file( + scenario_result, prometheus_query_results_file + ) row = big_query_utils.make_row(str(uuid.uuid4()), scenario_result) - return big_query_utils.insert_rows(bq, _PROJECT_ID, dataset_id, table_id, - [row]) + return big_query_utils.insert_rows( + bq, _PROJECT_ID, dataset_id, table_id, [row] + ) def _create_results_table(bq, dataset_id, table_id): - with open(os.path.dirname(__file__) + '/scenario_result_schema.json', - 'r') as f: + with open( + os.path.dirname(__file__) + "/scenario_result_schema.json", "r" + ) as f: table_schema = json.loads(f.read()) - desc = 'Results of performance benchmarks.' - return big_query_utils.create_table2(bq, _PROJECT_ID, dataset_id, table_id, - table_schema, desc) + desc = "Results of performance benchmarks." + return big_query_utils.create_table2( + bq, _PROJECT_ID, dataset_id, table_id, table_schema, desc + ) def _flatten_result_inplace(scenario_result): """Bigquery is not really great for handling deeply nested data - and repeated fields. To maintain values of some fields while keeping - the schema relatively simple, we artificially leave some of the fields - as JSON strings. - """ - scenario_result['scenario']['clientConfig'] = json.dumps( - scenario_result['scenario']['clientConfig']) - scenario_result['scenario']['serverConfig'] = json.dumps( - scenario_result['scenario']['serverConfig']) - scenario_result['latencies'] = json.dumps(scenario_result['latencies']) - scenario_result['serverCpuStats'] = [] - for stats in scenario_result['serverStats']: - scenario_result['serverCpuStats'].append(dict()) - scenario_result['serverCpuStats'][-1]['totalCpuTime'] = stats.pop( - 'totalCpuTime', None) - scenario_result['serverCpuStats'][-1]['idleCpuTime'] = stats.pop( - 'idleCpuTime', None) - for stats in scenario_result['clientStats']: - stats['latencies'] = json.dumps(stats['latencies']) - stats.pop('requestResults', None) - scenario_result['serverCores'] = json.dumps(scenario_result['serverCores']) - scenario_result['clientSuccess'] = json.dumps( - scenario_result['clientSuccess']) - scenario_result['serverSuccess'] = json.dumps( - scenario_result['serverSuccess']) - scenario_result['requestResults'] = json.dumps( - scenario_result.get('requestResults', [])) - scenario_result['serverCpuUsage'] = scenario_result['summary'].pop( - 'serverCpuUsage', None) - scenario_result['summary'].pop('successfulRequestsPerSecond', None) - scenario_result['summary'].pop('failedRequestsPerSecond', None) + and repeated fields. To maintain values of some fields while keeping + the schema relatively simple, we artificially leave some of the fields + as JSON strings. + """ + scenario_result["scenario"]["clientConfig"] = json.dumps( + scenario_result["scenario"]["clientConfig"] + ) + scenario_result["scenario"]["serverConfig"] = json.dumps( + scenario_result["scenario"]["serverConfig"] + ) + scenario_result["latencies"] = json.dumps(scenario_result["latencies"]) + scenario_result["serverCpuStats"] = [] + for stats in scenario_result["serverStats"]: + scenario_result["serverCpuStats"].append(dict()) + scenario_result["serverCpuStats"][-1]["totalCpuTime"] = stats.pop( + "totalCpuTime", None + ) + scenario_result["serverCpuStats"][-1]["idleCpuTime"] = stats.pop( + "idleCpuTime", None + ) + for stats in scenario_result["clientStats"]: + stats["latencies"] = json.dumps(stats["latencies"]) + stats.pop("requestResults", None) + scenario_result["serverCores"] = json.dumps(scenario_result["serverCores"]) + scenario_result["clientSuccess"] = json.dumps( + scenario_result["clientSuccess"] + ) + scenario_result["serverSuccess"] = json.dumps( + scenario_result["serverSuccess"] + ) + scenario_result["requestResults"] = json.dumps( + scenario_result.get("requestResults", []) + ) + scenario_result["serverCpuUsage"] = scenario_result["summary"].pop( + "serverCpuUsage", None + ) + scenario_result["summary"].pop("successfulRequestsPerSecond", None) + scenario_result["summary"].pop("failedRequestsPerSecond", None) def _populate_metadata_inplace(scenario_result): @@ -152,184 +178,219 @@ def _populate_metadata_inplace(scenario_result): # NOTE: Grabbing the Kokoro environment variables will only work if the # driver is running locally on the same machine where Kokoro has started # the job. For our setup, this is currently the case, so just assume that. - build_number = os.getenv('KOKORO_BUILD_NUMBER') - build_url = 'https://source.cloud.google.com/results/invocations/%s' % os.getenv( - 'KOKORO_BUILD_ID') - job_name = os.getenv('KOKORO_JOB_NAME') - git_commit = os.getenv('KOKORO_GIT_COMMIT') + build_number = os.getenv("KOKORO_BUILD_NUMBER") + build_url = ( + "https://source.cloud.google.com/results/invocations/%s" + % os.getenv("KOKORO_BUILD_ID") + ) + job_name = os.getenv("KOKORO_JOB_NAME") + git_commit = os.getenv("KOKORO_GIT_COMMIT") # actual commit is the actual head of PR that is getting tested # TODO(jtattermusch): unclear how to obtain on Kokoro - git_actual_commit = os.getenv('ghprbActualCommit') + git_actual_commit = os.getenv("ghprbActualCommit") utc_timestamp = str(calendar.timegm(time.gmtime())) - metadata = {'created': utc_timestamp} + metadata = {"created": utc_timestamp} if build_number: - metadata['buildNumber'] = build_number + metadata["buildNumber"] = build_number if build_url: - metadata['buildUrl'] = build_url + metadata["buildUrl"] = build_url if job_name: - metadata['jobName'] = job_name + metadata["jobName"] = job_name if git_commit: - metadata['gitCommit'] = git_commit + metadata["gitCommit"] = git_commit if git_actual_commit: - metadata['gitActualCommit'] = git_actual_commit + metadata["gitActualCommit"] = git_actual_commit - scenario_result['metadata'] = metadata + scenario_result["metadata"] = metadata def _populate_metadata_from_file(scenario_result, test_metadata_file): utc_timestamp = str(calendar.timegm(time.gmtime())) - metadata = {'created': utc_timestamp} + metadata = {"created": utc_timestamp} _annotation_to_bq_metadata_key_map = { - 'ci_' + key: key for key in ( - 'buildNumber', - 'buildUrl', - 'jobName', - 'gitCommit', - 'gitActualCommit', + "ci_" + key: key + for key in ( + "buildNumber", + "buildUrl", + "jobName", + "gitCommit", + "gitActualCommit", ) } if os.access(test_metadata_file, os.R_OK): - with open(test_metadata_file, 'r') as f: + with open(test_metadata_file, "r") as f: test_metadata = json.loads(f.read()) # eliminate managedFields from metadata set - if 'managedFields' in test_metadata: - del test_metadata['managedFields'] + if "managedFields" in test_metadata: + del test_metadata["managedFields"] - annotations = test_metadata.get('annotations', {}) + annotations = test_metadata.get("annotations", {}) # if use kubectl apply ..., kubectl will append current configuration to # annotation, the field is deleted since it includes a lot of irrelevant # information - if 'kubectl.kubernetes.io/last-applied-configuration' in annotations: - del annotations['kubectl.kubernetes.io/last-applied-configuration'] + if "kubectl.kubernetes.io/last-applied-configuration" in annotations: + del annotations["kubectl.kubernetes.io/last-applied-configuration"] # dump all metadata as JSON to testMetadata field - scenario_result['testMetadata'] = json.dumps(test_metadata) + scenario_result["testMetadata"] = json.dumps(test_metadata) for key, value in _annotation_to_bq_metadata_key_map.items(): if key in annotations: metadata[value] = annotations[key] - scenario_result['metadata'] = metadata + scenario_result["metadata"] = metadata def _populate_node_metadata_from_file(scenario_result, node_info_file): - node_metadata = {'driver': {}, 'servers': [], 'clients': []} + node_metadata = {"driver": {}, "servers": [], "clients": []} _node_info_to_bq_node_metadata_key_map = { - 'Name': 'name', - 'PodIP': 'podIP', - 'NodeName': 'nodeName', + "Name": "name", + "PodIP": "podIP", + "NodeName": "nodeName", } if os.access(node_info_file, os.R_OK): - with open(node_info_file, 'r') as f: + with open(node_info_file, "r") as f: file_metadata = json.loads(f.read()) for key, value in _node_info_to_bq_node_metadata_key_map.items(): - node_metadata['driver'][value] = file_metadata['Driver'][key] - for clientNodeInfo in file_metadata['Clients']: - node_metadata['clients'].append({ - value: clientNodeInfo[key] for key, value in - _node_info_to_bq_node_metadata_key_map.items() - }) - for serverNodeInfo in file_metadata['Servers']: - node_metadata['servers'].append({ - value: serverNodeInfo[key] for key, value in - _node_info_to_bq_node_metadata_key_map.items() - }) - - scenario_result['nodeMetadata'] = node_metadata - - -def _populate_prometheus_query_results_from_file(scenario_result, - prometheus_query_result_file): - """Populate the results from Prometheus query to Bigquery table """ + node_metadata["driver"][value] = file_metadata["Driver"][key] + for clientNodeInfo in file_metadata["Clients"]: + node_metadata["clients"].append( + { + value: clientNodeInfo[key] + for key, value in _node_info_to_bq_node_metadata_key_map.items() + } + ) + for serverNodeInfo in file_metadata["Servers"]: + node_metadata["servers"].append( + { + value: serverNodeInfo[key] + for key, value in _node_info_to_bq_node_metadata_key_map.items() + } + ) + + scenario_result["nodeMetadata"] = node_metadata + + +def _populate_prometheus_query_results_from_file( + scenario_result, prometheus_query_result_file +): + """Populate the results from Prometheus query to Bigquery table""" if os.access(prometheus_query_result_file, os.R_OK): - with open(prometheus_query_result_file, 'r', encoding='utf8') as f: + with open(prometheus_query_result_file, "r", encoding="utf8") as f: file_query_results = json.loads(f.read()) - scenario_result['testDurationSeconds'] = file_query_results[ - 'testDurationSeconds'] + scenario_result["testDurationSeconds"] = file_query_results[ + "testDurationSeconds" + ] clientsPrometheusData = [] - if 'clients' in file_query_results: + if "clients" in file_query_results: for client_name, client_data in file_query_results[ - 'clients'].items(): - clientPrometheusData = {'name': client_name} + "clients" + ].items(): + clientPrometheusData = {"name": client_name} containersPrometheusData = [] for container_name, container_data in client_data.items(): containerPrometheusData = { - 'name': container_name, - 'cpuSeconds': container_data['cpuSeconds'], - 'memoryMean': container_data['memoryMean'], + "name": container_name, + "cpuSeconds": container_data["cpuSeconds"], + "memoryMean": container_data["memoryMean"], } containersPrometheusData.append(containerPrometheusData) clientPrometheusData[ - 'containers'] = containersPrometheusData + "containers" + ] = containersPrometheusData clientsPrometheusData.append(clientPrometheusData) - scenario_result['clientsPrometheusData'] = clientsPrometheusData + scenario_result["clientsPrometheusData"] = clientsPrometheusData serversPrometheusData = [] - if 'servers' in file_query_results: + if "servers" in file_query_results: for server_name, server_data in file_query_results[ - 'servers'].items(): - serverPrometheusData = {'name': server_name} + "servers" + ].items(): + serverPrometheusData = {"name": server_name} containersPrometheusData = [] for container_name, container_data in server_data.items(): containerPrometheusData = { - 'name': container_name, - 'cpuSeconds': container_data['cpuSeconds'], - 'memoryMean': container_data['memoryMean'], + "name": container_name, + "cpuSeconds": container_data["cpuSeconds"], + "memoryMean": container_data["memoryMean"], } containersPrometheusData.append(containerPrometheusData) serverPrometheusData[ - 'containers'] = containersPrometheusData + "containers" + ] = containersPrometheusData serversPrometheusData.append(serverPrometheusData) - scenario_result['serversPrometheusData'] = serversPrometheusData - - -argp = argparse.ArgumentParser(description='Upload result to big query.') -argp.add_argument('--bq_result_table', - required=True, - default=None, - type=str, - help='Bigquery "dataset.table" to upload results to.') -argp.add_argument('--file_to_upload', - default='scenario_result.json', - type=str, - help='Report file to upload.') -argp.add_argument('--metadata_file_to_upload', - default='metadata.json', - type=str, - help='Metadata file to upload.') -argp.add_argument('--node_info_file_to_upload', - default='node_info.json', - type=str, - help='Node information file to upload.') -argp.add_argument('--prometheus_query_results_to_upload', - default='prometheus_query_result.json', - type=str, - help='Prometheus query result file to upload.') -argp.add_argument('--file_format', - choices=['scenario_result', 'netperf_latency_csv'], - default='scenario_result', - help='Format of the file to upload.') + scenario_result["serversPrometheusData"] = serversPrometheusData + + +argp = argparse.ArgumentParser(description="Upload result to big query.") +argp.add_argument( + "--bq_result_table", + required=True, + default=None, + type=str, + help='Bigquery "dataset.table" to upload results to.', +) +argp.add_argument( + "--file_to_upload", + default="scenario_result.json", + type=str, + help="Report file to upload.", +) +argp.add_argument( + "--metadata_file_to_upload", + default="metadata.json", + type=str, + help="Metadata file to upload.", +) +argp.add_argument( + "--node_info_file_to_upload", + default="node_info.json", + type=str, + help="Node information file to upload.", +) +argp.add_argument( + "--prometheus_query_results_to_upload", + default="prometheus_query_result.json", + type=str, + help="Prometheus query result file to upload.", +) +argp.add_argument( + "--file_format", + choices=["scenario_result", "netperf_latency_csv"], + default="scenario_result", + help="Format of the file to upload.", +) args = argp.parse_args() -dataset_id, table_id = args.bq_result_table.split('.', 2) +dataset_id, table_id = args.bq_result_table.split(".", 2) -if args.file_format == 'netperf_latency_csv': - _upload_netperf_latency_csv_to_bigquery(dataset_id, table_id, - args.file_to_upload) +if args.file_format == "netperf_latency_csv": + _upload_netperf_latency_csv_to_bigquery( + dataset_id, table_id, args.file_to_upload + ) else: - _upload_scenario_result_to_bigquery(dataset_id, table_id, - args.file_to_upload, - args.metadata_file_to_upload, - args.node_info_file_to_upload, - args.prometheus_query_results_to_upload) -print('Successfully uploaded %s, %s, %s and %s to BigQuery.\n' % - (args.file_to_upload, args.metadata_file_to_upload, - args.node_info_file_to_upload, args.prometheus_query_results_to_upload)) + _upload_scenario_result_to_bigquery( + dataset_id, + table_id, + args.file_to_upload, + args.metadata_file_to_upload, + args.node_info_file_to_upload, + args.prometheus_query_results_to_upload, + ) +print( + "Successfully uploaded %s, %s, %s and %s to BigQuery.\n" + % ( + args.file_to_upload, + args.metadata_file_to_upload, + args.node_info_file_to_upload, + args.prometheus_query_results_to_upload, + ) +) diff --git a/tools/run_tests/performance/loadtest_concat_yaml.py b/tools/run_tests/performance/loadtest_concat_yaml.py index a6ab10dc40967..2bc1ae006c6ad 100755 --- a/tools/run_tests/performance/loadtest_concat_yaml.py +++ b/tools/run_tests/performance/loadtest_concat_yaml.py @@ -36,30 +36,33 @@ def gen_content_strings(input_files: Iterable[str]) -> Iterable[str]: for input_file in input_files[1:]: with open(input_file) as f: content = f.read() - yield '---\n' + yield "---\n" yield content def main() -> None: - argp = argparse.ArgumentParser(description='Concatenates YAML files.') - argp.add_argument('-i', - '--inputs', - action='extend', - nargs='+', - type=str, - required=True, - help='Input files.') + argp = argparse.ArgumentParser(description="Concatenates YAML files.") argp.add_argument( - '-o', - '--output', + "-i", + "--inputs", + action="extend", + nargs="+", type=str, - help='Concatenated output file. Output to stdout if not set.') + required=True, + help="Input files.", + ) + argp.add_argument( + "-o", + "--output", + type=str, + help="Concatenated output file. Output to stdout if not set.", + ) args = argp.parse_args() - with open(args.output, 'w') if args.output else sys.stdout as f: + with open(args.output, "w") if args.output else sys.stdout as f: for content in gen_content_strings(args.inputs): - print(content, file=f, sep='', end='') + print(content, file=f, sep="", end="") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/run_tests/performance/loadtest_config.py b/tools/run_tests/performance/loadtest_config.py index 6f47499072a4b..6c0097011fc7b 100755 --- a/tools/run_tests/performance/loadtest_config.py +++ b/tools/run_tests/performance/loadtest_config.py @@ -52,49 +52,52 @@ def safe_name(language: str) -> str: def default_prefix() -> str: """Constructs and returns a default prefix for LoadTest names.""" - return os.environ.get('USER', 'loadtest') + return os.environ.get("USER", "loadtest") def now_string() -> str: """Returns the current date and time in string format.""" - return datetime.datetime.now().strftime('%Y%m%d%H%M%S') + return datetime.datetime.now().strftime("%Y%m%d%H%M%S") def validate_loadtest_name(name: str) -> None: """Validates that a LoadTest name is in the expected format.""" if len(name) > 253: raise ValueError( - 'LoadTest name must be less than 253 characters long: %s' % name) - if not all(c.isalnum() and not c.isupper() for c in name if c != '-'): - raise ValueError('Invalid characters in LoadTest name: %s' % name) - if not name or not name[0].isalpha() or name[-1] == '-': - raise ValueError('Invalid format for LoadTest name: %s' % name) + "LoadTest name must be less than 253 characters long: %s" % name + ) + if not all(c.isalnum() and not c.isupper() for c in name if c != "-"): + raise ValueError("Invalid characters in LoadTest name: %s" % name) + if not name or not name[0].isalpha() or name[-1] == "-": + raise ValueError("Invalid format for LoadTest name: %s" % name) -def loadtest_base_name(scenario_name: str, - uniquifier_elements: Iterable[str]) -> str: +def loadtest_base_name( + scenario_name: str, uniquifier_elements: Iterable[str] +) -> str: """Constructs and returns the base name for a LoadTest resource.""" - name_elements = scenario_name.split('_') + name_elements = scenario_name.split("_") name_elements.extend(uniquifier_elements) - return '-'.join(element.lower() for element in name_elements) + return "-".join(element.lower() for element in name_elements) -def loadtest_name(prefix: str, scenario_name: str, - uniquifier_elements: Iterable[str]) -> str: +def loadtest_name( + prefix: str, scenario_name: str, uniquifier_elements: Iterable[str] +) -> str: """Constructs and returns a valid name for a LoadTest resource.""" base_name = loadtest_base_name(scenario_name, uniquifier_elements) name_elements = [] if prefix: name_elements.append(prefix) name_elements.append(base_name) - name = '-'.join(name_elements) + name = "-".join(name_elements) validate_loadtest_name(name) return name def component_name(elements: Iterable[str]) -> str: """Constructs a component name from possibly empty elements.""" - return '-'.join((e for e in elements if e)) + return "-".join((e for e in elements if e)) def validate_annotations(annotations: Dict[str, str]) -> None: @@ -102,76 +105,83 @@ def validate_annotations(annotations: Dict[str, str]) -> None: These names are automatically added by the config generator. """ - names = set(('scenario', 'uniquifier')).intersection(annotations) + names = set(("scenario", "uniquifier")).intersection(annotations) if names: - raise ValueError('Annotations contain reserved names: %s' % names) + raise ValueError("Annotations contain reserved names: %s" % names) def gen_run_indices(runs_per_test: int) -> Iterable[str]: """Generates run indices for multiple runs, as formatted strings.""" if runs_per_test < 2: - yield '' + yield "" return - index_length = len('{:d}'.format(runs_per_test - 1)) - index_fmt = '{{:0{:d}d}}'.format(index_length) + index_length = len("{:d}".format(runs_per_test - 1)) + index_fmt = "{{:0{:d}d}}".format(index_length) for i in range(runs_per_test): yield index_fmt.format(i) -def scenario_name(base_name: str, client_channels: Optional[int], - server_threads: Optional[int], offered_load: Optional[int]): +def scenario_name( + base_name: str, + client_channels: Optional[int], + server_threads: Optional[int], + offered_load: Optional[int], +): """Constructs scenario name from base name and modifiers.""" elements = [base_name] if client_channels: - elements.append('{:d}channels'.format(client_channels)) + elements.append("{:d}channels".format(client_channels)) if server_threads: - elements.append('{:d}threads'.format(server_threads)) + elements.append("{:d}threads".format(server_threads)) if offered_load: - elements.append('{:d}load'.format(offered_load)) - return '_'.join(elements) + elements.append("{:d}load".format(offered_load)) + return "_".join(elements) def scenario_transform_function( - client_channels: Optional[int], server_threads: Optional[int], - offered_loads: Optional[Iterable[int]] -) -> Optional[Callable[[Iterable[Mapping[str, Any]]], Iterable[Mapping[str, - Any]]]]: + client_channels: Optional[int], + server_threads: Optional[int], + offered_loads: Optional[Iterable[int]], +) -> Optional[ + Callable[[Iterable[Mapping[str, Any]]], Iterable[Mapping[str, Any]]] +]: """Returns a transform to be applied to a list of scenarios.""" if not any((client_channels, server_threads, len(offered_loads))): return lambda s: s def _transform( - scenarios: Iterable[Mapping[str, - Any]]) -> Iterable[Mapping[str, Any]]: + scenarios: Iterable[Mapping[str, Any]] + ) -> Iterable[Mapping[str, Any]]: """Transforms scenarios by inserting num of client channels, number of async_server_threads and offered_load.""" for base_scenario in scenarios: - base_name = base_scenario['name'] + base_name = base_scenario["name"] if client_channels: - base_scenario['client_config'][ - 'client_channels'] = client_channels + base_scenario["client_config"][ + "client_channels" + ] = client_channels if server_threads: - base_scenario['server_config'][ - 'async_server_threads'] = server_threads + base_scenario["server_config"][ + "async_server_threads" + ] = server_threads if not offered_loads: - base_scenario['name'] = scenario_name(base_name, - client_channels, - server_threads, 0) + base_scenario["name"] = scenario_name( + base_name, client_channels, server_threads, 0 + ) yield base_scenario return for offered_load in offered_loads: scenario = copy.deepcopy(base_scenario) - scenario['client_config']['load_params'] = { - 'poisson': { - 'offered_load': offered_load - } + scenario["client_config"]["load_params"] = { + "poisson": {"offered_load": offered_load} } - scenario['name'] = scenario_name(base_name, client_channels, - server_threads, offered_load) + scenario["name"] = scenario_name( + base_name, client_channels, server_threads, offered_load + ) yield scenario return _transform @@ -188,8 +198,9 @@ def gen_loadtest_configs( annotations: Mapping[str, str], instances_per_client: int = 1, runs_per_test: int = 1, - scenario_transform: Callable[[Iterable[Mapping[str, Any]]], - List[Dict[str, Any]]] = lambda s: s + scenario_transform: Callable[ + [Iterable[Mapping[str, Any]]], List[Dict[str, Any]] + ] = lambda s: s, ) -> Iterable[Dict[str, Any]]: """Generates LoadTest configurations for a given language config. @@ -203,102 +214,121 @@ def gen_loadtest_configs( scenario_name_regex=scenario_name_regex, category=language_config.category, client_language=language_config.client_language, - server_language=language_config.server_language) + server_language=language_config.server_language, + ) scenarios = scenario_transform( - scenario_config_exporter.gen_scenarios(language_config.language, - scenario_filter)) + scenario_config_exporter.gen_scenarios( + language_config.language, scenario_filter + ) + ) for scenario in scenarios: for run_index in gen_run_indices(runs_per_test): - uniq = (uniquifier_elements + - [run_index] if run_index else uniquifier_elements) - name = loadtest_name(prefix, scenario['name'], uniq) - scenario_str = json.dumps({'scenarios': scenario}, - indent=' ') + '\n' + uniq = ( + uniquifier_elements + [run_index] + if run_index + else uniquifier_elements + ) + name = loadtest_name(prefix, scenario["name"], uniq) + scenario_str = ( + json.dumps({"scenarios": scenario}, indent=" ") + "\n" + ) config = copy.deepcopy(base_config) - metadata = config['metadata'] - metadata['name'] = name - if 'labels' not in metadata: - metadata['labels'] = dict() - metadata['labels']['language'] = safe_name(language_config.language) - metadata['labels']['prefix'] = prefix - if 'annotations' not in metadata: - metadata['annotations'] = dict() - metadata['annotations'].update(annotations) - metadata['annotations'].update({ - 'scenario': scenario['name'], - 'uniquifier': '-'.join(uniq), - }) - - spec = config['spec'] + metadata = config["metadata"] + metadata["name"] = name + if "labels" not in metadata: + metadata["labels"] = dict() + metadata["labels"]["language"] = safe_name(language_config.language) + metadata["labels"]["prefix"] = prefix + if "annotations" not in metadata: + metadata["annotations"] = dict() + metadata["annotations"].update(annotations) + metadata["annotations"].update( + { + "scenario": scenario["name"], + "uniquifier": "-".join(uniq), + } + ) + + spec = config["spec"] # Select clients with the required language. clients = [ - client for client in base_config_clients - if client['language'] == cl + client + for client in base_config_clients + if client["language"] == cl ] if not clients: - raise IndexError('Client language not found in template: %s' % - cl) + raise IndexError( + "Client language not found in template: %s" % cl + ) # Validate config for additional client instances. if instances_per_client > 1: c = collections.Counter( - (client.get('name', '') for client in clients)) + (client.get("name", "") for client in clients) + ) if max(c.values()) > 1: raise ValueError( - ('Multiple instances of multiple clients requires ' - 'unique names, name counts for language %s: %s') % - (cl, c.most_common())) + "Multiple instances of multiple clients requires " + "unique names, name counts for language %s: %s" + % (cl, c.most_common()) + ) # Name client instances with an index starting from zero. client_instances = [] for i in range(instances_per_client): client_instances.extend(copy.deepcopy(clients)) - for client in client_instances[-len(clients):]: - client['name'] = component_name((client.get('name', - ''), str(i))) + for client in client_instances[-len(clients) :]: + client["name"] = component_name( + (client.get("name", ""), str(i)) + ) # Set clients to named instances. - spec['clients'] = client_instances + spec["clients"] = client_instances # Select servers with the required language. - servers = copy.deepcopy([ - server for server in base_config_servers - if server['language'] == sl - ]) + servers = copy.deepcopy( + [ + server + for server in base_config_servers + if server["language"] == sl + ] + ) if not servers: - raise IndexError('Server language not found in template: %s' % - sl) + raise IndexError( + "Server language not found in template: %s" % sl + ) # Name servers with an index for consistency with clients. for i, server in enumerate(servers): - server['name'] = component_name((server.get('name', - ''), str(i))) + server["name"] = component_name( + (server.get("name", ""), str(i)) + ) # Set servers to named instances. - spec['servers'] = servers + spec["servers"] = servers # Add driver, if needed. - if 'driver' not in spec: - spec['driver'] = dict() + if "driver" not in spec: + spec["driver"] = dict() # Ensure driver has language and run fields. - driver = spec['driver'] - if 'language' not in driver: - driver['language'] = safe_name('c++') - if 'run' not in driver: - driver['run'] = dict() + driver = spec["driver"] + if "language" not in driver: + driver["language"] = safe_name("c++") + if "run" not in driver: + driver["run"] = dict() # Name the driver with an index for consistency with workers. # There is only one driver, so the index is zero. - if 'name' not in driver or not driver['name']: - driver['name'] = '0' + if "name" not in driver or not driver["name"]: + driver["name"] = "0" - spec['scenariosJSON'] = scenario_str + spec["scenariosJSON"] = scenario_str yield config @@ -309,41 +339,44 @@ def parse_key_value_args(args: Optional[Iterable[str]]) -> Dict[str, str]: if args is None: return d for arg in args: - key, equals, value = arg.partition('=') - if equals != '=': - raise ValueError('Expected key=value: ' + value) + key, equals, value = arg.partition("=") + if equals != "=": + raise ValueError("Expected key=value: " + value) d[key] = value return d def clear_empty_fields(config: Dict[str, Any]) -> None: """Clears fields set to empty values by string substitution.""" - spec = config['spec'] - if 'clients' in spec: - for client in spec['clients']: - if 'pool' in client and not client['pool']: - del client['pool'] - if 'servers' in spec: - for server in spec['servers']: - if 'pool' in server and not server['pool']: - del server['pool'] - if 'driver' in spec: - driver = spec['driver'] - if 'pool' in driver and not driver['pool']: - del driver['pool'] - if ('run' in driver and 'image' in driver['run'] and - not driver['run']['image']): - del driver['run']['image'] - if 'results' in spec and not ('bigQueryTable' in spec['results'] and - spec['results']['bigQueryTable']): - del spec['results'] + spec = config["spec"] + if "clients" in spec: + for client in spec["clients"]: + if "pool" in client and not client["pool"]: + del client["pool"] + if "servers" in spec: + for server in spec["servers"]: + if "pool" in server and not server["pool"]: + del server["pool"] + if "driver" in spec: + driver = spec["driver"] + if "pool" in driver and not driver["pool"]: + del driver["pool"] + if ( + "run" in driver + and "image" in driver["run"] + and not driver["run"]["image"] + ): + del driver["run"]["image"] + if "results" in spec and not ( + "bigQueryTable" in spec["results"] and spec["results"]["bigQueryTable"] + ): + del spec["results"] def config_dumper(header_comment: str) -> Type[yaml.SafeDumper]: """Returns a custom dumper to dump configurations in the expected format.""" class ConfigDumper(yaml.SafeDumper): - def expect_stream_start(self): super().expect_stream_start() if isinstance(self.event, yaml.StreamStartEvent): @@ -351,11 +384,11 @@ def expect_stream_start(self): self.write_indicator(header_comment, need_whitespace=False) def str_presenter(dumper, data): - if '\n' in data: - return dumper.represent_scalar('tag:yaml.org,2002:str', - data, - style='|') - return dumper.represent_scalar('tag:yaml.org,2002:str', data) + if "\n" in data: + return dumper.represent_scalar( + "tag:yaml.org,2002:str", data, style="|" + ) + return dumper.represent_scalar("tag:yaml.org,2002:str", data) ConfigDumper.add_representer(str, str_presenter) @@ -365,110 +398,133 @@ def str_presenter(dumper, data): def main() -> None: language_choices = sorted(scenario_config.LANGUAGES.keys()) argp = argparse.ArgumentParser( - description='Generates load test configs from a template.', - fromfile_prefix_chars='@') - argp.add_argument('-l', - '--language', - action='append', - choices=language_choices, - required=True, - help='Language(s) to benchmark.', - dest='languages') - argp.add_argument('-t', - '--template', - type=str, - required=True, - help='LoadTest configuration yaml file template.') - argp.add_argument('-s', - '--substitution', - action='append', - default=[], - help='Template substitution(s), in the form key=value.', - dest='substitutions') - argp.add_argument('-p', - '--prefix', - default='', - type=str, - help='Test name prefix.') - argp.add_argument('-u', - '--uniquifier_element', - action='append', - default=[], - help='String element(s) to make the test name unique.', - dest='uniquifier_elements') + description="Generates load test configs from a template.", + fromfile_prefix_chars="@", + ) argp.add_argument( - '-d', - action='store_true', - help='Use creation date and time as an additional uniquifier element.') - argp.add_argument('-a', - '--annotation', - action='append', - default=[], - help='metadata.annotation(s), in the form key=value.', - dest='annotations') - argp.add_argument('-r', - '--regex', - default='.*', - type=str, - help='Regex to select scenarios to run.') + "-l", + "--language", + action="append", + choices=language_choices, + required=True, + help="Language(s) to benchmark.", + dest="languages", + ) argp.add_argument( - '--category', - choices=['all', 'inproc', 'scalable', 'smoketest', 'sweep', 'psm'], - default='all', - help='Select a category of tests to run.') + "-t", + "--template", + type=str, + required=True, + help="LoadTest configuration yaml file template.", + ) argp.add_argument( - '--allow_client_language', - action='append', + "-s", + "--substitution", + action="append", + default=[], + help="Template substitution(s), in the form key=value.", + dest="substitutions", + ) + argp.add_argument( + "-p", "--prefix", default="", type=str, help="Test name prefix." + ) + argp.add_argument( + "-u", + "--uniquifier_element", + action="append", + default=[], + help="String element(s) to make the test name unique.", + dest="uniquifier_elements", + ) + argp.add_argument( + "-d", + action="store_true", + help="Use creation date and time as an additional uniquifier element.", + ) + argp.add_argument( + "-a", + "--annotation", + action="append", + default=[], + help="metadata.annotation(s), in the form key=value.", + dest="annotations", + ) + argp.add_argument( + "-r", + "--regex", + default=".*", + type=str, + help="Regex to select scenarios to run.", + ) + argp.add_argument( + "--category", + choices=["all", "inproc", "scalable", "smoketest", "sweep", "psm"], + default="all", + help="Select a category of tests to run.", + ) + argp.add_argument( + "--allow_client_language", + action="append", choices=language_choices, default=[], - help='Allow cross-language scenarios with this client language.', - dest='allow_client_languages') + help="Allow cross-language scenarios with this client language.", + dest="allow_client_languages", + ) argp.add_argument( - '--allow_server_language', - action='append', + "--allow_server_language", + action="append", choices=language_choices, default=[], - help='Allow cross-language scenarios with this server language.', - dest='allow_server_languages') - argp.add_argument('--instances_per_client', - default=1, - type=int, - help="Number of instances to generate for each client.") - argp.add_argument('--runs_per_test', - default=1, - type=int, - help='Number of copies to generate for each test.') - argp.add_argument('-o', - '--output', - type=str, - help='Output file name. Output to stdout if not set.') - argp.add_argument('--client_channels', - type=int, - help='Number of client channels.') - argp.add_argument('--server_threads', - type=int, - help='Number of async server threads.') + help="Allow cross-language scenarios with this server language.", + dest="allow_server_languages", + ) + argp.add_argument( + "--instances_per_client", + default=1, + type=int, + help="Number of instances to generate for each client.", + ) argp.add_argument( - '--offered_loads', + "--runs_per_test", + default=1, + type=int, + help="Number of copies to generate for each test.", + ) + argp.add_argument( + "-o", + "--output", + type=str, + help="Output file name. Output to stdout if not set.", + ) + argp.add_argument( + "--client_channels", type=int, help="Number of client channels." + ) + argp.add_argument( + "--server_threads", type=int, help="Number of async server threads." + ) + argp.add_argument( + "--offered_loads", nargs="*", type=int, default=[], - help='A list of QPS values at which each load test scenario will be run.' + help=( + "A list of QPS values at which each load test scenario will be run." + ), ) args = argp.parse_args() if args.instances_per_client < 1: - argp.error('instances_per_client must be greater than zero.') + argp.error("instances_per_client must be greater than zero.") if args.runs_per_test < 1: - argp.error('runs_per_test must be greater than zero.') + argp.error("runs_per_test must be greater than zero.") # Config generation ignores environment variables that are passed by the # controller at runtime. substitutions = { - 'DRIVER_PORT': '${DRIVER_PORT}', - 'KILL_AFTER': '${KILL_AFTER}', - 'POD_TIMEOUT': '${POD_TIMEOUT}', + "DRIVER_PORT": "${DRIVER_PORT}", + "KILL_AFTER": "${KILL_AFTER}", + "POD_TIMEOUT": "${POD_TIMEOUT}", } # The user can override the ignored variables above by passing them in as @@ -481,53 +537,60 @@ def main() -> None: annotations = parse_key_value_args(args.annotations) - transform = scenario_transform_function(args.client_channels, - args.server_threads, - args.offered_loads) + transform = scenario_transform_function( + args.client_channels, args.server_threads, args.offered_loads + ) with open(args.template) as f: base_config = yaml.safe_load( - string.Template(f.read()).substitute(substitutions)) + string.Template(f.read()).substitute(substitutions) + ) clear_empty_fields(base_config) - spec = base_config['spec'] - base_config_clients = spec['clients'] - del spec['clients'] - base_config_servers = spec['servers'] - del spec['servers'] + spec = base_config["spec"] + base_config_clients = spec["clients"] + del spec["clients"] + base_config_servers = spec["servers"] + del spec["servers"] - client_languages = [''] + args.allow_client_languages - server_languages = [''] + args.allow_server_languages + client_languages = [""] + args.allow_client_languages + server_languages = [""] + args.allow_server_languages config_generators = [] - for l, cl, sl in itertools.product(args.languages, client_languages, - server_languages): + for l, cl, sl in itertools.product( + args.languages, client_languages, server_languages + ): language_config = scenario_config_exporter.LanguageConfig( category=args.category, language=l, client_language=cl, - server_language=sl) + server_language=sl, + ) config_generators.append( - gen_loadtest_configs(base_config, - base_config_clients, - base_config_servers, - args.regex, - language_config, - loadtest_name_prefix=args.prefix, - uniquifier_elements=uniquifier_elements, - annotations=annotations, - instances_per_client=args.instances_per_client, - runs_per_test=args.runs_per_test, - scenario_transform=transform)) + gen_loadtest_configs( + base_config, + base_config_clients, + base_config_servers, + args.regex, + language_config, + loadtest_name_prefix=args.prefix, + uniquifier_elements=uniquifier_elements, + annotations=annotations, + instances_per_client=args.instances_per_client, + runs_per_test=args.runs_per_test, + scenario_transform=transform, + ) + ) configs = (config for config in itertools.chain(*config_generators)) - with open(args.output, 'w') if args.output else sys.stdout as f: - yaml.dump_all(configs, - stream=f, - Dumper=config_dumper( - CONFIGURATION_FILE_HEADER_COMMENT.strip()), - default_flow_style=False) + with open(args.output, "w") if args.output else sys.stdout as f: + yaml.dump_all( + configs, + stream=f, + Dumper=config_dumper(CONFIGURATION_FILE_HEADER_COMMENT.strip()), + default_flow_style=False, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/run_tests/performance/loadtest_template.py b/tools/run_tests/performance/loadtest_template.py index 16461e3c06f40..6cdf55a93fdaa 100755 --- a/tools/run_tests/performance/loadtest_template.py +++ b/tools/run_tests/performance/loadtest_template.py @@ -49,8 +49,9 @@ """ -def insert_worker(worker: Dict[str, Any], workers: List[Dict[str, - Any]]) -> None: +def insert_worker( + worker: Dict[str, Any], workers: List[Dict[str, Any]] +) -> None: """Inserts client or server into a list, without inserting duplicates.""" def dump(w): @@ -68,7 +69,7 @@ def uniquify_workers(workermap: Dict[str, List[Dict[str, Any]]]) -> None: if len(workers) <= 1: continue for i, worker in enumerate(workers): - worker['name'] = str(i) + worker["name"] = str(i) def loadtest_template( @@ -80,89 +81,99 @@ def loadtest_template( inject_server_pool: bool, inject_big_query_table: bool, inject_timeout_seconds: bool, - inject_ttl_seconds: bool) -> Dict[str, Any]: # yapf: disable + inject_ttl_seconds: bool) -> Dict[str, Any]: # fmt: skip """Generates the load test template.""" spec = dict() # type: Dict[str, Any] clientmap = dict() # Dict[str, List[Dict[str, Any]]] servermap = dict() # Dict[Str, List[Dict[str, Any]]] template = { - 'apiVersion': 'e2etest.grpc.io/v1', - 'kind': 'LoadTest', - 'metadata': metadata, + "apiVersion": "e2etest.grpc.io/v1", + "kind": "LoadTest", + "metadata": metadata, } for input_file_name in input_file_names: with open(input_file_name) as f: input_config = yaml.safe_load(f.read()) - if input_config.get('apiVersion') != template['apiVersion']: - raise ValueError('Unexpected api version in file {}: {}'.format( - input_file_name, input_config.get('apiVersion'))) - if input_config.get('kind') != template['kind']: - raise ValueError('Unexpected kind in file {}: {}'.format( - input_file_name, input_config.get('kind'))) - - for client in input_config['spec']['clients']: - del client['name'] + if input_config.get("apiVersion") != template["apiVersion"]: + raise ValueError( + "Unexpected api version in file {}: {}".format( + input_file_name, input_config.get("apiVersion") + ) + ) + if input_config.get("kind") != template["kind"]: + raise ValueError( + "Unexpected kind in file {}: {}".format( + input_file_name, input_config.get("kind") + ) + ) + + for client in input_config["spec"]["clients"]: + del client["name"] if inject_client_pool: - client['pool'] = '${client_pool}' - if client['language'] not in clientmap: - clientmap[client['language']] = [] - insert_worker(client, clientmap[client['language']]) + client["pool"] = "${client_pool}" + if client["language"] not in clientmap: + clientmap[client["language"]] = [] + insert_worker(client, clientmap[client["language"]]) - for server in input_config['spec']['servers']: - del server['name'] + for server in input_config["spec"]["servers"]: + del server["name"] if inject_server_pool: - server['pool'] = '${server_pool}' - if server['language'] not in servermap: - servermap[server['language']] = [] - insert_worker(server, servermap[server['language']]) + server["pool"] = "${server_pool}" + if server["language"] not in servermap: + servermap[server["language"]] = [] + insert_worker(server, servermap[server["language"]]) - input_spec = input_config['spec'] - del input_spec['clients'] - del input_spec['servers'] - del input_spec['scenariosJSON'] - spec.update(input_config['spec']) + input_spec = input_config["spec"] + del input_spec["clients"] + del input_spec["servers"] + del input_spec["scenariosJSON"] + spec.update(input_config["spec"]) uniquify_workers(clientmap) uniquify_workers(servermap) - spec.update({ - 'clients': - sum((clientmap[language] for language in sorted(clientmap)), - start=[]), - 'servers': - sum((servermap[language] for language in sorted(servermap)), - start=[]), - }) - - if 'driver' not in spec: - spec['driver'] = {'language': 'cxx'} - - driver = spec['driver'] - if 'name' in driver: - del driver['name'] + spec.update( + { + "clients": sum( + (clientmap[language] for language in sorted(clientmap)), + start=[], + ), + "servers": sum( + (servermap[language] for language in sorted(servermap)), + start=[], + ), + } + ) + + if "driver" not in spec: + spec["driver"] = {"language": "cxx"} + + driver = spec["driver"] + if "name" in driver: + del driver["name"] if inject_driver_image: - if 'run' not in driver: - driver['run'] = [{'name': 'main'}] - driver['run'][0]['image'] = '${driver_image}' + if "run" not in driver: + driver["run"] = [{"name": "main"}] + driver["run"][0]["image"] = "${driver_image}" if inject_driver_pool: - driver['pool'] = '${driver_pool}' + driver["pool"] = "${driver_pool}" - if 'run' not in driver: + if "run" not in driver: if inject_driver_pool: - raise ValueError('Cannot inject driver.pool: missing driver.run.') - del spec['driver'] + raise ValueError("Cannot inject driver.pool: missing driver.run.") + del spec["driver"] if inject_big_query_table: - if 'results' not in spec: - spec['results'] = dict() - spec['results']['bigQueryTable'] = '${big_query_table}' + if "results" not in spec: + spec["results"] = dict() + spec["results"]["bigQueryTable"] = "${big_query_table}" if inject_timeout_seconds: - spec['timeoutSeconds'] = '${timeout_seconds}' + spec["timeoutSeconds"] = "${timeout_seconds}" if inject_ttl_seconds: - spec['ttlSeconds'] = '${ttl_seconds}' + spec["ttlSeconds"] = "${ttl_seconds}" - template['spec'] = spec + template["spec"] = spec return template @@ -171,7 +182,6 @@ def template_dumper(header_comment: str) -> Type[yaml.SafeDumper]: """Returns a custom dumper to dump templates in the expected format.""" class TemplateDumper(yaml.SafeDumper): - def expect_stream_start(self): super().expect_stream_start() if isinstance(self.event, yaml.StreamStartEvent): @@ -179,11 +189,11 @@ def expect_stream_start(self): self.write_indicator(header_comment, need_whitespace=False) def str_presenter(dumper, data): - if '\n' in data: - return dumper.represent_scalar('tag:yaml.org,2002:str', - data, - style='|') - return dumper.represent_scalar('tag:yaml.org,2002:str', data) + if "\n" in data: + return dumper.represent_scalar( + "tag:yaml.org,2002:str", data, style="|" + ) + return dumper.represent_scalar("tag:yaml.org,2002:str", data) TemplateDumper.add_representer(str, str_presenter) @@ -192,62 +202,74 @@ def str_presenter(dumper, data): def main() -> None: argp = argparse.ArgumentParser( - description='Creates a load test config generator template.', - fromfile_prefix_chars='@') - argp.add_argument('-i', - '--inputs', - action='extend', - nargs='+', - type=str, - help='Input files.') - argp.add_argument('-o', - '--output', - type=str, - help='Output file. Outputs to stdout if not set.') + description="Creates a load test config generator template.", + fromfile_prefix_chars="@", + ) + argp.add_argument( + "-i", + "--inputs", + action="extend", + nargs="+", + type=str, + help="Input files.", + ) + argp.add_argument( + "-o", + "--output", + type=str, + help="Output file. Outputs to stdout if not set.", + ) + argp.add_argument( + "--inject_client_pool", + action="store_true", + help="Set spec.client(s).pool values to '${client_pool}'.", + ) + argp.add_argument( + "--inject_driver_image", + action="store_true", + help="Set spec.driver(s).image values to '${driver_image}'.", + ) + argp.add_argument( + "--inject_driver_pool", + action="store_true", + help="Set spec.driver(s).pool values to '${driver_pool}'.", + ) + argp.add_argument( + "--inject_server_pool", + action="store_true", + help="Set spec.server(s).pool values to '${server_pool}'.", + ) argp.add_argument( - '--inject_client_pool', - action='store_true', - help='Set spec.client(s).pool values to \'${client_pool}\'.') + "--inject_big_query_table", + action="store_true", + help="Set spec.results.bigQueryTable to '${big_query_table}'.", + ) argp.add_argument( - '--inject_driver_image', - action='store_true', - help='Set spec.driver(s).image values to \'${driver_image}\'.') + "--inject_timeout_seconds", + action="store_true", + help="Set spec.timeoutSeconds to '${timeout_seconds}'.", + ) argp.add_argument( - '--inject_driver_pool', - action='store_true', - help='Set spec.driver(s).pool values to \'${driver_pool}\'.') + "--inject_ttl_seconds", action="store_true", help="Set timeout " + ) argp.add_argument( - '--inject_server_pool', - action='store_true', - help='Set spec.server(s).pool values to \'${server_pool}\'.') + "-n", "--name", default="", type=str, help="metadata.name." + ) argp.add_argument( - '--inject_big_query_table', - action='store_true', - help='Set spec.results.bigQueryTable to \'${big_query_table}\'.') - argp.add_argument('--inject_timeout_seconds', - action='store_true', - help='Set spec.timeoutSeconds to \'${timeout_seconds}\'.') - argp.add_argument('--inject_ttl_seconds', - action='store_true', - help='Set timeout ') - argp.add_argument('-n', - '--name', - default='', - type=str, - help='metadata.name.') - argp.add_argument('-a', - '--annotation', - action='append', - type=str, - help='metadata.annotation(s), in the form key=value.', - dest='annotations') + "-a", + "--annotation", + action="append", + type=str, + help="metadata.annotation(s), in the form key=value.", + dest="annotations", + ) args = argp.parse_args() annotations = loadtest_config.parse_key_value_args(args.annotations) - metadata = {'name': args.name} + metadata = {"name": args.name} if annotations: - metadata['annotations'] = annotations + metadata["annotations"] = annotations template = loadtest_template( input_file_names=args.inputs, @@ -258,14 +280,17 @@ def main() -> None: inject_server_pool=args.inject_server_pool, inject_big_query_table=args.inject_big_query_table, inject_timeout_seconds=args.inject_timeout_seconds, - inject_ttl_seconds=args.inject_ttl_seconds) + inject_ttl_seconds=args.inject_ttl_seconds, + ) - with open(args.output, 'w') if args.output else sys.stdout as f: - yaml.dump(template, - stream=f, - Dumper=template_dumper(TEMPLATE_FILE_HEADER_COMMENT.strip()), - default_flow_style=False) + with open(args.output, "w") if args.output else sys.stdout as f: + yaml.dump( + template, + stream=f, + Dumper=template_dumper(TEMPLATE_FILE_HEADER_COMMENT.strip()), + default_flow_style=False, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/run_tests/performance/patch_scenario_results_schema.py b/tools/run_tests/performance/patch_scenario_results_schema.py index 5d38550e93673..44b23da6b4a1e 100755 --- a/tools/run_tests/performance/patch_scenario_results_schema.py +++ b/tools/run_tests/performance/patch_scenario_results_schema.py @@ -26,34 +26,40 @@ import uuid gcp_utils_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), '../../gcp/utils')) + os.path.join(os.path.dirname(__file__), "../../gcp/utils") +) sys.path.append(gcp_utils_dir) import big_query_utils -_PROJECT_ID = 'grpc-testing' +_PROJECT_ID = "grpc-testing" def _patch_results_table(dataset_id, table_id): bq = big_query_utils.create_big_query() - with open(os.path.dirname(__file__) + '/scenario_result_schema.json', - 'r') as f: + with open( + os.path.dirname(__file__) + "/scenario_result_schema.json", "r" + ) as f: table_schema = json.loads(f.read()) - desc = 'Results of performance benchmarks.' - return big_query_utils.patch_table(bq, _PROJECT_ID, dataset_id, table_id, - table_schema) + desc = "Results of performance benchmarks." + return big_query_utils.patch_table( + bq, _PROJECT_ID, dataset_id, table_id, table_schema + ) argp = argparse.ArgumentParser( - description='Patch schema of scenario results table.') -argp.add_argument('--bq_result_table', - required=True, - default=None, - type=str, - help='Bigquery "dataset.table" to patch.') + description="Patch schema of scenario results table." +) +argp.add_argument( + "--bq_result_table", + required=True, + default=None, + type=str, + help='Bigquery "dataset.table" to patch.', +) args = argp.parse_args() -dataset_id, table_id = args.bq_result_table.split('.', 2) +dataset_id, table_id = args.bq_result_table.split(".", 2) _patch_results_table(dataset_id, table_id) -print('Successfully patched schema of %s.\n' % args.bq_result_table) +print("Successfully patched schema of %s.\n" % args.bq_result_table) diff --git a/tools/run_tests/performance/prometheus.py b/tools/run_tests/performance/prometheus.py index df4fe31999401..ced87fc47573e 100644 --- a/tools/run_tests/performance/prometheus.py +++ b/tools/run_tests/performance/prometheus.py @@ -59,19 +59,15 @@ def _fetch_by_query(self, query: str) -> Dict[str, Any]: a time series. """ resp = requests.get( - self.url + '/api/v1/query_range', - { - 'query': query, - 'start': self.start, - 'end': self.end, - 'step': 5 - }, + self.url + "/api/v1/query_range", + {"query": query, "start": self.start, "end": self.end, "step": 5}, ) resp.raise_for_status() return resp.json() - def _fetch_cpu_for_pod(self, container_matcher: str, - pod_name: str) -> Dict[str, List[float]]: + def _fetch_cpu_for_pod( + self, container_matcher: str, pod_name: str + ) -> Dict[str, List[float]]: """Fetches the cpu data for each pod. Fetch total cpu seconds during the time range specified in the Prometheus instance @@ -83,16 +79,22 @@ def _fetch_cpu_for_pod(self, container_matcher: str, """ query = ( 'container_cpu_usage_seconds_total{job="kubernetes-cadvisor",pod="' - + pod_name + '",container=' + container_matcher + '}') - logging.debug('running prometheus query for cpu: %s', query) + + pod_name + + '",container=' + + container_matcher + + "}" + ) + logging.debug("running prometheus query for cpu: %s", query) cpu_data = self._fetch_by_query(query) - logging.debug('raw cpu data: %s', str(cpu_data)) + logging.debug("raw cpu data: %s", str(cpu_data)) cpu_container_name_to_data_list = get_data_list_from_timeseries( - cpu_data) + cpu_data + ) return cpu_container_name_to_data_list - def _fetch_memory_for_pod(self, container_matcher: str, - pod_name: str) -> Dict[str, List[float]]: + def _fetch_memory_for_pod( + self, container_matcher: str, pod_name: str + ) -> Dict[str, List[float]]: """Fetches memory data for each pod. Fetch total memory data during the time range specified in the Prometheus instance @@ -103,21 +105,26 @@ def _fetch_memory_for_pod(self, container_matcher: str, container_matcher: A string consist one or more container name separated by |. """ query = ( - 'container_memory_usage_bytes{job="kubernetes-cadvisor",pod="' + - pod_name + '",container=' + container_matcher + "}") + 'container_memory_usage_bytes{job="kubernetes-cadvisor",pod="' + + pod_name + + '",container=' + + container_matcher + + "}" + ) - logging.debug('running prometheus query for memory: %s', query) + logging.debug("running prometheus query for memory: %s", query) memory_data = self._fetch_by_query(query) - logging.debug('raw memory data: %s', str(memory_data)) + logging.debug("raw memory data: %s", str(memory_data)) memory_container_name_to_data_list = get_data_list_from_timeseries( - memory_data) + memory_data + ) return memory_container_name_to_data_list def fetch_cpu_and_memory_data( - self, container_list: List[str], - pod_dict: Dict[str, List[str]]) -> Dict[str, Any]: + self, container_list: List[str], pod_dict: Dict[str, List[str]] + ) -> Dict[str, Any]: """Fetch total cpu seconds and memory data for multiple pods. Args: @@ -134,15 +141,19 @@ def fetch_cpu_and_memory_data( for pod in pod_names: container_data = {} for container, data in self._fetch_cpu_for_pod( - container_matcher, pod).items(): + container_matcher, pod + ).items(): container_data[container] = {} container_data[container][ - 'cpuSeconds'] = compute_total_cpu_seconds(data) + "cpuSeconds" + ] = compute_total_cpu_seconds(data) for container, data in self._fetch_memory_for_pod( - container_matcher, pod).items(): + container_matcher, pod + ).items(): container_data[container][ - 'memoryMean'] = compute_average_memory_usage(data) + "memoryMean" + ] = compute_average_memory_usage(data) pod_data[pod] = container_data processed_data[role] = pod_data @@ -153,7 +164,7 @@ def construct_container_matcher(container_list: List[str]) -> str: """Constructs the container matching string used in the prometheus query.""" if len(container_list) == 0: - raise Exception('no container name provided') + raise Exception("no container name provided") containers_to_fetch = '"' if len(container_list) == 1: @@ -161,7 +172,7 @@ def construct_container_matcher(container_list: List[str]) -> str: else: containers_to_fetch = '~"' + container_list[0] for container in container_list[1:]: - containers_to_fetch = containers_to_fetch + '|' + container + containers_to_fetch = containers_to_fetch + "|" + container containers_to_fetch = containers_to_fetch + '"' return containers_to_fetch @@ -169,11 +180,12 @@ def construct_container_matcher(container_list: List[str]) -> str: def get_data_list_from_timeseries(data: Any) -> Dict[str, List[float]]: """Constructs a Dict as keys are the container names and values are a list of data taken from given timeseries data.""" - if data['status'] != 'success': - raise Exception('command failed: ' + data['status'] + str(data)) - if data['data']['resultType'] != 'matrix': - raise Exception('resultType is not matrix: ' + - data['data']['resultType']) + if data["status"] != "success": + raise Exception("command failed: " + data["status"] + str(data)) + if data["data"]["resultType"] != "matrix": + raise Exception( + "resultType is not matrix: " + data["data"]["resultType"] + ) container_name_to_data_list = {} for res in data["data"]["result"]: @@ -197,25 +209,26 @@ def compute_average_memory_usage(memory_data_list: List[float]) -> float: return statistics.mean(memory_data_list) -def construct_pod_dict(node_info_file: str, - pod_types: List[str]) -> Dict[str, List[str]]: +def construct_pod_dict( + node_info_file: str, pod_types: List[str] +) -> Dict[str, List[str]]: """Constructs a dict of pod names to be queried. Args: node_info_file: The file path contains the pod names to query. - The pods' names are put into a Dict of list that keyed by the + The pods' names are put into a Dict of list that keyed by the role name: clients, servers and driver. """ - with open(node_info_file, 'r') as f: + with open(node_info_file, "r") as f: pod_names = json.load(f) - pod_type_to_name = {'clients': [], 'driver': [], 'servers': []} + pod_type_to_name = {"clients": [], "driver": [], "servers": []} - for client in pod_names['Clients']: - pod_type_to_name['clients'].append(client['Name']) - for server in pod_names['Servers']: - pod_type_to_name['servers'].append(server['Name']) + for client in pod_names["Clients"]: + pod_type_to_name["clients"].append(client["Name"]) + for server in pod_names["Servers"]: + pod_type_to_name["servers"].append(server["Name"]) - pod_type_to_name["driver"].append(pod_names['Driver']['Name']) + pod_type_to_name["driver"].append(pod_names["Driver"]["Name"]) pod_names_to_query = {} for pod_type in pod_types: @@ -226,67 +239,73 @@ def construct_pod_dict(node_info_file: str, def convert_UTC_to_epoch(utc_timestamp: str) -> str: """Converts a utc timestamp string to epoch time string.""" parsed_time = parser.parse(utc_timestamp) - epoch = parsed_time.strftime('%s') + epoch = parsed_time.strftime("%s") return epoch def main() -> None: argp = argparse.ArgumentParser( - description='Fetch cpu and memory stats from prometheus') - argp.add_argument('--url', help='Prometheus base url', required=True) + description="Fetch cpu and memory stats from prometheus" + ) + argp.add_argument("--url", help="Prometheus base url", required=True) argp.add_argument( - '--scenario_result_file', - default='scenario_result.json', + "--scenario_result_file", + default="scenario_result.json", type=str, - help='File contains epoch seconds for start and end time', + help="File contains epoch seconds for start and end time", ) argp.add_argument( - '--node_info_file', - default='/var/data/qps_workers/node_info.json', - help='File contains pod name to query the metrics for', + "--node_info_file", + default="/var/data/qps_workers/node_info.json", + help="File contains pod name to query the metrics for", ) argp.add_argument( - '--pod_type', - action='append', - help= - 'Pod type to query the metrics for, the options are driver, client and server', - choices=['driver', 'clients', 'servers'], + "--pod_type", + action="append", + help=( + "Pod type to query the metrics for, the options are driver, client" + " and server" + ), + choices=["driver", "clients", "servers"], required=True, ) argp.add_argument( - '--container_name', - action='append', - help='The container names to query the metrics for', + "--container_name", + action="append", + help="The container names to query the metrics for", required=True, ) argp.add_argument( - '--export_file_name', - default='prometheus_query_result.json', + "--export_file_name", + default="prometheus_query_result.json", type=str, - help='Name of exported JSON file.', + help="Name of exported JSON file.", ) argp.add_argument( - '--quiet', + "--quiet", default=False, - help='Suppress informative output', + help="Suppress informative output", ) argp.add_argument( - '--delay_seconds', + "--delay_seconds", default=0, type=int, - help= - 'Configure delay in seconds to perform Prometheus queries, default is 0', + help=( + "Configure delay in seconds to perform Prometheus queries, default" + " is 0" + ), ) args = argp.parse_args() if not args.quiet: logging.getLogger().setLevel(logging.DEBUG) - with open(args.scenario_result_file, 'r') as q: + with open(args.scenario_result_file, "r") as q: scenario_result = json.load(q) start_time = convert_UTC_to_epoch( - scenario_result['summary']['startTime']) - end_time = convert_UTC_to_epoch(scenario_result['summary']['endTime']) + scenario_result["summary"]["startTime"] + ) + end_time = convert_UTC_to_epoch(scenario_result["summary"]["endTime"]) p = Prometheus( url=args.url, start=start_time, @@ -297,14 +316,15 @@ def main() -> None: pod_dict = construct_pod_dict(args.node_info_file, args.pod_type) processed_data = p.fetch_cpu_and_memory_data( - container_list=args.container_name, pod_dict=pod_dict) - processed_data['testDurationSeconds'] = float(end_time) - float(start_time) + container_list=args.container_name, pod_dict=pod_dict + ) + processed_data["testDurationSeconds"] = float(end_time) - float(start_time) logging.debug(json.dumps(processed_data, sort_keys=True, indent=4)) - with open(args.export_file_name, 'w', encoding='utf8') as export_file: + with open(args.export_file_name, "w", encoding="utf8") as export_file: json.dump(processed_data, export_file, sort_keys=True, indent=4) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/run_tests/performance/scenario_config.py b/tools/run_tests/performance/scenario_config.py index 25afa13a4b2b5..21ed19916f92d 100644 --- a/tools/run_tests/performance/scenario_config.py +++ b/tools/run_tests/performance/scenario_config.py @@ -20,27 +20,27 @@ JAVA_WARMUP_SECONDS = 15 # Java needs more warmup time for JIT to kick in. BENCHMARK_SECONDS = 30 -SMOKETEST = 'smoketest' -SCALABLE = 'scalable' -INPROC = 'inproc' -SWEEP = 'sweep' -PSM = 'psm' +SMOKETEST = "smoketest" +SCALABLE = "scalable" +INPROC = "inproc" +SWEEP = "sweep" +PSM = "psm" DEFAULT_CATEGORIES = (SCALABLE, SMOKETEST) SECURE_SECARGS = { - 'use_test_ca': True, - 'server_host_override': 'foo.test.google.fr' + "use_test_ca": True, + "server_host_override": "foo.test.google.fr", } HISTOGRAM_PARAMS = { - 'resolution': 0.01, - 'max_possible': 60e9, + "resolution": 0.01, + "max_possible": 60e9, } # target number of RPCs outstanding on across all client channels in # non-ping-pong tests (since we can only specify per-channel numbers, the # actual target will be slightly higher) -OUTSTANDING_REQUESTS = {'async': 6400, 'async-limited': 800, 'sync': 1000} +OUTSTANDING_REQUESTS = {"async": 6400, "async-limited": 800, "sync": 1000} # wide is the number of client channels in multi-channel tests (1 otherwise) WIDE = 64 @@ -59,10 +59,10 @@ def remove_nonproto_fields(scenario): This function removes additional information about the scenario that is not included in the ScenarioConfig protobuf message. """ - scenario.pop('CATEGORIES', None) - scenario.pop('CLIENT_LANGUAGE', None) - scenario.pop('SERVER_LANGUAGE', None) - scenario.pop('EXCLUDED_POLL_ENGINES', None) + scenario.pop("CATEGORIES", None) + scenario.pop("CLIENT_LANGUAGE", None) + scenario.pop("SERVER_LANGUAGE", None) + scenario.pop("EXCLUDED_POLL_ENGINES", None) return scenario @@ -76,478 +76,519 @@ def geometric_progression(start, stop, step): def _payload_type(use_generic_payload, req_size, resp_size): r = {} sizes = { - 'req_size': req_size, - 'resp_size': resp_size, + "req_size": req_size, + "resp_size": resp_size, } if use_generic_payload: - r['bytebuf_params'] = sizes + r["bytebuf_params"] = sizes else: - r['simple_params'] = sizes + r["simple_params"] = sizes return r def _load_params(offered_load): r = {} if offered_load is None: - r['closed_loop'] = {} + r["closed_loop"] = {} else: load = {} - load['offered_load'] = offered_load - r['poisson'] = load + load["offered_load"] = offered_load + r["poisson"] = load return r def _add_channel_arg(config, key, value): - if 'channel_args' in config: - channel_args = config['channel_args'] + if "channel_args" in config: + channel_args = config["channel_args"] else: channel_args = [] - config['channel_args'] = channel_args - arg = {'name': key} + config["channel_args"] = channel_args + arg = {"name": key} if isinstance(value, int): - arg['int_value'] = value + arg["int_value"] = value else: - arg['str_value'] = value + arg["str_value"] = value channel_args.append(arg) -def _ping_pong_scenario(name, - rpc_type, - client_type, - server_type, - secure=True, - use_generic_payload=False, - req_size=0, - resp_size=0, - unconstrained_client=None, - client_language=None, - server_language=None, - async_server_threads=0, - client_processes=0, - server_processes=0, - server_threads_per_cq=0, - client_threads_per_cq=0, - warmup_seconds=WARMUP_SECONDS, - categories=None, - channels=None, - outstanding=None, - num_clients=None, - resource_quota_size=None, - messages_per_stream=None, - excluded_poll_engines=None, - minimal_stack=False, - offered_load=None): +def _ping_pong_scenario( + name, + rpc_type, + client_type, + server_type, + secure=True, + use_generic_payload=False, + req_size=0, + resp_size=0, + unconstrained_client=None, + client_language=None, + server_language=None, + async_server_threads=0, + client_processes=0, + server_processes=0, + server_threads_per_cq=0, + client_threads_per_cq=0, + warmup_seconds=WARMUP_SECONDS, + categories=None, + channels=None, + outstanding=None, + num_clients=None, + resource_quota_size=None, + messages_per_stream=None, + excluded_poll_engines=None, + minimal_stack=False, + offered_load=None, +): """Creates a basic ping pong scenario.""" scenario = { - 'name': name, - 'num_servers': 1, - 'num_clients': 1, - 'client_config': { - 'client_type': client_type, - 'security_params': _get_secargs(secure), - 'outstanding_rpcs_per_channel': 1, - 'client_channels': 1, - 'async_client_threads': 1, - 'client_processes': client_processes, - 'threads_per_cq': client_threads_per_cq, - 'rpc_type': rpc_type, - 'histogram_params': HISTOGRAM_PARAMS, - 'channel_args': [], + "name": name, + "num_servers": 1, + "num_clients": 1, + "client_config": { + "client_type": client_type, + "security_params": _get_secargs(secure), + "outstanding_rpcs_per_channel": 1, + "client_channels": 1, + "async_client_threads": 1, + "client_processes": client_processes, + "threads_per_cq": client_threads_per_cq, + "rpc_type": rpc_type, + "histogram_params": HISTOGRAM_PARAMS, + "channel_args": [], }, - 'server_config': { - 'server_type': server_type, - 'security_params': _get_secargs(secure), - 'async_server_threads': async_server_threads, - 'server_processes': server_processes, - 'threads_per_cq': server_threads_per_cq, - 'channel_args': [], + "server_config": { + "server_type": server_type, + "security_params": _get_secargs(secure), + "async_server_threads": async_server_threads, + "server_processes": server_processes, + "threads_per_cq": server_threads_per_cq, + "channel_args": [], }, - 'warmup_seconds': warmup_seconds, - 'benchmark_seconds': BENCHMARK_SECONDS, - 'CATEGORIES': list(DEFAULT_CATEGORIES), - 'EXCLUDED_POLL_ENGINES': [], + "warmup_seconds": warmup_seconds, + "benchmark_seconds": BENCHMARK_SECONDS, + "CATEGORIES": list(DEFAULT_CATEGORIES), + "EXCLUDED_POLL_ENGINES": [], } if resource_quota_size: - scenario['server_config']['resource_quota_size'] = resource_quota_size + scenario["server_config"]["resource_quota_size"] = resource_quota_size if use_generic_payload: - if server_type != 'ASYNC_GENERIC_SERVER': - raise Exception('Use ASYNC_GENERIC_SERVER for generic payload.') - scenario['server_config']['payload_config'] = _payload_type( - use_generic_payload, req_size, resp_size) + if server_type != "ASYNC_GENERIC_SERVER": + raise Exception("Use ASYNC_GENERIC_SERVER for generic payload.") + scenario["server_config"]["payload_config"] = _payload_type( + use_generic_payload, req_size, resp_size + ) - scenario['client_config']['payload_config'] = _payload_type( - use_generic_payload, req_size, resp_size) + scenario["client_config"]["payload_config"] = _payload_type( + use_generic_payload, req_size, resp_size + ) # Optimization target of 'throughput' does not work well with epoll1 polling # engine. Use the default value of 'blend' - optimization_target = 'throughput' + optimization_target = "throughput" if unconstrained_client: - outstanding_calls = outstanding if outstanding is not None else OUTSTANDING_REQUESTS[ - unconstrained_client] + outstanding_calls = ( + outstanding + if outstanding is not None + else OUTSTANDING_REQUESTS[unconstrained_client] + ) # clamp buffer usage to something reasonable (16 gig for now) MAX_MEMORY_USE = 16 * 1024 * 1024 * 1024 if outstanding_calls * max(req_size, resp_size) > MAX_MEMORY_USE: - outstanding_calls = max(1, - MAX_MEMORY_USE / max(req_size, resp_size)) + outstanding_calls = max( + 1, MAX_MEMORY_USE / max(req_size, resp_size) + ) wide = channels if channels is not None else WIDE deep = int(math.ceil(1.0 * outstanding_calls / wide)) - scenario[ - 'num_clients'] = num_clients if num_clients is not None else 0 # use as many clients as available. - scenario['client_config']['outstanding_rpcs_per_channel'] = deep - scenario['client_config']['client_channels'] = wide - scenario['client_config']['async_client_threads'] = 0 + scenario["num_clients"] = ( + num_clients if num_clients is not None else 0 + ) # use as many clients as available. + scenario["client_config"]["outstanding_rpcs_per_channel"] = deep + scenario["client_config"]["client_channels"] = wide + scenario["client_config"]["async_client_threads"] = 0 if offered_load is not None: - optimization_target = 'latency' + optimization_target = "latency" else: - scenario['client_config']['outstanding_rpcs_per_channel'] = 1 - scenario['client_config']['client_channels'] = 1 - scenario['client_config']['async_client_threads'] = 1 - optimization_target = 'latency' + scenario["client_config"]["outstanding_rpcs_per_channel"] = 1 + scenario["client_config"]["client_channels"] = 1 + scenario["client_config"]["async_client_threads"] = 1 + optimization_target = "latency" - scenario['client_config']['load_params'] = _load_params(offered_load) + scenario["client_config"]["load_params"] = _load_params(offered_load) optimization_channel_arg = { - 'name': 'grpc.optimization_target', - 'str_value': optimization_target + "name": "grpc.optimization_target", + "str_value": optimization_target, } - scenario['client_config']['channel_args'].append(optimization_channel_arg) - scenario['server_config']['channel_args'].append(optimization_channel_arg) + scenario["client_config"]["channel_args"].append(optimization_channel_arg) + scenario["server_config"]["channel_args"].append(optimization_channel_arg) if minimal_stack: - _add_channel_arg(scenario['client_config'], 'grpc.minimal_stack', 1) - _add_channel_arg(scenario['server_config'], 'grpc.minimal_stack', 1) + _add_channel_arg(scenario["client_config"], "grpc.minimal_stack", 1) + _add_channel_arg(scenario["server_config"], "grpc.minimal_stack", 1) if messages_per_stream: - scenario['client_config']['messages_per_stream'] = messages_per_stream + scenario["client_config"]["messages_per_stream"] = messages_per_stream if client_language: # the CLIENT_LANGUAGE field is recognized by run_performance_tests.py - scenario['CLIENT_LANGUAGE'] = client_language + scenario["CLIENT_LANGUAGE"] = client_language if server_language: # the SERVER_LANGUAGE field is recognized by run_performance_tests.py - scenario['SERVER_LANGUAGE'] = server_language + scenario["SERVER_LANGUAGE"] = server_language if categories: - scenario['CATEGORIES'] = categories + scenario["CATEGORIES"] = categories if excluded_poll_engines: # The polling engines for which this scenario is excluded - scenario['EXCLUDED_POLL_ENGINES'] = excluded_poll_engines + scenario["EXCLUDED_POLL_ENGINES"] = excluded_poll_engines return scenario class Language(object): - @property def safename(self): return str(self) class CXXLanguage(Language): - @property def safename(self): - return 'cxx' + return "cxx" def worker_cmdline(self): - return ['cmake/build/qps_worker'] + return ["cmake/build/qps_worker"] def worker_port_offset(self): return 0 def scenarios(self): - yield _ping_pong_scenario('cpp_protobuf_async_unary_5000rpcs_1KB_psm', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - req_size=1024, - resp_size=1024, - outstanding=5000, - channels=1, - num_clients=1, - secure=False, - async_server_threads=1, - categories=[PSM]) + yield _ping_pong_scenario( + "cpp_protobuf_async_unary_5000rpcs_1KB_psm", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + req_size=1024, + resp_size=1024, + outstanding=5000, + channels=1, + num_clients=1, + secure=False, + async_server_threads=1, + categories=[PSM], + ) # TODO(ctiller): add 70% load latency test yield _ping_pong_scenario( - 'cpp_protobuf_async_unary_1channel_100rpcs_1MB', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "cpp_protobuf_async_unary_1channel_100rpcs_1MB", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", req_size=1024 * 1024, resp_size=1024 * 1024, - unconstrained_client='async', + unconstrained_client="async", outstanding=100, channels=1, num_clients=1, secure=False, - categories=[SWEEP]) + categories=[SWEEP], + ) yield _ping_pong_scenario( - 'cpp_protobuf_async_streaming_from_client_1channel_1MB', - rpc_type='STREAMING_FROM_CLIENT', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "cpp_protobuf_async_streaming_from_client_1channel_1MB", + rpc_type="STREAMING_FROM_CLIENT", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", req_size=1024 * 1024, resp_size=1024 * 1024, - unconstrained_client='async', + unconstrained_client="async", outstanding=1, channels=1, num_clients=1, secure=False, - categories=[SWEEP]) + categories=[SWEEP], + ) # Scenario was added in https://github.com/grpc/grpc/pull/12987, but its purpose is unclear # (beyond excercising some params that other scenarios don't) yield _ping_pong_scenario( - 'cpp_protobuf_async_unary_75Kqps_600channel_60Krpcs_300Breq_50Bresp', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "cpp_protobuf_async_unary_75Kqps_600channel_60Krpcs_300Breq_50Bresp", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", req_size=300, resp_size=50, - unconstrained_client='async', + unconstrained_client="async", outstanding=30000, channels=300, offered_load=37500, secure=False, async_server_threads=16, server_threads_per_cq=1, - categories=[SCALABLE]) + categories=[SCALABLE], + ) for secure in [True, False]: - secstr = 'secure' if secure else 'insecure' - smoketest_categories = ([SMOKETEST] if secure else []) - inproc_categories = ([INPROC] if not secure else []) + secstr = "secure" if secure else "insecure" + smoketest_categories = [SMOKETEST] if secure else [] + inproc_categories = [INPROC] if not secure else [] yield _ping_pong_scenario( - 'cpp_generic_async_streaming_ping_pong_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', + "cpp_generic_async_streaming_ping_pong_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", use_generic_payload=True, async_server_threads=1, secure=secure, - categories=smoketest_categories + inproc_categories + - [SCALABLE]) + categories=smoketest_categories + + inproc_categories + + [SCALABLE], + ) yield _ping_pong_scenario( - 'cpp_generic_async_streaming_qps_unconstrained_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + "cpp_generic_async_streaming_qps_unconstrained_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=secure, client_threads_per_cq=2, server_threads_per_cq=2, minimal_stack=not secure, - categories=smoketest_categories + inproc_categories + - [SCALABLE]) + categories=smoketest_categories + + inproc_categories + + [SCALABLE], + ) for mps in geometric_progression(10, 20, 10): yield _ping_pong_scenario( - 'cpp_generic_async_streaming_qps_unconstrained_%smps_%s' % - (mps, secstr), - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + "cpp_generic_async_streaming_qps_unconstrained_%smps_%s" + % (mps, secstr), + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=secure, messages_per_stream=mps, minimal_stack=not secure, - categories=smoketest_categories + inproc_categories + - [SCALABLE]) + categories=smoketest_categories + + inproc_categories + + [SCALABLE], + ) for mps in geometric_progression(1, 200, math.sqrt(10)): yield _ping_pong_scenario( - 'cpp_generic_async_streaming_qps_unconstrained_%smps_%s' % - (mps, secstr), - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + "cpp_generic_async_streaming_qps_unconstrained_%smps_%s" + % (mps, secstr), + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=secure, messages_per_stream=mps, minimal_stack=not secure, - categories=[SWEEP]) + categories=[SWEEP], + ) yield _ping_pong_scenario( - 'cpp_generic_async_streaming_qps_1channel_1MBmsg_%s' % secstr, - rpc_type='STREAMING', + "cpp_generic_async_streaming_qps_1channel_1MBmsg_%s" % secstr, + rpc_type="STREAMING", req_size=1024 * 1024, resp_size=1024 * 1024, - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=secure, minimal_stack=not secure, categories=inproc_categories + [SCALABLE], channels=1, - outstanding=100) + outstanding=100, + ) yield _ping_pong_scenario( - 'cpp_generic_async_streaming_qps_unconstrained_64KBmsg_%s' % - secstr, - rpc_type='STREAMING', + "cpp_generic_async_streaming_qps_unconstrained_64KBmsg_%s" + % secstr, + rpc_type="STREAMING", req_size=64 * 1024, resp_size=64 * 1024, - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=secure, minimal_stack=not secure, - categories=inproc_categories + [SCALABLE]) + categories=inproc_categories + [SCALABLE], + ) yield _ping_pong_scenario( - 'cpp_generic_async_streaming_qps_unconstrained_1cq_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async-limited', + "cpp_generic_async_streaming_qps_unconstrained_1cq_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async-limited", use_generic_payload=True, secure=secure, client_threads_per_cq=1000000, server_threads_per_cq=1000000, - categories=[SWEEP]) + categories=[SWEEP], + ) yield _ping_pong_scenario( - 'cpp_protobuf_async_streaming_qps_unconstrained_1cq_%s' % - secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async-limited', + "cpp_protobuf_async_streaming_qps_unconstrained_1cq_%s" + % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async-limited", secure=secure, client_threads_per_cq=1000000, server_threads_per_cq=1000000, - categories=inproc_categories + [SCALABLE]) + categories=inproc_categories + [SCALABLE], + ) yield _ping_pong_scenario( - 'cpp_protobuf_async_unary_qps_unconstrained_1cq_%s' % secstr, - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async-limited', + "cpp_protobuf_async_unary_qps_unconstrained_1cq_%s" % secstr, + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async-limited", secure=secure, client_threads_per_cq=1000000, server_threads_per_cq=1000000, - categories=inproc_categories + [SCALABLE]) + categories=inproc_categories + [SCALABLE], + ) yield _ping_pong_scenario( - 'cpp_generic_async_streaming_qps_one_server_core_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async-limited', + "cpp_generic_async_streaming_qps_one_server_core_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async-limited", use_generic_payload=True, async_server_threads=1, minimal_stack=not secure, secure=secure, - categories=[SWEEP]) + categories=[SWEEP], + ) yield _ping_pong_scenario( - 'cpp_protobuf_async_client_sync_server_unary_qps_unconstrained_%s' + "cpp_protobuf_async_client_sync_server_unary_qps_unconstrained_%s" % (secstr), - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='SYNC_SERVER', - unconstrained_client='async', + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="SYNC_SERVER", + unconstrained_client="async", secure=secure, minimal_stack=not secure, - categories=smoketest_categories + inproc_categories + - [SCALABLE]) + categories=smoketest_categories + + inproc_categories + + [SCALABLE], + ) yield _ping_pong_scenario( - 'cpp_protobuf_async_client_unary_1channel_64wide_128Breq_8MBresp_%s' + "cpp_protobuf_async_client_unary_1channel_64wide_128Breq_8MBresp_%s" % (secstr), - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", channels=1, outstanding=64, req_size=128, resp_size=8 * 1024 * 1024, secure=secure, minimal_stack=not secure, - categories=inproc_categories + [SCALABLE]) + categories=inproc_categories + [SCALABLE], + ) yield _ping_pong_scenario( - 'cpp_protobuf_async_client_sync_server_streaming_qps_unconstrained_%s' + "cpp_protobuf_async_client_sync_server_streaming_qps_unconstrained_%s" % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='SYNC_SERVER', - unconstrained_client='async', + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="SYNC_SERVER", + unconstrained_client="async", secure=secure, minimal_stack=not secure, - categories=[SWEEP]) + categories=[SWEEP], + ) yield _ping_pong_scenario( - 'cpp_protobuf_async_unary_ping_pong_%s_1MB' % secstr, - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "cpp_protobuf_async_unary_ping_pong_%s_1MB" % secstr, + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", req_size=1024 * 1024, resp_size=1024 * 1024, secure=secure, minimal_stack=not secure, - categories=smoketest_categories + inproc_categories + - [SCALABLE]) + categories=smoketest_categories + + inproc_categories + + [SCALABLE], + ) for rpc_type in [ - 'unary', 'streaming', 'streaming_from_client', - 'streaming_from_server' + "unary", + "streaming", + "streaming_from_client", + "streaming_from_server", ]: - for synchronicity in ['sync', 'async']: + for synchronicity in ["sync", "async"]: yield _ping_pong_scenario( - 'cpp_protobuf_%s_%s_ping_pong_%s' % - (synchronicity, rpc_type, secstr), + "cpp_protobuf_%s_%s_ping_pong_%s" + % (synchronicity, rpc_type, secstr), rpc_type=rpc_type.upper(), - client_type='%s_CLIENT' % synchronicity.upper(), - server_type='%s_SERVER' % synchronicity.upper(), + client_type="%s_CLIENT" % synchronicity.upper(), + server_type="%s_SERVER" % synchronicity.upper(), async_server_threads=1, minimal_stack=not secure, - secure=secure) + secure=secure, + ) - for size in geometric_progression(1, 1024 * 1024 * 1024 + 1, - 8): + for size in geometric_progression( + 1, 1024 * 1024 * 1024 + 1, 8 + ): yield _ping_pong_scenario( - 'cpp_protobuf_%s_%s_qps_unconstrained_%s_%db' % - (synchronicity, rpc_type, secstr, size), + "cpp_protobuf_%s_%s_qps_unconstrained_%s_%db" + % (synchronicity, rpc_type, secstr, size), rpc_type=rpc_type.upper(), req_size=size, resp_size=size, - client_type='%s_CLIENT' % synchronicity.upper(), - server_type='%s_SERVER' % synchronicity.upper(), + client_type="%s_CLIENT" % synchronicity.upper(), + server_type="%s_SERVER" % synchronicity.upper(), unconstrained_client=synchronicity, secure=secure, minimal_stack=not secure, - categories=[SWEEP]) + categories=[SWEEP], + ) maybe_scalable = [SCALABLE] - if rpc_type == 'streaming_from_server' and synchronicity == 'async' and secure: + if ( + rpc_type == "streaming_from_server" + and synchronicity == "async" + and secure + ): # protobuf_async_streaming_from_server_qps_unconstrained_secure is very flaky # and has extremely high variance so running it isn't really useful. # see b/198275705 maybe_scalable = [SWEEP] yield _ping_pong_scenario( - 'cpp_protobuf_%s_%s_qps_unconstrained_%s' % - (synchronicity, rpc_type, secstr), + "cpp_protobuf_%s_%s_qps_unconstrained_%s" + % (synchronicity, rpc_type, secstr), rpc_type=rpc_type.upper(), - client_type='%s_CLIENT' % synchronicity.upper(), - server_type='%s_SERVER' % synchronicity.upper(), + client_type="%s_CLIENT" % synchronicity.upper(), + server_type="%s_SERVER" % synchronicity.upper(), unconstrained_client=synchronicity, secure=secure, minimal_stack=not secure, server_threads_per_cq=2, client_threads_per_cq=2, - categories=inproc_categories + maybe_scalable) + categories=inproc_categories + maybe_scalable, + ) # TODO(vjpai): Re-enable this test. It has a lot of timeouts # and hasn't yet been conclusively identified as a test failure @@ -562,183 +603,212 @@ def scenarios(self): # categories=smoketest_categories+[SCALABLE], # resource_quota_size=500*1024) - if rpc_type == 'streaming': + if rpc_type == "streaming": for mps in geometric_progression(10, 20, 10): yield _ping_pong_scenario( - 'cpp_protobuf_%s_%s_qps_unconstrained_%smps_%s' + "cpp_protobuf_%s_%s_qps_unconstrained_%smps_%s" % (synchronicity, rpc_type, mps, secstr), rpc_type=rpc_type.upper(), - client_type='%s_CLIENT' % synchronicity.upper(), - server_type='%s_SERVER' % synchronicity.upper(), + client_type="%s_CLIENT" % synchronicity.upper(), + server_type="%s_SERVER" % synchronicity.upper(), unconstrained_client=synchronicity, secure=secure, messages_per_stream=mps, minimal_stack=not secure, - categories=inproc_categories + [SCALABLE]) + categories=inproc_categories + [SCALABLE], + ) for mps in geometric_progression(1, 200, math.sqrt(10)): yield _ping_pong_scenario( - 'cpp_protobuf_%s_%s_qps_unconstrained_%smps_%s' + "cpp_protobuf_%s_%s_qps_unconstrained_%smps_%s" % (synchronicity, rpc_type, mps, secstr), rpc_type=rpc_type.upper(), - client_type='%s_CLIENT' % synchronicity.upper(), - server_type='%s_SERVER' % synchronicity.upper(), + client_type="%s_CLIENT" % synchronicity.upper(), + server_type="%s_SERVER" % synchronicity.upper(), unconstrained_client=synchronicity, secure=secure, messages_per_stream=mps, minimal_stack=not secure, - categories=[SWEEP]) + categories=[SWEEP], + ) for channels in geometric_progression( - 1, 20000, math.sqrt(10)): + 1, 20000, math.sqrt(10) + ): for outstanding in geometric_progression( - 1, 200000, math.sqrt(10)): - if synchronicity == 'sync' and outstanding > 1200: + 1, 200000, math.sqrt(10) + ): + if synchronicity == "sync" and outstanding > 1200: continue if outstanding < channels: continue yield _ping_pong_scenario( - 'cpp_protobuf_%s_%s_qps_unconstrained_%s_%d_channels_%d_outstanding' - % (synchronicity, rpc_type, secstr, channels, - outstanding), + "cpp_protobuf_%s_%s_qps_unconstrained_%s_%d_channels_%d_outstanding" + % ( + synchronicity, + rpc_type, + secstr, + channels, + outstanding, + ), rpc_type=rpc_type.upper(), - client_type='%s_CLIENT' % synchronicity.upper(), - server_type='%s_SERVER' % synchronicity.upper(), + client_type="%s_CLIENT" % synchronicity.upper(), + server_type="%s_SERVER" % synchronicity.upper(), unconstrained_client=synchronicity, secure=secure, minimal_stack=not secure, categories=[SWEEP], channels=channels, - outstanding=outstanding) + outstanding=outstanding, + ) def __str__(self): - return 'c++' + return "c++" class CSharpLanguage(Language): """The legacy Grpc.Core implementation from grpc/grpc.""" def worker_cmdline(self): - return ['tools/run_tests/performance/run_worker_csharp.sh'] + return ["tools/run_tests/performance/run_worker_csharp.sh"] def worker_port_offset(self): return 100 def scenarios(self): - yield _ping_pong_scenario('csharp_generic_async_streaming_ping_pong', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - use_generic_payload=True, - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'csharp_generic_async_streaming_ping_pong_insecure_1MB', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', + yield _ping_pong_scenario( + "csharp_generic_async_streaming_ping_pong", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + use_generic_payload=True, + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "csharp_generic_async_streaming_ping_pong_insecure_1MB", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", req_size=1024 * 1024, resp_size=1024 * 1024, use_generic_payload=True, secure=False, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) yield _ping_pong_scenario( - 'csharp_generic_async_streaming_qps_unconstrained_insecure', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + "csharp_generic_async_streaming_qps_unconstrained_insecure", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=False, - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario('csharp_protobuf_async_streaming_ping_pong', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER') - - yield _ping_pong_scenario('csharp_protobuf_async_unary_ping_pong', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'csharp_protobuf_sync_to_async_unary_ping_pong', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER') - - yield _ping_pong_scenario( - 'csharp_protobuf_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'csharp_protobuf_async_streaming_qps_unconstrained', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', - categories=[SCALABLE]) - - yield _ping_pong_scenario('csharp_to_cpp_protobuf_sync_unary_ping_pong', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - server_language='c++', - async_server_threads=1, - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'csharp_to_cpp_protobuf_async_streaming_ping_pong', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='c++', - async_server_threads=1) - - yield _ping_pong_scenario( - 'csharp_to_cpp_protobuf_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', - server_language='c++', - categories=[SCALABLE]) - - yield _ping_pong_scenario( - 'csharp_to_cpp_protobuf_sync_to_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='sync', - server_language='c++', - categories=[SCALABLE]) - - yield _ping_pong_scenario( - 'cpp_to_csharp_protobuf_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', - client_language='c++', - categories=[SCALABLE]) - - yield _ping_pong_scenario('csharp_protobuf_async_unary_ping_pong_1MB', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - req_size=1024 * 1024, - resp_size=1024 * 1024, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "csharp_protobuf_async_streaming_ping_pong", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + ) + + yield _ping_pong_scenario( + "csharp_protobuf_async_unary_ping_pong", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "csharp_protobuf_sync_to_async_unary_ping_pong", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + ) + + yield _ping_pong_scenario( + "csharp_protobuf_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "csharp_protobuf_async_streaming_qps_unconstrained", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", + categories=[SCALABLE], + ) + + yield _ping_pong_scenario( + "csharp_to_cpp_protobuf_sync_unary_ping_pong", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + server_language="c++", + async_server_threads=1, + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "csharp_to_cpp_protobuf_async_streaming_ping_pong", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="c++", + async_server_threads=1, + ) + + yield _ping_pong_scenario( + "csharp_to_cpp_protobuf_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", + server_language="c++", + categories=[SCALABLE], + ) + + yield _ping_pong_scenario( + "csharp_to_cpp_protobuf_sync_to_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="sync", + server_language="c++", + categories=[SCALABLE], + ) + + yield _ping_pong_scenario( + "cpp_to_csharp_protobuf_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", + client_language="c++", + categories=[SCALABLE], + ) + + yield _ping_pong_scenario( + "csharp_protobuf_async_unary_ping_pong_1MB", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + req_size=1024 * 1024, + resp_size=1024 * 1024, + categories=[SMOKETEST, SCALABLE], + ) def __str__(self): - return 'csharp' + return "csharp" class DotnetLanguage(Language): @@ -747,142 +817,160 @@ class DotnetLanguage(Language): def worker_cmdline(self): # grpc-dotnet worker is only supported by the new GKE based OSS benchmark # framework, and the worker_cmdline() is only used by run_performance_tests.py - return ['grpc_dotnet_not_supported_by_legacy_performance_runner.sh'] + return ["grpc_dotnet_not_supported_by_legacy_performance_runner.sh"] def worker_port_offset(self): return 1100 def scenarios(self): - yield _ping_pong_scenario('dotnet_generic_async_streaming_ping_pong', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - use_generic_payload=True, - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'dotnet_generic_async_streaming_ping_pong_insecure_1MB', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', + yield _ping_pong_scenario( + "dotnet_generic_async_streaming_ping_pong", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + use_generic_payload=True, + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "dotnet_generic_async_streaming_ping_pong_insecure_1MB", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", req_size=1024 * 1024, resp_size=1024 * 1024, use_generic_payload=True, secure=False, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) yield _ping_pong_scenario( - 'dotnet_generic_async_streaming_qps_unconstrained_insecure', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + "dotnet_generic_async_streaming_qps_unconstrained_insecure", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=False, - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario('dotnet_protobuf_async_streaming_ping_pong', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER') - - yield _ping_pong_scenario('dotnet_protobuf_async_unary_ping_pong', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'dotnet_protobuf_sync_to_async_unary_ping_pong', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER') - - yield _ping_pong_scenario( - 'dotnet_protobuf_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'dotnet_protobuf_async_streaming_qps_unconstrained', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', - categories=[SCALABLE]) - - yield _ping_pong_scenario('dotnet_to_cpp_protobuf_sync_unary_ping_pong', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - server_language='c++', - async_server_threads=1, - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'dotnet_to_cpp_protobuf_async_streaming_ping_pong', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='c++', - async_server_threads=1) - - yield _ping_pong_scenario( - 'dotnet_to_cpp_protobuf_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', - server_language='c++', - categories=[SCALABLE]) - - yield _ping_pong_scenario( - 'dotnet_to_cpp_protobuf_sync_to_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='sync', - server_language='c++', - categories=[SCALABLE]) - - yield _ping_pong_scenario( - 'cpp_to_dotnet_protobuf_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', - client_language='c++', - categories=[SCALABLE]) - - yield _ping_pong_scenario('dotnet_protobuf_async_unary_ping_pong_1MB', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - req_size=1024 * 1024, - resp_size=1024 * 1024, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "dotnet_protobuf_async_streaming_ping_pong", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + ) + + yield _ping_pong_scenario( + "dotnet_protobuf_async_unary_ping_pong", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "dotnet_protobuf_sync_to_async_unary_ping_pong", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + ) + + yield _ping_pong_scenario( + "dotnet_protobuf_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "dotnet_protobuf_async_streaming_qps_unconstrained", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", + categories=[SCALABLE], + ) + + yield _ping_pong_scenario( + "dotnet_to_cpp_protobuf_sync_unary_ping_pong", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + server_language="c++", + async_server_threads=1, + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "dotnet_to_cpp_protobuf_async_streaming_ping_pong", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="c++", + async_server_threads=1, + ) + + yield _ping_pong_scenario( + "dotnet_to_cpp_protobuf_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", + server_language="c++", + categories=[SCALABLE], + ) + + yield _ping_pong_scenario( + "dotnet_to_cpp_protobuf_sync_to_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="sync", + server_language="c++", + categories=[SCALABLE], + ) + + yield _ping_pong_scenario( + "cpp_to_dotnet_protobuf_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", + client_language="c++", + categories=[SCALABLE], + ) + + yield _ping_pong_scenario( + "dotnet_protobuf_async_unary_ping_pong_1MB", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + req_size=1024 * 1024, + resp_size=1024 * 1024, + categories=[SMOKETEST, SCALABLE], + ) def __str__(self): - return 'dotnet' + return "dotnet" class PythonLanguage(Language): - def worker_cmdline(self): - return ['tools/run_tests/performance/run_worker_python.sh'] + return ["tools/run_tests/performance/run_worker_python.sh"] def worker_port_offset(self): return 500 def scenarios(self): yield _ping_pong_scenario( - 'python_protobuf_async_unary_5000rpcs_1KB_psm', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_protobuf_async_unary_5000rpcs_1KB_psm", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", req_size=1024, resp_size=1024, outstanding=5000, @@ -890,87 +978,102 @@ def scenarios(self): num_clients=1, secure=False, async_server_threads=1, - categories=[PSM]) - - yield _ping_pong_scenario('python_generic_sync_streaming_ping_pong', - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - use_generic_payload=True, - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario('python_protobuf_sync_streaming_ping_pong', - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER') - - yield _ping_pong_scenario('python_protobuf_async_unary_ping_pong', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER') - - yield _ping_pong_scenario('python_protobuf_sync_unary_ping_pong', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'python_protobuf_sync_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='sync') - - yield _ping_pong_scenario( - 'python_protobuf_sync_streaming_qps_unconstrained', - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='sync') - - yield _ping_pong_scenario('python_to_cpp_protobuf_sync_unary_ping_pong', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='c++', - async_server_threads=0, - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario( - 'python_to_cpp_protobuf_sync_streaming_ping_pong', - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='c++', - async_server_threads=1) - - yield _ping_pong_scenario('python_protobuf_sync_unary_ping_pong_1MB', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - req_size=1024 * 1024, - resp_size=1024 * 1024, - categories=[SMOKETEST, SCALABLE]) + categories=[PSM], + ) + + yield _ping_pong_scenario( + "python_generic_sync_streaming_ping_pong", + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + use_generic_payload=True, + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "python_protobuf_sync_streaming_ping_pong", + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + ) + + yield _ping_pong_scenario( + "python_protobuf_async_unary_ping_pong", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + ) + + yield _ping_pong_scenario( + "python_protobuf_sync_unary_ping_pong", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "python_protobuf_sync_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="sync", + ) + + yield _ping_pong_scenario( + "python_protobuf_sync_streaming_qps_unconstrained", + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="sync", + ) + + yield _ping_pong_scenario( + "python_to_cpp_protobuf_sync_unary_ping_pong", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="c++", + async_server_threads=0, + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "python_to_cpp_protobuf_sync_streaming_ping_pong", + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="c++", + async_server_threads=1, + ) + + yield _ping_pong_scenario( + "python_protobuf_sync_unary_ping_pong_1MB", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + req_size=1024 * 1024, + resp_size=1024 * 1024, + categories=[SMOKETEST, SCALABLE], + ) def __str__(self): - return 'python' + return "python" class PythonAsyncIOLanguage(Language): - def worker_cmdline(self): - return ['tools/run_tests/performance/run_worker_python_asyncio.sh'] + return ["tools/run_tests/performance/run_worker_python_asyncio.sh"] def worker_port_offset(self): return 1200 def scenarios(self): yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_unary_5000rpcs_1KB_psm', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_protobuf_async_unary_5000rpcs_1KB_psm", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", req_size=1024, resp_size=1024, outstanding=5000, @@ -978,207 +1081,230 @@ def scenarios(self): num_clients=1, secure=False, async_server_threads=1, - categories=[PSM]) + categories=[PSM], + ) for outstanding in [64, 128, 256, 512]: for channels in [1, 4]: yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_unary_ping_pong_%dx%d_max' % - ( + "python_asyncio_protobuf_async_unary_ping_pong_%dx%d_max" + % ( outstanding, channels, ), - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", outstanding=outstanding * channels, channels=channels, client_processes=0, server_processes=0, - unconstrained_client='async', - categories=[SCALABLE]) + unconstrained_client="async", + categories=[SCALABLE], + ) yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_unary_ping_pong_%d_1thread' % - outstanding, - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_protobuf_async_unary_ping_pong_%d_1thread" + % outstanding, + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", outstanding=outstanding, channels=1, client_processes=1, server_processes=1, - unconstrained_client='async', - categories=[SCALABLE]) + unconstrained_client="async", + categories=[SCALABLE], + ) yield _ping_pong_scenario( - 'python_asyncio_generic_async_streaming_ping_pong', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', + "python_asyncio_generic_async_streaming_ping_pong", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", channels=1, client_processes=1, server_processes=1, use_generic_payload=True, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_streaming_ping_pong', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_protobuf_async_streaming_ping_pong", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", channels=1, client_processes=1, server_processes=1, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_unary_ping_pong', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_protobuf_async_unary_ping_pong", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", client_processes=1, server_processes=1, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_unary_ping_pong', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_protobuf_async_unary_ping_pong", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", channels=1, client_processes=1, server_processes=1, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_protobuf_async_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", channels=1, - unconstrained_client='async') + unconstrained_client="async", + ) yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_streaming_qps_unconstrained', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_protobuf_async_streaming_qps_unconstrained", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", channels=1, - unconstrained_client='async') + unconstrained_client="async", + ) yield _ping_pong_scenario( - 'python_asyncio_to_cpp_protobuf_async_unary_ping_pong_1thread', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='c++', + "python_asyncio_to_cpp_protobuf_async_unary_ping_pong_1thread", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="c++", channels=1, client_processes=1, - unconstrained_client='async', - categories=[SMOKETEST, SCALABLE]) + unconstrained_client="async", + categories=[SMOKETEST, SCALABLE], + ) yield _ping_pong_scenario( - 'python_asyncio_to_cpp_protobuf_async_unary_ping_pong_max', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', + "python_asyncio_to_cpp_protobuf_async_unary_ping_pong_max", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", channels=1, client_processes=0, - server_language='c++', - categories=[SMOKETEST, SCALABLE]) + server_language="c++", + categories=[SMOKETEST, SCALABLE], + ) yield _ping_pong_scenario( - 'python_asyncio_to_cpp_protobuf_sync_streaming_ping_pong_1thread', - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_to_cpp_protobuf_sync_streaming_ping_pong_1thread", + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", channels=1, client_processes=1, server_processes=1, - unconstrained_client='async', - server_language='c++') + unconstrained_client="async", + server_language="c++", + ) yield _ping_pong_scenario( - 'python_asyncio_protobuf_async_unary_ping_pong_1MB', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "python_asyncio_protobuf_async_unary_ping_pong_1MB", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", req_size=1024 * 1024, resp_size=1024 * 1024, channels=1, client_processes=1, server_processes=1, - categories=[SMOKETEST, SCALABLE]) + categories=[SMOKETEST, SCALABLE], + ) def __str__(self): - return 'python_asyncio' + return "python_asyncio" class RubyLanguage(Language): - def worker_cmdline(self): - return ['tools/run_tests/performance/run_worker_ruby.sh'] + return ["tools/run_tests/performance/run_worker_ruby.sh"] def worker_port_offset(self): return 300 def scenarios(self): - yield _ping_pong_scenario('ruby_protobuf_sync_streaming_ping_pong', - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario('ruby_protobuf_unary_ping_pong', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - categories=[SMOKETEST, SCALABLE]) - - yield _ping_pong_scenario('ruby_protobuf_sync_unary_qps_unconstrained', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - unconstrained_client='sync') - - yield _ping_pong_scenario( - 'ruby_protobuf_sync_streaming_qps_unconstrained', - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - unconstrained_client='sync') - - yield _ping_pong_scenario('ruby_to_cpp_protobuf_sync_unary_ping_pong', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - server_language='c++', - async_server_threads=1) - - yield _ping_pong_scenario( - 'ruby_to_cpp_protobuf_sync_streaming_ping_pong', - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - server_language='c++', - async_server_threads=1) - - yield _ping_pong_scenario('ruby_protobuf_unary_ping_pong_1MB', - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - req_size=1024 * 1024, - resp_size=1024 * 1024, - categories=[SMOKETEST, SCALABLE]) + yield _ping_pong_scenario( + "ruby_protobuf_sync_streaming_ping_pong", + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "ruby_protobuf_unary_ping_pong", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + categories=[SMOKETEST, SCALABLE], + ) + + yield _ping_pong_scenario( + "ruby_protobuf_sync_unary_qps_unconstrained", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + unconstrained_client="sync", + ) + + yield _ping_pong_scenario( + "ruby_protobuf_sync_streaming_qps_unconstrained", + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + unconstrained_client="sync", + ) + + yield _ping_pong_scenario( + "ruby_to_cpp_protobuf_sync_unary_ping_pong", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + server_language="c++", + async_server_threads=1, + ) + + yield _ping_pong_scenario( + "ruby_to_cpp_protobuf_sync_streaming_ping_pong", + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + server_language="c++", + async_server_threads=1, + ) + + yield _ping_pong_scenario( + "ruby_protobuf_unary_ping_pong_1MB", + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + req_size=1024 * 1024, + resp_size=1024 * 1024, + categories=[SMOKETEST, SCALABLE], + ) def __str__(self): - return 'ruby' + return "ruby" class Php7Language(Language): - def __init__(self, php7_protobuf_c=False): super().__init__() self.php7_protobuf_c = php7_protobuf_c @@ -1186,10 +1312,10 @@ def __init__(self, php7_protobuf_c=False): def worker_cmdline(self): if self.php7_protobuf_c: return [ - 'tools/run_tests/performance/run_worker_php.sh', - '--use_protobuf_c_extension' + "tools/run_tests/performance/run_worker_php.sh", + "--use_protobuf_c_extension", ] - return ['tools/run_tests/performance/run_worker_php.sh'] + return ["tools/run_tests/performance/run_worker_php.sh"] def worker_port_offset(self): if self.php7_protobuf_c: @@ -1197,17 +1323,17 @@ def worker_port_offset(self): return 800 def scenarios(self): - php7_extension_mode = 'php7_protobuf_php_extension' + php7_extension_mode = "php7_protobuf_php_extension" if self.php7_protobuf_c: - php7_extension_mode = 'php7_protobuf_c_extension' + php7_extension_mode = "php7_protobuf_c_extension" yield _ping_pong_scenario( - '%s_to_cpp_protobuf_async_unary_5000rpcs_1KB_psm' % - php7_extension_mode, - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='c++', + "%s_to_cpp_protobuf_async_unary_5000rpcs_1KB_psm" + % php7_extension_mode, + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="c++", req_size=1024, resp_size=1024, outstanding=5000, @@ -1215,273 +1341,294 @@ def scenarios(self): num_clients=1, secure=False, async_server_threads=1, - categories=[PSM]) + categories=[PSM], + ) - yield _ping_pong_scenario('%s_to_cpp_protobuf_sync_unary_ping_pong' % - php7_extension_mode, - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - server_language='c++', - async_server_threads=1) + yield _ping_pong_scenario( + "%s_to_cpp_protobuf_sync_unary_ping_pong" % php7_extension_mode, + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + server_language="c++", + async_server_threads=1, + ) yield _ping_pong_scenario( - '%s_to_cpp_protobuf_sync_streaming_ping_pong' % php7_extension_mode, - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - server_language='c++', - async_server_threads=1) + "%s_to_cpp_protobuf_sync_streaming_ping_pong" % php7_extension_mode, + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + server_language="c++", + async_server_threads=1, + ) # TODO(ddyihai): Investigate why when async_server_threads=1/CPU usage 340%, the QPS performs # better than async_server_threads=0/CPU usage 490%. yield _ping_pong_scenario( - '%s_to_cpp_protobuf_sync_unary_qps_unconstrained' % - php7_extension_mode, - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='c++', + "%s_to_cpp_protobuf_sync_unary_qps_unconstrained" + % php7_extension_mode, + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="c++", outstanding=1, async_server_threads=1, - unconstrained_client='sync') + unconstrained_client="sync", + ) yield _ping_pong_scenario( - '%s_to_cpp_protobuf_sync_streaming_qps_unconstrained' % - php7_extension_mode, - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='c++', + "%s_to_cpp_protobuf_sync_streaming_qps_unconstrained" + % php7_extension_mode, + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="c++", outstanding=1, async_server_threads=1, - unconstrained_client='sync') + unconstrained_client="sync", + ) def __str__(self): if self.php7_protobuf_c: - return 'php7_protobuf_c' - return 'php7' + return "php7_protobuf_c" + return "php7" class JavaLanguage(Language): - def worker_cmdline(self): - return ['tools/run_tests/performance/run_worker_java.sh'] + return ["tools/run_tests/performance/run_worker_java.sh"] def worker_port_offset(self): return 400 def scenarios(self): - yield _ping_pong_scenario('java_protobuf_async_unary_5000rpcs_1KB_psm', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - req_size=1024, - resp_size=1024, - outstanding=5000, - channels=1, - num_clients=1, - secure=False, - async_server_threads=1, - warmup_seconds=JAVA_WARMUP_SECONDS, - categories=[PSM]) + yield _ping_pong_scenario( + "java_protobuf_async_unary_5000rpcs_1KB_psm", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + req_size=1024, + resp_size=1024, + outstanding=5000, + channels=1, + num_clients=1, + secure=False, + async_server_threads=1, + warmup_seconds=JAVA_WARMUP_SECONDS, + categories=[PSM], + ) for secure in [True, False]: - secstr = 'secure' if secure else 'insecure' + secstr = "secure" if secure else "insecure" smoketest_categories = ([SMOKETEST] if secure else []) + [SCALABLE] yield _ping_pong_scenario( - 'java_generic_async_streaming_ping_pong_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', + "java_generic_async_streaming_ping_pong_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", use_generic_payload=True, async_server_threads=1, secure=secure, warmup_seconds=JAVA_WARMUP_SECONDS, - categories=smoketest_categories) + categories=smoketest_categories, + ) yield _ping_pong_scenario( - 'java_protobuf_async_streaming_ping_pong_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', + "java_protobuf_async_streaming_ping_pong_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", async_server_threads=1, secure=secure, - warmup_seconds=JAVA_WARMUP_SECONDS) - - yield _ping_pong_scenario('java_protobuf_async_unary_ping_pong_%s' % - secstr, - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - async_server_threads=1, - secure=secure, - warmup_seconds=JAVA_WARMUP_SECONDS, - categories=smoketest_categories) - - yield _ping_pong_scenario('java_protobuf_unary_ping_pong_%s' % - secstr, - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - async_server_threads=1, - secure=secure, - warmup_seconds=JAVA_WARMUP_SECONDS) + warmup_seconds=JAVA_WARMUP_SECONDS, + ) + + yield _ping_pong_scenario( + "java_protobuf_async_unary_ping_pong_%s" % secstr, + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + async_server_threads=1, + secure=secure, + warmup_seconds=JAVA_WARMUP_SECONDS, + categories=smoketest_categories, + ) + + yield _ping_pong_scenario( + "java_protobuf_unary_ping_pong_%s" % secstr, + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + async_server_threads=1, + secure=secure, + warmup_seconds=JAVA_WARMUP_SECONDS, + ) yield _ping_pong_scenario( - 'java_protobuf_async_unary_qps_unconstrained_%s' % secstr, - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', + "java_protobuf_async_unary_qps_unconstrained_%s" % secstr, + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", secure=secure, warmup_seconds=JAVA_WARMUP_SECONDS, - categories=smoketest_categories + [SCALABLE]) + categories=smoketest_categories + [SCALABLE], + ) yield _ping_pong_scenario( - 'java_protobuf_async_streaming_qps_unconstrained_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - unconstrained_client='async', + "java_protobuf_async_streaming_qps_unconstrained_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + unconstrained_client="async", secure=secure, warmup_seconds=JAVA_WARMUP_SECONDS, - categories=[SCALABLE]) + categories=[SCALABLE], + ) yield _ping_pong_scenario( - 'java_generic_async_streaming_qps_unconstrained_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + "java_generic_async_streaming_qps_unconstrained_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=secure, warmup_seconds=JAVA_WARMUP_SECONDS, - categories=[SCALABLE]) + categories=[SCALABLE], + ) yield _ping_pong_scenario( - 'java_generic_async_streaming_qps_one_server_core_%s' % secstr, - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async-limited', + "java_generic_async_streaming_qps_one_server_core_%s" % secstr, + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async-limited", use_generic_payload=True, async_server_threads=1, secure=secure, - warmup_seconds=JAVA_WARMUP_SECONDS) + warmup_seconds=JAVA_WARMUP_SECONDS, + ) # TODO(jtattermusch): add scenarios java vs C++ def __str__(self): - return 'java' + return "java" class GoLanguage(Language): - def worker_cmdline(self): - return ['tools/run_tests/performance/run_worker_go.sh'] + return ["tools/run_tests/performance/run_worker_go.sh"] def worker_port_offset(self): return 600 def scenarios(self): - yield _ping_pong_scenario('go_protobuf_async_unary_5000rpcs_1KB_psm', - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - req_size=1024, - resp_size=1024, - outstanding=5000, - channels=1, - num_clients=1, - secure=False, - async_server_threads=1, - categories=[PSM]) + yield _ping_pong_scenario( + "go_protobuf_async_unary_5000rpcs_1KB_psm", + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + req_size=1024, + resp_size=1024, + outstanding=5000, + channels=1, + num_clients=1, + secure=False, + async_server_threads=1, + categories=[PSM], + ) for secure in [True, False]: - secstr = 'secure' if secure else 'insecure' + secstr = "secure" if secure else "insecure" smoketest_categories = ([SMOKETEST] if secure else []) + [SCALABLE] # ASYNC_GENERIC_SERVER for Go actually uses a sync streaming server, # but that's mostly because of lack of better name of the enum value. - yield _ping_pong_scenario('go_generic_sync_streaming_ping_pong_%s' % - secstr, - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - use_generic_payload=True, - async_server_threads=1, - secure=secure, - categories=smoketest_categories) + yield _ping_pong_scenario( + "go_generic_sync_streaming_ping_pong_%s" % secstr, + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + use_generic_payload=True, + async_server_threads=1, + secure=secure, + categories=smoketest_categories, + ) yield _ping_pong_scenario( - 'go_protobuf_sync_streaming_ping_pong_%s' % secstr, - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', + "go_protobuf_sync_streaming_ping_pong_%s" % secstr, + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", async_server_threads=1, - secure=secure) + secure=secure, + ) - yield _ping_pong_scenario('go_protobuf_sync_unary_ping_pong_%s' % - secstr, - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - async_server_threads=1, - secure=secure, - categories=smoketest_categories) + yield _ping_pong_scenario( + "go_protobuf_sync_unary_ping_pong_%s" % secstr, + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + async_server_threads=1, + secure=secure, + categories=smoketest_categories, + ) # unconstrained_client='async' is intended (client uses goroutines) yield _ping_pong_scenario( - 'go_protobuf_sync_unary_qps_unconstrained_%s' % secstr, - rpc_type='UNARY', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - unconstrained_client='async', + "go_protobuf_sync_unary_qps_unconstrained_%s" % secstr, + rpc_type="UNARY", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + unconstrained_client="async", secure=secure, - categories=smoketest_categories + [SCALABLE]) + categories=smoketest_categories + [SCALABLE], + ) # unconstrained_client='async' is intended (client uses goroutines) yield _ping_pong_scenario( - 'go_protobuf_sync_streaming_qps_unconstrained_%s' % secstr, - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='SYNC_SERVER', - unconstrained_client='async', + "go_protobuf_sync_streaming_qps_unconstrained_%s" % secstr, + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="SYNC_SERVER", + unconstrained_client="async", secure=secure, - categories=[SCALABLE]) + categories=[SCALABLE], + ) # unconstrained_client='async' is intended (client uses goroutines) # ASYNC_GENERIC_SERVER for Go actually uses a sync streaming server, # but that's mostly because of lack of better name of the enum value. yield _ping_pong_scenario( - 'go_generic_sync_streaming_qps_unconstrained_%s' % secstr, - rpc_type='STREAMING', - client_type='SYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - unconstrained_client='async', + "go_generic_sync_streaming_qps_unconstrained_%s" % secstr, + rpc_type="STREAMING", + client_type="SYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + unconstrained_client="async", use_generic_payload=True, secure=secure, - categories=[SCALABLE]) + categories=[SCALABLE], + ) # TODO(jtattermusch): add scenarios go vs C++ def __str__(self): - return 'go' + return "go" class NodeLanguage(Language): - def __init__(self, node_purejs=False): super().__init__() self.node_purejs = node_purejs def worker_cmdline(self): - fixture = 'native_js' if self.node_purejs else 'native_native' + fixture = "native_js" if self.node_purejs else "native_native" return [ - 'tools/run_tests/performance/run_worker_node.sh', fixture, - '--benchmark_impl=grpc' + "tools/run_tests/performance/run_worker_node.sh", + fixture, + "--benchmark_impl=grpc", ] def worker_port_offset(self): @@ -1490,15 +1637,15 @@ def worker_port_offset(self): return 1000 def scenarios(self): - node_implementation = 'node_purejs' if self.node_purejs else 'node' + node_implementation = "node_purejs" if self.node_purejs else "node" yield _ping_pong_scenario( - '%s_to_node_protobuf_async_unary_5000rpcs_1KB_psm' % - (node_implementation), - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='node', + "%s_to_node_protobuf_async_unary_5000rpcs_1KB_psm" + % (node_implementation), + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="node", req_size=1024, resp_size=1024, outstanding=5000, @@ -1506,98 +1653,105 @@ def scenarios(self): num_clients=1, secure=False, async_server_threads=1, - categories=[PSM]) + categories=[PSM], + ) for secure in [True, False]: - secstr = 'secure' if secure else 'insecure' + secstr = "secure" if secure else "insecure" smoketest_categories = ([SMOKETEST] if secure else []) + [SCALABLE] yield _ping_pong_scenario( - '%s_to_node_generic_async_streaming_ping_pong_%s' % - (node_implementation, secstr), - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - server_language='node', + "%s_to_node_generic_async_streaming_ping_pong_%s" + % (node_implementation, secstr), + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + server_language="node", use_generic_payload=True, async_server_threads=1, secure=secure, - categories=smoketest_categories) + categories=smoketest_categories, + ) yield _ping_pong_scenario( - '%s_to_node_protobuf_async_streaming_ping_pong_%s' % - (node_implementation, secstr), - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='node', + "%s_to_node_protobuf_async_streaming_ping_pong_%s" + % (node_implementation, secstr), + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="node", async_server_threads=1, - secure=secure) + secure=secure, + ) yield _ping_pong_scenario( - '%s_to_node_protobuf_async_unary_ping_pong_%s' % - (node_implementation, secstr), - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='node', + "%s_to_node_protobuf_async_unary_ping_pong_%s" + % (node_implementation, secstr), + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="node", async_server_threads=1, secure=secure, - categories=smoketest_categories) + categories=smoketest_categories, + ) yield _ping_pong_scenario( - '%s_to_node_protobuf_async_unary_qps_unconstrained_%s' % - (node_implementation, secstr), - rpc_type='UNARY', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='node', - unconstrained_client='async', + "%s_to_node_protobuf_async_unary_qps_unconstrained_%s" + % (node_implementation, secstr), + rpc_type="UNARY", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="node", + unconstrained_client="async", secure=secure, - categories=smoketest_categories + [SCALABLE]) + categories=smoketest_categories + [SCALABLE], + ) yield _ping_pong_scenario( - '%s_to_node_protobuf_async_streaming_qps_unconstrained_%s' % - (node_implementation, secstr), - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_SERVER', - server_language='node', - unconstrained_client='async', + "%s_to_node_protobuf_async_streaming_qps_unconstrained_%s" + % (node_implementation, secstr), + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_SERVER", + server_language="node", + unconstrained_client="async", secure=secure, - categories=[SCALABLE]) + categories=[SCALABLE], + ) yield _ping_pong_scenario( - '%s_to_node_generic_async_streaming_qps_unconstrained_%s' % - (node_implementation, secstr), - rpc_type='STREAMING', - client_type='ASYNC_CLIENT', - server_type='ASYNC_GENERIC_SERVER', - server_language='node', - unconstrained_client='async', + "%s_to_node_generic_async_streaming_qps_unconstrained_%s" + % (node_implementation, secstr), + rpc_type="STREAMING", + client_type="ASYNC_CLIENT", + server_type="ASYNC_GENERIC_SERVER", + server_language="node", + unconstrained_client="async", use_generic_payload=True, secure=secure, - categories=[SCALABLE]) + categories=[SCALABLE], + ) # TODO(murgatroid99): add scenarios node vs C++ def __str__(self): if self.node_purejs: - return 'node_purejs' - return 'node' + return "node_purejs" + return "node" LANGUAGES = { - 'c++': CXXLanguage(), - 'csharp': CSharpLanguage(), - 'dotnet': DotnetLanguage(), - 'ruby': RubyLanguage(), - 'php7': Php7Language(), - 'php7_protobuf_c': Php7Language(php7_protobuf_c=True), - 'java': JavaLanguage(), - 'python': PythonLanguage(), - 'python_asyncio': PythonAsyncIOLanguage(), - 'go': GoLanguage(), - 'node': NodeLanguage(), - 'node_purejs': NodeLanguage(node_purejs=True) + "c++": CXXLanguage(), + "csharp": CSharpLanguage(), + "dotnet": DotnetLanguage(), + "ruby": RubyLanguage(), + "php7": Php7Language(), + "php7_protobuf_c": Php7Language(php7_protobuf_c=True), + "java": JavaLanguage(), + "python": PythonLanguage(), + "python_asyncio": PythonAsyncIOLanguage(), + "go": GoLanguage(), + "node": NodeLanguage(), + "node_purejs": NodeLanguage(node_purejs=True), } diff --git a/tools/run_tests/performance/scenario_config_exporter.py b/tools/run_tests/performance/scenario_config_exporter.py index a5becefc06bba..6ce831678358b 100755 --- a/tools/run_tests/performance/scenario_config_exporter.py +++ b/tools/run_tests/performance/scenario_config_exporter.py @@ -48,47 +48,54 @@ # Language parameters for load test config generation. -LanguageConfig = NamedTuple('LanguageConfig', [('category', str), - ('language', str), - ('client_language', str), - ('server_language', str)]) +LanguageConfig = NamedTuple( + "LanguageConfig", + [ + ("category", str), + ("language", str), + ("client_language", str), + ("server_language", str), + ], +) def category_string(categories: Iterable[str], category: str) -> str: """Converts a list of categories into a single string for counting.""" - if category != 'all': - return category if category in categories else '' + if category != "all": + return category if category in categories else "" - main_categories = ('scalable', 'smoketest') + main_categories = ("scalable", "smoketest") s = set(categories) c = [m for m in main_categories if m in s] s.difference_update(main_categories) c.extend(s) - return ' '.join(c) + return " ".join(c) def gen_scenario_languages(category: str) -> Iterable[LanguageConfig]: """Generates tuples containing the languages specified in each scenario.""" for language in scenario_config.LANGUAGES: for scenario in scenario_config.LANGUAGES[language].scenarios(): - client_language = scenario.get('CLIENT_LANGUAGE', '') - server_language = scenario.get('SERVER_LANGUAGE', '') - categories = scenario.get('CATEGORIES', []) - if category != 'all' and category not in categories: + client_language = scenario.get("CLIENT_LANGUAGE", "") + server_language = scenario.get("SERVER_LANGUAGE", "") + categories = scenario.get("CATEGORIES", []) + if category != "all" and category not in categories: continue cat = category_string(categories, category) - yield LanguageConfig(category=cat, - language=language, - client_language=client_language, - server_language=server_language) + yield LanguageConfig( + category=cat, + language=language, + client_language=client_language, + server_language=server_language, + ) def scenario_filter( - scenario_name_regex: str = '.*', - category: str = 'all', - client_language: str = '', - server_language: str = '', + scenario_name_regex: str = ".*", + category: str = "all", + client_language: str = "", + server_language: str = "", ) -> Callable[[Dict[str, Any]], bool]: """Returns a function to filter scenarios to process.""" @@ -99,16 +106,17 @@ def filter_scenario(scenario: Dict[str, Any]) -> bool: # if the 'CATEGORIES' key is missing, treat scenario as part of # 'scalable' and 'smoketest'. This matches the behavior of # run_performance_tests.py. - scenario_categories = scenario.get('CATEGORIES', - ['scalable', 'smoketest']) - if category not in scenario_categories and category != 'all': + scenario_categories = scenario.get( + "CATEGORIES", ["scalable", "smoketest"] + ) + if category not in scenario_categories and category != "all": return False - scenario_client_language = scenario.get('CLIENT_LANGUAGE', '') + scenario_client_language = scenario.get("CLIENT_LANGUAGE", "") if client_language != scenario_client_language: return False - scenario_server_language = scenario.get('SERVER_LANGUAGE', '') + scenario_server_language = scenario.get("SERVER_LANGUAGE", "") if server_language != scenario_server_language: return False @@ -118,102 +126,128 @@ def filter_scenario(scenario: Dict[str, Any]) -> bool: def gen_scenarios( - language_name: str, scenario_filter_function: Callable[[Dict[str, Any]], - bool] + language_name: str, + scenario_filter_function: Callable[[Dict[str, Any]], bool], ) -> Iterable[Dict[str, Any]]: """Generates scenarios that match a given filter function.""" return map( scenario_config.remove_nonproto_fields, - filter(scenario_filter_function, - scenario_config.LANGUAGES[language_name].scenarios())) + filter( + scenario_filter_function, + scenario_config.LANGUAGES[language_name].scenarios(), + ), + ) -def dump_to_json_files(scenarios: Iterable[Dict[str, Any]], - filename_prefix: str) -> None: +def dump_to_json_files( + scenarios: Iterable[Dict[str, Any]], filename_prefix: str +) -> None: """Dumps a list of scenarios to JSON files""" count = 0 for scenario in scenarios: - filename = '{}{}.json'.format(filename_prefix, scenario['name']) - print('Writing file {}'.format(filename), file=sys.stderr) - with open(filename, 'w') as outfile: + filename = "{}{}.json".format(filename_prefix, scenario["name"]) + print("Writing file {}".format(filename), file=sys.stderr) + with open(filename, "w") as outfile: # The dump file should have {"scenarios" : []} as the top level # element, when embedded in a LoadTest configuration YAML file. - json.dump({'scenarios': [scenario]}, outfile, indent=2) + json.dump({"scenarios": [scenario]}, outfile, indent=2) count += 1 - print('Wrote {} scenarios'.format(count), file=sys.stderr) + print("Wrote {} scenarios".format(count), file=sys.stderr) def main() -> None: language_choices = sorted(scenario_config.LANGUAGES.keys()) - argp = argparse.ArgumentParser(description='Exports scenarios to files.') - argp.add_argument('--export_scenarios', - action='store_true', - help='Export scenarios to JSON files.') - argp.add_argument('--count_scenarios', - action='store_true', - help='Count scenarios for all test languages.') - argp.add_argument('-l', - '--language', - choices=language_choices, - help='Language to export.') - argp.add_argument('-f', - '--filename_prefix', - default='scenario_dump_', - type=str, - help='Prefix for exported JSON file names.') - argp.add_argument('-r', - '--regex', - default='.*', - type=str, - help='Regex to select scenarios to run.') + argp = argparse.ArgumentParser(description="Exports scenarios to files.") argp.add_argument( - '--category', - default='all', - choices=['all', 'inproc', 'scalable', 'smoketest', 'sweep'], - help='Select scenarios for a category of tests.') + "--export_scenarios", + action="store_true", + help="Export scenarios to JSON files.", + ) argp.add_argument( - '--client_language', - default='', + "--count_scenarios", + action="store_true", + help="Count scenarios for all test languages.", + ) + argp.add_argument( + "-l", "--language", choices=language_choices, help="Language to export." + ) + argp.add_argument( + "-f", + "--filename_prefix", + default="scenario_dump_", + type=str, + help="Prefix for exported JSON file names.", + ) + argp.add_argument( + "-r", + "--regex", + default=".*", + type=str, + help="Regex to select scenarios to run.", + ) + argp.add_argument( + "--category", + default="all", + choices=["all", "inproc", "scalable", "smoketest", "sweep"], + help="Select scenarios for a category of tests.", + ) + argp.add_argument( + "--client_language", + default="", choices=language_choices, - help='Select only scenarios with a specified client language.') + help="Select only scenarios with a specified client language.", + ) argp.add_argument( - '--server_language', - default='', + "--server_language", + default="", choices=language_choices, - help='Select only scenarios with a specified server language.') + help="Select only scenarios with a specified server language.", + ) args = argp.parse_args() if args.export_scenarios and not args.language: - print('Dumping scenarios requires a specified language.', - file=sys.stderr) + print( + "Dumping scenarios requires a specified language.", file=sys.stderr + ) argp.print_usage(file=sys.stderr) return if args.export_scenarios: - s_filter = scenario_filter(scenario_name_regex=args.regex, - category=args.category, - client_language=args.client_language, - server_language=args.server_language) + s_filter = scenario_filter( + scenario_name_regex=args.regex, + category=args.category, + client_language=args.client_language, + server_language=args.server_language, + ) scenarios = gen_scenarios(args.language, s_filter) dump_to_json_files(scenarios, args.filename_prefix) if args.count_scenarios: - print('Scenario count for all languages (category: {}):'.format( - args.category)) - print('{:>5} {:16} {:8} {:8} {}'.format('Count', 'Language', 'Client', - 'Server', 'Categories')) + print( + "Scenario count for all languages (category: {}):".format( + args.category + ) + ) + print( + "{:>5} {:16} {:8} {:8} {}".format( + "Count", "Language", "Client", "Server", "Categories" + ) + ) c = collections.Counter(gen_scenario_languages(args.category)) total = 0 - for ((cat, l, cl, sl), count) in c.most_common(): - print('{count:5} {l:16} {cl:8} {sl:8} {cat}'.format(l=l, - cl=cl, - sl=sl, - count=count, - cat=cat)) + for (cat, l, cl, sl), count in c.most_common(): + print( + "{count:5} {l:16} {cl:8} {sl:8} {cat}".format( + l=l, cl=cl, sl=sl, count=count, cat=cat + ) + ) total += count - print('\n{:>5} total scenarios (category: {})'.format( - total, args.category)) + print( + "\n{:>5} total scenarios (category: {})".format( + total, args.category + ) + ) if __name__ == "__main__": diff --git a/tools/run_tests/python_utils/bazel_report_helper.py b/tools/run_tests/python_utils/bazel_report_helper.py index c62807a585657..2531de7a3c834 100755 --- a/tools/run_tests/python_utils/bazel_report_helper.py +++ b/tools/run_tests/python_utils/bazel_report_helper.py @@ -20,7 +20,7 @@ import sys import uuid -_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) os.chdir(_ROOT) # How long to sleep before querying Resultstore API and uploading to bigquery @@ -31,44 +31,50 @@ def _platform_string(): """Detect current platform""" - if platform.system() == 'Windows': - return 'windows' - elif platform.system()[:7] == 'MSYS_NT': - return 'windows' - elif platform.system() == 'Darwin': - return 'mac' - elif platform.system() == 'Linux': - return 'linux' + if platform.system() == "Windows": + return "windows" + elif platform.system()[:7] == "MSYS_NT": + return "windows" + elif platform.system() == "Darwin": + return "mac" + elif platform.system() == "Linux": + return "linux" else: - return 'posix' + return "posix" def _append_to_kokoro_bazel_invocations(invocation_id: str) -> None: """Kokoro can display "Bazel" result link on kokoro jobs if told so.""" # to get "bazel" link for kokoro build, we need to upload # the "bazel_invocation_ids" file with bazel invocation ID as artifact. - kokoro_artifacts_dir = os.getenv('KOKORO_ARTIFACTS_DIR') + kokoro_artifacts_dir = os.getenv("KOKORO_ARTIFACTS_DIR") if kokoro_artifacts_dir: # append the bazel invocation UUID to the bazel_invocation_ids file. - with open(os.path.join(kokoro_artifacts_dir, 'bazel_invocation_ids'), - 'a') as f: - f.write(invocation_id + '\n') + with open( + os.path.join(kokoro_artifacts_dir, "bazel_invocation_ids"), "a" + ) as f: + f.write(invocation_id + "\n") print( - 'Added invocation ID %s to kokoro "bazel_invocation_ids" artifact' % - invocation_id, - file=sys.stderr) + 'Added invocation ID %s to kokoro "bazel_invocation_ids" artifact' + % invocation_id, + file=sys.stderr, + ) else: print( - 'Skipped adding invocation ID %s to kokoro "bazel_invocation_ids" artifact' - % invocation_id, - file=sys.stderr) + 'Skipped adding invocation ID %s to kokoro "bazel_invocation_ids"' + " artifact" % invocation_id, + file=sys.stderr, + ) pass -def _generate_junit_report_string(report_suite_name: str, invocation_id: str, - success: bool) -> None: +def _generate_junit_report_string( + report_suite_name: str, invocation_id: str, success: bool +) -> None: """Generate sponge_log.xml formatted report, that will make the bazel invocation reachable as a target in resultstore UI / sponge.""" - bazel_invocation_url = 'https://source.cloud.google.com/results/invocations/%s' % invocation_id + bazel_invocation_url = ( + "https://source.cloud.google.com/results/invocations/%s" % invocation_id + ) package_name = report_suite_name # set testcase name to invocation URL. That way, the link will be displayed in some form # resultstore UI and sponge even in case the bazel invocation succeeds. @@ -76,174 +82,210 @@ def _generate_junit_report_string(report_suite_name: str, invocation_id: str, if success: # unfortunately, neither resultstore UI nor sponge display the "system-err" output (or any other tags) # on a passing test case. But at least we tried. - test_output_tag = 'PASSED. See invocation results here: %s' % bazel_invocation_url + test_output_tag = ( + "PASSED. See invocation results here: %s" + % bazel_invocation_url + ) else: # The failure output will be displayes in both resultstore UI and sponge when clicking on the failing testcase. - test_output_tag = 'FAILED. See bazel invocation results here: %s' % bazel_invocation_url + test_output_tag = ( + 'FAILED. See bazel invocation results' + " here: %s" % bazel_invocation_url + ) lines = [ - '', - '' % - (report_suite_name, package_name), + "", + '' + % (report_suite_name, package_name), '' % testcase_name, test_output_tag, - '' - '', - '', + "", + "", ] - return '\n'.join(lines) + return "\n".join(lines) -def _create_bazel_wrapper(report_path: str, report_suite_name: str, - invocation_id: str, upload_results: bool) -> None: +def _create_bazel_wrapper( + report_path: str, + report_suite_name: str, + invocation_id: str, + upload_results: bool, +) -> None: """Create a "bazel wrapper" script that will execute bazel with extra settings and postprocessing.""" os.makedirs(report_path, exist_ok=True) - bazel_wrapper_filename = os.path.join(report_path, 'bazel_wrapper') - bazel_wrapper_bat_filename = bazel_wrapper_filename + '.bat' - bazel_rc_filename = os.path.join(report_path, 'bazel_wrapper.bazelrc') + bazel_wrapper_filename = os.path.join(report_path, "bazel_wrapper") + bazel_wrapper_bat_filename = bazel_wrapper_filename + ".bat" + bazel_rc_filename = os.path.join(report_path, "bazel_wrapper.bazelrc") # put xml reports in a separate directory if requested by GRPC_TEST_REPORT_BASE_DIR - report_base_dir = os.getenv('GRPC_TEST_REPORT_BASE_DIR', None) + report_base_dir = os.getenv("GRPC_TEST_REPORT_BASE_DIR", None) xml_report_path = os.path.abspath( - os.path.join(report_base_dir, report_path - ) if report_base_dir else report_path) + os.path.join(report_base_dir, report_path) + if report_base_dir + else report_path + ) os.makedirs(xml_report_path, exist_ok=True) - failing_report_filename = os.path.join(xml_report_path, 'sponge_log.xml') - success_report_filename = os.path.join(xml_report_path, - 'success_log_to_rename.xml') + failing_report_filename = os.path.join(xml_report_path, "sponge_log.xml") + success_report_filename = os.path.join( + xml_report_path, "success_log_to_rename.xml" + ) - if _platform_string() == 'windows': - workspace_status_command = 'tools/remote_build/workspace_status_kokoro.bat' + if _platform_string() == "windows": + workspace_status_command = ( + "tools/remote_build/workspace_status_kokoro.bat" + ) else: - workspace_status_command = 'tools/remote_build/workspace_status_kokoro.sh' + workspace_status_command = ( + "tools/remote_build/workspace_status_kokoro.sh" + ) # generate RC file with the bazel flags we want to use apply. # Using an RC file solves problems with flag ordering in the wrapper. # (e.g. some flags need to come after the build/test command) - with open(bazel_rc_filename, 'w') as f: + with open(bazel_rc_filename, "w") as f: f.write('build --invocation_id="%s"\n' % invocation_id) - f.write('build --workspace_status_command="%s"\n' % - workspace_status_command) + f.write( + 'build --workspace_status_command="%s"\n' % workspace_status_command + ) # generate "failing" and "success" report # the "failing" is named as "sponge_log.xml", which is the name picked up by sponge/resultstore # so the failing report will be used by default (unless we later replace the report with # one that says "success"). That way if something goes wrong before bazel is run, # there will at least be a "failing" target that indicates that (we really don't want silent failures). - with open(failing_report_filename, 'w') as f: + with open(failing_report_filename, "w") as f: f.write( - _generate_junit_report_string(report_suite_name, - invocation_id, - success=False)) - with open(success_report_filename, 'w') as f: + _generate_junit_report_string( + report_suite_name, invocation_id, success=False + ) + ) + with open(success_report_filename, "w") as f: f.write( - _generate_junit_report_string(report_suite_name, - invocation_id, - success=True)) + _generate_junit_report_string( + report_suite_name, invocation_id, success=True + ) + ) # generate the bazel wrapper for linux/macos - with open(bazel_wrapper_filename, 'w') as f: + with open(bazel_wrapper_filename, "w") as f: intro_lines = [ - '#!/bin/bash', - 'set -ex', - '', - 'tools/bazel --bazelrc="%s" "$@" || FAILED=true' % - bazel_rc_filename, - '', + "#!/bin/bash", + "set -ex", + "", + 'tools/bazel --bazelrc="%s" "$@" || FAILED=true' + % bazel_rc_filename, + "", ] if upload_results: upload_results_lines = [ - 'sleep %s' % _UPLOAD_RBE_RESULTS_DELAY_SECONDS, - 'PYTHONHTTPSVERIFY=0 python3 ./tools/run_tests/python_utils/upload_rbe_results.py --invocation_id="%s"' - % invocation_id, - '', + "sleep %s" % _UPLOAD_RBE_RESULTS_DELAY_SECONDS, + "PYTHONHTTPSVERIFY=0 python3" + " ./tools/run_tests/python_utils/upload_rbe_results.py" + ' --invocation_id="%s"' % invocation_id, + "", ] else: upload_results_lines = [] outro_lines = [ 'if [ "$FAILED" != "" ]', - 'then', - ' exit 1', - 'else', - ' # success: plant the pre-generated xml report that says "success"', - ' mv -f %s %s' % - (success_report_filename, failing_report_filename), - 'fi', + "then", + " exit 1", + "else", + ( + " # success: plant the pre-generated xml report that says" + ' "success"' + ), + " mv -f %s %s" + % (success_report_filename, failing_report_filename), + "fi", ] lines = [ - line + '\n' + line + "\n" for line in intro_lines + upload_results_lines + outro_lines ] f.writelines(lines) os.chmod(bazel_wrapper_filename, 0o775) # make the unix wrapper executable # generate bazel wrapper for windows - with open(bazel_wrapper_bat_filename, 'w') as f: + with open(bazel_wrapper_bat_filename, "w") as f: intro_lines = [ - '@echo on', - '', + "@echo on", + "", 'bazel --bazelrc="%s" %%*' % bazel_rc_filename, - 'set BAZEL_EXITCODE=%errorlevel%', - '', + "set BAZEL_EXITCODE=%errorlevel%", + "", ] if upload_results: upload_results_lines = [ - 'sleep %s' % _UPLOAD_RBE_RESULTS_DELAY_SECONDS, - 'python3 tools/run_tests/python_utils/upload_rbe_results.py --invocation_id="%s" || exit /b 1' - % invocation_id, - '', + "sleep %s" % _UPLOAD_RBE_RESULTS_DELAY_SECONDS, + "python3 tools/run_tests/python_utils/upload_rbe_results.py" + ' --invocation_id="%s" || exit /b 1' % invocation_id, + "", ] else: upload_results_lines = [] outro_lines = [ - 'if %BAZEL_EXITCODE% == 0 (', - ' @rem success: plant the pre-generated xml report that says "success"', - ' mv -f %s %s' % - (success_report_filename, failing_report_filename), - ')', - 'exit /b %BAZEL_EXITCODE%', + "if %BAZEL_EXITCODE% == 0 (", + ( + " @rem success: plant the pre-generated xml report that says" + ' "success"' + ), + " mv -f %s %s" + % (success_report_filename, failing_report_filename), + ")", + "exit /b %BAZEL_EXITCODE%", ] lines = [ - line + '\n' + line + "\n" for line in intro_lines + upload_results_lines + outro_lines ] f.writelines(lines) - print('Bazel invocation ID: %s' % invocation_id, file=sys.stderr) - print('Upload test results to BigQuery after bazel runs: %s' % - upload_results, - file=sys.stderr) - print('Generated bazel wrapper: %s' % bazel_wrapper_filename, - file=sys.stderr) - print('Generated bazel wrapper: %s' % bazel_wrapper_bat_filename, - file=sys.stderr) + print("Bazel invocation ID: %s" % invocation_id, file=sys.stderr) + print( + "Upload test results to BigQuery after bazel runs: %s" % upload_results, + file=sys.stderr, + ) + print( + "Generated bazel wrapper: %s" % bazel_wrapper_filename, file=sys.stderr + ) + print( + "Generated bazel wrapper: %s" % bazel_wrapper_bat_filename, + file=sys.stderr, + ) -if __name__ == '__main__': +if __name__ == "__main__": # parse command line argp = argparse.ArgumentParser( - description= - 'Generate bazel wrapper to help with bazel test reports in CI.') + description=( + "Generate bazel wrapper to help with bazel test reports in CI." + ) + ) argp.add_argument( - '--report_path', + "--report_path", required=True, type=str, - help= - 'Path under which the bazel wrapper and other files are going to be generated' + help=( + "Path under which the bazel wrapper and other files are going to be" + " generated" + ), + ) + argp.add_argument( + "--report_suite_name", + default="bazel_invocations", + type=str, + help="Test suite name to use in generated XML report", ) - argp.add_argument('--report_suite_name', - default='bazel_invocations', - type=str, - help='Test suite name to use in generated XML report') args = argp.parse_args() # generate new bazel invocation ID @@ -251,8 +293,9 @@ def _create_bazel_wrapper(report_path: str, report_suite_name: str, report_path = args.report_path report_suite_name = args.report_suite_name - upload_results = True if os.getenv('UPLOAD_TEST_RESULTS') else False + upload_results = True if os.getenv("UPLOAD_TEST_RESULTS") else False _append_to_kokoro_bazel_invocations(invocation_id) - _create_bazel_wrapper(report_path, report_suite_name, invocation_id, - upload_results) + _create_bazel_wrapper( + report_path, report_suite_name, invocation_id, upload_results + ) diff --git a/tools/run_tests/python_utils/check_on_pr.py b/tools/run_tests/python_utils/check_on_pr.py index 5bf58f940d757..9d20fdd487b16 100644 --- a/tools/run_tests/python_utils/check_on_pr.py +++ b/tools/run_tests/python_utils/check_on_pr.py @@ -24,8 +24,8 @@ import jwt import requests -_GITHUB_API_PREFIX = 'https://api.github.com' -_GITHUB_REPO = 'grpc/grpc' +_GITHUB_API_PREFIX = "https://api.github.com" +_GITHUB_REPO = "grpc/grpc" _GITHUB_APP_ID = 22338 _INSTALLATION_ID = 519109 @@ -34,82 +34,90 @@ _ACCESS_TOKEN_FETCH_RETRIES_INTERVAL_S = 15 _CHANGE_LABELS = { - -1: 'improvement', - 0: 'none', - 1: 'low', - 2: 'medium', - 3: 'high', + -1: "improvement", + 0: "none", + 1: "low", + 2: "medium", + 3: "high", } _INCREASE_DECREASE = { - -1: 'decrease', - 0: 'neutral', - 1: 'increase', + -1: "decrease", + 0: "neutral", + 1: "increase", } def _jwt_token(): github_app_key = open( - os.path.join(os.environ['KOKORO_KEYSTORE_DIR'], - '73836_grpc_checks_private_key'), 'rb').read() + os.path.join( + os.environ["KOKORO_KEYSTORE_DIR"], "73836_grpc_checks_private_key" + ), + "rb", + ).read() return jwt.encode( { - 'iat': int(time.time()), - 'exp': int(time.time() + 60 * 10), # expire in 10 minutes - 'iss': _GITHUB_APP_ID, + "iat": int(time.time()), + "exp": int(time.time() + 60 * 10), # expire in 10 minutes + "iss": _GITHUB_APP_ID, }, github_app_key, - algorithm='RS256') + algorithm="RS256", + ) def _access_token(): global _ACCESS_TOKEN_CACHE - if _ACCESS_TOKEN_CACHE == None or _ACCESS_TOKEN_CACHE['exp'] < time.time(): + if _ACCESS_TOKEN_CACHE == None or _ACCESS_TOKEN_CACHE["exp"] < time.time(): for i in range(_ACCESS_TOKEN_FETCH_RETRIES): resp = requests.post( - url='https://api.github.com/app/installations/%s/access_tokens' + url="https://api.github.com/app/installations/%s/access_tokens" % _INSTALLATION_ID, headers={ - 'Authorization': 'Bearer %s' % _jwt_token(), - 'Accept': 'application/vnd.github.machine-man-preview+json', - }) + "Authorization": "Bearer %s" % _jwt_token(), + "Accept": "application/vnd.github.machine-man-preview+json", + }, + ) try: _ACCESS_TOKEN_CACHE = { - 'token': resp.json()['token'], - 'exp': time.time() + 60 + "token": resp.json()["token"], + "exp": time.time() + 60, } break except (KeyError, ValueError): traceback.print_exc() - print('HTTP Status %d %s' % (resp.status_code, resp.reason)) + print("HTTP Status %d %s" % (resp.status_code, resp.reason)) print("Fetch access token from Github API failed:") print(resp.text) if i != _ACCESS_TOKEN_FETCH_RETRIES - 1: - print('Retrying after %.2f second.' % - _ACCESS_TOKEN_FETCH_RETRIES_INTERVAL_S) + print( + "Retrying after %.2f second." + % _ACCESS_TOKEN_FETCH_RETRIES_INTERVAL_S + ) time.sleep(_ACCESS_TOKEN_FETCH_RETRIES_INTERVAL_S) else: print("error: Unable to fetch access token, exiting...") sys.exit(0) - return _ACCESS_TOKEN_CACHE['token'] + return _ACCESS_TOKEN_CACHE["token"] -def _call(url, method='GET', json=None): - if not url.startswith('https://'): +def _call(url, method="GET", json=None): + if not url.startswith("https://"): url = _GITHUB_API_PREFIX + url headers = { - 'Authorization': 'Bearer %s' % _access_token(), - 'Accept': 'application/vnd.github.antiope-preview+json', + "Authorization": "Bearer %s" % _access_token(), + "Accept": "application/vnd.github.antiope-preview+json", } return requests.request(method=method, url=url, headers=headers, json=json) def _latest_commit(): resp = _call( - '/repos/%s/pulls/%s/commits' % - (_GITHUB_REPO, os.environ['KOKORO_GITHUB_PULL_REQUEST_NUMBER'])) + "/repos/%s/pulls/%s/commits" + % (_GITHUB_REPO, os.environ["KOKORO_GITHUB_PULL_REQUEST_NUMBER"]) + ) return resp.json()[-1] @@ -127,38 +135,43 @@ def check_on_pr(name, summary, success=True): summary: A str in Markdown to be used as the detail information of the check. success: A bool indicates whether the check is succeed or not. """ - if 'KOKORO_GIT_COMMIT' not in os.environ: - print('Missing KOKORO_GIT_COMMIT env var: not checking') + if "KOKORO_GIT_COMMIT" not in os.environ: + print("Missing KOKORO_GIT_COMMIT env var: not checking") return - if 'KOKORO_KEYSTORE_DIR' not in os.environ: - print('Missing KOKORO_KEYSTORE_DIR env var: not checking') + if "KOKORO_KEYSTORE_DIR" not in os.environ: + print("Missing KOKORO_KEYSTORE_DIR env var: not checking") return - if 'KOKORO_GITHUB_PULL_REQUEST_NUMBER' not in os.environ: - print('Missing KOKORO_GITHUB_PULL_REQUEST_NUMBER env var: not checking') + if "KOKORO_GITHUB_PULL_REQUEST_NUMBER" not in os.environ: + print("Missing KOKORO_GITHUB_PULL_REQUEST_NUMBER env var: not checking") return MAX_SUMMARY_LEN = 65400 if len(summary) > MAX_SUMMARY_LEN: # Drop some hints to the log should someone come looking for what really happened! - print('Clipping too long summary') + print("Clipping too long summary") print(summary) - summary = summary[:MAX_SUMMARY_LEN] + '\n\n\n... CLIPPED (too long)' - completion_time = str( - datetime.datetime.utcnow().replace(microsecond=0).isoformat()) + 'Z' - resp = _call('/repos/%s/check-runs' % _GITHUB_REPO, - method='POST', - json={ - 'name': name, - 'head_sha': os.environ['KOKORO_GIT_COMMIT'], - 'status': 'completed', - 'completed_at': completion_time, - 'conclusion': 'success' if success else 'failure', - 'output': { - 'title': name, - 'summary': summary, - } - }) - print('Result of Creating/Updating Check on PR:', - json.dumps(resp.json(), indent=2)) + summary = summary[:MAX_SUMMARY_LEN] + "\n\n\n... CLIPPED (too long)" + completion_time = ( + str(datetime.datetime.utcnow().replace(microsecond=0).isoformat()) + "Z" + ) + resp = _call( + "/repos/%s/check-runs" % _GITHUB_REPO, + method="POST", + json={ + "name": name, + "head_sha": os.environ["KOKORO_GIT_COMMIT"], + "status": "completed", + "completed_at": completion_time, + "conclusion": "success" if success else "failure", + "output": { + "title": name, + "summary": summary, + }, + }, + ) + print( + "Result of Creating/Updating Check on PR:", + json.dumps(resp.json(), indent=2), + ) def label_significance_on_pr(name, change, labels=_CHANGE_LABELS): @@ -176,28 +189,30 @@ def label_significance_on_pr(name, change, labels=_CHANGE_LABELS): if change > max(list(labels.keys())): change = max(list(labels.keys())) value = labels[change] - if 'KOKORO_GIT_COMMIT' not in os.environ: - print('Missing KOKORO_GIT_COMMIT env var: not checking') + if "KOKORO_GIT_COMMIT" not in os.environ: + print("Missing KOKORO_GIT_COMMIT env var: not checking") return - if 'KOKORO_KEYSTORE_DIR' not in os.environ: - print('Missing KOKORO_KEYSTORE_DIR env var: not checking') + if "KOKORO_KEYSTORE_DIR" not in os.environ: + print("Missing KOKORO_KEYSTORE_DIR env var: not checking") return - if 'KOKORO_GITHUB_PULL_REQUEST_NUMBER' not in os.environ: - print('Missing KOKORO_GITHUB_PULL_REQUEST_NUMBER env var: not checking') + if "KOKORO_GITHUB_PULL_REQUEST_NUMBER" not in os.environ: + print("Missing KOKORO_GITHUB_PULL_REQUEST_NUMBER env var: not checking") return existing = _call( - '/repos/%s/issues/%s/labels' % - (_GITHUB_REPO, os.environ['KOKORO_GITHUB_PULL_REQUEST_NUMBER']), - method='GET').json() - print('Result of fetching labels on PR:', existing) - new = [x['name'] for x in existing if not x['name'].startswith(name + '/')] - new.append(name + '/' + value) + "/repos/%s/issues/%s/labels" + % (_GITHUB_REPO, os.environ["KOKORO_GITHUB_PULL_REQUEST_NUMBER"]), + method="GET", + ).json() + print("Result of fetching labels on PR:", existing) + new = [x["name"] for x in existing if not x["name"].startswith(name + "/")] + new.append(name + "/" + value) resp = _call( - '/repos/%s/issues/%s/labels' % - (_GITHUB_REPO, os.environ['KOKORO_GITHUB_PULL_REQUEST_NUMBER']), - method='PUT', - json=new) - print('Result of setting labels on PR:', resp.text) + "/repos/%s/issues/%s/labels" + % (_GITHUB_REPO, os.environ["KOKORO_GITHUB_PULL_REQUEST_NUMBER"]), + method="PUT", + json=new, + ) + print("Result of setting labels on PR:", resp.text) def label_increase_decrease_on_pr(name, change, significant): diff --git a/tools/run_tests/python_utils/dockerjob.py b/tools/run_tests/python_utils/dockerjob.py index 310a733a201b3..417ff284109b1 100755 --- a/tools/run_tests/python_utils/dockerjob.py +++ b/tools/run_tests/python_utils/dockerjob.py @@ -26,20 +26,25 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) import jobset -_DEVNULL = open(os.devnull, 'w') +_DEVNULL = open(os.devnull, "w") def random_name(base_name): """Randomizes given base name.""" - return '%s_%s' % (base_name, uuid.uuid4()) + return "%s_%s" % (base_name, uuid.uuid4()) def docker_kill(cid): """Kills a docker container. Returns True if successful.""" - return subprocess.call(['docker', 'kill', str(cid)], - stdin=subprocess.PIPE, - stdout=_DEVNULL, - stderr=subprocess.STDOUT) == 0 + return ( + subprocess.call( + ["docker", "kill", str(cid)], + stdin=subprocess.PIPE, + stdout=_DEVNULL, + stderr=subprocess.STDOUT, + ) + == 0 + ) def docker_mapped_port(cid, port, timeout_seconds=15): @@ -47,34 +52,37 @@ def docker_mapped_port(cid, port, timeout_seconds=15): started = time.time() while time.time() - started < timeout_seconds: try: - output = subprocess.check_output('docker port %s %s' % (cid, port), - stderr=_DEVNULL, - shell=True).decode() - return int(output.split(':', 2)[1]) + output = subprocess.check_output( + "docker port %s %s" % (cid, port), stderr=_DEVNULL, shell=True + ).decode() + return int(output.split(":", 2)[1]) except subprocess.CalledProcessError as e: pass - raise Exception('Failed to get exposed port %s for container %s.' % - (port, cid)) + raise Exception( + "Failed to get exposed port %s for container %s." % (port, cid) + ) def docker_ip_address(cid, timeout_seconds=15): """Get port mapped to internal given internal port for given container.""" started = time.time() while time.time() - started < timeout_seconds: - cmd = 'docker inspect %s' % cid + cmd = "docker inspect %s" % cid try: - output = subprocess.check_output(cmd, stderr=_DEVNULL, - shell=True).decode() + output = subprocess.check_output( + cmd, stderr=_DEVNULL, shell=True + ).decode() json_info = json.loads(output) assert len(json_info) == 1 - out = json_info[0]['NetworkSettings']['IPAddress'] + out = json_info[0]["NetworkSettings"]["IPAddress"] if not out: continue return out except subprocess.CalledProcessError as e: pass raise Exception( - 'Non-retryable error: Failed to get ip address of container %s.' % cid) + "Non-retryable error: Failed to get ip address of container %s." % cid + ) def wait_for_healthy(cid, shortname, timeout_seconds): @@ -82,17 +90,23 @@ def wait_for_healthy(cid, shortname, timeout_seconds): started = time.time() while time.time() - started < timeout_seconds: try: - output = subprocess.check_output([ - 'docker', 'inspect', '--format="{{.State.Health.Status}}"', cid - ], - stderr=_DEVNULL).decode() - if output.strip('\n') == 'healthy': + output = subprocess.check_output( + [ + "docker", + "inspect", + '--format="{{.State.Health.Status}}"', + cid, + ], + stderr=_DEVNULL, + ).decode() + if output.strip("\n") == "healthy": return except subprocess.CalledProcessError as e: pass time.sleep(1) - raise Exception('Timed out waiting for %s (%s) to pass health check' % - (shortname, cid)) + raise Exception( + "Timed out waiting for %s (%s) to pass health check" % (shortname, cid) + ) def finish_jobs(jobs, suppress_failure=True): @@ -106,10 +120,15 @@ def finish_jobs(jobs, suppress_failure=True): def image_exists(image): """Returns True if given docker image exists.""" - return subprocess.call(['docker', 'inspect', image], - stdin=subprocess.PIPE, - stdout=_DEVNULL, - stderr=subprocess.STDOUT) == 0 + return ( + subprocess.call( + ["docker", "inspect", image], + stdin=subprocess.PIPE, + stdout=_DEVNULL, + stderr=subprocess.STDOUT, + ) + == 0 + ) def remove_image(image, skip_nonexistent=False, max_retries=10): @@ -117,13 +136,18 @@ def remove_image(image, skip_nonexistent=False, max_retries=10): if skip_nonexistent and not image_exists(image): return True for attempt in range(0, max_retries): - if subprocess.call(['docker', 'rmi', '-f', image], - stdin=subprocess.PIPE, - stdout=_DEVNULL, - stderr=subprocess.STDOUT) == 0: + if ( + subprocess.call( + ["docker", "rmi", "-f", image], + stdin=subprocess.PIPE, + stdout=_DEVNULL, + stderr=subprocess.STDOUT, + ) + == 0 + ): return True time.sleep(2) - print('Failed to remove docker image %s' % image) + print("Failed to remove docker image %s" % image) return False @@ -132,10 +156,9 @@ class DockerJob: def __init__(self, spec): self._spec = spec - self._job = jobset.Job(spec, - newline_on_success=True, - travis=True, - add_env={}) + self._job = jobset.Job( + spec, newline_on_success=True, travis=True, add_env={} + ) self._container_name = spec.container_name def mapped_port(self, port): @@ -145,8 +168,9 @@ def ip_address(self): return docker_ip_address(self._container_name) def wait_for_healthy(self, timeout_seconds): - wait_for_healthy(self._container_name, self._spec.shortname, - timeout_seconds) + wait_for_healthy( + self._container_name, self._spec.shortname, timeout_seconds + ) def kill(self, suppress_failure=False): """Sends kill signal to the container.""" diff --git a/tools/run_tests/python_utils/download_and_unzip.py b/tools/run_tests/python_utils/download_and_unzip.py index 440572691fda7..cd4c87a23ff84 100644 --- a/tools/run_tests/python_utils/download_and_unzip.py +++ b/tools/run_tests/python_utils/download_and_unzip.py @@ -33,13 +33,15 @@ def main(): with tempfile.TemporaryFile() as tmp_file: r = requests.get(download_url) if r.status_code != requests.codes.ok: - print("Download %s failed with [%d] \"%s\"" % - (download_url, r.status_code, r.text())) + print( + 'Download %s failed with [%d] "%s"' + % (download_url, r.status_code, r.text()) + ) sys.exit(1) else: tmp_file.write(r.content) print("Successfully downloaded from %s", download_url) - with zipfile.ZipFile(tmp_file, 'r') as target_zip_file: + with zipfile.ZipFile(tmp_file, "r") as target_zip_file: target_zip_file.extractall(destination) print("Successfully unzip to %s" % destination) diff --git a/tools/run_tests/python_utils/filter_pull_request_tests.py b/tools/run_tests/python_utils/filter_pull_request_tests.py index 1303958813f00..2ae2d04dfbaa0 100644 --- a/tools/run_tests/python_utils/filter_pull_request_tests.py +++ b/tools/run_tests/python_utils/filter_pull_request_tests.py @@ -23,42 +23,50 @@ class TestSuite: """ - Contains label to identify job as belonging to this test suite and - triggers to identify if changed files are relevant - """ + Contains label to identify job as belonging to this test suite and + triggers to identify if changed files are relevant + """ def __init__(self, labels): """ - Build TestSuite to group tests based on labeling - :param label: strings that should match a jobs's platform, config, language, or test group - """ + Build TestSuite to group tests based on labeling + :param label: strings that should match a jobs's platform, config, language, or test group + """ self.triggers = [] self.labels = labels def add_trigger(self, trigger): """ - Add a regex to list of triggers that determine if a changed file should run tests - :param trigger: regex matching file relevant to tests - """ + Add a regex to list of triggers that determine if a changed file should run tests + :param trigger: regex matching file relevant to tests + """ self.triggers.append(trigger) # Create test suites -_CORE_TEST_SUITE = TestSuite(['c']) -_CPP_TEST_SUITE = TestSuite(['c++']) -_CSHARP_TEST_SUITE = TestSuite(['csharp']) -_NODE_TEST_SUITE = TestSuite(['grpc-node']) -_OBJC_TEST_SUITE = TestSuite(['objc']) -_PHP_TEST_SUITE = TestSuite(['php', 'php7']) -_PYTHON_TEST_SUITE = TestSuite(['python']) -_RUBY_TEST_SUITE = TestSuite(['ruby']) -_LINUX_TEST_SUITE = TestSuite(['linux']) -_WINDOWS_TEST_SUITE = TestSuite(['windows']) -_MACOS_TEST_SUITE = TestSuite(['macos']) +_CORE_TEST_SUITE = TestSuite(["c"]) +_CPP_TEST_SUITE = TestSuite(["c++"]) +_CSHARP_TEST_SUITE = TestSuite(["csharp"]) +_NODE_TEST_SUITE = TestSuite(["grpc-node"]) +_OBJC_TEST_SUITE = TestSuite(["objc"]) +_PHP_TEST_SUITE = TestSuite(["php", "php7"]) +_PYTHON_TEST_SUITE = TestSuite(["python"]) +_RUBY_TEST_SUITE = TestSuite(["ruby"]) +_LINUX_TEST_SUITE = TestSuite(["linux"]) +_WINDOWS_TEST_SUITE = TestSuite(["windows"]) +_MACOS_TEST_SUITE = TestSuite(["macos"]) _ALL_TEST_SUITES = [ - _CORE_TEST_SUITE, _CPP_TEST_SUITE, _CSHARP_TEST_SUITE, _NODE_TEST_SUITE, - _OBJC_TEST_SUITE, _PHP_TEST_SUITE, _PYTHON_TEST_SUITE, _RUBY_TEST_SUITE, - _LINUX_TEST_SUITE, _WINDOWS_TEST_SUITE, _MACOS_TEST_SUITE + _CORE_TEST_SUITE, + _CPP_TEST_SUITE, + _CSHARP_TEST_SUITE, + _NODE_TEST_SUITE, + _OBJC_TEST_SUITE, + _PHP_TEST_SUITE, + _PYTHON_TEST_SUITE, + _RUBY_TEST_SUITE, + _LINUX_TEST_SUITE, + _WINDOWS_TEST_SUITE, + _MACOS_TEST_SUITE, ] # Dictionary of allowlistable files where the key is a regex matching changed files @@ -67,49 +75,49 @@ def add_trigger(self, trigger): # match any of these regexes will trigger all tests # DO NOT CHANGE THIS UNLESS YOU KNOW WHAT YOU ARE DOING (be careful even if you do) _ALLOWLIST_DICT = { - '^doc/': [], - '^examples/': [], - '^include/grpc\+\+/': [_CPP_TEST_SUITE], - '^include/grpcpp/': [_CPP_TEST_SUITE], - '^summerofcode/': [], - '^src/cpp/': [_CPP_TEST_SUITE], - '^src/csharp/': [_CSHARP_TEST_SUITE], - '^src/objective\-c/': [_OBJC_TEST_SUITE], - '^src/php/': [_PHP_TEST_SUITE], - '^src/python/': [_PYTHON_TEST_SUITE], - '^src/ruby/': [_RUBY_TEST_SUITE], - '^templates/': [], - '^test/core/': [_CORE_TEST_SUITE, _CPP_TEST_SUITE], - '^test/cpp/': [_CPP_TEST_SUITE], - '^test/distrib/cpp/': [_CPP_TEST_SUITE], - '^test/distrib/csharp/': [_CSHARP_TEST_SUITE], - '^test/distrib/php/': [_PHP_TEST_SUITE], - '^test/distrib/python/': [_PYTHON_TEST_SUITE], - '^test/distrib/ruby/': [_RUBY_TEST_SUITE], - '^tools/run_tests/xds_k8s_test_driver/': [], - '^tools/internal_ci/linux/grpc_xds_k8s.*': [], - '^vsprojects/': [_WINDOWS_TEST_SUITE], - 'composer\.json$': [_PHP_TEST_SUITE], - 'config\.m4$': [_PHP_TEST_SUITE], - 'CONTRIBUTING\.md$': [], - 'Gemfile$': [_RUBY_TEST_SUITE], - 'grpc\.def$': [_WINDOWS_TEST_SUITE], - 'grpc\.gemspec$': [_RUBY_TEST_SUITE], - 'gRPC\.podspec$': [_OBJC_TEST_SUITE], - 'gRPC\-Core\.podspec$': [_OBJC_TEST_SUITE], - 'gRPC\-ProtoRPC\.podspec$': [_OBJC_TEST_SUITE], - 'gRPC\-RxLibrary\.podspec$': [_OBJC_TEST_SUITE], - 'BUILDING\.md$': [], - 'LICENSE$': [], - 'MANIFEST\.md$': [], - 'package\.json$': [_PHP_TEST_SUITE], - 'package\.xml$': [_PHP_TEST_SUITE], - 'PATENTS$': [], - 'PYTHON\-MANIFEST\.in$': [_PYTHON_TEST_SUITE], - 'README\.md$': [], - 'requirements\.txt$': [_PYTHON_TEST_SUITE], - 'setup\.cfg$': [_PYTHON_TEST_SUITE], - 'setup\.py$': [_PYTHON_TEST_SUITE] + "^doc/": [], + "^examples/": [], + "^include/grpc\+\+/": [_CPP_TEST_SUITE], + "^include/grpcpp/": [_CPP_TEST_SUITE], + "^summerofcode/": [], + "^src/cpp/": [_CPP_TEST_SUITE], + "^src/csharp/": [_CSHARP_TEST_SUITE], + "^src/objective\-c/": [_OBJC_TEST_SUITE], + "^src/php/": [_PHP_TEST_SUITE], + "^src/python/": [_PYTHON_TEST_SUITE], + "^src/ruby/": [_RUBY_TEST_SUITE], + "^templates/": [], + "^test/core/": [_CORE_TEST_SUITE, _CPP_TEST_SUITE], + "^test/cpp/": [_CPP_TEST_SUITE], + "^test/distrib/cpp/": [_CPP_TEST_SUITE], + "^test/distrib/csharp/": [_CSHARP_TEST_SUITE], + "^test/distrib/php/": [_PHP_TEST_SUITE], + "^test/distrib/python/": [_PYTHON_TEST_SUITE], + "^test/distrib/ruby/": [_RUBY_TEST_SUITE], + "^tools/run_tests/xds_k8s_test_driver/": [], + "^tools/internal_ci/linux/grpc_xds_k8s.*": [], + "^vsprojects/": [_WINDOWS_TEST_SUITE], + "composer\.json$": [_PHP_TEST_SUITE], + "config\.m4$": [_PHP_TEST_SUITE], + "CONTRIBUTING\.md$": [], + "Gemfile$": [_RUBY_TEST_SUITE], + "grpc\.def$": [_WINDOWS_TEST_SUITE], + "grpc\.gemspec$": [_RUBY_TEST_SUITE], + "gRPC\.podspec$": [_OBJC_TEST_SUITE], + "gRPC\-Core\.podspec$": [_OBJC_TEST_SUITE], + "gRPC\-ProtoRPC\.podspec$": [_OBJC_TEST_SUITE], + "gRPC\-RxLibrary\.podspec$": [_OBJC_TEST_SUITE], + "BUILDING\.md$": [], + "LICENSE$": [], + "MANIFEST\.md$": [], + "package\.json$": [_PHP_TEST_SUITE], + "package\.xml$": [_PHP_TEST_SUITE], + "PATENTS$": [], + "PYTHON\-MANIFEST\.in$": [_PYTHON_TEST_SUITE], + "README\.md$": [], + "requirements\.txt$": [_PYTHON_TEST_SUITE], + "setup\.cfg$": [_PYTHON_TEST_SUITE], + "setup\.py$": [_PYTHON_TEST_SUITE], } # Regex that combines all keys in _ALLOWLIST_DICT @@ -123,24 +131,31 @@ def add_trigger(self, trigger): def _get_changed_files(base_branch): """ - Get list of changed files between current branch and base of target merge branch - """ + Get list of changed files between current branch and base of target merge branch + """ # Get file changes between branch and merge-base of specified branch # Not combined to be Windows friendly - base_commit = subprocess.check_output( - ["git", "merge-base", base_branch, "HEAD"]).decode("UTF-8").rstrip() - return subprocess.check_output( - ["git", "diff", base_commit, "--name-only", - "HEAD"]).decode("UTF-8").splitlines() + base_commit = ( + subprocess.check_output(["git", "merge-base", base_branch, "HEAD"]) + .decode("UTF-8") + .rstrip() + ) + return ( + subprocess.check_output( + ["git", "diff", base_commit, "--name-only", "HEAD"] + ) + .decode("UTF-8") + .splitlines() + ) def _can_skip_tests(file_names, triggers): """ - Determines if tests are skippable based on if all files do not match list of regexes - :param file_names: list of changed files generated by _get_changed_files() - :param triggers: list of regexes matching file name that indicates tests should be run - :return: safe to skip tests - """ + Determines if tests are skippable based on if all files do not match list of regexes + :param file_names: list of changed files generated by _get_changed_files() + :param triggers: list of regexes matching file name that indicates tests should be run + :return: safe to skip tests + """ for file_name in file_names: if any(re.match(trigger, file_name) for trigger in triggers): return False @@ -149,57 +164,63 @@ def _can_skip_tests(file_names, triggers): def _remove_irrelevant_tests(tests, skippable_labels): """ - Filters out tests by config or language - will not remove sanitizer tests - :param tests: list of all tests generated by run_tests_matrix.py - :param skippable_labels: list of languages and platforms with skippable tests - :return: list of relevant tests - """ + Filters out tests by config or language - will not remove sanitizer tests + :param tests: list of all tests generated by run_tests_matrix.py + :param skippable_labels: list of languages and platforms with skippable tests + :return: list of relevant tests + """ # test.labels[0] is platform and test.labels[2] is language # We skip a test if both are considered safe to skip - return [test for test in tests if test.labels[0] not in skippable_labels or \ - test.labels[2] not in skippable_labels] + return [ + test + for test in tests + if test.labels[0] not in skippable_labels + or test.labels[2] not in skippable_labels + ] def affects_c_cpp(base_branch): """ - Determines if a pull request's changes affect C/C++. This function exists because - there are pull request tests that only test C/C++ code - :param base_branch: branch that a pull request is requesting to merge into - :return: boolean indicating whether C/C++ changes are made in pull request - """ + Determines if a pull request's changes affect C/C++. This function exists because + there are pull request tests that only test C/C++ code + :param base_branch: branch that a pull request is requesting to merge into + :return: boolean indicating whether C/C++ changes are made in pull request + """ changed_files = _get_changed_files(base_branch) # Run all tests if any changed file is not in the allowlist dictionary for changed_file in changed_files: if not re.match(_ALL_TRIGGERS, changed_file): return True return not _can_skip_tests( - changed_files, _CPP_TEST_SUITE.triggers + _CORE_TEST_SUITE.triggers) + changed_files, _CPP_TEST_SUITE.triggers + _CORE_TEST_SUITE.triggers + ) def filter_tests(tests, base_branch): """ - Filters out tests that are safe to ignore - :param tests: list of all tests generated by run_tests_matrix.py - :return: list of relevant tests - """ + Filters out tests that are safe to ignore + :param tests: list of all tests generated by run_tests_matrix.py + :return: list of relevant tests + """ print( - 'Finding file differences between gRPC %s branch and pull request...\n' - % base_branch) + "Finding file differences between gRPC %s branch and pull request...\n" + % base_branch + ) changed_files = _get_changed_files(base_branch) for changed_file in changed_files: - print(' %s' % changed_file) - print('') + print(" %s" % changed_file) + print("") # Run all tests if any changed file is not in the allowlist dictionary for changed_file in changed_files: if not re.match(_ALL_TRIGGERS, changed_file): - return (tests) + return tests # Figure out which language and platform tests to run skippable_labels = [] for test_suite in _ALL_TEST_SUITES: if _can_skip_tests(changed_files, test_suite.triggers): for label in test_suite.labels: - print(' %s tests safe to skip' % label) + print(" %s tests safe to skip" % label) skippable_labels.append(label) tests = _remove_irrelevant_tests(tests, skippable_labels) return tests diff --git a/tools/run_tests/python_utils/jobset.py b/tools/run_tests/python_utils/jobset.py index 2174c00eb7d00..e475b5a82f307 100755 --- a/tools/run_tests/python_utils/jobset.py +++ b/tools/run_tests/python_utils/jobset.py @@ -39,7 +39,7 @@ # characters to the PR description, which leak into the environment here # and cause failures. def strip_non_ascii_chars(s): - return ''.join(c for c in s if ord(c) < 128) + return "".join(c for c in s if ord(c) < 128) def sanitized_environment(env): @@ -50,22 +50,22 @@ def sanitized_environment(env): def platform_string(): - if platform.system() == 'Windows': - return 'windows' - elif platform.system()[:7] == 'MSYS_NT': - return 'windows' - elif platform.system() == 'Darwin': - return 'mac' - elif platform.system() == 'Linux': - return 'linux' + if platform.system() == "Windows": + return "windows" + elif platform.system()[:7] == "MSYS_NT": + return "windows" + elif platform.system() == "Darwin": + return "mac" + elif platform.system() == "Linux": + return "linux" else: - return 'posix' + return "posix" # setup a signal handler so that signal.pause registers 'something' # when a child finishes # not using futures and threading to avoid a dependency on subprocess32 -if platform_string() == 'windows': +if platform_string() == "windows": pass else: @@ -81,33 +81,33 @@ def alarm_handler(unused_signum, unused_frame): _KILLED = object() _COLORS = { - 'red': [31, 0], - 'green': [32, 0], - 'yellow': [33, 0], - 'lightgray': [37, 0], - 'gray': [30, 1], - 'purple': [35, 0], - 'cyan': [36, 0] + "red": [31, 0], + "green": [32, 0], + "yellow": [33, 0], + "lightgray": [37, 0], + "gray": [30, 1], + "purple": [35, 0], + "cyan": [36, 0], } -_BEGINNING_OF_LINE = '\x1b[0G' -_CLEAR_LINE = '\x1b[2K' +_BEGINNING_OF_LINE = "\x1b[0G" +_CLEAR_LINE = "\x1b[2K" _TAG_COLOR = { - 'FAILED': 'red', - 'FLAKE': 'purple', - 'TIMEOUT_FLAKE': 'purple', - 'WARNING': 'yellow', - 'TIMEOUT': 'red', - 'PASSED': 'green', - 'START': 'gray', - 'WAITING': 'yellow', - 'SUCCESS': 'green', - 'IDLE': 'gray', - 'SKIPPED': 'cyan' + "FAILED": "red", + "FLAKE": "purple", + "TIMEOUT_FLAKE": "purple", + "WARNING": "yellow", + "TIMEOUT": "red", + "PASSED": "green", + "START": "gray", + "WAITING": "yellow", + "SUCCESS": "green", + "IDLE": "gray", + "SKIPPED": "cyan", } -_FORMAT = '%(asctime)-15s %(message)s' +_FORMAT = "%(asctime)-15s %(message)s" logging.basicConfig(level=logging.INFO, format=_FORMAT) @@ -122,27 +122,41 @@ def eintr_be_gone(fn): def message(tag, msg, explanatory_text=None, do_newline=False): - if message.old_tag == tag and message.old_msg == msg and not explanatory_text: + if ( + message.old_tag == tag + and message.old_msg == msg + and not explanatory_text + ): return message.old_tag = tag message.old_msg = msg if explanatory_text: if isinstance(explanatory_text, bytes): - explanatory_text = explanatory_text.decode('utf8', errors='replace') + explanatory_text = explanatory_text.decode("utf8", errors="replace") while True: try: - if platform_string() == 'windows' or not sys.stdout.isatty(): + if platform_string() == "windows" or not sys.stdout.isatty(): if explanatory_text: logging.info(explanatory_text) - logging.info('%s: %s', tag, msg) + logging.info("%s: %s", tag, msg) else: sys.stdout.write( - '%s%s%s\x1b[%d;%dm%s\x1b[0m: %s%s' % - (_BEGINNING_OF_LINE, _CLEAR_LINE, '\n%s' % - explanatory_text if explanatory_text is not None else '', - _COLORS[_TAG_COLOR[tag]][1], _COLORS[_TAG_COLOR[tag]][0], - tag, msg, '\n' - if do_newline or explanatory_text is not None else '')) + "%s%s%s\x1b[%d;%dm%s\x1b[0m: %s%s" + % ( + _BEGINNING_OF_LINE, + _CLEAR_LINE, + "\n%s" % explanatory_text + if explanatory_text is not None + else "", + _COLORS[_TAG_COLOR[tag]][1], + _COLORS[_TAG_COLOR[tag]][0], + tag, + msg, + "\n" + if do_newline or explanatory_text is not None + else "", + ) + ) sys.stdout.flush() return except IOError as e: @@ -150,43 +164,45 @@ def message(tag, msg, explanatory_text=None, do_newline=False): raise -message.old_tag = '' -message.old_msg = '' +message.old_tag = "" +message.old_msg = "" def which(filename): - if '/' in filename: + if "/" in filename: return filename - for path in os.environ['PATH'].split(os.pathsep): + for path in os.environ["PATH"].split(os.pathsep): if os.path.exists(os.path.join(path, filename)): return os.path.join(path, filename) - raise Exception('%s not found' % filename) + raise Exception("%s not found" % filename) class JobSpec(object): """Specifies what to run for a job.""" - def __init__(self, - cmdline, - shortname=None, - environ=None, - cwd=None, - shell=False, - timeout_seconds=5 * 60, - flake_retries=0, - timeout_retries=0, - kill_handler=None, - cpu_cost=1.0, - verbose_success=False, - logfilename=None): + def __init__( + self, + cmdline, + shortname=None, + environ=None, + cwd=None, + shell=False, + timeout_seconds=5 * 60, + flake_retries=0, + timeout_retries=0, + kill_handler=None, + cpu_cost=1.0, + verbose_success=False, + logfilename=None, + ): + """ + Arguments: + cmdline: a list of arguments to pass as the command line + environ: a dictionary of environment variables to set in the child process + kill_handler: a handler that will be called whenever job.kill() is invoked + cpu_cost: number of cores per second this job needs + logfilename: use given file to store job's output, rather than using a temporary file """ - Arguments: - cmdline: a list of arguments to pass as the command line - environ: a dictionary of environment variables to set in the child process - kill_handler: a handler that will be called whenever job.kill() is invoked - cpu_cost: number of cores per second this job needs - logfilename: use given file to store job's output, rather than using a temporary file - """ if environ is None: environ = {} self.cmdline = cmdline @@ -201,13 +217,18 @@ def __init__(self, self.cpu_cost = cpu_cost self.verbose_success = verbose_success self.logfilename = logfilename - if self.logfilename and self.flake_retries != 0 and self.timeout_retries != 0: + if ( + self.logfilename + and self.flake_retries != 0 + and self.timeout_retries != 0 + ): # Forbidden to avoid overwriting the test log when retrying. raise Exception( - 'Cannot use custom logfile when retries are enabled') + "Cannot use custom logfile when retries are enabled" + ) def identity(self): - return '%r %r' % (self.cmdline, self.environ) + return "%r %r" % (self.cmdline, self.environ) def __hash__(self): return hash(self.identity()) @@ -219,24 +240,27 @@ def __lt__(self, other): return self.identity() < other.identity() def __repr__(self): - return 'JobSpec(shortname=%s, cmdline=%s)' % (self.shortname, - self.cmdline) + return "JobSpec(shortname=%s, cmdline=%s)" % ( + self.shortname, + self.cmdline, + ) def __str__(self): - return '%s: %s %s' % (self.shortname, ' '.join( - '%s=%s' % kv for kv in list(self.environ.items())), ' '.join( - self.cmdline)) + return "%s: %s %s" % ( + self.shortname, + " ".join("%s=%s" % kv for kv in list(self.environ.items())), + " ".join(self.cmdline), + ) class JobResult(object): - def __init__(self): - self.state = 'UNKNOWN' + self.state = "UNKNOWN" self.returncode = -1 self.elapsed_time = 0 self.num_failures = 0 self.retries = 0 - self.message = '' + self.message = "" self.cpu_estimated = 1 self.cpu_measured = 1 @@ -249,12 +273,9 @@ def read_from_start(f): class Job(object): """Manages one job.""" - def __init__(self, - spec, - newline_on_success, - travis, - add_env, - quiet_success=False): + def __init__( + self, spec, newline_on_success, travis, add_env, quiet_success=False + ): self._spec = spec self._newline_on_success = newline_on_success self._travis = travis @@ -264,7 +285,7 @@ def __init__(self, self._suppress_failure_message = False self._quiet_success = quiet_success if not self._quiet_success: - message('START', spec.shortname, do_newline=self._travis) + message("START", spec.shortname, do_newline=self._travis) self.result = JobResult() self.start() @@ -275,10 +296,11 @@ def start(self): if self._spec.logfilename: # make sure the log directory exists logfile_dir = os.path.dirname( - os.path.abspath(self._spec.logfilename)) + os.path.abspath(self._spec.logfilename) + ) if not os.path.exists(logfile_dir): os.makedirs(logfile_dir) - self._logfile = open(self._spec.logfilename, 'w+') + self._logfile = open(self._spec.logfilename, "w+") else: # macOS: a series of quick os.unlink invocation might cause OS # error during the creation of temporary file. By using @@ -293,16 +315,18 @@ def start(self): # The Unix time command is finicky when used with MSBuild, so we don't use it # with jobs that run MSBuild. global measure_cpu_costs - if measure_cpu_costs and not 'vsprojects\\build' in cmdline[0]: - cmdline = ['time', '-p'] + cmdline + if measure_cpu_costs and not "vsprojects\\build" in cmdline[0]: + cmdline = ["time", "-p"] + cmdline else: measure_cpu_costs = False - try_start = lambda: subprocess.Popen(args=cmdline, - stderr=subprocess.STDOUT, - stdout=self._logfile, - cwd=self._spec.cwd, - shell=self._spec.shell, - env=env) + try_start = lambda: subprocess.Popen( + args=cmdline, + stderr=subprocess.STDOUT, + stdout=self._logfile, + cwd=self._spec.cwd, + shell=self._spec.shell, + env=env, + ) delay = 0.3 for i in range(0, 4): try: @@ -310,8 +334,10 @@ def start(self): break except OSError: message( - 'WARNING', 'Failed to start %s, retrying in %f seconds' % - (self._spec.shortname, delay)) + "WARNING", + "Failed to start %s, retrying in %f seconds" + % (self._spec.shortname, delay), + ) time.sleep(delay) delay *= 2 else: @@ -331,12 +357,17 @@ def stdout(self=self): self.result.elapsed_time = elapsed if self._process.returncode != 0: if self._retries < self._spec.flake_retries: - message('FLAKE', - '%s [ret=%d, pid=%d]' % - (self._spec.shortname, self._process.returncode, - self._process.pid), - stdout(), - do_newline=True) + message( + "FLAKE", + "%s [ret=%d, pid=%d]" + % ( + self._spec.shortname, + self._process.returncode, + self._process.pid, + ), + stdout(), + do_newline=True, + ) self._retries += 1 self.result.num_failures += 1 self.result.retries = self._timeout_retries + self._retries @@ -345,51 +376,71 @@ def stdout(self=self): else: self._state = _FAILURE if not self._suppress_failure_message: - message('FAILED', - '%s [ret=%d, pid=%d, time=%.1fsec]' % - (self._spec.shortname, self._process.returncode, - self._process.pid, elapsed), - stdout(), - do_newline=True) - self.result.state = 'FAILED' + message( + "FAILED", + "%s [ret=%d, pid=%d, time=%.1fsec]" + % ( + self._spec.shortname, + self._process.returncode, + self._process.pid, + elapsed, + ), + stdout(), + do_newline=True, + ) + self.result.state = "FAILED" self.result.num_failures += 1 self.result.returncode = self._process.returncode else: self._state = _SUCCESS - measurement = '' + measurement = "" if measure_cpu_costs: m = re.search( - r'real\s+([0-9.]+)\nuser\s+([0-9.]+)\nsys\s+([0-9.]+)', - (stdout()).decode('utf8', errors='replace')) + r"real\s+([0-9.]+)\nuser\s+([0-9.]+)\nsys\s+([0-9.]+)", + (stdout()).decode("utf8", errors="replace"), + ) real = float(m.group(1)) user = float(m.group(2)) sys = float(m.group(3)) if real > 0.5: cores = (user + sys) / real - self.result.cpu_measured = float('%.01f' % cores) - self.result.cpu_estimated = float('%.01f' % - self._spec.cpu_cost) - measurement = '; cpu_cost=%.01f; estimated=%.01f' % ( - self.result.cpu_measured, self.result.cpu_estimated) + self.result.cpu_measured = float("%.01f" % cores) + self.result.cpu_estimated = float( + "%.01f" % self._spec.cpu_cost + ) + measurement = "; cpu_cost=%.01f; estimated=%.01f" % ( + self.result.cpu_measured, + self.result.cpu_estimated, + ) if not self._quiet_success: - message('PASSED', - '%s [time=%.1fsec, retries=%d:%d%s]' % - (self._spec.shortname, elapsed, self._retries, - self._timeout_retries, measurement), - stdout() if self._spec.verbose_success else None, - do_newline=self._newline_on_success or self._travis) - self.result.state = 'PASSED' - elif (self._state == _RUNNING and - self._spec.timeout_seconds is not None and - time.time() - self._start > self._spec.timeout_seconds): + message( + "PASSED", + "%s [time=%.1fsec, retries=%d:%d%s]" + % ( + self._spec.shortname, + elapsed, + self._retries, + self._timeout_retries, + measurement, + ), + stdout() if self._spec.verbose_success else None, + do_newline=self._newline_on_success or self._travis, + ) + self.result.state = "PASSED" + elif ( + self._state == _RUNNING + and self._spec.timeout_seconds is not None + and time.time() - self._start > self._spec.timeout_seconds + ): elapsed = time.time() - self._start self.result.elapsed_time = elapsed if self._timeout_retries < self._spec.timeout_retries: - message('TIMEOUT_FLAKE', - '%s [pid=%d]' % - (self._spec.shortname, self._process.pid), - stdout(), - do_newline=True) + message( + "TIMEOUT_FLAKE", + "%s [pid=%d]" % (self._spec.shortname, self._process.pid), + stdout(), + do_newline=True, + ) self._timeout_retries += 1 self.result.num_failures += 1 self.result.retries = self._timeout_retries + self._retries @@ -399,13 +450,15 @@ def stdout(self=self): # NOTE: job is restarted regardless of jobset's max_time setting self.start() else: - message('TIMEOUT', - '%s [pid=%d, time=%.1fsec]' % - (self._spec.shortname, self._process.pid, elapsed), - stdout(), - do_newline=True) + message( + "TIMEOUT", + "%s [pid=%d, time=%.1fsec]" + % (self._spec.shortname, self._process.pid, elapsed), + stdout(), + do_newline=True, + ) self.kill() - self.result.state = 'TIMEOUT' + self.result.state = "TIMEOUT" self.result.num_failures += 1 return self._state @@ -423,9 +476,18 @@ def suppress_failure_message(self): class Jobset(object): """Manages one run of jobs.""" - def __init__(self, check_cancelled, maxjobs, maxjobs_cpu_agnostic, - newline_on_success, travis, stop_on_failure, add_env, - quiet_success, max_time): + def __init__( + self, + check_cancelled, + maxjobs, + maxjobs_cpu_agnostic, + newline_on_success, + travis, + stop_on_failure, + add_env, + quiet_success, + max_time, + ): self._running = set() self._check_cancelled = check_cancelled self._cancelled = False @@ -458,11 +520,13 @@ def cpu_cost(self): def start(self, spec): """Start a job. Return True on success, False on failure.""" while True: - if self._max_time > 0 and time.time( - ) - self._start_time > self._max_time: + if ( + self._max_time > 0 + and time.time() - self._start_time > self._max_time + ): skipped_job_result = JobResult() - skipped_job_result.state = 'SKIPPED' - message('SKIPPED', spec.shortname, do_newline=True) + skipped_job_result.state = "SKIPPED" + message("SKIPPED", spec.shortname, do_newline=True) self.resultset[spec.shortname] = [skipped_job_result] return True if self.cancelled(): @@ -476,8 +540,13 @@ def start(self, spec): self.reap(spec.shortname, spec.cpu_cost) if self.cancelled(): return False - job = Job(spec, self._newline_on_success, self._travis, self._add_env, - self._quiet_success) + job = Job( + spec, + self._newline_on_success, + self._travis, + self._add_env, + self._quiet_success, + ) self._running.add(job) if job.GetSpec().shortname not in self.resultset: self.resultset[job.GetSpec().shortname] = [] @@ -501,30 +570,46 @@ def reap(self, waiting_for=None, waiting_for_cost=None): break for job in dead: self._completed += 1 - if not self._quiet_success or job.result.state != 'PASSED': + if not self._quiet_success or job.result.state != "PASSED": self.resultset[job.GetSpec().shortname].append(job.result) self._running.remove(job) if dead: return - if not self._travis and platform_string() != 'windows': - rstr = '' if self._remaining is None else '%d queued, ' % self._remaining + if not self._travis and platform_string() != "windows": + rstr = ( + "" + if self._remaining is None + else "%d queued, " % self._remaining + ) if self._remaining is not None and self._completed > 0: now = time.time() sofar = now - self._start_time - remaining = sofar / self._completed * (self._remaining + - len(self._running)) - rstr = 'ETA %.1f sec; %s' % (remaining, rstr) + remaining = ( + sofar + / self._completed + * (self._remaining + len(self._running)) + ) + rstr = "ETA %.1f sec; %s" % (remaining, rstr) if waiting_for is not None: - wstr = ' next: %s @ %.2f cpu' % (waiting_for, - waiting_for_cost) + wstr = " next: %s @ %.2f cpu" % ( + waiting_for, + waiting_for_cost, + ) else: - wstr = '' + wstr = "" message( - 'WAITING', - '%s%d jobs running, %d complete, %d failed (load %.2f)%s' % - (rstr, len(self._running), self._completed, self._failures, - self.cpu_cost(), wstr)) - if platform_string() == 'windows': + "WAITING", + "%s%d jobs running, %d complete, %d failed (load %.2f)%s" + % ( + rstr, + len(self._running), + self._completed, + self._failures, + self.cpu_cost(), + wstr, + ), + ) + if platform_string() == "windows": time.sleep(0.1) else: signal.alarm(10) @@ -546,7 +631,7 @@ def finish(self): if self.cancelled(): pass # poll cancellation self.reap() - if platform_string() != 'windows': + if platform_string() != "windows": signal.alarm(0) return not self.cancelled() and self._failures == 0 @@ -566,31 +651,41 @@ def tag_remaining(xs): yield (x, n - i - 1) -def run(cmdlines, - check_cancelled=_never_cancelled, - maxjobs=None, - maxjobs_cpu_agnostic=None, - newline_on_success=False, - travis=False, - infinite_runs=False, - stop_on_failure=False, - add_env={}, - skip_jobs=False, - quiet_success=False, - max_time=-1): +def run( + cmdlines, + check_cancelled=_never_cancelled, + maxjobs=None, + maxjobs_cpu_agnostic=None, + newline_on_success=False, + travis=False, + infinite_runs=False, + stop_on_failure=False, + add_env={}, + skip_jobs=False, + quiet_success=False, + max_time=-1, +): if skip_jobs: resultset = {} skipped_job_result = JobResult() - skipped_job_result.state = 'SKIPPED' + skipped_job_result.state = "SKIPPED" for job in cmdlines: - message('SKIPPED', job.shortname, do_newline=True) + message("SKIPPED", job.shortname, do_newline=True) resultset[job.shortname] = [skipped_job_result] return 0, resultset js = Jobset( - check_cancelled, maxjobs if maxjobs is not None else _DEFAULT_MAX_JOBS, - maxjobs_cpu_agnostic if maxjobs_cpu_agnostic is not None else - _DEFAULT_MAX_JOBS, newline_on_success, travis, stop_on_failure, add_env, - quiet_success, max_time) + check_cancelled, + maxjobs if maxjobs is not None else _DEFAULT_MAX_JOBS, + maxjobs_cpu_agnostic + if maxjobs_cpu_agnostic is not None + else _DEFAULT_MAX_JOBS, + newline_on_success, + travis, + stop_on_failure, + add_env, + quiet_success, + max_time, + ) for cmdline, remaining in tag_remaining(cmdlines): if not js.start(cmdline): break diff --git a/tools/run_tests/python_utils/port_server.py b/tools/run_tests/python_utils/port_server.py index a530165140a3d..3ff2c2b9c9d28 100755 --- a/tools/run_tests/python_utils/port_server.py +++ b/tools/run_tests/python_utils/port_server.py @@ -35,23 +35,23 @@ # note that all changes must be backwards compatible _MY_VERSION = 21 -if len(sys.argv) == 2 and sys.argv[1] == 'dump_version': +if len(sys.argv) == 2 and sys.argv[1] == "dump_version": print(_MY_VERSION) sys.exit(0) -argp = argparse.ArgumentParser(description='Server for httpcli_test') -argp.add_argument('-p', '--port', default=12345, type=int) -argp.add_argument('-l', '--logfile', default=None, type=str) +argp = argparse.ArgumentParser(description="Server for httpcli_test") +argp.add_argument("-p", "--port", default=12345, type=int) +argp.add_argument("-l", "--logfile", default=None, type=str) args = argp.parse_args() if args.logfile is not None: sys.stdin.close() sys.stderr.close() sys.stdout.close() - sys.stderr = open(args.logfile, 'w') + sys.stderr = open(args.logfile, "w") sys.stdout = sys.stderr -print('port server running on port %d' % args.port) +print("port server running on port %d" % args.port) pool = [] in_use = {} @@ -62,22 +62,82 @@ # ports is used in a Cronet test, the test would fail (see issue #12149). These # ports must be excluded from pool. cronet_restricted_ports = [ - 1, 7, 9, 11, 13, 15, 17, 19, 20, 21, 22, 23, 25, 37, 42, 43, 53, 77, 79, 87, - 95, 101, 102, 103, 104, 109, 110, 111, 113, 115, 117, 119, 123, 135, 139, - 143, 179, 389, 465, 512, 513, 514, 515, 526, 530, 531, 532, 540, 556, 563, - 587, 601, 636, 993, 995, 2049, 3659, 4045, 6000, 6665, 6666, 6667, 6668, - 6669, 6697 + 1, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 20, + 21, + 22, + 23, + 25, + 37, + 42, + 43, + 53, + 77, + 79, + 87, + 95, + 101, + 102, + 103, + 104, + 109, + 110, + 111, + 113, + 115, + 117, + 119, + 123, + 135, + 139, + 143, + 179, + 389, + 465, + 512, + 513, + 514, + 515, + 526, + 530, + 531, + 532, + 540, + 556, + 563, + 587, + 601, + 636, + 993, + 995, + 2049, + 3659, + 4045, + 6000, + 6665, + 6666, + 6667, + 6668, + 6669, + 6697, ] def can_connect(port): # this test is only really useful on unices where SO_REUSE_PORT is available # so on Windows, where this test is expensive, skip it - if platform.system() == 'Windows': + if platform.system() == "Windows": return False s = socket.socket() try: - s.connect(('localhost', port)) + s.connect(("localhost", port)) return True except socket.error as e: return False @@ -89,7 +149,7 @@ def can_bind(port, proto): s = socket.socket(proto, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: - s.bind(('localhost', port)) + s.bind(("localhost", port)) return True except socket.error as e: return False @@ -100,7 +160,8 @@ def can_bind(port, proto): def refill_pool(max_timeout, req): """Scan for ports not marked for being in use""" chk = [ - port for port in range(1025, 32766) + port + for port in range(1025, 32766) if port not in cronet_restricted_ports ] random.shuffle(chk) @@ -113,8 +174,11 @@ def refill_pool(max_timeout, req): continue req.log_message("kill old request %d" % i) del in_use[i] - if can_bind(i, socket.AF_INET) and can_bind( - i, socket.AF_INET6) and not can_connect(i): + if ( + can_bind(i, socket.AF_INET) + and can_bind(i, socket.AF_INET6) + and not can_connect(i) + ): req.log_message("found available port %d" % i) pool.append(i) @@ -144,7 +208,6 @@ def allocate_port(req): class Handler(BaseHTTPRequestHandler): - def setup(self): # If the client is unreachable for 5 seconds, close the connection self.timeout = 5 @@ -153,51 +216,56 @@ def setup(self): def do_GET(self): global keep_running global mu - if self.path == '/get': + if self.path == "/get": # allocate a new port, it will stay bound for ten minutes and until # it's unused self.send_response(200) - self.send_header('Content-Type', 'text/plain') + self.send_header("Content-Type", "text/plain") self.end_headers() p = allocate_port(self) - self.log_message('allocated port %d' % p) - self.wfile.write(str(p).encode('ascii')) - elif self.path[0:6] == '/drop/': + self.log_message("allocated port %d" % p) + self.wfile.write(str(p).encode("ascii")) + elif self.path[0:6] == "/drop/": self.send_response(200) - self.send_header('Content-Type', 'text/plain') + self.send_header("Content-Type", "text/plain") self.end_headers() p = int(self.path[6:]) mu.acquire() if p in in_use: del in_use[p] pool.append(p) - k = 'known' + k = "known" else: - k = 'unknown' + k = "unknown" mu.release() - self.log_message('drop %s port %d' % (k, p)) - elif self.path == '/version_number': + self.log_message("drop %s port %d" % (k, p)) + elif self.path == "/version_number": # fetch a version string and the current process pid self.send_response(200) - self.send_header('Content-Type', 'text/plain') + self.send_header("Content-Type", "text/plain") self.end_headers() - self.wfile.write(str(_MY_VERSION).encode('ascii')) - elif self.path == '/dump': + self.wfile.write(str(_MY_VERSION).encode("ascii")) + elif self.path == "/dump": # yaml module is not installed on Macs and Windows machines by default # so we import it lazily (/dump action is only used for debugging) import yaml + self.send_response(200) - self.send_header('Content-Type', 'text/plain') + self.send_header("Content-Type", "text/plain") self.end_headers() mu.acquire() now = time.time() - out = yaml.dump({ - 'pool': pool, - 'in_use': dict((k, now - v) for k, v in list(in_use.items())) - }) + out = yaml.dump( + { + "pool": pool, + "in_use": dict( + (k, now - v) for k, v in list(in_use.items()) + ), + } + ) mu.release() - self.wfile.write(out.encode('ascii')) - elif self.path == '/quitquitquit': + self.wfile.write(out.encode("ascii")) + elif self.path == "/quitquitquit": self.send_response(200) self.end_headers() self.server.shutdown() @@ -207,4 +275,4 @@ class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): """Handle requests in a separate thread""" -ThreadedHTTPServer(('', args.port), Handler).serve_forever() +ThreadedHTTPServer(("", args.port), Handler).serve_forever() diff --git a/tools/run_tests/python_utils/report_utils.py b/tools/run_tests/python_utils/report_utils.py index 5cf7851b5f0cf..e54e15ce9f263 100644 --- a/tools/run_tests/python_utils/report_utils.py +++ b/tools/run_tests/python_utils/report_utils.py @@ -17,7 +17,7 @@ from mako import exceptions from mako.runtime import Context from mako.template import Template -except (ImportError): +except ImportError: pass # Mako not installed but it is ok. import datetime import os @@ -29,38 +29,44 @@ def _filter_msg(msg, output_format): """Filters out nonprintable and illegal characters from the message.""" - if output_format in ['XML', 'HTML']: + if output_format in ["XML", "HTML"]: if isinstance(msg, bytes): - decoded_msg = msg.decode('UTF-8', 'ignore') + decoded_msg = msg.decode("UTF-8", "ignore") else: decoded_msg = msg # keep whitespaces but remove formfeed and vertical tab characters # that make XML report unparsable. - filtered_msg = ''.join( - filter(lambda x: x in string.printable and x != '\f' and x != '\v', - decoded_msg)) - if output_format == 'HTML': - filtered_msg = filtered_msg.replace('"', '"') + filtered_msg = "".join( + filter( + lambda x: x in string.printable and x != "\f" and x != "\v", + decoded_msg, + ) + ) + if output_format == "HTML": + filtered_msg = filtered_msg.replace('"', """) return filtered_msg else: return msg def new_junit_xml_tree(): - return ET.ElementTree(ET.Element('testsuites')) + return ET.ElementTree(ET.Element("testsuites")) -def render_junit_xml_report(resultset, - report_file, - suite_package='grpc', - suite_name='tests', - replace_dots=True, - multi_target=False): +def render_junit_xml_report( + resultset, + report_file, + suite_package="grpc", + suite_name="tests", + replace_dots=True, + multi_target=False, +): """Generate JUnit-like XML report.""" if not multi_target: tree = new_junit_xml_tree() - append_junit_xml_results(tree, resultset, suite_package, suite_name, - '1', replace_dots) + append_junit_xml_results( + tree, resultset, suite_package, suite_name, "1", replace_dots + ) create_xml_report_file(tree, report_file) else: # To have each test result displayed as a separate target by the Resultstore/Sponge UI, @@ -68,83 +74,100 @@ def render_junit_xml_report(resultset, for shortname, results in six.iteritems(resultset): one_result = {shortname: results} tree = new_junit_xml_tree() - append_junit_xml_results(tree, one_result, - '%s_%s' % (suite_package, shortname), - '%s_%s' % (suite_name, shortname), '1', - replace_dots) - per_suite_report_file = os.path.join(os.path.dirname(report_file), - shortname, - os.path.basename(report_file)) + append_junit_xml_results( + tree, + one_result, + "%s_%s" % (suite_package, shortname), + "%s_%s" % (suite_name, shortname), + "1", + replace_dots, + ) + per_suite_report_file = os.path.join( + os.path.dirname(report_file), + shortname, + os.path.basename(report_file), + ) create_xml_report_file(tree, per_suite_report_file) def create_xml_report_file(tree, report_file): """Generate JUnit-like report file from xml tree .""" # env variable can be used to override the base location for the reports - base_dir = os.getenv('GRPC_TEST_REPORT_BASE_DIR', None) + base_dir = os.getenv("GRPC_TEST_REPORT_BASE_DIR", None) if base_dir: report_file = os.path.join(base_dir, report_file) # ensure the report directory exists report_dir = os.path.dirname(os.path.abspath(report_file)) if not os.path.exists(report_dir): os.makedirs(report_dir) - tree.write(report_file, encoding='UTF-8') + tree.write(report_file, encoding="UTF-8") -def append_junit_xml_results(tree, - resultset, - suite_package, - suite_name, - id, - replace_dots=True): +def append_junit_xml_results( + tree, resultset, suite_package, suite_name, id, replace_dots=True +): """Append a JUnit-like XML report tree with test results as a new suite.""" if replace_dots: # ResultStore UI displays test suite names containing dots only as the component # after the last dot, which results bad info being displayed in the UI. # We replace dots by another character to avoid this problem. - suite_name = suite_name.replace('.', '_') - testsuite = ET.SubElement(tree.getroot(), - 'testsuite', - id=id, - package=suite_package, - name=suite_name, - timestamp=datetime.datetime.now().isoformat()) + suite_name = suite_name.replace(".", "_") + testsuite = ET.SubElement( + tree.getroot(), + "testsuite", + id=id, + package=suite_package, + name=suite_name, + timestamp=datetime.datetime.now().isoformat(), + ) failure_count = 0 error_count = 0 for shortname, results in six.iteritems(resultset): for result in results: - xml_test = ET.SubElement(testsuite, 'testcase', name=shortname) + xml_test = ET.SubElement(testsuite, "testcase", name=shortname) if result.elapsed_time: - xml_test.set('time', str(result.elapsed_time)) - filtered_msg = _filter_msg(result.message, 'XML') - if result.state == 'FAILED': - ET.SubElement(xml_test, 'failure', - message='Failure').text = filtered_msg + xml_test.set("time", str(result.elapsed_time)) + filtered_msg = _filter_msg(result.message, "XML") + if result.state == "FAILED": + ET.SubElement( + xml_test, "failure", message="Failure" + ).text = filtered_msg failure_count += 1 - elif result.state == 'TIMEOUT': - ET.SubElement(xml_test, 'error', - message='Timeout').text = filtered_msg + elif result.state == "TIMEOUT": + ET.SubElement( + xml_test, "error", message="Timeout" + ).text = filtered_msg error_count += 1 - elif result.state == 'SKIPPED': - ET.SubElement(xml_test, 'skipped', message='Skipped') - testsuite.set('failures', str(failure_count)) - testsuite.set('errors', str(error_count)) - - -def render_interop_html_report(client_langs, server_langs, test_cases, - auth_test_cases, http2_cases, http2_server_cases, - resultset, num_failures, cloud_to_prod, - prod_servers, http2_interop): + elif result.state == "SKIPPED": + ET.SubElement(xml_test, "skipped", message="Skipped") + testsuite.set("failures", str(failure_count)) + testsuite.set("errors", str(error_count)) + + +def render_interop_html_report( + client_langs, + server_langs, + test_cases, + auth_test_cases, + http2_cases, + http2_server_cases, + resultset, + num_failures, + cloud_to_prod, + prod_servers, + http2_interop, +): """Generate HTML report for interop tests.""" - template_file = 'tools/run_tests/interop/interop_html_report.template' + template_file = "tools/run_tests/interop/interop_html_report.template" try: mytemplate = Template(filename=template_file, format_exceptions=True) except NameError: print( - 'Mako template is not installed. Skipping HTML report generation.') + "Mako template is not installed. Skipping HTML report generation." + ) return except IOError as e: - print(('Failed to find the template %s: %s' % (template_file, e))) + print(("Failed to find the template %s: %s" % (template_file, e))) return sorted_test_cases = sorted(test_cases) @@ -156,25 +179,25 @@ def render_interop_html_report(client_langs, server_langs, test_cases, sorted_prod_servers = sorted(prod_servers) args = { - 'client_langs': sorted_client_langs, - 'server_langs': sorted_server_langs, - 'test_cases': sorted_test_cases, - 'auth_test_cases': sorted_auth_test_cases, - 'http2_cases': sorted_http2_cases, - 'http2_server_cases': sorted_http2_server_cases, - 'resultset': resultset, - 'num_failures': num_failures, - 'cloud_to_prod': cloud_to_prod, - 'prod_servers': sorted_prod_servers, - 'http2_interop': http2_interop + "client_langs": sorted_client_langs, + "server_langs": sorted_server_langs, + "test_cases": sorted_test_cases, + "auth_test_cases": sorted_auth_test_cases, + "http2_cases": sorted_http2_cases, + "http2_server_cases": sorted_http2_server_cases, + "resultset": resultset, + "num_failures": num_failures, + "cloud_to_prod": cloud_to_prod, + "prod_servers": sorted_prod_servers, + "http2_interop": http2_interop, } - html_report_out_dir = 'reports' + html_report_out_dir = "reports" if not os.path.exists(html_report_out_dir): os.mkdir(html_report_out_dir) - html_file_path = os.path.join(html_report_out_dir, 'index.html') + html_file_path = os.path.join(html_report_out_dir, "index.html") try: - with open(html_file_path, 'w') as output_file: + with open(html_file_path, "w") as output_file: mytemplate.render_context(Context(output_file, **args)) except: print((exceptions.text_error_template().render())) @@ -182,8 +205,8 @@ def render_interop_html_report(client_langs, server_langs, test_cases, def render_perf_profiling_results(output_filepath, profile_names): - with open(output_filepath, 'w') as output_file: - output_file.write('
    \n') + with open(output_filepath, "w") as output_file: + output_file.write("
      \n") for name in profile_names: - output_file.write('
    • %s
    • \n' % (name, name)) - output_file.write('
    \n') + output_file.write("
  • %s
  • \n" % (name, name)) + output_file.write("
\n") diff --git a/tools/run_tests/python_utils/start_port_server.py b/tools/run_tests/python_utils/start_port_server.py index 15eada447733f..7c3b208d2c362 100644 --- a/tools/run_tests/python_utils/start_port_server.py +++ b/tools/run_tests/python_utils/start_port_server.py @@ -38,97 +38,118 @@ def start_port_server(): # otherwise, leave it up try: version = int( - request.urlopen('http://localhost:%d/version_number' % - _PORT_SERVER_PORT).read()) - logging.info('detected port server running version %d', version) + request.urlopen( + "http://localhost:%d/version_number" % _PORT_SERVER_PORT + ).read() + ) + logging.info("detected port server running version %d", version) running = True except Exception as e: - logging.exception('failed to detect port server') + logging.exception("failed to detect port server") running = False if running: current_version = int( - subprocess.check_output([ - sys.executable, # use the same python binary as this process - os.path.abspath('tools/run_tests/python_utils/port_server.py'), - 'dump_version' - ]).decode()) - logging.info('my port server is version %d', current_version) - running = (version >= current_version) + subprocess.check_output( + [ + sys.executable, # use the same python binary as this process + os.path.abspath( + "tools/run_tests/python_utils/port_server.py" + ), + "dump_version", + ] + ).decode() + ) + logging.info("my port server is version %d", current_version) + running = version >= current_version if not running: - logging.info('port_server version mismatch: killing the old one') - request.urlopen('http://localhost:%d/quitquitquit' % - _PORT_SERVER_PORT).read() + logging.info("port_server version mismatch: killing the old one") + request.urlopen( + "http://localhost:%d/quitquitquit" % _PORT_SERVER_PORT + ).read() time.sleep(1) if not running: fd, logfile = tempfile.mkstemp() os.close(fd) - logging.info('starting port_server, with log file %s', logfile) + logging.info("starting port_server, with log file %s", logfile) args = [ sys.executable, - os.path.abspath('tools/run_tests/python_utils/port_server.py'), - '-p', - '%d' % _PORT_SERVER_PORT, '-l', logfile + os.path.abspath("tools/run_tests/python_utils/port_server.py"), + "-p", + "%d" % _PORT_SERVER_PORT, + "-l", + logfile, ] env = dict(os.environ) - env['BUILD_ID'] = 'pleaseDontKillMeJenkins' - if jobset.platform_string() == 'windows': + env["BUILD_ID"] = "pleaseDontKillMeJenkins" + if jobset.platform_string() == "windows": # Working directory of port server needs to be outside of Jenkins # workspace to prevent file lock issues. tempdir = tempfile.mkdtemp() if sys.version_info.major == 2: creationflags = 0x00000008 # detached process else: - creationflags = 0 # DETACHED_PROCESS doesn't seem to work with python3 - port_server = subprocess.Popen(args, - env=env, - cwd=tempdir, - creationflags=creationflags, - close_fds=True) + creationflags = ( + 0 # DETACHED_PROCESS doesn't seem to work with python3 + ) + port_server = subprocess.Popen( + args, + env=env, + cwd=tempdir, + creationflags=creationflags, + close_fds=True, + ) else: - port_server = subprocess.Popen(args, - env=env, - preexec_fn=os.setsid, - close_fds=True) + port_server = subprocess.Popen( + args, env=env, preexec_fn=os.setsid, close_fds=True + ) time.sleep(1) # ensure port server is up waits = 0 while True: if waits > 10: logging.warning( - 'killing port server due to excessive start up waits') + "killing port server due to excessive start up waits" + ) port_server.kill() if port_server.poll() is not None: - logging.error('port_server failed to start') + logging.error("port_server failed to start") # try one final time: maybe another build managed to start one time.sleep(1) try: - request.urlopen('http://localhost:%d/get' % - _PORT_SERVER_PORT).read() + request.urlopen( + "http://localhost:%d/get" % _PORT_SERVER_PORT + ).read() logging.info( - 'last ditch attempt to contact port server succeeded') + "last ditch attempt to contact port server succeeded" + ) break except: logging.exception( - 'final attempt to contact port server failed') - port_log = open(logfile, 'r').read() + "final attempt to contact port server failed" + ) + port_log = open(logfile, "r").read() print(port_log) sys.exit(1) try: - port_server_url = 'http://localhost:%d/get' % _PORT_SERVER_PORT + port_server_url = "http://localhost:%d/get" % _PORT_SERVER_PORT request.urlopen(port_server_url).read() - logging.info('port server is up and ready') + logging.info("port server is up and ready") break except socket.timeout: - logging.exception('while waiting for port_server') + logging.exception("while waiting for port_server") time.sleep(1) waits += 1 except IOError: - logging.exception('while waiting for port_server') + logging.exception("while waiting for port_server") time.sleep(1) waits += 1 except: logging.exception( - 'error while contacting port server at "%s".' - 'Will try killing it.', port_server_url) + ( + 'error while contacting port server at "%s".' + "Will try killing it." + ), + port_server_url, + ) port_server.kill() raise diff --git a/tools/run_tests/python_utils/upload_rbe_results.py b/tools/run_tests/python_utils/upload_rbe_results.py index 4dd6b138fa607..8f506fdba59d2 100755 --- a/tools/run_tests/python_utils/upload_rbe_results.py +++ b/tools/run_tests/python_utils/upload_rbe_results.py @@ -25,56 +25,61 @@ import uuid gcp_utils_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), '../../gcp/utils')) + os.path.join(os.path.dirname(__file__), "../../gcp/utils") +) sys.path.append(gcp_utils_dir) import big_query_utils -_DATASET_ID = 'jenkins_test_results' -_DESCRIPTION = 'Test results from master RBE builds on Kokoro' +_DATASET_ID = "jenkins_test_results" +_DESCRIPTION = "Test results from master RBE builds on Kokoro" # 365 days in milliseconds _EXPIRATION_MS = 365 * 24 * 60 * 60 * 1000 -_PARTITION_TYPE = 'DAY' -_PROJECT_ID = 'grpc-testing' +_PARTITION_TYPE = "DAY" +_PROJECT_ID = "grpc-testing" _RESULTS_SCHEMA = [ - ('job_name', 'STRING', 'Name of Kokoro job'), - ('build_id', 'INTEGER', 'Build ID of Kokoro job'), - ('build_url', 'STRING', 'URL of Kokoro build'), - ('test_target', 'STRING', 'Bazel target path'), - ('test_class_name', 'STRING', 'Name of test class'), - ('test_case', 'STRING', 'Name of test case'), - ('result', 'STRING', 'Test or build result'), - ('timestamp', 'TIMESTAMP', 'Timestamp of test run'), - ('duration', 'FLOAT', 'Duration of the test run'), + ("job_name", "STRING", "Name of Kokoro job"), + ("build_id", "INTEGER", "Build ID of Kokoro job"), + ("build_url", "STRING", "URL of Kokoro build"), + ("test_target", "STRING", "Bazel target path"), + ("test_class_name", "STRING", "Name of test class"), + ("test_case", "STRING", "Name of test case"), + ("result", "STRING", "Test or build result"), + ("timestamp", "TIMESTAMP", "Timestamp of test run"), + ("duration", "FLOAT", "Duration of the test run"), ] -_TABLE_ID = 'rbe_test_results' +_TABLE_ID = "rbe_test_results" def _get_api_key(): """Returns string with API key to access ResultStore. - Intended to be used in Kokoro environment.""" - api_key_directory = os.getenv('KOKORO_GFILE_DIR') - api_key_file = os.path.join(api_key_directory, 'resultstore_api_key') - assert os.path.isfile(api_key_file), 'Must add --api_key arg if not on ' \ - 'Kokoro or Kokoro environment is not set up properly.' - with open(api_key_file, 'r') as f: - return f.read().replace('\n', '') + Intended to be used in Kokoro environment.""" + api_key_directory = os.getenv("KOKORO_GFILE_DIR") + api_key_file = os.path.join(api_key_directory, "resultstore_api_key") + assert os.path.isfile(api_key_file), ( + "Must add --api_key arg if not on " + "Kokoro or Kokoro environment is not set up properly." + ) + with open(api_key_file, "r") as f: + return f.read().replace("\n", "") def _get_invocation_id(): """Returns String of Bazel invocation ID. Intended to be used in - Kokoro environment.""" - bazel_id_directory = os.getenv('KOKORO_ARTIFACTS_DIR') - bazel_id_file = os.path.join(bazel_id_directory, 'bazel_invocation_ids') - assert os.path.isfile(bazel_id_file), 'bazel_invocation_ids file, written ' \ - 'by RBE initialization script, expected but not found.' - with open(bazel_id_file, 'r') as f: - return f.read().replace('\n', '') + Kokoro environment.""" + bazel_id_directory = os.getenv("KOKORO_ARTIFACTS_DIR") + bazel_id_file = os.path.join(bazel_id_directory, "bazel_invocation_ids") + assert os.path.isfile(bazel_id_file), ( + "bazel_invocation_ids file, written " + "by RBE initialization script, expected but not found." + ) + with open(bazel_id_file, "r") as f: + return f.read().replace("\n", "") def _parse_test_duration(duration_str): """Parse test duration string in '123.567s' format""" try: - if duration_str.endswith('s'): + if duration_str.endswith("s"): duration_str = duration_str[:-1] return float(duration_str) except: @@ -84,50 +89,54 @@ def _parse_test_duration(duration_str): def _upload_results_to_bq(rows): """Upload test results to a BQ table. - Args: - rows: A list of dictionaries containing data for each row to insert - """ + Args: + rows: A list of dictionaries containing data for each row to insert + """ bq = big_query_utils.create_big_query() - big_query_utils.create_partitioned_table(bq, - _PROJECT_ID, - _DATASET_ID, - _TABLE_ID, - _RESULTS_SCHEMA, - _DESCRIPTION, - partition_type=_PARTITION_TYPE, - expiration_ms=_EXPIRATION_MS) + big_query_utils.create_partitioned_table( + bq, + _PROJECT_ID, + _DATASET_ID, + _TABLE_ID, + _RESULTS_SCHEMA, + _DESCRIPTION, + partition_type=_PARTITION_TYPE, + expiration_ms=_EXPIRATION_MS, + ) max_retries = 3 for attempt in range(max_retries): - if big_query_utils.insert_rows(bq, _PROJECT_ID, _DATASET_ID, _TABLE_ID, - rows): + if big_query_utils.insert_rows( + bq, _PROJECT_ID, _DATASET_ID, _TABLE_ID, rows + ): break else: if attempt < max_retries - 1: - print('Error uploading result to bigquery, will retry.') + print("Error uploading result to bigquery, will retry.") else: print( - 'Error uploading result to bigquery, all attempts failed.') + "Error uploading result to bigquery, all attempts failed." + ) sys.exit(1) def _get_resultstore_data(api_key, invocation_id): """Returns dictionary of test results by querying ResultStore API. - Args: - api_key: String of ResultStore API key - invocation_id: String of ResultStore invocation ID to results from - """ + Args: + api_key: String of ResultStore API key + invocation_id: String of ResultStore invocation ID to results from + """ all_actions = [] - page_token = '' + page_token = "" # ResultStore's API returns data on a limited number of tests. When we exceed # that limit, the 'nextPageToken' field is included in the request to get # subsequent data, so keep requesting until 'nextPageToken' field is omitted. while True: req = urllib.request.Request( - url= - 'https://resultstore.googleapis.com/v2/invocations/%s/targets/-/configuredTargets/-/actions?key=%s&pageToken=%s&fields=next_page_token,actions.id,actions.status_attributes,actions.timing,actions.test_action' + url="https://resultstore.googleapis.com/v2/invocations/%s/targets/-/configuredTargets/-/actions?key=%s&pageToken=%s&fields=next_page_token,actions.id,actions.status_attributes,actions.timing,actions.test_action" % (invocation_id, api_key, page_token), - headers={'Content-Type': 'application/json'}) + headers={"Content-Type": "application/json"}, + ) ctx_dict = {} if os.getenv("PYTHONHTTPSVERIFY") == "0": ctx = ssl.create_default_context() @@ -135,43 +144,58 @@ def _get_resultstore_data(api_key, invocation_id): ctx.verify_mode = ssl.CERT_NONE ctx_dict = {"context": ctx} raw_resp = urllib.request.urlopen(req, **ctx_dict).read() - decoded_resp = raw_resp if isinstance( - raw_resp, str) else raw_resp.decode('utf-8', 'ignore') + decoded_resp = ( + raw_resp + if isinstance(raw_resp, str) + else raw_resp.decode("utf-8", "ignore") + ) results = json.loads(decoded_resp) - all_actions.extend(results['actions']) - if 'nextPageToken' not in results: + all_actions.extend(results["actions"]) + if "nextPageToken" not in results: break - page_token = results['nextPageToken'] + page_token = results["nextPageToken"] return all_actions if __name__ == "__main__": # Arguments are necessary if running in a non-Kokoro environment. argp = argparse.ArgumentParser( - description= - 'Fetches results for given RBE invocation and uploads them to BigQuery table.' + description=( + "Fetches results for given RBE invocation and uploads them to" + " BigQuery table." + ) + ) + argp.add_argument( + "--api_key", + default="", + type=str, + help="The API key to read from ResultStore API", + ) + argp.add_argument( + "--invocation_id", + default="", + type=str, + help="UUID of bazel invocation to fetch.", + ) + argp.add_argument( + "--bq_dump_file", + default=None, + type=str, + help="Dump JSON data to file just before uploading", + ) + argp.add_argument( + "--resultstore_dump_file", + default=None, + type=str, + help="Dump JSON data as received from ResultStore API", + ) + argp.add_argument( + "--skip_upload", + default=False, + action="store_const", + const=True, + help="Skip uploading to bigquery", ) - argp.add_argument('--api_key', - default='', - type=str, - help='The API key to read from ResultStore API') - argp.add_argument('--invocation_id', - default='', - type=str, - help='UUID of bazel invocation to fetch.') - argp.add_argument('--bq_dump_file', - default=None, - type=str, - help='Dump JSON data to file just before uploading') - argp.add_argument('--resultstore_dump_file', - default=None, - type=str, - help='Dump JSON data as received from ResultStore API') - argp.add_argument('--skip_upload', - default=False, - action='store_const', - const=True, - help='Skip uploading to bigquery') args = argp.parse_args() api_key = args.api_key or _get_api_key() @@ -179,129 +203,128 @@ def _get_resultstore_data(api_key, invocation_id): resultstore_actions = _get_resultstore_data(api_key, invocation_id) if args.resultstore_dump_file: - with open(args.resultstore_dump_file, 'w') as f: + with open(args.resultstore_dump_file, "w") as f: json.dump(resultstore_actions, f, indent=4, sort_keys=True) print( - ('Dumped resultstore data to file %s' % args.resultstore_dump_file)) + ("Dumped resultstore data to file %s" % args.resultstore_dump_file) + ) # google.devtools.resultstore.v2.Action schema: # https://github.com/googleapis/googleapis/blob/master/google/devtools/resultstore/v2/action.proto bq_rows = [] for index, action in enumerate(resultstore_actions): # Filter out non-test related data, such as build results. - if 'testAction' not in action: + if "testAction" not in action: continue # Some test results contain the fileProcessingErrors field, which indicates # an issue with parsing results individual test cases. - if 'fileProcessingErrors' in action: - test_cases = [{ - 'testCase': { - 'caseName': str(action['id']['actionId']), + if "fileProcessingErrors" in action: + test_cases = [ + { + "testCase": { + "caseName": str(action["id"]["actionId"]), + } } - }] + ] # Test timeouts have a different dictionary structure compared to pass and # fail results. - elif action['statusAttributes']['status'] == 'TIMED_OUT': - test_cases = [{ - 'testCase': { - 'caseName': str(action['id']['actionId']), - 'timedOut': True + elif action["statusAttributes"]["status"] == "TIMED_OUT": + test_cases = [ + { + "testCase": { + "caseName": str(action["id"]["actionId"]), + "timedOut": True, + } } - }] + ] # When RBE believes its infrastructure is failing, it will abort and # mark running tests as UNKNOWN. These infrastructure failures may be # related to our tests, so we should investigate if specific tests are # repeatedly being marked as UNKNOWN. - elif action['statusAttributes']['status'] == 'UNKNOWN': - test_cases = [{ - 'testCase': { - 'caseName': str(action['id']['actionId']), - 'unknown': True + elif action["statusAttributes"]["status"] == "UNKNOWN": + test_cases = [ + { + "testCase": { + "caseName": str(action["id"]["actionId"]), + "unknown": True, + } } - }] + ] # Take the timestamp from the previous action, which should be # a close approximation. - action['timing'] = { - 'startTime': - resultstore_actions[index - 1]['timing']['startTime'] + action["timing"] = { + "startTime": resultstore_actions[index - 1]["timing"][ + "startTime" + ] } - elif 'testSuite' not in action['testAction']: + elif "testSuite" not in action["testAction"]: continue - elif 'tests' not in action['testAction']['testSuite']: + elif "tests" not in action["testAction"]["testSuite"]: continue else: test_cases = [] - for tests_item in action['testAction']['testSuite']['tests']: - test_cases += tests_item['testSuite']['tests'] + for tests_item in action["testAction"]["testSuite"]["tests"]: + test_cases += tests_item["testSuite"]["tests"] for test_case in test_cases: - if any(s in test_case['testCase'] for s in ['errors', 'failures']): - result = 'FAILED' - elif 'timedOut' in test_case['testCase']: - result = 'TIMEOUT' - elif 'unknown' in test_case['testCase']: - result = 'UNKNOWN' + if any(s in test_case["testCase"] for s in ["errors", "failures"]): + result = "FAILED" + elif "timedOut" in test_case["testCase"]: + result = "TIMEOUT" + elif "unknown" in test_case["testCase"]: + result = "UNKNOWN" else: - result = 'PASSED' + result = "PASSED" try: - bq_rows.append({ - 'insertId': str(uuid.uuid4()), - 'json': { - 'job_name': - os.getenv('KOKORO_JOB_NAME'), - 'build_id': - os.getenv('KOKORO_BUILD_NUMBER'), - 'build_url': - 'https://source.cloud.google.com/results/invocations/%s' + bq_rows.append( + { + "insertId": str(uuid.uuid4()), + "json": { + "job_name": os.getenv("KOKORO_JOB_NAME"), + "build_id": os.getenv("KOKORO_BUILD_NUMBER"), + "build_url": "https://source.cloud.google.com/results/invocations/%s" % invocation_id, - 'test_target': - action['id']['targetId'], - 'test_class_name': - test_case['testCase'].get('className', ''), - 'test_case': - test_case['testCase']['caseName'], - 'result': - result, - 'timestamp': - action['timing']['startTime'], - 'duration': - _parse_test_duration(action['timing']['duration']), + "test_target": action["id"]["targetId"], + "test_class_name": test_case["testCase"].get( + "className", "" + ), + "test_case": test_case["testCase"]["caseName"], + "result": result, + "timestamp": action["timing"]["startTime"], + "duration": _parse_test_duration( + action["timing"]["duration"] + ), + }, } - }) + ) except Exception as e: - print(('Failed to parse test result. Error: %s' % str(e))) + print(("Failed to parse test result. Error: %s" % str(e))) print((json.dumps(test_case, indent=4))) - bq_rows.append({ - 'insertId': str(uuid.uuid4()), - 'json': { - 'job_name': - os.getenv('KOKORO_JOB_NAME'), - 'build_id': - os.getenv('KOKORO_BUILD_NUMBER'), - 'build_url': - 'https://source.cloud.google.com/results/invocations/%s' + bq_rows.append( + { + "insertId": str(uuid.uuid4()), + "json": { + "job_name": os.getenv("KOKORO_JOB_NAME"), + "build_id": os.getenv("KOKORO_BUILD_NUMBER"), + "build_url": "https://source.cloud.google.com/results/invocations/%s" % invocation_id, - 'test_target': - action['id']['targetId'], - 'test_class_name': - 'N/A', - 'test_case': - 'N/A', - 'result': - 'UNPARSEABLE', - 'timestamp': - 'N/A', + "test_target": action["id"]["targetId"], + "test_class_name": "N/A", + "test_case": "N/A", + "result": "UNPARSEABLE", + "timestamp": "N/A", + }, } - }) + ) if args.bq_dump_file: - with open(args.bq_dump_file, 'w') as f: + with open(args.bq_dump_file, "w") as f: json.dump(bq_rows, f, indent=4, sort_keys=True) - print(('Dumped BQ data to file %s' % args.bq_dump_file)) + print(("Dumped BQ data to file %s" % args.bq_dump_file)) if not args.skip_upload: # BigQuery sometimes fails with large uploads, so batch 1,000 rows at a time. MAX_ROWS = 1000 for i in range(0, len(bq_rows), MAX_ROWS): - _upload_results_to_bq(bq_rows[i:i + MAX_ROWS]) + _upload_results_to_bq(bq_rows[i : i + MAX_ROWS]) else: - print('Skipped upload to bigquery.') + print("Skipped upload to bigquery.") diff --git a/tools/run_tests/python_utils/upload_test_results.py b/tools/run_tests/python_utils/upload_test_results.py index c18ad6b955cfd..752ffac469016 100644 --- a/tools/run_tests/python_utils/upload_test_results.py +++ b/tools/run_tests/python_utils/upload_test_results.py @@ -24,65 +24,74 @@ import six gcp_utils_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), '../../gcp/utils')) + os.path.join(os.path.dirname(__file__), "../../gcp/utils") +) sys.path.append(gcp_utils_dir) import big_query_utils -_DATASET_ID = 'jenkins_test_results' -_DESCRIPTION = 'Test results from master job run on Jenkins' +_DATASET_ID = "jenkins_test_results" +_DESCRIPTION = "Test results from master job run on Jenkins" # 365 days in milliseconds _EXPIRATION_MS = 365 * 24 * 60 * 60 * 1000 -_PARTITION_TYPE = 'DAY' -_PROJECT_ID = 'grpc-testing' +_PARTITION_TYPE = "DAY" +_PROJECT_ID = "grpc-testing" _RESULTS_SCHEMA = [ - ('job_name', 'STRING', 'Name of Jenkins job'), - ('build_id', 'INTEGER', 'Build ID of Jenkins job'), - ('build_url', 'STRING', 'URL of Jenkins job'), - ('test_name', 'STRING', 'Individual test name'), - ('language', 'STRING', 'Language of test'), - ('platform', 'STRING', 'Platform used for test'), - ('config', 'STRING', 'Config used for test'), - ('compiler', 'STRING', 'Compiler used for test'), - ('iomgr_platform', 'STRING', 'Iomgr used for test'), - ('result', 'STRING', 'Test result: PASSED, TIMEOUT, FAILED, or SKIPPED'), - ('timestamp', 'TIMESTAMP', 'Timestamp of test run'), - ('elapsed_time', 'FLOAT', 'How long test took to run'), - ('cpu_estimated', 'FLOAT', 'Estimated CPU usage of test'), - ('cpu_measured', 'FLOAT', 'Actual CPU usage of test'), - ('return_code', 'INTEGER', 'Exit code of test'), + ("job_name", "STRING", "Name of Jenkins job"), + ("build_id", "INTEGER", "Build ID of Jenkins job"), + ("build_url", "STRING", "URL of Jenkins job"), + ("test_name", "STRING", "Individual test name"), + ("language", "STRING", "Language of test"), + ("platform", "STRING", "Platform used for test"), + ("config", "STRING", "Config used for test"), + ("compiler", "STRING", "Compiler used for test"), + ("iomgr_platform", "STRING", "Iomgr used for test"), + ("result", "STRING", "Test result: PASSED, TIMEOUT, FAILED, or SKIPPED"), + ("timestamp", "TIMESTAMP", "Timestamp of test run"), + ("elapsed_time", "FLOAT", "How long test took to run"), + ("cpu_estimated", "FLOAT", "Estimated CPU usage of test"), + ("cpu_measured", "FLOAT", "Actual CPU usage of test"), + ("return_code", "INTEGER", "Exit code of test"), ] _INTEROP_RESULTS_SCHEMA = [ - ('job_name', 'STRING', 'Name of Jenkins/Kokoro job'), - ('build_id', 'INTEGER', 'Build ID of Jenkins/Kokoro job'), - ('build_url', 'STRING', 'URL of Jenkins/Kokoro job'), - ('test_name', 'STRING', - 'Unique test name combining client, server, and test_name'), - ('suite', 'STRING', - 'Test suite: cloud_to_cloud, cloud_to_prod, or cloud_to_prod_auth'), - ('client', 'STRING', 'Client language'), - ('server', 'STRING', 'Server host name'), - ('test_case', 'STRING', 'Name of test case'), - ('result', 'STRING', 'Test result: PASSED, TIMEOUT, FAILED, or SKIPPED'), - ('timestamp', 'TIMESTAMP', 'Timestamp of test run'), - ('elapsed_time', 'FLOAT', 'How long test took to run'), + ("job_name", "STRING", "Name of Jenkins/Kokoro job"), + ("build_id", "INTEGER", "Build ID of Jenkins/Kokoro job"), + ("build_url", "STRING", "URL of Jenkins/Kokoro job"), + ( + "test_name", + "STRING", + "Unique test name combining client, server, and test_name", + ), + ( + "suite", + "STRING", + "Test suite: cloud_to_cloud, cloud_to_prod, or cloud_to_prod_auth", + ), + ("client", "STRING", "Client language"), + ("server", "STRING", "Server host name"), + ("test_case", "STRING", "Name of test case"), + ("result", "STRING", "Test result: PASSED, TIMEOUT, FAILED, or SKIPPED"), + ("timestamp", "TIMESTAMP", "Timestamp of test run"), + ("elapsed_time", "FLOAT", "How long test took to run"), ] def _get_build_metadata(test_results): """Add Kokoro build metadata to test_results based on environment - variables set by Kokoro. - """ - build_id = os.getenv('KOKORO_BUILD_NUMBER') - build_url = 'https://source.cloud.google.com/results/invocations/%s' % os.getenv( - 'KOKORO_BUILD_ID') - job_name = os.getenv('KOKORO_JOB_NAME') + variables set by Kokoro. + """ + build_id = os.getenv("KOKORO_BUILD_NUMBER") + build_url = ( + "https://source.cloud.google.com/results/invocations/%s" + % os.getenv("KOKORO_BUILD_ID") + ) + job_name = os.getenv("KOKORO_JOB_NAME") if build_id: - test_results['build_id'] = build_id + test_results["build_id"] = build_id if build_url: - test_results['build_url'] = build_url + test_results["build_url"] = build_url if job_name: - test_results['job_name'] = job_name + test_results["job_name"] = job_name def _insert_rows_with_retries(bq, bq_table, bq_rows): @@ -91,16 +100,21 @@ def _insert_rows_with_retries(bq, bq_table, bq_rows): for i in range((len(bq_rows) // 1000) + 1): max_retries = 3 for attempt in range(max_retries): - if big_query_utils.insert_rows(bq, _PROJECT_ID, _DATASET_ID, - bq_table, - bq_rows[i * 1000:(i + 1) * 1000]): + if big_query_utils.insert_rows( + bq, + _PROJECT_ID, + _DATASET_ID, + bq_table, + bq_rows[i * 1000 : (i + 1) * 1000], + ): break else: if attempt < max_retries - 1: - print('Error uploading result to bigquery, will retry.') + print("Error uploading result to bigquery, will retry.") else: print( - 'Error uploading result to bigquery, all attempts failed.' + "Error uploading result to bigquery, all attempts" + " failed." ) sys.exit(1) @@ -108,33 +122,35 @@ def _insert_rows_with_retries(bq, bq_table, bq_rows): def upload_results_to_bq(resultset, bq_table, extra_fields): """Upload test results to a BQ table. - Args: - resultset: dictionary generated by jobset.run - bq_table: string name of table to create/upload results to in BQ - extra_fields: dict with extra values that will be uploaded along with the results - """ + Args: + resultset: dictionary generated by jobset.run + bq_table: string name of table to create/upload results to in BQ + extra_fields: dict with extra values that will be uploaded along with the results + """ bq = big_query_utils.create_big_query() - big_query_utils.create_partitioned_table(bq, - _PROJECT_ID, - _DATASET_ID, - bq_table, - _RESULTS_SCHEMA, - _DESCRIPTION, - partition_type=_PARTITION_TYPE, - expiration_ms=_EXPIRATION_MS) + big_query_utils.create_partitioned_table( + bq, + _PROJECT_ID, + _DATASET_ID, + bq_table, + _RESULTS_SCHEMA, + _DESCRIPTION, + partition_type=_PARTITION_TYPE, + expiration_ms=_EXPIRATION_MS, + ) bq_rows = [] for shortname, results in six.iteritems(resultset): for result in results: test_results = {} _get_build_metadata(test_results) - test_results['cpu_estimated'] = result.cpu_estimated - test_results['cpu_measured'] = result.cpu_measured - test_results['elapsed_time'] = '%.2f' % result.elapsed_time - test_results['result'] = result.state - test_results['return_code'] = result.returncode - test_results['test_name'] = shortname - test_results['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S') + test_results["cpu_estimated"] = result.cpu_estimated + test_results["cpu_measured"] = result.cpu_measured + test_results["elapsed_time"] = "%.2f" % result.elapsed_time + test_results["result"] = result.state + test_results["return_code"] = result.returncode + test_results["test_name"] = shortname + test_results["timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S") for field_name, field_value in six.iteritems(extra_fields): test_results[field_name] = field_value row = big_query_utils.make_row(str(uuid.uuid4()), test_results) @@ -145,33 +161,35 @@ def upload_results_to_bq(resultset, bq_table, extra_fields): def upload_interop_results_to_bq(resultset, bq_table): """Upload interop test results to a BQ table. - Args: - resultset: dictionary generated by jobset.run - bq_table: string name of table to create/upload results to in BQ - """ + Args: + resultset: dictionary generated by jobset.run + bq_table: string name of table to create/upload results to in BQ + """ bq = big_query_utils.create_big_query() - big_query_utils.create_partitioned_table(bq, - _PROJECT_ID, - _DATASET_ID, - bq_table, - _INTEROP_RESULTS_SCHEMA, - _DESCRIPTION, - partition_type=_PARTITION_TYPE, - expiration_ms=_EXPIRATION_MS) + big_query_utils.create_partitioned_table( + bq, + _PROJECT_ID, + _DATASET_ID, + bq_table, + _INTEROP_RESULTS_SCHEMA, + _DESCRIPTION, + partition_type=_PARTITION_TYPE, + expiration_ms=_EXPIRATION_MS, + ) bq_rows = [] for shortname, results in six.iteritems(resultset): for result in results: test_results = {} _get_build_metadata(test_results) - test_results['elapsed_time'] = '%.2f' % result.elapsed_time - test_results['result'] = result.state - test_results['test_name'] = shortname - test_results['suite'] = shortname.split(':')[0] - test_results['client'] = shortname.split(':')[1] - test_results['server'] = shortname.split(':')[2] - test_results['test_case'] = shortname.split(':')[3] - test_results['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S') + test_results["elapsed_time"] = "%.2f" % result.elapsed_time + test_results["result"] = result.state + test_results["test_name"] = shortname + test_results["suite"] = shortname.split(":")[0] + test_results["client"] = shortname.split(":")[1] + test_results["server"] = shortname.split(":")[2] + test_results["test_case"] = shortname.split(":")[3] + test_results["timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S") row = big_query_utils.make_row(str(uuid.uuid4()), test_results) bq_rows.append(row) _insert_rows_with_retries(bq, bq_table, bq_rows) diff --git a/tools/run_tests/python_utils/watch_dirs.py b/tools/run_tests/python_utils/watch_dirs.py index b9e2cfa0c8432..409f46d8c0fc0 100755 --- a/tools/run_tests/python_utils/watch_dirs.py +++ b/tools/run_tests/python_utils/watch_dirs.py @@ -40,7 +40,7 @@ def _calculate(self): continue for root, _, files in os.walk(path): for f in files: - if f and f[0] == '.': + if f and f[0] == ".": continue try: st = os.stat(os.path.join(root, f)) @@ -51,8 +51,9 @@ def _calculate(self): if most_recent_change is None: most_recent_change = st.st_mtime else: - most_recent_change = max(most_recent_change, - st.st_mtime) + most_recent_change = max( + most_recent_change, st.st_mtime + ) return most_recent_change def most_recent_change(self): diff --git a/tools/run_tests/run_grpclb_interop_tests.py b/tools/run_tests/run_grpclb_interop_tests.py index 376574b2922b3..da0c6bafb4d96 100755 --- a/tools/run_tests/run_grpclb_interop_tests.py +++ b/tools/run_tests/run_grpclb_interop_tests.py @@ -37,9 +37,9 @@ import python_utils.report_utils as report_utils # Docker doesn't clean up after itself, so we do it on exit. -atexit.register(lambda: subprocess.call(['stty', 'echo'])) +atexit.register(lambda: subprocess.call(["stty", "echo"])) -ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(ROOT) _FALLBACK_SERVER_PORT = 443 @@ -48,20 +48,19 @@ _TEST_TIMEOUT = 30 -_FAKE_SERVERS_SAFENAME = 'fake_servers' +_FAKE_SERVERS_SAFENAME = "fake_servers" # Use a name that's verified by the test certs -_SERVICE_NAME = 'server.test.google.fr' +_SERVICE_NAME = "server.test.google.fr" class CXXLanguage: - def __init__(self): - self.client_cwd = '/var/local/git/grpc' - self.safename = 'cxx' + self.client_cwd = "/var/local/git/grpc" + self.safename = "cxx" def client_cmd(self, args): - return ['bins/opt/interop_client'] + args + return ["bins/opt/interop_client"] + args def global_env(self): # 1) Set c-ares as the resolver, to @@ -72,24 +71,21 @@ def global_env(self): # GoogleDefaultCredentials to be # able to use the test CA. return { - 'GRPC_DNS_RESOLVER': - 'ares', - 'GRPC_VERBOSITY': - 'DEBUG', - 'GRPC_TRACE': - 'client_channel,glb', - 'GRPC_DEFAULT_SSL_ROOTS_FILE_PATH': - '/var/local/git/grpc/src/core/tsi/test_creds/ca.pem', + "GRPC_DNS_RESOLVER": "ares", + "GRPC_VERBOSITY": "DEBUG", + "GRPC_TRACE": "client_channel,glb", + "GRPC_DEFAULT_SSL_ROOTS_FILE_PATH": ( + "/var/local/git/grpc/src/core/tsi/test_creds/ca.pem" + ), } def __str__(self): - return 'c++' + return "c++" class JavaLanguage: - def __init__(self): - self.client_cwd = '/var/local/git/grpc-java' + self.client_cwd = "/var/local/git/grpc-java" self.safename = str(self) def client_cmd(self, args): @@ -97,43 +93,49 @@ def client_cmd(self, args): # the set of test CA's that the Java runtime of the # docker container will pick up, so that # Java GoogleDefaultCreds can use it. - pem_to_der_cmd = ('openssl x509 -outform der ' - '-in /external_mount/src/core/tsi/test_creds/ca.pem ' - '-out /tmp/test_ca.der') + pem_to_der_cmd = ( + "openssl x509 -outform der " + "-in /external_mount/src/core/tsi/test_creds/ca.pem " + "-out /tmp/test_ca.der" + ) keystore_import_cmd = ( - 'keytool -import ' - '-keystore /usr/lib/jvm/java-8-oracle/jre/lib/security/cacerts ' - '-file /tmp/test_ca.der ' - '-deststorepass changeit ' - '-noprompt') + "keytool -import " + "-keystore /usr/lib/jvm/java-8-oracle/jre/lib/security/cacerts " + "-file /tmp/test_ca.der " + "-deststorepass changeit " + "-noprompt" + ) return [ - 'bash', '-c', - ('{pem_to_der_cmd} && ' - '{keystore_import_cmd} && ' - './run-test-client.sh {java_client_args}').format( - pem_to_der_cmd=pem_to_der_cmd, - keystore_import_cmd=keystore_import_cmd, - java_client_args=' '.join(args)) + "bash", + "-c", + ( + "{pem_to_der_cmd} && " + "{keystore_import_cmd} && " + "./run-test-client.sh {java_client_args}" + ).format( + pem_to_der_cmd=pem_to_der_cmd, + keystore_import_cmd=keystore_import_cmd, + java_client_args=" ".join(args), + ), ] def global_env(self): # 1) Enable grpclb # 2) Enable verbose logging return { - 'JAVA_OPTS': ( - '-Dio.grpc.internal.DnsNameResolverProvider.enable_grpclb=true ' - '-Djava.util.logging.config.file=/var/local/grpc_java_logging/logconf.txt' + "JAVA_OPTS": ( + "-Dio.grpc.internal.DnsNameResolverProvider.enable_grpclb=true " + "-Djava.util.logging.config.file=/var/local/grpc_java_logging/logconf.txt" ) } def __str__(self): - return 'java' + return "java" class GoLanguage: - def __init__(self): - self.client_cwd = '/go/src/google.golang.org/grpc/interop/client' + self.client_cwd = "/go/src/google.golang.org/grpc/interop/client" self.safename = str(self) def client_cmd(self, args): @@ -142,38 +144,40 @@ def client_cmd(self, args): # that Go's GoogleDefaultCredentials can use it. # See https://golang.org/src/crypto/x509/root_linux.go. return [ - 'bash', '-c', - ('cp /external_mount/src/core/tsi/test_creds/ca.pem ' - '/etc/ssl/certs/ca-certificates.crt && ' - '/go/bin/client {go_client_args}').format( - go_client_args=' '.join(args)) + "bash", + "-c", + ( + "cp /external_mount/src/core/tsi/test_creds/ca.pem " + "/etc/ssl/certs/ca-certificates.crt && " + "/go/bin/client {go_client_args}" + ).format(go_client_args=" ".join(args)), ] def global_env(self): return { - 'GRPC_GO_LOG_VERBOSITY_LEVEL': '3', - 'GRPC_GO_LOG_SEVERITY_LEVEL': 'INFO' + "GRPC_GO_LOG_VERBOSITY_LEVEL": "3", + "GRPC_GO_LOG_SEVERITY_LEVEL": "INFO", } def __str__(self): - return 'go' + return "go" _LANGUAGES = { - 'c++': CXXLanguage(), - 'go': GoLanguage(), - 'java': JavaLanguage(), + "c++": CXXLanguage(), + "go": GoLanguage(), + "java": JavaLanguage(), } def docker_run_cmdline(cmdline, image, docker_args, cwd, environ=None): """Wraps given cmdline array to create 'docker run' cmdline from it.""" # turn environ into -e docker args - docker_cmdline = 'docker run -i --rm=true'.split() + docker_cmdline = "docker run -i --rm=true".split() if environ: for k, v in list(environ.items()): - docker_cmdline += ['-e', '%s=%s' % (k, v)] - return docker_cmdline + ['-w', cwd] + docker_args + [image] + cmdline + docker_cmdline += ["-e", "%s=%s" % (k, v)] + return docker_cmdline + ["-w", cwd] + docker_args + [image] + cmdline def _job_kill_handler(job): @@ -183,60 +187,66 @@ def _job_kill_handler(job): def transport_security_to_args(transport_security): args = [] - if transport_security == 'tls': - args += ['--use_tls=true'] - elif transport_security == 'alts': - args += ['--use_tls=false', '--use_alts=true'] - elif transport_security == 'insecure': - args += ['--use_tls=false'] - elif transport_security == 'google_default_credentials': - args += ['--custom_credentials_type=google_default_credentials'] + if transport_security == "tls": + args += ["--use_tls=true"] + elif transport_security == "alts": + args += ["--use_tls=false", "--use_alts=true"] + elif transport_security == "insecure": + args += ["--use_tls=false"] + elif transport_security == "google_default_credentials": + args += ["--custom_credentials_type=google_default_credentials"] else: - print('Invalid transport security option.') + print("Invalid transport security option.") sys.exit(1) return args -def lb_client_interop_jobspec(language, - dns_server_ip, - docker_image, - transport_security='tls'): +def lb_client_interop_jobspec( + language, dns_server_ip, docker_image, transport_security="tls" +): """Runs a gRPC client under test in a docker container""" interop_only_options = [ - '--server_host=%s' % _SERVICE_NAME, - '--server_port=%d' % _FALLBACK_SERVER_PORT + "--server_host=%s" % _SERVICE_NAME, + "--server_port=%d" % _FALLBACK_SERVER_PORT, ] + transport_security_to_args(transport_security) # Don't set the server host override in any client; # Go and Java default to no override. # We're using a DNS server so there's no need. - if language.safename == 'c++': + if language.safename == "c++": interop_only_options += ['--server_host_override=""'] # Don't set --use_test_ca; we're configuring # clients to use test CA's via alternate means. - interop_only_options += ['--use_test_ca=false'] + interop_only_options += ["--use_test_ca=false"] client_args = language.client_cmd(interop_only_options) - container_name = dockerjob.random_name('lb_interop_client_%s' % - language.safename) + container_name = dockerjob.random_name( + "lb_interop_client_%s" % language.safename + ) docker_cmdline = docker_run_cmdline( client_args, environ=language.global_env(), image=docker_image, cwd=language.client_cwd, docker_args=[ - '--dns=%s' % dns_server_ip, - '--net=host', - '--name=%s' % container_name, - '-v', - '{grpc_grpc_root_dir}:/external_mount:ro'.format( - grpc_grpc_root_dir=ROOT), - ]) - jobset.message('IDLE', - 'docker_cmdline:\b|%s|' % ' '.join(docker_cmdline), - do_newline=True) - test_job = jobset.JobSpec(cmdline=docker_cmdline, - shortname=('lb_interop_client:%s' % language), - timeout_seconds=_TEST_TIMEOUT, - kill_handler=_job_kill_handler) + "--dns=%s" % dns_server_ip, + "--net=host", + "--name=%s" % container_name, + "-v", + "{grpc_grpc_root_dir}:/external_mount:ro".format( + grpc_grpc_root_dir=ROOT + ), + ], + ) + jobset.message( + "IDLE", + "docker_cmdline:\b|%s|" % " ".join(docker_cmdline), + do_newline=True, + ) + test_job = jobset.JobSpec( + cmdline=docker_cmdline, + shortname="lb_interop_client:%s" % language, + timeout_seconds=_TEST_TIMEOUT, + kill_handler=_job_kill_handler, + ) test_job.container_name = container_name return test_job @@ -244,165 +254,207 @@ def lb_client_interop_jobspec(language, def fallback_server_jobspec(transport_security, shortname): """Create jobspec for running a fallback server""" cmdline = [ - 'bin/server', - '--port=%d' % _FALLBACK_SERVER_PORT, + "bin/server", + "--port=%d" % _FALLBACK_SERVER_PORT, ] + transport_security_to_args(transport_security) - return grpc_server_in_docker_jobspec(server_cmdline=cmdline, - shortname=shortname) + return grpc_server_in_docker_jobspec( + server_cmdline=cmdline, shortname=shortname + ) def backend_server_jobspec(transport_security, shortname): """Create jobspec for running a backend server""" cmdline = [ - 'bin/server', - '--port=%d' % _BACKEND_SERVER_PORT, + "bin/server", + "--port=%d" % _BACKEND_SERVER_PORT, ] + transport_security_to_args(transport_security) - return grpc_server_in_docker_jobspec(server_cmdline=cmdline, - shortname=shortname) + return grpc_server_in_docker_jobspec( + server_cmdline=cmdline, shortname=shortname + ) def grpclb_jobspec(transport_security, short_stream, backend_addrs, shortname): """Create jobspec for running a balancer server""" cmdline = [ - 'bin/fake_grpclb', - '--backend_addrs=%s' % ','.join(backend_addrs), - '--port=%d' % _BALANCER_SERVER_PORT, - '--short_stream=%s' % short_stream, - '--service_name=%s' % _SERVICE_NAME, + "bin/fake_grpclb", + "--backend_addrs=%s" % ",".join(backend_addrs), + "--port=%d" % _BALANCER_SERVER_PORT, + "--short_stream=%s" % short_stream, + "--service_name=%s" % _SERVICE_NAME, ] + transport_security_to_args(transport_security) - return grpc_server_in_docker_jobspec(server_cmdline=cmdline, - shortname=shortname) + return grpc_server_in_docker_jobspec( + server_cmdline=cmdline, shortname=shortname + ) def grpc_server_in_docker_jobspec(server_cmdline, shortname): container_name = dockerjob.random_name(shortname) environ = { - 'GRPC_GO_LOG_VERBOSITY_LEVEL': '3', - 'GRPC_GO_LOG_SEVERITY_LEVEL': 'INFO ', + "GRPC_GO_LOG_VERBOSITY_LEVEL": "3", + "GRPC_GO_LOG_SEVERITY_LEVEL": "INFO ", } docker_cmdline = docker_run_cmdline( server_cmdline, - cwd='/go', + cwd="/go", image=docker_images.get(_FAKE_SERVERS_SAFENAME), environ=environ, - docker_args=['--name=%s' % container_name]) - jobset.message('IDLE', - 'docker_cmdline:\b|%s|' % ' '.join(docker_cmdline), - do_newline=True) - server_job = jobset.JobSpec(cmdline=docker_cmdline, - shortname=shortname, - timeout_seconds=30 * 60) + docker_args=["--name=%s" % container_name], + ) + jobset.message( + "IDLE", + "docker_cmdline:\b|%s|" % " ".join(docker_cmdline), + do_newline=True, + ) + server_job = jobset.JobSpec( + cmdline=docker_cmdline, shortname=shortname, timeout_seconds=30 * 60 + ) server_job.container_name = container_name return server_job -def dns_server_in_docker_jobspec(grpclb_ips, fallback_ips, shortname, - cause_no_error_no_data_for_balancer_a_record): +def dns_server_in_docker_jobspec( + grpclb_ips, + fallback_ips, + shortname, + cause_no_error_no_data_for_balancer_a_record, +): container_name = dockerjob.random_name(shortname) run_dns_server_cmdline = [ - 'python', - 'test/cpp/naming/utils/run_dns_server_for_lb_interop_tests.py', - '--grpclb_ips=%s' % ','.join(grpclb_ips), - '--fallback_ips=%s' % ','.join(fallback_ips), + "python", + "test/cpp/naming/utils/run_dns_server_for_lb_interop_tests.py", + "--grpclb_ips=%s" % ",".join(grpclb_ips), + "--fallback_ips=%s" % ",".join(fallback_ips), ] if cause_no_error_no_data_for_balancer_a_record: run_dns_server_cmdline.append( - '--cause_no_error_no_data_for_balancer_a_record') + "--cause_no_error_no_data_for_balancer_a_record" + ) docker_cmdline = docker_run_cmdline( run_dns_server_cmdline, - cwd='/var/local/git/grpc', + cwd="/var/local/git/grpc", image=docker_images.get(_FAKE_SERVERS_SAFENAME), - docker_args=['--name=%s' % container_name]) - jobset.message('IDLE', - 'docker_cmdline:\b|%s|' % ' '.join(docker_cmdline), - do_newline=True) - server_job = jobset.JobSpec(cmdline=docker_cmdline, - shortname=shortname, - timeout_seconds=30 * 60) + docker_args=["--name=%s" % container_name], + ) + jobset.message( + "IDLE", + "docker_cmdline:\b|%s|" % " ".join(docker_cmdline), + do_newline=True, + ) + server_job = jobset.JobSpec( + cmdline=docker_cmdline, shortname=shortname, timeout_seconds=30 * 60 + ) server_job.container_name = container_name return server_job -def build_interop_image_jobspec(lang_safename, basename_prefix='grpc_interop'): +def build_interop_image_jobspec(lang_safename, basename_prefix="grpc_interop"): """Creates jobspec for building interop docker image for a language""" - tag = '%s_%s:%s' % (basename_prefix, lang_safename, uuid.uuid4()) + tag = "%s_%s:%s" % (basename_prefix, lang_safename, uuid.uuid4()) env = { - 'INTEROP_IMAGE': tag, - 'BASE_NAME': '%s_%s' % (basename_prefix, lang_safename), + "INTEROP_IMAGE": tag, + "BASE_NAME": "%s_%s" % (basename_prefix, lang_safename), } build_job = jobset.JobSpec( - cmdline=['tools/run_tests/dockerize/build_interop_image.sh'], + cmdline=["tools/run_tests/dockerize/build_interop_image.sh"], environ=env, - shortname='build_docker_%s' % lang_safename, - timeout_seconds=30 * 60) + shortname="build_docker_%s" % lang_safename, + timeout_seconds=30 * 60, + ) build_job.tag = tag return build_job -argp = argparse.ArgumentParser(description='Run interop tests.') -argp.add_argument('-l', - '--language', - choices=['all'] + sorted(_LANGUAGES), - nargs='+', - default=['all'], - help='Clients to run.') -argp.add_argument('-j', '--jobs', default=multiprocessing.cpu_count(), type=int) -argp.add_argument('-s', - '--scenarios_file', - default=None, - type=str, - help='File containing test scenarios as JSON configs.') +argp = argparse.ArgumentParser(description="Run interop tests.") +argp.add_argument( + "-l", + "--language", + choices=["all"] + sorted(_LANGUAGES), + nargs="+", + default=["all"], + help="Clients to run.", +) +argp.add_argument("-j", "--jobs", default=multiprocessing.cpu_count(), type=int) +argp.add_argument( + "-s", + "--scenarios_file", + default=None, + type=str, + help="File containing test scenarios as JSON configs.", +) +argp.add_argument( + "-n", + "--scenario_name", + default=None, + type=str, + help=( + "Useful for manual runs: specify the name of " + "the scenario to run from scenarios_file. Run all scenarios if unset." + ), +) +argp.add_argument( + "--cxx_image_tag", + default=None, + type=str, + help=( + "Setting this skips the clients docker image " + "build step and runs the client from the named " + "image. Only supports running a one client language." + ), +) +argp.add_argument( + "--go_image_tag", + default=None, + type=str, + help=( + "Setting this skips the clients docker image build " + "step and runs the client from the named image. Only " + "supports running a one client language." + ), +) argp.add_argument( - '-n', - '--scenario_name', + "--java_image_tag", default=None, type=str, help=( - 'Useful for manual runs: specify the name of ' - 'the scenario to run from scenarios_file. Run all scenarios if unset.')) -argp.add_argument('--cxx_image_tag', - default=None, - type=str, - help=('Setting this skips the clients docker image ' - 'build step and runs the client from the named ' - 'image. Only supports running a one client language.')) -argp.add_argument('--go_image_tag', - default=None, - type=str, - help=('Setting this skips the clients docker image build ' - 'step and runs the client from the named image. Only ' - 'supports running a one client language.')) -argp.add_argument('--java_image_tag', - default=None, - type=str, - help=('Setting this skips the clients docker image build ' - 'step and runs the client from the named image. Only ' - 'supports running a one client language.')) + "Setting this skips the clients docker image build " + "step and runs the client from the named image. Only " + "supports running a one client language." + ), +) argp.add_argument( - '--servers_image_tag', + "--servers_image_tag", default=None, type=str, - help=('Setting this skips the fake servers docker image ' - 'build step and runs the servers from the named image.')) -argp.add_argument('--no_skips', - default=False, - type=bool, - nargs='?', - const=True, - help=('Useful for manual runs. Setting this overrides test ' - '"skips" configured in test scenarios.')) -argp.add_argument('--verbose', - default=False, - type=bool, - nargs='?', - const=True, - help='Increase logging.') + help=( + "Setting this skips the fake servers docker image " + "build step and runs the servers from the named image." + ), +) +argp.add_argument( + "--no_skips", + default=False, + type=bool, + nargs="?", + const=True, + help=( + "Useful for manual runs. Setting this overrides test " + '"skips" configured in test scenarios.' + ), +) +argp.add_argument( + "--verbose", + default=False, + type=bool, + nargs="?", + const=True, + help="Increase logging.", +) args = argp.parse_args() docker_images = {} build_jobs = [] -if len(args.language) and args.language[0] == 'all': +if len(args.language) and args.language[0] == "all": languages = list(_LANGUAGES.keys()) else: languages = args.language @@ -410,11 +462,11 @@ def build_interop_image_jobspec(lang_safename, basename_prefix='grpc_interop'): l = _LANGUAGES[lang_name] # First check if a pre-built image was supplied, and avoid # rebuilding the particular docker image if so. - if lang_name == 'c++' and args.cxx_image_tag: + if lang_name == "c++" and args.cxx_image_tag: docker_images[str(l.safename)] = args.cxx_image_tag - elif lang_name == 'go' and args.go_image_tag: + elif lang_name == "go" and args.go_image_tag: docker_images[str(l.safename)] = args.go_image_tag - elif lang_name == 'java' and args.java_image_tag: + elif lang_name == "java" and args.java_image_tag: docker_images[str(l.safename)] = args.java_image_tag else: # Build the test client in docker and save the fully @@ -429,69 +481,89 @@ def build_interop_image_jobspec(lang_safename, basename_prefix='grpc_interop'): else: # Build the test servers in docker and save the fully # built image. - job = build_interop_image_jobspec(_FAKE_SERVERS_SAFENAME, - basename_prefix='lb_interop') + job = build_interop_image_jobspec( + _FAKE_SERVERS_SAFENAME, basename_prefix="lb_interop" + ) build_jobs.append(job) docker_images[_FAKE_SERVERS_SAFENAME] = job.tag if build_jobs: - jobset.message('START', 'Building interop docker images.', do_newline=True) - print('Jobs to run: \n%s\n' % '\n'.join(str(j) for j in build_jobs)) - num_failures, _ = jobset.run(build_jobs, - newline_on_success=True, - maxjobs=args.jobs) + jobset.message("START", "Building interop docker images.", do_newline=True) + print("Jobs to run: \n%s\n" % "\n".join(str(j) for j in build_jobs)) + num_failures, _ = jobset.run( + build_jobs, newline_on_success=True, maxjobs=args.jobs + ) if num_failures == 0: - jobset.message('SUCCESS', - 'All docker images built successfully.', - do_newline=True) + jobset.message( + "SUCCESS", "All docker images built successfully.", do_newline=True + ) else: - jobset.message('FAILED', - 'Failed to build interop docker images.', - do_newline=True) + jobset.message( + "FAILED", "Failed to build interop docker images.", do_newline=True + ) sys.exit(1) def wait_until_dns_server_is_up(dns_server_ip): """Probes the DNS server until it's running and safe for tests.""" for i in range(0, 30): - print('Health check: attempt to connect to DNS server over TCP.') - tcp_connect_subprocess = subprocess.Popen([ - os.path.join(os.getcwd(), 'test/cpp/naming/utils/tcp_connect.py'), - '--server_host', dns_server_ip, '--server_port', - str(53), '--timeout', - str(1) - ]) + print("Health check: attempt to connect to DNS server over TCP.") + tcp_connect_subprocess = subprocess.Popen( + [ + os.path.join( + os.getcwd(), "test/cpp/naming/utils/tcp_connect.py" + ), + "--server_host", + dns_server_ip, + "--server_port", + str(53), + "--timeout", + str(1), + ] + ) tcp_connect_subprocess.communicate() if tcp_connect_subprocess.returncode == 0: - print(('Health check: attempt to make an A-record ' - 'query to DNS server.')) - dns_resolver_subprocess = subprocess.Popen([ - os.path.join(os.getcwd(), - 'test/cpp/naming/utils/dns_resolver.py'), - '--qname', - ('health-check-local-dns-server-is-alive.' - 'resolver-tests.grpctestingexp'), '--server_host', - dns_server_ip, '--server_port', - str(53) - ], - stdout=subprocess.PIPE) + print( + "Health check: attempt to make an A-record query to DNS server." + ) + dns_resolver_subprocess = subprocess.Popen( + [ + os.path.join( + os.getcwd(), "test/cpp/naming/utils/dns_resolver.py" + ), + "--qname", + ( + "health-check-local-dns-server-is-alive." + "resolver-tests.grpctestingexp" + ), + "--server_host", + dns_server_ip, + "--server_port", + str(53), + ], + stdout=subprocess.PIPE, + ) dns_resolver_stdout, _ = dns_resolver_subprocess.communicate() if dns_resolver_subprocess.returncode == 0: - if '123.123.123.123' in dns_resolver_stdout: - print(('DNS server is up! ' - 'Successfully reached it over UDP and TCP.')) + if "123.123.123.123" in dns_resolver_stdout: + print( + "DNS server is up! " + "Successfully reached it over UDP and TCP." + ) return time.sleep(0.1) - raise Exception(('Failed to reach DNS server over TCP and/or UDP. ' - 'Exitting without running tests.')) + raise Exception( + "Failed to reach DNS server over TCP and/or UDP. " + "Exitting without running tests." + ) def shortname(shortname_prefix, shortname, index): - return '%s_%s_%d' % (shortname_prefix, shortname, index) + return "%s_%s_%d" % (shortname_prefix, shortname, index) def run_one_scenario(scenario_config): - jobset.message('START', 'Run scenario: %s' % scenario_config['name']) + jobset.message("START", "Run scenario: %s" % scenario_config["name"]) server_jobs = {} server_addresses = {} suppress_server_logs = True @@ -499,42 +571,52 @@ def run_one_scenario(scenario_config): backend_addrs = [] fallback_ips = [] grpclb_ips = [] - shortname_prefix = scenario_config['name'] + shortname_prefix = scenario_config["name"] # Start backends - for i in range(len(scenario_config['backend_configs'])): - backend_config = scenario_config['backend_configs'][i] - backend_shortname = shortname(shortname_prefix, 'backend_server', i) + for i in range(len(scenario_config["backend_configs"])): + backend_config = scenario_config["backend_configs"][i] + backend_shortname = shortname(shortname_prefix, "backend_server", i) backend_spec = backend_server_jobspec( - backend_config['transport_sec'], backend_shortname) + backend_config["transport_sec"], backend_shortname + ) backend_job = dockerjob.DockerJob(backend_spec) server_jobs[backend_shortname] = backend_job backend_addrs.append( - '%s:%d' % (backend_job.ip_address(), _BACKEND_SERVER_PORT)) + "%s:%d" % (backend_job.ip_address(), _BACKEND_SERVER_PORT) + ) # Start fallbacks - for i in range(len(scenario_config['fallback_configs'])): - fallback_config = scenario_config['fallback_configs'][i] - fallback_shortname = shortname(shortname_prefix, 'fallback_server', - i) + for i in range(len(scenario_config["fallback_configs"])): + fallback_config = scenario_config["fallback_configs"][i] + fallback_shortname = shortname( + shortname_prefix, "fallback_server", i + ) fallback_spec = fallback_server_jobspec( - fallback_config['transport_sec'], fallback_shortname) + fallback_config["transport_sec"], fallback_shortname + ) fallback_job = dockerjob.DockerJob(fallback_spec) server_jobs[fallback_shortname] = fallback_job fallback_ips.append(fallback_job.ip_address()) # Start balancers - for i in range(len(scenario_config['balancer_configs'])): - balancer_config = scenario_config['balancer_configs'][i] - grpclb_shortname = shortname(shortname_prefix, 'grpclb_server', i) - grpclb_spec = grpclb_jobspec(balancer_config['transport_sec'], - balancer_config['short_stream'], - backend_addrs, grpclb_shortname) + for i in range(len(scenario_config["balancer_configs"])): + balancer_config = scenario_config["balancer_configs"][i] + grpclb_shortname = shortname(shortname_prefix, "grpclb_server", i) + grpclb_spec = grpclb_jobspec( + balancer_config["transport_sec"], + balancer_config["short_stream"], + backend_addrs, + grpclb_shortname, + ) grpclb_job = dockerjob.DockerJob(grpclb_spec) server_jobs[grpclb_shortname] = grpclb_job grpclb_ips.append(grpclb_job.ip_address()) # Start DNS server - dns_server_shortname = shortname(shortname_prefix, 'dns_server', 0) + dns_server_shortname = shortname(shortname_prefix, "dns_server", 0) dns_server_spec = dns_server_in_docker_jobspec( - grpclb_ips, fallback_ips, dns_server_shortname, - scenario_config['cause_no_error_no_data_for_balancer_a_record']) + grpclb_ips, + fallback_ips, + dns_server_shortname, + scenario_config["cause_no_error_no_data_for_balancer_a_record"], + ) dns_server_job = dockerjob.DockerJob(dns_server_spec) server_jobs[dns_server_shortname] = dns_server_job # Get the IP address of the docker container running the DNS server. @@ -550,35 +632,42 @@ def run_one_scenario(scenario_config): # Skip languages that are known to not currently # work for this test. if not args.no_skips and lang_name in scenario_config.get( - 'skip_langs', []): + "skip_langs", [] + ): jobset.message( - 'IDLE', 'Skipping scenario: %s for language: %s\n' % - (scenario_config['name'], lang_name)) + "IDLE", + "Skipping scenario: %s for language: %s\n" + % (scenario_config["name"], lang_name), + ) continue lang = _LANGUAGES[lang_name] test_job = lb_client_interop_jobspec( lang, dns_server_ip, docker_image=docker_images.get(lang.safename), - transport_security=scenario_config['transport_sec']) + transport_security=scenario_config["transport_sec"], + ) jobs.append(test_job) jobset.message( - 'IDLE', 'Jobs to run: \n%s\n' % '\n'.join(str(job) for job in jobs)) - num_failures, resultset = jobset.run(jobs, - newline_on_success=True, - maxjobs=args.jobs) - report_utils.render_junit_xml_report(resultset, 'sponge_log.xml') + "IDLE", "Jobs to run: \n%s\n" % "\n".join(str(job) for job in jobs) + ) + num_failures, resultset = jobset.run( + jobs, newline_on_success=True, maxjobs=args.jobs + ) + report_utils.render_junit_xml_report(resultset, "sponge_log.xml") if num_failures: suppress_server_logs = False - jobset.message('FAILED', - 'Scenario: %s. Some tests failed' % - scenario_config['name'], - do_newline=True) + jobset.message( + "FAILED", + "Scenario: %s. Some tests failed" % scenario_config["name"], + do_newline=True, + ) else: - jobset.message('SUCCESS', - 'Scenario: %s. All tests passed' % - scenario_config['name'], - do_newline=True) + jobset.message( + "SUCCESS", + "Scenario: %s. All tests passed" % scenario_config["name"], + do_newline=True, + ) return num_failures finally: # Check if servers are still running. @@ -586,18 +675,21 @@ def run_one_scenario(scenario_config): if not job.is_running(): print('Server "%s" has exited prematurely.' % server) suppress_failure = suppress_server_logs and not args.verbose - dockerjob.finish_jobs([j for j in six.itervalues(server_jobs)], - suppress_failure=suppress_failure) + dockerjob.finish_jobs( + [j for j in six.itervalues(server_jobs)], + suppress_failure=suppress_failure, + ) num_failures = 0 -with open(args.scenarios_file, 'r') as scenarios_input: +with open(args.scenarios_file, "r") as scenarios_input: all_scenarios = json.loads(scenarios_input.read()) for scenario in all_scenarios: if args.scenario_name: - if args.scenario_name != scenario['name']: - jobset.message('IDLE', - 'Skipping scenario: %s' % scenario['name']) + if args.scenario_name != scenario["name"]: + jobset.message( + "IDLE", "Skipping scenario: %s" % scenario["name"] + ) continue num_failures += run_one_scenario(scenario) if num_failures == 0: diff --git a/tools/run_tests/run_interop_tests.py b/tools/run_tests/run_interop_tests.py index a4642b8681338..05c4ebcebf75a 100755 --- a/tools/run_tests/run_interop_tests.py +++ b/tools/run_tests/run_interop_tests.py @@ -43,39 +43,43 @@ print(e) # Docker doesn't clean up after itself, so we do it on exit. -atexit.register(lambda: subprocess.call(['stty', 'echo'])) +atexit.register(lambda: subprocess.call(["stty", "echo"])) -ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(ROOT) _DEFAULT_SERVER_PORT = 8080 _SKIP_CLIENT_COMPRESSION = [ - 'client_compressed_unary', 'client_compressed_streaming' + "client_compressed_unary", + "client_compressed_streaming", ] _SKIP_SERVER_COMPRESSION = [ - 'server_compressed_unary', 'server_compressed_streaming' + "server_compressed_unary", + "server_compressed_streaming", ] _SKIP_COMPRESSION = _SKIP_CLIENT_COMPRESSION + _SKIP_SERVER_COMPRESSION _SKIP_ADVANCED = [ - 'status_code_and_message', 'custom_metadata', 'unimplemented_method', - 'unimplemented_service' + "status_code_and_message", + "custom_metadata", + "unimplemented_method", + "unimplemented_service", ] -_SKIP_SPECIAL_STATUS_MESSAGE = ['special_status_message'] +_SKIP_SPECIAL_STATUS_MESSAGE = ["special_status_message"] -_ORCA_TEST_CASES = ['orca_per_rpc', 'orca_oob'] +_ORCA_TEST_CASES = ["orca_per_rpc", "orca_oob"] -_GOOGLE_DEFAULT_CREDS_TEST_CASE = 'google_default_credentials' +_GOOGLE_DEFAULT_CREDS_TEST_CASE = "google_default_credentials" _SKIP_GOOGLE_DEFAULT_CREDS = [ _GOOGLE_DEFAULT_CREDS_TEST_CASE, ] -_COMPUTE_ENGINE_CHANNEL_CREDS_TEST_CASE = 'compute_engine_channel_credentials' +_COMPUTE_ENGINE_CHANNEL_CREDS_TEST_CASE = "compute_engine_channel_credentials" _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS = [ _COMPUTE_ENGINE_CHANNEL_CREDS_TEST_CASE, @@ -85,134 +89,138 @@ # disable this test on core-based languages, # see https://github.com/grpc/grpc/issues/9779 -_SKIP_DATA_FRAME_PADDING = ['data_frame_padding'] +_SKIP_DATA_FRAME_PADDING = ["data_frame_padding"] # report suffix "sponge_log.xml" is important for reports to get picked up by internal CI -_DOCKER_BUILD_XML_REPORT = 'interop_docker_build/sponge_log.xml' -_TESTS_XML_REPORT = 'interop_test/sponge_log.xml' +_DOCKER_BUILD_XML_REPORT = "interop_docker_build/sponge_log.xml" +_TESTS_XML_REPORT = "interop_test/sponge_log.xml" class CXXLanguage: - def __init__(self): self.client_cwd = None self.server_cwd = None self.http2_cwd = None - self.safename = 'cxx' + self.safename = "cxx" def client_cmd(self, args): - return ['cmake/build/interop_client'] + args + return ["cmake/build/interop_client"] + args def client_cmd_http2interop(self, args): - return ['cmake/build/http2_client'] + args + return ["cmake/build/http2_client"] + args def cloud_to_prod_env(self): return {} def server_cmd(self, args): - return ['cmake/build/interop_server'] + args + return ["cmake/build/interop_server"] + args def global_env(self): return {} def unimplemented_test_cases(self): - return _SKIP_DATA_FRAME_PADDING + \ - _SKIP_SPECIAL_STATUS_MESSAGE + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + return ( + _SKIP_DATA_FRAME_PADDING + + _SKIP_SPECIAL_STATUS_MESSAGE + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + ) def unimplemented_test_cases_server(self): return [] def __str__(self): - return 'c++' + return "c++" class AspNetCoreLanguage: - def __init__(self): - self.client_cwd = '../grpc-dotnet/output/InteropTestsClient' - self.server_cwd = '../grpc-dotnet/output/InteropTestsWebsite' + self.client_cwd = "../grpc-dotnet/output/InteropTestsClient" + self.server_cwd = "../grpc-dotnet/output/InteropTestsWebsite" self.safename = str(self) def cloud_to_prod_env(self): return {} def client_cmd(self, args): - return ['dotnet', 'exec', 'InteropTestsClient.dll'] + args + return ["dotnet", "exec", "InteropTestsClient.dll"] + args def server_cmd(self, args): - return ['dotnet', 'exec', 'InteropTestsWebsite.dll'] + args + return ["dotnet", "exec", "InteropTestsWebsite.dll"] + args def global_env(self): return {} def unimplemented_test_cases(self): - return _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + \ - _ORCA_TEST_CASES + return ( + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): return _ORCA_TEST_CASES def __str__(self): - return 'aspnetcore' + return "aspnetcore" class DartLanguage: - def __init__(self): - self.client_cwd = '../grpc-dart/interop' - self.server_cwd = '../grpc-dart/interop' - self.http2_cwd = '../grpc-dart/interop' + self.client_cwd = "../grpc-dart/interop" + self.server_cwd = "../grpc-dart/interop" + self.http2_cwd = "../grpc-dart/interop" self.safename = str(self) def client_cmd(self, args): - return ['dart', 'bin/client.dart'] + args + return ["dart", "bin/client.dart"] + args def cloud_to_prod_env(self): return {} def server_cmd(self, args): - return ['dart', 'bin/server.dart'] + args + return ["dart", "bin/server.dart"] + args def global_env(self): return {} def unimplemented_test_cases(self): - return _SKIP_COMPRESSION + \ - _SKIP_SPECIAL_STATUS_MESSAGE + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + \ - _ORCA_TEST_CASES + return ( + _SKIP_COMPRESSION + + _SKIP_SPECIAL_STATUS_MESSAGE + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): - return _SKIP_COMPRESSION + _SKIP_SPECIAL_STATUS_MESSAGE + _ORCA_TEST_CASES + return ( + _SKIP_COMPRESSION + _SKIP_SPECIAL_STATUS_MESSAGE + _ORCA_TEST_CASES + ) def __str__(self): - return 'dart' + return "dart" class JavaLanguage: - def __init__(self): - self.client_cwd = '../grpc-java' - self.server_cwd = '../grpc-java' - self.http2_cwd = '../grpc-java' + self.client_cwd = "../grpc-java" + self.server_cwd = "../grpc-java" + self.http2_cwd = "../grpc-java" self.safename = str(self) def client_cmd(self, args): - return ['./run-test-client.sh'] + args + return ["./run-test-client.sh"] + args def client_cmd_http2interop(self, args): return [ - './interop-testing/build/install/grpc-interop-testing/bin/http2-client' + "./interop-testing/build/install/grpc-interop-testing/bin/http2-client" ] + args def cloud_to_prod_env(self): return {} def server_cmd(self, args): - return ['./run-test-server.sh'] + args + return ["./run-test-server.sh"] + args def global_env(self): return {} @@ -223,20 +231,19 @@ def unimplemented_test_cases(self): def unimplemented_test_cases_server(self): # Does not support CompressedRequest feature. # Only supports CompressedResponse feature for unary. - return _SKIP_CLIENT_COMPRESSION + ['server_compressed_streaming'] + return _SKIP_CLIENT_COMPRESSION + ["server_compressed_streaming"] def __str__(self): - return 'java' + return "java" class JavaOkHttpClient: - def __init__(self): - self.client_cwd = '../grpc-java' - self.safename = 'java' + self.client_cwd = "../grpc-java" + self.safename = "java" def client_cmd(self, args): - return ['./run-test-client.sh', '--use_okhttp=true'] + args + return ["./run-test-client.sh", "--use_okhttp=true"] + args def cloud_to_prod_env(self): return {} @@ -248,32 +255,31 @@ def unimplemented_test_cases(self): return _SKIP_DATA_FRAME_PADDING def __str__(self): - return 'javaokhttp' + return "javaokhttp" class GoLanguage: - def __init__(self): # TODO: this relies on running inside docker - self.client_cwd = '/go/src/google.golang.org/grpc/interop/client' - self.server_cwd = '/go/src/google.golang.org/grpc/interop/server' - self.http2_cwd = '/go/src/google.golang.org/grpc/interop/http2' + self.client_cwd = "/go/src/google.golang.org/grpc/interop/client" + self.server_cwd = "/go/src/google.golang.org/grpc/interop/server" + self.http2_cwd = "/go/src/google.golang.org/grpc/interop/http2" self.safename = str(self) def client_cmd(self, args): - return ['go', 'run', 'client.go'] + args + return ["go", "run", "client.go"] + args def client_cmd_http2interop(self, args): - return ['go', 'run', 'negative_http2_client.go'] + args + return ["go", "run", "negative_http2_client.go"] + args def cloud_to_prod_env(self): return {} def server_cmd(self, args): - return ['go', 'run', 'server.go'] + args + return ["go", "run", "server.go"] + args def global_env(self): - return {'GO111MODULE': 'on'} + return {"GO111MODULE": "on"} def unimplemented_test_cases(self): return _SKIP_COMPRESSION @@ -282,22 +288,22 @@ def unimplemented_test_cases_server(self): return _SKIP_COMPRESSION def __str__(self): - return 'go' + return "go" class Http2Server: """Represents the HTTP/2 Interop Test server - This pretends to be a language in order to be built and run, but really it - isn't. - """ + This pretends to be a language in order to be built and run, but really it + isn't. + """ def __init__(self): self.server_cwd = None self.safename = str(self) def server_cmd(self, args): - return ['python test/http2_test/http2_test_server.py'] + return ["python test/http2_test/http2_test_server.py"] def cloud_to_prod_env(self): return {} @@ -306,32 +312,34 @@ def global_env(self): return {} def unimplemented_test_cases(self): - return _TEST_CASES + \ - _SKIP_DATA_FRAME_PADDING + \ - _SKIP_SPECIAL_STATUS_MESSAGE + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + return ( + _TEST_CASES + + _SKIP_DATA_FRAME_PADDING + + _SKIP_SPECIAL_STATUS_MESSAGE + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + ) def unimplemented_test_cases_server(self): return _TEST_CASES def __str__(self): - return 'http2' + return "http2" class Http2Client: """Represents the HTTP/2 Interop Test - This pretends to be a language in order to be built and run, but really it - isn't. - """ + This pretends to be a language in order to be built and run, but really it + isn't. + """ def __init__(self): self.client_cwd = None self.safename = str(self) def client_cmd(self, args): - return ['tools/http2_interop/http2_interop.test', '-test.v'] + args + return ["tools/http2_interop/http2_interop.test", "-test.v"] + args def cloud_to_prod_env(self): return {} @@ -340,30 +348,33 @@ def global_env(self): return {} def unimplemented_test_cases(self): - return _TEST_CASES + \ - _SKIP_SPECIAL_STATUS_MESSAGE + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + return ( + _TEST_CASES + + _SKIP_SPECIAL_STATUS_MESSAGE + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + ) def unimplemented_test_cases_server(self): return _TEST_CASES def __str__(self): - return 'http2' + return "http2" class NodeLanguage: - def __init__(self): - self.client_cwd = '../../../../home/appuser/grpc-node' - self.server_cwd = '../../../../home/appuser/grpc-node' + self.client_cwd = "../../../../home/appuser/grpc-node" + self.server_cwd = "../../../../home/appuser/grpc-node" self.safename = str(self) def client_cmd(self, args): return [ - 'packages/grpc-native-core/deps/grpc/tools/run_tests/interop/with_nvm.sh', - 'node', '--require', './test/fixtures/native_native', - 'test/interop/interop_client.js' + "packages/grpc-native-core/deps/grpc/tools/run_tests/interop/with_nvm.sh", + "node", + "--require", + "./test/fixtures/native_native", + "test/interop/interop_client.js", ] + args def cloud_to_prod_env(self): @@ -371,41 +382,45 @@ def cloud_to_prod_env(self): def server_cmd(self, args): return [ - 'packages/grpc-native-core/deps/grpc/tools/run_tests/interop/with_nvm.sh', - 'node', '--require', './test/fixtures/native_native', - 'test/interop/interop_server.js' + "packages/grpc-native-core/deps/grpc/tools/run_tests/interop/with_nvm.sh", + "node", + "--require", + "./test/fixtures/native_native", + "test/interop/interop_server.js", ] + args def global_env(self): return {} def unimplemented_test_cases(self): - return _SKIP_COMPRESSION + \ - _SKIP_DATA_FRAME_PADDING + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + \ - _ORCA_TEST_CASES + return ( + _SKIP_COMPRESSION + + _SKIP_DATA_FRAME_PADDING + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): - return _SKIP_COMPRESSION + \ - _ORCA_TEST_CASES + return _SKIP_COMPRESSION + _ORCA_TEST_CASES def __str__(self): - return 'node' + return "node" class NodePureJSLanguage: - def __init__(self): - self.client_cwd = '../../../../home/appuser/grpc-node' - self.server_cwd = '../../../../home/appuser/grpc-node' + self.client_cwd = "../../../../home/appuser/grpc-node" + self.server_cwd = "../../../../home/appuser/grpc-node" self.safename = str(self) def client_cmd(self, args): return [ - 'packages/grpc-native-core/deps/grpc/tools/run_tests/interop/with_nvm.sh', - 'node', '--require', './test/fixtures/js_js', - 'test/interop/interop_client.js' + "packages/grpc-native-core/deps/grpc/tools/run_tests/interop/with_nvm.sh", + "node", + "--require", + "./test/fixtures/js_js", + "test/interop/interop_client.js", ] + args def cloud_to_prod_env(self): @@ -415,66 +430,71 @@ def global_env(self): return {} def unimplemented_test_cases(self): - return _SKIP_COMPRESSION + \ - _SKIP_DATA_FRAME_PADDING + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + \ - _ORCA_TEST_CASES + return ( + _SKIP_COMPRESSION + + _SKIP_DATA_FRAME_PADDING + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): return _ORCA_TEST_CASES def __str__(self): - return 'nodepurejs' + return "nodepurejs" class PHP7Language: - def __init__(self): self.client_cwd = None self.server_cwd = None self.safename = str(self) def client_cmd(self, args): - return ['src/php/bin/interop_client.sh'] + args + return ["src/php/bin/interop_client.sh"] + args def cloud_to_prod_env(self): return {} def server_cmd(self, args): - return ['src/php/bin/interop_server.sh'] + args + return ["src/php/bin/interop_server.sh"] + args def global_env(self): return {} def unimplemented_test_cases(self): - return _SKIP_SERVER_COMPRESSION + \ - _SKIP_DATA_FRAME_PADDING + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + \ - _ORCA_TEST_CASES + return ( + _SKIP_SERVER_COMPRESSION + + _SKIP_DATA_FRAME_PADDING + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): - return _SKIP_COMPRESSION + \ - _ORCA_TEST_CASES + return _SKIP_COMPRESSION + _ORCA_TEST_CASES def __str__(self): - return 'php7' + return "php7" class ObjcLanguage: - def __init__(self): - self.client_cwd = 'src/objective-c/tests' + self.client_cwd = "src/objective-c/tests" self.safename = str(self) def client_cmd(self, args): # from args, extract the server port and craft xcodebuild command out of it for arg in args: - port = re.search('--server_port=(\d+)', arg) + port = re.search("--server_port=(\d+)", arg) if port: portnum = port.group(1) - cmdline = 'pod install && xcodebuild -workspace Tests.xcworkspace -scheme InteropTestsLocalSSL -destination name="iPhone 6" HOST_PORT_LOCALSSL=localhost:%s test' % portnum + cmdline = ( + "pod install && xcodebuild -workspace Tests.xcworkspace" + ' -scheme InteropTestsLocalSSL -destination name="iPhone 6"' + " HOST_PORT_LOCALSSL=localhost:%s test" % portnum + ) return [cmdline] def cloud_to_prod_env(self): @@ -488,24 +508,24 @@ def unimplemented_test_cases(self): # cmdline argument. Here we return all but one test cases as unimplemented, # and depend upon ObjC test's behavior that it runs all cases even when # we tell it to run just one. - return _TEST_CASES[1:] + \ - _SKIP_COMPRESSION + \ - _SKIP_DATA_FRAME_PADDING + \ - _SKIP_SPECIAL_STATUS_MESSAGE + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + \ - _ORCA_TEST_CASES + return ( + _TEST_CASES[1:] + + _SKIP_COMPRESSION + + _SKIP_DATA_FRAME_PADDING + + _SKIP_SPECIAL_STATUS_MESSAGE + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): - return _SKIP_COMPRESSION + \ - _ORCA_TEST_CASES + return _SKIP_COMPRESSION + _ORCA_TEST_CASES def __str__(self): - return 'objc' + return "objc" class RubyLanguage: - def __init__(self): self.client_cwd = None self.server_cwd = None @@ -513,8 +533,9 @@ def __init__(self): def client_cmd(self, args): return [ - 'tools/run_tests/interop/with_rvm.sh', 'ruby', - 'src/ruby/pb/test/client.rb' + "tools/run_tests/interop/with_rvm.sh", + "ruby", + "src/ruby/pb/test/client.rb", ] + args def cloud_to_prod_env(self): @@ -522,34 +543,35 @@ def cloud_to_prod_env(self): def server_cmd(self, args): return [ - 'tools/run_tests/interop/with_rvm.sh', 'ruby', - 'src/ruby/pb/test/server.rb' + "tools/run_tests/interop/with_rvm.sh", + "ruby", + "src/ruby/pb/test/server.rb", ] + args def global_env(self): return {} def unimplemented_test_cases(self): - return _SKIP_SERVER_COMPRESSION + \ - _SKIP_DATA_FRAME_PADDING + \ - _SKIP_SPECIAL_STATUS_MESSAGE + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + \ - _ORCA_TEST_CASES + return ( + _SKIP_SERVER_COMPRESSION + + _SKIP_DATA_FRAME_PADDING + + _SKIP_SPECIAL_STATUS_MESSAGE + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): - return _SKIP_COMPRESSION + \ - _ORCA_TEST_CASES + return _SKIP_COMPRESSION + _ORCA_TEST_CASES def __str__(self): - return 'ruby' + return "ruby" -_PYTHON_BINARY = 'py39/bin/python' +_PYTHON_BINARY = "py39/bin/python" class PythonLanguage: - def __init__(self): self.client_cwd = None self.server_cwd = None @@ -558,14 +580,17 @@ def __init__(self): def client_cmd(self, args): return [ - _PYTHON_BINARY, 'src/python/grpcio_tests/setup.py', 'run_interop', - '--client', '--args="{}"'.format(' '.join(args)) + _PYTHON_BINARY, + "src/python/grpcio_tests/setup.py", + "run_interop", + "--client", + '--args="{}"'.format(" ".join(args)), ] def client_cmd_http2interop(self, args): return [ _PYTHON_BINARY, - 'src/python/grpcio_tests/tests/http2/negative_http2_client.py', + "src/python/grpcio_tests/tests/http2/negative_http2_client.py", ] + args def cloud_to_prod_env(self): @@ -573,33 +598,36 @@ def cloud_to_prod_env(self): def server_cmd(self, args): return [ - _PYTHON_BINARY, 'src/python/grpcio_tests/setup.py', 'run_interop', - '--server', '--args="{}"'.format(' '.join(args)) + _PYTHON_BINARY, + "src/python/grpcio_tests/setup.py", + "run_interop", + "--server", + '--args="{}"'.format(" ".join(args)), ] def global_env(self): return { - 'LD_LIBRARY_PATH': '{}/libs/opt'.format(DOCKER_WORKDIR_ROOT), - 'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT) + "LD_LIBRARY_PATH": "{}/libs/opt".format(DOCKER_WORKDIR_ROOT), + "PYTHONPATH": "{}/src/python/gens".format(DOCKER_WORKDIR_ROOT), } def unimplemented_test_cases(self): - return _SKIP_COMPRESSION + \ - _SKIP_DATA_FRAME_PADDING + \ - _SKIP_GOOGLE_DEFAULT_CREDS + \ - _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + \ - _ORCA_TEST_CASES + return ( + _SKIP_COMPRESSION + + _SKIP_DATA_FRAME_PADDING + + _SKIP_GOOGLE_DEFAULT_CREDS + + _SKIP_COMPUTE_ENGINE_CHANNEL_CREDS + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): - return _SKIP_COMPRESSION + \ - _ORCA_TEST_CASES + return _SKIP_COMPRESSION + _ORCA_TEST_CASES def __str__(self): - return 'python' + return "python" class PythonAsyncIOLanguage: - def __init__(self): self.client_cwd = None self.server_cwd = None @@ -608,14 +636,18 @@ def __init__(self): def client_cmd(self, args): return [ - _PYTHON_BINARY, 'src/python/grpcio_tests/setup.py', 'run_interop', - '--use-asyncio', '--client', '--args="{}"'.format(' '.join(args)) + _PYTHON_BINARY, + "src/python/grpcio_tests/setup.py", + "run_interop", + "--use-asyncio", + "--client", + '--args="{}"'.format(" ".join(args)), ] def client_cmd_http2interop(self, args): return [ _PYTHON_BINARY, - 'src/python/grpcio_tests/tests/http2/negative_http2_client.py', + "src/python/grpcio_tests/tests/http2/negative_http2_client.py", ] + args def cloud_to_prod_env(self): @@ -623,107 +655,147 @@ def cloud_to_prod_env(self): def server_cmd(self, args): return [ - _PYTHON_BINARY, 'src/python/grpcio_tests/setup.py', 'run_interop', - '--use-asyncio', '--server', '--args="{}"'.format(' '.join(args)) + _PYTHON_BINARY, + "src/python/grpcio_tests/setup.py", + "run_interop", + "--use-asyncio", + "--server", + '--args="{}"'.format(" ".join(args)), ] def global_env(self): return { - 'LD_LIBRARY_PATH': '{}/libs/opt'.format(DOCKER_WORKDIR_ROOT), - 'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT) + "LD_LIBRARY_PATH": "{}/libs/opt".format(DOCKER_WORKDIR_ROOT), + "PYTHONPATH": "{}/src/python/gens".format(DOCKER_WORKDIR_ROOT), } def unimplemented_test_cases(self): # TODO(https://github.com/grpc/grpc/issues/21707) - return _SKIP_COMPRESSION + \ - _SKIP_DATA_FRAME_PADDING + \ - _AUTH_TEST_CASES + \ - ['timeout_on_sleeping_server'] + \ - _ORCA_TEST_CASES + return ( + _SKIP_COMPRESSION + + _SKIP_DATA_FRAME_PADDING + + _AUTH_TEST_CASES + + ["timeout_on_sleeping_server"] + + _ORCA_TEST_CASES + ) def unimplemented_test_cases_server(self): # TODO(https://github.com/grpc/grpc/issues/21749) - return _TEST_CASES + \ - _AUTH_TEST_CASES + \ - _HTTP2_TEST_CASES + \ - _HTTP2_SERVER_TEST_CASES + return ( + _TEST_CASES + + _AUTH_TEST_CASES + + _HTTP2_TEST_CASES + + _HTTP2_SERVER_TEST_CASES + ) def __str__(self): - return 'pythonasyncio' + return "pythonasyncio" _LANGUAGES = { - 'c++': CXXLanguage(), - 'aspnetcore': AspNetCoreLanguage(), - 'dart': DartLanguage(), - 'go': GoLanguage(), - 'java': JavaLanguage(), - 'javaokhttp': JavaOkHttpClient(), - 'node': NodeLanguage(), - 'nodepurejs': NodePureJSLanguage(), - 'php7': PHP7Language(), - 'objc': ObjcLanguage(), - 'ruby': RubyLanguage(), - 'python': PythonLanguage(), - 'pythonasyncio': PythonAsyncIOLanguage(), + "c++": CXXLanguage(), + "aspnetcore": AspNetCoreLanguage(), + "dart": DartLanguage(), + "go": GoLanguage(), + "java": JavaLanguage(), + "javaokhttp": JavaOkHttpClient(), + "node": NodeLanguage(), + "nodepurejs": NodePureJSLanguage(), + "php7": PHP7Language(), + "objc": ObjcLanguage(), + "ruby": RubyLanguage(), + "python": PythonLanguage(), + "pythonasyncio": PythonAsyncIOLanguage(), } # languages supported as cloud_to_cloud servers _SERVERS = [ - 'c++', 'node', 'aspnetcore', 'java', 'go', 'ruby', 'python', 'dart', - 'pythonasyncio', 'php7' + "c++", + "node", + "aspnetcore", + "java", + "go", + "ruby", + "python", + "dart", + "pythonasyncio", + "php7", ] _TEST_CASES = [ - 'large_unary', 'empty_unary', 'ping_pong', 'empty_stream', - 'client_streaming', 'server_streaming', 'cancel_after_begin', - 'cancel_after_first_response', 'timeout_on_sleeping_server', - 'custom_metadata', 'status_code_and_message', 'unimplemented_method', - 'client_compressed_unary', 'server_compressed_unary', - 'client_compressed_streaming', 'server_compressed_streaming', - 'unimplemented_service', 'special_status_message', 'orca_per_rpc', - 'orca_oob' + "large_unary", + "empty_unary", + "ping_pong", + "empty_stream", + "client_streaming", + "server_streaming", + "cancel_after_begin", + "cancel_after_first_response", + "timeout_on_sleeping_server", + "custom_metadata", + "status_code_and_message", + "unimplemented_method", + "client_compressed_unary", + "server_compressed_unary", + "client_compressed_streaming", + "server_compressed_streaming", + "unimplemented_service", + "special_status_message", + "orca_per_rpc", + "orca_oob", ] _AUTH_TEST_CASES = [ - 'compute_engine_creds', - 'jwt_token_creds', - 'oauth2_auth_token', - 'per_rpc_creds', + "compute_engine_creds", + "jwt_token_creds", + "oauth2_auth_token", + "per_rpc_creds", _GOOGLE_DEFAULT_CREDS_TEST_CASE, _COMPUTE_ENGINE_CHANNEL_CREDS_TEST_CASE, ] -_HTTP2_TEST_CASES = ['tls', 'framing'] +_HTTP2_TEST_CASES = ["tls", "framing"] _HTTP2_SERVER_TEST_CASES = [ - 'rst_after_header', 'rst_after_data', 'rst_during_data', 'goaway', 'ping', - 'max_streams', 'data_frame_padding', 'no_df_padding_sanity_test' + "rst_after_header", + "rst_after_data", + "rst_during_data", + "goaway", + "ping", + "max_streams", + "data_frame_padding", + "no_df_padding_sanity_test", ] _GRPC_CLIENT_TEST_CASES_FOR_HTTP2_SERVER_TEST_CASES = { - 'data_frame_padding': 'large_unary', - 'no_df_padding_sanity_test': 'large_unary' + "data_frame_padding": "large_unary", + "no_df_padding_sanity_test": "large_unary", } _HTTP2_SERVER_TEST_CASES_THAT_USE_GRPC_CLIENTS = list( - _GRPC_CLIENT_TEST_CASES_FOR_HTTP2_SERVER_TEST_CASES.keys()) + _GRPC_CLIENT_TEST_CASES_FOR_HTTP2_SERVER_TEST_CASES.keys() +) _LANGUAGES_WITH_HTTP2_CLIENTS_FOR_HTTP2_SERVER_TEST_CASES = [ - 'java', 'go', 'python', 'c++' + "java", + "go", + "python", + "c++", ] -_LANGUAGES_FOR_ALTS_TEST_CASES = ['java', 'go', 'c++', 'python'] +_LANGUAGES_FOR_ALTS_TEST_CASES = ["java", "go", "c++", "python"] -_SERVERS_FOR_ALTS_TEST_CASES = ['java', 'go', 'c++', 'python'] +_SERVERS_FOR_ALTS_TEST_CASES = ["java", "go", "c++", "python"] -_TRANSPORT_SECURITY_OPTIONS = ['tls', 'alts', 'insecure'] +_TRANSPORT_SECURITY_OPTIONS = ["tls", "alts", "insecure"] _CUSTOM_CREDENTIALS_TYPE_OPTIONS = [ - 'tls', 'google_default_credentials', 'compute_engine_channel_creds' + "tls", + "google_default_credentials", + "compute_engine_channel_creds", ] -DOCKER_WORKDIR_ROOT = '/var/local/git/grpc' +DOCKER_WORKDIR_ROOT = "/var/local/git/grpc" def docker_run_cmdline(cmdline, image, docker_args=[], cwd=None, environ=None): @@ -731,18 +803,18 @@ def docker_run_cmdline(cmdline, image, docker_args=[], cwd=None, environ=None): # don't use '-t' even when TTY is available, since that would break # the testcases generated by tools/interop_matrix/create_testcases.sh - docker_cmdline = ['docker', 'run', '-i', '--rm=true'] + docker_cmdline = ["docker", "run", "-i", "--rm=true"] # turn environ into -e docker args if environ: for k, v in list(environ.items()): - docker_cmdline += ['-e', '%s=%s' % (k, v)] + docker_cmdline += ["-e", "%s=%s" % (k, v)] # set working directory workdir = DOCKER_WORKDIR_ROOT if cwd: workdir = os.path.join(workdir, cwd) - docker_cmdline += ['-w', workdir] + docker_cmdline += ["-w", workdir] docker_cmdline += docker_args + [image] + cmdline return docker_cmdline @@ -752,82 +824,95 @@ def manual_cmdline(docker_cmdline, docker_image): """Returns docker cmdline adjusted for manual invocation.""" print_cmdline = [] for item in docker_cmdline: - if item.startswith('--name='): + if item.startswith("--name="): continue if item == docker_image: item = "$docker_image" item = item.replace('"', '\\"') # add quotes when necessary if any(character.isspace() for character in item): - item = "\"%s\"" % item + item = '"%s"' % item print_cmdline.append(item) - return ' '.join(print_cmdline) + return " ".join(print_cmdline) def write_cmdlog_maybe(cmdlog, filename): """Returns docker cmdline adjusted for manual invocation.""" if cmdlog: - with open(filename, 'w') as logfile: - logfile.write('#!/bin/bash\n') - logfile.write('# DO NOT MODIFY\n') + with open(filename, "w") as logfile: + logfile.write("#!/bin/bash\n") + logfile.write("# DO NOT MODIFY\n") logfile.write( - '# This file is generated by run_interop_tests.py/create_testcases.sh\n' + "# This file is generated by" + " run_interop_tests.py/create_testcases.sh\n" ) logfile.writelines("%s\n" % line for line in cmdlog) - print('Command log written to file %s' % filename) + print("Command log written to file %s" % filename) def bash_cmdline(cmdline): """Creates bash -c cmdline from args list.""" # Use login shell: # * makes error messages clearer if executables are missing - return ['bash', '-c', ' '.join(cmdline)] + return ["bash", "-c", " ".join(cmdline)] def compute_engine_creds_required(language, test_case): """Returns True if given test requires access to compute engine creds.""" language = str(language) - if test_case == 'compute_engine_creds': + if test_case == "compute_engine_creds": return True - if test_case == 'oauth2_auth_token' and language == 'c++': + if test_case == "oauth2_auth_token" and language == "c++": # C++ oauth2 test uses GCE creds because C++ only supports JWT return True return False -def auth_options(language, test_case, google_default_creds_use_key_file, - service_account_key_file, default_service_account): +def auth_options( + language, + test_case, + google_default_creds_use_key_file, + service_account_key_file, + default_service_account, +): """Returns (cmdline, env) tuple with cloud_to_prod_auth test options.""" language = str(language) cmdargs = [] env = {} - oauth_scope_arg = '--oauth_scope=https://www.googleapis.com/auth/xapi.zoo' - key_file_arg = '--service_account_key_file=%s' % service_account_key_file - default_account_arg = '--default_service_account=%s' % default_service_account + oauth_scope_arg = "--oauth_scope=https://www.googleapis.com/auth/xapi.zoo" + key_file_arg = "--service_account_key_file=%s" % service_account_key_file + default_account_arg = ( + "--default_service_account=%s" % default_service_account + ) - if test_case in ['jwt_token_creds', 'per_rpc_creds', 'oauth2_auth_token']: + if test_case in ["jwt_token_creds", "per_rpc_creds", "oauth2_auth_token"]: if language in [ - 'aspnetcore', 'node', 'php7', 'python', 'ruby', 'nodepurejs' + "aspnetcore", + "node", + "php7", + "python", + "ruby", + "nodepurejs", ]: - env['GOOGLE_APPLICATION_CREDENTIALS'] = service_account_key_file + env["GOOGLE_APPLICATION_CREDENTIALS"] = service_account_key_file else: cmdargs += [key_file_arg] - if test_case in ['per_rpc_creds', 'oauth2_auth_token']: + if test_case in ["per_rpc_creds", "oauth2_auth_token"]: cmdargs += [oauth_scope_arg] - if test_case == 'oauth2_auth_token' and language == 'c++': + if test_case == "oauth2_auth_token" and language == "c++": # C++ oauth2 test uses GCE creds and thus needs to know the default account cmdargs += [default_account_arg] - if test_case == 'compute_engine_creds': + if test_case == "compute_engine_creds": cmdargs += [oauth_scope_arg, default_account_arg] if test_case == _GOOGLE_DEFAULT_CREDS_TEST_CASE: if google_default_creds_use_key_file: - env['GOOGLE_APPLICATION_CREDENTIALS'] = service_account_key_file + env["GOOGLE_APPLICATION_CREDENTIALS"] = service_account_key_file cmdargs += [default_account_arg] if test_case == _COMPUTE_ENGINE_CHANNEL_CREDS_TEST_CASE: @@ -846,117 +931,142 @@ def _job_kill_handler(job): time.sleep(2) -def cloud_to_prod_jobspec(language, - test_case, - server_host_nickname, - server_host, - google_default_creds_use_key_file, - docker_image=None, - auth=False, - manual_cmd_log=None, - service_account_key_file=None, - default_service_account=None, - transport_security='tls'): +def cloud_to_prod_jobspec( + language, + test_case, + server_host_nickname, + server_host, + google_default_creds_use_key_file, + docker_image=None, + auth=False, + manual_cmd_log=None, + service_account_key_file=None, + default_service_account=None, + transport_security="tls", +): """Creates jobspec for cloud-to-prod interop test""" container_name = None cmdargs = [ - '--server_host=%s' % server_host, '--server_port=443', - '--test_case=%s' % test_case + "--server_host=%s" % server_host, + "--server_port=443", + "--test_case=%s" % test_case, ] - if transport_security == 'tls': - transport_security_options = ['--use_tls=true'] - elif transport_security == 'google_default_credentials' and str( - language) in ['c++', 'go', 'java', 'javaokhttp']: + if transport_security == "tls": + transport_security_options = ["--use_tls=true"] + elif transport_security == "google_default_credentials" and str( + language + ) in ["c++", "go", "java", "javaokhttp"]: transport_security_options = [ - '--custom_credentials_type=google_default_credentials' + "--custom_credentials_type=google_default_credentials" ] - elif transport_security == 'compute_engine_channel_creds' and str( - language) in ['go', 'java', 'javaokhttp']: + elif transport_security == "compute_engine_channel_creds" and str( + language + ) in ["go", "java", "javaokhttp"]: transport_security_options = [ - '--custom_credentials_type=compute_engine_channel_creds' + "--custom_credentials_type=compute_engine_channel_creds" ] else: print( - 'Invalid transport security option %s in cloud_to_prod_jobspec. Lang: %s' - % (str(language), transport_security)) + "Invalid transport security option %s in cloud_to_prod_jobspec." + " Lang: %s" % (str(language), transport_security) + ) sys.exit(1) cmdargs = cmdargs + transport_security_options environ = dict(language.cloud_to_prod_env(), **language.global_env()) if auth: auth_cmdargs, auth_env = auth_options( - language, test_case, google_default_creds_use_key_file, - service_account_key_file, default_service_account) + language, + test_case, + google_default_creds_use_key_file, + service_account_key_file, + default_service_account, + ) cmdargs += auth_cmdargs environ.update(auth_env) cmdline = bash_cmdline(language.client_cmd(cmdargs)) cwd = language.client_cwd if docker_image: - container_name = dockerjob.random_name('interop_client_%s' % - language.safename) + container_name = dockerjob.random_name( + "interop_client_%s" % language.safename + ) cmdline = docker_run_cmdline( cmdline, image=docker_image, cwd=cwd, environ=environ, - docker_args=['--net=host', - '--name=%s' % container_name]) + docker_args=["--net=host", "--name=%s" % container_name], + ) if manual_cmd_log is not None: if manual_cmd_log == []: - manual_cmd_log.append('echo "Testing ${docker_image:=%s}"' % - docker_image) + manual_cmd_log.append( + 'echo "Testing ${docker_image:=%s}"' % docker_image + ) manual_cmd_log.append(manual_cmdline(cmdline, docker_image)) cwd = None environ = None - suite_name = 'cloud_to_prod_auth' if auth else 'cloud_to_prod' - test_job = jobset.JobSpec(cmdline=cmdline, - cwd=cwd, - environ=environ, - shortname='%s:%s:%s:%s:%s' % - (suite_name, language, server_host_nickname, - test_case, transport_security), - timeout_seconds=_TEST_TIMEOUT, - flake_retries=4 if args.allow_flakes else 0, - timeout_retries=2 if args.allow_flakes else 0, - kill_handler=_job_kill_handler) + suite_name = "cloud_to_prod_auth" if auth else "cloud_to_prod" + test_job = jobset.JobSpec( + cmdline=cmdline, + cwd=cwd, + environ=environ, + shortname="%s:%s:%s:%s:%s" + % ( + suite_name, + language, + server_host_nickname, + test_case, + transport_security, + ), + timeout_seconds=_TEST_TIMEOUT, + flake_retries=4 if args.allow_flakes else 0, + timeout_retries=2 if args.allow_flakes else 0, + kill_handler=_job_kill_handler, + ) if docker_image: test_job.container_name = container_name return test_job -def cloud_to_cloud_jobspec(language, - test_case, - server_name, - server_host, - server_port, - docker_image=None, - transport_security='tls', - manual_cmd_log=None): +def cloud_to_cloud_jobspec( + language, + test_case, + server_name, + server_host, + server_port, + docker_image=None, + transport_security="tls", + manual_cmd_log=None, +): """Creates jobspec for cloud-to-cloud interop test""" interop_only_options = [ - '--server_host_override=foo.test.google.fr', - '--use_test_ca=true', + "--server_host_override=foo.test.google.fr", + "--use_test_ca=true", ] - if transport_security == 'tls': - interop_only_options += ['--use_tls=true'] - elif transport_security == 'alts': - interop_only_options += ['--use_tls=false', '--use_alts=true'] - elif transport_security == 'insecure': - interop_only_options += ['--use_tls=false'] + if transport_security == "tls": + interop_only_options += ["--use_tls=true"] + elif transport_security == "alts": + interop_only_options += ["--use_tls=false", "--use_alts=true"] + elif transport_security == "insecure": + interop_only_options += ["--use_tls=false"] else: print( - 'Invalid transport security option %s in cloud_to_cloud_jobspec.' % - transport_security) + "Invalid transport security option %s in cloud_to_cloud_jobspec." + % transport_security + ) sys.exit(1) client_test_case = test_case if test_case in _HTTP2_SERVER_TEST_CASES_THAT_USE_GRPC_CLIENTS: client_test_case = _GRPC_CLIENT_TEST_CASES_FOR_HTTP2_SERVER_TEST_CASES[ - test_case] + test_case + ] if client_test_case in language.unimplemented_test_cases(): - print('asking client %s to run unimplemented test case %s' % - (repr(language), client_test_case)) + print( + "asking client %s to run unimplemented test case %s" + % (repr(language), client_test_case) + ) sys.exit(1) if test_case in _ORCA_TEST_CASES: @@ -965,9 +1075,9 @@ def cloud_to_cloud_jobspec(language, ] common_options = [ - '--test_case=%s' % client_test_case, - '--server_host=%s' % server_host, - '--server_port=%s' % server_port, + "--test_case=%s" % client_test_case, + "--server_host=%s" % server_host, + "--server_port=%s" % server_port, ] if test_case in _HTTP2_SERVER_TEST_CASES: @@ -977,29 +1087,33 @@ def cloud_to_cloud_jobspec(language, cwd = language.client_cwd else: cmdline = bash_cmdline( - language.client_cmd_http2interop(common_options)) + language.client_cmd_http2interop(common_options) + ) cwd = language.http2_cwd else: cmdline = bash_cmdline( - language.client_cmd(common_options + interop_only_options)) + language.client_cmd(common_options + interop_only_options) + ) cwd = language.client_cwd environ = language.global_env() - if docker_image and language.safename != 'objc': + if docker_image and language.safename != "objc": # we can't run client in docker for objc. - container_name = dockerjob.random_name('interop_client_%s' % - language.safename) + container_name = dockerjob.random_name( + "interop_client_%s" % language.safename + ) cmdline = docker_run_cmdline( cmdline, image=docker_image, environ=environ, cwd=cwd, - docker_args=['--net=host', - '--name=%s' % container_name]) + docker_args=["--net=host", "--name=%s" % container_name], + ) if manual_cmd_log is not None: if manual_cmd_log == []: - manual_cmd_log.append('echo "Testing ${docker_image:=%s}"' % - docker_image) + manual_cmd_log.append( + 'echo "Testing ${docker_image:=%s}"' % docker_image + ) manual_cmd_log.append(manual_cmdline(cmdline, docker_image)) cwd = None @@ -1007,46 +1121,51 @@ def cloud_to_cloud_jobspec(language, cmdline=cmdline, cwd=cwd, environ=environ, - shortname='cloud_to_cloud:%s:%s_server:%s:%s' % - (language, server_name, test_case, transport_security), + shortname="cloud_to_cloud:%s:%s_server:%s:%s" + % (language, server_name, test_case, transport_security), timeout_seconds=_TEST_TIMEOUT, flake_retries=4 if args.allow_flakes else 0, timeout_retries=2 if args.allow_flakes else 0, - kill_handler=_job_kill_handler) + kill_handler=_job_kill_handler, + ) if docker_image: test_job.container_name = container_name return test_job -def server_jobspec(language, - docker_image, - transport_security='tls', - manual_cmd_log=None): +def server_jobspec( + language, docker_image, transport_security="tls", manual_cmd_log=None +): """Create jobspec for running a server""" - container_name = dockerjob.random_name('interop_server_%s' % - language.safename) - server_cmd = ['--port=%s' % _DEFAULT_SERVER_PORT] - if transport_security == 'tls': - server_cmd += ['--use_tls=true'] - elif transport_security == 'alts': - server_cmd += ['--use_tls=false', '--use_alts=true'] - elif transport_security == 'insecure': - server_cmd += ['--use_tls=false'] + container_name = dockerjob.random_name( + "interop_server_%s" % language.safename + ) + server_cmd = ["--port=%s" % _DEFAULT_SERVER_PORT] + if transport_security == "tls": + server_cmd += ["--use_tls=true"] + elif transport_security == "alts": + server_cmd += ["--use_tls=false", "--use_alts=true"] + elif transport_security == "insecure": + server_cmd += ["--use_tls=false"] else: - print('Invalid transport security option %s in server_jobspec.' % - transport_security) + print( + "Invalid transport security option %s in server_jobspec." + % transport_security + ) sys.exit(1) cmdline = bash_cmdline(language.server_cmd(server_cmd)) environ = language.global_env() - docker_args = ['--name=%s' % container_name] - if language.safename == 'http2': + docker_args = ["--name=%s" % container_name] + if language.safename == "http2": # we are running the http2 interop server. Open next N ports beginning # with the server port. These ports are used for http2 interop test # (one test case per port). docker_args += list( itertools.chain.from_iterable( - ('-p', str(_DEFAULT_SERVER_PORT + i)) - for i in range(len(_HTTP2_SERVER_TEST_CASES)))) + ("-p", str(_DEFAULT_SERVER_PORT + i)) + for i in range(len(_HTTP2_SERVER_TEST_CASES)) + ) + ) # Enable docker's healthcheck mechanism. # This runs a Python script inside the container every second. The script # pings the http2 server to verify it is ready. The 'health-retries' flag @@ -1056,31 +1175,36 @@ def server_jobspec(language, # or 'docker inspect' can be used to see the health of the container on the # command line. docker_args += [ - '--health-cmd=python test/http2_test/http2_server_health_check.py ' - '--server_host=%s --server_port=%d' % - ('localhost', _DEFAULT_SERVER_PORT), - '--health-interval=1s', - '--health-retries=5', - '--health-timeout=10s', + "--health-cmd=python test/http2_test/http2_server_health_check.py " + "--server_host=%s --server_port=%d" + % ("localhost", _DEFAULT_SERVER_PORT), + "--health-interval=1s", + "--health-retries=5", + "--health-timeout=10s", ] else: - docker_args += ['-p', str(_DEFAULT_SERVER_PORT)] + docker_args += ["-p", str(_DEFAULT_SERVER_PORT)] - docker_cmdline = docker_run_cmdline(cmdline, - image=docker_image, - cwd=language.server_cwd, - environ=environ, - docker_args=docker_args) + docker_cmdline = docker_run_cmdline( + cmdline, + image=docker_image, + cwd=language.server_cwd, + environ=environ, + docker_args=docker_args, + ) if manual_cmd_log is not None: if manual_cmd_log == []: - manual_cmd_log.append('echo "Testing ${docker_image:=%s}"' % - docker_image) + manual_cmd_log.append( + 'echo "Testing ${docker_image:=%s}"' % docker_image + ) manual_cmd_log.append(manual_cmdline(docker_cmdline, docker_image)) - server_job = jobset.JobSpec(cmdline=docker_cmdline, - environ=environ, - shortname='interop_server_%s' % language, - timeout_seconds=30 * 60) + server_job = jobset.JobSpec( + cmdline=docker_cmdline, + environ=environ, + shortname="interop_server_%s" % language, + timeout_seconds=30 * 60, + ) server_job.container_name = container_name return server_job @@ -1088,16 +1212,17 @@ def server_jobspec(language, def build_interop_image_jobspec(language, tag=None): """Creates jobspec for building interop docker image for a language""" if not tag: - tag = 'grpc_interop_%s:%s' % (language.safename, uuid.uuid4()) + tag = "grpc_interop_%s:%s" % (language.safename, uuid.uuid4()) env = { - 'INTEROP_IMAGE': tag, - 'BASE_NAME': 'grpc_interop_%s' % language.safename + "INTEROP_IMAGE": tag, + "BASE_NAME": "grpc_interop_%s" % language.safename, } build_job = jobset.JobSpec( - cmdline=['tools/run_tests/dockerize/build_interop_image.sh'], + cmdline=["tools/run_tests/dockerize/build_interop_image.sh"], environ=env, - shortname='build_docker_%s' % (language), - timeout_seconds=30 * 60) + shortname="build_docker_%s" % (language), + timeout_seconds=30 * 60, + ) build_job.tag = tag return build_job @@ -1112,210 +1237,259 @@ def aggregate_http2_results(stdout): passed = 0 failed = 0 failed_cases = [] - for case in results['cases']: - if case.get('skipped', False): + for case in results["cases"]: + if case.get("skipped", False): skipped += 1 else: - if case.get('passed', False): + if case.get("passed", False): passed += 1 else: failed += 1 - failed_cases.append(case.get('name', "NONAME")) + failed_cases.append(case.get("name", "NONAME")) return { - 'passed': passed, - 'failed': failed, - 'skipped': skipped, - 'failed_cases': ', '.join(failed_cases), - 'percent': 1.0 * passed / (passed + failed) + "passed": passed, + "failed": failed, + "skipped": skipped, + "failed_cases": ", ".join(failed_cases), + "percent": 1.0 * passed / (passed + failed), } # A dictionary of prod servers to test against. # See go/grpc-interop-tests (internal-only) for details. prod_servers = { - 'default': 'grpc-test.sandbox.googleapis.com', - 'gateway_v4': 'grpc-test4.sandbox.googleapis.com', + "default": "grpc-test.sandbox.googleapis.com", + "gateway_v4": "grpc-test4.sandbox.googleapis.com", } -argp = argparse.ArgumentParser(description='Run interop tests.') -argp.add_argument('-l', - '--language', - choices=['all'] + sorted(_LANGUAGES), - nargs='+', - default=['all'], - help='Clients to run. Objc client can be only run on OSX.') -argp.add_argument('-j', '--jobs', default=multiprocessing.cpu_count(), type=int) -argp.add_argument('--cloud_to_prod', - default=False, - action='store_const', - const=True, - help='Run cloud_to_prod tests.') -argp.add_argument('--cloud_to_prod_auth', - default=False, - action='store_const', - const=True, - help='Run cloud_to_prod_auth tests.') -argp.add_argument('--google_default_creds_use_key_file', - default=False, - action='store_const', - const=True, - help=('Whether or not we should use a key file for the ' - 'google_default_credentials test case, e.g. by ' - 'setting env var GOOGLE_APPLICATION_CREDENTIALS.')) -argp.add_argument('--prod_servers', - choices=list(prod_servers.keys()), - default=['default'], - nargs='+', - help=('The servers to run cloud_to_prod and ' - 'cloud_to_prod_auth tests against.')) -argp.add_argument('-s', - '--server', - choices=['all'] + sorted(_SERVERS), - nargs='+', - help='Run cloud_to_cloud servers in a separate docker ' + - 'image. Servers can only be started automatically if ' + - '--use_docker option is enabled.', - default=[]) +argp = argparse.ArgumentParser(description="Run interop tests.") argp.add_argument( - '--override_server', - action='append', - type=lambda kv: kv.split('='), - help= - 'Use servername=HOST:PORT to explicitly specify a server. E.g. csharp=localhost:50000', - default=[]) + "-l", + "--language", + choices=["all"] + sorted(_LANGUAGES), + nargs="+", + default=["all"], + help="Clients to run. Objc client can be only run on OSX.", +) +argp.add_argument("-j", "--jobs", default=multiprocessing.cpu_count(), type=int) +argp.add_argument( + "--cloud_to_prod", + default=False, + action="store_const", + const=True, + help="Run cloud_to_prod tests.", +) +argp.add_argument( + "--cloud_to_prod_auth", + default=False, + action="store_const", + const=True, + help="Run cloud_to_prod_auth tests.", +) +argp.add_argument( + "--google_default_creds_use_key_file", + default=False, + action="store_const", + const=True, + help=( + "Whether or not we should use a key file for the " + "google_default_credentials test case, e.g. by " + "setting env var GOOGLE_APPLICATION_CREDENTIALS." + ), +) +argp.add_argument( + "--prod_servers", + choices=list(prod_servers.keys()), + default=["default"], + nargs="+", + help=( + "The servers to run cloud_to_prod and cloud_to_prod_auth tests against." + ), +) +argp.add_argument( + "-s", + "--server", + choices=["all"] + sorted(_SERVERS), + nargs="+", + help="Run cloud_to_cloud servers in a separate docker " + + "image. Servers can only be started automatically if " + + "--use_docker option is enabled.", + default=[], +) +argp.add_argument( + "--override_server", + action="append", + type=lambda kv: kv.split("="), + help=( + "Use servername=HOST:PORT to explicitly specify a server. E.g." + " csharp=localhost:50000" + ), + default=[], +) # TODO(jtattermusch): the default service_account_key_file only works when --use_docker is used. argp.add_argument( - '--service_account_key_file', + "--service_account_key_file", type=str, - help='The service account key file to use for some auth interop tests.', - default='/root/service_account/grpc-testing-ebe7c1ac7381.json') + help="The service account key file to use for some auth interop tests.", + default="/root/service_account/grpc-testing-ebe7c1ac7381.json", +) argp.add_argument( - '--default_service_account', + "--default_service_account", type=str, - help='Default GCE service account email to use for some auth interop tests.', - default='830293263384-compute@developer.gserviceaccount.com') + help=( + "Default GCE service account email to use for some auth interop tests." + ), + default="830293263384-compute@developer.gserviceaccount.com", +) argp.add_argument( - '-t', - '--travis', + "-t", + "--travis", default=False, - action='store_const', + action="store_const", const=True, - help='When set, indicates that the script is running on CI (= not locally).' + help=( + "When set, indicates that the script is running on CI (= not locally)." + ), +) +argp.add_argument( + "-v", "--verbose", default=False, action="store_const", const=True ) -argp.add_argument('-v', - '--verbose', - default=False, - action='store_const', - const=True) argp.add_argument( - '--use_docker', + "--use_docker", default=False, - action='store_const', + action="store_const", const=True, - help='Run all the interop tests under docker. That provides ' + - 'additional isolation and prevents the need to install ' + - 'language specific prerequisites. Only available on Linux.') + help="Run all the interop tests under docker. That provides " + + "additional isolation and prevents the need to install " + + "language specific prerequisites. Only available on Linux.", +) argp.add_argument( - '--allow_flakes', + "--allow_flakes", default=False, - action='store_const', + action="store_const", const=True, - help= - 'Allow flaky tests to show as passing (re-runs failed tests up to five times)' + help=( + "Allow flaky tests to show as passing (re-runs failed tests up to five" + " times)" + ), ) -argp.add_argument('--manual_run', - default=False, - action='store_const', - const=True, - help='Prepare things for running interop tests manually. ' + - 'Preserve docker images after building them and skip ' - 'actually running the tests. Only print commands to run by ' + - 'hand.') argp.add_argument( - '--http2_interop', + "--manual_run", default=False, - action='store_const', + action="store_const", const=True, - help='Enable HTTP/2 client edge case testing. (Bad client, good server)') + help="Prepare things for running interop tests manually. " + + "Preserve docker images after building them and skip " + "actually running the tests. Only print commands to run by " + "hand.", +) argp.add_argument( - '--http2_server_interop', + "--http2_interop", default=False, - action='store_const', + action="store_const", const=True, - help= - 'Enable HTTP/2 server edge case testing. (Includes positive and negative tests' + help="Enable HTTP/2 client edge case testing. (Bad client, good server)", ) -argp.add_argument('--transport_security', - choices=_TRANSPORT_SECURITY_OPTIONS, - default='tls', - type=str, - nargs='?', - const=True, - help='Which transport security mechanism to use.') argp.add_argument( - '--custom_credentials_type', + "--http2_server_interop", + default=False, + action="store_const", + const=True, + help=( + "Enable HTTP/2 server edge case testing. (Includes positive and" + " negative tests" + ), +) +argp.add_argument( + "--transport_security", + choices=_TRANSPORT_SECURITY_OPTIONS, + default="tls", + type=str, + nargs="?", + const=True, + help="Which transport security mechanism to use.", +) +argp.add_argument( + "--custom_credentials_type", choices=_CUSTOM_CREDENTIALS_TYPE_OPTIONS, default=_CUSTOM_CREDENTIALS_TYPE_OPTIONS, - nargs='+', - help= - 'Credential types to test in the cloud_to_prod setup. Default is to test with all creds types possible.' + nargs="+", + help=( + "Credential types to test in the cloud_to_prod setup. Default is to" + " test with all creds types possible." + ), ) argp.add_argument( - '--skip_compute_engine_creds', + "--skip_compute_engine_creds", default=False, - action='store_const', + action="store_const", const=True, - help='Skip auth tests requiring access to compute engine credentials.') + help="Skip auth tests requiring access to compute engine credentials.", +) argp.add_argument( - '--internal_ci', + "--internal_ci", default=False, - action='store_const', + action="store_const", const=True, help=( - '(Deprecated, has no effect) Put reports into subdirectories to improve ' - 'presentation of results by Internal CI.')) -argp.add_argument('--bq_result_table', - default='', - type=str, - nargs='?', - help='Upload test results to a specified BQ table.') + "(Deprecated, has no effect) Put reports into subdirectories to improve" + " presentation of results by Internal CI." + ), +) +argp.add_argument( + "--bq_result_table", + default="", + type=str, + nargs="?", + help="Upload test results to a specified BQ table.", +) args = argp.parse_args() -servers = set(s for s in itertools.chain.from_iterable( - _SERVERS if x == 'all' else [x] for x in args.server)) +servers = set( + s + for s in itertools.chain.from_iterable( + _SERVERS if x == "all" else [x] for x in args.server + ) +) # ALTS servers are only available for certain languages. -if args.transport_security == 'alts': +if args.transport_security == "alts": servers = servers.intersection(_SERVERS_FOR_ALTS_TEST_CASES) if args.use_docker: if not args.travis: - print('Seen --use_docker flag, will run interop tests under docker.') - print('') + print("Seen --use_docker flag, will run interop tests under docker.") + print("") print( - 'IMPORTANT: The changes you are testing need to be locally committed' + "IMPORTANT: The changes you are testing need to be locally" + " committed" ) print( - 'because only the committed changes in the current branch will be') - print('copied to the docker environment.') + "because only the committed changes in the current branch will be" + ) + print("copied to the docker environment.") time.sleep(5) if args.manual_run and not args.use_docker: - print('--manual_run is only supported with --use_docker option enabled.') + print("--manual_run is only supported with --use_docker option enabled.") sys.exit(1) if not args.use_docker and servers: print( - 'Running interop servers is only supported with --use_docker option enabled.' + "Running interop servers is only supported with --use_docker option" + " enabled." ) sys.exit(1) # we want to include everything but objc in 'all' # because objc won't run on non-mac platforms -all_but_objc = set(six.iterkeys(_LANGUAGES)) - set(['objc']) -languages = set(_LANGUAGES[l] for l in itertools.chain.from_iterable( - all_but_objc if x == 'all' else [x] for x in args.language)) +all_but_objc = set(six.iterkeys(_LANGUAGES)) - set(["objc"]) +languages = set( + _LANGUAGES[l] + for l in itertools.chain.from_iterable( + all_but_objc if x == "all" else [x] for x in args.language + ) +) # ALTS interop clients are only available for certain languages. -if args.transport_security == 'alts': +if args.transport_security == "alts": alts_languages = set(_LANGUAGES[l] for l in _LANGUAGES_FOR_ALTS_TEST_CASES) languages = languages.intersection(alts_languages) @@ -1324,7 +1498,8 @@ def aggregate_http2_results(stdout): languages_http2_clients_for_http2_server_interop = set( _LANGUAGES[l] for l in _LANGUAGES_WITH_HTTP2_CLIENTS_FOR_HTTP2_SERVER_TEST_CASES - if 'all' in args.language or l in args.language) + if "all" in args.language or l in args.language + ) http2Interop = Http2Client() if args.http2_interop else None http2InteropServer = Http2Server() if args.http2_server_interop else None @@ -1332,10 +1507,13 @@ def aggregate_http2_results(stdout): docker_images = {} if args.use_docker: # languages for which to build docker images - languages_to_build = set(_LANGUAGES[k] - for k in set([str(l) for l in languages] + - [s for s in servers])) - languages_to_build = languages_to_build | languages_http2_clients_for_http2_server_interop + languages_to_build = set( + _LANGUAGES[k] + for k in set([str(l) for l in languages] + [s for s in servers]) + ) + languages_to_build = ( + languages_to_build | languages_http2_clients_for_http2_server_interop + ) if args.http2_interop: languages_to_build.add(http2Interop) @@ -1345,7 +1523,7 @@ def aggregate_http2_results(stdout): build_jobs = [] for l in languages_to_build: - if str(l) == 'objc': + if str(l) == "objc": # we don't need to build a docker image for objc continue job = build_interop_image_jobspec(l) @@ -1353,27 +1531,32 @@ def aggregate_http2_results(stdout): build_jobs.append(job) if build_jobs: - jobset.message('START', - 'Building interop docker images.', - do_newline=True) + jobset.message( + "START", "Building interop docker images.", do_newline=True + ) if args.verbose: - print('Jobs to run: \n%s\n' % '\n'.join(str(j) for j in build_jobs)) + print("Jobs to run: \n%s\n" % "\n".join(str(j) for j in build_jobs)) - num_failures, build_resultset = jobset.run(build_jobs, - newline_on_success=True, - maxjobs=args.jobs) + num_failures, build_resultset = jobset.run( + build_jobs, newline_on_success=True, maxjobs=args.jobs + ) - report_utils.render_junit_xml_report(build_resultset, - _DOCKER_BUILD_XML_REPORT) + report_utils.render_junit_xml_report( + build_resultset, _DOCKER_BUILD_XML_REPORT + ) if num_failures == 0: - jobset.message('SUCCESS', - 'All docker images built successfully.', - do_newline=True) + jobset.message( + "SUCCESS", + "All docker images built successfully.", + do_newline=True, + ) else: - jobset.message('FAILED', - 'Failed to build interop docker images.', - do_newline=True) + jobset.message( + "FAILED", + "Failed to build interop docker images.", + do_newline=True, + ) for image in six.itervalues(docker_images): dockerjob.remove_image(image, skip_nonexistent=True) sys.exit(1) @@ -1387,70 +1570,87 @@ def aggregate_http2_results(stdout): try: for s in servers: lang = str(s) - spec = server_jobspec(_LANGUAGES[lang], - docker_images.get(lang), - args.transport_security, - manual_cmd_log=server_manual_cmd_log) + spec = server_jobspec( + _LANGUAGES[lang], + docker_images.get(lang), + args.transport_security, + manual_cmd_log=server_manual_cmd_log, + ) if not args.manual_run: job = dockerjob.DockerJob(spec) server_jobs[lang] = job - server_addresses[lang] = ('localhost', - job.mapped_port(_DEFAULT_SERVER_PORT)) + server_addresses[lang] = ( + "localhost", + job.mapped_port(_DEFAULT_SERVER_PORT), + ) else: # don't run the server, set server port to a placeholder value - server_addresses[lang] = ('localhost', '${SERVER_PORT}') + server_addresses[lang] = ("localhost", "${SERVER_PORT}") http2_server_job = None if args.http2_server_interop: # launch a HTTP2 server emulator that creates edge cases lang = str(http2InteropServer) - spec = server_jobspec(http2InteropServer, - docker_images.get(lang), - manual_cmd_log=server_manual_cmd_log) + spec = server_jobspec( + http2InteropServer, + docker_images.get(lang), + manual_cmd_log=server_manual_cmd_log, + ) if not args.manual_run: http2_server_job = dockerjob.DockerJob(spec) server_jobs[lang] = http2_server_job else: # don't run the server, set server port to a placeholder value - server_addresses[lang] = ('localhost', '${SERVER_PORT}') + server_addresses[lang] = ("localhost", "${SERVER_PORT}") jobs = [] if args.cloud_to_prod: - if args.transport_security not in ['tls']: - print('TLS is always enabled for cloud_to_prod scenarios.') + if args.transport_security not in ["tls"]: + print("TLS is always enabled for cloud_to_prod scenarios.") for server_host_nickname in args.prod_servers: for language in languages: for test_case in _TEST_CASES: if not test_case in language.unimplemented_test_cases(): - if not test_case in _SKIP_ADVANCED + _SKIP_COMPRESSION + _SKIP_SPECIAL_STATUS_MESSAGE + _ORCA_TEST_CASES: - for transport_security in args.custom_credentials_type: + if ( + not test_case + in _SKIP_ADVANCED + + _SKIP_COMPRESSION + + _SKIP_SPECIAL_STATUS_MESSAGE + + _ORCA_TEST_CASES + ): + for ( + transport_security + ) in args.custom_credentials_type: # google_default_credentials not yet supported by all languages - if transport_security == 'google_default_credentials' and str( - language) not in [ - 'c++', 'go', 'java', 'javaokhttp' - ]: + if ( + transport_security + == "google_default_credentials" + and str(language) + not in ["c++", "go", "java", "javaokhttp"] + ): continue # compute_engine_channel_creds not yet supported by all languages - if transport_security == 'compute_engine_channel_creds' and str( - language) not in [ - 'go', 'java', 'javaokhttp' - ]: + if ( + transport_security + == "compute_engine_channel_creds" + and str(language) + not in ["go", "java", "javaokhttp"] + ): continue test_job = cloud_to_prod_jobspec( language, test_case, server_host_nickname, prod_servers[server_host_nickname], - google_default_creds_use_key_file=args. - google_default_creds_use_key_file, + google_default_creds_use_key_file=args.google_default_creds_use_key_file, docker_image=docker_images.get( - str(language)), + str(language) + ), manual_cmd_log=client_manual_cmd_log, - service_account_key_file=args. - service_account_key_file, - default_service_account=args. - default_service_account, - transport_security=transport_security) + service_account_key_file=args.service_account_key_file, + default_service_account=args.default_service_account, + transport_security=transport_security, + ) jobs.append(test_job) if args.http2_interop: for test_case in _HTTP2_TEST_CASES: @@ -1459,52 +1659,63 @@ def aggregate_http2_results(stdout): test_case, server_host_nickname, prod_servers[server_host_nickname], - google_default_creds_use_key_file=args. - google_default_creds_use_key_file, + google_default_creds_use_key_file=args.google_default_creds_use_key_file, docker_image=docker_images.get(str(http2Interop)), manual_cmd_log=client_manual_cmd_log, service_account_key_file=args.service_account_key_file, default_service_account=args.default_service_account, - transport_security=args.transport_security) + transport_security=args.transport_security, + ) jobs.append(test_job) if args.cloud_to_prod_auth: - if args.transport_security not in ['tls']: - print('TLS is always enabled for cloud_to_prod scenarios.') + if args.transport_security not in ["tls"]: + print("TLS is always enabled for cloud_to_prod scenarios.") for server_host_nickname in args.prod_servers: for language in languages: for test_case in _AUTH_TEST_CASES: - if (not args.skip_compute_engine_creds or - not compute_engine_creds_required( - language, test_case)): + if ( + not args.skip_compute_engine_creds + or not compute_engine_creds_required( + language, test_case + ) + ): if not test_case in language.unimplemented_test_cases(): if test_case == _GOOGLE_DEFAULT_CREDS_TEST_CASE: - transport_security = 'google_default_credentials' - elif test_case == _COMPUTE_ENGINE_CHANNEL_CREDS_TEST_CASE: - transport_security = 'compute_engine_channel_creds' + transport_security = ( + "google_default_credentials" + ) + elif ( + test_case + == _COMPUTE_ENGINE_CHANNEL_CREDS_TEST_CASE + ): + transport_security = ( + "compute_engine_channel_creds" + ) else: - transport_security = 'tls' - if transport_security not in args.custom_credentials_type: + transport_security = "tls" + if ( + transport_security + not in args.custom_credentials_type + ): continue test_job = cloud_to_prod_jobspec( language, test_case, server_host_nickname, prod_servers[server_host_nickname], - google_default_creds_use_key_file=args. - google_default_creds_use_key_file, + google_default_creds_use_key_file=args.google_default_creds_use_key_file, docker_image=docker_images.get(str(language)), auth=True, manual_cmd_log=client_manual_cmd_log, - service_account_key_file=args. - service_account_key_file, - default_service_account=args. - default_service_account, - transport_security=transport_security) + service_account_key_file=args.service_account_key_file, + default_service_account=args.default_service_account, + transport_security=transport_security, + ) jobs.append(test_job) for server in args.override_server: server_name = server[0] - (server_host, server_port) = server[1].split(':') + (server_host, server_port) = server[1].split(":") server_addresses[server_name] = (server_host, server_port) for server_name, server_address in list(server_addresses.items()): @@ -1525,7 +1736,8 @@ def aggregate_http2_results(stdout): server_port, docker_image=docker_images.get(str(language)), transport_security=args.transport_security, - manual_cmd_log=client_manual_cmd_log) + manual_cmd_log=client_manual_cmd_log, + ) jobs.append(test_job) if args.http2_interop: @@ -1541,7 +1753,8 @@ def aggregate_http2_results(stdout): server_port, docker_image=docker_images.get(str(http2Interop)), transport_security=args.transport_security, - manual_cmd_log=client_manual_cmd_log) + manual_cmd_log=client_manual_cmd_log, + ) jobs.append(test_job) if args.http2_server_interop: @@ -1549,7 +1762,8 @@ def aggregate_http2_results(stdout): http2_server_job.wait_for_healthy(timeout_seconds=600) for language in languages_http2_clients_for_http2_server_interop: for test_case in set(_HTTP2_SERVER_TEST_CASES) - set( - _HTTP2_SERVER_TEST_CASES_THAT_USE_GRPC_CLIENTS): + _HTTP2_SERVER_TEST_CASES_THAT_USE_GRPC_CLIENTS + ): offset = sorted(_HTTP2_SERVER_TEST_CASES).index(test_case) server_port = _DEFAULT_SERVER_PORT + offset if not args.manual_run: @@ -1558,10 +1772,11 @@ def aggregate_http2_results(stdout): language, test_case, str(http2InteropServer), - 'localhost', + "localhost", server_port, docker_image=docker_images.get(str(language)), - manual_cmd_log=client_manual_cmd_log) + manual_cmd_log=client_manual_cmd_log, + ) jobs.append(test_job) for language in languages: # HTTP2_SERVER_TEST_CASES_THAT_USE_GRPC_CLIENTS is a subset of @@ -1576,48 +1791,52 @@ def aggregate_http2_results(stdout): server_port = _DEFAULT_SERVER_PORT + offset if not args.manual_run: server_port = http2_server_job.mapped_port(server_port) - if args.transport_security != 'insecure': + if args.transport_security != "insecure": print( - ('Creating grpc client to http2 server test case ' - 'with insecure connection, even though ' - 'args.transport_security is not insecure. Http2 ' - 'test server only supports insecure connections.')) + "Creating grpc client to http2 server test case " + "with insecure connection, even though " + "args.transport_security is not insecure. Http2 " + "test server only supports insecure connections." + ) test_job = cloud_to_cloud_jobspec( language, test_case, str(http2InteropServer), - 'localhost', + "localhost", server_port, docker_image=docker_images.get(str(language)), - transport_security='insecure', - manual_cmd_log=client_manual_cmd_log) + transport_security="insecure", + manual_cmd_log=client_manual_cmd_log, + ) jobs.append(test_job) if not jobs: - print('No jobs to run.') + print("No jobs to run.") for image in six.itervalues(docker_images): dockerjob.remove_image(image, skip_nonexistent=True) sys.exit(1) if args.manual_run: - print('All tests will skipped --manual_run option is active.') + print("All tests will skipped --manual_run option is active.") if args.verbose: - print('Jobs to run: \n%s\n' % '\n'.join(str(job) for job in jobs)) + print("Jobs to run: \n%s\n" % "\n".join(str(job) for job in jobs)) - num_failures, resultset = jobset.run(jobs, - newline_on_success=True, - maxjobs=args.jobs, - skip_jobs=args.manual_run) + num_failures, resultset = jobset.run( + jobs, + newline_on_success=True, + maxjobs=args.jobs, + skip_jobs=args.manual_run, + ) if args.bq_result_table and resultset: upload_interop_results_to_bq(resultset, args.bq_result_table) if num_failures: - jobset.message('FAILED', 'Some tests failed', do_newline=True) + jobset.message("FAILED", "Some tests failed", do_newline=True) else: - jobset.message('SUCCESS', 'All tests passed', do_newline=True) + jobset.message("SUCCESS", "All tests passed", do_newline=True) - write_cmdlog_maybe(server_manual_cmd_log, 'interop_server_cmds.sh') - write_cmdlog_maybe(client_manual_cmd_log, 'interop_client_cmds.sh') + write_cmdlog_maybe(server_manual_cmd_log, "interop_server_cmds.sh") + write_cmdlog_maybe(client_manual_cmd_log, "interop_client_cmds.sh") report_utils.render_junit_xml_report(resultset, _TESTS_XML_REPORT) @@ -1625,8 +1844,9 @@ def aggregate_http2_results(stdout): if "http2" in name: job[0].http2results = aggregate_http2_results(job[0].message) - http2_server_test_cases = (_HTTP2_SERVER_TEST_CASES - if args.http2_server_interop else []) + http2_server_test_cases = ( + _HTTP2_SERVER_TEST_CASES if args.http2_server_interop else [] + ) if num_failures: sys.exit(1) @@ -1642,7 +1862,7 @@ def aggregate_http2_results(stdout): for image in six.itervalues(docker_images): if not args.manual_run: - print('Removing docker image %s' % image) + print("Removing docker image %s" % image) dockerjob.remove_image(image) else: - print('Preserving docker image: %s' % image) + print("Preserving docker image: %s" % image) diff --git a/tools/run_tests/run_microbenchmark.py b/tools/run_tests/run_microbenchmark.py index 27316225db67e..89d9ecb20a74a 100755 --- a/tools/run_tests/run_microbenchmark.py +++ b/tools/run_tests/run_microbenchmark.py @@ -24,26 +24,32 @@ import python_utils.start_port_server as start_port_server sys.path.append( - os.path.join(os.path.dirname(sys.argv[0]), '..', 'profiling', - 'microbenchmarks', 'bm_diff')) + os.path.join( + os.path.dirname(sys.argv[0]), + "..", + "profiling", + "microbenchmarks", + "bm_diff", + ) +) import bm_constants -flamegraph_dir = os.path.join(os.path.expanduser('~'), 'FlameGraph') +flamegraph_dir = os.path.join(os.path.expanduser("~"), "FlameGraph") -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../..')) -if not os.path.exists('reports'): - os.makedirs('reports') +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../..")) +if not os.path.exists("reports"): + os.makedirs("reports") start_port_server.start_port_server() def fnize(s): - out = '' + out = "" for c in s: - if c in '<>, /': - if len(out) and out[-1] == '_': + if c in "<>, /": + if len(out) and out[-1] == "_": continue - out += '_' + out += "_" else: out += c return out @@ -66,8 +72,10 @@ def heading(name): def link(txt, tgt): global index_html - index_html += "

%s

\n" % (html.escape( - tgt, quote=True), html.escape(txt)) + index_html += '

%s

\n' % ( + html.escape(tgt, quote=True), + html.escape(txt), + ) def text(txt): @@ -77,30 +85,33 @@ def text(txt): def _bazel_build_benchmark(bm_name, cfg): """Build given benchmark with bazel""" - subprocess.check_call([ - 'tools/bazel', 'build', - '--config=%s' % cfg, - '//test/cpp/microbenchmarks:%s' % bm_name - ]) + subprocess.check_call( + [ + "tools/bazel", + "build", + "--config=%s" % cfg, + "//test/cpp/microbenchmarks:%s" % bm_name, + ] + ) def run_summary(bm_name, cfg, base_json_name): _bazel_build_benchmark(bm_name, cfg) cmd = [ - 'bazel-bin/test/cpp/microbenchmarks/%s' % bm_name, - '--benchmark_out=%s.%s.json' % (base_json_name, cfg), - '--benchmark_out_format=json' + "bazel-bin/test/cpp/microbenchmarks/%s" % bm_name, + "--benchmark_out=%s.%s.json" % (base_json_name, cfg), + "--benchmark_out_format=json", ] if args.summary_time is not None: - cmd += ['--benchmark_min_time=%d' % args.summary_time] - return subprocess.check_output(cmd).decode('UTF-8') + cmd += ["--benchmark_min_time=%d" % args.summary_time] + return subprocess.check_output(cmd).decode("UTF-8") def collect_summary(bm_name, args): # no counters, run microbenchmark and add summary # both to HTML report and to console. - nocounters_heading = 'Summary: %s' % bm_name - nocounters_summary = run_summary(bm_name, 'opt', bm_name) + nocounters_heading = "Summary: %s" % bm_name + nocounters_summary = run_summary(bm_name, "opt", bm_name) heading(nocounters_heading) text(nocounters_summary) print(nocounters_heading) @@ -108,34 +119,41 @@ def collect_summary(bm_name, args): collectors = { - 'summary': collect_summary, + "summary": collect_summary, } -argp = argparse.ArgumentParser(description='Collect data from microbenchmarks') -argp.add_argument('-c', - '--collect', - choices=sorted(collectors.keys()), - nargs='*', - default=sorted(collectors.keys()), - help='Which collectors should be run against each benchmark') -argp.add_argument('-b', - '--benchmarks', - choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, - default=bm_constants._AVAILABLE_BENCHMARK_TESTS, - nargs='+', - type=str, - help='Which microbenchmarks should be run') +argp = argparse.ArgumentParser(description="Collect data from microbenchmarks") +argp.add_argument( + "-c", + "--collect", + choices=sorted(collectors.keys()), + nargs="*", + default=sorted(collectors.keys()), + help="Which collectors should be run against each benchmark", +) argp.add_argument( - '--bq_result_table', - default='', + "-b", + "--benchmarks", + choices=bm_constants._AVAILABLE_BENCHMARK_TESTS, + default=bm_constants._AVAILABLE_BENCHMARK_TESTS, + nargs="+", type=str, - help='Upload results from summary collection to a specified bigquery table.' + help="Which microbenchmarks should be run", ) argp.add_argument( - '--summary_time', + "--bq_result_table", + default="", + type=str, + help=( + "Upload results from summary collection to a specified bigquery table." + ), +) +argp.add_argument( + "--summary_time", default=None, type=int, - help='Minimum time to run benchmarks for the summary collection') + help="Minimum time to run benchmarks for the summary collection", +) args = argp.parse_args() try: @@ -143,8 +161,8 @@ def collect_summary(bm_name, args): for bm_name in args.benchmarks: collectors[collect](bm_name, args) finally: - if not os.path.exists('reports'): - os.makedirs('reports') + if not os.path.exists("reports"): + os.makedirs("reports") index_html += "\n\n" - with open('reports/index.html', 'w') as f: + with open("reports/index.html", "w") as f: f.write(index_html) diff --git a/tools/run_tests/run_performance_tests.py b/tools/run_tests/run_performance_tests.py index 869d58266cabc..606d26778e3dc 100755 --- a/tools/run_tests/run_performance_tests.py +++ b/tools/run_tests/run_performance_tests.py @@ -37,10 +37,10 @@ import python_utils.jobset as jobset import python_utils.report_utils as report_utils -_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(_ROOT) -_REMOTE_HOST_USERNAME = 'jenkins' +_REMOTE_HOST_USERNAME = "jenkins" _SCENARIO_TIMEOUT = 3 * 60 _WORKER_TIMEOUT = 3 * 60 @@ -59,10 +59,9 @@ def __init__(self, spec, language, host_and_port, perf_file_base_name=None): self.perf_file_base_name = perf_file_base_name def start(self): - self._job = jobset.Job(self._spec, - newline_on_success=True, - travis=True, - add_env={}) + self._job = jobset.Job( + self._spec, newline_on_success=True, travis=True, add_env={} + ) def is_running(self): """Polls a job and returns True if given job is still running.""" @@ -74,93 +73,110 @@ def kill(self): self._job = None -def create_qpsworker_job(language, - shortname=None, - port=10000, - remote_host=None, - perf_cmd=None): - cmdline = (language.worker_cmdline() + ['--driver_port=%s' % port]) +def create_qpsworker_job( + language, shortname=None, port=10000, remote_host=None, perf_cmd=None +): + cmdline = language.worker_cmdline() + ["--driver_port=%s" % port] if remote_host: - host_and_port = '%s:%s' % (remote_host, port) + host_and_port = "%s:%s" % (remote_host, port) else: - host_and_port = 'localhost:%s' % port + host_and_port = "localhost:%s" % port perf_file_base_name = None if perf_cmd: - perf_file_base_name = '%s-%s' % (host_and_port, shortname) + perf_file_base_name = "%s-%s" % (host_and_port, shortname) # specify -o output file so perf.data gets collected when worker stopped - cmdline = perf_cmd + ['-o', '%s-perf.data' % perf_file_base_name - ] + cmdline + cmdline = ( + perf_cmd + ["-o", "%s-perf.data" % perf_file_base_name] + cmdline + ) worker_timeout = _WORKER_TIMEOUT if remote_host: - user_at_host = '%s@%s' % (_REMOTE_HOST_USERNAME, remote_host) - ssh_cmd = ['ssh'] - cmdline = ['timeout', '%s' % (worker_timeout + 30)] + cmdline - ssh_cmd.extend([ - str(user_at_host), - 'cd ~/performance_workspace/grpc/ && %s' % ' '.join(cmdline) - ]) + user_at_host = "%s@%s" % (_REMOTE_HOST_USERNAME, remote_host) + ssh_cmd = ["ssh"] + cmdline = ["timeout", "%s" % (worker_timeout + 30)] + cmdline + ssh_cmd.extend( + [ + str(user_at_host), + "cd ~/performance_workspace/grpc/ && %s" % " ".join(cmdline), + ] + ) cmdline = ssh_cmd jobspec = jobset.JobSpec( cmdline=cmdline, shortname=shortname, - timeout_seconds= - worker_timeout, # workers get restarted after each scenario - verbose_success=True) + timeout_seconds=worker_timeout, # workers get restarted after each scenario + verbose_success=True, + ) return QpsWorkerJob(jobspec, language, host_and_port, perf_file_base_name) -def create_scenario_jobspec(scenario_json, - workers, - remote_host=None, - bq_result_table=None, - server_cpu_load=0): +def create_scenario_jobspec( + scenario_json, + workers, + remote_host=None, + bq_result_table=None, + server_cpu_load=0, +): """Runs one scenario using QPS driver.""" # setting QPS_WORKERS env variable here makes sure it works with SSH too. - cmd = 'QPS_WORKERS="%s" ' % ','.join(workers) + cmd = 'QPS_WORKERS="%s" ' % ",".join(workers) if bq_result_table: cmd += 'BQ_RESULT_TABLE="%s" ' % bq_result_table - cmd += 'tools/run_tests/performance/run_qps_driver.sh ' - cmd += '--scenarios_json=%s ' % pipes.quote( - json.dumps({'scenarios': [scenario_json]})) - cmd += '--scenario_result_file=scenario_result.json ' + cmd += "tools/run_tests/performance/run_qps_driver.sh " + cmd += "--scenarios_json=%s " % pipes.quote( + json.dumps({"scenarios": [scenario_json]}) + ) + cmd += "--scenario_result_file=scenario_result.json " if server_cpu_load != 0: - cmd += '--search_param=offered_load --initial_search_value=1000 --targeted_cpu_load=%d --stride=500 --error_tolerance=0.01' % server_cpu_load + cmd += ( + "--search_param=offered_load --initial_search_value=1000" + " --targeted_cpu_load=%d --stride=500 --error_tolerance=0.01" + % server_cpu_load + ) if remote_host: - user_at_host = '%s@%s' % (_REMOTE_HOST_USERNAME, remote_host) + user_at_host = "%s@%s" % (_REMOTE_HOST_USERNAME, remote_host) cmd = 'ssh %s "cd ~/performance_workspace/grpc/ && "%s' % ( - user_at_host, pipes.quote(cmd)) - - return jobset.JobSpec(cmdline=[cmd], - shortname='%s' % scenario_json['name'], - timeout_seconds=_SCENARIO_TIMEOUT, - shell=True, - verbose_success=True) + user_at_host, + pipes.quote(cmd), + ) + + return jobset.JobSpec( + cmdline=[cmd], + shortname="%s" % scenario_json["name"], + timeout_seconds=_SCENARIO_TIMEOUT, + shell=True, + verbose_success=True, + ) def create_quit_jobspec(workers, remote_host=None): """Runs quit using QPS driver.""" # setting QPS_WORKERS env variable here makes sure it works with SSH too. - cmd = 'QPS_WORKERS="%s" cmake/build/qps_json_driver --quit' % ','.join( - w.host_and_port for w in workers) + cmd = 'QPS_WORKERS="%s" cmake/build/qps_json_driver --quit' % ",".join( + w.host_and_port for w in workers + ) if remote_host: - user_at_host = '%s@%s' % (_REMOTE_HOST_USERNAME, remote_host) + user_at_host = "%s@%s" % (_REMOTE_HOST_USERNAME, remote_host) cmd = 'ssh %s "cd ~/performance_workspace/grpc/ && "%s' % ( - user_at_host, pipes.quote(cmd)) - - return jobset.JobSpec(cmdline=[cmd], - shortname='shutdown_workers', - timeout_seconds=_QUIT_WORKER_TIMEOUT, - shell=True, - verbose_success=True) + user_at_host, + pipes.quote(cmd), + ) + + return jobset.JobSpec( + cmdline=[cmd], + shortname="shutdown_workers", + timeout_seconds=_QUIT_WORKER_TIMEOUT, + shell=True, + verbose_success=True, + ) -def create_netperf_jobspec(server_host='localhost', - client_host=None, - bq_result_table=None): +def create_netperf_jobspec( + server_host="localhost", client_host=None, bq_result_table=None +): """Runs netperf benchmark.""" cmd = 'NETPERF_SERVER_HOST="%s" ' % server_host if bq_result_table: @@ -169,52 +185,58 @@ def create_netperf_jobspec(server_host='localhost', # If netperf is running remotely, the env variables populated by Jenkins # won't be available on the client, but we need them for uploading results # to BigQuery. - jenkins_job_name = os.getenv('KOKORO_JOB_NAME') + jenkins_job_name = os.getenv("KOKORO_JOB_NAME") if jenkins_job_name: cmd += 'KOKORO_JOB_NAME="%s" ' % jenkins_job_name - jenkins_build_number = os.getenv('KOKORO_BUILD_NUMBER') + jenkins_build_number = os.getenv("KOKORO_BUILD_NUMBER") if jenkins_build_number: cmd += 'KOKORO_BUILD_NUMBER="%s" ' % jenkins_build_number - cmd += 'tools/run_tests/performance/run_netperf.sh' + cmd += "tools/run_tests/performance/run_netperf.sh" if client_host: - user_at_host = '%s@%s' % (_REMOTE_HOST_USERNAME, client_host) + user_at_host = "%s@%s" % (_REMOTE_HOST_USERNAME, client_host) cmd = 'ssh %s "cd ~/performance_workspace/grpc/ && "%s' % ( - user_at_host, pipes.quote(cmd)) - - return jobset.JobSpec(cmdline=[cmd], - shortname='netperf', - timeout_seconds=_NETPERF_TIMEOUT, - shell=True, - verbose_success=True) + user_at_host, + pipes.quote(cmd), + ) + + return jobset.JobSpec( + cmdline=[cmd], + shortname="netperf", + timeout_seconds=_NETPERF_TIMEOUT, + shell=True, + verbose_success=True, + ) def archive_repo(languages): """Archives local version of repo including submodules.""" - cmdline = ['tar', '-cf', '../grpc.tar', '../grpc/'] - if 'java' in languages: - cmdline.append('../grpc-java') - if 'go' in languages: - cmdline.append('../grpc-go') - if 'node' in languages or 'node_purejs' in languages: - cmdline.append('../grpc-node') - - archive_job = jobset.JobSpec(cmdline=cmdline, - shortname='archive_repo', - timeout_seconds=3 * 60) - - jobset.message('START', 'Archiving local repository.', do_newline=True) - num_failures, _ = jobset.run([archive_job], - newline_on_success=True, - maxjobs=1) + cmdline = ["tar", "-cf", "../grpc.tar", "../grpc/"] + if "java" in languages: + cmdline.append("../grpc-java") + if "go" in languages: + cmdline.append("../grpc-go") + if "node" in languages or "node_purejs" in languages: + cmdline.append("../grpc-node") + + archive_job = jobset.JobSpec( + cmdline=cmdline, shortname="archive_repo", timeout_seconds=3 * 60 + ) + + jobset.message("START", "Archiving local repository.", do_newline=True) + num_failures, _ = jobset.run( + [archive_job], newline_on_success=True, maxjobs=1 + ) if num_failures == 0: - jobset.message('SUCCESS', - 'Archive with local repository created successfully.', - do_newline=True) + jobset.message( + "SUCCESS", + "Archive with local repository created successfully.", + do_newline=True, + ) else: - jobset.message('FAILED', - 'Failed to archive local repository.', - do_newline=True) + jobset.message( + "FAILED", "Failed to archive local repository.", do_newline=True + ) sys.exit(1) @@ -223,78 +245,85 @@ def prepare_remote_hosts(hosts, prepare_local=False): prepare_timeout = 10 * 60 prepare_jobs = [] for host in hosts: - user_at_host = '%s@%s' % (_REMOTE_HOST_USERNAME, host) + user_at_host = "%s@%s" % (_REMOTE_HOST_USERNAME, host) prepare_jobs.append( jobset.JobSpec( - cmdline=['tools/run_tests/performance/remote_host_prepare.sh'], - shortname='remote_host_prepare.%s' % host, - environ={'USER_AT_HOST': user_at_host}, - timeout_seconds=prepare_timeout)) + cmdline=["tools/run_tests/performance/remote_host_prepare.sh"], + shortname="remote_host_prepare.%s" % host, + environ={"USER_AT_HOST": user_at_host}, + timeout_seconds=prepare_timeout, + ) + ) if prepare_local: # Prepare localhost as well prepare_jobs.append( jobset.JobSpec( - cmdline=['tools/run_tests/performance/kill_workers.sh'], - shortname='local_prepare', - timeout_seconds=prepare_timeout)) - jobset.message('START', 'Preparing hosts.', do_newline=True) - num_failures, _ = jobset.run(prepare_jobs, - newline_on_success=True, - maxjobs=10) + cmdline=["tools/run_tests/performance/kill_workers.sh"], + shortname="local_prepare", + timeout_seconds=prepare_timeout, + ) + ) + jobset.message("START", "Preparing hosts.", do_newline=True) + num_failures, _ = jobset.run( + prepare_jobs, newline_on_success=True, maxjobs=10 + ) if num_failures == 0: - jobset.message('SUCCESS', - 'Prepare step completed successfully.', - do_newline=True) + jobset.message( + "SUCCESS", "Prepare step completed successfully.", do_newline=True + ) else: - jobset.message('FAILED', - 'Failed to prepare remote hosts.', - do_newline=True) + jobset.message( + "FAILED", "Failed to prepare remote hosts.", do_newline=True + ) sys.exit(1) -def build_on_remote_hosts(hosts, - languages=list(scenario_config.LANGUAGES.keys()), - build_local=False): +def build_on_remote_hosts( + hosts, languages=list(scenario_config.LANGUAGES.keys()), build_local=False +): """Builds performance worker on remote hosts (and maybe also locally).""" build_timeout = 45 * 60 # Kokoro VMs (which are local only) do not have caching, so they need more time to build local_build_timeout = 60 * 60 build_jobs = [] for host in hosts: - user_at_host = '%s@%s' % (_REMOTE_HOST_USERNAME, host) + user_at_host = "%s@%s" % (_REMOTE_HOST_USERNAME, host) build_jobs.append( jobset.JobSpec( - cmdline=['tools/run_tests/performance/remote_host_build.sh'] + - languages, - shortname='remote_host_build.%s' % host, - environ={ - 'USER_AT_HOST': user_at_host, - 'CONFIG': 'opt' - }, - timeout_seconds=build_timeout)) + cmdline=["tools/run_tests/performance/remote_host_build.sh"] + + languages, + shortname="remote_host_build.%s" % host, + environ={"USER_AT_HOST": user_at_host, "CONFIG": "opt"}, + timeout_seconds=build_timeout, + ) + ) if build_local: # start port server locally build_jobs.append( jobset.JobSpec( - cmdline=['python', 'tools/run_tests/start_port_server.py'], - shortname='local_start_port_server', - timeout_seconds=2 * 60)) + cmdline=["python", "tools/run_tests/start_port_server.py"], + shortname="local_start_port_server", + timeout_seconds=2 * 60, + ) + ) # Build locally as well build_jobs.append( jobset.JobSpec( - cmdline=['tools/run_tests/performance/build_performance.sh'] + - languages, - shortname='local_build', - environ={'CONFIG': 'opt'}, - timeout_seconds=local_build_timeout)) - jobset.message('START', 'Building.', do_newline=True) - num_failures, _ = jobset.run(build_jobs, - newline_on_success=True, - maxjobs=10) + cmdline=["tools/run_tests/performance/build_performance.sh"] + + languages, + shortname="local_build", + environ={"CONFIG": "opt"}, + timeout_seconds=local_build_timeout, + ) + ) + jobset.message("START", "Building.", do_newline=True) + num_failures, _ = jobset.run( + build_jobs, newline_on_success=True, maxjobs=10 + ) if num_failures == 0: - jobset.message('SUCCESS', 'Built successfully.', do_newline=True) + jobset.message("SUCCESS", "Built successfully.", do_newline=True) else: - jobset.message('FAILED', 'Build failed.', do_newline=True) + jobset.message("FAILED", "Build failed.", do_newline=True) sys.exit(1) @@ -311,51 +340,69 @@ def create_qpsworkers(languages, worker_hosts, perf_cmd=None): workers = [(worker_host, 10000) for worker_host in worker_hosts] return [ - create_qpsworker_job(language, - shortname='qps_worker_%s_%s' % - (language, worker_idx), - port=worker[1] + language.worker_port_offset(), - remote_host=worker[0], - perf_cmd=perf_cmd) + create_qpsworker_job( + language, + shortname="qps_worker_%s_%s" % (language, worker_idx), + port=worker[1] + language.worker_port_offset(), + remote_host=worker[0], + perf_cmd=perf_cmd, + ) for language in languages for worker_idx, worker in enumerate(workers) ] -def perf_report_processor_job(worker_host, perf_base_name, output_filename, - flame_graph_reports): - print('Creating perf report collection job for %s' % worker_host) - cmd = '' - if worker_host != 'localhost': +def perf_report_processor_job( + worker_host, perf_base_name, output_filename, flame_graph_reports +): + print("Creating perf report collection job for %s" % worker_host) + cmd = "" + if worker_host != "localhost": user_at_host = "%s@%s" % (_REMOTE_HOST_USERNAME, worker_host) - cmd = "USER_AT_HOST=%s OUTPUT_FILENAME=%s OUTPUT_DIR=%s PERF_BASE_NAME=%s tools/run_tests/performance/process_remote_perf_flamegraphs.sh" % ( - user_at_host, output_filename, flame_graph_reports, perf_base_name) + cmd = ( + "USER_AT_HOST=%s OUTPUT_FILENAME=%s OUTPUT_DIR=%s PERF_BASE_NAME=%s" + " tools/run_tests/performance/process_remote_perf_flamegraphs.sh" + % ( + user_at_host, + output_filename, + flame_graph_reports, + perf_base_name, + ) + ) else: - cmd = "OUTPUT_FILENAME=%s OUTPUT_DIR=%s PERF_BASE_NAME=%s tools/run_tests/performance/process_local_perf_flamegraphs.sh" % ( - output_filename, flame_graph_reports, perf_base_name) - - return jobset.JobSpec(cmdline=cmd, - timeout_seconds=3 * 60, - shell=True, - verbose_success=True, - shortname='process perf report') + cmd = ( + "OUTPUT_FILENAME=%s OUTPUT_DIR=%s PERF_BASE_NAME=%s" + " tools/run_tests/performance/process_local_perf_flamegraphs.sh" + % (output_filename, flame_graph_reports, perf_base_name) + ) + + return jobset.JobSpec( + cmdline=cmd, + timeout_seconds=3 * 60, + shell=True, + verbose_success=True, + shortname="process perf report", + ) -Scenario = collections.namedtuple('Scenario', 'jobspec workers name') +Scenario = collections.namedtuple("Scenario", "jobspec workers name") -def create_scenarios(languages, - workers_by_lang, - remote_host=None, - regex='.*', - category='all', - bq_result_table=None, - netperf=False, - netperf_hosts=[], - server_cpu_load=0): +def create_scenarios( + languages, + workers_by_lang, + remote_host=None, + regex=".*", + category="all", + bq_result_table=None, + netperf=False, + netperf_hosts=[], + server_cpu_load=0, +): """Create jobspecs for scenarios to run.""" all_workers = [ - worker for workers in list(workers_by_lang.values()) + worker + for workers in list(workers_by_lang.values()) for worker in workers ] scenarios = [] @@ -363,7 +410,7 @@ def create_scenarios(languages, if netperf: if not netperf_hosts: - netperf_server = 'localhost' + netperf_server = "localhost" netperf_client = None elif len(netperf_hosts) == 1: netperf_server = netperf_hosts[0] @@ -373,65 +420,88 @@ def create_scenarios(languages, netperf_client = netperf_hosts[1] scenarios.append( Scenario( - create_netperf_jobspec(server_host=netperf_server, - client_host=netperf_client, - bq_result_table=bq_result_table), - _NO_WORKERS, 'netperf')) + create_netperf_jobspec( + server_host=netperf_server, + client_host=netperf_client, + bq_result_table=bq_result_table, + ), + _NO_WORKERS, + "netperf", + ) + ) for language in languages: for scenario_json in language.scenarios(): - if re.search(regex, scenario_json['name']): - categories = scenario_json.get('CATEGORIES', - ['scalable', 'smoketest']) - if category in categories or category == 'all': + if re.search(regex, scenario_json["name"]): + categories = scenario_json.get( + "CATEGORIES", ["scalable", "smoketest"] + ) + if category in categories or category == "all": workers = workers_by_lang[str(language)][:] # 'SERVER_LANGUAGE' is an indicator for this script to pick # a server in different language. custom_server_lang = scenario_json.get( - 'SERVER_LANGUAGE', None) + "SERVER_LANGUAGE", None + ) custom_client_lang = scenario_json.get( - 'CLIENT_LANGUAGE', None) + "CLIENT_LANGUAGE", None + ) scenario_json = scenario_config.remove_nonproto_fields( - scenario_json) + scenario_json + ) if custom_server_lang and custom_client_lang: raise Exception( - 'Cannot set both custom CLIENT_LANGUAGE and SERVER_LANGUAGE' - 'in the same scenario') + "Cannot set both custom CLIENT_LANGUAGE and" + " SERVER_LANGUAGEin the same scenario" + ) if custom_server_lang: if not workers_by_lang.get(custom_server_lang, []): - print('Warning: Skipping scenario %s as' % - scenario_json['name']) print( - 'SERVER_LANGUAGE is set to %s yet the language has ' - 'not been selected with -l' % - custom_server_lang) + "Warning: Skipping scenario %s as" + % scenario_json["name"] + ) + print( + "SERVER_LANGUAGE is set to %s yet the language" + " has not been selected with -l" + % custom_server_lang + ) continue - for idx in range(0, scenario_json['num_servers']): + for idx in range(0, scenario_json["num_servers"]): # replace first X workers by workers of a different language workers[idx] = workers_by_lang[custom_server_lang][ - idx] + idx + ] if custom_client_lang: if not workers_by_lang.get(custom_client_lang, []): - print('Warning: Skipping scenario %s as' % - scenario_json['name']) print( - 'CLIENT_LANGUAGE is set to %s yet the language has ' - 'not been selected with -l' % - custom_client_lang) + "Warning: Skipping scenario %s as" + % scenario_json["name"] + ) + print( + "CLIENT_LANGUAGE is set to %s yet the language" + " has not been selected with -l" + % custom_client_lang + ) continue - for idx in range(scenario_json['num_servers'], - len(workers)): + for idx in range( + scenario_json["num_servers"], len(workers) + ): # replace all client workers by workers of a different language, # leave num_server workers as they are server workers. workers[idx] = workers_by_lang[custom_client_lang][ - idx] + idx + ] scenario = Scenario( create_scenario_jobspec( - scenario_json, [w.host_and_port for w in workers], + scenario_json, + [w.host_and_port for w in workers], remote_host=remote_host, bq_result_table=bq_result_table, - server_cpu_load=server_cpu_load), workers, - scenario_json['name']) + server_cpu_load=server_cpu_load, + ), + workers, + scenario_json["name"], + ) scenarios.append(scenario) return scenarios @@ -446,13 +516,13 @@ def finish_qps_workers(jobs, qpsworker_jobs): if job.is_running(): print('QPS worker "%s" is still running.' % job.host_and_port) if retries > 10: - print('Killing all QPS workers.') + print("Killing all QPS workers.") for job in jobs: job.kill() num_killed += 1 retries += 1 time.sleep(3) - print('All QPS workers finished.') + print("All QPS workers finished.") return num_killed @@ -465,127 +535,158 @@ def finish_qps_workers(jobs, qpsworker_jobs): # perf reports directory. # Alos, the perf profiles need to be fetched and processed after each scenario # in order to avoid clobbering the output files. -def run_collect_perf_profile_jobs(hosts_and_base_names, scenario_name, - flame_graph_reports): +def run_collect_perf_profile_jobs( + hosts_and_base_names, scenario_name, flame_graph_reports +): perf_report_jobs = [] global profile_output_files for host_and_port in hosts_and_base_names: perf_base_name = hosts_and_base_names[host_and_port] - output_filename = '%s-%s' % (scenario_name, perf_base_name) + output_filename = "%s-%s" % (scenario_name, perf_base_name) # from the base filename, create .svg output filename - host = host_and_port.split(':')[0] - profile_output_files.append('%s.svg' % output_filename) + host = host_and_port.split(":")[0] + profile_output_files.append("%s.svg" % output_filename) perf_report_jobs.append( - perf_report_processor_job(host, perf_base_name, output_filename, - flame_graph_reports)) - - jobset.message('START', - 'Collecting perf reports from qps workers', - do_newline=True) - failures, _ = jobset.run(perf_report_jobs, - newline_on_success=True, - maxjobs=1) - jobset.message('SUCCESS', - 'Collecting perf reports from qps workers', - do_newline=True) + perf_report_processor_job( + host, perf_base_name, output_filename, flame_graph_reports + ) + ) + + jobset.message( + "START", "Collecting perf reports from qps workers", do_newline=True + ) + failures, _ = jobset.run( + perf_report_jobs, newline_on_success=True, maxjobs=1 + ) + jobset.message( + "SUCCESS", "Collecting perf reports from qps workers", do_newline=True + ) return failures def main(): - argp = argparse.ArgumentParser(description='Run performance tests.') - argp.add_argument('-l', - '--language', - choices=['all'] + - sorted(scenario_config.LANGUAGES.keys()), - nargs='+', - required=True, - help='Languages to benchmark.') + argp = argparse.ArgumentParser(description="Run performance tests.") + argp.add_argument( + "-l", + "--language", + choices=["all"] + sorted(scenario_config.LANGUAGES.keys()), + nargs="+", + required=True, + help="Languages to benchmark.", + ) argp.add_argument( - '--remote_driver_host', + "--remote_driver_host", default=None, - help= - 'Run QPS driver on given host. By default, QPS driver is run locally.') - argp.add_argument('--remote_worker_host', - nargs='+', - default=[], - help='Worker hosts where to start QPS workers.') + help=( + "Run QPS driver on given host. By default, QPS driver is run" + " locally." + ), + ) + argp.add_argument( + "--remote_worker_host", + nargs="+", + default=[], + help="Worker hosts where to start QPS workers.", + ) argp.add_argument( - '--dry_run', + "--dry_run", default=False, - action='store_const', + action="store_const", const=True, - help='Just list scenarios to be run, but don\'t run them.') - argp.add_argument('-r', - '--regex', - default='.*', - type=str, - help='Regex to select scenarios to run.') - argp.add_argument('--bq_result_table', - default=None, - type=str, - help='Bigquery "dataset.table" to upload results to.') - argp.add_argument('--category', - choices=['smoketest', 'all', 'scalable', 'sweep'], - default='all', - help='Select a category of tests to run.') - argp.add_argument('--netperf', - default=False, - action='store_const', - const=True, - help='Run netperf benchmark as one of the scenarios.') + help="Just list scenarios to be run, but don't run them.", + ) + argp.add_argument( + "-r", + "--regex", + default=".*", + type=str, + help="Regex to select scenarios to run.", + ) + argp.add_argument( + "--bq_result_table", + default=None, + type=str, + help='Bigquery "dataset.table" to upload results to.', + ) + argp.add_argument( + "--category", + choices=["smoketest", "all", "scalable", "sweep"], + default="all", + help="Select a category of tests to run.", + ) + argp.add_argument( + "--netperf", + default=False, + action="store_const", + const=True, + help="Run netperf benchmark as one of the scenarios.", + ) argp.add_argument( - '--server_cpu_load', + "--server_cpu_load", default=0, type=int, - help='Select a targeted server cpu load to run. 0 means ignore this flag' + help=( + "Select a targeted server cpu load to run. 0 means ignore this flag" + ), ) - argp.add_argument('-x', - '--xml_report', - default='report.xml', - type=str, - help='Name of XML report file to generate.') argp.add_argument( - '--perf_args', - help=('Example usage: "--perf_args=record -F 99 -g". ' - 'Wrap QPS workers in a perf command ' - 'with the arguments to perf specified here. ' - '".svg" flame graph profiles will be ' - 'created for each Qps Worker on each scenario. ' - 'Files will output to "/" ' - 'directory. Output files from running the worker ' - 'under perf are saved in the repo root where its ran. ' - 'Note that the perf "-g" flag is necessary for ' - 'flame graphs generation to work (assuming the binary ' - 'being profiled uses frame pointers, check out ' - '"--call-graph dwarf" option using libunwind otherwise.) ' - 'Also note that the entire "--perf_args=" must ' - 'be wrapped in quotes as in the example usage. ' - 'If the "--perg_args" is unspecified, "perf" will ' - 'not be used at all. ' - 'See http://www.brendangregg.com/perf.html ' - 'for more general perf examples.')) + "-x", + "--xml_report", + default="report.xml", + type=str, + help="Name of XML report file to generate.", + ) + argp.add_argument( + "--perf_args", + help=( + 'Example usage: "--perf_args=record -F 99 -g". ' + "Wrap QPS workers in a perf command " + "with the arguments to perf specified here. " + '".svg" flame graph profiles will be ' + "created for each Qps Worker on each scenario. " + 'Files will output to "/" ' + "directory. Output files from running the worker " + "under perf are saved in the repo root where its ran. " + 'Note that the perf "-g" flag is necessary for ' + "flame graphs generation to work (assuming the binary " + "being profiled uses frame pointers, check out " + '"--call-graph dwarf" option using libunwind otherwise.) ' + 'Also note that the entire "--perf_args=" must ' + "be wrapped in quotes as in the example usage. " + 'If the "--perg_args" is unspecified, "perf" will ' + "not be used at all. " + "See http://www.brendangregg.com/perf.html " + "for more general perf examples." + ), + ) argp.add_argument( - '--skip_generate_flamegraphs', + "--skip_generate_flamegraphs", default=False, - action='store_const', + action="store_const", const=True, - help=('Turn flame graph generation off. ' - 'May be useful if "perf_args" arguments do not make sense for ' - 'generating flamegraphs (e.g., "--perf_args=stat ...")')) + help=( + "Turn flame graph generation off. " + 'May be useful if "perf_args" arguments do not make sense for ' + 'generating flamegraphs (e.g., "--perf_args=stat ...")' + ), + ) argp.add_argument( - '-f', - '--flame_graph_reports', - default='perf_reports', + "-f", + "--flame_graph_reports", + default="perf_reports", type=str, - help= - 'Name of directory to output flame graph profiles to, if any are created.' + help=( + "Name of directory to output flame graph profiles to, if any are" + " created." + ), ) argp.add_argument( - '-u', - '--remote_host_username', - default='', + "-u", + "--remote_host_username", + default="", type=str, - help='Use a username that isn\'t "Jenkins" to SSH into remote workers.') + help='Use a username that isn\'t "Jenkins" to SSH into remote workers.', + ) args = argp.parse_args() @@ -594,9 +695,12 @@ def main(): _REMOTE_HOST_USERNAME = args.remote_host_username languages = set( - scenario_config.LANGUAGES[l] for l in itertools.chain.from_iterable( - six.iterkeys(scenario_config.LANGUAGES) if x == 'all' else [x] - for x in args.language)) + scenario_config.LANGUAGES[l] + for l in itertools.chain.from_iterable( + six.iterkeys(scenario_config.LANGUAGES) if x == "all" else [x] + for x in args.language + ) + ) # Put together set of remote hosts where to run and build remote_hosts = set() @@ -617,38 +721,42 @@ def main(): if not args.remote_driver_host: build_local = True if not args.dry_run: - build_on_remote_hosts(remote_hosts, - languages=[str(l) for l in languages], - build_local=build_local) + build_on_remote_hosts( + remote_hosts, + languages=[str(l) for l in languages], + build_local=build_local, + ) perf_cmd = None if args.perf_args: - print('Running workers under perf profiler') + print("Running workers under perf profiler") # Expect /usr/bin/perf to be installed here, as is usual - perf_cmd = ['/usr/bin/perf'] - perf_cmd.extend(re.split('\s+', args.perf_args)) + perf_cmd = ["/usr/bin/perf"] + perf_cmd.extend(re.split("\s+", args.perf_args)) - qpsworker_jobs = create_qpsworkers(languages, - args.remote_worker_host, - perf_cmd=perf_cmd) + qpsworker_jobs = create_qpsworkers( + languages, args.remote_worker_host, perf_cmd=perf_cmd + ) # get list of worker addresses for each language. workers_by_lang = dict([(str(language), []) for language in languages]) for job in qpsworker_jobs: workers_by_lang[str(job.language)].append(job) - scenarios = create_scenarios(languages, - workers_by_lang=workers_by_lang, - remote_host=args.remote_driver_host, - regex=args.regex, - category=args.category, - bq_result_table=args.bq_result_table, - netperf=args.netperf, - netperf_hosts=args.remote_worker_host, - server_cpu_load=args.server_cpu_load) + scenarios = create_scenarios( + languages, + workers_by_lang=workers_by_lang, + remote_host=args.remote_driver_host, + regex=args.regex, + category=args.category, + bq_result_table=args.bq_result_table, + netperf=args.netperf, + netperf_hosts=args.remote_worker_host, + server_cpu_load=args.server_cpu_load, + ) if not scenarios: - raise Exception('No scenarios to run') + raise Exception("No scenarios to run") total_scenario_failures = 0 qps_workers_killed = 0 @@ -670,50 +778,69 @@ def main(): jobs.append( create_quit_jobspec( scenario.workers, - remote_host=args.remote_driver_host)) + remote_host=args.remote_driver_host, + ) + ) scenario_failures, resultset = jobset.run( - jobs, newline_on_success=True, maxjobs=1) + jobs, newline_on_success=True, maxjobs=1 + ) total_scenario_failures += scenario_failures merged_resultset = dict( - itertools.chain(six.iteritems(merged_resultset), - six.iteritems(resultset))) + itertools.chain( + six.iteritems(merged_resultset), + six.iteritems(resultset), + ) + ) finally: # Consider qps workers that need to be killed as failures qps_workers_killed += finish_qps_workers( - scenario.workers, qpsworker_jobs) - - if perf_cmd and scenario_failures == 0 and not args.skip_generate_flamegraphs: + scenario.workers, qpsworker_jobs + ) + + if ( + perf_cmd + and scenario_failures == 0 + and not args.skip_generate_flamegraphs + ): workers_and_base_names = {} for worker in scenario.workers: if not worker.perf_file_base_name: raise Exception( - 'using perf buf perf report filename is unspecified' + "using perf buf perf report filename is unspecified" ) workers_and_base_names[ - worker.host_and_port] = worker.perf_file_base_name + worker.host_and_port + ] = worker.perf_file_base_name perf_report_failures += run_collect_perf_profile_jobs( - workers_and_base_names, scenario.name, - args.flame_graph_reports) + workers_and_base_names, + scenario.name, + args.flame_graph_reports, + ) # Still write the index.html even if some scenarios failed. # 'profile_output_files' will only have names for scenarios that passed if perf_cmd and not args.skip_generate_flamegraphs: # write the index fil to the output dir, with all profiles from all scenarios/workers report_utils.render_perf_profiling_results( - '%s/index.html' % args.flame_graph_reports, profile_output_files) - - report_utils.render_junit_xml_report(merged_resultset, - args.xml_report, - suite_name='benchmarks', - multi_target=True) + "%s/index.html" % args.flame_graph_reports, profile_output_files + ) + + report_utils.render_junit_xml_report( + merged_resultset, + args.xml_report, + suite_name="benchmarks", + multi_target=True, + ) if total_scenario_failures > 0 or qps_workers_killed > 0: - print('%s scenarios failed and %s qps worker jobs killed' % - (total_scenario_failures, qps_workers_killed)) + print( + "%s scenarios failed and %s qps worker jobs killed" + % (total_scenario_failures, qps_workers_killed) + ) sys.exit(1) if perf_report_failures > 0: - print('%s perf profile collection jobs failed' % perf_report_failures) + print("%s perf profile collection jobs failed" % perf_report_failures) sys.exit(1) diff --git a/tools/run_tests/run_tests.py b/tools/run_tests/run_tests.py index 9b2bf25b47beb..b027415335eb0 100755 --- a/tools/run_tests/run_tests.py +++ b/tools/run_tests/run_tests.py @@ -48,23 +48,24 @@ try: from python_utils.upload_test_results import upload_results_to_bq -except (ImportError): +except ImportError: pass # It's ok to not import because this is only necessary to upload results to BQ. gcp_utils_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), '../gcp/utils')) + os.path.join(os.path.dirname(__file__), "../gcp/utils") +) sys.path.append(gcp_utils_dir) -_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(_ROOT) _FORCE_ENVIRON_FOR_WRAPPERS = { - 'GRPC_VERBOSITY': 'DEBUG', + "GRPC_VERBOSITY": "DEBUG", } _POLLING_STRATEGIES = { - 'linux': ['epoll1', 'poll'], - 'mac': ['poll'], + "linux": ["epoll1", "poll"], + "mac": ["poll"], } @@ -82,62 +83,68 @@ def run_shell_command(cmd, env=None, cwd=None): except subprocess.CalledProcessError as e: logging.exception( "Error while running command '%s'. Exit status %d. Output:\n%s", - e.cmd, e.returncode, e.output) + e.cmd, + e.returncode, + e.output, + ) raise def max_parallel_tests_for_current_platform(): # Too much test parallelization has only been seen to be a problem # so far on windows. - if jobset.platform_string() == 'windows': + if jobset.platform_string() == "windows": return 64 return 1024 def _print_debug_info_epilogue(dockerfile_dir=None): """Use to print useful info for debug/repro just before exiting.""" - print('') - print('=== run_tests.py DEBUG INFO ===') - print('command: \"%s\"' % ' '.join(sys.argv)) + print("") + print("=== run_tests.py DEBUG INFO ===") + print('command: "%s"' % " ".join(sys.argv)) if dockerfile_dir: - print('dockerfile: %s' % dockerfile_dir) - kokoro_job_name = os.getenv('KOKORO_JOB_NAME') + print("dockerfile: %s" % dockerfile_dir) + kokoro_job_name = os.getenv("KOKORO_JOB_NAME") if kokoro_job_name: - print('kokoro job name: %s' % kokoro_job_name) - print('===============================') + print("kokoro job name: %s" % kokoro_job_name) + print("===============================") # SimpleConfig: just compile with CONFIG=config, and run the binary to test class Config(object): - - def __init__(self, - config, - environ=None, - timeout_multiplier=1, - tool_prefix=[], - iomgr_platform='native'): + def __init__( + self, + config, + environ=None, + timeout_multiplier=1, + tool_prefix=[], + iomgr_platform="native", + ): if environ is None: environ = {} self.build_config = config self.environ = environ - self.environ['CONFIG'] = config + self.environ["CONFIG"] = config self.tool_prefix = tool_prefix self.timeout_multiplier = timeout_multiplier self.iomgr_platform = iomgr_platform - def job_spec(self, - cmdline, - timeout_seconds=_DEFAULT_TIMEOUT_SECONDS, - shortname=None, - environ={}, - cpu_cost=1.0, - flaky=False): + def job_spec( + self, + cmdline, + timeout_seconds=_DEFAULT_TIMEOUT_SECONDS, + shortname=None, + environ={}, + cpu_cost=1.0, + flaky=False, + ): """Construct a jobset.JobSpec for a test under this config - Args: - cmdline: a list of strings specifying the command line the test - would like to run - """ + Args: + cmdline: a list of strings specifying the command line the test + would like to run + """ actual_environ = self.environ.copy() for k, v in environ.items(): actual_environ[k] = v @@ -150,96 +157,121 @@ def job_spec(self, shortname=shortname, environ=actual_environ, cpu_cost=cpu_cost, - timeout_seconds=(self.timeout_multiplier * - timeout_seconds if timeout_seconds else None), + timeout_seconds=( + self.timeout_multiplier * timeout_seconds + if timeout_seconds + else None + ), flake_retries=4 if flaky or args.allow_flakes else 0, - timeout_retries=1 if flaky or args.allow_flakes else 0) + timeout_retries=1 if flaky or args.allow_flakes else 0, + ) def get_c_tests(travis, test_lang): out = [] - platforms_str = 'ci_platforms' if travis else 'platforms' - with open('tools/run_tests/generated/tests.json') as f: + platforms_str = "ci_platforms" if travis else "platforms" + with open("tools/run_tests/generated/tests.json") as f: js = json.load(f) return [ - tgt for tgt in js - if tgt['language'] == test_lang and platform_string() in - tgt[platforms_str] and not (travis and tgt['flaky']) + tgt + for tgt in js + if tgt["language"] == test_lang + and platform_string() in tgt[platforms_str] + and not (travis and tgt["flaky"]) ] def _check_compiler(compiler, supported_compilers): if compiler not in supported_compilers: - raise Exception('Compiler %s not supported (on this platform).' % - compiler) + raise Exception( + "Compiler %s not supported (on this platform)." % compiler + ) def _check_arch(arch, supported_archs): if arch not in supported_archs: - raise Exception('Architecture %s not supported.' % arch) + raise Exception("Architecture %s not supported." % arch) def _is_use_docker_child(): """Returns True if running running as a --use_docker child.""" - return True if os.getenv('DOCKER_RUN_SCRIPT_COMMAND') else False - - -_PythonConfigVars = collections.namedtuple('_ConfigVars', [ - 'shell', - 'builder', - 'builder_prefix_arguments', - 'venv_relative_python', - 'toolchain', - 'runner', -]) + return True if os.getenv("DOCKER_RUN_SCRIPT_COMMAND") else False + + +_PythonConfigVars = collections.namedtuple( + "_ConfigVars", + [ + "shell", + "builder", + "builder_prefix_arguments", + "venv_relative_python", + "toolchain", + "runner", + ], +) def _python_config_generator(name, major, minor, bits, config_vars): - build = (config_vars.shell + config_vars.builder + - config_vars.builder_prefix_arguments + - [_python_pattern_function(major=major, minor=minor, bits=bits)] + - [name] + config_vars.venv_relative_python + config_vars.toolchain) - run = (config_vars.shell + config_vars.runner + [ - os.path.join(name, config_vars.venv_relative_python[0]), - ]) + build = ( + config_vars.shell + + config_vars.builder + + config_vars.builder_prefix_arguments + + [_python_pattern_function(major=major, minor=minor, bits=bits)] + + [name] + + config_vars.venv_relative_python + + config_vars.toolchain + ) + run = ( + config_vars.shell + + config_vars.runner + + [ + os.path.join(name, config_vars.venv_relative_python[0]), + ] + ) return PythonConfig(name, build, run) def _pypy_config_generator(name, major, config_vars): return PythonConfig( - name, config_vars.shell + config_vars.builder + - config_vars.builder_prefix_arguments + - [_pypy_pattern_function(major=major)] + [name] + - config_vars.venv_relative_python + config_vars.toolchain, - config_vars.shell + config_vars.runner + - [os.path.join(name, config_vars.venv_relative_python[0])]) + name, + config_vars.shell + + config_vars.builder + + config_vars.builder_prefix_arguments + + [_pypy_pattern_function(major=major)] + + [name] + + config_vars.venv_relative_python + + config_vars.toolchain, + config_vars.shell + + config_vars.runner + + [os.path.join(name, config_vars.venv_relative_python[0])], + ) def _python_pattern_function(major, minor, bits): # Bit-ness is handled by the test machine's environment if os.name == "nt": if bits == "64": - return '/c/Python{major}{minor}/python.exe'.format(major=major, - minor=minor, - bits=bits) + return "/c/Python{major}{minor}/python.exe".format( + major=major, minor=minor, bits=bits + ) else: - return '/c/Python{major}{minor}_{bits}bits/python.exe'.format( - major=major, minor=minor, bits=bits) + return "/c/Python{major}{minor}_{bits}bits/python.exe".format( + major=major, minor=minor, bits=bits + ) else: - return 'python{major}.{minor}'.format(major=major, minor=minor) + return "python{major}.{minor}".format(major=major, minor=minor) def _pypy_pattern_function(major): - if major == '2': - return 'pypy' - elif major == '3': - return 'pypy3' + if major == "2": + return "pypy" + elif major == "3": + return "pypy3" else: raise ValueError("Unknown PyPy major version") class CLanguage(object): - def __init__(self, lang_suffix, test_lang): self.lang_suffix = lang_suffix self.platform = platform_string() @@ -248,105 +280,129 @@ def __init__(self, lang_suffix, test_lang): def configure(self, config, args): self.config = config self.args = args - if self.platform == 'windows': - _check_compiler(self.args.compiler, [ - 'default', - 'cmake', - 'cmake_ninja_vs2019', - 'cmake_vs2019', - ]) - _check_arch(self.args.arch, ['default', 'x64', 'x86']) - - activate_vs_tools = '' - if self.args.compiler == 'cmake_ninja_vs2019' or self.args.compiler == 'cmake' or self.args.compiler == 'default': + if self.platform == "windows": + _check_compiler( + self.args.compiler, + [ + "default", + "cmake", + "cmake_ninja_vs2019", + "cmake_vs2019", + ], + ) + _check_arch(self.args.arch, ["default", "x64", "x86"]) + + activate_vs_tools = "" + if ( + self.args.compiler == "cmake_ninja_vs2019" + or self.args.compiler == "cmake" + or self.args.compiler == "default" + ): # cmake + ninja build is the default because it is faster and supports boringssl assembly optimizations # the compiler used is exactly the same as for cmake_vs2017 - cmake_generator = 'Ninja' - activate_vs_tools = '2019' - elif self.args.compiler == 'cmake_vs2019': - cmake_generator = 'Visual Studio 16 2019' + cmake_generator = "Ninja" + activate_vs_tools = "2019" + elif self.args.compiler == "cmake_vs2019": + cmake_generator = "Visual Studio 16 2019" else: - print('should never reach here.') + print("should never reach here.") sys.exit(1) self._cmake_configure_extra_args = [] self._cmake_generator_windows = cmake_generator # required to pass as cmake "-A" configuration for VS builds (but not for Ninja) - self._cmake_architecture_windows = 'x64' if self.args.arch == 'x64' else 'Win32' + self._cmake_architecture_windows = ( + "x64" if self.args.arch == "x64" else "Win32" + ) # when builing with Ninja, the VS common tools need to be activated first self._activate_vs_tools_windows = activate_vs_tools # "x64_x86" means create 32bit binaries, but use 64bit toolkit to secure more memory for the build - self._vs_tools_architecture_windows = 'x64' if self.args.arch == 'x64' else 'x64_x86' + self._vs_tools_architecture_windows = ( + "x64" if self.args.arch == "x64" else "x64_x86" + ) else: - if self.platform == 'linux': + if self.platform == "linux": # Allow all the known architectures. _check_arch_option has already checked that we're not doing # something illegal when not running under docker. - _check_arch(self.args.arch, ['default', 'x64', 'x86', 'arm64']) + _check_arch(self.args.arch, ["default", "x64", "x86", "arm64"]) else: - _check_arch(self.args.arch, ['default']) + _check_arch(self.args.arch, ["default"]) - self._docker_distro, self._cmake_configure_extra_args = self._compiler_options( - self.args.use_docker, self.args.compiler) + ( + self._docker_distro, + self._cmake_configure_extra_args, + ) = self._compiler_options(self.args.use_docker, self.args.compiler) - if self.args.arch == 'x86': + if self.args.arch == "x86": # disable boringssl asm optimizations when on x86 # see https://github.com/grpc/grpc/blob/b5b8578b3f8b4a9ce61ed6677e19d546e43c5c68/tools/run_tests/artifacts/artifact_targets.py#L253 - self._cmake_configure_extra_args.append('-DOPENSSL_NO_ASM=ON') + self._cmake_configure_extra_args.append("-DOPENSSL_NO_ASM=ON") def test_specs(self): out = [] binaries = get_c_tests(self.args.travis, self.test_lang) for target in binaries: - if target.get('boringssl', False): + if target.get("boringssl", False): # cmake doesn't build boringssl tests continue - auto_timeout_scaling = target.get('auto_timeout_scaling', True) - polling_strategies = (_POLLING_STRATEGIES.get( - self.platform, ['all']) if target.get('uses_polling', True) else - ['none']) + auto_timeout_scaling = target.get("auto_timeout_scaling", True) + polling_strategies = ( + _POLLING_STRATEGIES.get(self.platform, ["all"]) + if target.get("uses_polling", True) + else ["none"] + ) for polling_strategy in polling_strategies: env = { - 'GRPC_DEFAULT_SSL_ROOTS_FILE_PATH': - _ROOT + '/src/core/tsi/test_creds/ca.pem', - 'GRPC_POLL_STRATEGY': - polling_strategy, - 'GRPC_VERBOSITY': - 'DEBUG' + "GRPC_DEFAULT_SSL_ROOTS_FILE_PATH": _ROOT + + "/src/core/tsi/test_creds/ca.pem", + "GRPC_POLL_STRATEGY": polling_strategy, + "GRPC_VERBOSITY": "DEBUG", } - resolver = os.environ.get('GRPC_DNS_RESOLVER', None) + resolver = os.environ.get("GRPC_DNS_RESOLVER", None) if resolver: - env['GRPC_DNS_RESOLVER'] = resolver - shortname_ext = '' if polling_strategy == 'all' else ' GRPC_POLL_STRATEGY=%s' % polling_strategy - if polling_strategy in target.get('excluded_poll_engines', []): + env["GRPC_DNS_RESOLVER"] = resolver + shortname_ext = ( + "" + if polling_strategy == "all" + else " GRPC_POLL_STRATEGY=%s" % polling_strategy + ) + if polling_strategy in target.get("excluded_poll_engines", []): continue timeout_scaling = 1 if auto_timeout_scaling: config = self.args.config - if ('asan' in config or config == 'msan' or - config == 'tsan' or config == 'ubsan' or - config == 'helgrind' or config == 'memcheck'): + if ( + "asan" in config + or config == "msan" + or config == "tsan" + or config == "ubsan" + or config == "helgrind" + or config == "memcheck" + ): # Scale overall test timeout if running under various sanitizers. # scaling value is based on historical data analysis timeout_scaling *= 3 - if self.config.build_config in target['exclude_configs']: + if self.config.build_config in target["exclude_configs"]: continue - if self.args.iomgr_platform in target.get('exclude_iomgrs', []): + if self.args.iomgr_platform in target.get("exclude_iomgrs", []): continue - if self.platform == 'windows': - if self._cmake_generator_windows == 'Ninja': - binary = 'cmake/build/%s.exe' % target['name'] + if self.platform == "windows": + if self._cmake_generator_windows == "Ninja": + binary = "cmake/build/%s.exe" % target["name"] else: - binary = 'cmake/build/%s/%s.exe' % (_MSBUILD_CONFIG[ - self.config.build_config], target['name']) + binary = "cmake/build/%s/%s.exe" % ( + _MSBUILD_CONFIG[self.config.build_config], + target["name"], + ) else: - binary = 'cmake/build/%s' % target['name'] + binary = "cmake/build/%s" % target["name"] - cpu_cost = target['cpu_cost'] - if cpu_cost == 'capacity': + cpu_cost = target["cpu_cost"] + if cpu_cost == "capacity": cpu_cost = multiprocessing.cpu_count() if os.path.isfile(binary): list_test_command = None @@ -355,119 +411,140 @@ def test_specs(self): # these are the flag defined by gtest and benchmark framework to list # and filter test runs. We use them to split each individual test # into its own JobSpec, and thus into its own process. - if 'benchmark' in target and target['benchmark']: - with open(os.devnull, 'w') as fnull: + if "benchmark" in target and target["benchmark"]: + with open(os.devnull, "w") as fnull: tests = subprocess.check_output( - [binary, '--benchmark_list_tests'], - stderr=fnull) - for line in tests.decode().split('\n'): + [binary, "--benchmark_list_tests"], stderr=fnull + ) + for line in tests.decode().split("\n"): test = line.strip() if not test: continue - cmdline = [binary, - '--benchmark_filter=%s$' % test - ] + target['args'] + cmdline = [ + binary, + "--benchmark_filter=%s$" % test, + ] + target["args"] out.append( self.config.job_spec( cmdline, - shortname='%s %s' % - (' '.join(cmdline), shortname_ext), + shortname="%s %s" + % (" ".join(cmdline), shortname_ext), cpu_cost=cpu_cost, timeout_seconds=target.get( - 'timeout_seconds', - _DEFAULT_TIMEOUT_SECONDS) * - timeout_scaling, - environ=env)) - elif 'gtest' in target and target['gtest']: + "timeout_seconds", + _DEFAULT_TIMEOUT_SECONDS, + ) + * timeout_scaling, + environ=env, + ) + ) + elif "gtest" in target and target["gtest"]: # here we parse the output of --gtest_list_tests to build up a complete # list of the tests contained in a binary for each test, we then # add a job to run, filtering for just that test. - with open(os.devnull, 'w') as fnull: + with open(os.devnull, "w") as fnull: tests = subprocess.check_output( - [binary, '--gtest_list_tests'], stderr=fnull) + [binary, "--gtest_list_tests"], stderr=fnull + ) base = None - for line in tests.decode().split('\n'): - i = line.find('#') + for line in tests.decode().split("\n"): + i = line.find("#") if i >= 0: line = line[:i] if not line: continue - if line[0] != ' ': + if line[0] != " ": base = line.strip() else: assert base is not None - assert line[1] == ' ' + assert line[1] == " " test = base + line.strip() - cmdline = [binary, - '--gtest_filter=%s' % test - ] + target['args'] + cmdline = [ + binary, + "--gtest_filter=%s" % test, + ] + target["args"] out.append( self.config.job_spec( cmdline, - shortname='%s %s' % - (' '.join(cmdline), shortname_ext), + shortname="%s %s" + % (" ".join(cmdline), shortname_ext), cpu_cost=cpu_cost, timeout_seconds=target.get( - 'timeout_seconds', - _DEFAULT_TIMEOUT_SECONDS) * - timeout_scaling, - environ=env)) + "timeout_seconds", + _DEFAULT_TIMEOUT_SECONDS, + ) + * timeout_scaling, + environ=env, + ) + ) else: - cmdline = [binary] + target['args'] + cmdline = [binary] + target["args"] shortname = target.get( - 'shortname', - ' '.join(pipes.quote(arg) for arg in cmdline)) + "shortname", + " ".join(pipes.quote(arg) for arg in cmdline), + ) shortname += shortname_ext out.append( self.config.job_spec( cmdline, shortname=shortname, cpu_cost=cpu_cost, - flaky=target.get('flaky', False), + flaky=target.get("flaky", False), timeout_seconds=target.get( - 'timeout_seconds', - _DEFAULT_TIMEOUT_SECONDS) * timeout_scaling, - environ=env)) - elif self.args.regex == '.*' or self.platform == 'windows': - print('\nWARNING: binary not found, skipping', binary) + "timeout_seconds", _DEFAULT_TIMEOUT_SECONDS + ) + * timeout_scaling, + environ=env, + ) + ) + elif self.args.regex == ".*" or self.platform == "windows": + print("\nWARNING: binary not found, skipping", binary) return sorted(out) def pre_build_steps(self): return [] def build_steps(self): - if self.platform == 'windows': - return [[ - 'tools\\run_tests\\helper_scripts\\build_cxx.bat', - '-DgRPC_BUILD_MSVC_MP_COUNT=%d' % self.args.jobs - ] + self._cmake_configure_extra_args] + if self.platform == "windows": + return [ + [ + "tools\\run_tests\\helper_scripts\\build_cxx.bat", + "-DgRPC_BUILD_MSVC_MP_COUNT=%d" % self.args.jobs, + ] + + self._cmake_configure_extra_args + ] else: - return [['tools/run_tests/helper_scripts/build_cxx.sh'] + - self._cmake_configure_extra_args] + return [ + ["tools/run_tests/helper_scripts/build_cxx.sh"] + + self._cmake_configure_extra_args + ] def build_steps_environ(self): """Extra environment variables set for pre_build_steps and build_steps jobs.""" - environ = {'GRPC_RUN_TESTS_CXX_LANGUAGE_SUFFIX': self.lang_suffix} - if self.platform == 'windows': - environ['GRPC_CMAKE_GENERATOR'] = self._cmake_generator_windows + environ = {"GRPC_RUN_TESTS_CXX_LANGUAGE_SUFFIX": self.lang_suffix} + if self.platform == "windows": + environ["GRPC_CMAKE_GENERATOR"] = self._cmake_generator_windows environ[ - 'GRPC_CMAKE_ARCHITECTURE'] = self._cmake_architecture_windows + "GRPC_CMAKE_ARCHITECTURE" + ] = self._cmake_architecture_windows environ[ - 'GRPC_BUILD_ACTIVATE_VS_TOOLS'] = self._activate_vs_tools_windows + "GRPC_BUILD_ACTIVATE_VS_TOOLS" + ] = self._activate_vs_tools_windows environ[ - 'GRPC_BUILD_VS_TOOLS_ARCHITECTURE'] = self._vs_tools_architecture_windows + "GRPC_BUILD_VS_TOOLS_ARCHITECTURE" + ] = self._vs_tools_architecture_windows return environ def post_tests_steps(self): - if self.platform == 'windows': + if self.platform == "windows": return [] else: - return [['tools/run_tests/helper_scripts/post_tests_c.sh']] + return [["tools/run_tests/helper_scripts/post_tests_c.sh"]] - def _clang_cmake_configure_extra_args(self, version_suffix=''): + def _clang_cmake_configure_extra_args(self, version_suffix=""): return [ - '-DCMAKE_C_COMPILER=clang%s' % version_suffix, - '-DCMAKE_CXX_COMPILER=clang++%s' % version_suffix, + "-DCMAKE_C_COMPILER=clang%s" % version_suffix, + "-DCMAKE_CXX_COMPILER=clang++%s" % version_suffix, ] def _compiler_options(self, use_docker, compiler): @@ -475,32 +552,37 @@ def _compiler_options(self, use_docker, compiler): if not use_docker and not _is_use_docker_child(): # if not running under docker, we cannot ensure the right compiler version will be used, # so we only allow the non-specific choices. - _check_compiler(compiler, ['default', 'cmake']) - - if compiler == 'default' or compiler == 'cmake': - return ('debian11', []) - elif compiler == 'gcc7': - return ('gcc_7', []) - elif compiler == 'gcc10.2': - return ('debian11', []) - elif compiler == 'gcc10.2_openssl102': - return ('debian11_openssl102', [ - "-DgRPC_SSL_PROVIDER=package", - ]) - elif compiler == 'gcc12': - return ('gcc_12', ["-DCMAKE_CXX_STANDARD=20"]) - elif compiler == 'gcc_musl': - return ('alpine', []) - elif compiler == 'clang6': - return ('clang_6', self._clang_cmake_configure_extra_args()) - elif compiler == 'clang15': - return ('clang_15', self._clang_cmake_configure_extra_args()) + _check_compiler(compiler, ["default", "cmake"]) + + if compiler == "default" or compiler == "cmake": + return ("debian11", []) + elif compiler == "gcc7": + return ("gcc_7", []) + elif compiler == "gcc10.2": + return ("debian11", []) + elif compiler == "gcc10.2_openssl102": + return ( + "debian11_openssl102", + [ + "-DgRPC_SSL_PROVIDER=package", + ], + ) + elif compiler == "gcc12": + return ("gcc_12", ["-DCMAKE_CXX_STANDARD=20"]) + elif compiler == "gcc_musl": + return ("alpine", []) + elif compiler == "clang6": + return ("clang_6", self._clang_cmake_configure_extra_args()) + elif compiler == "clang15": + return ("clang_15", self._clang_cmake_configure_extra_args()) else: - raise Exception('Compiler %s not supported.' % compiler) + raise Exception("Compiler %s not supported." % compiler) def dockerfile_dir(self): - return 'tools/dockerfile/test/cxx_%s_%s' % ( - self._docker_distro, _docker_arch_suffix(self.args.arch)) + return "tools/dockerfile/test/cxx_%s_%s" % ( + self._docker_distro, + _docker_arch_suffix(self.args.arch), + ) def __str__(self): return self.lang_suffix @@ -508,7 +590,6 @@ def __str__(self): # This tests Node on grpc/grpc-node and will become the standard for Node testing class RemoteNodeLanguage(object): - def __init__(self): self.platform = platform_string() @@ -517,35 +598,47 @@ def configure(self, config, args): self.args = args # Note: electron ABI only depends on major and minor version, so that's all # we should specify in the compiler argument - _check_compiler(self.args.compiler, [ - 'default', 'node0.12', 'node4', 'node5', 'node6', 'node7', 'node8', - 'electron1.3', 'electron1.6' - ]) - if self.args.compiler == 'default': - self.runtime = 'node' - self.node_version = '8' + _check_compiler( + self.args.compiler, + [ + "default", + "node0.12", + "node4", + "node5", + "node6", + "node7", + "node8", + "electron1.3", + "electron1.6", + ], + ) + if self.args.compiler == "default": + self.runtime = "node" + self.node_version = "8" else: - if self.args.compiler.startswith('electron'): - self.runtime = 'electron' + if self.args.compiler.startswith("electron"): + self.runtime = "electron" self.node_version = self.args.compiler[8:] else: - self.runtime = 'node' + self.runtime = "node" # Take off the word "node" self.node_version = self.args.compiler[4:] # TODO: update with Windows/electron scripts when available for grpc/grpc-node def test_specs(self): - if self.platform == 'windows': + if self.platform == "windows": return [ self.config.job_spec( - ['tools\\run_tests\\helper_scripts\\run_node.bat']) + ["tools\\run_tests\\helper_scripts\\run_node.bat"] + ) ] else: return [ self.config.job_spec( - ['tools/run_tests/helper_scripts/run_grpc-node.sh'], + ["tools/run_tests/helper_scripts/run_grpc-node.sh"], None, - environ=_FORCE_ENVIRON_FOR_WRAPPERS) + environ=_FORCE_ENVIRON_FOR_WRAPPERS, + ) ] def pre_build_steps(self): @@ -562,67 +655,70 @@ def post_tests_steps(self): return [] def dockerfile_dir(self): - return 'tools/dockerfile/test/node_jessie_%s' % _docker_arch_suffix( - self.args.arch) + return "tools/dockerfile/test/node_jessie_%s" % _docker_arch_suffix( + self.args.arch + ) def __str__(self): - return 'grpc-node' + return "grpc-node" class Php7Language(object): - def configure(self, config, args): self.config = config self.args = args - _check_compiler(self.args.compiler, ['default']) + _check_compiler(self.args.compiler, ["default"]) def test_specs(self): return [ - self.config.job_spec(['src/php/bin/run_tests.sh'], - environ=_FORCE_ENVIRON_FOR_WRAPPERS) + self.config.job_spec( + ["src/php/bin/run_tests.sh"], + environ=_FORCE_ENVIRON_FOR_WRAPPERS, + ) ] def pre_build_steps(self): return [] def build_steps(self): - return [['tools/run_tests/helper_scripts/build_php.sh']] + return [["tools/run_tests/helper_scripts/build_php.sh"]] def build_steps_environ(self): """Extra environment variables set for pre_build_steps and build_steps jobs.""" return {} def post_tests_steps(self): - return [['tools/run_tests/helper_scripts/post_tests_php.sh']] + return [["tools/run_tests/helper_scripts/post_tests_php.sh"]] def dockerfile_dir(self): - return 'tools/dockerfile/test/php7_debian11_%s' % _docker_arch_suffix( - self.args.arch) + return "tools/dockerfile/test/php7_debian11_%s" % _docker_arch_suffix( + self.args.arch + ) def __str__(self): - return 'php7' + return "php7" class PythonConfig( - collections.namedtuple('PythonConfig', ['name', 'build', 'run'])): + collections.namedtuple("PythonConfig", ["name", "build", "run"]) +): """Tuple of commands (named s.t. 'what it says on the tin' applies)""" class PythonLanguage(object): - _TEST_SPECS_FILE = { - 'native': ['src/python/grpcio_tests/tests/tests.json'], - 'gevent': [ - 'src/python/grpcio_tests/tests/tests.json', - 'src/python/grpcio_tests/tests_gevent/tests.json', + "native": ["src/python/grpcio_tests/tests/tests.json"], + "gevent": [ + "src/python/grpcio_tests/tests/tests.json", + "src/python/grpcio_tests/tests_gevent/tests.json", ], - 'asyncio': ['src/python/grpcio_tests/tests_aio/tests.json'], + "asyncio": ["src/python/grpcio_tests/tests_aio/tests.json"], } _TEST_COMMAND = { - 'native': 'test_lite', - 'gevent': 'test_gevent', - 'asyncio': 'test_aio', + "native": "test_lite", + "gevent": "test_gevent", + "asyncio": "test_aio", } def configure(self, config, args): @@ -643,20 +739,25 @@ def test_specs(self): # TODO(https://github.com/grpc/grpc/issues/21401) Fork handlers is not # designed for non-native IO manager. It has a side-effect that # overrides threading settings in C-Core. - if io_platform != 'native': - environment['GRPC_ENABLE_FORK_SUPPORT'] = '0' + if io_platform != "native": + environment["GRPC_ENABLE_FORK_SUPPORT"] = "0" for python_config in self.pythons: - jobs.extend([ - self.config.job_spec( - python_config.run + [self._TEST_COMMAND[io_platform]], - timeout_seconds=8 * 60, - environ=dict( - GRPC_PYTHON_TESTRUNNER_FILTER=str(test_case), - **environment), - shortname='%s.%s.%s' % - (python_config.name, io_platform, test_case), - ) for test_case in test_cases - ]) + jobs.extend( + [ + self.config.job_spec( + python_config.run + + [self._TEST_COMMAND[io_platform]], + timeout_seconds=8 * 60, + environ=dict( + GRPC_PYTHON_TESTRUNNER_FILTER=str(test_case), + **environment, + ), + shortname="%s.%s.%s" + % (python_config.name, io_platform, test_case), + ) + for test_case in test_cases + ] + ) return jobs def pre_build_steps(self): @@ -670,97 +771,114 @@ def build_steps_environ(self): return {} def post_tests_steps(self): - if self.config.build_config != 'gcov': + if self.config.build_config != "gcov": return [] else: - return [['tools/run_tests/helper_scripts/post_tests_python.sh']] + return [["tools/run_tests/helper_scripts/post_tests_python.sh"]] def dockerfile_dir(self): - return 'tools/dockerfile/test/python_%s_%s' % ( + return "tools/dockerfile/test/python_%s_%s" % ( self._python_docker_distro_name(), - _docker_arch_suffix(self.args.arch)) + _docker_arch_suffix(self.args.arch), + ) def _python_docker_distro_name(self): """Choose the docker image to use based on python version.""" - if self.args.compiler == 'python_alpine': - return 'alpine' + if self.args.compiler == "python_alpine": + return "alpine" else: - return 'debian11_default' + return "debian11_default" def _get_pythons(self, args): """Get python runtimes to test with, based on current platform, architecture, compiler etc.""" - if args.iomgr_platform != 'native': + if args.iomgr_platform != "native": raise ValueError( - 'Python builds no longer differentiate IO Manager platforms, please use "native"' + "Python builds no longer differentiate IO Manager platforms," + ' please use "native"' ) - if args.arch == 'x86': - bits = '32' + if args.arch == "x86": + bits = "32" else: - bits = '64' + bits = "64" - if os.name == 'nt': - shell = ['bash'] + if os.name == "nt": + shell = ["bash"] builder = [ os.path.abspath( - 'tools/run_tests/helper_scripts/build_python_msys2.sh') + "tools/run_tests/helper_scripts/build_python_msys2.sh" + ) ] - builder_prefix_arguments = ['MINGW{}'.format(bits)] - venv_relative_python = ['Scripts/python.exe'] - toolchain = ['mingw32'] + builder_prefix_arguments = ["MINGW{}".format(bits)] + venv_relative_python = ["Scripts/python.exe"] + toolchain = ["mingw32"] else: shell = [] builder = [ os.path.abspath( - 'tools/run_tests/helper_scripts/build_python.sh') + "tools/run_tests/helper_scripts/build_python.sh" + ) ] builder_prefix_arguments = [] - venv_relative_python = ['bin/python'] - toolchain = ['unix'] + venv_relative_python = ["bin/python"] + toolchain = ["unix"] runner = [ - os.path.abspath('tools/run_tests/helper_scripts/run_python.sh') + os.path.abspath("tools/run_tests/helper_scripts/run_python.sh") ] - config_vars = _PythonConfigVars(shell, builder, - builder_prefix_arguments, - venv_relative_python, toolchain, runner) - python37_config = _python_config_generator(name='py37', - major='3', - minor='7', - bits=bits, - config_vars=config_vars) - python38_config = _python_config_generator(name='py38', - major='3', - minor='8', - bits=bits, - config_vars=config_vars) - python39_config = _python_config_generator(name='py39', - major='3', - minor='9', - bits=bits, - config_vars=config_vars) - python310_config = _python_config_generator(name='py310', - major='3', - minor='10', - bits=bits, - config_vars=config_vars) - pypy27_config = _pypy_config_generator(name='pypy', - major='2', - config_vars=config_vars) - pypy32_config = _pypy_config_generator(name='pypy3', - major='3', - config_vars=config_vars) - - if args.compiler == 'default': - if os.name == 'nt': + config_vars = _PythonConfigVars( + shell, + builder, + builder_prefix_arguments, + venv_relative_python, + toolchain, + runner, + ) + python37_config = _python_config_generator( + name="py37", + major="3", + minor="7", + bits=bits, + config_vars=config_vars, + ) + python38_config = _python_config_generator( + name="py38", + major="3", + minor="8", + bits=bits, + config_vars=config_vars, + ) + python39_config = _python_config_generator( + name="py39", + major="3", + minor="9", + bits=bits, + config_vars=config_vars, + ) + python310_config = _python_config_generator( + name="py310", + major="3", + minor="10", + bits=bits, + config_vars=config_vars, + ) + pypy27_config = _pypy_config_generator( + name="pypy", major="2", config_vars=config_vars + ) + pypy32_config = _pypy_config_generator( + name="pypy3", major="3", config_vars=config_vars + ) + + if args.compiler == "default": + if os.name == "nt": return (python38_config,) - elif os.uname()[0] == 'Darwin': + elif os.uname()[0] == "Darwin": # NOTE(rbellevi): Testing takes significantly longer on # MacOS, so we restrict the number of interpreter versions # tested. return (python38_config,) - elif platform.machine() == 'aarch64': + elif platform.machine() == "aarch64": # Currently the python_debian11_default_arm64 docker image # only has python3.9 installed (and that seems sufficient # for arm64 testing) @@ -770,21 +888,21 @@ def _get_pythons(self, args): python37_config, python38_config, ) - elif args.compiler == 'python3.7': + elif args.compiler == "python3.7": return (python37_config,) - elif args.compiler == 'python3.8': + elif args.compiler == "python3.8": return (python38_config,) - elif args.compiler == 'python3.9': + elif args.compiler == "python3.9": return (python39_config,) - elif args.compiler == 'python3.10': + elif args.compiler == "python3.10": return (python310_config,) - elif args.compiler == 'pypy': + elif args.compiler == "pypy": return (pypy27_config,) - elif args.compiler == 'pypy3': + elif args.compiler == "pypy3": return (pypy32_config,) - elif args.compiler == 'python_alpine': + elif args.compiler == "python_alpine": return (python39_config,) - elif args.compiler == 'all_the_cpythons': + elif args.compiler == "all_the_cpythons": return ( python37_config, python38_config, @@ -792,24 +910,25 @@ def _get_pythons(self, args): python310_config, ) else: - raise Exception('Compiler %s not supported.' % args.compiler) + raise Exception("Compiler %s not supported." % args.compiler) def __str__(self): - return 'python' + return "python" class RubyLanguage(object): - def configure(self, config, args): self.config = config self.args = args - _check_compiler(self.args.compiler, ['default']) + _check_compiler(self.args.compiler, ["default"]) def test_specs(self): tests = [ - self.config.job_spec(['tools/run_tests/helper_scripts/run_ruby.sh'], - timeout_seconds=10 * 60, - environ=_FORCE_ENVIRON_FOR_WRAPPERS) + self.config.job_spec( + ["tools/run_tests/helper_scripts/run_ruby.sh"], + timeout_seconds=10 * 60, + environ=_FORCE_ENVIRON_FOR_WRAPPERS, + ) ] # TODO(apolcyn): re-enable the following tests after # https://bugs.ruby-lang.org/issues/15499 is fixed: @@ -823,176 +942,193 @@ def test_specs(self): # b/266212253. # - src/ruby/end2end/grpc_class_init_test.rb for test in [ - 'src/ruby/end2end/sig_handling_test.rb', - 'src/ruby/end2end/channel_closing_test.rb', - 'src/ruby/end2end/killed_client_thread_test.rb', - 'src/ruby/end2end/forking_client_test.rb', - 'src/ruby/end2end/multiple_killed_watching_threads_test.rb', - 'src/ruby/end2end/load_grpc_with_gc_stress_test.rb', - 'src/ruby/end2end/client_memory_usage_test.rb', - 'src/ruby/end2end/package_with_underscore_test.rb', - 'src/ruby/end2end/graceful_sig_handling_test.rb', - 'src/ruby/end2end/graceful_sig_stop_test.rb', - 'src/ruby/end2end/errors_load_before_grpc_lib_test.rb', - 'src/ruby/end2end/logger_load_before_grpc_lib_test.rb', - 'src/ruby/end2end/status_codes_load_before_grpc_lib_test.rb', - 'src/ruby/end2end/call_credentials_timeout_test.rb', - 'src/ruby/end2end/call_credentials_returning_bad_metadata_doesnt_kill_background_thread_test.rb' + "src/ruby/end2end/sig_handling_test.rb", + "src/ruby/end2end/channel_closing_test.rb", + "src/ruby/end2end/killed_client_thread_test.rb", + "src/ruby/end2end/forking_client_test.rb", + "src/ruby/end2end/multiple_killed_watching_threads_test.rb", + "src/ruby/end2end/load_grpc_with_gc_stress_test.rb", + "src/ruby/end2end/client_memory_usage_test.rb", + "src/ruby/end2end/package_with_underscore_test.rb", + "src/ruby/end2end/graceful_sig_handling_test.rb", + "src/ruby/end2end/graceful_sig_stop_test.rb", + "src/ruby/end2end/errors_load_before_grpc_lib_test.rb", + "src/ruby/end2end/logger_load_before_grpc_lib_test.rb", + "src/ruby/end2end/status_codes_load_before_grpc_lib_test.rb", + "src/ruby/end2end/call_credentials_timeout_test.rb", + "src/ruby/end2end/call_credentials_returning_bad_metadata_doesnt_kill_background_thread_test.rb", ]: tests.append( - self.config.job_spec(['ruby', test], - shortname=test, - timeout_seconds=20 * 60, - environ=_FORCE_ENVIRON_FOR_WRAPPERS)) + self.config.job_spec( + ["ruby", test], + shortname=test, + timeout_seconds=20 * 60, + environ=_FORCE_ENVIRON_FOR_WRAPPERS, + ) + ) return tests def pre_build_steps(self): - return [['tools/run_tests/helper_scripts/pre_build_ruby.sh']] + return [["tools/run_tests/helper_scripts/pre_build_ruby.sh"]] def build_steps(self): - return [['tools/run_tests/helper_scripts/build_ruby.sh']] + return [["tools/run_tests/helper_scripts/build_ruby.sh"]] def build_steps_environ(self): """Extra environment variables set for pre_build_steps and build_steps jobs.""" return {} def post_tests_steps(self): - return [['tools/run_tests/helper_scripts/post_tests_ruby.sh']] + return [["tools/run_tests/helper_scripts/post_tests_ruby.sh"]] def dockerfile_dir(self): - return 'tools/dockerfile/test/ruby_debian11_%s' % _docker_arch_suffix( - self.args.arch) + return "tools/dockerfile/test/ruby_debian11_%s" % _docker_arch_suffix( + self.args.arch + ) def __str__(self): - return 'ruby' + return "ruby" class CSharpLanguage(object): - def __init__(self): self.platform = platform_string() def configure(self, config, args): self.config = config self.args = args - _check_compiler(self.args.compiler, ['default', 'coreclr', 'mono']) - if self.args.compiler == 'default': + _check_compiler(self.args.compiler, ["default", "coreclr", "mono"]) + if self.args.compiler == "default": # test both runtimes by default - self.test_runtimes = ['coreclr', 'mono'] + self.test_runtimes = ["coreclr", "mono"] else: # only test the specified runtime self.test_runtimes = [self.args.compiler] - if self.platform == 'windows': - _check_arch(self.args.arch, ['default']) - self._cmake_arch_option = 'x64' + if self.platform == "windows": + _check_arch(self.args.arch, ["default"]) + self._cmake_arch_option = "x64" else: - self._docker_distro = 'debian11' + self._docker_distro = "debian11" def test_specs(self): - with open('src/csharp/tests.json') as f: + with open("src/csharp/tests.json") as f: tests_by_assembly = json.load(f) msbuild_config = _MSBUILD_CONFIG[self.config.build_config] - nunit_args = ['--labels=All', '--noresult', '--workers=1'] + nunit_args = ["--labels=All", "--noresult", "--workers=1"] specs = [] for test_runtime in self.test_runtimes: - if test_runtime == 'coreclr': - assembly_extension = '.dll' - assembly_subdir = 'bin/%s/netcoreapp3.1' % msbuild_config - runtime_cmd = ['dotnet', 'exec'] - elif test_runtime == 'mono': - assembly_extension = '.exe' - assembly_subdir = 'bin/%s/net45' % msbuild_config - if self.platform == 'windows': + if test_runtime == "coreclr": + assembly_extension = ".dll" + assembly_subdir = "bin/%s/netcoreapp3.1" % msbuild_config + runtime_cmd = ["dotnet", "exec"] + elif test_runtime == "mono": + assembly_extension = ".exe" + assembly_subdir = "bin/%s/net45" % msbuild_config + if self.platform == "windows": runtime_cmd = [] - elif self.platform == 'mac': + elif self.platform == "mac": # mono before version 5.2 on MacOS defaults to 32bit runtime - runtime_cmd = ['mono', '--arch=64'] + runtime_cmd = ["mono", "--arch=64"] else: - runtime_cmd = ['mono'] + runtime_cmd = ["mono"] else: raise Exception('Illegal runtime "%s" was specified.') for assembly in six.iterkeys(tests_by_assembly): - assembly_file = 'src/csharp/%s/%s/%s%s' % ( - assembly, assembly_subdir, assembly, assembly_extension) + assembly_file = "src/csharp/%s/%s/%s%s" % ( + assembly, + assembly_subdir, + assembly, + assembly_extension, + ) # normally, run each test as a separate process for test in tests_by_assembly[assembly]: - cmdline = runtime_cmd + [assembly_file, - '--test=%s' % test] + nunit_args + cmdline = ( + runtime_cmd + + [assembly_file, "--test=%s" % test] + + nunit_args + ) specs.append( self.config.job_spec( cmdline, - shortname='csharp.%s.%s' % (test_runtime, test), - environ=_FORCE_ENVIRON_FOR_WRAPPERS)) + shortname="csharp.%s.%s" % (test_runtime, test), + environ=_FORCE_ENVIRON_FOR_WRAPPERS, + ) + ) return specs def pre_build_steps(self): - if self.platform == 'windows': - return [['tools\\run_tests\\helper_scripts\\pre_build_csharp.bat']] + if self.platform == "windows": + return [["tools\\run_tests\\helper_scripts\\pre_build_csharp.bat"]] else: - return [['tools/run_tests/helper_scripts/pre_build_csharp.sh']] + return [["tools/run_tests/helper_scripts/pre_build_csharp.sh"]] def build_steps(self): - if self.platform == 'windows': - return [['tools\\run_tests\\helper_scripts\\build_csharp.bat']] + if self.platform == "windows": + return [["tools\\run_tests\\helper_scripts\\build_csharp.bat"]] else: - return [['tools/run_tests/helper_scripts/build_csharp.sh']] + return [["tools/run_tests/helper_scripts/build_csharp.sh"]] def build_steps_environ(self): """Extra environment variables set for pre_build_steps and build_steps jobs.""" - if self.platform == 'windows': - return {'ARCHITECTURE': self._cmake_arch_option} + if self.platform == "windows": + return {"ARCHITECTURE": self._cmake_arch_option} else: return {} def post_tests_steps(self): - if self.platform == 'windows': - return [['tools\\run_tests\\helper_scripts\\post_tests_csharp.bat']] + if self.platform == "windows": + return [["tools\\run_tests\\helper_scripts\\post_tests_csharp.bat"]] else: - return [['tools/run_tests/helper_scripts/post_tests_csharp.sh']] + return [["tools/run_tests/helper_scripts/post_tests_csharp.sh"]] def dockerfile_dir(self): - return 'tools/dockerfile/test/csharp_%s_%s' % ( - self._docker_distro, _docker_arch_suffix(self.args.arch)) + return "tools/dockerfile/test/csharp_%s_%s" % ( + self._docker_distro, + _docker_arch_suffix(self.args.arch), + ) def __str__(self): - return 'csharp' + return "csharp" class ObjCLanguage(object): - def configure(self, config, args): self.config = config self.args = args - _check_compiler(self.args.compiler, ['default']) + _check_compiler(self.args.compiler, ["default"]) def test_specs(self): out = [] out.append( self.config.job_spec( - ['src/objective-c/tests/build_one_example.sh'], + ["src/objective-c/tests/build_one_example.sh"], timeout_seconds=20 * 60, - shortname='ios-buildtest-example-sample', + shortname="ios-buildtest-example-sample", cpu_cost=1e6, environ={ - 'SCHEME': 'Sample', - 'EXAMPLE_PATH': 'src/objective-c/examples/Sample', - })) + "SCHEME": "Sample", + "EXAMPLE_PATH": "src/objective-c/examples/Sample", + }, + ) + ) # TODO(jtattermusch): Create bazel target for the sample and remove the test task from here. out.append( self.config.job_spec( - ['src/objective-c/tests/build_one_example.sh'], + ["src/objective-c/tests/build_one_example.sh"], timeout_seconds=20 * 60, - shortname='ios-buildtest-example-switftsample', + shortname="ios-buildtest-example-switftsample", cpu_cost=1e6, environ={ - 'SCHEME': 'SwiftSample', - 'EXAMPLE_PATH': 'src/objective-c/examples/SwiftSample' - })) + "SCHEME": "SwiftSample", + "EXAMPLE_PATH": "src/objective-c/examples/SwiftSample", + }, + ) + ) # Disabled due to #20258 # TODO (mxyan): Reenable this test when #20258 is resolved. # out.append( @@ -1011,11 +1147,13 @@ def test_specs(self): # How does one add the cfstream dependency in bazel? out.append( self.config.job_spec( - ['test/core/iomgr/ios/CFStreamTests/build_and_run_tests.sh'], + ["test/core/iomgr/ios/CFStreamTests/build_and_run_tests.sh"], timeout_seconds=60 * 60, - shortname='ios-test-cfstream-tests', + shortname="ios-test-cfstream-tests", cpu_cost=1e6, - environ=_FORCE_ENVIRON_FOR_WRAPPERS)) + environ=_FORCE_ENVIRON_FOR_WRAPPERS, + ) + ) return sorted(out) def pre_build_steps(self): @@ -1035,37 +1173,39 @@ def dockerfile_dir(self): return None def __str__(self): - return 'objc' + return "objc" class Sanity(object): - def __init__(self, config_file): self.config_file = config_file def configure(self, config, args): self.config = config self.args = args - _check_compiler(self.args.compiler, ['default']) + _check_compiler(self.args.compiler, ["default"]) def test_specs(self): import yaml - with open('tools/run_tests/sanity/%s' % self.config_file, 'r') as f: - environ = {'TEST': 'true'} + + with open("tools/run_tests/sanity/%s" % self.config_file, "r") as f: + environ = {"TEST": "true"} if _is_use_docker_child(): - environ['CLANG_FORMAT_SKIP_DOCKER'] = 'true' - environ['CLANG_TIDY_SKIP_DOCKER'] = 'true' - environ['IWYU_SKIP_DOCKER'] = 'true' + environ["CLANG_FORMAT_SKIP_DOCKER"] = "true" + environ["CLANG_TIDY_SKIP_DOCKER"] = "true" + environ["IWYU_SKIP_DOCKER"] = "true" # sanity tests run tools/bazel wrapper concurrently # and that can result in a download/run race in the wrapper. # under docker we already have the right version of bazel # so we can just disable the wrapper. - environ['DISABLE_BAZEL_WRAPPER'] = 'true' + environ["DISABLE_BAZEL_WRAPPER"] = "true" return [ - self.config.job_spec(cmd['script'].split(), - timeout_seconds=45 * 60, - environ=environ, - cpu_cost=cmd.get('cpu_cost', 1)) + self.config.job_spec( + cmd["script"].split(), + timeout_seconds=45 * 60, + environ=environ, + cpu_cost=cmd.get("cpu_cost", 1), + ) for cmd in yaml.safe_load(f) ] @@ -1083,111 +1223,126 @@ def post_tests_steps(self): return [] def dockerfile_dir(self): - return 'tools/dockerfile/test/sanity' + return "tools/dockerfile/test/sanity" def __str__(self): - return 'sanity' + return "sanity" # different configurations we can run under -with open('tools/run_tests/generated/configs.json') as f: +with open("tools/run_tests/generated/configs.json") as f: _CONFIGS = dict( - (cfg['config'], Config(**cfg)) for cfg in ast.literal_eval(f.read())) + (cfg["config"], Config(**cfg)) for cfg in ast.literal_eval(f.read()) + ) _LANGUAGES = { - 'c++': CLanguage('cxx', 'c++'), - 'c': CLanguage('c', 'c'), - 'grpc-node': RemoteNodeLanguage(), - 'php7': Php7Language(), - 'python': PythonLanguage(), - 'ruby': RubyLanguage(), - 'csharp': CSharpLanguage(), - 'objc': ObjCLanguage(), - 'sanity': Sanity('sanity_tests.yaml'), - 'clang-tidy': Sanity('clang_tidy_tests.yaml'), - 'iwyu': Sanity('iwyu_tests.yaml'), + "c++": CLanguage("cxx", "c++"), + "c": CLanguage("c", "c"), + "grpc-node": RemoteNodeLanguage(), + "php7": Php7Language(), + "python": PythonLanguage(), + "ruby": RubyLanguage(), + "csharp": CSharpLanguage(), + "objc": ObjCLanguage(), + "sanity": Sanity("sanity_tests.yaml"), + "clang-tidy": Sanity("clang_tidy_tests.yaml"), + "iwyu": Sanity("iwyu_tests.yaml"), } _MSBUILD_CONFIG = { - 'dbg': 'Debug', - 'opt': 'Release', - 'gcov': 'Debug', + "dbg": "Debug", + "opt": "Release", + "gcov": "Debug", } def _build_step_environ(cfg, extra_env={}): """Environment variables set for each build step.""" - environ = {'CONFIG': cfg, 'GRPC_RUN_TESTS_JOBS': str(args.jobs)} + environ = {"CONFIG": cfg, "GRPC_RUN_TESTS_JOBS": str(args.jobs)} msbuild_cfg = _MSBUILD_CONFIG.get(cfg) if msbuild_cfg: - environ['MSBUILD_CONFIG'] = msbuild_cfg + environ["MSBUILD_CONFIG"] = msbuild_cfg environ.update(extra_env) return environ def _windows_arch_option(arch): """Returns msbuild cmdline option for selected architecture.""" - if arch == 'default' or arch == 'x86': - return '/p:Platform=Win32' - elif arch == 'x64': - return '/p:Platform=x64' + if arch == "default" or arch == "x86": + return "/p:Platform=Win32" + elif arch == "x64": + return "/p:Platform=x64" else: - print('Architecture %s not supported.' % arch) + print("Architecture %s not supported." % arch) sys.exit(1) def _check_arch_option(arch): """Checks that architecture option is valid.""" - if platform_string() == 'windows': + if platform_string() == "windows": _windows_arch_option(arch) - elif platform_string() == 'linux': + elif platform_string() == "linux": # On linux, we need to be running under docker with the right architecture. runtime_machine = platform.machine() runtime_arch = platform.architecture()[0] - if arch == 'default': + if arch == "default": return - elif runtime_machine == 'x86_64' and runtime_arch == '64bit' and arch == 'x64': + elif ( + runtime_machine == "x86_64" + and runtime_arch == "64bit" + and arch == "x64" + ): return - elif runtime_machine == 'x86_64' and runtime_arch == '32bit' and arch == 'x86': + elif ( + runtime_machine == "x86_64" + and runtime_arch == "32bit" + and arch == "x86" + ): return - elif runtime_machine == 'aarch64' and runtime_arch == '64bit' and arch == 'arm64': + elif ( + runtime_machine == "aarch64" + and runtime_arch == "64bit" + and arch == "arm64" + ): return else: print( - 'Architecture %s does not match current runtime architecture.' % - arch) + "Architecture %s does not match current runtime architecture." + % arch + ) sys.exit(1) else: - if args.arch != 'default': - print('Architecture %s not supported on current platform.' % - args.arch) + if args.arch != "default": + print( + "Architecture %s not supported on current platform." % args.arch + ) sys.exit(1) def _docker_arch_suffix(arch): """Returns suffix to dockerfile dir to use.""" - if arch == 'default' or arch == 'x64': - return 'x64' - elif arch == 'x86': - return 'x86' - elif arch == 'arm64': - return 'arm64' + if arch == "default" or arch == "x64": + return "x64" + elif arch == "x86": + return "x86" + elif arch == "arm64": + return "arm64" else: - print('Architecture %s not supported with current settings.' % arch) + print("Architecture %s not supported with current settings." % arch) sys.exit(1) def runs_per_test_type(arg_str): """Auxiliary function to parse the "runs_per_test" flag. - Returns: - A positive integer or 0, the latter indicating an infinite number of - runs. + Returns: + A positive integer or 0, the latter indicating an infinite number of + runs. - Raises: - argparse.ArgumentTypeError: Upon invalid input. + Raises: + argparse.ArgumentTypeError: Upon invalid input. """ - if arg_str == 'inf': + if arg_str == "inf": return 0 try: n = int(arg_str) @@ -1195,7 +1350,7 @@ def runs_per_test_type(arg_str): raise ValueError return n except: - msg = '\'{}\' is not a positive integer or \'inf\''.format(arg_str) + msg = "'{}' is not a positive integer or 'inf'".format(arg_str) raise argparse.ArgumentTypeError(msg) @@ -1203,7 +1358,8 @@ def percent_type(arg_str): pct = float(arg_str) if pct > 100 or pct < 0: raise argparse.ArgumentTypeError( - "'%f' is not a valid percentage in the [0, 100] range" % pct) + "'%f' is not a valid percentage in the [0, 100] range" % pct + ) return pct @@ -1216,24 +1372,27 @@ def _shut_down_legacy_server(legacy_server_port): """Shut down legacy version of port server.""" try: version = int( - urllib.request.urlopen('http://localhost:%d/version_number' % - legacy_server_port, - timeout=10).read()) + urllib.request.urlopen( + "http://localhost:%d/version_number" % legacy_server_port, + timeout=10, + ).read() + ) except: pass else: - urllib.request.urlopen('http://localhost:%d/quitquitquit' % - legacy_server_port).read() + urllib.request.urlopen( + "http://localhost:%d/quitquitquit" % legacy_server_port + ).read() def _calculate_num_runs_failures(list_of_results): """Calculate number of runs and failures for a particular test. - Args: - list_of_results: (List) of JobResult object. - Returns: - A tuple of total number of runs and failures. - """ + Args: + list_of_results: (List) of JobResult object. + Returns: + A tuple of total number of runs and failures. + """ num_runs = len(list_of_results) # By default, there is 1 run per JobResult. num_failures = 0 for jobresult in list_of_results: @@ -1253,29 +1412,31 @@ class BuildAndRunError(object): # returns a list of things that failed (or an empty list on success) -def _build_and_run(check_cancelled, - newline_on_success, - xml_report=None, - build_only=False): +def _build_and_run( + check_cancelled, newline_on_success, xml_report=None, build_only=False +): """Do one pass of building & running tests.""" # build latest sequentially - num_failures, resultset = jobset.run(build_steps, - maxjobs=1, - stop_on_failure=True, - newline_on_success=newline_on_success, - travis=args.travis) + num_failures, resultset = jobset.run( + build_steps, + maxjobs=1, + stop_on_failure=True, + newline_on_success=newline_on_success, + travis=args.travis, + ) if num_failures: return [BuildAndRunError.BUILD] if build_only: if xml_report: report_utils.render_junit_xml_report( - resultset, xml_report, suite_name=args.report_suite_name) + resultset, xml_report, suite_name=args.report_suite_name + ) return [] # start antagonists antagonists = [ - subprocess.Popen(['tools/run_tests/python_utils/antagonist.py']) + subprocess.Popen(["tools/run_tests/python_utils/antagonist.py"]) for _ in range(0, args.antagonists) ] start_port_server.start_port_server() @@ -1283,11 +1444,18 @@ def _build_and_run(check_cancelled, num_test_failures = 0 try: infinite_runs = runs_per_test == 0 - one_run = set(spec for language in languages - for spec in language.test_specs() - if (re.search(args.regex, spec.shortname) and - (args.regex_exclude == '' or - not re.search(args.regex_exclude, spec.shortname)))) + one_run = set( + spec + for language in languages + for spec in language.test_specs() + if ( + re.search(args.regex, spec.shortname) + and ( + args.regex_exclude == "" + or not re.search(args.regex_exclude, spec.shortname) + ) + ) + ) # When running on travis, we want out test runs to be as similar as possible # for reproducibility purposes. if args.travis and args.max_time <= 0: @@ -1296,7 +1464,8 @@ def _build_and_run(check_cancelled, # whereas otherwise, we want to shuffle things up to give all tests a # chance to run. massaged_one_run = list( - one_run) # random.sample needs an indexable seq. + one_run + ) # random.sample needs an indexable seq. num_jobs = len(massaged_one_run) # for a random sample, get as many as indicated by the 'sample_percent' # argument. By default this arg is 100, resulting in a shuffle of all @@ -1304,21 +1473,30 @@ def _build_and_run(check_cancelled, sample_size = int(num_jobs * args.sample_percent / 100.0) massaged_one_run = random.sample(massaged_one_run, sample_size) if not isclose(args.sample_percent, 100.0): - assert args.runs_per_test == 1, "Can't do sampling (-p) over multiple runs (-n)." - print("Running %d tests out of %d (~%d%%)" % - (sample_size, num_jobs, args.sample_percent)) + assert ( + args.runs_per_test == 1 + ), "Can't do sampling (-p) over multiple runs (-n)." + print( + "Running %d tests out of %d (~%d%%)" + % (sample_size, num_jobs, args.sample_percent) + ) if infinite_runs: - assert len(massaged_one_run - ) > 0, 'Must have at least one test for a -n inf run' - runs_sequence = (itertools.repeat(massaged_one_run) if infinite_runs - else itertools.repeat(massaged_one_run, runs_per_test)) + assert ( + len(massaged_one_run) > 0 + ), "Must have at least one test for a -n inf run" + runs_sequence = ( + itertools.repeat(massaged_one_run) + if infinite_runs + else itertools.repeat(massaged_one_run, runs_per_test) + ) all_runs = itertools.chain.from_iterable(runs_sequence) if args.quiet_success: jobset.message( - 'START', - 'Running tests quietly, only failing tests will be reported', - do_newline=True) + "START", + "Running tests quietly, only failing tests will be reported", + do_newline=True, + ) num_test_failures, resultset = jobset.run( all_runs, check_cancelled, @@ -1328,49 +1506,57 @@ def _build_and_run(check_cancelled, maxjobs_cpu_agnostic=max_parallel_tests_for_current_platform(), stop_on_failure=args.stop_on_failure, quiet_success=args.quiet_success, - max_time=args.max_time) + max_time=args.max_time, + ) if resultset: for k, v in sorted(resultset.items()): num_runs, num_failures = _calculate_num_runs_failures(v) if num_failures > 0: if num_failures == num_runs: # what about infinite_runs??? - jobset.message('FAILED', k, do_newline=True) + jobset.message("FAILED", k, do_newline=True) else: - jobset.message('FLAKE', - '%s [%d/%d runs flaked]' % - (k, num_failures, num_runs), - do_newline=True) + jobset.message( + "FLAKE", + "%s [%d/%d runs flaked]" + % (k, num_failures, num_runs), + do_newline=True, + ) finally: for antagonist in antagonists: antagonist.kill() if args.bq_result_table and resultset: upload_extra_fields = { - 'compiler': args.compiler, - 'config': args.config, - 'iomgr_platform': args.iomgr_platform, - 'language': args.language[ + "compiler": args.compiler, + "config": args.config, + "iomgr_platform": args.iomgr_platform, + "language": args.language[ 0 ], # args.language is a list but will always have one element when uploading to BQ is enabled. - 'platform': platform_string() + "platform": platform_string(), } try: - upload_results_to_bq(resultset, args.bq_result_table, - upload_extra_fields) + upload_results_to_bq( + resultset, args.bq_result_table, upload_extra_fields + ) except NameError as e: logging.warning( - e) # It's fine to ignore since this is not critical + e + ) # It's fine to ignore since this is not critical if xml_report and resultset: report_utils.render_junit_xml_report( resultset, xml_report, suite_name=args.report_suite_name, - multi_target=args.report_multi_target) + multi_target=args.report_multi_target, + ) - number_failures, _ = jobset.run(post_tests_steps, - maxjobs=1, - stop_on_failure=False, - newline_on_success=newline_on_success, - travis=args.travis) + number_failures, _ = jobset.run( + post_tests_steps, + maxjobs=1, + stop_on_failure=False, + newline_on_success=newline_on_success, + travis=args.travis, + ) out = [] if number_failures: @@ -1382,164 +1568,197 @@ def _build_and_run(check_cancelled, # parse command line -argp = argparse.ArgumentParser(description='Run grpc tests.') -argp.add_argument('-c', - '--config', - choices=sorted(_CONFIGS.keys()), - default='opt') +argp = argparse.ArgumentParser(description="Run grpc tests.") +argp.add_argument( + "-c", "--config", choices=sorted(_CONFIGS.keys()), default="opt" +) argp.add_argument( - '-n', - '--runs_per_test', + "-n", + "--runs_per_test", default=1, type=runs_per_test_type, - help='A positive integer or "inf". If "inf", all tests will run in an ' - 'infinite loop. Especially useful in combination with "-f"') -argp.add_argument('-r', '--regex', default='.*', type=str) -argp.add_argument('--regex_exclude', default='', type=str) -argp.add_argument('-j', '--jobs', default=multiprocessing.cpu_count(), type=int) -argp.add_argument('-s', '--slowdown', default=1.0, type=float) -argp.add_argument('-p', - '--sample_percent', - default=100.0, - type=percent_type, - help='Run a random sample with that percentage of tests') + help=( + 'A positive integer or "inf". If "inf", all tests will run in an ' + 'infinite loop. Especially useful in combination with "-f"' + ), +) +argp.add_argument("-r", "--regex", default=".*", type=str) +argp.add_argument("--regex_exclude", default="", type=str) +argp.add_argument("-j", "--jobs", default=multiprocessing.cpu_count(), type=int) +argp.add_argument("-s", "--slowdown", default=1.0, type=float) argp.add_argument( - '-t', - '--travis', + "-p", + "--sample_percent", + default=100.0, + type=percent_type, + help="Run a random sample with that percentage of tests", +) +argp.add_argument( + "-t", + "--travis", default=False, - action='store_const', + action="store_const", const=True, - help='When set, indicates that the script is running on CI (= not locally).' + help=( + "When set, indicates that the script is running on CI (= not locally)." + ), +) +argp.add_argument( + "--newline_on_success", default=False, action="store_const", const=True ) -argp.add_argument('--newline_on_success', - default=False, - action='store_const', - const=True) -argp.add_argument('-l', - '--language', - choices=sorted(_LANGUAGES.keys()), - nargs='+', - required=True) -argp.add_argument('-S', - '--stop_on_failure', - default=False, - action='store_const', - const=True) -argp.add_argument('--use_docker', - default=False, - action='store_const', - const=True, - help='Run all the tests under docker. That provides ' + - 'additional isolation and prevents the need to install ' + - 'language specific prerequisites. Only available on Linux.') argp.add_argument( - '--allow_flakes', + "-l", + "--language", + choices=sorted(_LANGUAGES.keys()), + nargs="+", + required=True, +) +argp.add_argument( + "-S", "--stop_on_failure", default=False, action="store_const", const=True +) +argp.add_argument( + "--use_docker", + default=False, + action="store_const", + const=True, + help="Run all the tests under docker. That provides " + + "additional isolation and prevents the need to install " + + "language specific prerequisites. Only available on Linux.", +) +argp.add_argument( + "--allow_flakes", default=False, - action='store_const', + action="store_const", const=True, - help= - 'Allow flaky tests to show as passing (re-runs failed tests up to five times)' + help=( + "Allow flaky tests to show as passing (re-runs failed tests up to five" + " times)" + ), ) argp.add_argument( - '--arch', - choices=['default', 'x86', 'x64', 'arm64'], - default='default', - help= - 'Selects architecture to target. For some platforms "default" is the only supported choice.' + "--arch", + choices=["default", "x86", "x64", "arm64"], + default="default", + help=( + 'Selects architecture to target. For some platforms "default" is the' + " only supported choice." + ), ) argp.add_argument( - '--compiler', + "--compiler", choices=[ - 'default', - 'gcc7', - 'gcc10.2', - 'gcc10.2_openssl102', - 'gcc12', - 'gcc_musl', - 'clang6', - 'clang15', - 'python2.7', - 'python3.5', - 'python3.7', - 'python3.8', - 'python3.9', - 'pypy', - 'pypy3', - 'python_alpine', - 'all_the_cpythons', - 'electron1.3', - 'electron1.6', - 'coreclr', - 'cmake', - 'cmake_ninja_vs2019', - 'cmake_vs2019', - 'mono', + "default", + "gcc7", + "gcc10.2", + "gcc10.2_openssl102", + "gcc12", + "gcc_musl", + "clang6", + "clang15", + "python2.7", + "python3.5", + "python3.7", + "python3.8", + "python3.9", + "pypy", + "pypy3", + "python_alpine", + "all_the_cpythons", + "electron1.3", + "electron1.6", + "coreclr", + "cmake", + "cmake_ninja_vs2019", + "cmake_vs2019", + "mono", ], - default='default', - help= - 'Selects compiler to use. Allowed values depend on the platform and language.' + default="default", + help=( + "Selects compiler to use. Allowed values depend on the platform and" + " language." + ), +) +argp.add_argument( + "--iomgr_platform", + choices=["native", "gevent", "asyncio"], + default="native", + help="Selects iomgr platform to build on", ) -argp.add_argument('--iomgr_platform', - choices=['native', 'gevent', 'asyncio'], - default='native', - help='Selects iomgr platform to build on') -argp.add_argument('--build_only', - default=False, - action='store_const', - const=True, - help='Perform all the build steps but don\'t run any tests.') -argp.add_argument('--measure_cpu_costs', - default=False, - action='store_const', - const=True, - help='Measure the cpu costs of tests') -argp.add_argument('-a', '--antagonists', default=0, type=int) -argp.add_argument('-x', - '--xml_report', - default=None, - type=str, - help='Generates a JUnit-compatible XML report') -argp.add_argument('--report_suite_name', - default='tests', - type=str, - help='Test suite name to use in generated JUnit XML report') argp.add_argument( - '--report_multi_target', + "--build_only", default=False, + action="store_const", const=True, - action='store_const', - help='Generate separate XML report for each test job (Looks better in UIs).' + help="Perform all the build steps but don't run any tests.", ) argp.add_argument( - '--quiet_success', + "--measure_cpu_costs", default=False, - action='store_const', + action="store_const", const=True, - help= - 'Don\'t print anything when a test passes. Passing tests also will not be reported in XML report. ' - + 'Useful when running many iterations of each test (argument -n).') + help="Measure the cpu costs of tests", +) +argp.add_argument("-a", "--antagonists", default=0, type=int) +argp.add_argument( + "-x", + "--xml_report", + default=None, + type=str, + help="Generates a JUnit-compatible XML report", +) +argp.add_argument( + "--report_suite_name", + default="tests", + type=str, + help="Test suite name to use in generated JUnit XML report", +) +argp.add_argument( + "--report_multi_target", + default=False, + const=True, + action="store_const", + help=( + "Generate separate XML report for each test job (Looks better in UIs)." + ), +) +argp.add_argument( + "--quiet_success", + default=False, + action="store_const", + const=True, + help=( + "Don't print anything when a test passes. Passing tests also will not" + " be reported in XML report. " + ) + + "Useful when running many iterations of each test (argument -n).", +) argp.add_argument( - '--force_default_poller', + "--force_default_poller", default=False, - action='store_const', + action="store_const", const=True, - help='Don\'t try to iterate over many polling strategies when they exist') + help="Don't try to iterate over many polling strategies when they exist", +) argp.add_argument( - '--force_use_pollers', + "--force_use_pollers", default=None, type=str, - help='Only use the specified comma-delimited list of polling engines. ' - 'Example: --force_use_pollers epoll1,poll ' - ' (This flag has no effect if --force_default_poller flag is also used)') -argp.add_argument('--max_time', - default=-1, - type=int, - help='Maximum test runtime in seconds') -argp.add_argument('--bq_result_table', - default='', - type=str, - nargs='?', - help='Upload test results to a specified BQ table.') + help=( + "Only use the specified comma-delimited list of polling engines. " + "Example: --force_use_pollers epoll1,poll " + " (This flag has no effect if --force_default_poller flag is also used)" + ), +) +argp.add_argument( + "--max_time", default=-1, type=int, help="Maximum test runtime in seconds" +) +argp.add_argument( + "--bq_result_table", + default="", + type=str, + nargs="?", + help="Upload test results to a specified BQ table.", +) args = argp.parse_args() flaky_tests = set() @@ -1548,7 +1767,7 @@ def _build_and_run(check_cancelled, if args.force_default_poller: _POLLING_STRATEGIES = {} elif args.force_use_pollers: - _POLLING_STRATEGIES[platform_string()] = args.force_use_pollers.split(',') + _POLLING_STRATEGIES[platform_string()] = args.force_use_pollers.split(",") jobset.measure_cpu_costs = args.measure_cpu_costs @@ -1558,51 +1777,55 @@ def _build_and_run(check_cancelled, # TODO(jtattermusch): is this setting applied/being used? if args.travis: - _FORCE_ENVIRON_FOR_WRAPPERS = {'GRPC_TRACE': 'api'} + _FORCE_ENVIRON_FOR_WRAPPERS = {"GRPC_TRACE": "api"} languages = set(_LANGUAGES[l] for l in args.language) for l in languages: l.configure(run_config, args) if len(languages) != 1: - print('Building multiple languages simultaneously is not supported!') + print("Building multiple languages simultaneously is not supported!") sys.exit(1) # If --use_docker was used, respawn the run_tests.py script under a docker container # instead of continuing. if args.use_docker: if not args.travis: - print('Seen --use_docker flag, will run tests under docker.') - print('') + print("Seen --use_docker flag, will run tests under docker.") + print("") print( - 'IMPORTANT: The changes you are testing need to be locally committed' + "IMPORTANT: The changes you are testing need to be locally" + " committed" ) print( - 'because only the committed changes in the current branch will be') - print('copied to the docker environment.') + "because only the committed changes in the current branch will be" + ) + print("copied to the docker environment.") time.sleep(5) dockerfile_dirs = set([l.dockerfile_dir() for l in languages]) if len(dockerfile_dirs) > 1: - print('Languages to be tested require running under different docker ' - 'images.') + print( + "Languages to be tested require running under different docker " + "images." + ) sys.exit(1) else: dockerfile_dir = next(iter(dockerfile_dirs)) - child_argv = [arg for arg in sys.argv if not arg == '--use_docker'] - run_tests_cmd = 'python3 tools/run_tests/run_tests.py %s' % ' '.join( - child_argv[1:]) + child_argv = [arg for arg in sys.argv if not arg == "--use_docker"] + run_tests_cmd = "python3 tools/run_tests/run_tests.py %s" % " ".join( + child_argv[1:] + ) env = os.environ.copy() - env['DOCKERFILE_DIR'] = dockerfile_dir - env['DOCKER_RUN_SCRIPT'] = 'tools/run_tests/dockerize/docker_run.sh' - env['DOCKER_RUN_SCRIPT_COMMAND'] = run_tests_cmd + env["DOCKERFILE_DIR"] = dockerfile_dir + env["DOCKER_RUN_SCRIPT"] = "tools/run_tests/dockerize/docker_run.sh" + env["DOCKER_RUN_SCRIPT_COMMAND"] = run_tests_cmd retcode = subprocess.call( - 'tools/run_tests/dockerize/build_and_run_docker.sh', - shell=True, - env=env) + "tools/run_tests/dockerize/build_and_run_docker.sh", shell=True, env=env + ) _print_debug_info_epilogue(dockerfile_dir=dockerfile_dir) sys.exit(retcode) @@ -1612,42 +1835,59 @@ def _build_and_run(check_cancelled, # flakes on downloading dependencies etc.) build_steps = list( set( - jobset.JobSpec(cmdline, - environ=_build_step_environ( - build_config, extra_env=l.build_steps_environ()), - timeout_seconds=_PRE_BUILD_STEP_TIMEOUT_SECONDS, - flake_retries=2) + jobset.JobSpec( + cmdline, + environ=_build_step_environ( + build_config, extra_env=l.build_steps_environ() + ), + timeout_seconds=_PRE_BUILD_STEP_TIMEOUT_SECONDS, + flake_retries=2, + ) for l in languages - for cmdline in l.pre_build_steps())) + for cmdline in l.pre_build_steps() + ) +) # collect build steps build_steps.extend( set( - jobset.JobSpec(cmdline, - environ=_build_step_environ( - build_config, extra_env=l.build_steps_environ()), - timeout_seconds=None) + jobset.JobSpec( + cmdline, + environ=_build_step_environ( + build_config, extra_env=l.build_steps_environ() + ), + timeout_seconds=None, + ) for l in languages - for cmdline in l.build_steps())) + for cmdline in l.build_steps() + ) +) # collect post test steps post_tests_steps = list( set( - jobset.JobSpec(cmdline, - environ=_build_step_environ( - build_config, extra_env=l.build_steps_environ())) + jobset.JobSpec( + cmdline, + environ=_build_step_environ( + build_config, extra_env=l.build_steps_environ() + ), + ) for l in languages - for cmdline in l.post_tests_steps())) + for cmdline in l.post_tests_steps() + ) +) runs_per_test = args.runs_per_test -errors = _build_and_run(check_cancelled=lambda: False, - newline_on_success=args.newline_on_success, - xml_report=args.xml_report, - build_only=args.build_only) +errors = _build_and_run( + check_cancelled=lambda: False, + newline_on_success=args.newline_on_success, + xml_report=args.xml_report, + build_only=args.build_only, +) if not errors: - jobset.message('SUCCESS', 'All tests passed', do_newline=True) + jobset.message("SUCCESS", "All tests passed", do_newline=True) else: - jobset.message('FAILED', 'Some tests failed', do_newline=True) + jobset.message("FAILED", "Some tests failed", do_newline=True) if not _is_use_docker_child(): # if --use_docker was used, the outer invocation of run_tests.py will diff --git a/tools/run_tests/run_tests_matrix.py b/tools/run_tests/run_tests_matrix.py index 1fbe6476bd406..90eb52b8032ff 100755 --- a/tools/run_tests/run_tests_matrix.py +++ b/tools/run_tests/run_tests_matrix.py @@ -25,7 +25,7 @@ import python_utils.jobset as jobset import python_utils.report_utils as report_utils -_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '../..')) +_ROOT = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "../..")) os.chdir(_ROOT) _DEFAULT_RUNTESTS_TIMEOUT = 1 * 60 * 60 @@ -42,18 +42,18 @@ # Name of the top-level umbrella report that includes all the run_tests.py invocations # Note that the starting letter 't' matters so that the targets are listed AFTER # the per-test breakdown items that start with 'run_tests/' (it is more readable that way) -_MATRIX_REPORT_NAME = 'toplevel_run_tests_invocations' +_MATRIX_REPORT_NAME = "toplevel_run_tests_invocations" def _safe_report_name(name): """Reports with '+' in target name won't show correctly in ResultStore""" - return name.replace('+', 'p') + return name.replace("+", "p") def _report_filename(name): """Generates report file name with directory structure that leads to better presentation by internal CI""" # 'sponge_log.xml' suffix must be there for results to get recognized by kokoro. - return '%s/%s' % (_safe_report_name(name), 'sponge_log.xml') + return "%s/%s" % (_safe_report_name(name), "sponge_log.xml") def _matrix_job_logfilename(shortname_for_multi_target): @@ -62,119 +62,162 @@ def _matrix_job_logfilename(shortname_for_multi_target): # for the corresponding 'sponge_log.xml' report. # the shortname_for_multi_target component must be set to match the sponge_log.xml location # because the top-level render_junit_xml_report is called with multi_target=True - sponge_log_name = '%s/%s/%s' % ( - _MATRIX_REPORT_NAME, shortname_for_multi_target, 'sponge_log.log') + sponge_log_name = "%s/%s/%s" % ( + _MATRIX_REPORT_NAME, + shortname_for_multi_target, + "sponge_log.log", + ) # env variable can be used to override the base location for the reports # so we need to match that behavior here too - base_dir = os.getenv('GRPC_TEST_REPORT_BASE_DIR', None) + base_dir = os.getenv("GRPC_TEST_REPORT_BASE_DIR", None) if base_dir: sponge_log_name = os.path.join(base_dir, sponge_log_name) return sponge_log_name -def _docker_jobspec(name, - runtests_args=[], - runtests_envs={}, - inner_jobs=_DEFAULT_INNER_JOBS, - timeout_seconds=None): +def _docker_jobspec( + name, + runtests_args=[], + runtests_envs={}, + inner_jobs=_DEFAULT_INNER_JOBS, + timeout_seconds=None, +): """Run a single instance of run_tests.py in a docker container""" if not timeout_seconds: timeout_seconds = _DEFAULT_RUNTESTS_TIMEOUT - shortname = 'run_tests_%s' % name - test_job = jobset.JobSpec(cmdline=[ - 'python3', 'tools/run_tests/run_tests.py', '--use_docker', '-t', '-j', - str(inner_jobs), '-x', - 'run_tests/%s' % _report_filename(name), '--report_suite_name', - '%s' % _safe_report_name(name) - ] + runtests_args, - environ=runtests_envs, - shortname=shortname, - timeout_seconds=timeout_seconds, - logfilename=_matrix_job_logfilename(shortname)) + shortname = "run_tests_%s" % name + test_job = jobset.JobSpec( + cmdline=[ + "python3", + "tools/run_tests/run_tests.py", + "--use_docker", + "-t", + "-j", + str(inner_jobs), + "-x", + "run_tests/%s" % _report_filename(name), + "--report_suite_name", + "%s" % _safe_report_name(name), + ] + + runtests_args, + environ=runtests_envs, + shortname=shortname, + timeout_seconds=timeout_seconds, + logfilename=_matrix_job_logfilename(shortname), + ) return test_job -def _workspace_jobspec(name, - runtests_args=[], - workspace_name=None, - runtests_envs={}, - inner_jobs=_DEFAULT_INNER_JOBS, - timeout_seconds=None): +def _workspace_jobspec( + name, + runtests_args=[], + workspace_name=None, + runtests_envs={}, + inner_jobs=_DEFAULT_INNER_JOBS, + timeout_seconds=None, +): """Run a single instance of run_tests.py in a separate workspace""" if not workspace_name: - workspace_name = 'workspace_%s' % name + workspace_name = "workspace_%s" % name if not timeout_seconds: timeout_seconds = _DEFAULT_RUNTESTS_TIMEOUT - shortname = 'run_tests_%s' % name - env = {'WORKSPACE_NAME': workspace_name} + shortname = "run_tests_%s" % name + env = {"WORKSPACE_NAME": workspace_name} env.update(runtests_envs) # if report base dir is set, we don't need to ".." to come out of the workspace dir - report_dir_prefix = '' if os.getenv('GRPC_TEST_REPORT_BASE_DIR', - None) else '../' - test_job = jobset.JobSpec(cmdline=[ - 'bash', 'tools/run_tests/helper_scripts/run_tests_in_workspace.sh', - '-t', '-j', - str(inner_jobs), '-x', - '%srun_tests/%s' % - (report_dir_prefix, _report_filename(name)), '--report_suite_name', - '%s' % _safe_report_name(name) - ] + runtests_args, - environ=env, - shortname=shortname, - timeout_seconds=timeout_seconds, - logfilename=_matrix_job_logfilename(shortname)) + report_dir_prefix = ( + "" if os.getenv("GRPC_TEST_REPORT_BASE_DIR", None) else "../" + ) + test_job = jobset.JobSpec( + cmdline=[ + "bash", + "tools/run_tests/helper_scripts/run_tests_in_workspace.sh", + "-t", + "-j", + str(inner_jobs), + "-x", + "%srun_tests/%s" % (report_dir_prefix, _report_filename(name)), + "--report_suite_name", + "%s" % _safe_report_name(name), + ] + + runtests_args, + environ=env, + shortname=shortname, + timeout_seconds=timeout_seconds, + logfilename=_matrix_job_logfilename(shortname), + ) return test_job -def _generate_jobs(languages, - configs, - platforms, - iomgr_platforms=['native'], - arch=None, - compiler=None, - labels=[], - extra_args=[], - extra_envs={}, - inner_jobs=_DEFAULT_INNER_JOBS, - timeout_seconds=None): +def _generate_jobs( + languages, + configs, + platforms, + iomgr_platforms=["native"], + arch=None, + compiler=None, + labels=[], + extra_args=[], + extra_envs={}, + inner_jobs=_DEFAULT_INNER_JOBS, + timeout_seconds=None, +): result = [] for language in languages: for platform in platforms: for iomgr_platform in iomgr_platforms: for config in configs: - name = '%s_%s_%s_%s' % (language, platform, config, - iomgr_platform) + name = "%s_%s_%s_%s" % ( + language, + platform, + config, + iomgr_platform, + ) runtests_args = [ - '-l', language, '-c', config, '--iomgr_platform', - iomgr_platform + "-l", + language, + "-c", + config, + "--iomgr_platform", + iomgr_platform, ] if arch or compiler: - name += '_%s_%s' % (arch, compiler) + name += "_%s_%s" % (arch, compiler) runtests_args += [ - '--arch', arch, '--compiler', compiler + "--arch", + arch, + "--compiler", + compiler, ] - if '--build_only' in extra_args: - name += '_buildonly' + if "--build_only" in extra_args: + name += "_buildonly" for extra_env in extra_envs: - name += '_%s_%s' % (extra_env, extra_envs[extra_env]) + name += "_%s_%s" % (extra_env, extra_envs[extra_env]) runtests_args += extra_args - if platform == 'linux': - job = _docker_jobspec(name=name, - runtests_args=runtests_args, - runtests_envs=extra_envs, - inner_jobs=inner_jobs, - timeout_seconds=timeout_seconds) + if platform == "linux": + job = _docker_jobspec( + name=name, + runtests_args=runtests_args, + runtests_envs=extra_envs, + inner_jobs=inner_jobs, + timeout_seconds=timeout_seconds, + ) else: job = _workspace_jobspec( name=name, runtests_args=runtests_args, runtests_envs=extra_envs, inner_jobs=inner_jobs, - timeout_seconds=timeout_seconds) - - job.labels = [platform, config, language, iomgr_platform - ] + labels + timeout_seconds=timeout_seconds, + ) + + job.labels = [ + platform, + config, + language, + iomgr_platform, + ] + labels result.append(job) return result @@ -182,206 +225,230 @@ def _generate_jobs(languages, def _create_test_jobs(extra_args=[], inner_jobs=_DEFAULT_INNER_JOBS): test_jobs = [] # sanity tests - test_jobs += _generate_jobs(languages=['sanity', 'clang-tidy', 'iwyu'], - configs=['dbg'], - platforms=['linux'], - labels=['basictests'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["sanity", "clang-tidy", "iwyu"], + configs=["dbg"], + platforms=["linux"], + labels=["basictests"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + ) # supported on all platforms. test_jobs += _generate_jobs( - languages=['c'], - configs=['dbg', 'opt'], - platforms=['linux', 'macos', 'windows'], - labels=['basictests', 'corelang'], - extra_args= - extra_args, # don't use multi_target report because C has too many test cases + languages=["c"], + configs=["dbg", "opt"], + platforms=["linux", "macos", "windows"], + labels=["basictests", "corelang"], + extra_args=extra_args, # don't use multi_target report because C has too many test cases inner_jobs=inner_jobs, - timeout_seconds=_CPP_RUNTESTS_TIMEOUT) + timeout_seconds=_CPP_RUNTESTS_TIMEOUT, + ) # C# tests (both on .NET desktop/mono and .NET core) - test_jobs += _generate_jobs(languages=['csharp'], - configs=['dbg', 'opt'], - platforms=['linux', 'macos', 'windows'], - labels=['basictests', 'multilang'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["csharp"], + configs=["dbg", "opt"], + platforms=["linux", "macos", "windows"], + labels=["basictests", "multilang"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + ) # ARM64 Linux C# tests - test_jobs += _generate_jobs(languages=['csharp'], - configs=['dbg', 'opt'], - platforms=['linux'], - arch='arm64', - compiler='default', - labels=['basictests_arm64'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs) - - test_jobs += _generate_jobs(languages=['python'], - configs=['opt'], - platforms=['linux', 'macos', 'windows'], - iomgr_platforms=['native'], - labels=['basictests', 'multilang'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["csharp"], + configs=["dbg", "opt"], + platforms=["linux"], + arch="arm64", + compiler="default", + labels=["basictests_arm64"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + ) + + test_jobs += _generate_jobs( + languages=["python"], + configs=["opt"], + platforms=["linux", "macos", "windows"], + iomgr_platforms=["native"], + labels=["basictests", "multilang"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + ) # ARM64 Linux Python tests - test_jobs += _generate_jobs(languages=['python'], - configs=['opt'], - platforms=['linux'], - arch='arm64', - compiler='default', - iomgr_platforms=['native'], - labels=['basictests_arm64'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["python"], + configs=["opt"], + platforms=["linux"], + arch="arm64", + compiler="default", + iomgr_platforms=["native"], + labels=["basictests_arm64"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + ) # supported on linux and mac. test_jobs += _generate_jobs( - languages=['c++'], - configs=['dbg', 'opt'], - platforms=['linux', 'macos'], - labels=['basictests', 'corelang'], - extra_args= - extra_args, # don't use multi_target report because C++ has too many test cases + languages=["c++"], + configs=["dbg", "opt"], + platforms=["linux", "macos"], + labels=["basictests", "corelang"], + extra_args=extra_args, # don't use multi_target report because C++ has too many test cases inner_jobs=inner_jobs, - timeout_seconds=_CPP_RUNTESTS_TIMEOUT) + timeout_seconds=_CPP_RUNTESTS_TIMEOUT, + ) - test_jobs += _generate_jobs(languages=['ruby', 'php7'], - configs=['dbg', 'opt'], - platforms=['linux', 'macos'], - labels=['basictests', 'multilang'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["ruby", "php7"], + configs=["dbg", "opt"], + platforms=["linux", "macos"], + labels=["basictests", "multilang"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + ) # ARM64 Linux Ruby and PHP tests - test_jobs += _generate_jobs(languages=['ruby', 'php7'], - configs=['dbg', 'opt'], - platforms=['linux'], - arch='arm64', - compiler='default', - labels=['basictests_arm64'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["ruby", "php7"], + configs=["dbg", "opt"], + platforms=["linux"], + arch="arm64", + compiler="default", + labels=["basictests_arm64"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + ) # supported on mac only. - test_jobs += _generate_jobs(languages=['objc'], - configs=['opt'], - platforms=['macos'], - labels=['basictests', 'multilang'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs, - timeout_seconds=_OBJC_RUNTESTS_TIMEOUT) + test_jobs += _generate_jobs( + languages=["objc"], + configs=["opt"], + platforms=["macos"], + labels=["basictests", "multilang"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + timeout_seconds=_OBJC_RUNTESTS_TIMEOUT, + ) return test_jobs -def _create_portability_test_jobs(extra_args=[], - inner_jobs=_DEFAULT_INNER_JOBS): +def _create_portability_test_jobs( + extra_args=[], inner_jobs=_DEFAULT_INNER_JOBS +): test_jobs = [] # portability C x86 - test_jobs += _generate_jobs(languages=['c'], - configs=['dbg'], - platforms=['linux'], - arch='x86', - compiler='default', - labels=['portability', 'corelang'], - extra_args=extra_args, - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["c"], + configs=["dbg"], + platforms=["linux"], + arch="x86", + compiler="default", + labels=["portability", "corelang"], + extra_args=extra_args, + inner_jobs=inner_jobs, + ) # portability C and C++ on x64 for compiler in [ - 'gcc7', - # 'gcc10.2_openssl102', // TODO(b/283304471): Enable this later - 'gcc12', - 'gcc_musl', - 'clang6', - 'clang15' + "gcc7", + # 'gcc10.2_openssl102', // TODO(b/283304471): Enable this later + "gcc12", + "gcc_musl", + "clang6", + "clang15", ]: - test_jobs += _generate_jobs(languages=['c', 'c++'], - configs=['dbg'], - platforms=['linux'], - arch='x64', - compiler=compiler, - labels=['portability', 'corelang'], - extra_args=extra_args, - inner_jobs=inner_jobs, - timeout_seconds=_CPP_RUNTESTS_TIMEOUT) + test_jobs += _generate_jobs( + languages=["c", "c++"], + configs=["dbg"], + platforms=["linux"], + arch="x64", + compiler=compiler, + labels=["portability", "corelang"], + extra_args=extra_args, + inner_jobs=inner_jobs, + timeout_seconds=_CPP_RUNTESTS_TIMEOUT, + ) # portability C on Windows 64-bit (x86 is the default) - test_jobs += _generate_jobs(languages=['c'], - configs=['dbg'], - platforms=['windows'], - arch='x64', - compiler='default', - labels=['portability', 'corelang'], - extra_args=extra_args, - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["c"], + configs=["dbg"], + platforms=["windows"], + arch="x64", + compiler="default", + labels=["portability", "corelang"], + extra_args=extra_args, + inner_jobs=inner_jobs, + ) # portability C on Windows with the "Visual Studio" cmake # generator, i.e. not using Ninja (to verify that we can still build with msbuild) - test_jobs += _generate_jobs(languages=['c'], - configs=['dbg'], - platforms=['windows'], - arch='default', - compiler='cmake_vs2019', - labels=['portability', 'corelang'], - extra_args=extra_args, - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["c"], + configs=["dbg"], + platforms=["windows"], + arch="default", + compiler="cmake_vs2019", + labels=["portability", "corelang"], + extra_args=extra_args, + inner_jobs=inner_jobs, + ) # portability C++ on Windows # TODO(jtattermusch): some of the tests are failing, so we force --build_only - test_jobs += _generate_jobs(languages=['c++'], - configs=['dbg'], - platforms=['windows'], - arch='default', - compiler='default', - labels=['portability', 'corelang'], - extra_args=extra_args + ['--build_only'], - inner_jobs=inner_jobs, - timeout_seconds=_CPP_RUNTESTS_TIMEOUT) + test_jobs += _generate_jobs( + languages=["c++"], + configs=["dbg"], + platforms=["windows"], + arch="default", + compiler="default", + labels=["portability", "corelang"], + extra_args=extra_args + ["--build_only"], + inner_jobs=inner_jobs, + timeout_seconds=_CPP_RUNTESTS_TIMEOUT, + ) # portability C and C++ on Windows using VS2019 (build only) # TODO(jtattermusch): The C tests with exactly the same config are already running as part of the # basictests_c suite (so we force --build_only to avoid running them twice). # The C++ tests aren't all passing, so also force --build_only. - test_jobs += _generate_jobs(languages=['c', 'c++'], - configs=['dbg'], - platforms=['windows'], - arch='x64', - compiler='cmake_ninja_vs2019', - labels=['portability', 'corelang'], - extra_args=extra_args + ['--build_only'], - inner_jobs=inner_jobs, - timeout_seconds=_CPP_RUNTESTS_TIMEOUT) + test_jobs += _generate_jobs( + languages=["c", "c++"], + configs=["dbg"], + platforms=["windows"], + arch="x64", + compiler="cmake_ninja_vs2019", + labels=["portability", "corelang"], + extra_args=extra_args + ["--build_only"], + inner_jobs=inner_jobs, + timeout_seconds=_CPP_RUNTESTS_TIMEOUT, + ) # C and C++ with no-exceptions on Linux - test_jobs += _generate_jobs(languages=['c', 'c++'], - configs=['noexcept'], - platforms=['linux'], - labels=['portability', 'corelang'], - extra_args=extra_args, - inner_jobs=inner_jobs, - timeout_seconds=_CPP_RUNTESTS_TIMEOUT) - - test_jobs += _generate_jobs(languages=['python'], - configs=['dbg'], - platforms=['linux'], - arch='default', - compiler='python_alpine', - labels=['portability', 'multilang'], - extra_args=extra_args + - ['--report_multi_target'], - inner_jobs=inner_jobs) + test_jobs += _generate_jobs( + languages=["c", "c++"], + configs=["noexcept"], + platforms=["linux"], + labels=["portability", "corelang"], + extra_args=extra_args, + inner_jobs=inner_jobs, + timeout_seconds=_CPP_RUNTESTS_TIMEOUT, + ) + + test_jobs += _generate_jobs( + languages=["python"], + configs=["dbg"], + platforms=["linux"], + arch="default", + compiler="python_alpine", + labels=["portability", "multilang"], + extra_args=extra_args + ["--report_multi_target"], + inner_jobs=inner_jobs, + ) return test_jobs @@ -403,177 +470,213 @@ def _runs_per_test_type(arg_str): raise ValueError return n except: - msg = '\'{}\' is not a positive integer'.format(arg_str) + msg = "'{}' is not a positive integer".format(arg_str) raise argparse.ArgumentTypeError(msg) if __name__ == "__main__": argp = argparse.ArgumentParser( - description='Run a matrix of run_tests.py tests.') - argp.add_argument('-j', - '--jobs', - default=multiprocessing.cpu_count() / _DEFAULT_INNER_JOBS, - type=int, - help='Number of concurrent run_tests.py instances.') - argp.add_argument('-f', - '--filter', - choices=_allowed_labels(), - nargs='+', - default=[], - help='Filter targets to run by label with AND semantics.') - argp.add_argument('--exclude', - choices=_allowed_labels(), - nargs='+', - default=[], - help='Exclude targets with any of given labels.') - argp.add_argument('--build_only', - default=False, - action='store_const', - const=True, - help='Pass --build_only flag to run_tests.py instances.') + description="Run a matrix of run_tests.py tests." + ) + argp.add_argument( + "-j", + "--jobs", + default=multiprocessing.cpu_count() / _DEFAULT_INNER_JOBS, + type=int, + help="Number of concurrent run_tests.py instances.", + ) + argp.add_argument( + "-f", + "--filter", + choices=_allowed_labels(), + nargs="+", + default=[], + help="Filter targets to run by label with AND semantics.", + ) argp.add_argument( - '--force_default_poller', + "--exclude", + choices=_allowed_labels(), + nargs="+", + default=[], + help="Exclude targets with any of given labels.", + ) + argp.add_argument( + "--build_only", + default=False, + action="store_const", + const=True, + help="Pass --build_only flag to run_tests.py instances.", + ) + argp.add_argument( + "--force_default_poller", + default=False, + action="store_const", + const=True, + help="Pass --force_default_poller to run_tests.py instances.", + ) + argp.add_argument( + "--dry_run", default=False, - action='store_const', + action="store_const", const=True, - help='Pass --force_default_poller to run_tests.py instances.') - argp.add_argument('--dry_run', - default=False, - action='store_const', - const=True, - help='Only print what would be run.') + help="Only print what would be run.", + ) argp.add_argument( - '--filter_pr_tests', + "--filter_pr_tests", default=False, - action='store_const', + action="store_const", const=True, - help='Filters out tests irrelevant to pull request changes.') + help="Filters out tests irrelevant to pull request changes.", + ) argp.add_argument( - '--base_branch', - default='origin/master', + "--base_branch", + default="origin/master", type=str, - help='Branch that pull request is requesting to merge into') - argp.add_argument('--inner_jobs', - default=_DEFAULT_INNER_JOBS, - type=int, - help='Number of jobs in each run_tests.py instance') + help="Branch that pull request is requesting to merge into", + ) argp.add_argument( - '-n', - '--runs_per_test', + "--inner_jobs", + default=_DEFAULT_INNER_JOBS, + type=int, + help="Number of jobs in each run_tests.py instance", + ) + argp.add_argument( + "-n", + "--runs_per_test", default=1, type=_runs_per_test_type, - help='How many times to run each tests. >1 runs implies ' + - 'omitting passing test from the output & reports.') - argp.add_argument('--max_time', - default=-1, - type=int, - help='Maximum amount of time to run tests for' + - '(other tests will be skipped)') + help="How many times to run each tests. >1 runs implies " + + "omitting passing test from the output & reports.", + ) + argp.add_argument( + "--max_time", + default=-1, + type=int, + help="Maximum amount of time to run tests for" + + "(other tests will be skipped)", + ) argp.add_argument( - '--internal_ci', + "--internal_ci", default=False, - action='store_const', + action="store_const", const=True, - help= - '(Deprecated, has no effect) Put reports into subdirectories to improve presentation of ' - 'results by Kokoro.') - argp.add_argument('--bq_result_table', - default='', - type=str, - nargs='?', - help='Upload test results to a specified BQ table.') - argp.add_argument('--extra_args', - default='', - type=str, - nargs=argparse.REMAINDER, - help='Extra test args passed to each sub-script.') + help=( + "(Deprecated, has no effect) Put reports into subdirectories to" + " improve presentation of results by Kokoro." + ), + ) + argp.add_argument( + "--bq_result_table", + default="", + type=str, + nargs="?", + help="Upload test results to a specified BQ table.", + ) + argp.add_argument( + "--extra_args", + default="", + type=str, + nargs=argparse.REMAINDER, + help="Extra test args passed to each sub-script.", + ) args = argp.parse_args() extra_args = [] if args.build_only: - extra_args.append('--build_only') + extra_args.append("--build_only") if args.force_default_poller: - extra_args.append('--force_default_poller') + extra_args.append("--force_default_poller") if args.runs_per_test > 1: - extra_args.append('-n') - extra_args.append('%s' % args.runs_per_test) - extra_args.append('--quiet_success') + extra_args.append("-n") + extra_args.append("%s" % args.runs_per_test) + extra_args.append("--quiet_success") if args.max_time > 0: - extra_args.extend(('--max_time', '%d' % args.max_time)) + extra_args.extend(("--max_time", "%d" % args.max_time)) if args.bq_result_table: - extra_args.append('--bq_result_table') - extra_args.append('%s' % args.bq_result_table) - extra_args.append('--measure_cpu_costs') + extra_args.append("--bq_result_table") + extra_args.append("%s" % args.bq_result_table) + extra_args.append("--measure_cpu_costs") if args.extra_args: extra_args.extend(args.extra_args) - all_jobs = _create_test_jobs(extra_args=extra_args, inner_jobs=args.inner_jobs) + \ - _create_portability_test_jobs(extra_args=extra_args, inner_jobs=args.inner_jobs) + all_jobs = _create_test_jobs( + extra_args=extra_args, inner_jobs=args.inner_jobs + ) + _create_portability_test_jobs( + extra_args=extra_args, inner_jobs=args.inner_jobs + ) jobs = [] for job in all_jobs: if not args.filter or all( - filter in job.labels for filter in args.filter): - if not any(exclude_label in job.labels - for exclude_label in args.exclude): + filter in job.labels for filter in args.filter + ): + if not any( + exclude_label in job.labels for exclude_label in args.exclude + ): jobs.append(job) if not jobs: - jobset.message('FAILED', - 'No test suites match given criteria.', - do_newline=True) + jobset.message( + "FAILED", "No test suites match given criteria.", do_newline=True + ) sys.exit(1) - print('IMPORTANT: The changes you are testing need to be locally committed') - print('because only the committed changes in the current branch will be') - print('copied to the docker environment or into subworkspaces.') + print("IMPORTANT: The changes you are testing need to be locally committed") + print("because only the committed changes in the current branch will be") + print("copied to the docker environment or into subworkspaces.") skipped_jobs = [] if args.filter_pr_tests: - print('Looking for irrelevant tests to skip...') + print("Looking for irrelevant tests to skip...") relevant_jobs = filter_tests(jobs, args.base_branch) if len(relevant_jobs) == len(jobs): - print('No tests will be skipped.') + print("No tests will be skipped.") else: - print('These tests will be skipped:') + print("These tests will be skipped:") skipped_jobs = list(set(jobs) - set(relevant_jobs)) # Sort by shortnames to make printing of skipped tests consistent skipped_jobs.sort(key=lambda job: job.shortname) for job in list(skipped_jobs): - print(' %s' % job.shortname) + print(" %s" % job.shortname) jobs = relevant_jobs - print('Will run these tests:') + print("Will run these tests:") for job in jobs: - print(' %s: "%s"' % (job.shortname, ' '.join(job.cmdline))) - print('') + print(' %s: "%s"' % (job.shortname, " ".join(job.cmdline))) + print("") if args.dry_run: - print('--dry_run was used, exiting') + print("--dry_run was used, exiting") sys.exit(1) - jobset.message('START', 'Running test matrix.', do_newline=True) - num_failures, resultset = jobset.run(jobs, - newline_on_success=True, - travis=True, - maxjobs=args.jobs) + jobset.message("START", "Running test matrix.", do_newline=True) + num_failures, resultset = jobset.run( + jobs, newline_on_success=True, travis=True, maxjobs=args.jobs + ) # Merge skipped tests into results to show skipped tests on report.xml if skipped_jobs: ignored_num_skipped_failures, skipped_results = jobset.run( - skipped_jobs, skip_jobs=True) + skipped_jobs, skip_jobs=True + ) resultset.update(skipped_results) - report_utils.render_junit_xml_report(resultset, - _report_filename(_MATRIX_REPORT_NAME), - suite_name=_MATRIX_REPORT_NAME, - multi_target=True) + report_utils.render_junit_xml_report( + resultset, + _report_filename(_MATRIX_REPORT_NAME), + suite_name=_MATRIX_REPORT_NAME, + multi_target=True, + ) if num_failures == 0: - jobset.message('SUCCESS', - 'All run_tests.py instances finished successfully.', - do_newline=True) + jobset.message( + "SUCCESS", + "All run_tests.py instances finished successfully.", + do_newline=True, + ) else: - jobset.message('FAILED', - 'Some run_tests.py instances have failed.', - do_newline=True) + jobset.message( + "FAILED", + "Some run_tests.py instances have failed.", + do_newline=True, + ) sys.exit(1) diff --git a/tools/run_tests/run_xds_tests.py b/tools/run_tests/run_xds_tests.py index 84dedd379b7de..a727d8614ae0a 100755 --- a/tools/run_tests/run_xds_tests.py +++ b/tools/run_tests/run_xds_tests.py @@ -48,8 +48,9 @@ from envoy.extensions.filters.common.fault.v3 import fault_pb2 from envoy.extensions.filters.http.fault.v3 import fault_pb2 from envoy.extensions.filters.http.router.v3 import router_pb2 - from envoy.extensions.filters.network.http_connection_manager.v3 import \ - http_connection_manager_pb2 + from envoy.extensions.filters.network.http_connection_manager.v3 import ( + http_connection_manager_pb2, + ) from envoy.service.status.v3 import csds_pb2 from envoy.service.status.v3 import csds_pb2_grpc except ImportError: @@ -59,36 +60,36 @@ logger = logging.getLogger() console_handler = logging.StreamHandler() -formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') +formatter = logging.Formatter(fmt="%(asctime)s: %(levelname)-8s %(message)s") console_handler.setFormatter(formatter) logger.handlers = [] logger.addHandler(console_handler) logger.setLevel(logging.WARNING) # Suppress excessive logs for gRPC Python -original_grpc_trace = os.environ.pop('GRPC_TRACE', None) -original_grpc_verbosity = os.environ.pop('GRPC_VERBOSITY', None) +original_grpc_trace = os.environ.pop("GRPC_TRACE", None) +original_grpc_verbosity = os.environ.pop("GRPC_VERBOSITY", None) # Suppress not-essential logs for GCP clients -logging.getLogger('google_auth_httplib2').setLevel(logging.WARNING) -logging.getLogger('googleapiclient.discovery').setLevel(logging.WARNING) +logging.getLogger("google_auth_httplib2").setLevel(logging.WARNING) +logging.getLogger("googleapiclient.discovery").setLevel(logging.WARNING) _TEST_CASES = [ - 'backends_restart', - 'change_backend_service', - 'gentle_failover', - 'load_report_based_failover', - 'ping_pong', - 'remove_instance_group', - 'round_robin', - 'secondary_locality_gets_no_requests_on_partial_primary_failure', - 'secondary_locality_gets_requests_on_primary_failure', - 'traffic_splitting', - 'path_matching', - 'header_matching', - 'api_listener', - 'forwarding_rule_port_match', - 'forwarding_rule_default_port', - 'metadata_filter', + "backends_restart", + "change_backend_service", + "gentle_failover", + "load_report_based_failover", + "ping_pong", + "remove_instance_group", + "round_robin", + "secondary_locality_gets_no_requests_on_partial_primary_failure", + "secondary_locality_gets_requests_on_primary_failure", + "traffic_splitting", + "path_matching", + "header_matching", + "api_listener", + "forwarding_rule_port_match", + "forwarding_rule_default_port", + "metadata_filter", ] # Valid test cases, but not in all. So the tests can only run manually, and @@ -96,23 +97,23 @@ # # TODO: Move them into _TEST_CASES when support is ready in all languages. _ADDITIONAL_TEST_CASES = [ - 'circuit_breaking', - 'timeout', - 'fault_injection', - 'csds', + "circuit_breaking", + "timeout", + "fault_injection", + "csds", ] # Test cases that require the V3 API. Skipped in older runs. -_V3_TEST_CASES = frozenset(['timeout', 'fault_injection', 'csds']) +_V3_TEST_CASES = frozenset(["timeout", "fault_injection", "csds"]) # Test cases that require the alpha API. Skipped for stable API runs. -_ALPHA_TEST_CASES = frozenset(['timeout']) +_ALPHA_TEST_CASES = frozenset(["timeout"]) def parse_test_cases(arg): - if arg == '': + if arg == "": return [] - arg_split = arg.split(',') + arg_split = arg.split(",") test_cases = set() all_test_cases = _TEST_CASES + _ADDITIONAL_TEST_CASES for arg in arg_split: @@ -121,7 +122,7 @@ def parse_test_cases(arg): else: test_cases = test_cases.union([arg]) if not all([test_case in all_test_cases for test_case in test_cases]): - raise Exception('Failed to parse test cases %s' % arg) + raise Exception("Failed to parse test cases %s" % arg) # Perserve order. return [x for x in all_test_cases if x in test_cases] @@ -131,149 +132,206 @@ def parse_port_range(port_arg): port = int(port_arg) return list(range(port, port + 1)) except: - port_min, port_max = port_arg.split(':') + port_min, port_max = port_arg.split(":") return list(range(int(port_min), int(port_max) + 1)) -argp = argparse.ArgumentParser(description='Run xDS interop tests on GCP') +argp = argparse.ArgumentParser(description="Run xDS interop tests on GCP") # TODO(zdapeng): remove default value of project_id and project_num -argp.add_argument('--project_id', default='grpc-testing', help='GCP project id') -argp.add_argument('--project_num', - default='830293263384', - help='GCP project number') +argp.add_argument("--project_id", default="grpc-testing", help="GCP project id") +argp.add_argument( + "--project_num", default="830293263384", help="GCP project number" +) argp.add_argument( - '--gcp_suffix', - default='', - help='Optional suffix for all generated GCP resource names. Useful to ' - 'ensure distinct names across test runs.') + "--gcp_suffix", + default="", + help=( + "Optional suffix for all generated GCP resource names. Useful to " + "ensure distinct names across test runs." + ), +) argp.add_argument( - '--test_case', - default='ping_pong', + "--test_case", + default="ping_pong", type=parse_test_cases, - help='Comma-separated list of test cases to run. Available tests: %s, ' - '(or \'all\' to run every test). ' - 'Alternative tests not included in \'all\': %s' % - (','.join(_TEST_CASES), ','.join(_ADDITIONAL_TEST_CASES))) + help=( + "Comma-separated list of test cases to run. Available tests: %s, " + "(or 'all' to run every test). " + "Alternative tests not included in 'all': %s" + ) + % (",".join(_TEST_CASES), ",".join(_ADDITIONAL_TEST_CASES)), +) argp.add_argument( - '--bootstrap_file', - default='', - help='File to reference via GRPC_XDS_BOOTSTRAP. Disables built-in ' - 'bootstrap generation') + "--bootstrap_file", + default="", + help=( + "File to reference via GRPC_XDS_BOOTSTRAP. Disables built-in " + "bootstrap generation" + ), +) argp.add_argument( - '--xds_v3_support', + "--xds_v3_support", default=False, - action='store_true', - help='Support xDS v3 via GRPC_XDS_EXPERIMENTAL_V3_SUPPORT. ' - 'If a pre-created bootstrap file is provided via the --bootstrap_file ' - 'parameter, it should include xds_v3 in its server_features field.') + action="store_true", + help=( + "Support xDS v3 via GRPC_XDS_EXPERIMENTAL_V3_SUPPORT. " + "If a pre-created bootstrap file is provided via the --bootstrap_file " + "parameter, it should include xds_v3 in its server_features field." + ), +) argp.add_argument( - '--client_cmd', + "--client_cmd", default=None, - help='Command to launch xDS test client. {server_uri}, {stats_port} and ' - '{qps} references will be replaced using str.format(). GRPC_XDS_BOOTSTRAP ' - 'will be set for the command') + help=( + "Command to launch xDS test client. {server_uri}, {stats_port} and" + " {qps} references will be replaced using str.format()." + " GRPC_XDS_BOOTSTRAP will be set for the command" + ), +) argp.add_argument( - '--client_hosts', + "--client_hosts", default=None, - help='Comma-separated list of hosts running client processes. If set, ' - '--client_cmd is ignored and client processes are assumed to be running on ' - 'the specified hosts.') -argp.add_argument('--zone', default='us-central1-a') -argp.add_argument('--secondary_zone', - default='us-west1-b', - help='Zone to use for secondary TD locality tests') -argp.add_argument('--qps', default=100, type=int, help='Client QPS') + help=( + "Comma-separated list of hosts running client processes. If set," + " --client_cmd is ignored and client processes are assumed to be" + " running on the specified hosts." + ), +) +argp.add_argument("--zone", default="us-central1-a") argp.add_argument( - '--wait_for_backend_sec', + "--secondary_zone", + default="us-west1-b", + help="Zone to use for secondary TD locality tests", +) +argp.add_argument("--qps", default=100, type=int, help="Client QPS") +argp.add_argument( + "--wait_for_backend_sec", default=1200, type=int, - help='Time limit for waiting for created backend services to report ' - 'healthy when launching or updated GCP resources') + help=( + "Time limit for waiting for created backend services to report " + "healthy when launching or updated GCP resources" + ), +) argp.add_argument( - '--use_existing_gcp_resources', + "--use_existing_gcp_resources", default=False, - action='store_true', - help= - 'If set, find and use already created GCP resources instead of creating new' - ' ones.') + action="store_true", + help=( + "If set, find and use already created GCP resources instead of creating" + " new ones." + ), +) argp.add_argument( - '--keep_gcp_resources', + "--keep_gcp_resources", default=False, - action='store_true', - help= - 'Leave GCP VMs and configuration running after test. Default behavior is ' - 'to delete when tests complete.') -argp.add_argument('--halt_after_fail', - action='store_true', - help='Halt and save the resources when test failed.') + action="store_true", + help=( + "Leave GCP VMs and configuration running after test. Default behavior" + " is to delete when tests complete." + ), +) +argp.add_argument( + "--halt_after_fail", + action="store_true", + help="Halt and save the resources when test failed.", +) argp.add_argument( - '--compute_discovery_document', + "--compute_discovery_document", default=None, type=str, - help= - 'If provided, uses this file instead of retrieving via the GCP discovery ' - 'API') + help=( + "If provided, uses this file instead of retrieving via the GCP" + " discovery API" + ), +) argp.add_argument( - '--alpha_compute_discovery_document', + "--alpha_compute_discovery_document", default=None, type=str, - help='If provided, uses this file instead of retrieving via the alpha GCP ' - 'discovery API') -argp.add_argument('--network', - default='global/networks/default', - help='GCP network to use') -_DEFAULT_PORT_RANGE = '8080:8280' -argp.add_argument('--service_port_range', - default=_DEFAULT_PORT_RANGE, - type=parse_port_range, - help='Listening port for created gRPC backends. Specified as ' - 'either a single int or as a range in the format min:max, in ' - 'which case an available port p will be chosen s.t. min <= p ' - '<= max') + help=( + "If provided, uses this file instead of retrieving via the alpha GCP " + "discovery API" + ), +) +argp.add_argument( + "--network", default="global/networks/default", help="GCP network to use" +) +_DEFAULT_PORT_RANGE = "8080:8280" argp.add_argument( - '--stats_port', + "--service_port_range", + default=_DEFAULT_PORT_RANGE, + type=parse_port_range, + help=( + "Listening port for created gRPC backends. Specified as " + "either a single int or as a range in the format min:max, in " + "which case an available port p will be chosen s.t. min <= p " + "<= max" + ), +) +argp.add_argument( + "--stats_port", default=8079, type=int, - help='Local port for the client process to expose the LB stats service') -argp.add_argument('--xds_server', - default='trafficdirector.googleapis.com:443', - help='xDS server') -argp.add_argument('--source_image', - default='projects/debian-cloud/global/images/family/debian-9', - help='Source image for VMs created during the test') -argp.add_argument('--path_to_server_binary', - default=None, - type=str, - help='If set, the server binary must already be pre-built on ' - 'the specified source image') -argp.add_argument('--machine_type', - default='e2-standard-2', - help='Machine type for VMs created during the test') + help="Local port for the client process to expose the LB stats service", +) +argp.add_argument( + "--xds_server", + default="trafficdirector.googleapis.com:443", + help="xDS server", +) +argp.add_argument( + "--source_image", + default="projects/debian-cloud/global/images/family/debian-9", + help="Source image for VMs created during the test", +) +argp.add_argument( + "--path_to_server_binary", + default=None, + type=str, + help=( + "If set, the server binary must already be pre-built on " + "the specified source image" + ), +) argp.add_argument( - '--instance_group_size', + "--machine_type", + default="e2-standard-2", + help="Machine type for VMs created during the test", +) +argp.add_argument( + "--instance_group_size", default=2, type=int, - help='Number of VMs to create per instance group. Certain test cases (e.g., ' - 'round_robin) may not give meaningful results if this is set to a value ' - 'less than 2.') -argp.add_argument('--verbose', - help='verbose log output', - default=False, - action='store_true') + help=( + "Number of VMs to create per instance group. Certain test cases (e.g.," + " round_robin) may not give meaningful results if this is set to a" + " value less than 2." + ), +) +argp.add_argument( + "--verbose", help="verbose log output", default=False, action="store_true" +) # TODO(ericgribkoff) Remove this param once the sponge-formatted log files are # visible in all test environments. -argp.add_argument('--log_client_output', - help='Log captured client output', - default=False, - action='store_true') +argp.add_argument( + "--log_client_output", + help="Log captured client output", + default=False, + action="store_true", +) # TODO(ericgribkoff) Remove this flag once all test environments are verified to # have access to the alpha compute APIs. -argp.add_argument('--only_stable_gcp_apis', - help='Do not use alpha compute APIs. Some tests may be ' - 'incompatible with this option (gRPC health checks are ' - 'currently alpha and required for simulating server failure', - default=False, - action='store_true') +argp.add_argument( + "--only_stable_gcp_apis", + help=( + "Do not use alpha compute APIs. Some tests may be " + "incompatible with this option (gRPC health checks are " + "currently alpha and required for simulating server failure" + ), + default=False, + action="store_true", +) args = argp.parse_args() if args.verbose: @@ -281,7 +339,7 @@ def parse_port_range(port_arg): CLIENT_HOSTS = [] if args.client_hosts: - CLIENT_HOSTS = args.client_hosts.split(',') + CLIENT_HOSTS = args.client_hosts.split(",") # Each of the config propagation in the control plane should finish within 600s. # Otherwise, it indicates a bug in the control plane. The config propagation @@ -324,58 +382,68 @@ def parse_port_range(port_arg): ], "server_features": {server_features} }}] -}}""" % (args.network.split('/')[-1], args.zone, args.xds_server) +}}""" % ( + args.network.split("/")[-1], + args.zone, + args.xds_server, +) # TODO(ericgribkoff) Add change_backend_service to this list once TD no longer # sends an update with no localities when adding the MIG to the backend service # can race with the URL map patch. -_TESTS_TO_FAIL_ON_RPC_FAILURE = ['ping_pong', 'round_robin'] +_TESTS_TO_FAIL_ON_RPC_FAILURE = ["ping_pong", "round_robin"] # Tests that run UnaryCall and EmptyCall. -_TESTS_TO_RUN_MULTIPLE_RPCS = ['path_matching', 'header_matching'] +_TESTS_TO_RUN_MULTIPLE_RPCS = ["path_matching", "header_matching"] # Tests that make UnaryCall with test metadata. -_TESTS_TO_SEND_METADATA = ['header_matching'] -_TEST_METADATA_KEY = 'xds_md' -_TEST_METADATA_VALUE_UNARY = 'unary_yranu' -_TEST_METADATA_VALUE_EMPTY = 'empty_ytpme' +_TESTS_TO_SEND_METADATA = ["header_matching"] +_TEST_METADATA_KEY = "xds_md" +_TEST_METADATA_VALUE_UNARY = "unary_yranu" +_TEST_METADATA_VALUE_EMPTY = "empty_ytpme" # Extra RPC metadata whose value is a number, sent with UnaryCall only. -_TEST_METADATA_NUMERIC_KEY = 'xds_md_numeric' -_TEST_METADATA_NUMERIC_VALUE = '159' -_PATH_MATCHER_NAME = 'path-matcher' -_BASE_TEMPLATE_NAME = 'test-template' -_BASE_INSTANCE_GROUP_NAME = 'test-ig' -_BASE_HEALTH_CHECK_NAME = 'test-hc' -_BASE_FIREWALL_RULE_NAME = 'test-fw-rule' -_BASE_BACKEND_SERVICE_NAME = 'test-backend-service' -_BASE_URL_MAP_NAME = 'test-map' -_BASE_SERVICE_HOST = 'grpc-test' -_BASE_TARGET_PROXY_NAME = 'test-target-proxy' -_BASE_FORWARDING_RULE_NAME = 'test-forwarding-rule' -_TEST_LOG_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), - '../../reports') -_SPONGE_LOG_NAME = 'sponge_log.log' -_SPONGE_XML_NAME = 'sponge_log.xml' +_TEST_METADATA_NUMERIC_KEY = "xds_md_numeric" +_TEST_METADATA_NUMERIC_VALUE = "159" +_PATH_MATCHER_NAME = "path-matcher" +_BASE_TEMPLATE_NAME = "test-template" +_BASE_INSTANCE_GROUP_NAME = "test-ig" +_BASE_HEALTH_CHECK_NAME = "test-hc" +_BASE_FIREWALL_RULE_NAME = "test-fw-rule" +_BASE_BACKEND_SERVICE_NAME = "test-backend-service" +_BASE_URL_MAP_NAME = "test-map" +_BASE_SERVICE_HOST = "grpc-test" +_BASE_TARGET_PROXY_NAME = "test-target-proxy" +_BASE_FORWARDING_RULE_NAME = "test-forwarding-rule" +_TEST_LOG_BASE_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../reports" +) +_SPONGE_LOG_NAME = "sponge_log.log" +_SPONGE_XML_NAME = "sponge_log.xml" def get_client_stats(num_rpcs, timeout_sec): if CLIENT_HOSTS: hosts = CLIENT_HOSTS else: - hosts = ['localhost'] + hosts = ["localhost"] for host in hosts: - with grpc.insecure_channel('%s:%d' % - (host, args.stats_port)) as channel: + with grpc.insecure_channel( + "%s:%d" % (host, args.stats_port) + ) as channel: stub = test_pb2_grpc.LoadBalancerStatsServiceStub(channel) request = messages_pb2.LoadBalancerStatsRequest() request.num_rpcs = num_rpcs request.timeout_sec = timeout_sec rpc_timeout = timeout_sec + _CONNECTION_TIMEOUT_SEC - logger.debug('Invoking GetClientStats RPC to %s:%d:', host, - args.stats_port) - response = stub.GetClientStats(request, - wait_for_ready=True, - timeout=rpc_timeout) - logger.debug('Invoked GetClientStats RPC to %s: %s', host, - json_format.MessageToJson(response)) + logger.debug( + "Invoking GetClientStats RPC to %s:%d:", host, args.stats_port + ) + response = stub.GetClientStats( + request, wait_for_ready=True, timeout=rpc_timeout + ) + logger.debug( + "Invoked GetClientStats RPC to %s: %s", + host, + json_format.MessageToJson(response), + ) return response @@ -383,18 +451,26 @@ def get_client_accumulated_stats(): if CLIENT_HOSTS: hosts = CLIENT_HOSTS else: - hosts = ['localhost'] + hosts = ["localhost"] for host in hosts: - with grpc.insecure_channel('%s:%d' % - (host, args.stats_port)) as channel: + with grpc.insecure_channel( + "%s:%d" % (host, args.stats_port) + ) as channel: stub = test_pb2_grpc.LoadBalancerStatsServiceStub(channel) request = messages_pb2.LoadBalancerAccumulatedStatsRequest() - logger.debug('Invoking GetClientAccumulatedStats RPC to %s:%d:', - host, args.stats_port) + logger.debug( + "Invoking GetClientAccumulatedStats RPC to %s:%d:", + host, + args.stats_port, + ) response = stub.GetClientAccumulatedStats( - request, wait_for_ready=True, timeout=_CONNECTION_TIMEOUT_SEC) - logger.debug('Invoked GetClientAccumulatedStats RPC to %s: %s', - host, response) + request, wait_for_ready=True, timeout=_CONNECTION_TIMEOUT_SEC + ) + logger.debug( + "Invoked GetClientAccumulatedStats RPC to %s: %s", + host, + response, + ) return response @@ -402,36 +478,43 @@ def get_client_xds_config_dump(): if CLIENT_HOSTS: hosts = CLIENT_HOSTS else: - hosts = ['localhost'] + hosts = ["localhost"] for host in hosts: - server_address = '%s:%d' % (host, args.stats_port) + server_address = "%s:%d" % (host, args.stats_port) with grpc.insecure_channel(server_address) as channel: stub = csds_pb2_grpc.ClientStatusDiscoveryServiceStub(channel) - logger.debug('Fetching xDS config dump from %s', server_address) - response = stub.FetchClientStatus(csds_pb2.ClientStatusRequest(), - wait_for_ready=True, - timeout=_CONNECTION_TIMEOUT_SEC) - logger.debug('Fetched xDS config dump from %s', server_address) + logger.debug("Fetching xDS config dump from %s", server_address) + response = stub.FetchClientStatus( + csds_pb2.ClientStatusRequest(), + wait_for_ready=True, + timeout=_CONNECTION_TIMEOUT_SEC, + ) + logger.debug("Fetched xDS config dump from %s", server_address) if len(response.config) != 1: - logger.error('Unexpected number of ClientConfigs %d: %s', - len(response.config), response) + logger.error( + "Unexpected number of ClientConfigs %d: %s", + len(response.config), + response, + ) return None else: # Converting the ClientStatusResponse into JSON, because many # fields are packed in google.protobuf.Any. It will require many # duplicated code to unpack proto message and inspect values. return json_format.MessageToDict( - response.config[0], preserving_proto_field_name=True) + response.config[0], preserving_proto_field_name=True + ) def configure_client(rpc_types, metadata=[], timeout_sec=None): if CLIENT_HOSTS: hosts = CLIENT_HOSTS else: - hosts = ['localhost'] + hosts = ["localhost"] for host in hosts: - with grpc.insecure_channel('%s:%d' % - (host, args.stats_port)) as channel: + with grpc.insecure_channel( + "%s:%d" % (host, args.stats_port) + ) as channel: stub = test_pb2_grpc.XdsUpdateClientConfigureServiceStub(channel) request = messages_pb2.ClientConfigureRequest() request.types.extend(rpc_types) @@ -443,58 +526,63 @@ def configure_client(rpc_types, metadata=[], timeout_sec=None): if timeout_sec: request.timeout_sec = timeout_sec logger.debug( - 'Invoking XdsUpdateClientConfigureService RPC to %s:%d: %s', - host, args.stats_port, request) - stub.Configure(request, - wait_for_ready=True, - timeout=_CONNECTION_TIMEOUT_SEC) - logger.debug('Invoked XdsUpdateClientConfigureService RPC to %s', - host) + "Invoking XdsUpdateClientConfigureService RPC to %s:%d: %s", + host, + args.stats_port, + request, + ) + stub.Configure( + request, wait_for_ready=True, timeout=_CONNECTION_TIMEOUT_SEC + ) + logger.debug( + "Invoked XdsUpdateClientConfigureService RPC to %s", host + ) class RpcDistributionError(Exception): pass -def _verify_rpcs_to_given_backends(backends, timeout_sec, num_rpcs, - allow_failures): +def _verify_rpcs_to_given_backends( + backends, timeout_sec, num_rpcs, allow_failures +): start_time = time.time() error_msg = None - logger.debug('Waiting for %d sec until backends %s receive load' % - (timeout_sec, backends)) + logger.debug( + "Waiting for %d sec until backends %s receive load" + % (timeout_sec, backends) + ) while time.time() - start_time <= timeout_sec: error_msg = None stats = get_client_stats(num_rpcs, timeout_sec) rpcs_by_peer = stats.rpcs_by_peer for backend in backends: if backend not in rpcs_by_peer: - error_msg = 'Backend %s did not receive load' % backend + error_msg = "Backend %s did not receive load" % backend break if not error_msg and len(rpcs_by_peer) > len(backends): - error_msg = 'Unexpected backend received load: %s' % rpcs_by_peer + error_msg = "Unexpected backend received load: %s" % rpcs_by_peer if not allow_failures and stats.num_failures > 0: - error_msg = '%d RPCs failed' % stats.num_failures + error_msg = "%d RPCs failed" % stats.num_failures if not error_msg: return raise RpcDistributionError(error_msg) -def wait_until_all_rpcs_go_to_given_backends_or_fail(backends, - timeout_sec, - num_rpcs=_NUM_TEST_RPCS): - _verify_rpcs_to_given_backends(backends, - timeout_sec, - num_rpcs, - allow_failures=True) +def wait_until_all_rpcs_go_to_given_backends_or_fail( + backends, timeout_sec, num_rpcs=_NUM_TEST_RPCS +): + _verify_rpcs_to_given_backends( + backends, timeout_sec, num_rpcs, allow_failures=True + ) -def wait_until_all_rpcs_go_to_given_backends(backends, - timeout_sec, - num_rpcs=_NUM_TEST_RPCS): - _verify_rpcs_to_given_backends(backends, - timeout_sec, - num_rpcs, - allow_failures=False) +def wait_until_all_rpcs_go_to_given_backends( + backends, timeout_sec, num_rpcs=_NUM_TEST_RPCS +): + _verify_rpcs_to_given_backends( + backends, timeout_sec, num_rpcs, allow_failures=False + ) def wait_until_no_rpcs_go_to_given_backends(backends, timeout_sec): @@ -505,15 +593,15 @@ def wait_until_no_rpcs_go_to_given_backends(backends, timeout_sec): rpcs_by_peer = stats.rpcs_by_peer for backend in backends: if backend in rpcs_by_peer: - error_msg = 'Unexpected backend %s receives load' % backend + error_msg = "Unexpected backend %s receives load" % backend break if not error_msg: return - raise Exception('Unexpected RPCs going to given backends') + raise Exception("Unexpected RPCs going to given backends") def wait_until_rpcs_in_flight(rpc_type, timeout_sec, num_rpcs, threshold): - '''Block until the test client reaches the state with the given number + """Block until the test client reaches the state with the given number of RPCs being outstanding stably. Args: @@ -524,31 +612,35 @@ def wait_until_rpcs_in_flight(rpc_type, timeout_sec, num_rpcs, threshold): num_rpcs: Expected number of RPCs to be in-flight. threshold: Number within [0,100], the tolerable percentage by which the actual number of RPCs in-flight can differ from the expected number. - ''' + """ if threshold < 0 or threshold > 100: - raise ValueError('Value error: Threshold should be between 0 to 100') + raise ValueError("Value error: Threshold should be between 0 to 100") threshold_fraction = threshold / 100.0 start_time = time.time() error_msg = None logger.debug( - 'Waiting for %d sec until %d %s RPCs (with %d%% tolerance) in-flight' % - (timeout_sec, num_rpcs, rpc_type, threshold)) + "Waiting for %d sec until %d %s RPCs (with %d%% tolerance) in-flight" + % (timeout_sec, num_rpcs, rpc_type, threshold) + ) while time.time() - start_time <= timeout_sec: - error_msg = _check_rpcs_in_flight(rpc_type, num_rpcs, threshold, - threshold_fraction) + error_msg = _check_rpcs_in_flight( + rpc_type, num_rpcs, threshold, threshold_fraction + ) if error_msg: - logger.debug('Progress: %s', error_msg) + logger.debug("Progress: %s", error_msg) time.sleep(2) else: break # Ensure the number of outstanding RPCs is stable. if not error_msg: time.sleep(5) - error_msg = _check_rpcs_in_flight(rpc_type, num_rpcs, threshold, - threshold_fraction) + error_msg = _check_rpcs_in_flight( + rpc_type, num_rpcs, threshold, threshold_fraction + ) if error_msg: - raise Exception("Wrong number of %s RPCs in-flight: %s" % - (rpc_type, error_msg)) + raise Exception( + "Wrong number of %s RPCs in-flight: %s" % (rpc_type, error_msg) + ) def _check_rpcs_in_flight(rpc_type, num_rpcs, threshold, threshold_fraction): @@ -559,16 +651,23 @@ def _check_rpcs_in_flight(rpc_type, num_rpcs, threshold, threshold_fraction): rpcs_failed = stats.num_rpcs_failed_by_method[rpc_type] rpcs_in_flight = rpcs_started - rpcs_succeeded - rpcs_failed if rpcs_in_flight < (num_rpcs * (1 - threshold_fraction)): - error_msg = ('actual(%d) < expected(%d - %d%%)' % - (rpcs_in_flight, num_rpcs, threshold)) + error_msg = "actual(%d) < expected(%d - %d%%)" % ( + rpcs_in_flight, + num_rpcs, + threshold, + ) elif rpcs_in_flight > (num_rpcs * (1 + threshold_fraction)): - error_msg = ('actual(%d) > expected(%d + %d%%)' % - (rpcs_in_flight, num_rpcs, threshold)) + error_msg = "actual(%d) > expected(%d + %d%%)" % ( + rpcs_in_flight, + num_rpcs, + threshold, + ) return error_msg -def compare_distributions(actual_distribution, expected_distribution, - threshold): +def compare_distributions( + actual_distribution, expected_distribution, threshold +): """Compare if two distributions are similar. Args: @@ -588,18 +687,21 @@ def compare_distributions(actual_distribution, expected_distribution, """ if len(expected_distribution) != len(actual_distribution): raise Exception( - 'Error: expected and actual distributions have different size (%d vs %d)' - % (len(expected_distribution), len(actual_distribution))) + "Error: expected and actual distributions have different size (%d" + " vs %d)" % (len(expected_distribution), len(actual_distribution)) + ) if threshold < 0 or threshold > 100: - raise ValueError('Value error: Threshold should be between 0 to 100') + raise ValueError("Value error: Threshold should be between 0 to 100") threshold_fraction = threshold / 100.0 for expected, actual in zip(expected_distribution, actual_distribution): if actual < (expected * (1 - threshold_fraction)): - raise Exception("actual(%f) < expected(%f-%d%%)" % - (actual, expected, threshold)) + raise Exception( + "actual(%f) < expected(%f-%d%%)" % (actual, expected, threshold) + ) if actual > (expected * (1 + threshold_fraction)): - raise Exception("actual(%f) > expected(%f+%d%%)" % - (actual, expected, threshold)) + raise Exception( + "actual(%f) > expected(%f+%d%%)" % (actual, expected, threshold) + ) return True @@ -616,54 +718,74 @@ def compare_expected_instances(stats, expected_instances): """ for rpc_type, expected_peers in list(expected_instances.items()): rpcs_by_peer_for_type = stats.rpcs_by_method[rpc_type] - rpcs_by_peer = rpcs_by_peer_for_type.rpcs_by_peer if rpcs_by_peer_for_type else None - logger.debug('rpc: %s, by_peer: %s', rpc_type, rpcs_by_peer) + rpcs_by_peer = ( + rpcs_by_peer_for_type.rpcs_by_peer + if rpcs_by_peer_for_type + else None + ) + logger.debug("rpc: %s, by_peer: %s", rpc_type, rpcs_by_peer) peers = list(rpcs_by_peer.keys()) if set(peers) != set(expected_peers): - logger.info('unexpected peers for %s, got %s, want %s', rpc_type, - peers, expected_peers) + logger.info( + "unexpected peers for %s, got %s, want %s", + rpc_type, + peers, + expected_peers, + ) return False return True def test_backends_restart(gcp, backend_service, instance_group): - logger.info('Running test_backends_restart') + logger.info("Running test_backends_restart") instance_names = get_instance_names(gcp, instance_group) num_instances = len(instance_names) start_time = time.time() - wait_until_all_rpcs_go_to_given_backends(instance_names, - _WAIT_FOR_STATS_SEC) + wait_until_all_rpcs_go_to_given_backends( + instance_names, _WAIT_FOR_STATS_SEC + ) try: resize_instance_group(gcp, instance_group, 0) - wait_until_all_rpcs_go_to_given_backends_or_fail([], - _WAIT_FOR_BACKEND_SEC) + wait_until_all_rpcs_go_to_given_backends_or_fail( + [], _WAIT_FOR_BACKEND_SEC + ) finally: resize_instance_group(gcp, instance_group, num_instances) wait_for_healthy_backends(gcp, backend_service, instance_group) new_instance_names = get_instance_names(gcp, instance_group) - wait_until_all_rpcs_go_to_given_backends(new_instance_names, - _WAIT_FOR_BACKEND_SEC) + wait_until_all_rpcs_go_to_given_backends( + new_instance_names, _WAIT_FOR_BACKEND_SEC + ) -def test_change_backend_service(gcp, original_backend_service, instance_group, - alternate_backend_service, - same_zone_instance_group): - logger.info('Running test_change_backend_service') +def test_change_backend_service( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, +): + logger.info("Running test_change_backend_service") original_backend_instances = get_instance_names(gcp, instance_group) - alternate_backend_instances = get_instance_names(gcp, - same_zone_instance_group) - patch_backend_service(gcp, alternate_backend_service, - [same_zone_instance_group]) + alternate_backend_instances = get_instance_names( + gcp, same_zone_instance_group + ) + patch_backend_service( + gcp, alternate_backend_service, [same_zone_instance_group] + ) wait_for_healthy_backends(gcp, original_backend_service, instance_group) - wait_for_healthy_backends(gcp, alternate_backend_service, - same_zone_instance_group) - wait_until_all_rpcs_go_to_given_backends(original_backend_instances, - _WAIT_FOR_STATS_SEC) + wait_for_healthy_backends( + gcp, alternate_backend_service, same_zone_instance_group + ) + wait_until_all_rpcs_go_to_given_backends( + original_backend_instances, _WAIT_FOR_STATS_SEC + ) passed = True try: patch_url_map_backend_service(gcp, alternate_backend_service) - wait_until_all_rpcs_go_to_given_backends(alternate_backend_instances, - _WAIT_FOR_URL_MAP_PATCH_SEC) + wait_until_all_rpcs_go_to_given_backends( + alternate_backend_instances, _WAIT_FOR_URL_MAP_PATCH_SEC + ) except Exception: passed = False raise @@ -673,52 +795,64 @@ def test_change_backend_service(gcp, original_backend_service, instance_group, patch_backend_service(gcp, alternate_backend_service, []) -def test_gentle_failover(gcp, - backend_service, - primary_instance_group, - secondary_instance_group, - swapped_primary_and_secondary=False): - logger.info('Running test_gentle_failover') +def test_gentle_failover( + gcp, + backend_service, + primary_instance_group, + secondary_instance_group, + swapped_primary_and_secondary=False, +): + logger.info("Running test_gentle_failover") num_primary_instances = len(get_instance_names(gcp, primary_instance_group)) min_instances_for_gentle_failover = 3 # Need >50% failure to start failover passed = True try: if num_primary_instances < min_instances_for_gentle_failover: - resize_instance_group(gcp, primary_instance_group, - min_instances_for_gentle_failover) + resize_instance_group( + gcp, primary_instance_group, min_instances_for_gentle_failover + ) patch_backend_service( - gcp, backend_service, - [primary_instance_group, secondary_instance_group]) + gcp, + backend_service, + [primary_instance_group, secondary_instance_group], + ) primary_instance_names = get_instance_names(gcp, primary_instance_group) - secondary_instance_names = get_instance_names(gcp, - secondary_instance_group) + secondary_instance_names = get_instance_names( + gcp, secondary_instance_group + ) wait_for_healthy_backends(gcp, backend_service, primary_instance_group) - wait_for_healthy_backends(gcp, backend_service, - secondary_instance_group) - wait_until_all_rpcs_go_to_given_backends(primary_instance_names, - _WAIT_FOR_STATS_SEC) + wait_for_healthy_backends( + gcp, backend_service, secondary_instance_group + ) + wait_until_all_rpcs_go_to_given_backends( + primary_instance_names, _WAIT_FOR_STATS_SEC + ) instances_to_stop = primary_instance_names[:-1] remaining_instances = primary_instance_names[-1:] try: - set_serving_status(instances_to_stop, - gcp.service_port, - serving=False) + set_serving_status( + instances_to_stop, gcp.service_port, serving=False + ) wait_until_all_rpcs_go_to_given_backends( remaining_instances + secondary_instance_names, - _WAIT_FOR_BACKEND_SEC) + _WAIT_FOR_BACKEND_SEC, + ) finally: - set_serving_status(primary_instance_names, - gcp.service_port, - serving=True) + set_serving_status( + primary_instance_names, gcp.service_port, serving=True + ) except RpcDistributionError as e: if not swapped_primary_and_secondary and is_primary_instance_group( - gcp, secondary_instance_group): + gcp, secondary_instance_group + ): # Swap expectation of primary and secondary instance groups. - test_gentle_failover(gcp, - backend_service, - secondary_instance_group, - primary_instance_group, - swapped_primary_and_secondary=True) + test_gentle_failover( + gcp, + backend_service, + secondary_instance_group, + primary_instance_group, + swapped_primary_and_secondary=True, + ) else: passed = False raise e @@ -727,98 +861,123 @@ def test_gentle_failover(gcp, raise finally: if passed or not args.halt_after_fail: - patch_backend_service(gcp, backend_service, - [primary_instance_group]) - resize_instance_group(gcp, primary_instance_group, - num_primary_instances) + patch_backend_service( + gcp, backend_service, [primary_instance_group] + ) + resize_instance_group( + gcp, primary_instance_group, num_primary_instances + ) instance_names = get_instance_names(gcp, primary_instance_group) - wait_until_all_rpcs_go_to_given_backends(instance_names, - _WAIT_FOR_BACKEND_SEC) + wait_until_all_rpcs_go_to_given_backends( + instance_names, _WAIT_FOR_BACKEND_SEC + ) -def test_load_report_based_failover(gcp, backend_service, - primary_instance_group, - secondary_instance_group): - logger.info('Running test_load_report_based_failover') +def test_load_report_based_failover( + gcp, backend_service, primary_instance_group, secondary_instance_group +): + logger.info("Running test_load_report_based_failover") passed = True try: patch_backend_service( - gcp, backend_service, - [primary_instance_group, secondary_instance_group]) + gcp, + backend_service, + [primary_instance_group, secondary_instance_group], + ) primary_instance_names = get_instance_names(gcp, primary_instance_group) - secondary_instance_names = get_instance_names(gcp, - secondary_instance_group) + secondary_instance_names = get_instance_names( + gcp, secondary_instance_group + ) wait_for_healthy_backends(gcp, backend_service, primary_instance_group) - wait_for_healthy_backends(gcp, backend_service, - secondary_instance_group) - wait_until_all_rpcs_go_to_given_backends(primary_instance_names, - _WAIT_FOR_STATS_SEC) + wait_for_healthy_backends( + gcp, backend_service, secondary_instance_group + ) + wait_until_all_rpcs_go_to_given_backends( + primary_instance_names, _WAIT_FOR_STATS_SEC + ) # Set primary locality's balance mode to RATE, and RPS to 20% of the # client's QPS. The secondary locality will be used. max_rate = int(args.qps * 1 / 5) - logger.info('Patching backend service to RATE with %d max_rate', - max_rate) + logger.info( + "Patching backend service to RATE with %d max_rate", max_rate + ) patch_backend_service( gcp, - backend_service, [primary_instance_group, secondary_instance_group], - balancing_mode='RATE', - max_rate=max_rate) + backend_service, + [primary_instance_group, secondary_instance_group], + balancing_mode="RATE", + max_rate=max_rate, + ) wait_until_all_rpcs_go_to_given_backends( primary_instance_names + secondary_instance_names, - _WAIT_FOR_BACKEND_SEC) + _WAIT_FOR_BACKEND_SEC, + ) # Set primary locality's balance mode to RATE, and RPS to 120% of the # client's QPS. Only the primary locality will be used. max_rate = int(args.qps * 6 / 5) - logger.info('Patching backend service to RATE with %d max_rate', - max_rate) + logger.info( + "Patching backend service to RATE with %d max_rate", max_rate + ) patch_backend_service( gcp, - backend_service, [primary_instance_group, secondary_instance_group], - balancing_mode='RATE', - max_rate=max_rate) - wait_until_all_rpcs_go_to_given_backends(primary_instance_names, - _WAIT_FOR_BACKEND_SEC) + backend_service, + [primary_instance_group, secondary_instance_group], + balancing_mode="RATE", + max_rate=max_rate, + ) + wait_until_all_rpcs_go_to_given_backends( + primary_instance_names, _WAIT_FOR_BACKEND_SEC + ) logger.info("success") except Exception: passed = False raise finally: if passed or not args.halt_after_fail: - patch_backend_service(gcp, backend_service, - [primary_instance_group]) + patch_backend_service( + gcp, backend_service, [primary_instance_group] + ) instance_names = get_instance_names(gcp, primary_instance_group) - wait_until_all_rpcs_go_to_given_backends(instance_names, - _WAIT_FOR_BACKEND_SEC) + wait_until_all_rpcs_go_to_given_backends( + instance_names, _WAIT_FOR_BACKEND_SEC + ) def test_ping_pong(gcp, backend_service, instance_group): - logger.info('Running test_ping_pong') + logger.info("Running test_ping_pong") wait_for_healthy_backends(gcp, backend_service, instance_group) instance_names = get_instance_names(gcp, instance_group) - wait_until_all_rpcs_go_to_given_backends(instance_names, - _WAIT_FOR_STATS_SEC) + wait_until_all_rpcs_go_to_given_backends( + instance_names, _WAIT_FOR_STATS_SEC + ) -def test_remove_instance_group(gcp, backend_service, instance_group, - same_zone_instance_group): - logger.info('Running test_remove_instance_group') +def test_remove_instance_group( + gcp, backend_service, instance_group, same_zone_instance_group +): + logger.info("Running test_remove_instance_group") passed = True try: - patch_backend_service(gcp, - backend_service, - [instance_group, same_zone_instance_group], - balancing_mode='RATE') + patch_backend_service( + gcp, + backend_service, + [instance_group, same_zone_instance_group], + balancing_mode="RATE", + ) wait_for_healthy_backends(gcp, backend_service, instance_group) - wait_for_healthy_backends(gcp, backend_service, - same_zone_instance_group) + wait_for_healthy_backends( + gcp, backend_service, same_zone_instance_group + ) instance_names = get_instance_names(gcp, instance_group) - same_zone_instance_names = get_instance_names(gcp, - same_zone_instance_group) + same_zone_instance_names = get_instance_names( + gcp, same_zone_instance_group + ) try: wait_until_all_rpcs_go_to_given_backends( instance_names + same_zone_instance_names, - _WAIT_FOR_OPERATION_SEC) + _WAIT_FOR_OPERATION_SEC, + ) remaining_instance_group = same_zone_instance_group remaining_instance_names = same_zone_instance_names except RpcDistributionError as e: @@ -827,36 +986,44 @@ def test_remove_instance_group(gcp, backend_service, instance_group, # with the remainder of the test case. try: wait_until_all_rpcs_go_to_given_backends( - instance_names, _WAIT_FOR_STATS_SEC) + instance_names, _WAIT_FOR_STATS_SEC + ) remaining_instance_group = same_zone_instance_group remaining_instance_names = same_zone_instance_names except RpcDistributionError as e: wait_until_all_rpcs_go_to_given_backends( - same_zone_instance_names, _WAIT_FOR_STATS_SEC) + same_zone_instance_names, _WAIT_FOR_STATS_SEC + ) remaining_instance_group = instance_group remaining_instance_names = instance_names - patch_backend_service(gcp, - backend_service, [remaining_instance_group], - balancing_mode='RATE') - wait_until_all_rpcs_go_to_given_backends(remaining_instance_names, - _WAIT_FOR_BACKEND_SEC) + patch_backend_service( + gcp, + backend_service, + [remaining_instance_group], + balancing_mode="RATE", + ) + wait_until_all_rpcs_go_to_given_backends( + remaining_instance_names, _WAIT_FOR_BACKEND_SEC + ) except Exception: passed = False raise finally: if passed or not args.halt_after_fail: patch_backend_service(gcp, backend_service, [instance_group]) - wait_until_all_rpcs_go_to_given_backends(instance_names, - _WAIT_FOR_BACKEND_SEC) + wait_until_all_rpcs_go_to_given_backends( + instance_names, _WAIT_FOR_BACKEND_SEC + ) def test_round_robin(gcp, backend_service, instance_group): - logger.info('Running test_round_robin') + logger.info("Running test_round_robin") wait_for_healthy_backends(gcp, backend_service, instance_group) instance_names = get_instance_names(gcp, instance_group) threshold = 1 - wait_until_all_rpcs_go_to_given_backends(instance_names, - _WAIT_FOR_STATS_SEC) + wait_until_all_rpcs_go_to_given_backends( + instance_names, _WAIT_FOR_STATS_SEC + ) # TODO(ericgribkoff) Delayed config propagation from earlier tests # may result in briefly receiving an empty EDS update, resulting in failed # RPCs. Retry distribution validation if this occurs; long-term fix is @@ -869,294 +1036,362 @@ def test_round_robin(gcp, backend_service, instance_group): requests_received = [stats.rpcs_by_peer[x] for x in stats.rpcs_by_peer] total_requests_received = sum(requests_received) if total_requests_received != _NUM_TEST_RPCS: - logger.info('Unexpected RPC failures, retrying: %s', stats) + logger.info("Unexpected RPC failures, retrying: %s", stats) continue expected_requests = total_requests_received / len(instance_names) for instance in instance_names: - if abs(stats.rpcs_by_peer[instance] - - expected_requests) > threshold: + if ( + abs(stats.rpcs_by_peer[instance] - expected_requests) + > threshold + ): raise Exception( - 'RPC peer distribution differs from expected by more than %d ' - 'for instance %s (%s)' % (threshold, instance, stats)) + "RPC peer distribution differs from expected by more than" + " %d for instance %s (%s)" % (threshold, instance, stats) + ) return - raise Exception('RPC failures persisted through %d retries' % max_attempts) + raise Exception("RPC failures persisted through %d retries" % max_attempts) def test_secondary_locality_gets_no_requests_on_partial_primary_failure( - gcp, - backend_service, - primary_instance_group, - secondary_instance_group, - swapped_primary_and_secondary=False): + gcp, + backend_service, + primary_instance_group, + secondary_instance_group, + swapped_primary_and_secondary=False, +): logger.info( - 'Running secondary_locality_gets_no_requests_on_partial_primary_failure' + "Running secondary_locality_gets_no_requests_on_partial_primary_failure" ) passed = True try: patch_backend_service( - gcp, backend_service, - [primary_instance_group, secondary_instance_group]) + gcp, + backend_service, + [primary_instance_group, secondary_instance_group], + ) wait_for_healthy_backends(gcp, backend_service, primary_instance_group) - wait_for_healthy_backends(gcp, backend_service, - secondary_instance_group) + wait_for_healthy_backends( + gcp, backend_service, secondary_instance_group + ) primary_instance_names = get_instance_names(gcp, primary_instance_group) - wait_until_all_rpcs_go_to_given_backends(primary_instance_names, - _WAIT_FOR_STATS_SEC) + wait_until_all_rpcs_go_to_given_backends( + primary_instance_names, _WAIT_FOR_STATS_SEC + ) instances_to_stop = primary_instance_names[:1] remaining_instances = primary_instance_names[1:] try: - set_serving_status(instances_to_stop, - gcp.service_port, - serving=False) - wait_until_all_rpcs_go_to_given_backends(remaining_instances, - _WAIT_FOR_BACKEND_SEC) + set_serving_status( + instances_to_stop, gcp.service_port, serving=False + ) + wait_until_all_rpcs_go_to_given_backends( + remaining_instances, _WAIT_FOR_BACKEND_SEC + ) finally: - set_serving_status(primary_instance_names, - gcp.service_port, - serving=True) + set_serving_status( + primary_instance_names, gcp.service_port, serving=True + ) except RpcDistributionError as e: if not swapped_primary_and_secondary and is_primary_instance_group( - gcp, secondary_instance_group): + gcp, secondary_instance_group + ): # Swap expectation of primary and secondary instance groups. test_secondary_locality_gets_no_requests_on_partial_primary_failure( gcp, backend_service, secondary_instance_group, primary_instance_group, - swapped_primary_and_secondary=True) + swapped_primary_and_secondary=True, + ) else: passed = False raise e finally: if passed or not args.halt_after_fail: - patch_backend_service(gcp, backend_service, - [primary_instance_group]) + patch_backend_service( + gcp, backend_service, [primary_instance_group] + ) def test_secondary_locality_gets_requests_on_primary_failure( - gcp, - backend_service, - primary_instance_group, - secondary_instance_group, - swapped_primary_and_secondary=False): - logger.info('Running secondary_locality_gets_requests_on_primary_failure') + gcp, + backend_service, + primary_instance_group, + secondary_instance_group, + swapped_primary_and_secondary=False, +): + logger.info("Running secondary_locality_gets_requests_on_primary_failure") passed = True try: patch_backend_service( - gcp, backend_service, - [primary_instance_group, secondary_instance_group]) + gcp, + backend_service, + [primary_instance_group, secondary_instance_group], + ) wait_for_healthy_backends(gcp, backend_service, primary_instance_group) - wait_for_healthy_backends(gcp, backend_service, - secondary_instance_group) + wait_for_healthy_backends( + gcp, backend_service, secondary_instance_group + ) primary_instance_names = get_instance_names(gcp, primary_instance_group) - secondary_instance_names = get_instance_names(gcp, - secondary_instance_group) - wait_until_all_rpcs_go_to_given_backends(primary_instance_names, - _WAIT_FOR_STATS_SEC) + secondary_instance_names = get_instance_names( + gcp, secondary_instance_group + ) + wait_until_all_rpcs_go_to_given_backends( + primary_instance_names, _WAIT_FOR_STATS_SEC + ) try: - set_serving_status(primary_instance_names, - gcp.service_port, - serving=False) - wait_until_all_rpcs_go_to_given_backends(secondary_instance_names, - _WAIT_FOR_BACKEND_SEC) + set_serving_status( + primary_instance_names, gcp.service_port, serving=False + ) + wait_until_all_rpcs_go_to_given_backends( + secondary_instance_names, _WAIT_FOR_BACKEND_SEC + ) finally: - set_serving_status(primary_instance_names, - gcp.service_port, - serving=True) + set_serving_status( + primary_instance_names, gcp.service_port, serving=True + ) except RpcDistributionError as e: if not swapped_primary_and_secondary and is_primary_instance_group( - gcp, secondary_instance_group): + gcp, secondary_instance_group + ): # Swap expectation of primary and secondary instance groups. test_secondary_locality_gets_requests_on_primary_failure( gcp, backend_service, secondary_instance_group, primary_instance_group, - swapped_primary_and_secondary=True) + swapped_primary_and_secondary=True, + ) else: passed = False raise e finally: if passed or not args.halt_after_fail: - patch_backend_service(gcp, backend_service, - [primary_instance_group]) + patch_backend_service( + gcp, backend_service, [primary_instance_group] + ) -def prepare_services_for_urlmap_tests(gcp, original_backend_service, - instance_group, alternate_backend_service, - same_zone_instance_group): - ''' +def prepare_services_for_urlmap_tests( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, +): + """ This function prepares the services to be ready for tests that modifies urlmaps. Returns: Returns original and alternate backend names as lists of strings. - ''' - logger.info('waiting for original backends to become healthy') + """ + logger.info("waiting for original backends to become healthy") wait_for_healthy_backends(gcp, original_backend_service, instance_group) - patch_backend_service(gcp, alternate_backend_service, - [same_zone_instance_group]) - logger.info('waiting for alternate to become healthy') - wait_for_healthy_backends(gcp, alternate_backend_service, - same_zone_instance_group) + patch_backend_service( + gcp, alternate_backend_service, [same_zone_instance_group] + ) + logger.info("waiting for alternate to become healthy") + wait_for_healthy_backends( + gcp, alternate_backend_service, same_zone_instance_group + ) original_backend_instances = get_instance_names(gcp, instance_group) - logger.info('original backends instances: %s', original_backend_instances) + logger.info("original backends instances: %s", original_backend_instances) - alternate_backend_instances = get_instance_names(gcp, - same_zone_instance_group) - logger.info('alternate backends instances: %s', alternate_backend_instances) + alternate_backend_instances = get_instance_names( + gcp, same_zone_instance_group + ) + logger.info("alternate backends instances: %s", alternate_backend_instances) # Start with all traffic going to original_backend_service. - logger.info('waiting for traffic to all go to original backends') - wait_until_all_rpcs_go_to_given_backends(original_backend_instances, - _WAIT_FOR_STATS_SEC) + logger.info("waiting for traffic to all go to original backends") + wait_until_all_rpcs_go_to_given_backends( + original_backend_instances, _WAIT_FOR_STATS_SEC + ) return original_backend_instances, alternate_backend_instances -def test_metadata_filter(gcp, original_backend_service, instance_group, - alternate_backend_service, same_zone_instance_group): +def test_metadata_filter( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, +): logger.info("Running test_metadata_filter") wait_for_healthy_backends(gcp, original_backend_service, instance_group) original_backend_instances = get_instance_names(gcp, instance_group) - alternate_backend_instances = get_instance_names(gcp, - same_zone_instance_group) - patch_backend_service(gcp, alternate_backend_service, - [same_zone_instance_group]) - wait_for_healthy_backends(gcp, alternate_backend_service, - same_zone_instance_group) + alternate_backend_instances = get_instance_names( + gcp, same_zone_instance_group + ) + patch_backend_service( + gcp, alternate_backend_service, [same_zone_instance_group] + ) + wait_for_healthy_backends( + gcp, alternate_backend_service, same_zone_instance_group + ) passed = True try: with open(bootstrap_path) as f: - md = json.load(f)['node']['metadata'] + md = json.load(f)["node"]["metadata"] match_labels = [] for k, v in list(md.items()): - match_labels.append({'name': k, 'value': v}) + match_labels.append({"name": k, "value": v}) - not_match_labels = [{'name': 'fake', 'value': 'fail'}] + not_match_labels = [{"name": "fake", "value": "fail"}] test_route_rules = [ # test MATCH_ALL [ { - 'priority': 0, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ALL', - 'filterLabels': not_match_labels - }] - }], - 'service': original_backend_service.url + "priority": 0, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ALL", + "filterLabels": not_match_labels, + } + ], + } + ], + "service": original_backend_service.url, }, { - 'priority': 1, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ALL', - 'filterLabels': match_labels - }] - }], - 'service': alternate_backend_service.url + "priority": 1, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ALL", + "filterLabels": match_labels, + } + ], + } + ], + "service": alternate_backend_service.url, }, ], # test mixing MATCH_ALL and MATCH_ANY # test MATCH_ALL: super set labels won't match [ { - 'priority': 0, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ALL', - 'filterLabels': not_match_labels + match_labels - }] - }], - 'service': original_backend_service.url + "priority": 0, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ALL", + "filterLabels": not_match_labels + + match_labels, + } + ], + } + ], + "service": original_backend_service.url, }, { - 'priority': 1, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ANY', - 'filterLabels': not_match_labels + match_labels - }] - }], - 'service': alternate_backend_service.url + "priority": 1, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ANY", + "filterLabels": not_match_labels + + match_labels, + } + ], + } + ], + "service": alternate_backend_service.url, }, ], # test MATCH_ANY [ { - 'priority': 0, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ANY', - 'filterLabels': not_match_labels - }] - }], - 'service': original_backend_service.url + "priority": 0, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ANY", + "filterLabels": not_match_labels, + } + ], + } + ], + "service": original_backend_service.url, }, { - 'priority': 1, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ANY', - 'filterLabels': not_match_labels + match_labels - }] - }], - 'service': alternate_backend_service.url + "priority": 1, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ANY", + "filterLabels": not_match_labels + + match_labels, + } + ], + } + ], + "service": alternate_backend_service.url, }, ], # test match multiple route rules [ { - 'priority': 0, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ANY', - 'filterLabels': match_labels - }] - }], - 'service': alternate_backend_service.url + "priority": 0, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ANY", + "filterLabels": match_labels, + } + ], + } + ], + "service": alternate_backend_service.url, }, { - 'priority': 1, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ALL', - 'filterLabels': match_labels - }] - }], - 'service': original_backend_service.url + "priority": 1, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ALL", + "filterLabels": match_labels, + } + ], + } + ], + "service": original_backend_service.url, }, - ] + ], ] for route_rules in test_route_rules: - wait_until_all_rpcs_go_to_given_backends(original_backend_instances, - _WAIT_FOR_STATS_SEC) - patch_url_map_backend_service(gcp, - original_backend_service, - route_rules=route_rules) - wait_until_no_rpcs_go_to_given_backends(original_backend_instances, - _WAIT_FOR_STATS_SEC) wait_until_all_rpcs_go_to_given_backends( - alternate_backend_instances, _WAIT_FOR_STATS_SEC) + original_backend_instances, _WAIT_FOR_STATS_SEC + ) + patch_url_map_backend_service( + gcp, original_backend_service, route_rules=route_rules + ) + wait_until_no_rpcs_go_to_given_backends( + original_backend_instances, _WAIT_FOR_STATS_SEC + ) + wait_until_all_rpcs_go_to_given_backends( + alternate_backend_instances, _WAIT_FOR_STATS_SEC + ) patch_url_map_backend_service(gcp, original_backend_service) except Exception: passed = False @@ -1166,56 +1401,74 @@ def test_metadata_filter(gcp, original_backend_service, instance_group, patch_backend_service(gcp, alternate_backend_service, []) -def test_api_listener(gcp, backend_service, instance_group, - alternate_backend_service): +def test_api_listener( + gcp, backend_service, instance_group, alternate_backend_service +): logger.info("Running api_listener") passed = True try: wait_for_healthy_backends(gcp, backend_service, instance_group) backend_instances = get_instance_names(gcp, instance_group) - wait_until_all_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + wait_until_all_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) # create a second suite of map+tp+fr with the same host name in host rule # and we have to disable proxyless validation because it needs `0.0.0.0` # ip address in fr for proxyless and also we violate ip:port uniqueness # for test purpose. See https://github.com/grpc/grpc-java/issues/8009 - new_config_suffix = '2' - url_map_2 = create_url_map(gcp, url_map_name + new_config_suffix, - backend_service, service_host_name) + new_config_suffix = "2" + url_map_2 = create_url_map( + gcp, + url_map_name + new_config_suffix, + backend_service, + service_host_name, + ) target_proxy_2 = create_target_proxy( - gcp, target_proxy_name + new_config_suffix, False, url_map_2) + gcp, target_proxy_name + new_config_suffix, False, url_map_2 + ) if not gcp.service_port: raise Exception( - 'Faied to find a valid port for the forwarding rule') + "Faied to find a valid port for the forwarding rule" + ) potential_ip_addresses = [] max_attempts = 10 for i in range(max_attempts): - potential_ip_addresses.append('10.10.10.%d' % - (random.randint(0, 255))) - create_global_forwarding_rule(gcp, - forwarding_rule_name + new_config_suffix, - [gcp.service_port], - potential_ip_addresses, target_proxy_2) + potential_ip_addresses.append( + "10.10.10.%d" % (random.randint(0, 255)) + ) + create_global_forwarding_rule( + gcp, + forwarding_rule_name + new_config_suffix, + [gcp.service_port], + potential_ip_addresses, + target_proxy_2, + ) if gcp.service_port != _DEFAULT_SERVICE_PORT: - patch_url_map_host_rule_with_port(gcp, - url_map_name + new_config_suffix, - backend_service, - service_host_name) - wait_until_all_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + patch_url_map_host_rule_with_port( + gcp, + url_map_name + new_config_suffix, + backend_service, + service_host_name, + ) + wait_until_all_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) delete_global_forwarding_rule(gcp, gcp.global_forwarding_rules[0]) delete_target_proxy(gcp, gcp.target_proxies[0]) delete_url_map(gcp, gcp.url_maps[0]) - verify_attempts = int(_WAIT_FOR_URL_MAP_PATCH_SEC / _NUM_TEST_RPCS * - args.qps) + verify_attempts = int( + _WAIT_FOR_URL_MAP_PATCH_SEC / _NUM_TEST_RPCS * args.qps + ) for i in range(verify_attempts): - wait_until_all_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + wait_until_all_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) # delete host rule for the original host name patch_url_map_backend_service(gcp, alternate_backend_service) - wait_until_no_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + wait_until_no_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) except Exception: passed = False @@ -1225,16 +1478,18 @@ def test_api_listener(gcp, backend_service, instance_group, delete_global_forwarding_rules(gcp) delete_target_proxies(gcp) delete_url_maps(gcp) - create_url_map(gcp, url_map_name, backend_service, - service_host_name) + create_url_map( + gcp, url_map_name, backend_service, service_host_name + ) create_target_proxy(gcp, target_proxy_name) - create_global_forwarding_rule(gcp, forwarding_rule_name, - potential_service_ports) + create_global_forwarding_rule( + gcp, forwarding_rule_name, potential_service_ports + ) if gcp.service_port != _DEFAULT_SERVICE_PORT: - patch_url_map_host_rule_with_port(gcp, url_map_name, - backend_service, - service_host_name) - server_uri = service_host_name + ':' + str(gcp.service_port) + patch_url_map_host_rule_with_port( + gcp, url_map_name, backend_service, service_host_name + ) + server_uri = service_host_name + ":" + str(gcp.service_port) else: server_uri = service_host_name return server_uri @@ -1246,28 +1501,36 @@ def test_forwarding_rule_port_match(gcp, backend_service, instance_group): try: wait_for_healthy_backends(gcp, backend_service, instance_group) backend_instances = get_instance_names(gcp, instance_group) - wait_until_all_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + wait_until_all_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) delete_global_forwarding_rules(gcp) - create_global_forwarding_rule(gcp, forwarding_rule_name, [ - x for x in parse_port_range(_DEFAULT_PORT_RANGE) - if x != gcp.service_port - ]) - wait_until_no_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + create_global_forwarding_rule( + gcp, + forwarding_rule_name, + [ + x + for x in parse_port_range(_DEFAULT_PORT_RANGE) + if x != gcp.service_port + ], + ) + wait_until_no_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) except Exception: passed = False raise finally: if passed or not args.halt_after_fail: delete_global_forwarding_rules(gcp) - create_global_forwarding_rule(gcp, forwarding_rule_name, - potential_service_ports) + create_global_forwarding_rule( + gcp, forwarding_rule_name, potential_service_ports + ) if gcp.service_port != _DEFAULT_SERVICE_PORT: - patch_url_map_host_rule_with_port(gcp, url_map_name, - backend_service, - service_host_name) - server_uri = service_host_name + ':' + str(gcp.service_port) + patch_url_map_host_rule_with_port( + gcp, url_map_name, backend_service, service_host_name + ) + server_uri = service_host_name + ":" + str(gcp.service_port) else: server_uri = service_host_name return server_uri @@ -1280,16 +1543,19 @@ def test_forwarding_rule_default_port(gcp, backend_service, instance_group): wait_for_healthy_backends(gcp, backend_service, instance_group) backend_instances = get_instance_names(gcp, instance_group) if gcp.service_port == _DEFAULT_SERVICE_PORT: - wait_until_all_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + wait_until_all_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) delete_global_forwarding_rules(gcp) - create_global_forwarding_rule(gcp, forwarding_rule_name, - parse_port_range(_DEFAULT_PORT_RANGE)) - patch_url_map_host_rule_with_port(gcp, url_map_name, - backend_service, - service_host_name) - wait_until_no_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + create_global_forwarding_rule( + gcp, forwarding_rule_name, parse_port_range(_DEFAULT_PORT_RANGE) + ) + patch_url_map_host_rule_with_port( + gcp, url_map_name, backend_service, service_host_name + ) + wait_until_no_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) # expect success when no port in client request service uri, and no port in url-map delete_global_forwarding_rule(gcp, gcp.global_forwarding_rules[0]) delete_target_proxy(gcp, gcp.target_proxies[0]) @@ -1299,18 +1565,23 @@ def test_forwarding_rule_default_port(gcp, backend_service, instance_group): potential_ip_addresses = [] max_attempts = 10 for i in range(max_attempts): - potential_ip_addresses.append('10.10.10.%d' % - (random.randint(0, 255))) - create_global_forwarding_rule(gcp, forwarding_rule_name, [80], - potential_ip_addresses) - wait_until_all_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + potential_ip_addresses.append( + "10.10.10.%d" % (random.randint(0, 255)) + ) + create_global_forwarding_rule( + gcp, forwarding_rule_name, [80], potential_ip_addresses + ) + wait_until_all_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) # expect failure when no port in client request uri, but specify port in url-map - patch_url_map_host_rule_with_port(gcp, url_map_name, backend_service, - service_host_name) - wait_until_no_rpcs_go_to_given_backends(backend_instances, - _WAIT_FOR_STATS_SEC) + patch_url_map_host_rule_with_port( + gcp, url_map_name, backend_service, service_host_name + ) + wait_until_no_rpcs_go_to_given_backends( + backend_instances, _WAIT_FOR_STATS_SEC + ) except Exception: passed = False raise @@ -1319,59 +1590,79 @@ def test_forwarding_rule_default_port(gcp, backend_service, instance_group): delete_global_forwarding_rules(gcp) delete_target_proxies(gcp) delete_url_maps(gcp) - create_url_map(gcp, url_map_name, backend_service, - service_host_name) + create_url_map( + gcp, url_map_name, backend_service, service_host_name + ) create_target_proxy(gcp, target_proxy_name) - create_global_forwarding_rule(gcp, forwarding_rule_name, - potential_service_ports) + create_global_forwarding_rule( + gcp, forwarding_rule_name, potential_service_ports + ) if gcp.service_port != _DEFAULT_SERVICE_PORT: - patch_url_map_host_rule_with_port(gcp, url_map_name, - backend_service, - service_host_name) - server_uri = service_host_name + ':' + str(gcp.service_port) + patch_url_map_host_rule_with_port( + gcp, url_map_name, backend_service, service_host_name + ) + server_uri = service_host_name + ":" + str(gcp.service_port) else: server_uri = service_host_name return server_uri -def test_traffic_splitting(gcp, original_backend_service, instance_group, - alternate_backend_service, same_zone_instance_group): +def test_traffic_splitting( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, +): # This test start with all traffic going to original_backend_service. Then # it updates URL-map to set default action to traffic splitting between # original and alternate. It waits for all backends in both services to # receive traffic, then verifies that weights are expected. - logger.info('Running test_traffic_splitting') + logger.info("Running test_traffic_splitting") - original_backend_instances, alternate_backend_instances = prepare_services_for_urlmap_tests( - gcp, original_backend_service, instance_group, - alternate_backend_service, same_zone_instance_group) + ( + original_backend_instances, + alternate_backend_instances, + ) = prepare_services_for_urlmap_tests( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, + ) passed = True try: # Patch urlmap, change route action to traffic splitting between # original and alternate. - logger.info('patching url map with traffic splitting') + logger.info("patching url map with traffic splitting") original_service_percentage, alternate_service_percentage = 20, 80 patch_url_map_backend_service( gcp, services_with_weights={ original_backend_service: original_service_percentage, alternate_backend_service: alternate_service_percentage, - }) + }, + ) # Split percentage between instances: [20,80] -> [10,10,40,40]. expected_instance_percentage = [ original_service_percentage * 1.0 / len(original_backend_instances) ] * len(original_backend_instances) + [ - alternate_service_percentage * 1.0 / - len(alternate_backend_instances) - ] * len(alternate_backend_instances) + alternate_service_percentage + * 1.0 + / len(alternate_backend_instances) + ] * len( + alternate_backend_instances + ) # Wait for traffic to go to both services. logger.info( - 'waiting for traffic to go to all backends (including alternate)') + "waiting for traffic to go to all backends (including alternate)" + ) wait_until_all_rpcs_go_to_given_backends( original_backend_instances + alternate_backend_instances, - _WAIT_FOR_STATS_SEC) + _WAIT_FOR_STATS_SEC, + ) # Verify that weights between two services are expected. retry_count = 10 @@ -1388,18 +1679,24 @@ def test_traffic_splitting(gcp, original_backend_service, instance_group, ] try: - compare_distributions(got_instance_percentage, - expected_instance_percentage, 5) + compare_distributions( + got_instance_percentage, expected_instance_percentage, 5 + ) except Exception as e: - logger.info('attempt %d', i) - logger.info('got percentage: %s', got_instance_percentage) - logger.info('expected percentage: %s', - expected_instance_percentage) + logger.info("attempt %d", i) + logger.info("got percentage: %s", got_instance_percentage) + logger.info( + "expected percentage: %s", expected_instance_percentage + ) logger.info(e) if i == retry_count - 1: raise Exception( - 'RPC distribution (%s) differs from expected (%s)' % - (got_instance_percentage, expected_instance_percentage)) + "RPC distribution (%s) differs from expected (%s)" + % ( + got_instance_percentage, + expected_instance_percentage, + ) + ) else: logger.info("success") break @@ -1412,8 +1709,13 @@ def test_traffic_splitting(gcp, original_backend_service, instance_group, patch_backend_service(gcp, alternate_backend_service, []) -def test_path_matching(gcp, original_backend_service, instance_group, - alternate_backend_service, same_zone_instance_group): +def test_path_matching( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, +): # This test start with all traffic (UnaryCall and EmptyCall) going to # original_backend_service. # @@ -1421,42 +1723,59 @@ def test_path_matching(gcp, original_backend_service, instance_group, # go different backends. It waits for all backends in both services to # receive traffic, then verifies that traffic goes to the expected # backends. - logger.info('Running test_path_matching') + logger.info("Running test_path_matching") - original_backend_instances, alternate_backend_instances = prepare_services_for_urlmap_tests( - gcp, original_backend_service, instance_group, - alternate_backend_service, same_zone_instance_group) + ( + original_backend_instances, + alternate_backend_instances, + ) = prepare_services_for_urlmap_tests( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, + ) passed = True try: # A list of tuples (route_rules, expected_instances). test_cases = [ ( - [{ - 'priority': 0, - # FullPath EmptyCall -> alternate_backend_service. - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/EmptyCall' - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # FullPath EmptyCall -> alternate_backend_service. + "matchRules": [ + { + "fullPathMatch": ( + "/grpc.testing.TestService/EmptyCall" + ) + } + ], + "service": alternate_backend_service.url, + } + ], { "EmptyCall": alternate_backend_instances, - "UnaryCall": original_backend_instances - }), + "UnaryCall": original_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # Prefix UnaryCall -> alternate_backend_service. - 'matchRules': [{ - 'prefixMatch': '/grpc.testing.TestService/Unary' - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # Prefix UnaryCall -> alternate_backend_service. + "matchRules": [ + {"prefixMatch": "/grpc.testing.TestService/Unary"} + ], + "service": alternate_backend_service.url, + } + ], { "UnaryCall": alternate_backend_instances, - "EmptyCall": original_backend_instances - }), + "EmptyCall": original_backend_instances, + }, + ), ( # This test case is similar to the one above (but with route # services swapped). This test has two routes (full_path and @@ -1465,71 +1784,90 @@ def test_path_matching(gcp, original_backend_service, instance_group, # client to handle duplicate Clusters in the RDS response. [ { - 'priority': 0, + "priority": 0, # Prefix UnaryCall -> original_backend_service. - 'matchRules': [{ - 'prefixMatch': '/grpc.testing.TestService/Unary' - }], - 'service': original_backend_service.url + "matchRules": [ + {"prefixMatch": "/grpc.testing.TestService/Unary"} + ], + "service": original_backend_service.url, }, { - 'priority': 1, + "priority": 1, # FullPath EmptyCall -> alternate_backend_service. - 'matchRules': [{ - 'fullPathMatch': - '/grpc.testing.TestService/EmptyCall' - }], - 'service': alternate_backend_service.url - } + "matchRules": [ + { + "fullPathMatch": ( + "/grpc.testing.TestService/EmptyCall" + ) + } + ], + "service": alternate_backend_service.url, + }, ], { "UnaryCall": original_backend_instances, - "EmptyCall": alternate_backend_instances - }), + "EmptyCall": alternate_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # Regex UnaryCall -> alternate_backend_service. - 'matchRules': [{ - 'regexMatch': - '^\/.*\/UnaryCall$' # Unary methods with any services. - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # Regex UnaryCall -> alternate_backend_service. + "matchRules": [ + { + "regexMatch": ( # Unary methods with any services. + "^\/.*\/UnaryCall$" + ) + } + ], + "service": alternate_backend_service.url, + } + ], { "UnaryCall": alternate_backend_instances, - "EmptyCall": original_backend_instances - }), + "EmptyCall": original_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # ignoreCase EmptyCall -> alternate_backend_service. - 'matchRules': [{ - # Case insensitive matching. - 'fullPathMatch': '/gRpC.tEsTinG.tEstseRvice/empTycaLl', - 'ignoreCase': True, - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # ignoreCase EmptyCall -> alternate_backend_service. + "matchRules": [ + { + # Case insensitive matching. + "fullPathMatch": ( + "/gRpC.tEsTinG.tEstseRvice/empTycaLl" + ), + "ignoreCase": True, + } + ], + "service": alternate_backend_service.url, + } + ], { "UnaryCall": original_backend_instances, - "EmptyCall": alternate_backend_instances - }), + "EmptyCall": alternate_backend_instances, + }, + ), ] - for (route_rules, expected_instances) in test_cases: - logger.info('patching url map with %s', route_rules) - patch_url_map_backend_service(gcp, - original_backend_service, - route_rules=route_rules) + for route_rules, expected_instances in test_cases: + logger.info("patching url map with %s", route_rules) + patch_url_map_backend_service( + gcp, original_backend_service, route_rules=route_rules + ) # Wait for traffic to go to both services. logger.info( - 'waiting for traffic to go to all backends (including alternate)' + "waiting for traffic to go to all backends (including" + " alternate)" ) wait_until_all_rpcs_go_to_given_backends( original_backend_instances + alternate_backend_instances, - _WAIT_FOR_STATS_SEC) + _WAIT_FOR_STATS_SEC, + ) retry_count = 80 # Each attempt takes about 5 seconds, 80 retries is equivalent to 400 @@ -1538,16 +1876,18 @@ def test_path_matching(gcp, original_backend_service, instance_group, stats = get_client_stats(_NUM_TEST_RPCS, _WAIT_FOR_STATS_SEC) if not stats.rpcs_by_method: raise ValueError( - 'stats.rpcs_by_method is None, the interop client stats service does not support this test case' + "stats.rpcs_by_method is None, the interop client stats" + " service does not support this test case" ) - logger.info('attempt %d', i) + logger.info("attempt %d", i) if compare_expected_instances(stats, expected_instances): logger.info("success") break elif i == retry_count - 1: raise Exception( - 'timeout waiting for RPCs to the expected instances: %s' - % expected_instances) + "timeout waiting for RPCs to the expected instances: %s" + % expected_instances + ) except Exception: passed = False raise @@ -1557,8 +1897,13 @@ def test_path_matching(gcp, original_backend_service, instance_group, patch_backend_service(gcp, alternate_backend_service, []) -def test_header_matching(gcp, original_backend_service, instance_group, - alternate_backend_service, same_zone_instance_group): +def test_header_matching( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, +): # This test start with all traffic (UnaryCall and EmptyCall) going to # original_backend_service. # @@ -1566,173 +1911,231 @@ def test_header_matching(gcp, original_backend_service, instance_group, # go to different backends. It waits for all backends in both services to # receive traffic, then verifies that traffic goes to the expected # backends. - logger.info('Running test_header_matching') + logger.info("Running test_header_matching") - original_backend_instances, alternate_backend_instances = prepare_services_for_urlmap_tests( - gcp, original_backend_service, instance_group, - alternate_backend_service, same_zone_instance_group) + ( + original_backend_instances, + alternate_backend_instances, + ) = prepare_services_for_urlmap_tests( + gcp, + original_backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, + ) passed = True try: # A list of tuples (route_rules, expected_instances). test_cases = [ ( - [{ - 'priority': 0, - # Header ExactMatch -> alternate_backend_service. - # EmptyCall is sent with the metadata. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'exactMatch': _TEST_METADATA_VALUE_EMPTY - }] - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # Header ExactMatch -> alternate_backend_service. + # EmptyCall is sent with the metadata. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "exactMatch": _TEST_METADATA_VALUE_EMPTY, + } + ], + } + ], + "service": alternate_backend_service.url, + } + ], { "EmptyCall": alternate_backend_instances, - "UnaryCall": original_backend_instances - }), + "UnaryCall": original_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # Header PrefixMatch -> alternate_backend_service. - # UnaryCall is sent with the metadata. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'prefixMatch': _TEST_METADATA_VALUE_UNARY[:2] - }] - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # Header PrefixMatch -> alternate_backend_service. + # UnaryCall is sent with the metadata. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "prefixMatch": _TEST_METADATA_VALUE_UNARY[ + :2 + ], + } + ], + } + ], + "service": alternate_backend_service.url, + } + ], { "EmptyCall": original_backend_instances, - "UnaryCall": alternate_backend_instances - }), + "UnaryCall": alternate_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # Header SuffixMatch -> alternate_backend_service. - # EmptyCall is sent with the metadata. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'suffixMatch': _TEST_METADATA_VALUE_EMPTY[-2:] - }] - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # Header SuffixMatch -> alternate_backend_service. + # EmptyCall is sent with the metadata. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "suffixMatch": _TEST_METADATA_VALUE_EMPTY[ + -2: + ], + } + ], + } + ], + "service": alternate_backend_service.url, + } + ], { "EmptyCall": alternate_backend_instances, - "UnaryCall": original_backend_instances - }), + "UnaryCall": original_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # Header 'xds_md_numeric' present -> alternate_backend_service. - # UnaryCall is sent with the metadata, so will be sent to alternative. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_NUMERIC_KEY, - 'presentMatch': True - }] - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # Header 'xds_md_numeric' present -> alternate_backend_service. + # UnaryCall is sent with the metadata, so will be sent to alternative. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_NUMERIC_KEY, + "presentMatch": True, + } + ], + } + ], + "service": alternate_backend_service.url, + } + ], { "EmptyCall": original_backend_instances, - "UnaryCall": alternate_backend_instances - }), + "UnaryCall": alternate_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # Header invert ExactMatch -> alternate_backend_service. - # UnaryCall is sent with the metadata, so will be sent to - # original. EmptyCall will be sent to alternative. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'exactMatch': _TEST_METADATA_VALUE_UNARY, - 'invertMatch': True - }] - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # Header invert ExactMatch -> alternate_backend_service. + # UnaryCall is sent with the metadata, so will be sent to + # original. EmptyCall will be sent to alternative. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "exactMatch": _TEST_METADATA_VALUE_UNARY, + "invertMatch": True, + } + ], + } + ], + "service": alternate_backend_service.url, + } + ], { "EmptyCall": alternate_backend_instances, - "UnaryCall": original_backend_instances - }), + "UnaryCall": original_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # Header 'xds_md_numeric' range [100,200] -> alternate_backend_service. - # UnaryCall is sent with the metadata in range. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_NUMERIC_KEY, - 'rangeMatch': { - 'rangeStart': '100', - 'rangeEnd': '200' + [ + { + "priority": 0, + # Header 'xds_md_numeric' range [100,200] -> alternate_backend_service. + # UnaryCall is sent with the metadata in range. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_NUMERIC_KEY, + "rangeMatch": { + "rangeStart": "100", + "rangeEnd": "200", + }, + } + ], } - }] - }], - 'service': alternate_backend_service.url - }], + ], + "service": alternate_backend_service.url, + } + ], { "EmptyCall": original_backend_instances, - "UnaryCall": alternate_backend_instances - }), + "UnaryCall": alternate_backend_instances, + }, + ), ( - [{ - 'priority': 0, - # Header RegexMatch -> alternate_backend_service. - # EmptyCall is sent with the metadata. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': - _TEST_METADATA_KEY, - 'regexMatch': - "^%s.*%s$" % (_TEST_METADATA_VALUE_EMPTY[:2], - _TEST_METADATA_VALUE_EMPTY[-2:]) - }] - }], - 'service': alternate_backend_service.url - }], + [ + { + "priority": 0, + # Header RegexMatch -> alternate_backend_service. + # EmptyCall is sent with the metadata. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "regexMatch": "^%s.*%s$" + % ( + _TEST_METADATA_VALUE_EMPTY[:2], + _TEST_METADATA_VALUE_EMPTY[-2:], + ), + } + ], + } + ], + "service": alternate_backend_service.url, + } + ], { "EmptyCall": alternate_backend_instances, - "UnaryCall": original_backend_instances - }), + "UnaryCall": original_backend_instances, + }, + ), ] - for (route_rules, expected_instances) in test_cases: - logger.info('patching url map with %s -> alternative', - route_rules[0]['matchRules']) - patch_url_map_backend_service(gcp, - original_backend_service, - route_rules=route_rules) + for route_rules, expected_instances in test_cases: + logger.info( + "patching url map with %s -> alternative", + route_rules[0]["matchRules"], + ) + patch_url_map_backend_service( + gcp, original_backend_service, route_rules=route_rules + ) # Wait for traffic to go to both services. logger.info( - 'waiting for traffic to go to all backends (including alternate)' + "waiting for traffic to go to all backends (including" + " alternate)" ) wait_until_all_rpcs_go_to_given_backends( original_backend_instances + alternate_backend_instances, - _WAIT_FOR_STATS_SEC) + _WAIT_FOR_STATS_SEC, + ) retry_count = 80 # Each attempt takes about 5 seconds, 80 retries is equivalent to 400 @@ -1741,16 +2144,18 @@ def test_header_matching(gcp, original_backend_service, instance_group, stats = get_client_stats(_NUM_TEST_RPCS, _WAIT_FOR_STATS_SEC) if not stats.rpcs_by_method: raise ValueError( - 'stats.rpcs_by_method is None, the interop client stats service does not support this test case' + "stats.rpcs_by_method is None, the interop client stats" + " service does not support this test case" ) - logger.info('attempt %d', i) + logger.info("attempt %d", i) if compare_expected_instances(stats, expected_instances): logger.info("success") break elif i == retry_count - 1: raise Exception( - 'timeout waiting for RPCs to the expected instances: %s' - % expected_instances) + "timeout waiting for RPCs to the expected instances: %s" + % expected_instances + ) except Exception: passed = False raise @@ -1760,9 +2165,10 @@ def test_header_matching(gcp, original_backend_service, instance_group, patch_backend_service(gcp, alternate_backend_service, []) -def test_circuit_breaking(gcp, original_backend_service, instance_group, - same_zone_instance_group): - ''' +def test_circuit_breaking( + gcp, original_backend_service, instance_group, same_zone_instance_group +): + """ Since backend service circuit_breakers configuration cannot be unset, which causes trouble for restoring validate_for_proxy flag in target proxy/global forwarding rule. This test uses dedicated backend sevices. @@ -1787,8 +2193,8 @@ def test_circuit_breaking(gcp, original_backend_service, instance_group, more_extra_backend_service (with circuit_breakers) -> [] url_map -> [original_backend_service] - ''' - logger.info('Running test_circuit_breaking') + """ + logger.info("Running test_circuit_breaking") additional_backend_services = [] passed = True try: @@ -1797,211 +2203,274 @@ def test_circuit_breaking(gcp, original_backend_service, instance_group, # breakers is resolved or configuring backend service circuit breakers is # enabled for config validation, these dedicated backend services can be # eliminated. - extra_backend_service_name = _BASE_BACKEND_SERVICE_NAME + '-extra' + gcp_suffix - more_extra_backend_service_name = _BASE_BACKEND_SERVICE_NAME + '-more-extra' + gcp_suffix - extra_backend_service = add_backend_service(gcp, - extra_backend_service_name) + extra_backend_service_name = ( + _BASE_BACKEND_SERVICE_NAME + "-extra" + gcp_suffix + ) + more_extra_backend_service_name = ( + _BASE_BACKEND_SERVICE_NAME + "-more-extra" + gcp_suffix + ) + extra_backend_service = add_backend_service( + gcp, extra_backend_service_name + ) additional_backend_services.append(extra_backend_service) more_extra_backend_service = add_backend_service( - gcp, more_extra_backend_service_name) + gcp, more_extra_backend_service_name + ) additional_backend_services.append(more_extra_backend_service) # The config validation for proxyless doesn't allow setting # circuit_breakers. Disable validate validate_for_proxyless # for this test. This can be removed when validation # accepts circuit_breakers. - logger.info('disabling validate_for_proxyless in target proxy') + logger.info("disabling validate_for_proxyless in target proxy") set_validate_for_proxyless(gcp, False) extra_backend_service_max_requests = 500 more_extra_backend_service_max_requests = 1000 - patch_backend_service(gcp, - extra_backend_service, [instance_group], - circuit_breakers={ - 'maxRequests': - extra_backend_service_max_requests - }) - logger.info('Waiting for extra backends to become healthy') + patch_backend_service( + gcp, + extra_backend_service, + [instance_group], + circuit_breakers={ + "maxRequests": extra_backend_service_max_requests + }, + ) + logger.info("Waiting for extra backends to become healthy") wait_for_healthy_backends(gcp, extra_backend_service, instance_group) - patch_backend_service(gcp, - more_extra_backend_service, - [same_zone_instance_group], - circuit_breakers={ - 'maxRequests': - more_extra_backend_service_max_requests - }) - logger.info('Waiting for more extra backend to become healthy') - wait_for_healthy_backends(gcp, more_extra_backend_service, - same_zone_instance_group) + patch_backend_service( + gcp, + more_extra_backend_service, + [same_zone_instance_group], + circuit_breakers={ + "maxRequests": more_extra_backend_service_max_requests + }, + ) + logger.info("Waiting for more extra backend to become healthy") + wait_for_healthy_backends( + gcp, more_extra_backend_service, same_zone_instance_group + ) extra_backend_instances = get_instance_names(gcp, instance_group) more_extra_backend_instances = get_instance_names( - gcp, same_zone_instance_group) + gcp, same_zone_instance_group + ) route_rules = [ { - 'priority': 0, + "priority": 0, # UnaryCall -> extra_backend_service - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/UnaryCall' - }], - 'service': extra_backend_service.url + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/UnaryCall"} + ], + "service": extra_backend_service.url, }, { - 'priority': 1, + "priority": 1, # EmptyCall -> more_extra_backend_service - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/EmptyCall' - }], - 'service': more_extra_backend_service.url + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/EmptyCall"} + ], + "service": more_extra_backend_service.url, }, ] # Make client send UNARY_CALL and EMPTY_CALL. - configure_client([ - messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, - messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL - ]) - logger.info('Patching url map with %s', route_rules) - patch_url_map_backend_service(gcp, - extra_backend_service, - route_rules=route_rules) - logger.info('Waiting for traffic to go to all backends') + configure_client( + [ + messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, + messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL, + ] + ) + logger.info("Patching url map with %s", route_rules) + patch_url_map_backend_service( + gcp, extra_backend_service, route_rules=route_rules + ) + logger.info("Waiting for traffic to go to all backends") wait_until_all_rpcs_go_to_given_backends( extra_backend_instances + more_extra_backend_instances, - _WAIT_FOR_STATS_SEC) + _WAIT_FOR_STATS_SEC, + ) # Make all calls keep-open. - configure_client([ - messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, - messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL - ], [(messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, - 'rpc-behavior', 'keep-open'), - (messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL, - 'rpc-behavior', 'keep-open')]) + configure_client( + [ + messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, + messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL, + ], + [ + ( + messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, + "rpc-behavior", + "keep-open", + ), + ( + messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL, + "rpc-behavior", + "keep-open", + ), + ], + ) wait_until_rpcs_in_flight( - 'UNARY_CALL', (_WAIT_FOR_BACKEND_SEC + - int(extra_backend_service_max_requests / args.qps)), - extra_backend_service_max_requests, 1) - logger.info('UNARY_CALL reached stable state (%d)', - extra_backend_service_max_requests) + "UNARY_CALL", + ( + _WAIT_FOR_BACKEND_SEC + + int(extra_backend_service_max_requests / args.qps) + ), + extra_backend_service_max_requests, + 1, + ) + logger.info( + "UNARY_CALL reached stable state (%d)", + extra_backend_service_max_requests, + ) wait_until_rpcs_in_flight( - 'EMPTY_CALL', - (_WAIT_FOR_BACKEND_SEC + - int(more_extra_backend_service_max_requests / args.qps)), - more_extra_backend_service_max_requests, 1) - logger.info('EMPTY_CALL reached stable state (%d)', - more_extra_backend_service_max_requests) + "EMPTY_CALL", + ( + _WAIT_FOR_BACKEND_SEC + + int(more_extra_backend_service_max_requests / args.qps) + ), + more_extra_backend_service_max_requests, + 1, + ) + logger.info( + "EMPTY_CALL reached stable state (%d)", + more_extra_backend_service_max_requests, + ) # Increment circuit breakers max_requests threshold. extra_backend_service_max_requests = 800 - patch_backend_service(gcp, - extra_backend_service, [instance_group], - circuit_breakers={ - 'maxRequests': - extra_backend_service_max_requests - }) + patch_backend_service( + gcp, + extra_backend_service, + [instance_group], + circuit_breakers={ + "maxRequests": extra_backend_service_max_requests + }, + ) wait_until_rpcs_in_flight( - 'UNARY_CALL', (_WAIT_FOR_BACKEND_SEC + - int(extra_backend_service_max_requests / args.qps)), - extra_backend_service_max_requests, 1) - logger.info('UNARY_CALL reached stable state after increase (%d)', - extra_backend_service_max_requests) - logger.info('success') + "UNARY_CALL", + ( + _WAIT_FOR_BACKEND_SEC + + int(extra_backend_service_max_requests / args.qps) + ), + extra_backend_service_max_requests, + 1, + ) + logger.info( + "UNARY_CALL reached stable state after increase (%d)", + extra_backend_service_max_requests, + ) + logger.info("success") # Avoid new RPCs being outstanding (some test clients create threads # for sending RPCs) after restoring backend services. configure_client( - [messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL]) + [messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL] + ) except Exception: passed = False raise finally: if passed or not args.halt_after_fail: patch_url_map_backend_service(gcp, original_backend_service) - patch_backend_service(gcp, original_backend_service, - [instance_group]) + patch_backend_service( + gcp, original_backend_service, [instance_group] + ) for backend_service in additional_backend_services: delete_backend_service(gcp, backend_service) set_validate_for_proxyless(gcp, True) def test_timeout(gcp, original_backend_service, instance_group): - logger.info('Running test_timeout') + logger.info("Running test_timeout") - logger.info('waiting for original backends to become healthy') + logger.info("waiting for original backends to become healthy") wait_for_healthy_backends(gcp, original_backend_service, instance_group) # UnaryCall -> maxStreamDuration:3s - route_rules = [{ - 'priority': 0, - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/UnaryCall' - }], - 'service': original_backend_service.url, - 'routeAction': { - 'maxStreamDuration': { - 'seconds': 3, + route_rules = [ + { + "priority": 0, + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/UnaryCall"} + ], + "service": original_backend_service.url, + "routeAction": { + "maxStreamDuration": { + "seconds": 3, + }, }, - }, - }] - patch_url_map_backend_service(gcp, - original_backend_service, - route_rules=route_rules) + } + ] + patch_url_map_backend_service( + gcp, original_backend_service, route_rules=route_rules + ) # A list of tuples (testcase_name, {client_config}, {expected_results}) test_cases = [ ( - 'timeout_exceeded (UNARY_CALL), timeout_different_route (EMPTY_CALL)', + ( + "timeout_exceeded (UNARY_CALL), timeout_different_route" + " (EMPTY_CALL)" + ), # UnaryCall and EmptyCall both sleep-4. # UnaryCall timeouts, EmptyCall succeeds. { - 'rpc_types': [ + "rpc_types": [ messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL, ], - 'metadata': [ - (messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, - 'rpc-behavior', 'sleep-4'), - (messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL, - 'rpc-behavior', 'sleep-4'), + "metadata": [ + ( + messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, + "rpc-behavior", + "sleep-4", + ), + ( + messages_pb2.ClientConfigureRequest.RpcType.EMPTY_CALL, + "rpc-behavior", + "sleep-4", + ), ], }, { - 'UNARY_CALL': 4, # DEADLINE_EXCEEDED - 'EMPTY_CALL': 0, + "UNARY_CALL": 4, # DEADLINE_EXCEEDED + "EMPTY_CALL": 0, }, ), ( - 'app_timeout_exceeded', + "app_timeout_exceeded", # UnaryCall only with sleep-2; timeout=1s; calls timeout. { - 'rpc_types': [ + "rpc_types": [ messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, ], - 'metadata': [ - (messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, - 'rpc-behavior', 'sleep-2'), + "metadata": [ + ( + messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, + "rpc-behavior", + "sleep-2", + ), ], - 'timeout_sec': 1, + "timeout_sec": 1, }, { - 'UNARY_CALL': 4, # DEADLINE_EXCEEDED + "UNARY_CALL": 4, # DEADLINE_EXCEEDED }, ), ( - 'timeout_not_exceeded', + "timeout_not_exceeded", # UnaryCall only with no sleep; calls succeed. { - 'rpc_types': [ + "rpc_types": [ messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, ], }, { - 'UNARY_CALL': 0, + "UNARY_CALL": 0, }, - ) + ), ] passed = True try: first_case = True - for (testcase_name, client_config, expected_results) in test_cases: - logger.info('starting case %s', testcase_name) + for testcase_name, client_config, expected_results in test_cases: + logger.info("starting case %s", testcase_name) configure_client(**client_config) # wait a second to help ensure the client stops sending RPCs with # the old config. We will make multiple attempts if it is failing, @@ -2017,10 +2486,11 @@ def test_timeout(gcp, original_backend_service, instance_group): before_stats = get_client_accumulated_stats() if not before_stats.stats_per_method: raise ValueError( - 'stats.stats_per_method is None, the interop client stats service does not support this test case' + "stats.stats_per_method is None, the interop client stats" + " service does not support this test case" ) for i in range(attempt_count): - logger.info('%s: attempt %d', testcase_name, i) + logger.info("%s: attempt %d", testcase_name, i) test_runtime_secs = 10 time.sleep(test_runtime_secs) @@ -2028,24 +2498,36 @@ def test_timeout(gcp, original_backend_service, instance_group): success = True for rpc, status in list(expected_results.items()): - qty = (after_stats.stats_per_method[rpc].result[status] - - before_stats.stats_per_method[rpc].result[status]) + qty = ( + after_stats.stats_per_method[rpc].result[status] + - before_stats.stats_per_method[rpc].result[status] + ) want = test_runtime_secs * args.qps # Allow 10% deviation from expectation to reduce flakiness - if qty < (want * .9) or qty > (want * 1.1): - logger.info('%s: failed due to %s[%s]: got %d want ~%d', - testcase_name, rpc, status, qty, want) + if qty < (want * 0.9) or qty > (want * 1.1): + logger.info( + "%s: failed due to %s[%s]: got %d want ~%d", + testcase_name, + rpc, + status, + qty, + want, + ) success = False if success: - logger.info('success') + logger.info("success") break - logger.info('%s attempt %d failed', testcase_name, i) + logger.info("%s attempt %d failed", testcase_name, i) before_stats = after_stats else: raise Exception( - '%s: timeout waiting for expected results: %s; got %s' % - (testcase_name, expected_results, - after_stats.stats_per_method)) + "%s: timeout waiting for expected results: %s; got %s" + % ( + testcase_name, + expected_results, + after_stats.stats_per_method, + ) + ) except Exception: passed = False raise @@ -2055,133 +2537,115 @@ def test_timeout(gcp, original_backend_service, instance_group): def test_fault_injection(gcp, original_backend_service, instance_group): - logger.info('Running test_fault_injection') + logger.info("Running test_fault_injection") - logger.info('waiting for original backends to become healthy') + logger.info("waiting for original backends to become healthy") wait_for_healthy_backends(gcp, original_backend_service, instance_group) - testcase_header = 'fi_testcase' + testcase_header = "fi_testcase" def _route(pri, name, fi_policy): return { - 'priority': pri, - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': testcase_header, - 'exactMatch': name, - }], - }], - 'service': original_backend_service.url, - 'routeAction': { - 'faultInjectionPolicy': fi_policy - }, + "priority": pri, + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": testcase_header, + "exactMatch": name, + } + ], + } + ], + "service": original_backend_service.url, + "routeAction": {"faultInjectionPolicy": fi_policy}, } def _abort(pct): return { - 'abort': { - 'httpStatus': 401, - 'percentage': pct, + "abort": { + "httpStatus": 401, + "percentage": pct, } } def _delay(pct): return { - 'delay': { - 'fixedDelay': { - 'seconds': '20' - }, - 'percentage': pct, + "delay": { + "fixedDelay": {"seconds": "20"}, + "percentage": pct, } } zero_route = _abort(0) zero_route.update(_delay(0)) route_rules = [ - _route(0, 'zero_percent_fault_injection', zero_route), - _route(1, 'always_delay', _delay(100)), - _route(2, 'always_abort', _abort(100)), - _route(3, 'delay_half', _delay(50)), - _route(4, 'abort_half', _abort(50)), + _route(0, "zero_percent_fault_injection", zero_route), + _route(1, "always_delay", _delay(100)), + _route(2, "always_abort", _abort(100)), + _route(3, "delay_half", _delay(50)), + _route(4, "abort_half", _abort(50)), { - 'priority': 5, - 'matchRules': [{ - 'prefixMatch': '/' - }], - 'service': original_backend_service.url, + "priority": 5, + "matchRules": [{"prefixMatch": "/"}], + "service": original_backend_service.url, }, ] set_validate_for_proxyless(gcp, False) - patch_url_map_backend_service(gcp, - original_backend_service, - route_rules=route_rules) + patch_url_map_backend_service( + gcp, original_backend_service, route_rules=route_rules + ) # A list of tuples (testcase_name, {client_config}, {code: percent}). Each # test case will set the testcase_header with the testcase_name for routing # to the appropriate config for the case, defined above. test_cases = [ ( - 'always_delay', - { - 'timeout_sec': 2 - }, - { - 4: 1 - }, # DEADLINE_EXCEEDED + "always_delay", + {"timeout_sec": 2}, + {4: 1}, # DEADLINE_EXCEEDED ), ( - 'always_abort', + "always_abort", {}, - { - 16: 1 - }, # UNAUTHENTICATED + {16: 1}, # UNAUTHENTICATED ), ( - 'delay_half', - { - 'timeout_sec': 2 - }, - { - 4: .5, - 0: .5 - }, # DEADLINE_EXCEEDED / OK: 50% / 50% + "delay_half", + {"timeout_sec": 2}, + {4: 0.5, 0: 0.5}, # DEADLINE_EXCEEDED / OK: 50% / 50% ), ( - 'abort_half', + "abort_half", {}, - { - 16: .5, - 0: .5 - }, # UNAUTHENTICATED / OK: 50% / 50% + {16: 0.5, 0: 0.5}, # UNAUTHENTICATED / OK: 50% / 50% ), ( - 'zero_percent_fault_injection', + "zero_percent_fault_injection", {}, - { - 0: 1 - }, # OK + {0: 1}, # OK ), ( - 'non_matching_fault_injection', # Not in route_rules, above. + "non_matching_fault_injection", # Not in route_rules, above. {}, - { - 0: 1 - }, # OK + {0: 1}, # OK ), ] passed = True try: first_case = True - for (testcase_name, client_config, expected_results) in test_cases: - logger.info('starting case %s', testcase_name) + for testcase_name, client_config, expected_results in test_cases: + logger.info("starting case %s", testcase_name) - client_config['metadata'] = [ - (messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, - testcase_header, testcase_name) + client_config["metadata"] = [ + ( + messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, + testcase_header, + testcase_name, + ) ] - client_config['rpc_types'] = [ + client_config["rpc_types"] = [ messages_pb2.ClientConfigureRequest.RpcType.UNARY_CALL, ] configure_client(**client_config) @@ -2202,10 +2666,11 @@ def _delay(pct): before_stats = get_client_accumulated_stats() if not before_stats.stats_per_method: raise ValueError( - 'stats.stats_per_method is None, the interop client stats service does not support this test case' + "stats.stats_per_method is None, the interop client stats" + " service does not support this test case" ) for i in range(attempt_count): - logger.info('%s: attempt %d', testcase_name, i) + logger.info("%s: attempt %d", testcase_name, i) test_runtime_secs = 10 time.sleep(test_runtime_secs) @@ -2213,26 +2678,38 @@ def _delay(pct): success = True for status, pct in list(expected_results.items()): - rpc = 'UNARY_CALL' - qty = (after_stats.stats_per_method[rpc].result[status] - - before_stats.stats_per_method[rpc].result[status]) + rpc = "UNARY_CALL" + qty = ( + after_stats.stats_per_method[rpc].result[status] + - before_stats.stats_per_method[rpc].result[status] + ) want = pct * args.qps * test_runtime_secs # Allow 10% deviation from expectation to reduce flakiness VARIANCE_ALLOWED = 0.1 if abs(qty - want) > want * VARIANCE_ALLOWED: - logger.info('%s: failed due to %s[%s]: got %d want ~%d', - testcase_name, rpc, status, qty, want) + logger.info( + "%s: failed due to %s[%s]: got %d want ~%d", + testcase_name, + rpc, + status, + qty, + want, + ) success = False if success: - logger.info('success') + logger.info("success") break - logger.info('%s attempt %d failed', testcase_name, i) + logger.info("%s attempt %d failed", testcase_name, i) before_stats = after_stats else: raise Exception( - '%s: timeout waiting for expected results: %s; got %s' % - (testcase_name, expected_results, - after_stats.stats_per_method)) + "%s: timeout waiting for expected results: %s; got %s" + % ( + testcase_name, + expected_results, + after_stats.stats_per_method, + ) + ) except Exception: passed = False raise @@ -2245,10 +2722,11 @@ def _delay(pct): def test_csds(gcp, original_backend_service, instance_group, server_uri): test_csds_timeout_s = datetime.timedelta(minutes=5).total_seconds() sleep_interval_between_attempts_s = datetime.timedelta( - seconds=2).total_seconds() - logger.info('Running test_csds') + seconds=2 + ).total_seconds() + logger.info("Running test_csds") - logger.info('waiting for original backends to become healthy') + logger.info("waiting for original backends to become healthy") wait_for_healthy_backends(gcp, original_backend_service, instance_group) # Test case timeout: 5 minutes @@ -2256,123 +2734,162 @@ def test_csds(gcp, original_backend_service, instance_group, server_uri): cnt = 0 while time.time() <= deadline: client_config = get_client_xds_config_dump() - logger.info('test_csds attempt %d: received xDS config %s', cnt, - json.dumps(client_config, indent=2)) + logger.info( + "test_csds attempt %d: received xDS config %s", + cnt, + json.dumps(client_config, indent=2), + ) if client_config is not None: # Got the xDS config dump, now validate it ok = True try: - if client_config['node']['locality']['zone'] != args.zone: - logger.info('Invalid zone %s != %s', - client_config['node']['locality']['zone'], - args.zone) + if client_config["node"]["locality"]["zone"] != args.zone: + logger.info( + "Invalid zone %s != %s", + client_config["node"]["locality"]["zone"], + args.zone, + ) ok = False seen = set() - for xds_config in client_config.get('xds_config', []): - if 'listener_config' in xds_config: - listener_name = xds_config['listener_config'][ - 'dynamic_listeners'][0]['active_state']['listener'][ - 'name'] + for xds_config in client_config.get("xds_config", []): + if "listener_config" in xds_config: + listener_name = xds_config["listener_config"][ + "dynamic_listeners" + ][0]["active_state"]["listener"]["name"] if listener_name != server_uri: - logger.info('Invalid Listener name %s != %s', - listener_name, server_uri) + logger.info( + "Invalid Listener name %s != %s", + listener_name, + server_uri, + ) ok = False else: - seen.add('lds') - elif 'route_config' in xds_config: + seen.add("lds") + elif "route_config" in xds_config: num_vh = len( - xds_config['route_config']['dynamic_route_configs'] - [0]['route_config']['virtual_hosts']) + xds_config["route_config"]["dynamic_route_configs"][ + 0 + ]["route_config"]["virtual_hosts"] + ) if num_vh <= 0: - logger.info('Invalid number of VirtualHosts %s', - num_vh) + logger.info( + "Invalid number of VirtualHosts %s", num_vh + ) ok = False else: - seen.add('rds') - elif 'cluster_config' in xds_config: - cluster_type = xds_config['cluster_config'][ - 'dynamic_active_clusters'][0]['cluster']['type'] - if cluster_type != 'EDS': - logger.info('Invalid cluster type %s != EDS', - cluster_type) + seen.add("rds") + elif "cluster_config" in xds_config: + cluster_type = xds_config["cluster_config"][ + "dynamic_active_clusters" + ][0]["cluster"]["type"] + if cluster_type != "EDS": + logger.info( + "Invalid cluster type %s != EDS", cluster_type + ) ok = False else: - seen.add('cds') - elif 'endpoint_config' in xds_config: + seen.add("cds") + elif "endpoint_config" in xds_config: sub_zone = xds_config["endpoint_config"][ - "dynamic_endpoint_configs"][0]["endpoint_config"][ - "endpoints"][0]["locality"]["sub_zone"] + "dynamic_endpoint_configs" + ][0]["endpoint_config"]["endpoints"][0]["locality"][ + "sub_zone" + ] if args.zone not in sub_zone: - logger.info('Invalid endpoint sub_zone %s', - sub_zone) + logger.info( + "Invalid endpoint sub_zone %s", sub_zone + ) ok = False else: - seen.add('eds') + seen.add("eds") for generic_xds_config in client_config.get( - 'generic_xds_configs', []): - if re.search(r'\.Listener$', - generic_xds_config['type_url']): - seen.add('lds') + "generic_xds_configs", [] + ): + if re.search( + r"\.Listener$", generic_xds_config["type_url"] + ): + seen.add("lds") listener = generic_xds_config["xds_config"] - if listener['name'] != server_uri: - logger.info('Invalid Listener name %s != %s', - listener_name, server_uri) + if listener["name"] != server_uri: + logger.info( + "Invalid Listener name %s != %s", + listener_name, + server_uri, + ) ok = False - elif re.search(r'\.RouteConfiguration$', - generic_xds_config['type_url']): - seen.add('rds') + elif re.search( + r"\.RouteConfiguration$", generic_xds_config["type_url"] + ): + seen.add("rds") route_config = generic_xds_config["xds_config"] - if not len(route_config['virtual_hosts']): - logger.info('Invalid number of VirtualHosts %s', - num_vh) + if not len(route_config["virtual_hosts"]): + logger.info( + "Invalid number of VirtualHosts %s", num_vh + ) ok = False - elif re.search(r'\.Cluster$', - generic_xds_config['type_url']): - seen.add('cds') + elif re.search( + r"\.Cluster$", generic_xds_config["type_url"] + ): + seen.add("cds") cluster = generic_xds_config["xds_config"] - if cluster['type'] != 'EDS': - logger.info('Invalid cluster type %s != EDS', - cluster_type) + if cluster["type"] != "EDS": + logger.info( + "Invalid cluster type %s != EDS", cluster_type + ) ok = False - elif re.search(r'\.ClusterLoadAssignment$', - generic_xds_config['type_url']): - seen.add('eds') + elif re.search( + r"\.ClusterLoadAssignment$", + generic_xds_config["type_url"], + ): + seen.add("eds") endpoint = generic_xds_config["xds_config"] - if args.zone not in endpoint["endpoints"][0][ - "locality"]["sub_zone"]: - logger.info('Invalid endpoint sub_zone %s', - sub_zone) + if ( + args.zone + not in endpoint["endpoints"][0]["locality"][ + "sub_zone" + ] + ): + logger.info( + "Invalid endpoint sub_zone %s", sub_zone + ) ok = False - want = {'lds', 'rds', 'cds', 'eds'} + want = {"lds", "rds", "cds", "eds"} if seen != want: - logger.info('Incomplete xDS config dump, seen=%s', seen) + logger.info("Incomplete xDS config dump, seen=%s", seen) ok = False except: - logger.exception('Error in xDS config dump:') + logger.exception("Error in xDS config dump:") ok = False finally: if ok: # Successfully fetched xDS config, and they looks good. - logger.info('success') + logger.info("success") return - logger.info('test_csds attempt %d failed', cnt) + logger.info("test_csds attempt %d failed", cnt) # Give the client some time to fetch xDS resources time.sleep(sleep_interval_between_attempts_s) cnt += 1 - raise RuntimeError('failed to receive a valid xDS config in %s seconds' % - test_csds_timeout_s) + raise RuntimeError( + "failed to receive a valid xDS config in %s seconds" + % test_csds_timeout_s + ) def set_validate_for_proxyless(gcp, validate_for_proxyless): if not gcp.alpha_compute: logger.debug( - 'Not setting validateForProxy because alpha is not enabled') + "Not setting validateForProxy because alpha is not enabled" + ) return - if len(gcp.global_forwarding_rules) != 1 or len( - gcp.target_proxies) != 1 or len(gcp.url_maps) != 1: + if ( + len(gcp.global_forwarding_rules) != 1 + or len(gcp.target_proxies) != 1 + or len(gcp.url_maps) != 1 + ): logger.debug( - "Global forwarding rule, target proxy or url map not found.") + "Global forwarding rule, target proxy or url map not found." + ) return # This function deletes global_forwarding_rule and target_proxy, then # recreate target_proxy with validateForProxyless=False. This is necessary @@ -2384,17 +2901,18 @@ def set_validate_for_proxyless(gcp, validate_for_proxyless): def get_serving_status(instance, service_port): - with grpc.insecure_channel('%s:%d' % (instance, service_port)) as channel: + with grpc.insecure_channel("%s:%d" % (instance, service_port)) as channel: health_stub = health_pb2_grpc.HealthStub(channel) return health_stub.Check(health_pb2.HealthCheckRequest()) def set_serving_status(instances, service_port, serving): - logger.info('setting %s serving status to %s', instances, serving) + logger.info("setting %s serving status to %s", instances, serving) for instance in instances: - with grpc.insecure_channel('%s:%d' % - (instance, service_port)) as channel: - logger.info('setting %s serving status to %s', instance, serving) + with grpc.insecure_channel( + "%s:%d" % (instance, service_port) + ) as channel: + logger.info("setting %s serving status to %s", instance, serving) stub = test_pb2_grpc.XdsUpdateHealthServiceStub(channel) retry_count = 5 for i in range(5): @@ -2403,14 +2921,19 @@ def set_serving_status(instances, service_port, serving): else: stub.SetNotServing(empty_pb2.Empty()) serving_status = get_serving_status(instance, service_port) - logger.info('got instance service status %s', serving_status) - want_status = health_pb2.HealthCheckResponse.SERVING if serving else health_pb2.HealthCheckResponse.NOT_SERVING + logger.info("got instance service status %s", serving_status) + want_status = ( + health_pb2.HealthCheckResponse.SERVING + if serving + else health_pb2.HealthCheckResponse.NOT_SERVING + ) if serving_status.status == want_status: break if i == retry_count - 1: raise Exception( - 'failed to set instance service status after %d retries' - % retry_count) + "failed to set instance service status after %d retries" + % retry_count + ) def is_primary_instance_group(gcp, instance_group): @@ -2420,15 +2943,19 @@ def is_primary_instance_group(gcp, instance_group): instance_names = get_instance_names(gcp, instance_group) stats = get_client_stats(_NUM_TEST_RPCS, _WAIT_FOR_STATS_SEC) return all( - peer in instance_names for peer in list(stats.rpcs_by_peer.keys())) + peer in instance_names for peer in list(stats.rpcs_by_peer.keys()) + ) def get_startup_script(path_to_server_binary, service_port): if path_to_server_binary: - return 'nohup %s --port=%d 1>/dev/null &' % (path_to_server_binary, - service_port) + return "nohup %s --port=%d 1>/dev/null &" % ( + path_to_server_binary, + service_port, + ) else: - return """#!/bin/bash + return ( + """#!/bin/bash sudo apt update sudo apt install -y git default-jdk mkdir java_server @@ -2439,180 +2966,201 @@ def get_startup_script(path_to_server_binary, service_port): ../gradlew installDist -x test -PskipCodegen=true -PskipAndroid=true nohup build/install/grpc-interop-testing/bin/xds-test-server \ - --port=%d 1>/dev/null &""" % service_port + --port=%d 1>/dev/null &""" + % service_port + ) -def create_instance_template(gcp, name, network, source_image, machine_type, - startup_script): +def create_instance_template( + gcp, name, network, source_image, machine_type, startup_script +): config = { - 'name': name, - 'properties': { - 'tags': { - 'items': ['allow-health-checks'] + "name": name, + "properties": { + "tags": {"items": ["allow-health-checks"]}, + "machineType": machine_type, + "serviceAccounts": [ + { + "email": "default", + "scopes": [ + "https://www.googleapis.com/auth/cloud-platform", + ], + } + ], + "networkInterfaces": [ + { + "accessConfigs": [{"type": "ONE_TO_ONE_NAT"}], + "network": network, + } + ], + "disks": [ + { + "boot": True, + "initializeParams": {"sourceImage": source_image}, + "autoDelete": True, + } + ], + "metadata": { + "items": [{"key": "startup-script", "value": startup_script}] }, - 'machineType': machine_type, - 'serviceAccounts': [{ - 'email': 'default', - 'scopes': ['https://www.googleapis.com/auth/cloud-platform',] - }], - 'networkInterfaces': [{ - 'accessConfigs': [{ - 'type': 'ONE_TO_ONE_NAT' - }], - 'network': network - }], - 'disks': [{ - 'boot': True, - 'initializeParams': { - 'sourceImage': source_image - }, - 'autoDelete': True - }], - 'metadata': { - 'items': [{ - 'key': 'startup-script', - 'value': startup_script - }] - } - } + }, } - logger.debug('Sending GCP request with body=%s', config) - result = gcp.compute.instanceTemplates().insert( - project=gcp.project, body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) - gcp.instance_template = GcpResource(config['name'], result['targetLink']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + gcp.compute.instanceTemplates() + .insert(project=gcp.project, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) + gcp.instance_template = GcpResource(config["name"], result["targetLink"]) def add_instance_group(gcp, zone, name, size): config = { - 'name': name, - 'instanceTemplate': gcp.instance_template.url, - 'targetSize': size, - 'namedPorts': [{ - 'name': 'grpc', - 'port': gcp.service_port - }] + "name": name, + "instanceTemplate": gcp.instance_template.url, + "targetSize": size, + "namedPorts": [{"name": "grpc", "port": gcp.service_port}], } - logger.debug('Sending GCP request with body=%s', config) - result = gcp.compute.instanceGroupManagers().insert( - project=gcp.project, zone=zone, - body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_zone_operation(gcp, zone, result['name']) - result = gcp.compute.instanceGroupManagers().get( - project=gcp.project, zone=zone, - instanceGroupManager=config['name']).execute( - num_retries=_GCP_API_RETRIES) - instance_group = InstanceGroup(config['name'], result['instanceGroup'], - zone) + logger.debug("Sending GCP request with body=%s", config) + result = ( + gcp.compute.instanceGroupManagers() + .insert(project=gcp.project, zone=zone, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_zone_operation(gcp, zone, result["name"]) + result = ( + gcp.compute.instanceGroupManagers() + .get( + project=gcp.project, zone=zone, instanceGroupManager=config["name"] + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + instance_group = InstanceGroup( + config["name"], result["instanceGroup"], zone + ) gcp.instance_groups.append(instance_group) - wait_for_instance_group_to_reach_expected_size(gcp, instance_group, size, - _WAIT_FOR_OPERATION_SEC) + wait_for_instance_group_to_reach_expected_size( + gcp, instance_group, size, _WAIT_FOR_OPERATION_SEC + ) return instance_group def create_health_check(gcp, name): if gcp.alpha_compute: config = { - 'name': name, - 'type': 'GRPC', - 'grpcHealthCheck': { - 'portSpecification': 'USE_SERVING_PORT' - } + "name": name, + "type": "GRPC", + "grpcHealthCheck": {"portSpecification": "USE_SERVING_PORT"}, } compute_to_use = gcp.alpha_compute else: config = { - 'name': name, - 'type': 'TCP', - 'tcpHealthCheck': { - 'portName': 'grpc' - } + "name": name, + "type": "TCP", + "tcpHealthCheck": {"portName": "grpc"}, } compute_to_use = gcp.compute - logger.debug('Sending GCP request with body=%s', config) - result = compute_to_use.healthChecks().insert( - project=gcp.project, body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) - gcp.health_check = GcpResource(config['name'], result['targetLink']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + compute_to_use.healthChecks() + .insert(project=gcp.project, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) + gcp.health_check = GcpResource(config["name"], result["targetLink"]) def create_health_check_firewall_rule(gcp, name): config = { - 'name': name, - 'direction': 'INGRESS', - 'allowed': [{ - 'IPProtocol': 'tcp' - }], - 'sourceRanges': ['35.191.0.0/16', '130.211.0.0/22'], - 'targetTags': ['allow-health-checks'], + "name": name, + "direction": "INGRESS", + "allowed": [{"IPProtocol": "tcp"}], + "sourceRanges": ["35.191.0.0/16", "130.211.0.0/22"], + "targetTags": ["allow-health-checks"], } - logger.debug('Sending GCP request with body=%s', config) - result = gcp.compute.firewalls().insert( - project=gcp.project, body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) - gcp.health_check_firewall_rule = GcpResource(config['name'], - result['targetLink']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + gcp.compute.firewalls() + .insert(project=gcp.project, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) + gcp.health_check_firewall_rule = GcpResource( + config["name"], result["targetLink"] + ) def add_backend_service(gcp, name): if gcp.alpha_compute: - protocol = 'GRPC' + protocol = "GRPC" compute_to_use = gcp.alpha_compute else: - protocol = 'HTTP2' + protocol = "HTTP2" compute_to_use = gcp.compute config = { - 'name': name, - 'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', - 'healthChecks': [gcp.health_check.url], - 'portName': 'grpc', - 'protocol': protocol + "name": name, + "loadBalancingScheme": "INTERNAL_SELF_MANAGED", + "healthChecks": [gcp.health_check.url], + "portName": "grpc", + "protocol": protocol, } - logger.debug('Sending GCP request with body=%s', config) - result = compute_to_use.backendServices().insert( - project=gcp.project, body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) - backend_service = GcpResource(config['name'], result['targetLink']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + compute_to_use.backendServices() + .insert(project=gcp.project, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) + backend_service = GcpResource(config["name"], result["targetLink"]) gcp.backend_services.append(backend_service) return backend_service def create_url_map(gcp, name, backend_service, host_name): config = { - 'name': name, - 'defaultService': backend_service.url, - 'pathMatchers': [{ - 'name': _PATH_MATCHER_NAME, - 'defaultService': backend_service.url, - }], - 'hostRules': [{ - 'hosts': [host_name], - 'pathMatcher': _PATH_MATCHER_NAME - }] + "name": name, + "defaultService": backend_service.url, + "pathMatchers": [ + { + "name": _PATH_MATCHER_NAME, + "defaultService": backend_service.url, + } + ], + "hostRules": [ + {"hosts": [host_name], "pathMatcher": _PATH_MATCHER_NAME} + ], } - logger.debug('Sending GCP request with body=%s', config) - result = gcp.compute.urlMaps().insert( - project=gcp.project, body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) - url_map = GcpResource(config['name'], result['targetLink']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + gcp.compute.urlMaps() + .insert(project=gcp.project, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) + url_map = GcpResource(config["name"], result["targetLink"]) gcp.url_maps.append(url_map) return url_map def patch_url_map_host_rule_with_port(gcp, name, backend_service, host_name): config = { - 'hostRules': [{ - 'hosts': ['%s:%d' % (host_name, gcp.service_port)], - 'pathMatcher': _PATH_MATCHER_NAME - }] + "hostRules": [ + { + "hosts": ["%s:%d" % (host_name, gcp.service_port)], + "pathMatcher": _PATH_MATCHER_NAME, + } + ] } - logger.debug('Sending GCP request with body=%s', config) - result = gcp.compute.urlMaps().patch( - project=gcp.project, urlMap=name, - body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + gcp.compute.urlMaps() + .patch(project=gcp.project, urlMap=name, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) def create_target_proxy(gcp, name, validate_for_proxyless=True, url_map=None): @@ -2622,34 +3170,40 @@ def create_target_proxy(gcp, name, validate_for_proxyless=True, url_map=None): arg_url_map_url = gcp.url_maps[0].url if gcp.alpha_compute: config = { - 'name': name, - 'url_map': arg_url_map_url, - 'validate_for_proxyless': validate_for_proxyless + "name": name, + "url_map": arg_url_map_url, + "validate_for_proxyless": validate_for_proxyless, } - logger.debug('Sending GCP request with body=%s', config) - result = gcp.alpha_compute.targetGrpcProxies().insert( - project=gcp.project, - body=config).execute(num_retries=_GCP_API_RETRIES) + logger.debug("Sending GCP request with body=%s", config) + result = ( + gcp.alpha_compute.targetGrpcProxies() + .insert(project=gcp.project, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) else: config = { - 'name': name, - 'url_map': arg_url_map_url, + "name": name, + "url_map": arg_url_map_url, } - logger.debug('Sending GCP request with body=%s', config) - result = gcp.compute.targetHttpProxies().insert( - project=gcp.project, - body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) - target_proxy = GcpResource(config['name'], result['targetLink']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + gcp.compute.targetHttpProxies() + .insert(project=gcp.project, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) + target_proxy = GcpResource(config["name"], result["targetLink"]) gcp.target_proxies.append(target_proxy) return target_proxy -def create_global_forwarding_rule(gcp, - name, - potential_ports, - potential_ip_addresses=['0.0.0.0'], - target_proxy=None): +def create_global_forwarding_rule( + gcp, + name, + potential_ports, + potential_ip_addresses=["0.0.0.0"], + target_proxy=None, +): if target_proxy: arg_target_proxy_url = target_proxy.url else: @@ -2662,35 +3216,42 @@ def create_global_forwarding_rule(gcp, for ip_address in potential_ip_addresses: try: config = { - 'name': name, - 'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', - 'portRange': str(port), - 'IPAddress': ip_address, - 'network': args.network, - 'target': arg_target_proxy_url, + "name": name, + "loadBalancingScheme": "INTERNAL_SELF_MANAGED", + "portRange": str(port), + "IPAddress": ip_address, + "network": args.network, + "target": arg_target_proxy_url, } - logger.debug('Sending GCP request with body=%s', config) - result = compute_to_use.globalForwardingRules().insert( - project=gcp.project, - body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) - global_forwarding_rule = GcpResource(config['name'], - result['targetLink']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + compute_to_use.globalForwardingRules() + .insert(project=gcp.project, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) + global_forwarding_rule = GcpResource( + config["name"], result["targetLink"] + ) gcp.global_forwarding_rules.append(global_forwarding_rule) gcp.service_port = port return except googleapiclient.errors.HttpError as http_error: logger.warning( - 'Got error %s when attempting to create forwarding rule to ' - '%s:%d. Retrying with another port.' % - (http_error, ip_address, port)) + "Got error %s when attempting to create forwarding rule to " + "%s:%d. Retrying with another port." + % (http_error, ip_address, port) + ) def get_health_check(gcp, health_check_name): try: - result = gcp.compute.healthChecks().get( - project=gcp.project, healthCheck=health_check_name).execute() - gcp.health_check = GcpResource(health_check_name, result['selfLink']) + result = ( + gcp.compute.healthChecks() + .get(project=gcp.project, healthCheck=health_check_name) + .execute() + ) + gcp.health_check = GcpResource(health_check_name, result["selfLink"]) except Exception as e: gcp.errors.append(e) gcp.health_check = GcpResource(health_check_name, None) @@ -2698,10 +3259,14 @@ def get_health_check(gcp, health_check_name): def get_health_check_firewall_rule(gcp, firewall_name): try: - result = gcp.compute.firewalls().get(project=gcp.project, - firewall=firewall_name).execute() - gcp.health_check_firewall_rule = GcpResource(firewall_name, - result['selfLink']) + result = ( + gcp.compute.firewalls() + .get(project=gcp.project, firewall=firewall_name) + .execute() + ) + gcp.health_check_firewall_rule = GcpResource( + firewall_name, result["selfLink"] + ) except Exception as e: gcp.errors.append(e) gcp.health_check_firewall_rule = GcpResource(firewall_name, None) @@ -2709,9 +3274,12 @@ def get_health_check_firewall_rule(gcp, firewall_name): def get_backend_service(gcp, backend_service_name, record_error=True): try: - result = gcp.compute.backendServices().get( - project=gcp.project, backendService=backend_service_name).execute() - backend_service = GcpResource(backend_service_name, result['selfLink']) + result = ( + gcp.compute.backendServices() + .get(project=gcp.project, backendService=backend_service_name) + .execute() + ) + backend_service = GcpResource(backend_service_name, result["selfLink"]) except Exception as e: if record_error: gcp.errors.append(e) @@ -2722,9 +3290,12 @@ def get_backend_service(gcp, backend_service_name, record_error=True): def get_url_map(gcp, url_map_name, record_error=True): try: - result = gcp.compute.urlMaps().get(project=gcp.project, - urlMap=url_map_name).execute() - url_map = GcpResource(url_map_name, result['selfLink']) + result = ( + gcp.compute.urlMaps() + .get(project=gcp.project, urlMap=url_map_name) + .execute() + ) + url_map = GcpResource(url_map_name, result["selfLink"]) gcp.url_maps.append(url_map) except Exception as e: if record_error: @@ -2734,14 +3305,18 @@ def get_url_map(gcp, url_map_name, record_error=True): def get_target_proxy(gcp, target_proxy_name, record_error=True): try: if gcp.alpha_compute: - result = gcp.alpha_compute.targetGrpcProxies().get( - project=gcp.project, - targetGrpcProxy=target_proxy_name).execute() + result = ( + gcp.alpha_compute.targetGrpcProxies() + .get(project=gcp.project, targetGrpcProxy=target_proxy_name) + .execute() + ) else: - result = gcp.compute.targetHttpProxies().get( - project=gcp.project, - targetHttpProxy=target_proxy_name).execute() - target_proxy = GcpResource(target_proxy_name, result['selfLink']) + result = ( + gcp.compute.targetHttpProxies() + .get(project=gcp.project, targetHttpProxy=target_proxy_name) + .execute() + ) + target_proxy = GcpResource(target_proxy_name, result["selfLink"]) gcp.target_proxies.append(target_proxy) except Exception as e: if record_error: @@ -2750,10 +3325,14 @@ def get_target_proxy(gcp, target_proxy_name, record_error=True): def get_global_forwarding_rule(gcp, forwarding_rule_name, record_error=True): try: - result = gcp.compute.globalForwardingRules().get( - project=gcp.project, forwardingRule=forwarding_rule_name).execute() - global_forwarding_rule = GcpResource(forwarding_rule_name, - result['selfLink']) + result = ( + gcp.compute.globalForwardingRules() + .get(project=gcp.project, forwardingRule=forwarding_rule_name) + .execute() + ) + global_forwarding_rule = GcpResource( + forwarding_rule_name, result["selfLink"] + ) gcp.global_forwarding_rules.append(global_forwarding_rule) except Exception as e: if record_error: @@ -2762,9 +3341,12 @@ def get_global_forwarding_rule(gcp, forwarding_rule_name, record_error=True): def get_instance_template(gcp, template_name): try: - result = gcp.compute.instanceTemplates().get( - project=gcp.project, instanceTemplate=template_name).execute() - gcp.instance_template = GcpResource(template_name, result['selfLink']) + result = ( + gcp.compute.instanceTemplates() + .get(project=gcp.project, instanceTemplate=template_name) + .execute() + ) + gcp.instance_template = GcpResource(template_name, result["selfLink"]) except Exception as e: gcp.errors.append(e) gcp.instance_template = GcpResource(template_name, None) @@ -2772,12 +3354,19 @@ def get_instance_template(gcp, template_name): def get_instance_group(gcp, zone, instance_group_name): try: - result = gcp.compute.instanceGroups().get( - project=gcp.project, zone=zone, - instanceGroup=instance_group_name).execute() - gcp.service_port = result['namedPorts'][0]['port'] - instance_group = InstanceGroup(instance_group_name, result['selfLink'], - zone) + result = ( + gcp.compute.instanceGroups() + .get( + project=gcp.project, + zone=zone, + instanceGroup=instance_group_name, + ) + .execute() + ) + gcp.service_port = result["namedPorts"][0]["port"] + instance_group = InstanceGroup( + instance_group_name, result["selfLink"], zone + ) except Exception as e: gcp.errors.append(e) instance_group = InstanceGroup(instance_group_name, None, zone) @@ -2789,21 +3378,30 @@ def delete_global_forwarding_rule(gcp, forwarding_rule_to_delete=None): if not forwarding_rule_to_delete: return try: - logger.debug('Deleting forwarding rule %s', - forwarding_rule_to_delete.name) - result = gcp.compute.globalForwardingRules().delete( - project=gcp.project, - forwardingRule=forwarding_rule_to_delete.name).execute( - num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug( + "Deleting forwarding rule %s", forwarding_rule_to_delete.name + ) + result = ( + gcp.compute.globalForwardingRules() + .delete( + project=gcp.project, + forwardingRule=forwarding_rule_to_delete.name, + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) if forwarding_rule_to_delete in gcp.global_forwarding_rules: gcp.global_forwarding_rules.remove(forwarding_rule_to_delete) else: logger.debug( - 'Forwarding rule %s does not exist in gcp.global_forwarding_rules', - forwarding_rule_to_delete.name) + ( + "Forwarding rule %s does not exist in" + " gcp.global_forwarding_rules" + ), + forwarding_rule_to_delete.name, + ) except googleapiclient.errors.HttpError as http_error: - logger.info('Delete failed: %s', http_error) + logger.info("Delete failed: %s", http_error) def delete_global_forwarding_rules(gcp): @@ -2817,25 +3415,33 @@ def delete_target_proxy(gcp, proxy_to_delete=None): return try: if gcp.alpha_compute: - logger.debug('Deleting grpc proxy %s', proxy_to_delete.name) - result = gcp.alpha_compute.targetGrpcProxies().delete( - project=gcp.project, - targetGrpcProxy=proxy_to_delete.name).execute( - num_retries=_GCP_API_RETRIES) + logger.debug("Deleting grpc proxy %s", proxy_to_delete.name) + result = ( + gcp.alpha_compute.targetGrpcProxies() + .delete( + project=gcp.project, targetGrpcProxy=proxy_to_delete.name + ) + .execute(num_retries=_GCP_API_RETRIES) + ) else: - logger.debug('Deleting http proxy %s', proxy_to_delete.name) - result = gcp.compute.targetHttpProxies().delete( - project=gcp.project, - targetHttpProxy=proxy_to_delete.name).execute( - num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug("Deleting http proxy %s", proxy_to_delete.name) + result = ( + gcp.compute.targetHttpProxies() + .delete( + project=gcp.project, targetHttpProxy=proxy_to_delete.name + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) if proxy_to_delete in gcp.target_proxies: gcp.target_proxies.remove(proxy_to_delete) else: - logger.debug('Gcp proxy %s does not exist in gcp.target_proxies', - proxy_to_delete.name) + logger.debug( + "Gcp proxy %s does not exist in gcp.target_proxies", + proxy_to_delete.name, + ) except googleapiclient.errors.HttpError as http_error: - logger.info('Delete failed: %s', http_error) + logger.info("Delete failed: %s", http_error) def delete_target_proxies(gcp): @@ -2848,18 +3454,22 @@ def delete_url_map(gcp, url_map_to_delete=None): if not url_map_to_delete: return try: - logger.debug('Deleting url map %s', url_map_to_delete.name) - result = gcp.compute.urlMaps().delete( - project=gcp.project, - urlMap=url_map_to_delete.name).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug("Deleting url map %s", url_map_to_delete.name) + result = ( + gcp.compute.urlMaps() + .delete(project=gcp.project, urlMap=url_map_to_delete.name) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) if url_map_to_delete in gcp.url_maps: gcp.url_maps.remove(url_map_to_delete) else: - logger.debug('Url map %s does not exist in gcp.url_maps', - url_map_to_delete.name) + logger.debug( + "Url map %s does not exist in gcp.url_maps", + url_map_to_delete.name, + ) except googleapiclient.errors.HttpError as http_error: - logger.info('Delete failed: %s', http_error) + logger.info("Delete failed: %s", http_error) def delete_url_maps(gcp): @@ -2870,13 +3480,15 @@ def delete_url_maps(gcp): def delete_backend_service(gcp, backend_service): try: - logger.debug('Deleting backend service %s', backend_service.name) - result = gcp.compute.backendServices().delete( - project=gcp.project, backendService=backend_service.name).execute( - num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug("Deleting backend service %s", backend_service.name) + result = ( + gcp.compute.backendServices() + .delete(project=gcp.project, backendService=backend_service.name) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) except googleapiclient.errors.HttpError as http_error: - logger.info('Delete failed: %s', http_error) + logger.info("Delete failed: %s", http_error) def delete_backend_services(gcp): @@ -2886,116 +3498,153 @@ def delete_backend_services(gcp): def delete_firewall(gcp): try: - logger.debug('Deleting firewall %s', - gcp.health_check_firewall_rule.name) - result = gcp.compute.firewalls().delete( - project=gcp.project, - firewall=gcp.health_check_firewall_rule.name).execute( - num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug( + "Deleting firewall %s", gcp.health_check_firewall_rule.name + ) + result = ( + gcp.compute.firewalls() + .delete( + project=gcp.project, + firewall=gcp.health_check_firewall_rule.name, + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) except googleapiclient.errors.HttpError as http_error: - logger.info('Delete failed: %s', http_error) + logger.info("Delete failed: %s", http_error) def delete_health_check(gcp): try: - logger.debug('Deleting health check %s', gcp.health_check.name) - result = gcp.compute.healthChecks().delete( - project=gcp.project, healthCheck=gcp.health_check.name).execute( - num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug("Deleting health check %s", gcp.health_check.name) + result = ( + gcp.compute.healthChecks() + .delete(project=gcp.project, healthCheck=gcp.health_check.name) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) except googleapiclient.errors.HttpError as http_error: - logger.info('Delete failed: %s', http_error) + logger.info("Delete failed: %s", http_error) def delete_instance_groups(gcp): for instance_group in gcp.instance_groups: try: - logger.debug('Deleting instance group %s %s', instance_group.name, - instance_group.zone) - result = gcp.compute.instanceGroupManagers().delete( - project=gcp.project, - zone=instance_group.zone, - instanceGroupManager=instance_group.name).execute( - num_retries=_GCP_API_RETRIES) - wait_for_zone_operation(gcp, - instance_group.zone, - result['name'], - timeout_sec=_WAIT_FOR_BACKEND_SEC) + logger.debug( + "Deleting instance group %s %s", + instance_group.name, + instance_group.zone, + ) + result = ( + gcp.compute.instanceGroupManagers() + .delete( + project=gcp.project, + zone=instance_group.zone, + instanceGroupManager=instance_group.name, + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_zone_operation( + gcp, + instance_group.zone, + result["name"], + timeout_sec=_WAIT_FOR_BACKEND_SEC, + ) except googleapiclient.errors.HttpError as http_error: - logger.info('Delete failed: %s', http_error) + logger.info("Delete failed: %s", http_error) def delete_instance_template(gcp): try: - logger.debug('Deleting instance template %s', - gcp.instance_template.name) - result = gcp.compute.instanceTemplates().delete( - project=gcp.project, - instanceTemplate=gcp.instance_template.name).execute( - num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug( + "Deleting instance template %s", gcp.instance_template.name + ) + result = ( + gcp.compute.instanceTemplates() + .delete( + project=gcp.project, instanceTemplate=gcp.instance_template.name + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) except googleapiclient.errors.HttpError as http_error: - logger.info('Delete failed: %s', http_error) + logger.info("Delete failed: %s", http_error) -def patch_backend_service(gcp, - backend_service, - instance_groups, - balancing_mode='UTILIZATION', - max_rate=1, - circuit_breakers=None): +def patch_backend_service( + gcp, + backend_service, + instance_groups, + balancing_mode="UTILIZATION", + max_rate=1, + circuit_breakers=None, +): if gcp.alpha_compute: compute_to_use = gcp.alpha_compute else: compute_to_use = gcp.compute config = { - 'backends': [{ - 'group': instance_group.url, - 'balancingMode': balancing_mode, - 'maxRate': max_rate if balancing_mode == 'RATE' else None - } for instance_group in instance_groups], - 'circuitBreakers': circuit_breakers, + "backends": [ + { + "group": instance_group.url, + "balancingMode": balancing_mode, + "maxRate": max_rate if balancing_mode == "RATE" else None, + } + for instance_group in instance_groups + ], + "circuitBreakers": circuit_breakers, } - logger.debug('Sending GCP request with body=%s', config) - result = compute_to_use.backendServices().patch( - project=gcp.project, backendService=backend_service.name, - body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, - result['name'], - timeout_sec=_WAIT_FOR_BACKEND_SEC) - - -def resize_instance_group(gcp, - instance_group, - new_size, - timeout_sec=_WAIT_FOR_OPERATION_SEC): - result = gcp.compute.instanceGroupManagers().resize( - project=gcp.project, - zone=instance_group.zone, - instanceGroupManager=instance_group.name, - size=new_size).execute(num_retries=_GCP_API_RETRIES) - wait_for_zone_operation(gcp, - instance_group.zone, - result['name'], - timeout_sec=360) - wait_for_instance_group_to_reach_expected_size(gcp, instance_group, - new_size, timeout_sec) - - -def patch_url_map_backend_service(gcp, - backend_service=None, - services_with_weights=None, - route_rules=None, - url_map=None): + logger.debug("Sending GCP request with body=%s", config) + result = ( + compute_to_use.backendServices() + .patch( + project=gcp.project, + backendService=backend_service.name, + body=config, + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation( + gcp, result["name"], timeout_sec=_WAIT_FOR_BACKEND_SEC + ) + + +def resize_instance_group( + gcp, instance_group, new_size, timeout_sec=_WAIT_FOR_OPERATION_SEC +): + result = ( + gcp.compute.instanceGroupManagers() + .resize( + project=gcp.project, + zone=instance_group.zone, + instanceGroupManager=instance_group.name, + size=new_size, + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_zone_operation( + gcp, instance_group.zone, result["name"], timeout_sec=360 + ) + wait_for_instance_group_to_reach_expected_size( + gcp, instance_group, new_size, timeout_sec + ) + + +def patch_url_map_backend_service( + gcp, + backend_service=None, + services_with_weights=None, + route_rules=None, + url_map=None, +): if url_map: url_map_name = url_map.name else: url_map_name = gcp.url_maps[0].name - '''change url_map's backend service + """change url_map's backend service Only one of backend_service and service_with_weights can be not None. - ''' + """ if gcp.alpha_compute: compute_to_use = gcp.alpha_compute else: @@ -3003,33 +3652,46 @@ def patch_url_map_backend_service(gcp, if backend_service and services_with_weights: raise ValueError( - 'both backend_service and service_with_weights are not None.') + "both backend_service and service_with_weights are not None." + ) default_service = backend_service.url if backend_service else None - default_route_action = { - 'weightedBackendServices': [{ - 'backendService': service.url, - 'weight': w, - } for service, w in list(services_with_weights.items())] - } if services_with_weights else None + default_route_action = ( + { + "weightedBackendServices": [ + { + "backendService": service.url, + "weight": w, + } + for service, w in list(services_with_weights.items()) + ] + } + if services_with_weights + else None + ) config = { - 'pathMatchers': [{ - 'name': _PATH_MATCHER_NAME, - 'defaultService': default_service, - 'defaultRouteAction': default_route_action, - 'routeRules': route_rules, - }] + "pathMatchers": [ + { + "name": _PATH_MATCHER_NAME, + "defaultService": default_service, + "defaultRouteAction": default_route_action, + "routeRules": route_rules, + } + ] } - logger.debug('Sending GCP request with body=%s', config) - result = compute_to_use.urlMaps().patch( - project=gcp.project, urlMap=url_map_name, - body=config).execute(num_retries=_GCP_API_RETRIES) - wait_for_global_operation(gcp, result['name']) + logger.debug("Sending GCP request with body=%s", config) + result = ( + compute_to_use.urlMaps() + .patch(project=gcp.project, urlMap=url_map_name, body=config) + .execute(num_retries=_GCP_API_RETRIES) + ) + wait_for_global_operation(gcp, result["name"]) -def wait_for_instance_group_to_reach_expected_size(gcp, instance_group, - expected_size, timeout_sec): +def wait_for_instance_group_to_reach_expected_size( + gcp, instance_group, expected_size, timeout_sec +): start_time = time.time() while True: current_size = len(get_instance_names(gcp, instance_group)) @@ -3037,102 +3699,121 @@ def wait_for_instance_group_to_reach_expected_size(gcp, instance_group, break if time.time() - start_time > timeout_sec: raise Exception( - 'Instance group had expected size %d but actual size %d' % - (expected_size, current_size)) + "Instance group had expected size %d but actual size %d" + % (expected_size, current_size) + ) time.sleep(2) -def wait_for_global_operation(gcp, - operation, - timeout_sec=_WAIT_FOR_OPERATION_SEC): +def wait_for_global_operation( + gcp, operation, timeout_sec=_WAIT_FOR_OPERATION_SEC +): start_time = time.time() while time.time() - start_time <= timeout_sec: - result = gcp.compute.globalOperations().get( - project=gcp.project, - operation=operation).execute(num_retries=_GCP_API_RETRIES) - if result['status'] == 'DONE': - if 'error' in result: - raise Exception(result['error']) + result = ( + gcp.compute.globalOperations() + .get(project=gcp.project, operation=operation) + .execute(num_retries=_GCP_API_RETRIES) + ) + if result["status"] == "DONE": + if "error" in result: + raise Exception(result["error"]) return time.sleep(2) - raise Exception('Operation %s did not complete within %d' % - (operation, timeout_sec)) + raise Exception( + "Operation %s did not complete within %d" % (operation, timeout_sec) + ) -def wait_for_zone_operation(gcp, - zone, - operation, - timeout_sec=_WAIT_FOR_OPERATION_SEC): +def wait_for_zone_operation( + gcp, zone, operation, timeout_sec=_WAIT_FOR_OPERATION_SEC +): start_time = time.time() while time.time() - start_time <= timeout_sec: - result = gcp.compute.zoneOperations().get( - project=gcp.project, zone=zone, - operation=operation).execute(num_retries=_GCP_API_RETRIES) - if result['status'] == 'DONE': - if 'error' in result: - raise Exception(result['error']) + result = ( + gcp.compute.zoneOperations() + .get(project=gcp.project, zone=zone, operation=operation) + .execute(num_retries=_GCP_API_RETRIES) + ) + if result["status"] == "DONE": + if "error" in result: + raise Exception(result["error"]) return time.sleep(2) - raise Exception('Operation %s did not complete within %d' % - (operation, timeout_sec)) + raise Exception( + "Operation %s did not complete within %d" % (operation, timeout_sec) + ) -def wait_for_healthy_backends(gcp, - backend_service, - instance_group, - timeout_sec=_WAIT_FOR_BACKEND_SEC): +def wait_for_healthy_backends( + gcp, backend_service, instance_group, timeout_sec=_WAIT_FOR_BACKEND_SEC +): start_time = time.time() - config = {'group': instance_group.url} + config = {"group": instance_group.url} instance_names = get_instance_names(gcp, instance_group) expected_size = len(instance_names) while time.time() - start_time <= timeout_sec: for instance_name in instance_names: try: status = get_serving_status(instance_name, gcp.service_port) - logger.info('serving status response from %s: %s', - instance_name, status) + logger.info( + "serving status response from %s: %s", instance_name, status + ) except grpc.RpcError as rpc_error: - logger.info('checking serving status of %s failed: %s', - instance_name, rpc_error) - result = gcp.compute.backendServices().getHealth( - project=gcp.project, - backendService=backend_service.name, - body=config).execute(num_retries=_GCP_API_RETRIES) - if 'healthStatus' in result: - logger.info('received GCP healthStatus: %s', result['healthStatus']) + logger.info( + "checking serving status of %s failed: %s", + instance_name, + rpc_error, + ) + result = ( + gcp.compute.backendServices() + .getHealth( + project=gcp.project, + backendService=backend_service.name, + body=config, + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + if "healthStatus" in result: + logger.info("received GCP healthStatus: %s", result["healthStatus"]) healthy = True - for instance in result['healthStatus']: - if instance['healthState'] != 'HEALTHY': + for instance in result["healthStatus"]: + if instance["healthState"] != "HEALTHY": healthy = False break - if healthy and expected_size == len(result['healthStatus']): + if healthy and expected_size == len(result["healthStatus"]): return else: - logger.info('no healthStatus received from GCP') + logger.info("no healthStatus received from GCP") time.sleep(5) - raise Exception('Not all backends became healthy within %d seconds: %s' % - (timeout_sec, result)) + raise Exception( + "Not all backends became healthy within %d seconds: %s" + % (timeout_sec, result) + ) def get_instance_names(gcp, instance_group): instance_names = [] - result = gcp.compute.instanceGroups().listInstances( - project=gcp.project, - zone=instance_group.zone, - instanceGroup=instance_group.name, - body={ - 'instanceState': 'ALL' - }).execute(num_retries=_GCP_API_RETRIES) - if 'items' not in result: + result = ( + gcp.compute.instanceGroups() + .listInstances( + project=gcp.project, + zone=instance_group.zone, + instanceGroup=instance_group.name, + body={"instanceState": "ALL"}, + ) + .execute(num_retries=_GCP_API_RETRIES) + ) + if "items" not in result: return [] - for item in result['items']: + for item in result["items"]: # listInstances() returns the full URL of the instance, which ends with # the instance name. compute.instances().get() requires using the # instance name (not the full URL) to look up instance details, so we # just extract the name manually. - instance_name = item['instance'].split('/')[-1] + instance_name = item["instance"].split("/")[-1] instance_names.append(instance_name) - logger.info('retrieved instance names: %s', instance_names) + logger.info("retrieved instance names: %s", instance_names) return instance_names @@ -3151,7 +3832,6 @@ def clean_up(gcp): class InstanceGroup(object): - def __init__(self, name, url, zone): self.name = name self.url = url @@ -3159,14 +3839,12 @@ def __init__(self, name, url, zone): class GcpResource(object): - def __init__(self, name, url): self.name = name self.url = url class GcpState(object): - def __init__(self, compute, alpha_compute, project, project_num): self.compute = compute self.alpha_compute = alpha_compute @@ -3186,23 +3864,29 @@ def __init__(self, compute, alpha_compute, project, project_num): logging.debug( "script start time: %s", - datetime.datetime.now( - datetime.timezone.utc).astimezone().strftime("%Y-%m-%dT%H:%M:%S %Z")) -logging.debug("logging local timezone: %s", - datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo) + datetime.datetime.now(datetime.timezone.utc) + .astimezone() + .strftime("%Y-%m-%dT%H:%M:%S %Z"), +) +logging.debug( + "logging local timezone: %s", + datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo, +) alpha_compute = None if args.compute_discovery_document: - with open(args.compute_discovery_document, 'r') as discovery_doc: + with open(args.compute_discovery_document, "r") as discovery_doc: compute = googleapiclient.discovery.build_from_document( - discovery_doc.read()) + discovery_doc.read() + ) if not args.only_stable_gcp_apis and args.alpha_compute_discovery_document: - with open(args.alpha_compute_discovery_document, 'r') as discovery_doc: + with open(args.alpha_compute_discovery_document, "r") as discovery_doc: alpha_compute = googleapiclient.discovery.build_from_document( - discovery_doc.read()) + discovery_doc.read() + ) else: - compute = googleapiclient.discovery.build('compute', 'v1') + compute = googleapiclient.discovery.build("compute", "v1") if not args.only_stable_gcp_apis: - alpha_compute = googleapiclient.discovery.build('compute', 'alpha') + alpha_compute = googleapiclient.discovery.build("compute", "alpha") test_results = {} failed_tests = [] @@ -3220,151 +3904,196 @@ def __init__(self, compute, alpha_compute, project, project_num): num_attempts = 5 for i in range(num_attempts): try: - logger.info('Using GCP suffix %s', gcp_suffix) + logger.info("Using GCP suffix %s", gcp_suffix) create_health_check(gcp, health_check_name) break except googleapiclient.errors.HttpError as http_error: - gcp_suffix = '%s-%04d' % (gcp_suffix, random.randint(0, 9999)) + gcp_suffix = "%s-%04d" % (gcp_suffix, random.randint(0, 9999)) health_check_name = _BASE_HEALTH_CHECK_NAME + gcp_suffix - logger.exception('HttpError when creating health check') + logger.exception("HttpError when creating health check") if gcp.health_check is None: - raise Exception('Failed to create health check name after %d ' - 'attempts' % num_attempts) + raise Exception( + "Failed to create health check name after %d attempts" + % num_attempts + ) firewall_name = _BASE_FIREWALL_RULE_NAME + gcp_suffix backend_service_name = _BASE_BACKEND_SERVICE_NAME + gcp_suffix - alternate_backend_service_name = _BASE_BACKEND_SERVICE_NAME + '-alternate' + gcp_suffix - extra_backend_service_name = _BASE_BACKEND_SERVICE_NAME + '-extra' + gcp_suffix - more_extra_backend_service_name = _BASE_BACKEND_SERVICE_NAME + '-more-extra' + gcp_suffix + alternate_backend_service_name = ( + _BASE_BACKEND_SERVICE_NAME + "-alternate" + gcp_suffix + ) + extra_backend_service_name = ( + _BASE_BACKEND_SERVICE_NAME + "-extra" + gcp_suffix + ) + more_extra_backend_service_name = ( + _BASE_BACKEND_SERVICE_NAME + "-more-extra" + gcp_suffix + ) url_map_name = _BASE_URL_MAP_NAME + gcp_suffix - url_map_name_2 = url_map_name + '2' + url_map_name_2 = url_map_name + "2" service_host_name = _BASE_SERVICE_HOST + gcp_suffix target_proxy_name = _BASE_TARGET_PROXY_NAME + gcp_suffix - target_proxy_name_2 = target_proxy_name + '2' + target_proxy_name_2 = target_proxy_name + "2" forwarding_rule_name = _BASE_FORWARDING_RULE_NAME + gcp_suffix - forwarding_rule_name_2 = forwarding_rule_name + '2' + forwarding_rule_name_2 = forwarding_rule_name + "2" template_name = _BASE_TEMPLATE_NAME + gcp_suffix instance_group_name = _BASE_INSTANCE_GROUP_NAME + gcp_suffix - same_zone_instance_group_name = _BASE_INSTANCE_GROUP_NAME + '-same-zone' + gcp_suffix - secondary_zone_instance_group_name = _BASE_INSTANCE_GROUP_NAME + '-secondary-zone' + gcp_suffix + same_zone_instance_group_name = ( + _BASE_INSTANCE_GROUP_NAME + "-same-zone" + gcp_suffix + ) + secondary_zone_instance_group_name = ( + _BASE_INSTANCE_GROUP_NAME + "-secondary-zone" + gcp_suffix + ) potential_service_ports = list(args.service_port_range) random.shuffle(potential_service_ports) if args.use_existing_gcp_resources: - logger.info('Reusing existing GCP resources') + logger.info("Reusing existing GCP resources") get_health_check(gcp, health_check_name) get_health_check_firewall_rule(gcp, firewall_name) backend_service = get_backend_service(gcp, backend_service_name) alternate_backend_service = get_backend_service( - gcp, alternate_backend_service_name) - extra_backend_service = get_backend_service(gcp, - extra_backend_service_name, - record_error=False) + gcp, alternate_backend_service_name + ) + extra_backend_service = get_backend_service( + gcp, extra_backend_service_name, record_error=False + ) more_extra_backend_service = get_backend_service( - gcp, more_extra_backend_service_name, record_error=False) + gcp, more_extra_backend_service_name, record_error=False + ) get_url_map(gcp, url_map_name) get_target_proxy(gcp, target_proxy_name) get_global_forwarding_rule(gcp, forwarding_rule_name) get_url_map(gcp, url_map_name_2, record_error=False) get_target_proxy(gcp, target_proxy_name_2, record_error=False) - get_global_forwarding_rule(gcp, - forwarding_rule_name_2, - record_error=False) + get_global_forwarding_rule( + gcp, forwarding_rule_name_2, record_error=False + ) get_instance_template(gcp, template_name) instance_group = get_instance_group(gcp, args.zone, instance_group_name) same_zone_instance_group = get_instance_group( - gcp, args.zone, same_zone_instance_group_name) + gcp, args.zone, same_zone_instance_group_name + ) secondary_zone_instance_group = get_instance_group( - gcp, args.secondary_zone, secondary_zone_instance_group_name) + gcp, args.secondary_zone, secondary_zone_instance_group_name + ) if gcp.errors: raise Exception(gcp.errors) else: create_health_check_firewall_rule(gcp, firewall_name) backend_service = add_backend_service(gcp, backend_service_name) alternate_backend_service = add_backend_service( - gcp, alternate_backend_service_name) + gcp, alternate_backend_service_name + ) create_url_map(gcp, url_map_name, backend_service, service_host_name) create_target_proxy(gcp, target_proxy_name) - create_global_forwarding_rule(gcp, forwarding_rule_name, - potential_service_ports) + create_global_forwarding_rule( + gcp, forwarding_rule_name, potential_service_ports + ) if not gcp.service_port: raise Exception( - 'Failed to find a valid ip:port for the forwarding rule') + "Failed to find a valid ip:port for the forwarding rule" + ) if gcp.service_port != _DEFAULT_SERVICE_PORT: - patch_url_map_host_rule_with_port(gcp, url_map_name, - backend_service, - service_host_name) - startup_script = get_startup_script(args.path_to_server_binary, - gcp.service_port) - create_instance_template(gcp, template_name, args.network, - args.source_image, args.machine_type, - startup_script) - instance_group = add_instance_group(gcp, args.zone, instance_group_name, - _INSTANCE_GROUP_SIZE) + patch_url_map_host_rule_with_port( + gcp, url_map_name, backend_service, service_host_name + ) + startup_script = get_startup_script( + args.path_to_server_binary, gcp.service_port + ) + create_instance_template( + gcp, + template_name, + args.network, + args.source_image, + args.machine_type, + startup_script, + ) + instance_group = add_instance_group( + gcp, args.zone, instance_group_name, _INSTANCE_GROUP_SIZE + ) patch_backend_service(gcp, backend_service, [instance_group]) same_zone_instance_group = add_instance_group( - gcp, args.zone, same_zone_instance_group_name, _INSTANCE_GROUP_SIZE) + gcp, args.zone, same_zone_instance_group_name, _INSTANCE_GROUP_SIZE + ) secondary_zone_instance_group = add_instance_group( - gcp, args.secondary_zone, secondary_zone_instance_group_name, - _INSTANCE_GROUP_SIZE) + gcp, + args.secondary_zone, + secondary_zone_instance_group_name, + _INSTANCE_GROUP_SIZE, + ) wait_for_healthy_backends(gcp, backend_service, instance_group) if args.test_case: client_env = dict(os.environ) if original_grpc_trace: - client_env['GRPC_TRACE'] = original_grpc_trace + client_env["GRPC_TRACE"] = original_grpc_trace if original_grpc_verbosity: - client_env['GRPC_VERBOSITY'] = original_grpc_verbosity + client_env["GRPC_VERBOSITY"] = original_grpc_verbosity bootstrap_server_features = [] if gcp.service_port == _DEFAULT_SERVICE_PORT: server_uri = service_host_name else: - server_uri = service_host_name + ':' + str(gcp.service_port) + server_uri = service_host_name + ":" + str(gcp.service_port) if args.xds_v3_support: - client_env['GRPC_XDS_EXPERIMENTAL_V3_SUPPORT'] = 'true' - bootstrap_server_features.append('xds_v3') + client_env["GRPC_XDS_EXPERIMENTAL_V3_SUPPORT"] = "true" + bootstrap_server_features.append("xds_v3") if args.bootstrap_file: bootstrap_path = os.path.abspath(args.bootstrap_file) else: with tempfile.NamedTemporaryFile(delete=False) as bootstrap_file: bootstrap_file.write( _BOOTSTRAP_TEMPLATE.format( - node_id='projects/%s/networks/%s/nodes/%s' % - (gcp.project_num, args.network.split('/')[-1], - uuid.uuid1()), - server_features=json.dumps( - bootstrap_server_features)).encode('utf-8')) + node_id="projects/%s/networks/%s/nodes/%s" + % ( + gcp.project_num, + args.network.split("/")[-1], + uuid.uuid1(), + ), + server_features=json.dumps(bootstrap_server_features), + ).encode("utf-8") + ) bootstrap_path = bootstrap_file.name - client_env['GRPC_XDS_BOOTSTRAP'] = bootstrap_path - client_env['GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING'] = 'true' - client_env['GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT'] = 'true' - client_env['GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION'] = 'true' + client_env["GRPC_XDS_BOOTSTRAP"] = bootstrap_path + client_env["GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING"] = "true" + client_env["GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT"] = "true" + client_env["GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION"] = "true" for test_case in args.test_case: if test_case in _V3_TEST_CASES and not args.xds_v3_support: - logger.info('skipping test %s due to missing v3 support', - test_case) + logger.info( + "skipping test %s due to missing v3 support", test_case + ) continue if test_case in _ALPHA_TEST_CASES and not gcp.alpha_compute: - logger.info('skipping test %s due to missing alpha support', - test_case) + logger.info( + "skipping test %s due to missing alpha support", test_case + ) continue - if test_case in [ - 'api_listener', 'forwarding_rule_port_match', - 'forwarding_rule_default_port' - ] and CLIENT_HOSTS: + if ( + test_case + in [ + "api_listener", + "forwarding_rule_port_match", + "forwarding_rule_default_port", + ] + and CLIENT_HOSTS + ): logger.info( - 'skipping test %s because test configuration is' - 'not compatible with client processes on existing' - 'client hosts', test_case) + ( + "skipping test %s because test configuration is" + "not compatible with client processes on existing" + "client hosts" + ), + test_case, + ) continue - if test_case == 'forwarding_rule_default_port': + if test_case == "forwarding_rule_default_port": server_uri = service_host_name result = jobset.JobResult() log_dir = os.path.join(_TEST_LOG_BASE_DIR, test_case) if not os.path.exists(log_dir): os.makedirs(log_dir) test_log_filename = os.path.join(log_dir, _SPONGE_LOG_NAME) - test_log_file = open(test_log_filename, 'w+') + test_log_file = open(test_log_filename, "w+") client_process = None if test_case in _TESTS_TO_RUN_MULTIPLE_RPCS: @@ -3379,14 +4108,15 @@ def __init__(self, compute, alpha_compute, project, project_num): keyU=_TEST_METADATA_KEY, valueU=_TEST_METADATA_VALUE_UNARY, keyNU=_TEST_METADATA_NUMERIC_KEY, - valueNU=_TEST_METADATA_NUMERIC_VALUE) + valueNU=_TEST_METADATA_NUMERIC_VALUE, + ) else: # Setting the arg explicitly to empty with '--metadata=""' # makes C# client fail # (see https://github.com/commandlineparser/commandline/issues/412), # so instead we just rely on clients using the default when # metadata arg is not specified. - metadata_to_send = '' + metadata_to_send = "" # TODO(ericgribkoff) Temporarily disable fail_on_failed_rpc checks # in the client. This means we will ignore intermittent RPC @@ -3401,7 +4131,7 @@ def __init__(self, compute, alpha_compute, project, project_num): # A fix is to not share resources between tests (though that does # mean the tests will be significantly slower due to creating new # resources). - fail_on_failed_rpc = '' + fail_on_failed_rpc = "" try: if not CLIENT_HOSTS: @@ -3411,91 +4141,145 @@ def __init__(self, compute, alpha_compute, project, project_num): qps=args.qps, fail_on_failed_rpc=fail_on_failed_rpc, rpcs_to_send=rpcs_to_send, - metadata_to_send=metadata_to_send) - logger.debug('running client: %s', client_cmd_formatted) + metadata_to_send=metadata_to_send, + ) + logger.debug("running client: %s", client_cmd_formatted) client_cmd = shlex.split(client_cmd_formatted) - client_process = subprocess.Popen(client_cmd, - env=client_env, - stderr=subprocess.STDOUT, - stdout=test_log_file) - if test_case == 'backends_restart': + client_process = subprocess.Popen( + client_cmd, + env=client_env, + stderr=subprocess.STDOUT, + stdout=test_log_file, + ) + if test_case == "backends_restart": test_backends_restart(gcp, backend_service, instance_group) - elif test_case == 'change_backend_service': - test_change_backend_service(gcp, backend_service, - instance_group, - alternate_backend_service, - same_zone_instance_group) - elif test_case == 'gentle_failover': - test_gentle_failover(gcp, backend_service, instance_group, - secondary_zone_instance_group) - elif test_case == 'load_report_based_failover': + elif test_case == "change_backend_service": + test_change_backend_service( + gcp, + backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, + ) + elif test_case == "gentle_failover": + test_gentle_failover( + gcp, + backend_service, + instance_group, + secondary_zone_instance_group, + ) + elif test_case == "load_report_based_failover": test_load_report_based_failover( - gcp, backend_service, instance_group, - secondary_zone_instance_group) - elif test_case == 'ping_pong': + gcp, + backend_service, + instance_group, + secondary_zone_instance_group, + ) + elif test_case == "ping_pong": test_ping_pong(gcp, backend_service, instance_group) - elif test_case == 'remove_instance_group': - test_remove_instance_group(gcp, backend_service, - instance_group, - same_zone_instance_group) - elif test_case == 'round_robin': + elif test_case == "remove_instance_group": + test_remove_instance_group( + gcp, + backend_service, + instance_group, + same_zone_instance_group, + ) + elif test_case == "round_robin": test_round_robin(gcp, backend_service, instance_group) - elif test_case == 'secondary_locality_gets_no_requests_on_partial_primary_failure': + elif ( + test_case + == "secondary_locality_gets_no_requests_on_partial_primary_failure" + ): test_secondary_locality_gets_no_requests_on_partial_primary_failure( - gcp, backend_service, instance_group, - secondary_zone_instance_group) - elif test_case == 'secondary_locality_gets_requests_on_primary_failure': + gcp, + backend_service, + instance_group, + secondary_zone_instance_group, + ) + elif ( + test_case + == "secondary_locality_gets_requests_on_primary_failure" + ): test_secondary_locality_gets_requests_on_primary_failure( - gcp, backend_service, instance_group, - secondary_zone_instance_group) - elif test_case == 'traffic_splitting': - test_traffic_splitting(gcp, backend_service, instance_group, - alternate_backend_service, - same_zone_instance_group) - elif test_case == 'path_matching': - test_path_matching(gcp, backend_service, instance_group, - alternate_backend_service, - same_zone_instance_group) - elif test_case == 'header_matching': - test_header_matching(gcp, backend_service, instance_group, - alternate_backend_service, - same_zone_instance_group) - elif test_case == 'circuit_breaking': - test_circuit_breaking(gcp, backend_service, instance_group, - same_zone_instance_group) - elif test_case == 'timeout': + gcp, + backend_service, + instance_group, + secondary_zone_instance_group, + ) + elif test_case == "traffic_splitting": + test_traffic_splitting( + gcp, + backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, + ) + elif test_case == "path_matching": + test_path_matching( + gcp, + backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, + ) + elif test_case == "header_matching": + test_header_matching( + gcp, + backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, + ) + elif test_case == "circuit_breaking": + test_circuit_breaking( + gcp, + backend_service, + instance_group, + same_zone_instance_group, + ) + elif test_case == "timeout": test_timeout(gcp, backend_service, instance_group) - elif test_case == 'fault_injection': + elif test_case == "fault_injection": test_fault_injection(gcp, backend_service, instance_group) - elif test_case == 'api_listener': - server_uri = test_api_listener(gcp, backend_service, - instance_group, - alternate_backend_service) - elif test_case == 'forwarding_rule_port_match': + elif test_case == "api_listener": + server_uri = test_api_listener( + gcp, + backend_service, + instance_group, + alternate_backend_service, + ) + elif test_case == "forwarding_rule_port_match": server_uri = test_forwarding_rule_port_match( - gcp, backend_service, instance_group) - elif test_case == 'forwarding_rule_default_port': + gcp, backend_service, instance_group + ) + elif test_case == "forwarding_rule_default_port": server_uri = test_forwarding_rule_default_port( - gcp, backend_service, instance_group) - elif test_case == 'metadata_filter': - test_metadata_filter(gcp, backend_service, instance_group, - alternate_backend_service, - same_zone_instance_group) - elif test_case == 'csds': + gcp, backend_service, instance_group + ) + elif test_case == "metadata_filter": + test_metadata_filter( + gcp, + backend_service, + instance_group, + alternate_backend_service, + same_zone_instance_group, + ) + elif test_case == "csds": test_csds(gcp, backend_service, instance_group, server_uri) else: - logger.error('Unknown test case: %s', test_case) + logger.error("Unknown test case: %s", test_case) sys.exit(1) if client_process and client_process.poll() is not None: raise Exception( - 'Client process exited prematurely with exit code %d' % - client_process.returncode) - result.state = 'PASSED' + "Client process exited prematurely with exit code %d" + % client_process.returncode + ) + result.state = "PASSED" result.returncode = 0 except Exception as e: - logger.exception('Test case %s failed', test_case) + logger.exception("Test case %s failed", test_case) failed_tests.append(test_case) - result.state = 'FAILED' + result.state = "FAILED" result.message = str(e) if args.halt_after_fail: # Stop the test suite if one case failed. @@ -3503,36 +4287,39 @@ def __init__(self, compute, alpha_compute, project, project_num): finally: if client_process: if client_process.returncode: - logger.info('Client exited with code %d' % - client_process.returncode) + logger.info( + "Client exited with code %d" + % client_process.returncode + ) else: client_process.terminate() test_log_file.close() # Workaround for Python 3, as report_utils will invoke decode() on # result.message, which has a default value of ''. - result.message = result.message.encode('UTF-8') + result.message = result.message.encode("UTF-8") test_results[test_case] = [result] if args.log_client_output: - logger.info('Client output:') - with open(test_log_filename, 'r') as client_output: + logger.info("Client output:") + with open(test_log_filename, "r") as client_output: logger.info(client_output.read()) if not os.path.exists(_TEST_LOG_BASE_DIR): os.makedirs(_TEST_LOG_BASE_DIR) - report_utils.render_junit_xml_report(test_results, - os.path.join( - _TEST_LOG_BASE_DIR, - _SPONGE_XML_NAME), - suite_name='xds_tests', - multi_target=True) + report_utils.render_junit_xml_report( + test_results, + os.path.join(_TEST_LOG_BASE_DIR, _SPONGE_XML_NAME), + suite_name="xds_tests", + multi_target=True, + ) if failed_tests: - logger.error('Test case(s) %s failed', failed_tests) + logger.error("Test case(s) %s failed", failed_tests) sys.exit(1) finally: keep_resources = args.keep_gcp_resources if args.halt_after_fail and failed_tests: logger.info( - 'Halt after fail triggered, exiting without cleaning up resources') + "Halt after fail triggered, exiting without cleaning up resources" + ) keep_resources = True if not keep_resources: - logger.info('Cleaning up GCP resources. This may take some time.') + logger.info("Cleaning up GCP resources. This may take some time.") clean_up(gcp) diff --git a/tools/run_tests/sanity/check_banned_filenames.py b/tools/run_tests/sanity/check_banned_filenames.py index f9a04c06755ce..0d3f39a62e92e 100755 --- a/tools/run_tests/sanity/check_banned_filenames.py +++ b/tools/run_tests/sanity/check_banned_filenames.py @@ -18,10 +18,10 @@ import sys BANNED_FILENAMES = [ - 'BUILD.gn', + "BUILD.gn", ] -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) bad = [] for filename in BANNED_FILENAMES: diff --git a/tools/run_tests/sanity/check_bazel_workspace.py b/tools/run_tests/sanity/check_bazel_workspace.py index 67d8a86044a13..73b9315aa6f62 100755 --- a/tools/run_tests/sanity/check_bazel_workspace.py +++ b/tools/run_tests/sanity/check_bazel_workspace.py @@ -20,47 +20,72 @@ import subprocess import sys -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) -git_hash_pattern = re.compile('[0-9a-f]{40}') +git_hash_pattern = re.compile("[0-9a-f]{40}") # Parse git hashes from submodules -git_submodules = subprocess.check_output( - 'git submodule', shell=True).decode().strip().split('\n') +git_submodules = ( + subprocess.check_output("git submodule", shell=True) + .decode() + .strip() + .split("\n") +) git_submodule_hashes = { re.search(git_hash_pattern, s).group() for s in git_submodules } -_BAZEL_SKYLIB_DEP_NAME = 'bazel_skylib' -_BAZEL_TOOLCHAINS_DEP_NAME = 'bazel_toolchains' -_BAZEL_COMPDB_DEP_NAME = 'bazel_compdb' -_TWISTED_TWISTED_DEP_NAME = 'com_github_twisted_twisted' -_YAML_PYYAML_DEP_NAME = 'com_github_yaml_pyyaml' -_TWISTED_INCREMENTAL_DEP_NAME = 'com_github_twisted_incremental' -_ZOPEFOUNDATION_ZOPE_INTERFACE_DEP_NAME = 'com_github_zopefoundation_zope_interface' -_TWISTED_CONSTANTLY_DEP_NAME = 'com_github_twisted_constantly' +_BAZEL_SKYLIB_DEP_NAME = "bazel_skylib" +_BAZEL_TOOLCHAINS_DEP_NAME = "bazel_toolchains" +_BAZEL_COMPDB_DEP_NAME = "bazel_compdb" +_TWISTED_TWISTED_DEP_NAME = "com_github_twisted_twisted" +_YAML_PYYAML_DEP_NAME = "com_github_yaml_pyyaml" +_TWISTED_INCREMENTAL_DEP_NAME = "com_github_twisted_incremental" +_ZOPEFOUNDATION_ZOPE_INTERFACE_DEP_NAME = ( + "com_github_zopefoundation_zope_interface" +) +_TWISTED_CONSTANTLY_DEP_NAME = "com_github_twisted_constantly" _GRPC_DEP_NAMES = [ - 'upb', 'boringssl', 'zlib', 'com_google_protobuf', 'com_google_googletest', - 'rules_cc', 'com_github_google_benchmark', 'com_github_cares_cares', - 'com_google_absl', 'com_google_fuzztest', 'io_opencensus_cpp', 'envoy_api', - _BAZEL_SKYLIB_DEP_NAME, _BAZEL_TOOLCHAINS_DEP_NAME, _BAZEL_COMPDB_DEP_NAME, - _TWISTED_TWISTED_DEP_NAME, _YAML_PYYAML_DEP_NAME, - _TWISTED_INCREMENTAL_DEP_NAME, _ZOPEFOUNDATION_ZOPE_INTERFACE_DEP_NAME, - _TWISTED_CONSTANTLY_DEP_NAME, 'io_bazel_rules_go', - 'build_bazel_rules_apple', 'build_bazel_apple_support', - 'com_github_libuv_libuv', 'com_googlesource_code_re2', 'bazel_gazelle', - 'opencensus_proto', 'com_envoyproxy_protoc_gen_validate', - 'com_google_googleapis', 'com_google_libprotobuf_mutator', - 'com_github_cncf_udpa' + "upb", + "boringssl", + "zlib", + "com_google_protobuf", + "com_google_googletest", + "rules_cc", + "com_github_google_benchmark", + "com_github_cares_cares", + "com_google_absl", + "com_google_fuzztest", + "io_opencensus_cpp", + "envoy_api", + _BAZEL_SKYLIB_DEP_NAME, + _BAZEL_TOOLCHAINS_DEP_NAME, + _BAZEL_COMPDB_DEP_NAME, + _TWISTED_TWISTED_DEP_NAME, + _YAML_PYYAML_DEP_NAME, + _TWISTED_INCREMENTAL_DEP_NAME, + _ZOPEFOUNDATION_ZOPE_INTERFACE_DEP_NAME, + _TWISTED_CONSTANTLY_DEP_NAME, + "io_bazel_rules_go", + "build_bazel_rules_apple", + "build_bazel_apple_support", + "com_github_libuv_libuv", + "com_googlesource_code_re2", + "bazel_gazelle", + "opencensus_proto", + "com_envoyproxy_protoc_gen_validate", + "com_google_googleapis", + "com_google_libprotobuf_mutator", + "com_github_cncf_udpa", ] _GRPC_BAZEL_ONLY_DEPS = [ - 'upb', # third_party/upb is checked in locally - 'rules_cc', - 'com_google_absl', - 'com_google_fuzztest', - 'io_opencensus_cpp', + "upb", # third_party/upb is checked in locally + "rules_cc", + "com_google_absl", + "com_google_fuzztest", + "io_opencensus_cpp", _BAZEL_SKYLIB_DEP_NAME, _BAZEL_TOOLCHAINS_DEP_NAME, _BAZEL_COMPDB_DEP_NAME, @@ -69,20 +94,19 @@ _TWISTED_INCREMENTAL_DEP_NAME, _ZOPEFOUNDATION_ZOPE_INTERFACE_DEP_NAME, _TWISTED_CONSTANTLY_DEP_NAME, - 'io_bazel_rules_go', - 'build_bazel_rules_apple', - 'build_bazel_apple_support', - 'com_googlesource_code_re2', - 'bazel_gazelle', - 'opencensus_proto', - 'com_envoyproxy_protoc_gen_validate', - 'com_google_googleapis', - 'com_google_libprotobuf_mutator' + "io_bazel_rules_go", + "build_bazel_rules_apple", + "build_bazel_apple_support", + "com_googlesource_code_re2", + "bazel_gazelle", + "opencensus_proto", + "com_envoyproxy_protoc_gen_validate", + "com_google_googleapis", + "com_google_libprotobuf_mutator", ] class BazelEvalState(object): - def __init__(self, names_and_urls, overridden_name=None): self.names_and_urls = names_and_urls self.overridden_name = overridden_name @@ -102,44 +126,44 @@ def existing_rules(self): return [] def archive(self, **args): - assert self.names_and_urls.get(args['name']) is None - if args['name'] in _GRPC_BAZEL_ONLY_DEPS: - self.names_and_urls[args['name']] = 'dont care' + assert self.names_and_urls.get(args["name"]) is None + if args["name"] in _GRPC_BAZEL_ONLY_DEPS: + self.names_and_urls[args["name"]] = "dont care" return - url = args.get('url', None) + url = args.get("url", None) if not url: # we will only be looking for git commit hashes, so concatenating # the urls is fine. - url = ' '.join(args['urls']) - self.names_and_urls[args['name']] = url + url = " ".join(args["urls"]) + self.names_and_urls[args["name"]] = url def git_repository(self, **args): - assert self.names_and_urls.get(args['name']) is None - if args['name'] in _GRPC_BAZEL_ONLY_DEPS: - self.names_and_urls[args['name']] = 'dont care' + assert self.names_and_urls.get(args["name"]) is None + if args["name"] in _GRPC_BAZEL_ONLY_DEPS: + self.names_and_urls[args["name"]] = "dont care" return - self.names_and_urls[args['name']] = args['remote'] + self.names_and_urls[args["name"]] = args["remote"] def grpc_python_deps(self): pass # Parse git hashes from bazel/grpc_deps.bzl {new_}http_archive rules -with open(os.path.join('bazel', 'grpc_deps.bzl'), 'r') as f: +with open(os.path.join("bazel", "grpc_deps.bzl"), "r") as f: names_and_urls = {} eval_state = BazelEvalState(names_and_urls) bazel_file = f.read() # grpc_deps.bzl only defines 'grpc_deps' and 'grpc_test_only_deps', add these # lines to call them. -bazel_file += '\ngrpc_deps()\n' -bazel_file += '\ngrpc_test_only_deps()\n' +bazel_file += "\ngrpc_deps()\n" +bazel_file += "\ngrpc_test_only_deps()\n" build_rules = { - 'native': eval_state, - 'http_archive': lambda **args: eval_state.http_archive(**args), - 'load': lambda a, b: None, - 'git_repository': lambda **args: eval_state.git_repository(**args), - 'grpc_python_deps': lambda: None, + "native": eval_state, + "http_archive": lambda **args: eval_state.http_archive(**args), + "load": lambda a, b: None, + "git_repository": lambda **args: eval_state.git_repository(**args), + "grpc_python_deps": lambda: None, } exec((bazel_file), build_rules) for name in _GRPC_DEP_NAMES: @@ -165,25 +189,29 @@ def grpc_python_deps(self): # not used by any of the targets built by Bazel. if len(workspace_git_hashes - git_submodule_hashes) > 0: print( - "Found discrepancies between git submodules and Bazel WORKSPACE dependencies" + "Found discrepancies between git submodules and Bazel WORKSPACE" + " dependencies" ) print(("workspace_git_hashes: %s" % workspace_git_hashes)) print(("git_submodule_hashes: %s" % git_submodule_hashes)) - print(("workspace_git_hashes - git_submodule_hashes: %s" % - (workspace_git_hashes - git_submodule_hashes))) + print( + "workspace_git_hashes - git_submodule_hashes: %s" + % (workspace_git_hashes - git_submodule_hashes) + ) sys.exit(1) # Also check that we can override each dependency for name in _GRPC_DEP_NAMES: names_and_urls_with_overridden_name = {} - state = BazelEvalState(names_and_urls_with_overridden_name, - overridden_name=name) + state = BazelEvalState( + names_and_urls_with_overridden_name, overridden_name=name + ) rules = { - 'native': state, - 'http_archive': lambda **args: state.http_archive(**args), - 'load': lambda a, b: None, - 'git_repository': lambda **args: state.git_repository(**args), - 'grpc_python_deps': lambda *args, **kwargs: None, + "native": state, + "http_archive": lambda **args: state.http_archive(**args), + "load": lambda a, b: None, + "git_repository": lambda **args: state.git_repository(**args), + "grpc_python_deps": lambda *args, **kwargs: None, } exec((bazel_file), rules) assert name not in list(names_and_urls_with_overridden_name.keys()) diff --git a/tools/run_tests/sanity/check_deprecated_grpc++.py b/tools/run_tests/sanity/check_deprecated_grpc++.py index 28f20f404c187..8dac39bdc4f6a 100755 --- a/tools/run_tests/sanity/check_deprecated_grpc++.py +++ b/tools/run_tests/sanity/check_deprecated_grpc++.py @@ -17,33 +17,44 @@ import os import sys -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) expected_files = [ - "include/grpc++/create_channel_posix.h", "include/grpc++/server_builder.h", - "include/grpc++/resource_quota.h", "include/grpc++/create_channel.h", - "include/grpc++/alarm.h", "include/grpc++/server.h", - "include/grpc++/server_context.h", "include/grpc++/client_context.h", - "include/grpc++/server_posix.h", "include/grpc++/grpc++.h", + "include/grpc++/create_channel_posix.h", + "include/grpc++/server_builder.h", + "include/grpc++/resource_quota.h", + "include/grpc++/create_channel.h", + "include/grpc++/alarm.h", + "include/grpc++/server.h", + "include/grpc++/server_context.h", + "include/grpc++/client_context.h", + "include/grpc++/server_posix.h", + "include/grpc++/grpc++.h", "include/grpc++/health_check_service_interface.h", - "include/grpc++/completion_queue.h", "include/grpc++/channel.h", - "include/grpc++/support/sync_stream.h", "include/grpc++/support/status.h", + "include/grpc++/completion_queue.h", + "include/grpc++/channel.h", + "include/grpc++/support/sync_stream.h", + "include/grpc++/support/status.h", "include/grpc++/support/config.h", "include/grpc++/support/status_code_enum.h", "include/grpc++/support/byte_buffer.h", "include/grpc++/support/error_details.h", "include/grpc++/support/async_unary_call.h", "include/grpc++/support/channel_arguments.h", - "include/grpc++/support/async_stream.h", "include/grpc++/support/slice.h", + "include/grpc++/support/async_stream.h", + "include/grpc++/support/slice.h", "include/grpc++/support/stub_options.h", - "include/grpc++/support/string_ref.h", "include/grpc++/support/time.h", + "include/grpc++/support/string_ref.h", + "include/grpc++/support/time.h", "include/grpc++/security/auth_metadata_processor.h", "include/grpc++/security/credentials.h", "include/grpc++/security/server_credentials.h", "include/grpc++/security/auth_context.h", "include/grpc++/impl/rpc_method.h", - "include/grpc++/impl/server_builder_option.h", "include/grpc++/impl/call.h", - "include/grpc++/impl/service_type.h", "include/grpc++/impl/grpc_library.h", + "include/grpc++/impl/server_builder_option.h", + "include/grpc++/impl/call.h", + "include/grpc++/impl/service_type.h", + "include/grpc++/impl/grpc_library.h", "include/grpc++/impl/client_unary_call.h", "include/grpc++/impl/channel_argument_option.h", "include/grpc++/impl/rpc_service_method.h", @@ -86,10 +97,10 @@ "include/grpc++/generic/async_generic_service.h", "include/grpc++/generic/generic_stub.h", "include/grpc++/test/mock_stream.h", - "include/grpc++/test/server_context_test_spouse.h" + "include/grpc++/test/server_context_test_spouse.h", ] -file_template = '''// +file_template = """// // // Copyright 2018 gRPC authors. // @@ -117,12 +128,12 @@ #include #endif // GRPCXX_FILE_PATH_NAME_UPPER -''' +""" errors = 0 path_files = [] -for root, dirs, files in os.walk('include/grpc++'): +for root, dirs, files in os.walk("include/grpc++"): for filename in files: path_file = os.path.join(root, filename) path_files.append(path_file) @@ -131,41 +142,43 @@ diff_plus = [file for file in path_files if file not in expected_files] diff_minus = [file for file in expected_files if file not in path_files] for file in diff_minus: - print(('- ', file)) + print(("- ", file)) for file in diff_plus: - print(('+ ', file)) + print(("+ ", file)) errors += 1 if errors > 0: sys.exit(errors) for path_file in expected_files: - relative_path_file = path_file.split('/', 2)[2] + relative_path_file = path_file.split("/", 2)[2] - replace_lower = relative_path_file.replace('+', 'p') + replace_lower = relative_path_file.replace("+", "p") - replace_upper = relative_path_file.replace('/', '_') - replace_upper = replace_upper.replace('.', '_') - replace_upper = replace_upper.upper().replace('+', 'X') + replace_upper = relative_path_file.replace("/", "_") + replace_upper = replace_upper.replace(".", "_") + replace_upper = replace_upper.upper().replace("+", "X") - expected_content = file_template.replace('FILE_PATH_NAME_LOWER', - replace_lower) - expected_content = expected_content.replace('FILE_PATH_NAME_UPPER', - replace_upper) + expected_content = file_template.replace( + "FILE_PATH_NAME_LOWER", replace_lower + ) + expected_content = expected_content.replace( + "FILE_PATH_NAME_UPPER", replace_upper + ) - path_file_expected = path_file + '.expected' + path_file_expected = path_file + ".expected" with open(path_file_expected, "w") as fo: fo.write(expected_content) - if 0 != os.system('diff %s %s' % (path_file_expected, path_file)): - print(('Difference found in file:', path_file)) + if 0 != os.system("diff %s %s" % (path_file_expected, path_file)): + print(("Difference found in file:", path_file)) errors += 1 os.remove(path_file_expected) check_extensions = [".h", ".cc", ".c", ".m"] -for root, dirs, files in os.walk('src'): +for root, dirs, files in os.walk("src"): for filename in files: path_file = os.path.join(root, filename) for ext in check_extensions: @@ -173,10 +186,11 @@ try: with open(path_file, "r") as fi: content = fi.read() - if '#include '), - (r'\n#include "grpc(.*)"', r'\n#include '), + (r'\n#include "include/(.*)"', r"\n#include <\1>"), + (r'\n#include "grpc(.*)"', r"\n#include "), ] -fix = sys.argv[1:] == ['--fix'] +fix = sys.argv[1:] == ["--fix"] if fix: print("FIXING!") @@ -35,14 +35,17 @@ def check_include_style(directory_root): for root, dirs, files in os.walk(directory_root): for filename in files: path = os.path.join(root, filename) - if os.path.splitext(path)[1] not in ['.c', '.cc', '.h']: + if os.path.splitext(path)[1] not in [".c", ".cc", ".h"]: continue - if filename.endswith('.pb.h') or filename.endswith('.pb.c'): + if filename.endswith(".pb.h") or filename.endswith(".pb.c"): continue # Skip check for upb generated code. - if (filename.endswith('.upb.h') or filename.endswith('.upb.c') or - filename.endswith('.upbdefs.h') or - filename.endswith('.upbdefs.c')): + if ( + filename.endswith(".upb.h") + or filename.endswith(".upb.c") + or filename.endswith(".upbdefs.h") + or filename.endswith(".upbdefs.c") + ): continue with open(path) as f: text = f.read() @@ -52,18 +55,18 @@ def check_include_style(directory_root): if text != original: bad_files.append(path) if fix: - with open(path, 'w') as f: + with open(path, "w") as f: f.write(text) return bad_files all_bad_files = [] -all_bad_files += check_include_style(os.path.join('src', 'core')) -all_bad_files += check_include_style(os.path.join('src', 'cpp')) -all_bad_files += check_include_style(os.path.join('test', 'core')) -all_bad_files += check_include_style(os.path.join('test', 'cpp')) -all_bad_files += check_include_style(os.path.join('include', 'grpc')) -all_bad_files += check_include_style(os.path.join('include', 'grpcpp')) +all_bad_files += check_include_style(os.path.join("src", "core")) +all_bad_files += check_include_style(os.path.join("src", "cpp")) +all_bad_files += check_include_style(os.path.join("test", "core")) +all_bad_files += check_include_style(os.path.join("test", "cpp")) +all_bad_files += check_include_style(os.path.join("include", "grpc")) +all_bad_files += check_include_style(os.path.join("include", "grpcpp")) if all_bad_files: for f in all_bad_files: diff --git a/tools/run_tests/sanity/check_package_name.py b/tools/run_tests/sanity/check_package_name.py index addca8b88fd6f..9611144b83496 100755 --- a/tools/run_tests/sanity/check_package_name.py +++ b/tools/run_tests/sanity/check_package_name.py @@ -17,57 +17,60 @@ import os import sys -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) # Allowance for overrides for specific files EXPECTED_NAMES = { - 'src/proto/grpc/channelz': 'channelz', - 'src/proto/grpc/status': 'status', - 'src/proto/grpc/testing': 'testing', - 'src/proto/grpc/testing/duplicate': 'duplicate', - 'src/proto/grpc/lb/v1': 'lb', - 'src/proto/grpc/testing/xds': 'xds', - 'src/proto/grpc/testing/xds/v3': 'xds_v3', - 'src/proto/grpc/core': 'core', - 'src/proto/grpc/health/v1': 'health', - 'src/proto/grpc/reflection/v1alpha': 'reflection', - 'src/proto/grpc/reflection/v1': 'reflection_v1', + "src/proto/grpc/channelz": "channelz", + "src/proto/grpc/status": "status", + "src/proto/grpc/testing": "testing", + "src/proto/grpc/testing/duplicate": "duplicate", + "src/proto/grpc/lb/v1": "lb", + "src/proto/grpc/testing/xds": "xds", + "src/proto/grpc/testing/xds/v3": "xds_v3", + "src/proto/grpc/core": "core", + "src/proto/grpc/health/v1": "health", + "src/proto/grpc/reflection/v1alpha": "reflection", + "src/proto/grpc/reflection/v1": "reflection_v1", } errors = 0 -for root, dirs, files in os.walk('.'): - if root.startswith('./'): - root = root[len('./'):] +for root, dirs, files in os.walk("."): + if root.startswith("./"): + root = root[len("./") :] # don't check third party - if root.startswith('third_party/'): + if root.startswith("third_party/"): continue # only check BUILD files - if 'BUILD' not in files: + if "BUILD" not in files: continue - text = open('%s/BUILD' % root).read() + text = open("%s/BUILD" % root).read() # find a grpc_package clause - pkg_start = text.find('grpc_package(') + pkg_start = text.find("grpc_package(") if pkg_start == -1: continue # parse it, taking into account nested parens - pkg_end = pkg_start + len('grpc_package(') + pkg_end = pkg_start + len("grpc_package(") level = 1 while level == 1: - if text[pkg_end] == ')': + if text[pkg_end] == ")": level -= 1 - elif text[pkg_end] == '(': + elif text[pkg_end] == "(": level += 1 pkg_end += 1 # it's a python statement, so evaluate it to pull out the name of the package - name = eval(text[pkg_start:pkg_end], - {'grpc_package': lambda name, **kwargs: name}) + name = eval( + text[pkg_start:pkg_end], {"grpc_package": lambda name, **kwargs: name} + ) # the name should be the path within the source tree, excepting some special # BUILD files (really we should normalize them too at some point) # TODO(ctiller): normalize all package names expected_name = EXPECTED_NAMES.get(root, root) if name != expected_name: - print("%s/BUILD should define a grpc_package with name=%r, not %r" % - (root, expected_name, name)) + print( + "%s/BUILD should define a grpc_package with name=%r, not %r" + % (root, expected_name, name) + ) errors += 1 if errors != 0: diff --git a/tools/run_tests/sanity/check_port_platform.py b/tools/run_tests/sanity/check_port_platform.py index fb6b68c53824a..9f5d0f55db251 100755 --- a/tools/run_tests/sanity/check_port_platform.py +++ b/tools/run_tests/sanity/check_port_platform.py @@ -17,7 +17,7 @@ import os import sys -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) def check_port_platform_inclusion(directory_root, legal_list): @@ -25,29 +25,32 @@ def check_port_platform_inclusion(directory_root, legal_list): for root, dirs, files in os.walk(directory_root): for filename in files: path = os.path.join(root, filename) - if os.path.splitext(path)[1] not in ['.c', '.cc', '.h']: + if os.path.splitext(path)[1] not in [".c", ".cc", ".h"]: continue if path in [ - os.path.join('include', 'grpc', 'support', - 'port_platform.h'), - os.path.join('include', 'grpc', 'impl', 'codegen', - 'port_platform.h'), + os.path.join("include", "grpc", "support", "port_platform.h"), + os.path.join( + "include", "grpc", "impl", "codegen", "port_platform.h" + ), ]: continue - if filename.endswith('.pb.h') or filename.endswith('.pb.c'): + if filename.endswith(".pb.h") or filename.endswith(".pb.c"): continue # Skip check for upb generated code. - if (filename.endswith('.upb.h') or filename.endswith('.upb.c') or - filename.endswith('.upbdefs.h') or - filename.endswith('.upbdefs.c')): + if ( + filename.endswith(".upb.h") + or filename.endswith(".upb.c") + or filename.endswith(".upbdefs.h") + or filename.endswith(".upbdefs.c") + ): continue with open(path) as f: all_lines_in_file = f.readlines() for index, l in enumerate(all_lines_in_file): - if '#include' in l: + if "#include" in l: if l not in legal_list: bad_files.append(path) - elif all_lines_in_file[index + 1] != '\n': + elif all_lines_in_file[index + 1] != "\n": # Require a blank line after including port_platform.h in # order to prevent the formatter from reording it's # inclusion order upon future changes. @@ -57,30 +60,37 @@ def check_port_platform_inclusion(directory_root, legal_list): all_bad_files = [] -all_bad_files += check_port_platform_inclusion(os.path.join('src', 'core'), [ - '#include \n', -]) -all_bad_files += check_port_platform_inclusion(os.path.join( - 'include', 'grpc'), [ - '#include \n', - '#include \n', - ]) +all_bad_files += check_port_platform_inclusion( + os.path.join("src", "core"), + [ + "#include \n", + ], +) +all_bad_files += check_port_platform_inclusion( + os.path.join("include", "grpc"), + [ + "#include \n", + "#include \n", + ], +) -if sys.argv[1:] == ['--fix']: +if sys.argv[1:] == ["--fix"]: for path in all_bad_files: - text = '' + text = "" found = False with open(path) as f: for l in f.readlines(): - if not found and '#include' in l: - text += '#include \n\n' + if not found and "#include" in l: + text += "#include \n\n" found = True text += l - with open(path, 'w') as f: + with open(path, "w") as f: f.write(text) else: if len(all_bad_files) > 0: for f in all_bad_files: - print((('port_platform.h is not the first included header or there ' - 'is not a blank line following its inclusion in %s') % f)) + print( + "port_platform.h is not the first included header or there " + "is not a blank line following its inclusion in %s" % f + ) sys.exit(1) diff --git a/tools/run_tests/sanity/check_qps_scenario_changes.py b/tools/run_tests/sanity/check_qps_scenario_changes.py index ce3de84e05d9a..a84b18053d8fc 100755 --- a/tools/run_tests/sanity/check_qps_scenario_changes.py +++ b/tools/run_tests/sanity/check_qps_scenario_changes.py @@ -18,17 +18,19 @@ import subprocess import sys -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../../test/cpp/qps')) -subprocess.check_call(['./json_run_localhost_scenario_gen.py']) -subprocess.check_call(['./qps_json_driver_scenario_gen.py']) -subprocess.check_call(['buildifier', '-v', '-r', '.']) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../../test/cpp/qps")) +subprocess.check_call(["./json_run_localhost_scenario_gen.py"]) +subprocess.check_call(["./qps_json_driver_scenario_gen.py"]) +subprocess.check_call(["buildifier", "-v", "-r", "."]) -output = subprocess.check_output(['git', 'status', '--porcelain']).decode() -qps_json_driver_bzl = 'test/cpp/qps/qps_json_driver_scenarios.bzl' -json_run_localhost_bzl = 'test/cpp/qps/json_run_localhost_scenarios.bzl' +output = subprocess.check_output(["git", "status", "--porcelain"]).decode() +qps_json_driver_bzl = "test/cpp/qps/qps_json_driver_scenarios.bzl" +json_run_localhost_bzl = "test/cpp/qps/json_run_localhost_scenarios.bzl" if qps_json_driver_bzl in output or json_run_localhost_bzl in output: - print('qps benchmark scenarios have been updated, please commit ' - 'test/cpp/qps/qps_json_driver_scenarios.bzl and/or ' - 'test/cpp/qps/json_run_localhost_scenarios.bzl') + print( + "qps benchmark scenarios have been updated, please commit " + "test/cpp/qps/qps_json_driver_scenarios.bzl and/or " + "test/cpp/qps/json_run_localhost_scenarios.bzl" + ) sys.exit(1) diff --git a/tools/run_tests/sanity/check_test_filtering.py b/tools/run_tests/sanity/check_test_filtering.py index d69d579bbf34f..3f2366f8c9db1 100755 --- a/tools/run_tests/sanity/check_test_filtering.py +++ b/tools/run_tests/sanity/check_test_filtering.py @@ -20,16 +20,24 @@ import unittest # hack import paths to pick up extra code -sys.path.insert(0, os.path.abspath('tools/run_tests/')) +sys.path.insert(0, os.path.abspath("tools/run_tests/")) import python_utils.filter_pull_request_tests as filter_pull_request_tests from run_tests_matrix import _create_portability_test_jobs from run_tests_matrix import _create_test_jobs _LIST_OF_LANGUAGE_LABELS = [ - 'c', 'c++', 'csharp', 'grpc-node', 'objc', 'php', 'php7', 'python', 'ruby' + "c", + "c++", + "csharp", + "grpc-node", + "objc", + "php", + "php7", + "python", + "ruby", ] -_LIST_OF_PLATFORM_LABELS = ['linux', 'macos', 'windows'] -_LIST_OF_SANITY_TESTS = ['sanity', 'clang-tidy', 'iwyu'] +_LIST_OF_PLATFORM_LABELS = ["linux", "macos", "windows"] +_LIST_OF_SANITY_TESTS = ["sanity", "clang-tidy", "iwyu"] def has_sanity_tests(job): @@ -40,7 +48,6 @@ def has_sanity_tests(job): class TestFilteringTest(unittest.TestCase): - def generate_all_tests(self): all_jobs = _create_test_jobs() + _create_portability_test_jobs() self.assertIsNotNone(all_jobs) @@ -48,11 +55,11 @@ def generate_all_tests(self): def test_filtering(self, changed_files=[], labels=_LIST_OF_LANGUAGE_LABELS): """ - Default args should filter no tests because changed_files is empty and - default labels should be able to match all jobs - :param changed_files: mock list of changed_files from pull request - :param labels: list of job labels that should be skipped - """ + Default args should filter no tests because changed_files is empty and + default labels should be able to match all jobs + :param changed_files: mock list of changed_files from pull request + :param labels: list of job labels that should be skipped + """ all_jobs = self.generate_all_tests() # Replacing _get_changed_files function to allow specifying changed files in filter_tests function @@ -74,8 +81,9 @@ def _get_changed_files(foo): if has_sanity_tests(job): sanity_tests_in_filtered_jobs += 1 filtered_jobs = [job for job in filtered_jobs if has_sanity_tests(job)] - self.assertEqual(sanity_tests_in_all_jobs, - sanity_tests_in_filtered_jobs) + self.assertEqual( + sanity_tests_in_all_jobs, sanity_tests_in_filtered_jobs + ) for label in labels: for job in filtered_jobs: @@ -88,85 +96,148 @@ def _get_changed_files(foo): for job in all_jobs: if has_sanity_tests(job): continue - if (label in job.labels): + if label in job.labels: jobs_matching_labels += 1 - self.assertEqual(len(filtered_jobs), - len(all_jobs) - jobs_matching_labels) + self.assertEqual( + len(filtered_jobs), len(all_jobs) - jobs_matching_labels + ) def test_individual_language_filters(self): # Changing unlisted file should trigger all languages - self.test_filtering(['ffffoo/bar.baz'], [_LIST_OF_LANGUAGE_LABELS]) + self.test_filtering(["ffffoo/bar.baz"], [_LIST_OF_LANGUAGE_LABELS]) # Changing core should trigger all tests - self.test_filtering(['src/core/foo.bar'], [_LIST_OF_LANGUAGE_LABELS]) + self.test_filtering(["src/core/foo.bar"], [_LIST_OF_LANGUAGE_LABELS]) # Testing individual languages - self.test_filtering(['test/core/foo.bar'], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._CORE_TEST_SUITE.labels + - filter_pull_request_tests._CPP_TEST_SUITE.labels - ]) - self.test_filtering(['src/cpp/foo.bar'], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._CPP_TEST_SUITE.labels - ]) - self.test_filtering(['src/csharp/foo.bar'], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._CSHARP_TEST_SUITE.labels - ]) - self.test_filtering(['src/objective-c/foo.bar'], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._OBJC_TEST_SUITE.labels - ]) - self.test_filtering(['src/php/foo.bar'], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._PHP_TEST_SUITE.labels - ]) - self.test_filtering(['src/python/foo.bar'], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._PYTHON_TEST_SUITE.labels - ]) - self.test_filtering(['src/ruby/foo.bar'], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._RUBY_TEST_SUITE.labels - ]) + self.test_filtering( + ["test/core/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label + not in filter_pull_request_tests._CORE_TEST_SUITE.labels + + filter_pull_request_tests._CPP_TEST_SUITE.labels + ], + ) + self.test_filtering( + ["src/cpp/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label not in filter_pull_request_tests._CPP_TEST_SUITE.labels + ], + ) + self.test_filtering( + ["src/csharp/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label + not in filter_pull_request_tests._CSHARP_TEST_SUITE.labels + ], + ) + self.test_filtering( + ["src/objective-c/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label + not in filter_pull_request_tests._OBJC_TEST_SUITE.labels + ], + ) + self.test_filtering( + ["src/php/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label not in filter_pull_request_tests._PHP_TEST_SUITE.labels + ], + ) + self.test_filtering( + ["src/python/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label + not in filter_pull_request_tests._PYTHON_TEST_SUITE.labels + ], + ) + self.test_filtering( + ["src/ruby/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label + not in filter_pull_request_tests._RUBY_TEST_SUITE.labels + ], + ) def test_combined_language_filters(self): - self.test_filtering(['src/cpp/foo.bar', 'test/core/foo.bar'], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._CPP_TEST_SUITE.labels and - label not in filter_pull_request_tests._CORE_TEST_SUITE.labels - ]) - self.test_filtering(['src/cpp/foo.bar', "src/csharp/foo.bar"], [ - label for label in _LIST_OF_LANGUAGE_LABELS - if label not in filter_pull_request_tests._CPP_TEST_SUITE.labels and - label not in filter_pull_request_tests._CSHARP_TEST_SUITE.labels - ]) - self.test_filtering([ - 'src/objective-c/foo.bar', 'src/php/foo.bar', "src/python/foo.bar", - "src/ruby/foo.bar" - ], [ - label for label in _LIST_OF_LANGUAGE_LABELS if - label not in filter_pull_request_tests._OBJC_TEST_SUITE.labels and - label not in filter_pull_request_tests._PHP_TEST_SUITE.labels and - label not in filter_pull_request_tests._PYTHON_TEST_SUITE.labels and - label not in filter_pull_request_tests._RUBY_TEST_SUITE.labels - ]) + self.test_filtering( + ["src/cpp/foo.bar", "test/core/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label not in filter_pull_request_tests._CPP_TEST_SUITE.labels + and label + not in filter_pull_request_tests._CORE_TEST_SUITE.labels + ], + ) + self.test_filtering( + ["src/cpp/foo.bar", "src/csharp/foo.bar"], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label not in filter_pull_request_tests._CPP_TEST_SUITE.labels + and label + not in filter_pull_request_tests._CSHARP_TEST_SUITE.labels + ], + ) + self.test_filtering( + [ + "src/objective-c/foo.bar", + "src/php/foo.bar", + "src/python/foo.bar", + "src/ruby/foo.bar", + ], + [ + label + for label in _LIST_OF_LANGUAGE_LABELS + if label + not in filter_pull_request_tests._OBJC_TEST_SUITE.labels + and label + not in filter_pull_request_tests._PHP_TEST_SUITE.labels + and label + not in filter_pull_request_tests._PYTHON_TEST_SUITE.labels + and label + not in filter_pull_request_tests._RUBY_TEST_SUITE.labels + ], + ) def test_platform_filter(self): - self.test_filtering(['vsprojects/foo.bar'], [ - label for label in _LIST_OF_PLATFORM_LABELS - if label not in filter_pull_request_tests._WINDOWS_TEST_SUITE.labels - ]) + self.test_filtering( + ["vsprojects/foo.bar"], + [ + label + for label in _LIST_OF_PLATFORM_LABELS + if label + not in filter_pull_request_tests._WINDOWS_TEST_SUITE.labels + ], + ) def test_allowlist(self): allowlist = filter_pull_request_tests._ALLOWLIST_DICT files_that_should_trigger_all_tests = [ - 'src/core/foo.bar', 'some_file_not_on_the_white_list', 'BUILD', - 'etc/roots.pem', 'Makefile', 'tools/foo' + "src/core/foo.bar", + "some_file_not_on_the_white_list", + "BUILD", + "etc/roots.pem", + "Makefile", + "tools/foo", ] for key in list(allowlist.keys()): for file_name in files_that_should_trigger_all_tests: self.assertFalse(re.match(key, file_name)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tools/run_tests/sanity/check_tracer_sanity.py b/tools/run_tests/sanity/check_tracer_sanity.py index 040e0852e23bd..fce69b2c270fb 100755 --- a/tools/run_tests/sanity/check_tracer_sanity.py +++ b/tools/run_tests/sanity/check_tracer_sanity.py @@ -18,29 +18,30 @@ import re import sys -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) errors = 0 tracers = [] -pattern = re.compile("GRPC_TRACER_INITIALIZER\((true|false), \"(.*)\"\)") -for root, dirs, files in os.walk('src/core'): +pattern = re.compile('GRPC_TRACER_INITIALIZER\((true|false), "(.*)"\)') +for root, dirs, files in os.walk("src/core"): for filename in files: path = os.path.join(root, filename) - if os.path.splitext(path)[1] != '.c': + if os.path.splitext(path)[1] != ".c": continue with open(path) as f: text = f.read() for o in pattern.findall(text): tracers.append(o[1]) -with open('doc/environment_variables.md') as f: +with open("doc/environment_variables.md") as f: text = f.read() for t in tracers: if t not in text: - print(( - "ERROR: tracer \"%s\" is not mentioned in doc/environment_variables.md" - % t)) + print( + 'ERROR: tracer "%s" is not mentioned in' + " doc/environment_variables.md" % t + ) errors += 1 assert errors == 0 diff --git a/tools/run_tests/sanity/check_version.py b/tools/run_tests/sanity/check_version.py index d84a172fd4d53..3e483cce06989 100755 --- a/tools/run_tests/sanity/check_version.py +++ b/tools/run_tests/sanity/check_version.py @@ -23,60 +23,69 @@ errors = 0 -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) # hack import paths to pick up extra code -sys.path.insert(0, os.path.abspath('tools/buildgen/plugins')) +sys.path.insert(0, os.path.abspath("tools/buildgen/plugins")) from expand_version import Version try: - branch_name = subprocess.check_output('git rev-parse --abbrev-ref HEAD', - shell=True).decode() + branch_name = subprocess.check_output( + "git rev-parse --abbrev-ref HEAD", shell=True + ).decode() except: - print('WARNING: not a git repository') + print("WARNING: not a git repository") branch_name = None if branch_name is not None: - m = re.match(r'^release-([0-9]+)_([0-9]+)$', branch_name) + m = re.match(r"^release-([0-9]+)_([0-9]+)$", branch_name) if m: - print('RELEASE branch') + print("RELEASE branch") # version number should align with the branched version - check_version = lambda version: (version.major == int(m.group(1)) and - version.minor == int(m.group(2))) - warning = 'Version key "%%s" value "%%s" should have a major version %s and minor version %s' % ( - m.group(1), m.group(2)) - elif re.match(r'^debian/.*$', branch_name): + check_version = lambda version: ( + version.major == int(m.group(1)) + and version.minor == int(m.group(2)) + ) + warning = ( + 'Version key "%%s" value "%%s" should have a major version %s and' + " minor version %s" % (m.group(1), m.group(2)) + ) + elif re.match(r"^debian/.*$", branch_name): # no additional version checks for debian branches check_version = lambda version: True else: # all other branches should have a -dev tag - check_version = lambda version: version.tag == 'dev' + check_version = lambda version: version.tag == "dev" warning = 'Version key "%s" value "%s" should have a -dev tag' else: check_version = lambda version: True -with open('build_handwritten.yaml', 'r') as f: +with open("build_handwritten.yaml", "r") as f: build_yaml = yaml.safe_load(f.read()) -settings = build_yaml['settings'] +settings = build_yaml["settings"] -top_version = Version(settings['version']) +top_version = Version(settings["version"]) if not check_version(top_version): errors += 1 - print((warning % ('version', top_version))) + print((warning % ("version", top_version))) for tag, value in list(settings.items()): - if re.match(r'^[a-z]+_version$', tag): + if re.match(r"^[a-z]+_version$", tag): value = Version(value) - if tag != 'core_version': + if tag != "core_version": if value.major != top_version.major: errors += 1 - print(('major version mismatch on %s: %d vs %d' % - (tag, value.major, top_version.major))) + print( + "major version mismatch on %s: %d vs %d" + % (tag, value.major, top_version.major) + ) if value.minor != top_version.minor: errors += 1 - print(('minor version mismatch on %s: %d vs %d' % - (tag, value.minor, top_version.minor))) + print( + "minor version mismatch on %s: %d vs %d" + % (tag, value.minor, top_version.minor) + ) if not check_version(value): errors += 1 print((warning % (tag, value))) diff --git a/tools/run_tests/sanity/core_banned_functions.py b/tools/run_tests/sanity/core_banned_functions.py index 9d2bf283927f6..fbef656345096 100755 --- a/tools/run_tests/sanity/core_banned_functions.py +++ b/tools/run_tests/sanity/core_banned_functions.py @@ -20,76 +20,83 @@ import os import sys -os.chdir(os.path.join(os.path.dirname(sys.argv[0]), '../../..')) +os.chdir(os.path.join(os.path.dirname(sys.argv[0]), "../../..")) # map of banned function signature to allowlist BANNED_EXCEPT = { - 'grpc_slice_from_static_buffer(': ['src/core/lib/slice/slice.cc'], - 'grpc_resource_quota_ref(': ['src/core/lib/resource_quota/api.cc'], - 'grpc_resource_quota_unref(': [ - 'src/core/lib/resource_quota/api.cc', 'src/core/lib/surface/server.cc' + "grpc_slice_from_static_buffer(": ["src/core/lib/slice/slice.cc"], + "grpc_resource_quota_ref(": ["src/core/lib/resource_quota/api.cc"], + "grpc_resource_quota_unref(": [ + "src/core/lib/resource_quota/api.cc", + "src/core/lib/surface/server.cc", ], - 'grpc_error_create(': [ - 'src/core/lib/iomgr/error.cc', 'src/core/lib/iomgr/error_cfstream.cc' + "grpc_error_create(": [ + "src/core/lib/iomgr/error.cc", + "src/core/lib/iomgr/error_cfstream.cc", ], - 'grpc_error_ref(': ['src/core/lib/iomgr/error.cc'], - 'grpc_error_unref(': ['src/core/lib/iomgr/error.cc'], - 'grpc_os_error(': [ - 'src/core/lib/iomgr/error.cc', 'src/core/lib/iomgr/error.h' + "grpc_error_ref(": ["src/core/lib/iomgr/error.cc"], + "grpc_error_unref(": ["src/core/lib/iomgr/error.cc"], + "grpc_os_error(": [ + "src/core/lib/iomgr/error.cc", + "src/core/lib/iomgr/error.h", ], - 'grpc_wsa_error(': [ - 'src/core/lib/iomgr/error.cc', 'src/core/lib/iomgr/error.h' + "grpc_wsa_error(": [ + "src/core/lib/iomgr/error.cc", + "src/core/lib/iomgr/error.h", ], - 'grpc_log_if_error(': [ - 'src/core/lib/iomgr/error.cc', 'src/core/lib/iomgr/error.h' + "grpc_log_if_error(": [ + "src/core/lib/iomgr/error.cc", + "src/core/lib/iomgr/error.h", ], - 'grpc_slice_malloc(': [ - 'src/core/lib/slice/slice.cc', 'src/core/lib/slice/slice.h' + "grpc_slice_malloc(": [ + "src/core/lib/slice/slice.cc", + "src/core/lib/slice/slice.h", ], - 'grpc_call_cancel(': ['src/core/lib/surface/call.cc'], - 'grpc_channel_destroy(': [ - 'src/core/lib/surface/channel.cc', - 'src/core/tsi/alts/handshaker/alts_shared_resource.cc', + "grpc_call_cancel(": ["src/core/lib/surface/call.cc"], + "grpc_channel_destroy(": [ + "src/core/lib/surface/channel.cc", + "src/core/tsi/alts/handshaker/alts_shared_resource.cc", ], - 'grpc_closure_create(': [ - 'src/core/lib/iomgr/closure.cc', 'src/core/lib/iomgr/closure.h' + "grpc_closure_create(": [ + "src/core/lib/iomgr/closure.cc", + "src/core/lib/iomgr/closure.h", ], - 'grpc_closure_init(': [ - 'src/core/lib/iomgr/closure.cc', 'src/core/lib/iomgr/closure.h' + "grpc_closure_init(": [ + "src/core/lib/iomgr/closure.cc", + "src/core/lib/iomgr/closure.h", ], - 'grpc_closure_sched(': ['src/core/lib/iomgr/closure.cc'], - 'grpc_closure_run(': ['src/core/lib/iomgr/closure.cc'], - 'grpc_closure_list_sched(': ['src/core/lib/iomgr/closure.cc'], - 'grpc_error*': ['src/core/lib/iomgr/error.cc'], - 'grpc_error_string': ['src/core/lib/iomgr/error.cc'], + "grpc_closure_sched(": ["src/core/lib/iomgr/closure.cc"], + "grpc_closure_run(": ["src/core/lib/iomgr/closure.cc"], + "grpc_closure_list_sched(": ["src/core/lib/iomgr/closure.cc"], + "grpc_error*": ["src/core/lib/iomgr/error.cc"], + "grpc_error_string": ["src/core/lib/iomgr/error.cc"], # use grpc_core::CSlice{Ref,Unref} instead inside core # (or prefer grpc_core::Slice!) - 'grpc_slice_ref(': ['src/core/lib/slice/slice.cc'], - 'grpc_slice_unref(': ['src/core/lib/slice/slice.cc'], + "grpc_slice_ref(": ["src/core/lib/slice/slice.cc"], + "grpc_slice_unref(": ["src/core/lib/slice/slice.cc"], # std::random_device needs /dev/random which is not available on all linuxes that we support. # Any usage must be optional and opt-in, so that those platforms can use gRPC without problem. - 'std::random_device': [ - 'src/core/ext/filters/client_channel/lb_policy/rls/rls.cc', - 'src/core/ext/filters/client_channel/resolver/google_c2p/google_c2p_resolver.cc', + "std::random_device": [ + "src/core/ext/filters/client_channel/lb_policy/rls/rls.cc", + "src/core/ext/filters/client_channel/resolver/google_c2p/google_c2p_resolver.cc", ], # use 'grpc_core::Crash' instead - 'GPR_ASSERT(false': [], - + "GPR_ASSERT(false": [], # Use `std::exchange()` instead. - 'absl::exchange': [], + "absl::exchange": [], # Use `std::make_unique()` instead. - 'absl::make_unique': [], + "absl::make_unique": [], } errors = 0 num_files = 0 -for root, dirs, files in os.walk('src/core'): - if root.startswith('src/core/tsi'): +for root, dirs, files in os.walk("src/core"): + if root.startswith("src/core/tsi"): continue for filename in files: num_files += 1 path = os.path.join(root, filename) - if os.path.splitext(path)[1] not in ('.h', '.cc'): + if os.path.splitext(path)[1] not in (".h", ".cc"): continue with open(path) as f: text = f.read() diff --git a/tools/run_tests/sanity/sanity_tests.yaml b/tools/run_tests/sanity/sanity_tests.yaml index 3a871d38d4cf8..66fa532c7157c 100644 --- a/tools/run_tests/sanity/sanity_tests.yaml +++ b/tools/run_tests/sanity/sanity_tests.yaml @@ -37,7 +37,7 @@ cpu_cost: 1000 - script: tools/distrib/pylint_code.sh - script: tools/distrib/python/check_grpcio_tools.py -- script: tools/distrib/yapf_code.sh --diff +- script: tools/distrib/black_code.sh --diff cpu_cost: 1000 - script: tools/distrib/isort_code.sh --diff cpu_cost: 1000 diff --git a/tools/run_tests/task_runner.py b/tools/run_tests/task_runner.py index 6b227e21ff67c..fb2502d2a282e 100755 --- a/tools/run_tests/task_runner.py +++ b/tools/run_tests/task_runner.py @@ -36,10 +36,10 @@ def _create_build_map(): """Maps task names and labels to list of tasks to be built.""" target_build_map = dict([(target.name, [target]) for target in _TARGETS]) if len(_TARGETS) > len(list(target_build_map.keys())): - raise Exception('Target names need to be unique') + raise Exception("Target names need to be unique") label_build_map = {} - label_build_map['all'] = [t for t in _TARGETS] # to build all targets + label_build_map["all"] = [t for t in _TARGETS] # to build all targets for target in _TARGETS: for label in target.labels: if label in label_build_map: @@ -48,42 +48,52 @@ def _create_build_map(): label_build_map[label] = [target] if set(target_build_map.keys()).intersection(list(label_build_map.keys())): - raise Exception('Target names need to be distinct from label names') + raise Exception("Target names need to be distinct from label names") return dict(list(target_build_map.items()) + list(label_build_map.items())) _BUILD_MAP = _create_build_map() -argp = argparse.ArgumentParser(description='Runs build/test targets.') -argp.add_argument('-b', - '--build', - choices=sorted(_BUILD_MAP.keys()), - nargs='+', - default=['all'], - help='Target name or target label to build.') -argp.add_argument('-f', - '--filter', - choices=sorted(_BUILD_MAP.keys()), - nargs='+', - default=[], - help='Filter targets to build with AND semantics.') -argp.add_argument('-j', '--jobs', default=multiprocessing.cpu_count(), type=int) -argp.add_argument('-x', - '--xml_report', - default='report_taskrunner_sponge_log.xml', - type=str, - help='Filename for the JUnit-compatible XML report') -argp.add_argument('--dry_run', - default=False, - action='store_const', - const=True, - help='Only print what would be run.') +argp = argparse.ArgumentParser(description="Runs build/test targets.") argp.add_argument( - '--inner_jobs', + "-b", + "--build", + choices=sorted(_BUILD_MAP.keys()), + nargs="+", + default=["all"], + help="Target name or target label to build.", +) +argp.add_argument( + "-f", + "--filter", + choices=sorted(_BUILD_MAP.keys()), + nargs="+", + default=[], + help="Filter targets to build with AND semantics.", +) +argp.add_argument("-j", "--jobs", default=multiprocessing.cpu_count(), type=int) +argp.add_argument( + "-x", + "--xml_report", + default="report_taskrunner_sponge_log.xml", + type=str, + help="Filename for the JUnit-compatible XML report", +) +argp.add_argument( + "--dry_run", + default=False, + action="store_const", + const=True, + help="Only print what would be run.", +) +argp.add_argument( + "--inner_jobs", default=None, type=int, - help= - 'Number of parallel jobs to use by each target. Passed as build_jobspec(inner_jobs=N) to each target.' + help=( + "Number of parallel jobs to use by each target. Passed as" + " build_jobspec(inner_jobs=N) to each target." + ), ) args = argp.parse_args() @@ -96,13 +106,13 @@ def _create_build_map(): # Among targets selected by -b, filter out those that don't match the filter targets = [t for t in targets if all(f in t.labels for f in args.filter)] -print('Will build %d targets:' % len(targets)) +print("Will build %d targets:" % len(targets)) for target in targets: - print(' %s, labels %s' % (target.name, target.labels)) + print(" %s, labels %s" % (target.name, target.labels)) print() if args.dry_run: - print('--dry_run was used, exiting') + print("--dry_run was used, exiting") sys.exit(1) # Execute pre-build phase @@ -110,31 +120,31 @@ def _create_build_map(): for target in targets: prebuild_jobs += target.pre_build_jobspecs() if prebuild_jobs: - num_failures, _ = jobset.run(prebuild_jobs, - newline_on_success=True, - maxjobs=args.jobs) + num_failures, _ = jobset.run( + prebuild_jobs, newline_on_success=True, maxjobs=args.jobs + ) if num_failures != 0: - jobset.message('FAILED', 'Pre-build phase failed.', do_newline=True) + jobset.message("FAILED", "Pre-build phase failed.", do_newline=True) sys.exit(1) build_jobs = [] for target in targets: build_jobs.append(target.build_jobspec(inner_jobs=args.inner_jobs)) if not build_jobs: - print('Nothing to build.') + print("Nothing to build.") sys.exit(1) -jobset.message('START', 'Building targets.', do_newline=True) -num_failures, resultset = jobset.run(build_jobs, - newline_on_success=True, - maxjobs=args.jobs) -report_utils.render_junit_xml_report(resultset, - args.xml_report, - suite_name='tasks') +jobset.message("START", "Building targets.", do_newline=True) +num_failures, resultset = jobset.run( + build_jobs, newline_on_success=True, maxjobs=args.jobs +) +report_utils.render_junit_xml_report( + resultset, args.xml_report, suite_name="tasks" +) if num_failures == 0: - jobset.message('SUCCESS', - 'All targets built successfully.', - do_newline=True) + jobset.message( + "SUCCESS", "All targets built successfully.", do_newline=True + ) else: - jobset.message('FAILED', 'Failed to build targets.', do_newline=True) + jobset.message("FAILED", "Failed to build targets.", do_newline=True) sys.exit(1) diff --git a/tools/run_tests/xds_k8s_test_driver/README.md b/tools/run_tests/xds_k8s_test_driver/README.md index 5fbe80f9d927a..29a29c372c67f 100644 --- a/tools/run_tests/xds_k8s_test_driver/README.md +++ b/tools/run_tests/xds_k8s_test_driver/README.md @@ -216,7 +216,7 @@ from your dev environment. You need: ### Making changes to the driver 1. Install additional dev packages: `pip install -r requirements-dev.txt` -2. Use `./bin/yapf.sh` and `./bin/isort.sh` helpers to auto-format code. +2. Use `./bin/black.sh` and `./bin/isort.sh` helpers to auto-format code. ### Updating Python Dependencies diff --git a/tools/run_tests/xds_k8s_test_driver/bin/yapf.sh b/tools/run_tests/xds_k8s_test_driver/bin/black.sh similarity index 82% rename from tools/run_tests/xds_k8s_test_driver/bin/yapf.sh rename to tools/run_tests/xds_k8s_test_driver/bin/black.sh index 0b6d2ce1d67bb..60c077d375a07 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/yapf.sh +++ b/tools/run_tests/xds_k8s_test_driver/bin/black.sh @@ -17,10 +17,11 @@ set -eo pipefail display_usage() { cat </dev/stderr -A helper to run yapf formatter. +A helper to run black formatter. USAGE: $0 [--diff] --diff: Do not apply changes, only show the diff + --check: Do not apply changes, only print what files will be changed ENVIRONMENT: XDS_K8S_DRIVER_VENV_DIR: the path to python virtual environment directory @@ -28,6 +29,7 @@ ENVIRONMENT: EXAMPLES: $0 $0 --diff +$0 --check EOF exit 1 } @@ -48,11 +50,11 @@ source "${XDS_K8S_DRIVER_DIR}/bin/ensure_venv.sh" if [[ "$1" == "--diff" ]]; then readonly MODE="--diff" +elif [[ "$1" == "--check" ]]; then + readonly MODE="--check" else - readonly MODE="--in-place" - readonly VERBOSE="--verbose" # print out file names while processing + readonly MODE="" fi -exec python -m yapf "${MODE}" ${VERBOSE:-} \ - --parallel --recursive --style=../../../setup.cfg \ - framework bin tests +# shellcheck disable=SC2086 +exec python -m black --config=../../../black.toml ${MODE} . diff --git a/tools/run_tests/xds_k8s_test_driver/bin/cleanup/cleanup.py b/tools/run_tests/xds_k8s_test_driver/bin/cleanup/cleanup.py index b1cfde5d55aac..7a2581a4cea4a 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/cleanup/cleanup.py +++ b/tools/run_tests/xds_k8s_test_driver/bin/cleanup/cleanup.py @@ -50,53 +50,68 @@ _KubernetesClientRunner = k8s_xds_client_runner.KubernetesClientRunner _KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner -GCLOUD = os.environ.get('GCLOUD', 'gcloud') +GCLOUD = os.environ.get("GCLOUD", "gcloud") GCLOUD_CMD_TIMEOUT_S = datetime.timedelta(seconds=5).total_seconds() -ZONE = 'us-central1-a' -SECONDARY_ZONE = 'us-west1-b' +ZONE = "us-central1-a" +SECONDARY_ZONE = "us-west1-b" -PSM_SECURITY_PREFIX = 'psm-interop' # Prefix for gke resources to delete. -URL_MAP_TEST_PREFIX = 'interop-psm-url-map' # Prefix for url-map test resources to delete. +PSM_SECURITY_PREFIX = "psm-interop" # Prefix for gke resources to delete. +URL_MAP_TEST_PREFIX = ( # Prefix for url-map test resources to delete. + "interop-psm-url-map" +) KEEP_PERIOD_HOURS = flags.DEFINE_integer( "keep_hours", default=168, - help= - "number of hours for a resource to keep. Resources older than this will be deleted. Default is 168 (7 days)" + help=( + "number of hours for a resource to keep. Resources older than this will" + " be deleted. Default is 168 (7 days)" + ), ) DRY_RUN = flags.DEFINE_bool( "dry_run", default=False, - help="dry run, print resources but do not perform deletion") + help="dry run, print resources but do not perform deletion", +) TD_RESOURCE_PREFIXES = flags.DEFINE_list( "td_resource_prefixes", default=[PSM_SECURITY_PREFIX], - help= - "a comma-separated list of prefixes for which the leaked TD resources will be deleted", + help=( + "a comma-separated list of prefixes for which the leaked TD resources" + " will be deleted" + ), ) SERVER_PREFIXES = flags.DEFINE_list( "server_prefixes", default=[PSM_SECURITY_PREFIX], - help= - "a comma-separated list of prefixes for which the leaked servers will be deleted", + help=( + "a comma-separated list of prefixes for which the leaked servers will" + " be deleted" + ), ) CLIENT_PREFIXES = flags.DEFINE_list( "client_prefixes", default=[PSM_SECURITY_PREFIX, URL_MAP_TEST_PREFIX], - help= - "a comma-separated list of prefixes for which the leaked clients will be deleted", + help=( + "a comma-separated list of prefixes for which the leaked clients will" + " be deleted" + ), ) def load_keep_config() -> None: global KEEP_CONFIG json_path = os.path.realpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'keep_xds_interop_resources.json')) - with open(json_path, 'r') as f: + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "keep_xds_interop_resources.json", + ) + ) + with open(json_path, "r") as f: KEEP_CONFIG = json.load(f) - logging.debug('Resource keep config loaded: %s', - json.dumps(KEEP_CONFIG, indent=2)) + logging.debug( + "Resource keep config loaded: %s", json.dumps(KEEP_CONFIG, indent=2) + ) def is_marked_as_keep_gce(suffix: str) -> bool: @@ -110,71 +125,154 @@ def is_marked_as_keep_gke(suffix: str) -> bool: @functools.lru_cache() def get_expire_timestamp() -> datetime.datetime: return datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( - hours=KEEP_PERIOD_HOURS.value) + hours=KEEP_PERIOD_HOURS.value + ) def exec_gcloud(project: str, *cmds: List[str]) -> Json: - cmds = [GCLOUD, '--project', project, '--quiet'] + list(cmds) - if 'list' in cmds: + cmds = [GCLOUD, "--project", project, "--quiet"] + list(cmds) + if "list" in cmds: # Add arguments to shape the list output - cmds.extend([ - '--format', 'json', '--filter', - f'creationTimestamp <= {get_expire_timestamp().isoformat()}' - ]) + cmds.extend( + [ + "--format", + "json", + "--filter", + f"creationTimestamp <= {get_expire_timestamp().isoformat()}", + ] + ) # Executing the gcloud command - logging.debug('Executing: %s', " ".join(cmds)) - proc = subprocess.Popen(cmds, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + logging.debug("Executing: %s", " ".join(cmds)) + proc = subprocess.Popen( + cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) # NOTE(lidiz) the gcloud subprocess won't return unless its output is read stdout = proc.stdout.read() stderr = proc.stderr.read() try: returncode = proc.wait(timeout=GCLOUD_CMD_TIMEOUT_S) except subprocess.TimeoutExpired: - logging.error('> Timeout executing cmd [%s]', " ".join(cmds)) + logging.error("> Timeout executing cmd [%s]", " ".join(cmds)) return None if returncode: - logging.error('> Failed to execute cmd [%s], returned %d, stderr: %s', - " ".join(cmds), returncode, stderr) + logging.error( + "> Failed to execute cmd [%s], returned %d, stderr: %s", + " ".join(cmds), + returncode, + stderr, + ) return None if stdout: return json.loads(stdout) return None -def remove_relative_resources_run_xds_tests(project: str, network: str, - prefix: str, suffix: str): +def remove_relative_resources_run_xds_tests( + project: str, network: str, prefix: str, suffix: str +): """Removing GCP resources created by run_xds_tests.py.""" - logging.info('----- Removing run_xds_tests.py resources with suffix [%s]', - suffix) - exec_gcloud(project, 'compute', 'forwarding-rules', 'delete', - f'test-forwarding-rule{suffix}', '--global') - exec_gcloud(project, 'compute', 'target-http-proxies', 'delete', - f'test-target-proxy{suffix}') - exec_gcloud(project, 'alpha', 'compute', 'target-grpc-proxies', 'delete', - f'test-target-proxy{suffix}') - exec_gcloud(project, 'compute', 'url-maps', 'delete', f'test-map{suffix}') - exec_gcloud(project, 'compute', 'backend-services', 'delete', - f'test-backend-service{suffix}', '--global') - exec_gcloud(project, 'compute', 'backend-services', 'delete', - f'test-backend-service-alternate{suffix}', '--global') - exec_gcloud(project, 'compute', 'backend-services', 'delete', - f'test-backend-service-extra{suffix}', '--global') - exec_gcloud(project, 'compute', 'backend-services', 'delete', - f'test-backend-service-more-extra{suffix}', '--global') - exec_gcloud(project, 'compute', 'firewall-rules', 'delete', - f'test-fw-rule{suffix}') - exec_gcloud(project, 'compute', 'health-checks', 'delete', - f'test-hc{suffix}') - exec_gcloud(project, 'compute', 'instance-groups', 'managed', 'delete', - f'test-ig{suffix}', '--zone', ZONE) - exec_gcloud(project, 'compute', 'instance-groups', 'managed', 'delete', - f'test-ig-same-zone{suffix}', '--zone', ZONE) - exec_gcloud(project, 'compute', 'instance-groups', 'managed', 'delete', - f'test-ig-secondary-zone{suffix}', '--zone', SECONDARY_ZONE) - exec_gcloud(project, 'compute', 'instance-templates', 'delete', - f'test-template{suffix}') + logging.info( + "----- Removing run_xds_tests.py resources with suffix [%s]", suffix + ) + exec_gcloud( + project, + "compute", + "forwarding-rules", + "delete", + f"test-forwarding-rule{suffix}", + "--global", + ) + exec_gcloud( + project, + "compute", + "target-http-proxies", + "delete", + f"test-target-proxy{suffix}", + ) + exec_gcloud( + project, + "alpha", + "compute", + "target-grpc-proxies", + "delete", + f"test-target-proxy{suffix}", + ) + exec_gcloud(project, "compute", "url-maps", "delete", f"test-map{suffix}") + exec_gcloud( + project, + "compute", + "backend-services", + "delete", + f"test-backend-service{suffix}", + "--global", + ) + exec_gcloud( + project, + "compute", + "backend-services", + "delete", + f"test-backend-service-alternate{suffix}", + "--global", + ) + exec_gcloud( + project, + "compute", + "backend-services", + "delete", + f"test-backend-service-extra{suffix}", + "--global", + ) + exec_gcloud( + project, + "compute", + "backend-services", + "delete", + f"test-backend-service-more-extra{suffix}", + "--global", + ) + exec_gcloud( + project, "compute", "firewall-rules", "delete", f"test-fw-rule{suffix}" + ) + exec_gcloud( + project, "compute", "health-checks", "delete", f"test-hc{suffix}" + ) + exec_gcloud( + project, + "compute", + "instance-groups", + "managed", + "delete", + f"test-ig{suffix}", + "--zone", + ZONE, + ) + exec_gcloud( + project, + "compute", + "instance-groups", + "managed", + "delete", + f"test-ig-same-zone{suffix}", + "--zone", + ZONE, + ) + exec_gcloud( + project, + "compute", + "instance-groups", + "managed", + "delete", + f"test-ig-secondary-zone{suffix}", + "--zone", + SECONDARY_ZONE, + ) + exec_gcloud( + project, + "compute", + "instance-templates", + "delete", + f"test-template{suffix}", + ) # cleanup_td creates TrafficDirectorManager (and its varients for security and @@ -190,13 +288,15 @@ def cleanup_td_for_gke(project, network, resource_prefix, resource_suffix): project=project, network=network, resource_prefix=resource_prefix, - resource_suffix=resource_suffix) + resource_suffix=resource_suffix, + ) security_td = traffic_director.TrafficDirectorSecureManager( gcp_api_manager, project=project, network=network, resource_prefix=resource_prefix, - resource_suffix=resource_suffix) + resource_suffix=resource_suffix, + ) # TODO: cleanup appnet resources. # appnet_td = traffic_director.TrafficDirectorAppNetManager( # gcp_api_manager, @@ -205,16 +305,25 @@ def cleanup_td_for_gke(project, network, resource_prefix, resource_suffix): # resource_prefix=resource_prefix, # resource_suffix=resource_suffix) - logger.info('----- Removing traffic director for gke, prefix %s, suffix %s', - resource_prefix, resource_suffix) + logger.info( + "----- Removing traffic director for gke, prefix %s, suffix %s", + resource_prefix, + resource_suffix, + ) security_td.cleanup(force=True) # appnet_td.cleanup(force=True) plain_td.cleanup(force=True) # cleanup_client creates a client runner, and calls its cleanup() method. -def cleanup_client(project, network, k8s_api_manager, resource_prefix, - resource_suffix, gcp_service_account): +def cleanup_client( + project, + network, + k8s_api_manager, + resource_prefix, + resource_suffix, + gcp_service_account, +): runner_kwargs = dict( deployment_name=xds_flags.CLIENT_NAME.value, image_name=xds_k8s_flags.CLIENT_IMAGE.value, @@ -224,21 +333,30 @@ def cleanup_client(project, network, k8s_api_manager, resource_prefix, gcp_service_account=gcp_service_account, xds_server_uri=xds_flags.XDS_SERVER_URI.value, network=network, - stats_port=xds_flags.CLIENT_PORT.value) + stats_port=xds_flags.CLIENT_PORT.value, + ) client_namespace = _KubernetesClientRunner.make_namespace_name( - resource_prefix, resource_suffix) + resource_prefix, resource_suffix + ) client_runner = _KubernetesClientRunner( k8s.KubernetesNamespace(k8s_api_manager, client_namespace), - **runner_kwargs) + **runner_kwargs, + ) - logger.info('Cleanup client') + logger.info("Cleanup client") client_runner.cleanup(force=True, force_namespace=True) # cleanup_server creates a server runner, and calls its cleanup() method. -def cleanup_server(project, network, k8s_api_manager, resource_prefix, - resource_suffix, gcp_service_account): +def cleanup_server( + project, + network, + k8s_api_manager, + resource_prefix, + resource_suffix, + gcp_service_account, +): runner_kwargs = dict( deployment_name=xds_flags.SERVER_NAME.value, image_name=xds_k8s_flags.SERVER_IMAGE.value, @@ -246,70 +364,90 @@ def cleanup_server(project, network, k8s_api_manager, resource_prefix, gcp_project=project, gcp_api_manager=gcp.api.GcpApiManager(), gcp_service_account=gcp_service_account, - network=network) + network=network, + ) server_namespace = _KubernetesServerRunner.make_namespace_name( - resource_prefix, resource_suffix) + resource_prefix, resource_suffix + ) server_runner = _KubernetesServerRunner( k8s.KubernetesNamespace(k8s_api_manager, server_namespace), - **runner_kwargs) + **runner_kwargs, + ) - logger.info('Cleanup server') + logger.info("Cleanup server") server_runner.cleanup(force=True, force_namespace=True) -def delete_leaked_td_resources(dry_run, td_resource_rules, project, network, - resources): +def delete_leaked_td_resources( + dry_run, td_resource_rules, project, network, resources +): for resource in resources: - logger.info('-----') - logger.info('----- Cleaning up resource %s', resource['name']) + logger.info("-----") + logger.info("----- Cleaning up resource %s", resource["name"]) if dry_run: # Skip deletion for dry-runs - logging.info('----- Skipped [Dry Run]: %s', resource['name']) + logging.info("----- Skipped [Dry Run]: %s", resource["name"]) continue matched = False - for (regex, resource_prefix, keep, remove) in td_resource_rules: - result = re.search(regex, resource['name']) + for regex, resource_prefix, keep, remove in td_resource_rules: + result = re.search(regex, resource["name"]) if result is not None: matched = True if keep(result.group(1)): - logging.info('Skipped [keep]:') + logging.info("Skipped [keep]:") break # break inner loop, continue outer loop remove(project, network, resource_prefix, result.group(1)) break if not matched: logging.info( - '----- Skipped [does not matching resource name templates]') - - -def delete_k8s_resources(dry_run, k8s_resource_rules, project, network, - k8s_api_manager, gcp_service_account, namespaces): + "----- Skipped [does not matching resource name templates]" + ) + + +def delete_k8s_resources( + dry_run, + k8s_resource_rules, + project, + network, + k8s_api_manager, + gcp_service_account, + namespaces, +): for ns in namespaces: - logger.info('-----') - logger.info('----- Cleaning up k8s namespaces %s', ns.metadata.name) + logger.info("-----") + logger.info("----- Cleaning up k8s namespaces %s", ns.metadata.name) if ns.metadata.creation_timestamp <= get_expire_timestamp(): if dry_run: # Skip deletion for dry-runs - logging.info('----- Skipped [Dry Run]: %s', ns.metadata.name) + logging.info("----- Skipped [Dry Run]: %s", ns.metadata.name) continue matched = False - for (regex, resource_prefix, remove) in k8s_resource_rules: + for regex, resource_prefix, remove in k8s_resource_rules: result = re.search(regex, ns.metadata.name) if result is not None: matched = True - remove(project, network, k8s_api_manager, resource_prefix, - result.group(1), gcp_service_account) + remove( + project, + network, + k8s_api_manager, + resource_prefix, + result.group(1), + gcp_service_account, + ) break if not matched: logging.info( - '----- Skipped [does not matching resource name templates]') + "----- Skipped [does not matching resource name templates]" + ) else: - logging.info('----- Skipped [resource is within expiry date]') + logging.info("----- Skipped [resource is within expiry date]") -def find_and_remove_leaked_k8s_resources(dry_run, project, network, - gcp_service_account): +def find_and_remove_leaked_k8s_resources( + dry_run, project, network, gcp_service_account +): k8s_resource_rules = [ # items in each tuple, in order # - regex to match @@ -318,22 +456,31 @@ def find_and_remove_leaked_k8s_resources(dry_run, project, network, ] for prefix in CLIENT_PREFIXES.value: k8s_resource_rules.append( - (f'{prefix}-client-(.*)', prefix, cleanup_client),) + (f"{prefix}-client-(.*)", prefix, cleanup_client), + ) for prefix in SERVER_PREFIXES.value: k8s_resource_rules.append( - (f'{prefix}-server-(.*)', prefix, cleanup_server),) + (f"{prefix}-server-(.*)", prefix, cleanup_server), + ) # Delete leaked k8s namespaces, those usually mean there are leaked testing # client/servers from the gke framework. k8s_api_manager = k8s.KubernetesApiManager(xds_k8s_flags.KUBE_CONTEXT.value) nss = k8s_api_manager.core.list_namespace() - delete_k8s_resources(dry_run, k8s_resource_rules, project, network, - k8s_api_manager, gcp_service_account, nss.items) + delete_k8s_resources( + dry_run, + k8s_resource_rules, + project, + network, + k8s_api_manager, + gcp_service_account, + nss.items, + ) def main(argv): if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + raise app.UsageError("Too many command-line arguments.") load_keep_config() # Must be called before KubernetesApiManager or GcpApiManager init. @@ -350,14 +497,28 @@ def main(argv): # - prefix of the resource (only used by gke resources) # - function to check of the resource should be kept # - function to delete the resource - (r'test-hc(.*)', '', is_marked_as_keep_gce, - remove_relative_resources_run_xds_tests), - (r'test-template(.*)', '', is_marked_as_keep_gce, - remove_relative_resources_run_xds_tests), + ( + r"test-hc(.*)", + "", + is_marked_as_keep_gce, + remove_relative_resources_run_xds_tests, + ), + ( + r"test-template(.*)", + "", + is_marked_as_keep_gce, + remove_relative_resources_run_xds_tests, + ), ] for prefix in TD_RESOURCE_PREFIXES.value: - td_resource_rules.append((f'{prefix}-health-check-(.*)', prefix, - is_marked_as_keep_gke, cleanup_td_for_gke),) + td_resource_rules.append( + ( + f"{prefix}-health-check-(.*)", + prefix, + is_marked_as_keep_gke, + cleanup_td_for_gke, + ), + ) # List resources older than KEEP_PERIOD. We only list health-checks and # instance templates because these are leaves in the resource dependency tree. @@ -369,25 +530,31 @@ def main(argv): # forwarding-rule. compute = gcp.compute.ComputeV1(gcp.api.GcpApiManager(), project) leakedHealthChecks = [] - for item in compute.list_health_check()['items']: - if dateutil.parser.isoparse( - item['creationTimestamp']) <= get_expire_timestamp(): + for item in compute.list_health_check()["items"]: + if ( + dateutil.parser.isoparse(item["creationTimestamp"]) + <= get_expire_timestamp() + ): leakedHealthChecks.append(item) - delete_leaked_td_resources(dry_run, td_resource_rules, project, network, - leakedHealthChecks) + delete_leaked_td_resources( + dry_run, td_resource_rules, project, network, leakedHealthChecks + ) # Delete leaked instance templates, those usually mean there are leaked VMs # from the gce framework. Also note that this is only needed for the gce # resources. - leakedInstanceTemplates = exec_gcloud(project, 'compute', - 'instance-templates', 'list') - delete_leaked_td_resources(dry_run, td_resource_rules, project, network, - leakedInstanceTemplates) + leakedInstanceTemplates = exec_gcloud( + project, "compute", "instance-templates", "list" + ) + delete_leaked_td_resources( + dry_run, td_resource_rules, project, network, leakedInstanceTemplates + ) - find_and_remove_leaked_k8s_resources(dry_run, project, network, - gcp_service_account) + find_and_remove_leaked_k8s_resources( + dry_run, project, network, gcp_service_account + ) -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/tools/run_tests/xds_k8s_test_driver/bin/cleanup/namespace.py b/tools/run_tests/xds_k8s_test_driver/bin/cleanup/namespace.py index 3b91c17eb0269..27c1aee92c9fc 100644 --- a/tools/run_tests/xds_k8s_test_driver/bin/cleanup/namespace.py +++ b/tools/run_tests/xds_k8s_test_driver/bin/cleanup/namespace.py @@ -22,7 +22,7 @@ def main(argv): if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + raise app.UsageError("Too many command-line arguments.") cleanup.load_keep_config() # Must be called before KubernetesApiManager or GcpApiManager init. @@ -33,9 +33,10 @@ def main(argv): gcp_service_account: str = xds_k8s_flags.GCP_SERVICE_ACCOUNT.value dry_run: bool = cleanup.DRY_RUN.value - cleanup.find_and_remove_leaked_k8s_resources(dry_run, project, network, - gcp_service_account) + cleanup.find_and_remove_leaked_k8s_resources( + dry_run, project, network, gcp_service_account + ) -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/tools/run_tests/xds_k8s_test_driver/bin/isort.sh b/tools/run_tests/xds_k8s_test_driver/bin/isort.sh index 5fe812dc13133..3acd70092f141 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/isort.sh +++ b/tools/run_tests/xds_k8s_test_driver/bin/isort.sh @@ -55,7 +55,6 @@ fi # typing is the only module allowed to put imports on the same line: # https://google.github.io/styleguide/pyguide.html#313-imports-formatting exec python -m isort "${MODE}" \ - --force-sort-within-sections \ - --force-single-line-imports --single-line-exclusions=typing \ + --settings-path=../../../black.toml \ framework bin tests diff --git a/tools/run_tests/xds_k8s_test_driver/bin/lib/common.py b/tools/run_tests/xds_k8s_test_driver/bin/lib/common.py index edf4c58e9daf4..6578417fe61cc 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/lib/common.py +++ b/tools/run_tests/xds_k8s_test_driver/bin/lib/common.py @@ -37,17 +37,21 @@ def make_client_namespace( - k8s_api_manager: k8s.KubernetesApiManager) -> k8s.KubernetesNamespace: + k8s_api_manager: k8s.KubernetesApiManager, +) -> k8s.KubernetesNamespace: namespace_name: str = KubernetesClientRunner.make_namespace_name( - xds_flags.RESOURCE_PREFIX.value, xds_flags.RESOURCE_SUFFIX.value) + xds_flags.RESOURCE_PREFIX.value, xds_flags.RESOURCE_SUFFIX.value + ) return k8s.KubernetesNamespace(k8s_api_manager, namespace_name) -def make_client_runner(namespace: k8s.KubernetesNamespace, - gcp_api_manager: gcp.api.GcpApiManager, - port_forwarding: bool = False, - reuse_namespace: bool = True, - secure: bool = False) -> KubernetesClientRunner: +def make_client_runner( + namespace: k8s.KubernetesNamespace, + gcp_api_manager: gcp.api.GcpApiManager, + port_forwarding: bool = False, + reuse_namespace: bool = True, + secure: bool = False, +) -> KubernetesClientRunner: # KubernetesClientRunner arguments. runner_kwargs = dict( deployment_name=xds_flags.CLIENT_NAME.value, @@ -60,27 +64,33 @@ def make_client_runner(namespace: k8s.KubernetesNamespace, network=xds_flags.NETWORK.value, stats_port=xds_flags.CLIENT_PORT.value, reuse_namespace=reuse_namespace, - debug_use_port_forwarding=port_forwarding) + debug_use_port_forwarding=port_forwarding, + ) if secure: runner_kwargs.update( - deployment_template='client-secure.deployment.yaml') + deployment_template="client-secure.deployment.yaml" + ) return KubernetesClientRunner(namespace, **runner_kwargs) def make_server_namespace( - k8s_api_manager: k8s.KubernetesApiManager) -> k8s.KubernetesNamespace: + k8s_api_manager: k8s.KubernetesApiManager, +) -> k8s.KubernetesNamespace: namespace_name: str = KubernetesServerRunner.make_namespace_name( - xds_flags.RESOURCE_PREFIX.value, xds_flags.RESOURCE_SUFFIX.value) + xds_flags.RESOURCE_PREFIX.value, xds_flags.RESOURCE_SUFFIX.value + ) return k8s.KubernetesNamespace(k8s_api_manager, namespace_name) -def make_server_runner(namespace: k8s.KubernetesNamespace, - gcp_api_manager: gcp.api.GcpApiManager, - port_forwarding: bool = False, - reuse_namespace: bool = True, - reuse_service: bool = False, - secure: bool = False) -> KubernetesServerRunner: +def make_server_runner( + namespace: k8s.KubernetesNamespace, + gcp_api_manager: gcp.api.GcpApiManager, + port_forwarding: bool = False, + reuse_namespace: bool = True, + reuse_service: bool = False, + secure: bool = False, +) -> KubernetesServerRunner: # KubernetesServerRunner arguments. runner_kwargs = dict( deployment_name=xds_flags.SERVER_NAME.value, @@ -93,10 +103,11 @@ def make_server_runner(namespace: k8s.KubernetesNamespace, network=xds_flags.NETWORK.value, reuse_namespace=reuse_namespace, reuse_service=reuse_service, - debug_use_port_forwarding=port_forwarding) + debug_use_port_forwarding=port_forwarding, + ) if secure: - runner_kwargs['deployment_template'] = 'server-secure.deployment.yaml' + runner_kwargs["deployment_template"] = "server-secure.deployment.yaml" return KubernetesServerRunner(namespace, **runner_kwargs) @@ -108,49 +119,59 @@ def _ensure_atexit(signum, frame): # Pylint is wrong about "Module 'signal' has no 'Signals' member": # https://docs.python.org/3/library/signal.html#signal.Signals sig = signal.Signals(signum) # pylint: disable=no-member - logger.warning('Caught %r, initiating graceful shutdown...\n', sig) + logger.warning("Caught %r, initiating graceful shutdown...\n", sig) sys.exit(1) -def _graceful_exit(server_runner: KubernetesServerRunner, - client_runner: KubernetesClientRunner): +def _graceful_exit( + server_runner: KubernetesServerRunner, client_runner: KubernetesClientRunner +): """Stop port forwarding processes.""" client_runner.stop_pod_dependencies() server_runner.stop_pod_dependencies() -def register_graceful_exit(server_runner: KubernetesServerRunner, - client_runner: KubernetesClientRunner): +def register_graceful_exit( + server_runner: KubernetesServerRunner, client_runner: KubernetesClientRunner +): atexit.register(_graceful_exit, server_runner, client_runner) for signum in (signal.SIGTERM, signal.SIGHUP, signal.SIGINT): signal.signal(signum, _ensure_atexit) -def get_client_pod(client_runner: KubernetesClientRunner, - deployment_name: str) -> k8s.V1Pod: +def get_client_pod( + client_runner: KubernetesClientRunner, deployment_name: str +) -> k8s.V1Pod: client_deployment: k8s.V1Deployment client_deployment = client_runner.k8s_namespace.get_deployment( - deployment_name) + deployment_name + ) client_pod_name: str = client_runner._wait_deployment_pod_count( - client_deployment)[0] + client_deployment + )[0] return client_runner._wait_pod_started(client_pod_name) -def get_server_pod(server_runner: KubernetesServerRunner, - deployment_name: str) -> k8s.V1Pod: +def get_server_pod( + server_runner: KubernetesServerRunner, deployment_name: str +) -> k8s.V1Pod: server_deployment: k8s.V1Deployment server_deployment = server_runner.k8s_namespace.get_deployment( - deployment_name) + deployment_name + ) server_pod_name: str = server_runner._wait_deployment_pod_count( - server_deployment)[0] + server_deployment + )[0] return server_runner._wait_pod_started(server_pod_name) -def get_test_server_for_pod(server_runner: KubernetesServerRunner, - server_pod: k8s.V1Pod, **kwargs) -> _XdsTestServer: +def get_test_server_for_pod( + server_runner: KubernetesServerRunner, server_pod: k8s.V1Pod, **kwargs +) -> _XdsTestServer: return server_runner._xds_test_server_for_pod(server_pod, **kwargs) -def get_test_client_for_pod(client_runner: KubernetesClientRunner, - client_pod: k8s.V1Pod, **kwargs) -> _XdsTestClient: +def get_test_client_for_pod( + client_runner: KubernetesClientRunner, client_pod: k8s.V1Pod, **kwargs +) -> _XdsTestClient: return client_runner._xds_test_client_for_pod(client_pod, **kwargs) diff --git a/tools/run_tests/xds_k8s_test_driver/bin/run_channelz.py b/tools/run_tests/xds_k8s_test_driver/bin/run_channelz.py index 79747be06a273..e69e7f2302a0e 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/run_channelz.py +++ b/tools/run_tests/xds_k8s_test_driver/bin/run_channelz.py @@ -44,21 +44,30 @@ from framework.test_app import server_app # Flags -_SECURITY = flags.DEFINE_enum('security', - default=None, - enum_values=[ - 'mtls', 'tls', 'plaintext', 'mtls_error', - 'server_authz_error' - ], - help='Show info for a security setup') +_SECURITY = flags.DEFINE_enum( + "security", + default=None, + enum_values=[ + "mtls", + "tls", + "plaintext", + "mtls_error", + "server_authz_error", + ], + help="Show info for a security setup", +) flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_k8s_flags) # Running outside of a test suite, so require explicit resource_suffix. flags.mark_flag_as_required(xds_flags.RESOURCE_SUFFIX.name) -flags.register_validator(xds_flags.SERVER_XDS_PORT.name, - lambda val: val > 0, - message="Run outside of a test suite, must provide" - " the exact port value (must be greater than 0).") +flags.register_validator( + xds_flags.SERVER_XDS_PORT.name, + lambda val: val > 0, + message=( + "Run outside of a test suite, must provide" + " the exact port value (must be greater than 0)." + ), +) logger = logging.get_absl_logger() @@ -72,14 +81,16 @@ def debug_cert(cert): if not cert: - return '' + return "" sha1 = hashlib.sha1(cert) - return f'sha1={sha1.hexdigest()}, len={len(cert)}' + return f"sha1={sha1.hexdigest()}, len={len(cert)}" def debug_sock_tls(tls): - return (f'local: {debug_cert(tls.local_certificate)}\n' - f'remote: {debug_cert(tls.remote_certificate)}') + return ( + f"local: {debug_cert(tls.local_certificate)}\n" + f"remote: {debug_cert(tls.remote_certificate)}" + ) def get_deployment_pods(k8s_ns, deployment_name): @@ -98,38 +109,46 @@ def debug_security_setup_negative(test_client): # Client side. client_correct_setup = True channel: _Channel = test_client.wait_for_server_channel_state( - state=_ChannelState.TRANSIENT_FAILURE) + state=_ChannelState.TRANSIENT_FAILURE + ) try: subchannel, *subchannels = list( - test_client.channelz.list_channel_subchannels(channel)) + test_client.channelz.list_channel_subchannels(channel) + ) except ValueError: - print("Client setup fail: subchannel not found. " - "Common causes: test client didn't connect to TD; " - "test client exhausted retries, and closed all subchannels.") + print( + "Client setup fail: subchannel not found. " + "Common causes: test client didn't connect to TD; " + "test client exhausted retries, and closed all subchannels." + ) return # Client must have exactly one subchannel. - logger.debug('Found subchannel, %s', subchannel) + logger.debug("Found subchannel, %s", subchannel) if subchannels: client_correct_setup = False - print(f'Unexpected subchannels {subchannels}') + print(f"Unexpected subchannels {subchannels}") subchannel_state: _ChannelState = subchannel.data.state.state if subchannel_state is not _ChannelState.TRANSIENT_FAILURE: client_correct_setup = False - print('Subchannel expected to be in ' - 'TRANSIENT_FAILURE, same as its channel') + print( + "Subchannel expected to be in " + "TRANSIENT_FAILURE, same as its channel" + ) # Client subchannel must have no sockets. sockets = list(test_client.channelz.list_subchannels_sockets(subchannel)) if sockets: client_correct_setup = False - print(f'Unexpected subchannel sockets {sockets}') + print(f"Unexpected subchannel sockets {sockets}") # Results. if client_correct_setup: - print('Client setup pass: the channel ' - 'to the server has exactly one subchannel ' - 'in TRANSIENT_FAILURE, and no sockets') + print( + "Client setup pass: the channel " + "to the server has exactly one subchannel " + "in TRANSIENT_FAILURE, and no sockets" + ) def debug_security_setup_positive(test_client, test_server): @@ -137,26 +156,27 @@ def debug_security_setup_positive(test_client, test_server): test_client.wait_for_active_server_channel() client_sock: _Socket = test_client.get_active_server_channel_socket() server_sock: _Socket = test_server.get_server_socket_matching_client( - client_sock) + client_sock + ) server_tls = server_sock.security.tls client_tls = client_sock.security.tls - print(f'\nServer certs:\n{debug_sock_tls(server_tls)}') - print(f'\nClient certs:\n{debug_sock_tls(client_tls)}') + print(f"\nServer certs:\n{debug_sock_tls(server_tls)}") + print(f"\nClient certs:\n{debug_sock_tls(client_tls)}") print() if server_tls.local_certificate: eq = server_tls.local_certificate == client_tls.remote_certificate - print(f'(TLS) Server local matches client remote: {eq}') + print(f"(TLS) Server local matches client remote: {eq}") else: - print('(TLS) Not detected') + print("(TLS) Not detected") if server_tls.remote_certificate: eq = server_tls.remote_certificate == client_tls.local_certificate - print(f'(mTLS) Server remote matches client local: {eq}') + print(f"(mTLS) Server remote matches client local: {eq}") else: - print('(mTLS) Not detected') + print("(mTLS) Not detected") def debug_basic_setup(test_client, test_server): @@ -164,15 +184,16 @@ def debug_basic_setup(test_client, test_server): test_client.wait_for_active_server_channel() client_sock: _Socket = test_client.get_active_server_channel_socket() server_sock: _Socket = test_server.get_server_socket_matching_client( - client_sock) + client_sock + ) - logger.debug('Client socket: %s\n', client_sock) - logger.debug('Matching server socket: %s\n', server_sock) + logger.debug("Client socket: %s\n", client_sock) + logger.debug("Matching server socket: %s\n", server_sock) def main(argv): if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + raise app.UsageError("Too many command-line arguments.") # Must be called before KubernetesApiManager or GcpApiManager init. xds_flags.set_socket_default_timeout_from_flag() @@ -191,10 +212,12 @@ def main(argv): server_namespace, gcp_api_manager, port_forwarding=should_port_forward, - secure=is_secure) + secure=is_secure, + ) # Find server pod. - server_pod: k8s.V1Pod = common.get_server_pod(server_runner, - xds_flags.SERVER_NAME.value) + server_pod: k8s.V1Pod = common.get_server_pod( + server_runner, xds_flags.SERVER_NAME.value + ) # Client client_namespace = common.make_client_namespace(k8s_api_manager) @@ -202,10 +225,12 @@ def main(argv): client_namespace, gcp_api_manager, port_forwarding=should_port_forward, - secure=is_secure) + secure=is_secure, + ) # Find client pod. - client_pod: k8s.V1Pod = common.get_client_pod(client_runner, - xds_flags.CLIENT_NAME.value) + client_pod: k8s.V1Pod = common.get_client_pod( + client_runner, xds_flags.CLIENT_NAME.value + ) # Ensure port forwarding stopped. common.register_graceful_exit(server_runner, client_runner) @@ -215,24 +240,27 @@ def main(argv): server_runner, server_pod, test_port=xds_flags.SERVER_PORT.value, - secure_mode=is_secure) - test_server.set_xds_address(xds_flags.SERVER_XDS_HOST.value, - xds_flags.SERVER_XDS_PORT.value) + secure_mode=is_secure, + ) + test_server.set_xds_address( + xds_flags.SERVER_XDS_HOST.value, xds_flags.SERVER_XDS_PORT.value + ) # Create client app for the client pod. test_client: _XdsTestClient = common.get_test_client_for_pod( - client_runner, client_pod, server_target=test_server.xds_uri) + client_runner, client_pod, server_target=test_server.xds_uri + ) with test_client, test_server: - if _SECURITY.value in ('mtls', 'tls', 'plaintext'): + if _SECURITY.value in ("mtls", "tls", "plaintext"): debug_security_setup_positive(test_client, test_server) - elif _SECURITY.value in ('mtls_error', 'server_authz_error'): + elif _SECURITY.value in ("mtls_error", "server_authz_error"): debug_security_setup_negative(test_client) else: debug_basic_setup(test_client, test_server) - logger.info('SUCCESS!') + logger.info("SUCCESS!") -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/tools/run_tests/xds_k8s_test_driver/bin/run_ping_pong.py b/tools/run_tests/xds_k8s_test_driver/bin/run_ping_pong.py index 96caefdc323f6..ca38d59343d9b 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/run_ping_pong.py +++ b/tools/run_tests/xds_k8s_test_driver/bin/run_ping_pong.py @@ -31,21 +31,30 @@ _SECURE = flags.DEFINE_bool( "secure", default=False, - help="Set to True if the the client/server were started " - "with the PSM security enabled.") -_NUM_RPCS = flags.DEFINE_integer("num_rpcs", - default=100, - lower_bound=1, - upper_bound=10_000, - help="The number of RPCs to check.") + help=( + "Set to True if the the client/server were started " + "with the PSM security enabled." + ), +) +_NUM_RPCS = flags.DEFINE_integer( + "num_rpcs", + default=100, + lower_bound=1, + upper_bound=10_000, + help="The number of RPCs to check.", +) flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_k8s_flags) # Running outside of a test suite, so require explicit resource_suffix. flags.mark_flag_as_required(xds_flags.RESOURCE_SUFFIX.name) -flags.register_validator(xds_flags.SERVER_XDS_PORT.name, - lambda val: val > 0, - message="Run outside of a test suite, must provide" - " the exact port value (must be greater than 0).") +flags.register_validator( + xds_flags.SERVER_XDS_PORT.name, + lambda val: val > 0, + message=( + "Run outside of a test suite, must provide" + " the exact port value (must be greater than 0)." + ), +) logger = logging.get_absl_logger() @@ -58,13 +67,16 @@ LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse -def get_client_rpc_stats(test_client: _XdsTestClient, - num_rpcs: int) -> LoadBalancerStatsResponse: +def get_client_rpc_stats( + test_client: _XdsTestClient, num_rpcs: int +) -> LoadBalancerStatsResponse: lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs) hl = framework.helpers.highlighter.HighlighterYaml() - logger.info('[%s] Received LoadBalancerStatsResponse:\n%s', - test_client.hostname, - hl.highlight(helpers_grpc.lb_stats_pretty(lb_stats))) + logger.info( + "[%s] Received LoadBalancerStatsResponse:\n%s", + test_client.hostname, + hl.highlight(helpers_grpc.lb_stats_pretty(lb_stats)), + ) return lb_stats @@ -74,17 +86,19 @@ def run_ping_pong(test_client: _XdsTestClient, num_rpcs: int): for backend, rpcs_count in lb_stats.rpcs_by_peer.items(): if int(rpcs_count) < 1: raise AssertionError( - f'Backend {backend} did not receive a single RPC') + f"Backend {backend} did not receive a single RPC" + ) failed = int(lb_stats.num_failures) if int(lb_stats.num_failures) > 0: raise AssertionError( - f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed') + f"Expected all RPCs to succeed: {failed} of {num_rpcs} failed" + ) def main(argv): if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + raise app.UsageError("Too many command-line arguments.") # Must be called before KubernetesApiManager or GcpApiManager init. xds_flags.set_socket_default_timeout_from_flag() @@ -103,10 +117,12 @@ def main(argv): server_namespace, gcp_api_manager, port_forwarding=should_port_forward, - secure=is_secure) + secure=is_secure, + ) # Find server pod. - server_pod: k8s.V1Pod = common.get_server_pod(server_runner, - xds_flags.SERVER_NAME.value) + server_pod: k8s.V1Pod = common.get_server_pod( + server_runner, xds_flags.SERVER_NAME.value + ) # Client client_namespace = common.make_client_namespace(k8s_api_manager) @@ -114,10 +130,12 @@ def main(argv): client_namespace, gcp_api_manager, port_forwarding=should_port_forward, - secure=is_secure) + secure=is_secure, + ) # Find client pod. - client_pod: k8s.V1Pod = common.get_client_pod(client_runner, - xds_flags.CLIENT_NAME.value) + client_pod: k8s.V1Pod = common.get_client_pod( + client_runner, xds_flags.CLIENT_NAME.value + ) # Ensure port forwarding stopped. common.register_graceful_exit(server_runner, client_runner) @@ -127,19 +145,22 @@ def main(argv): server_runner, server_pod, test_port=xds_flags.SERVER_PORT.value, - secure_mode=is_secure) - test_server.set_xds_address(xds_flags.SERVER_XDS_HOST.value, - xds_flags.SERVER_XDS_PORT.value) + secure_mode=is_secure, + ) + test_server.set_xds_address( + xds_flags.SERVER_XDS_HOST.value, xds_flags.SERVER_XDS_PORT.value + ) # Create client app for the client pod. test_client: _XdsTestClient = common.get_test_client_for_pod( - client_runner, client_pod, server_target=test_server.xds_uri) + client_runner, client_pod, server_target=test_server.xds_uri + ) with test_client, test_server: run_ping_pong(test_client, _NUM_RPCS.value) - logger.info('SUCCESS!') + logger.info("SUCCESS!") -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/tools/run_tests/xds_k8s_test_driver/bin/run_td_setup.py b/tools/run_tests/xds_k8s_test_driver/bin/run_td_setup.py index 88e617ae98385..a3899c5bff990 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/run_td_setup.py +++ b/tools/run_tests/xds_k8s_test_driver/bin/run_td_setup.py @@ -45,31 +45,46 @@ logger = logging.getLogger(__name__) # Flags -_CMD = flags.DEFINE_enum('cmd', - default='create', - enum_values=[ - 'cycle', 'create', 'cleanup', 'backends-add', - 'backends-cleanup', 'unused-xds-port' - ], - help='Command') -_SECURITY = flags.DEFINE_enum('security', - default=None, - enum_values=[ - 'mtls', 'tls', 'plaintext', 'mtls_error', - 'server_authz_error' - ], - help='Configure TD with security') +_CMD = flags.DEFINE_enum( + "cmd", + default="create", + enum_values=[ + "cycle", + "create", + "cleanup", + "backends-add", + "backends-cleanup", + "unused-xds-port", + ], + help="Command", +) +_SECURITY = flags.DEFINE_enum( + "security", + default=None, + enum_values=[ + "mtls", + "tls", + "plaintext", + "mtls_error", + "server_authz_error", + ], + help="Configure TD with security", +) flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_k8s_flags) # Running outside of a test suite, so require explicit resource_suffix. flags.mark_flag_as_required(xds_flags.RESOURCE_SUFFIX.name) -@flags.multi_flags_validator((xds_flags.SERVER_XDS_PORT.name, _CMD.name), - message="Run outside of a test suite, must provide" - " the exact port value (must be greater than 0).") +@flags.multi_flags_validator( + (xds_flags.SERVER_XDS_PORT.name, _CMD.name), + message=( + "Run outside of a test suite, must provide" + " the exact port value (must be greater than 0)." + ), +) def _check_server_xds_port_flag(flags_dict): - if flags_dict[_CMD.name] not in ('create', 'cycle'): + if flags_dict[_CMD.name] not in ("create", "cycle"): return True return flags_dict[xds_flags.SERVER_XDS_PORT.name] > 0 @@ -78,9 +93,11 @@ def _check_server_xds_port_flag(flags_dict): _KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner -def main(argv): # pylint: disable=too-many-locals,too-many-branches,too-many-statements +def main( + argv, +): # pylint: disable=too-many-locals,too-many-branches,too-many-statements if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + raise app.UsageError("Too many command-line arguments.") # Must be called before KubernetesApiManager or GcpApiManager init. xds_flags.set_socket_default_timeout_from_flag() @@ -102,7 +119,8 @@ def main(argv): # pylint: disable=too-many-locals,too-many-branches,too-many-st server_xds_host = xds_flags.SERVER_XDS_HOST.value server_xds_port = xds_flags.SERVER_XDS_PORT.value server_namespace = _KubernetesServerRunner.make_namespace_name( - resource_prefix, resource_suffix) + resource_prefix, resource_suffix + ) gcp_api_manager = gcp.api.GcpApiManager() @@ -112,141 +130,181 @@ def main(argv): # pylint: disable=too-many-locals,too-many-branches,too-many-st project=project, network=network, resource_prefix=resource_prefix, - resource_suffix=resource_suffix) + resource_suffix=resource_suffix, + ) else: td = traffic_director.TrafficDirectorSecureManager( gcp_api_manager, project=project, network=network, resource_prefix=resource_prefix, - resource_suffix=resource_suffix) + resource_suffix=resource_suffix, + ) if server_maintenance_port is None: - server_maintenance_port = \ + server_maintenance_port = ( _KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT + ) try: - if command in ('create', 'cycle'): - logger.info('Create mode') + if command in ("create", "cycle"): + logger.info("Create mode") if security_mode is None: - logger.info('No security') - td.setup_for_grpc(server_xds_host, - server_xds_port, - health_check_port=server_maintenance_port) - - elif security_mode == 'mtls': - logger.info('Setting up mtls') - td.setup_for_grpc(server_xds_host, - server_xds_port, - health_check_port=server_maintenance_port) - td.setup_server_security(server_namespace=server_namespace, - server_name=server_name, - server_port=server_port, - tls=True, - mtls=True) - td.setup_client_security(server_namespace=server_namespace, - server_name=server_name, - tls=True, - mtls=True) - - elif security_mode == 'tls': - logger.info('Setting up tls') - td.setup_for_grpc(server_xds_host, - server_xds_port, - health_check_port=server_maintenance_port) - td.setup_server_security(server_namespace=server_namespace, - server_name=server_name, - server_port=server_port, - tls=True, - mtls=False) - td.setup_client_security(server_namespace=server_namespace, - server_name=server_name, - tls=True, - mtls=False) - - elif security_mode == 'plaintext': - logger.info('Setting up plaintext') - td.setup_for_grpc(server_xds_host, - server_xds_port, - health_check_port=server_maintenance_port) - td.setup_server_security(server_namespace=server_namespace, - server_name=server_name, - server_port=server_port, - tls=False, - mtls=False) - td.setup_client_security(server_namespace=server_namespace, - server_name=server_name, - tls=False, - mtls=False) - - elif security_mode == 'mtls_error': + logger.info("No security") + td.setup_for_grpc( + server_xds_host, + server_xds_port, + health_check_port=server_maintenance_port, + ) + + elif security_mode == "mtls": + logger.info("Setting up mtls") + td.setup_for_grpc( + server_xds_host, + server_xds_port, + health_check_port=server_maintenance_port, + ) + td.setup_server_security( + server_namespace=server_namespace, + server_name=server_name, + server_port=server_port, + tls=True, + mtls=True, + ) + td.setup_client_security( + server_namespace=server_namespace, + server_name=server_name, + tls=True, + mtls=True, + ) + + elif security_mode == "tls": + logger.info("Setting up tls") + td.setup_for_grpc( + server_xds_host, + server_xds_port, + health_check_port=server_maintenance_port, + ) + td.setup_server_security( + server_namespace=server_namespace, + server_name=server_name, + server_port=server_port, + tls=True, + mtls=False, + ) + td.setup_client_security( + server_namespace=server_namespace, + server_name=server_name, + tls=True, + mtls=False, + ) + + elif security_mode == "plaintext": + logger.info("Setting up plaintext") + td.setup_for_grpc( + server_xds_host, + server_xds_port, + health_check_port=server_maintenance_port, + ) + td.setup_server_security( + server_namespace=server_namespace, + server_name=server_name, + server_port=server_port, + tls=False, + mtls=False, + ) + td.setup_client_security( + server_namespace=server_namespace, + server_name=server_name, + tls=False, + mtls=False, + ) + + elif security_mode == "mtls_error": # Error case: server expects client mTLS cert, # but client configured only for TLS - logger.info('Setting up mtls_error') - td.setup_for_grpc(server_xds_host, - server_xds_port, - health_check_port=server_maintenance_port) - td.setup_server_security(server_namespace=server_namespace, - server_name=server_name, - server_port=server_port, - tls=True, - mtls=True) - td.setup_client_security(server_namespace=server_namespace, - server_name=server_name, - tls=True, - mtls=False) - - elif security_mode == 'server_authz_error': + logger.info("Setting up mtls_error") + td.setup_for_grpc( + server_xds_host, + server_xds_port, + health_check_port=server_maintenance_port, + ) + td.setup_server_security( + server_namespace=server_namespace, + server_name=server_name, + server_port=server_port, + tls=True, + mtls=True, + ) + td.setup_client_security( + server_namespace=server_namespace, + server_name=server_name, + tls=True, + mtls=False, + ) + + elif security_mode == "server_authz_error": # Error case: client does not authorize server # because of mismatched SAN name. - logger.info('Setting up mtls_error') - td.setup_for_grpc(server_xds_host, - server_xds_port, - health_check_port=server_maintenance_port) + logger.info("Setting up mtls_error") + td.setup_for_grpc( + server_xds_host, + server_xds_port, + health_check_port=server_maintenance_port, + ) # Regular TLS setup, but with client policy configured using # intentionality incorrect server_namespace. - td.setup_server_security(server_namespace=server_namespace, - server_name=server_name, - server_port=server_port, - tls=True, - mtls=False) + td.setup_server_security( + server_namespace=server_namespace, + server_name=server_name, + server_port=server_port, + tls=True, + mtls=False, + ) td.setup_client_security( - server_namespace=f'incorrect-namespace-{rand.rand_string()}', + server_namespace=( + f"incorrect-namespace-{rand.rand_string()}" + ), server_name=server_name, tls=True, - mtls=False) + mtls=False, + ) - logger.info('Works!') + logger.info("Works!") except Exception: # noqa pylint: disable=broad-except - logger.exception('Got error during creation') + logger.exception("Got error during creation") - if command in ('cleanup', 'cycle'): - logger.info('Cleaning up') + if command in ("cleanup", "cycle"): + logger.info("Cleaning up") td.cleanup(force=True) - if command == 'backends-add': - logger.info('Adding backends') + if command == "backends-add": + logger.info("Adding backends") k8s_api_manager = k8s.KubernetesApiManager( - xds_k8s_flags.KUBE_CONTEXT.value) - k8s_namespace = k8s.KubernetesNamespace(k8s_api_manager, - server_namespace) + xds_k8s_flags.KUBE_CONTEXT.value + ) + k8s_namespace = k8s.KubernetesNamespace( + k8s_api_manager, server_namespace + ) neg_name, neg_zones = k8s_namespace.get_service_neg( - server_name, server_port) + server_name, server_port + ) td.load_backend_service() td.backend_service_add_neg_backends(neg_name, neg_zones) td.wait_for_backends_healthy_status() - elif command == 'backends-cleanup': + elif command == "backends-cleanup": td.load_backend_service() td.backend_service_remove_all_backends() - elif command == 'unused-xds-port': + elif command == "unused-xds-port": try: unused_xds_port = td.find_unused_forwarding_rule_port() - logger.info('Found unused forwarding rule port: %s', - unused_xds_port) + logger.info( + "Found unused forwarding rule port: %s", unused_xds_port + ) except Exception: # noqa pylint: disable=broad-except logger.exception("Couldn't find unused forwarding rule port") -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/tools/run_tests/xds_k8s_test_driver/bin/run_test_client.py b/tools/run_tests/xds_k8s_test_driver/bin/run_test_client.py index c1b80cf711d65..2981739206cd3 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/run_test_client.py +++ b/tools/run_tests/xds_k8s_test_driver/bin/run_test_client.py @@ -25,43 +25,50 @@ logger = logging.getLogger(__name__) # Flags -_CMD = flags.DEFINE_enum('cmd', - default='run', - enum_values=['run', 'cleanup'], - help='Command') -_SECURE = flags.DEFINE_bool("secure", - default=False, - help="Run client in the secure mode") -_QPS = flags.DEFINE_integer('qps', default=25, help='Queries per second') -_PRINT_RESPONSE = flags.DEFINE_bool("print_response", - default=False, - help="Client prints responses") +_CMD = flags.DEFINE_enum( + "cmd", default="run", enum_values=["run", "cleanup"], help="Command" +) +_SECURE = flags.DEFINE_bool( + "secure", default=False, help="Run client in the secure mode" +) +_QPS = flags.DEFINE_integer("qps", default=25, help="Queries per second") +_PRINT_RESPONSE = flags.DEFINE_bool( + "print_response", default=False, help="Client prints responses" +) _FOLLOW = flags.DEFINE_bool( "follow", default=False, - help= - "Follow pod logs. Requires --collect_app_logs or --debug_use_port_forwarding" + help=( + "Follow pod logs. Requires --collect_app_logs or" + " --debug_use_port_forwarding" + ), ) _CONFIG_MESH = flags.DEFINE_bool( "config_mesh", default=None, - help="Optional. Supplied to bootstrap generator to indicate AppNet mesh.") -_REUSE_NAMESPACE = flags.DEFINE_bool("reuse_namespace", - default=True, - help="Use existing namespace if exists") + help="Optional. Supplied to bootstrap generator to indicate AppNet mesh.", +) +_REUSE_NAMESPACE = flags.DEFINE_bool( + "reuse_namespace", default=True, help="Use existing namespace if exists" +) _CLEANUP_NAMESPACE = flags.DEFINE_bool( "cleanup_namespace", default=False, - help="Delete namespace during resource cleanup") + help="Delete namespace during resource cleanup", +) flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_k8s_flags) # Running outside of a test suite, so require explicit resource_suffix. flags.mark_flag_as_required(xds_flags.RESOURCE_SUFFIX.name) -@flags.multi_flags_validator((xds_flags.SERVER_XDS_PORT.name, _CMD.name), - message="Run outside of a test suite, must provide" - " the exact port value (must be greater than 0).") +@flags.multi_flags_validator( + (xds_flags.SERVER_XDS_PORT.name, _CMD.name), + message=( + "Run outside of a test suite, must provide" + " the exact port value (must be greater than 0)." + ), +) def _check_server_xds_port_flag(flags_dict): if flags_dict[_CMD.name] == "cleanup": return True @@ -69,10 +76,9 @@ def _check_server_xds_port_flag(flags_dict): def _make_sigint_handler(client_runner: common.KubernetesClientRunner): - def sigint_handler(sig, frame): del sig, frame - print('Caught Ctrl+C. Shutting down the logs') + print("Caught Ctrl+C. Shutting down the logs") client_runner.stop_pod_dependencies(log_drain_sec=3) return sigint_handler @@ -80,15 +86,16 @@ def sigint_handler(sig, frame): def main(argv): if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + raise app.UsageError("Too many command-line arguments.") # Must be called before KubernetesApiManager or GcpApiManager init. xds_flags.set_socket_default_timeout_from_flag() # Log following and port forwarding. should_follow_logs = _FOLLOW.value and xds_flags.COLLECT_APP_LOGS.value - should_port_forward = (should_follow_logs and - xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value) + should_port_forward = ( + should_follow_logs and xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value + ) # Setup. gcp_api_manager = gcp.api.GcpApiManager() @@ -99,31 +106,34 @@ def main(argv): gcp_api_manager, reuse_namespace=_REUSE_NAMESPACE.value, secure=_SECURE.value, - port_forwarding=should_port_forward) + port_forwarding=should_port_forward, + ) # Server target server_xds_host = xds_flags.SERVER_XDS_HOST.value server_xds_port = xds_flags.SERVER_XDS_PORT.value - if _CMD.value == 'run': - logger.info('Run client, secure_mode=%s', _SECURE.value) + if _CMD.value == "run": + logger.info("Run client, secure_mode=%s", _SECURE.value) client_runner.run( - server_target=f'xds:///{server_xds_host}:{server_xds_port}', + server_target=f"xds:///{server_xds_host}:{server_xds_port}", qps=_QPS.value, print_response=_PRINT_RESPONSE.value, secure_mode=_SECURE.value, config_mesh=_CONFIG_MESH.value, - log_to_stdout=_FOLLOW.value) + log_to_stdout=_FOLLOW.value, + ) if should_follow_logs: - print('Following pod logs. Press Ctrl+C top stop') + print("Following pod logs. Press Ctrl+C top stop") signal.signal(signal.SIGINT, _make_sigint_handler(client_runner)) signal.pause() - elif _CMD.value == 'cleanup': - logger.info('Cleanup client') - client_runner.cleanup(force=True, - force_namespace=_CLEANUP_NAMESPACE.value) + elif _CMD.value == "cleanup": + logger.info("Cleanup client") + client_runner.cleanup( + force=True, force_namespace=_CLEANUP_NAMESPACE.value + ) -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/tools/run_tests/xds_k8s_test_driver/bin/run_test_server.py b/tools/run_tests/xds_k8s_test_driver/bin/run_test_server.py index 1cc0d89f53455..c24720d5c5bcc 100755 --- a/tools/run_tests/xds_k8s_test_driver/bin/run_test_server.py +++ b/tools/run_tests/xds_k8s_test_driver/bin/run_test_server.py @@ -25,27 +25,26 @@ logger = logging.getLogger(__name__) # Flags -_CMD = flags.DEFINE_enum('cmd', - default='run', - enum_values=['run', 'cleanup'], - help='Command') -_SECURE = flags.DEFINE_bool("secure", - default=False, - help="Run server in the secure mode") -_REUSE_NAMESPACE = flags.DEFINE_bool("reuse_namespace", - default=True, - help="Use existing namespace if exists") -_REUSE_SERVICE = flags.DEFINE_bool("reuse_service", - default=False, - help="Use existing service if exists") -_FOLLOW = flags.DEFINE_bool("follow", - default=False, - help="Follow pod logs. " - "Requires --collect_app_logs") +_CMD = flags.DEFINE_enum( + "cmd", default="run", enum_values=["run", "cleanup"], help="Command" +) +_SECURE = flags.DEFINE_bool( + "secure", default=False, help="Run server in the secure mode" +) +_REUSE_NAMESPACE = flags.DEFINE_bool( + "reuse_namespace", default=True, help="Use existing namespace if exists" +) +_REUSE_SERVICE = flags.DEFINE_bool( + "reuse_service", default=False, help="Use existing service if exists" +) +_FOLLOW = flags.DEFINE_bool( + "follow", default=False, help="Follow pod logs. Requires --collect_app_logs" +) _CLEANUP_NAMESPACE = flags.DEFINE_bool( "cleanup_namespace", default=False, - help="Delete namespace during resource cleanup") + help="Delete namespace during resource cleanup", +) flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_k8s_flags) # Running outside of a test suite, so require explicit resource_suffix. @@ -53,10 +52,9 @@ def _make_sigint_handler(server_runner: common.KubernetesServerRunner): - def sigint_handler(sig, frame): del sig, frame - print('Caught Ctrl+C. Shutting down the logs') + print("Caught Ctrl+C. Shutting down the logs") server_runner.stop_pod_dependencies(log_drain_sec=3) return sigint_handler @@ -64,14 +62,15 @@ def sigint_handler(sig, frame): def main(argv): if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + raise app.UsageError("Too many command-line arguments.") # Must be called before KubernetesApiManager or GcpApiManager init. xds_flags.set_socket_default_timeout_from_flag() should_follow_logs = _FOLLOW.value and xds_flags.COLLECT_APP_LOGS.value - should_port_forward = (should_follow_logs and - xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value) + should_port_forward = ( + should_follow_logs and xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value + ) # Setup. gcp_api_manager = gcp.api.GcpApiManager() @@ -83,25 +82,28 @@ def main(argv): reuse_namespace=_REUSE_NAMESPACE.value, reuse_service=_REUSE_SERVICE.value, secure=_SECURE.value, - port_forwarding=should_port_forward) + port_forwarding=should_port_forward, + ) - if _CMD.value == 'run': - logger.info('Run server, secure_mode=%s', _SECURE.value) + if _CMD.value == "run": + logger.info("Run server, secure_mode=%s", _SECURE.value) server_runner.run( test_port=xds_flags.SERVER_PORT.value, maintenance_port=xds_flags.SERVER_MAINTENANCE_PORT.value, secure_mode=_SECURE.value, - log_to_stdout=_FOLLOW.value) + log_to_stdout=_FOLLOW.value, + ) if should_follow_logs: - print('Following pod logs. Press Ctrl+C top stop') + print("Following pod logs. Press Ctrl+C top stop") signal.signal(signal.SIGINT, _make_sigint_handler(server_runner)) signal.pause() - elif _CMD.value == 'cleanup': - logger.info('Cleanup server') - server_runner.cleanup(force=True, - force_namespace=_CLEANUP_NAMESPACE.value) + elif _CMD.value == "cleanup": + logger.info("Cleanup server") + server_runner.cleanup( + force=True, force_namespace=_CLEANUP_NAMESPACE.value + ) -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/bootstrap_generator_testcase.py b/tools/run_tests/xds_k8s_test_driver/framework/bootstrap_generator_testcase.py index 4622428d910a6..96ff1864fb394 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/bootstrap_generator_testcase.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/bootstrap_generator_testcase.py @@ -42,8 +42,9 @@ def setUpClass(cls): """ super().setUpClass() if cls.server_maintenance_port is None: - cls.server_maintenance_port = \ + cls.server_maintenance_port = ( KubernetesServerRunner.DEFAULT_MAINTENANCE_PORT + ) # Bootstrap generator tests are run as parameterized tests which only # perform steps specific to the parameterized version of the bootstrap @@ -53,22 +54,28 @@ def setUpClass(cls): # side variants of the bootstrap generator test. if cls.resource_suffix_randomize: cls.resource_suffix = helpers_rand.random_resource_suffix() - logger.info('Test run resource prefix: %s, suffix: %s', - cls.resource_prefix, cls.resource_suffix) + logger.info( + "Test run resource prefix: %s, suffix: %s", + cls.resource_prefix, + cls.resource_suffix, + ) # TD Manager cls.td = cls.initTrafficDirectorManager() # Test namespaces for client and server. cls.server_namespace = KubernetesServerRunner.make_namespace_name( - cls.resource_prefix, cls.resource_suffix) + cls.resource_prefix, cls.resource_suffix + ) cls.client_namespace = KubernetesClientRunner.make_namespace_name( - cls.resource_prefix, cls.resource_suffix) + cls.resource_prefix, cls.resource_suffix + ) # Ensures the firewall exist if cls.ensure_firewall: cls.td.create_firewall_rule( - allowed_ports=cls.firewall_allowed_ports) + allowed_ports=cls.firewall_allowed_ports + ) # Randomize xds port, when it's set to 0 if cls.server_xds_port == 0: @@ -78,12 +85,14 @@ def setUpClass(cls): # forwarding rule. This check is better than nothing, # but we should find a better approach. cls.server_xds_port = cls.td.find_unused_forwarding_rule_port() - logger.info('Found unused xds port: %s', cls.server_xds_port) + logger.info("Found unused xds port: %s", cls.server_xds_port) # Common TD resources across client and server tests. - cls.td.setup_for_grpc(cls.server_xds_host, - cls.server_xds_port, - health_check_port=cls.server_maintenance_port) + cls.td.setup_for_grpc( + cls.server_xds_host, + cls.server_xds_port, + health_check_port=cls.server_maintenance_port, + ) @classmethod def tearDownClass(cls): @@ -98,13 +107,13 @@ def initTrafficDirectorManager(cls) -> TrafficDirectorManager: resource_prefix=cls.resource_prefix, resource_suffix=cls.resource_suffix, network=cls.network, - compute_api_version=cls.compute_api_version) + compute_api_version=cls.compute_api_version, + ) @classmethod def initKubernetesServerRunner( - cls, - *, - td_bootstrap_image: Optional[str] = None) -> KubernetesServerRunner: + cls, *, td_bootstrap_image: Optional[str] = None + ) -> KubernetesServerRunner: if not td_bootstrap_image: td_bootstrap_image = cls.td_bootstrap_image return KubernetesServerRunner( @@ -118,31 +127,37 @@ def initKubernetesServerRunner( xds_server_uri=cls.xds_server_uri, network=cls.network, debug_use_port_forwarding=cls.debug_use_port_forwarding, - enable_workload_identity=cls.enable_workload_identity) + enable_workload_identity=cls.enable_workload_identity, + ) @staticmethod - def startTestServer(server_runner, - port, - maintenance_port, - xds_host, - xds_port, - replica_count=1, - **kwargs) -> XdsTestServer: - test_server = server_runner.run(replica_count=replica_count, - test_port=port, - maintenance_port=maintenance_port, - **kwargs)[0] + def startTestServer( + server_runner, + port, + maintenance_port, + xds_host, + xds_port, + replica_count=1, + **kwargs, + ) -> XdsTestServer: + test_server = server_runner.run( + replica_count=replica_count, + test_port=port, + maintenance_port=maintenance_port, + **kwargs, + )[0] test_server.set_xds_address(xds_host, xds_port) return test_server def initKubernetesClientRunner( - self, - td_bootstrap_image: Optional[str] = None) -> KubernetesClientRunner: + self, td_bootstrap_image: Optional[str] = None + ) -> KubernetesClientRunner: if not td_bootstrap_image: td_bootstrap_image = self.td_bootstrap_image return KubernetesClientRunner( - k8s.KubernetesNamespace(self.k8s_api_manager, - self.client_namespace), + k8s.KubernetesNamespace( + self.k8s_api_manager, self.client_namespace + ), deployment_name=self.client_name, image_name=self.client_image, td_bootstrap_image=td_bootstrap_image, @@ -154,11 +169,14 @@ def initKubernetesClientRunner( debug_use_port_forwarding=self.debug_use_port_forwarding, enable_workload_identity=self.enable_workload_identity, stats_port=self.client_port, - reuse_namespace=self.server_namespace == self.client_namespace) - - def startTestClient(self, test_server: XdsTestServer, - **kwargs) -> XdsTestClient: - test_client = self.client_runner.run(server_target=test_server.xds_uri, - **kwargs) + reuse_namespace=self.server_namespace == self.client_namespace, + ) + + def startTestClient( + self, test_server: XdsTestServer, **kwargs + ) -> XdsTestClient: + test_client = self.client_runner.run( + server_target=test_server.xds_uri, **kwargs + ) test_client.wait_for_active_server_channel() return test_client diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/datetime.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/datetime.py index 2833af0a29122..08d06a1349f3f 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/datetime.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/datetime.py @@ -16,7 +16,7 @@ import re from typing import Pattern -RE_ZERO_OFFSET: Pattern[str] = re.compile(r'[+\-]00:?00$') +RE_ZERO_OFFSET: Pattern[str] = re.compile(r"[+\-]00:?00$") def utc_now() -> datetime.datetime: @@ -26,7 +26,7 @@ def utc_now() -> datetime.datetime: def shorten_utc_zone(utc_datetime_str: str) -> str: """Replace ±00:00 timezone designator with Z (zero offset AKA Zulu time).""" - return RE_ZERO_OFFSET.sub('Z', utc_datetime_str) + return RE_ZERO_OFFSET.sub("Z", utc_datetime_str) def iso8601_utc_time(time: datetime.datetime = None) -> str: @@ -47,4 +47,4 @@ def datetime_suffix(*, seconds: bool = False) -> str: Hours and minutes are joined together for better readability, so time is visually distinct from dash-separated date. """ - return utc_now().strftime('%Y%m%d-%H%M' + ('%S' if seconds else '')) + return utc_now().strftime("%Y%m%d-%H%M" + ("%S" if seconds else "")) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/grpc.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/grpc.py index 10a99dc47c164..efa0d51c4a7f9 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/grpc.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/grpc.py @@ -67,8 +67,8 @@ def rpcs_completed(self): @staticmethod def from_response( - method_name: str, - method_stats: grpc_testing.MethodStats) -> "PrettyStatsPerMethod": + method_name: str, method_stats: grpc_testing.MethodStats + ) -> "PrettyStatsPerMethod": stats: Dict[str, int] = dict() for status_int, count in method_stats.result.items(): status: Optional[grpc.StatusCode] = status_from_int(status_int) @@ -103,7 +103,8 @@ def accumulated_stats_pretty( result: List[Dict] = [] for method_name, method_stats in accumulated_stats.stats_per_method.items(): pretty_stats = PrettyStatsPerMethod.from_response( - method_name, method_stats) + method_name, method_stats + ) # Skip methods with no RPCs reported when ignore_empty is True. if ignore_empty and not pretty_stats.rpcs_started: continue @@ -132,7 +133,8 @@ class PrettyLoadBalancerStats: @staticmethod def _parse_rpcs_by_peer( - rpcs_by_peer: grpc_testing.RpcsByPeer) -> "RpcsByPeer": + rpcs_by_peer: grpc_testing.RpcsByPeer, + ) -> "RpcsByPeer": result = dict() for peer, count in rpcs_by_peer.items(): result[peer] = count @@ -146,7 +148,8 @@ def from_response( for method_name, stats in lb_stats.rpcs_by_method.items(): if stats: rpcs_by_method[method_name] = cls._parse_rpcs_by_peer( - stats.rpcs_by_peer) + stats.rpcs_by_peer + ) return PrettyLoadBalancerStats( num_failures=lb_stats.num_failures, rpcs_by_peer=cls._parse_rpcs_by_peer(lb_stats.rpcs_by_peer), diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/highlighter.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/highlighter.py index 0c9f0112614ff..ef6b74c72f762 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/highlighter.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/highlighter.py @@ -32,18 +32,21 @@ import pygments.styles # The style for terminals supporting 8/16 colors. -STYLE_ANSI_16 = 'ansi16' +STYLE_ANSI_16 = "ansi16" # Join with pygments styles for terminals supporting 88/256 colors. ALL_COLOR_STYLES = [STYLE_ANSI_16] + list(pygments.styles.get_all_styles()) # Flags. -COLOR = flags.DEFINE_bool("color", default=True, help='Colorize the output') +COLOR = flags.DEFINE_bool("color", default=True, help="Colorize the output") COLOR_STYLE = flags.DEFINE_enum( "color_style", - default='material', + default="material", enum_values=ALL_COLOR_STYLES, - help=('Color styles for terminals supporting 256 colors. ' - f'Use {STYLE_ANSI_16} style for terminals supporting 8/16 colors')) + help=( + "Color styles for terminals supporting 256 colors. " + f"Use {STYLE_ANSI_16} style for terminals supporting 8/16 colors" + ), +) logger = logging.getLogger(__name__) @@ -62,19 +65,23 @@ class Highlighter: color: bool color_style: Optional[str] = None - def __init__(self, - *, - lexer: Lexer, - color: Optional[bool] = None, - color_style: Optional[str] = None): + def __init__( + self, + *, + lexer: Lexer, + color: Optional[bool] = None, + color_style: Optional[str] = None, + ): self.lexer = lexer self.color = color if color is not None else COLOR.value if self.color: color_style = color_style if color_style else COLOR_STYLE.value if color_style not in ALL_COLOR_STYLES: - raise ValueError(f'Unrecognized color style {color_style}, ' - f'valid styles: {ALL_COLOR_STYLES}') + raise ValueError( + f"Unrecognized color style {color_style}, " + f"valid styles: {ALL_COLOR_STYLES}" + ) if color_style == STYLE_ANSI_16: # 8/16 colors support only. self.formatter = TerminalFormatter() @@ -89,11 +96,11 @@ def highlight(self, code: str) -> str: class HighlighterYaml(Highlighter): - - def __init__(self, - *, - color: Optional[bool] = None, - color_style: Optional[str] = None): - super().__init__(lexer=YamlLexer(encoding='utf-8'), - color=color, - color_style=color_style) + def __init__( + self, *, color: Optional[bool] = None, color_style: Optional[str] = None + ): + super().__init__( + lexer=YamlLexer(encoding="utf-8"), + color=color, + color_style=color_style, + ) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/logs.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/logs.py index 4c20d415c805c..4678737e2c209 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/logs.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/logs.py @@ -21,7 +21,7 @@ def _ensure_flags_parsed() -> None: if not flags.FLAGS.is_parsed(): - raise flags.UnparsedFlagAccessError('Must initialize absl flags first.') + raise flags.UnparsedFlagAccessError("Must initialize absl flags first.") @functools.lru_cache(None) @@ -35,9 +35,9 @@ def log_get_root_dir() -> pathlib.Path: def log_dir_mkdir(name: str) -> pathlib.Path: """Creates and returns a subdir with the given name in the log folder.""" if len(pathlib.Path(name).parts) != 1: - raise ValueError(f'Dir name must be a single component; got: {name}') + raise ValueError(f"Dir name must be a single component; got: {name}") if ".." in name: - raise ValueError(f'Dir name must not be above the log root.') + raise ValueError(f"Dir name must not be above the log root.") log_subdir = log_get_root_dir() / name if log_subdir.exists() and log_subdir.is_dir(): logging.debug("Using existing log subdir: %s", log_subdir) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/rand.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/rand.py index dedf2143603bc..6ed7995468859 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/rand.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/rand.py @@ -27,12 +27,12 @@ def rand_string(length: int = 8, *, lowercase: bool = False) -> str: """Return random alphanumeric string of given length. - Space for default arguments: alphabet^length - lowercase and uppercase = (26*2 + 10)^8 = 2.18e14 = 218 trillion. - lowercase only = (26 + 10)^8 = 2.8e12 = 2.8 trillion. - """ + Space for default arguments: alphabet^length + lowercase and uppercase = (26*2 + 10)^8 = 2.18e14 = 218 trillion. + lowercase only = (26 + 10)^8 = 2.8e12 = 2.8 trillion. + """ alphabet = ALPHANUM_LOWERCASE if lowercase else ALPHANUM - return ''.join(random.choices(population=alphabet, k=length)) + return "".join(random.choices(population=alphabet, k=length)) def random_resource_suffix() -> str: @@ -46,4 +46,4 @@ def random_resource_suffix() -> str: # produce a collision: math.sqrt(math.pi/2 * (26+10)**5) ≈ 9745. # https://en.wikipedia.org/wiki/Birthday_attack#Mathematics unique_hash: str = rand_string(5, lowercase=True) - return f'{datetime_suffix}-{unique_hash}' + return f"{datetime_suffix}-{unique_hash}" diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py index e29f0371a0e30..15990b79e1e09 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py @@ -40,10 +40,10 @@ def _build_retry_conditions( - *, - retry_on_exceptions: Optional[_ExceptionClasses] = None, - check_result: Optional[CheckResultFn] = None) -> List[retry_base]: - + *, + retry_on_exceptions: Optional[_ExceptionClasses] = None, + check_result: Optional[CheckResultFn] = None, +) -> List[retry_base]: # Retry on all exceptions by default if retry_on_exceptions is None: retry_on_exceptions = (Exception,) @@ -59,45 +59,53 @@ def _build_retry_conditions( def exponential_retryer_with_timeout( - *, - wait_min: timedelta, - wait_max: timedelta, - timeout: timedelta, - retry_on_exceptions: Optional[_ExceptionClasses] = None, - check_result: Optional[CheckResultFn] = None, - logger: Optional[logging.Logger] = None, - log_level: Optional[int] = logging.DEBUG) -> Retrying: + *, + wait_min: timedelta, + wait_max: timedelta, + timeout: timedelta, + retry_on_exceptions: Optional[_ExceptionClasses] = None, + check_result: Optional[CheckResultFn] = None, + logger: Optional[logging.Logger] = None, + log_level: Optional[int] = logging.DEBUG, +) -> Retrying: if logger is None: logger = retryers_logger if log_level is None: log_level = logging.DEBUG retry_conditions = _build_retry_conditions( - retry_on_exceptions=retry_on_exceptions, check_result=check_result) - retry_error_callback = _on_error_callback(timeout=timeout, - check_result=check_result) - return Retrying(retry=tenacity.retry_any(*retry_conditions), - wait=wait.wait_exponential(min=wait_min.total_seconds(), - max=wait_max.total_seconds()), - stop=stop.stop_after_delay(timeout.total_seconds()), - before_sleep=_before_sleep_log(logger, log_level), - retry_error_callback=retry_error_callback) - - -def constant_retryer(*, - wait_fixed: timedelta, - attempts: int = 0, - timeout: Optional[timedelta] = None, - retry_on_exceptions: Optional[_ExceptionClasses] = None, - check_result: Optional[CheckResultFn] = None, - logger: Optional[logging.Logger] = None, - log_level: Optional[int] = logging.DEBUG) -> Retrying: + retry_on_exceptions=retry_on_exceptions, check_result=check_result + ) + retry_error_callback = _on_error_callback( + timeout=timeout, check_result=check_result + ) + return Retrying( + retry=tenacity.retry_any(*retry_conditions), + wait=wait.wait_exponential( + min=wait_min.total_seconds(), max=wait_max.total_seconds() + ), + stop=stop.stop_after_delay(timeout.total_seconds()), + before_sleep=_before_sleep_log(logger, log_level), + retry_error_callback=retry_error_callback, + ) + + +def constant_retryer( + *, + wait_fixed: timedelta, + attempts: int = 0, + timeout: Optional[timedelta] = None, + retry_on_exceptions: Optional[_ExceptionClasses] = None, + check_result: Optional[CheckResultFn] = None, + logger: Optional[logging.Logger] = None, + log_level: Optional[int] = logging.DEBUG, +) -> Retrying: if logger is None: logger = retryers_logger if log_level is None: log_level = logging.DEBUG if attempts < 1 and timeout is None: - raise ValueError('The number of attempts or the timeout must be set') + raise ValueError("The number of attempts or the timeout must be set") stops = [] if attempts > 0: stops.append(stop.stop_after_attempt(attempts)) @@ -105,36 +113,44 @@ def constant_retryer(*, stops.append(stop.stop_after_delay(timeout.total_seconds())) retry_conditions = _build_retry_conditions( - retry_on_exceptions=retry_on_exceptions, check_result=check_result) - retry_error_callback = _on_error_callback(timeout=timeout, - attempts=attempts, - check_result=check_result) - return Retrying(retry=tenacity.retry_any(*retry_conditions), - wait=wait.wait_fixed(wait_fixed.total_seconds()), - stop=stop.stop_any(*stops), - before_sleep=_before_sleep_log(logger, log_level), - retry_error_callback=retry_error_callback) - - -def _on_error_callback(*, - timeout: Optional[timedelta] = None, - attempts: int = 0, - check_result: Optional[CheckResultFn] = None): + retry_on_exceptions=retry_on_exceptions, check_result=check_result + ) + retry_error_callback = _on_error_callback( + timeout=timeout, attempts=attempts, check_result=check_result + ) + return Retrying( + retry=tenacity.retry_any(*retry_conditions), + wait=wait.wait_fixed(wait_fixed.total_seconds()), + stop=stop.stop_any(*stops), + before_sleep=_before_sleep_log(logger, log_level), + retry_error_callback=retry_error_callback, + ) + + +def _on_error_callback( + *, + timeout: Optional[timedelta] = None, + attempts: int = 0, + check_result: Optional[CheckResultFn] = None, +): """A helper to propagate the initial state to the RetryError, so that it can assemble a helpful message containing timeout/number of attempts. """ def error_handler(retry_state: tenacity.RetryCallState): - raise RetryError(retry_state, - timeout=timeout, - attempts=attempts, - check_result=check_result) + raise RetryError( + retry_state, + timeout=timeout, + attempts=attempts, + check_result=check_result, + ) return error_handler -def _safe_check_result(check_result: CheckResultFn, - retry_on_exceptions: _ExceptionClasses) -> CheckResultFn: +def _safe_check_result( + check_result: CheckResultFn, retry_on_exceptions: _ExceptionClasses +) -> CheckResultFn: """Wraps check_result callback to catch and handle retry_on_exceptions. Normally tenacity doesn't retry when retry_if_result/retry_if_not_result @@ -150,11 +166,14 @@ def _check_result_wrapped(result): return check_result(result) except retry_on_exceptions: retryers_logger.warning( - "Result check callback %s raised an exception." - "This shouldn't happen, please handle any exceptions and " - "return return a boolean.", + ( + "Result check callback %s raised an exception." + "This shouldn't happen, please handle any exceptions and " + "return return a boolean." + ), tenacity_utils.get_callback_name(check_result), - exc_info=True) + exc_info=True, + ) return False return _check_result_wrapped @@ -168,60 +187,65 @@ def _before_sleep_log(logger, log_level, exc_info=False): def log_it(retry_state): if retry_state.outcome.failed: ex = retry_state.outcome.exception() - verb, value = 'raised', '%s: %s' % (type(ex).__name__, ex) + verb, value = "raised", "%s: %s" % (type(ex).__name__, ex) if exc_info: local_exc_info = tenacity_compat.get_exc_info_from_future( - retry_state.outcome) + retry_state.outcome + ) else: local_exc_info = False else: local_exc_info = False # exc_info does not apply when no exception result = retry_state.outcome.result() if isinstance(result, (int, bool, str)): - verb, value = 'returned', result + verb, value = "returned", result else: - verb, value = 'returned type', type(result) - - logger.log(log_level, - "Retrying %s in %s seconds as it %s %s.", - tenacity_utils.get_callback_name(retry_state.fn), - getattr(retry_state.next_action, 'sleep'), - verb, - value, - exc_info=local_exc_info) + verb, value = "returned type", type(result) + + logger.log( + log_level, + "Retrying %s in %s seconds as it %s %s.", + tenacity_utils.get_callback_name(retry_state.fn), + getattr(retry_state.next_action, "sleep"), + verb, + value, + exc_info=local_exc_info, + ) return log_it class RetryError(tenacity.RetryError): - - def __init__(self, - retry_state, - *, - timeout: Optional[timedelta] = None, - attempts: int = 0, - check_result: Optional[CheckResultFn] = None): + def __init__( + self, + retry_state, + *, + timeout: Optional[timedelta] = None, + attempts: int = 0, + check_result: Optional[CheckResultFn] = None, + ): super().__init__(retry_state.outcome) callback_name = tenacity_utils.get_callback_name(retry_state.fn) - self.message = f'Retry error calling {callback_name}:' + self.message = f"Retry error calling {callback_name}:" if timeout: - self.message += f' timeout {timeout} (h:mm:ss) exceeded' + self.message += f" timeout {timeout} (h:mm:ss) exceeded" if attempts: - self.message += ' or' + self.message += " or" if attempts: - self.message += f' {attempts} attempts exhausted' + self.message += f" {attempts} attempts exhausted" - self.message += '.' + self.message += "." if retry_state.outcome.failed: ex = retry_state.outcome.exception() - self.message += f' Last exception: {type(ex).__name__}: {ex}' + self.message += f" Last exception: {type(ex).__name__}: {ex}" elif check_result: - self.message += ' Check result callback returned False.' + self.message += " Check result callback returned False." def result(self, *, default=None): - return default if self.last_attempt.failed else self.last_attempt.result( + return ( + default if self.last_attempt.failed else self.last_attempt.result() ) def __str__(self): diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/skips.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/skips.py index b88e4814e4c5d..9df9f73289e4c 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/skips.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/skips.py @@ -52,6 +52,7 @@ class TestConfig: TODO(sergiitk): rename to LangSpec and rename skips.py to lang.py. """ + client_lang: Lang server_lang: Lang version: Optional[str] @@ -74,29 +75,32 @@ def version_gte(self, another: str) -> bool: 3) Unspecified version (self.version is None) is treated as "master". """ - if self.version in ('master', 'dev', 'dev-master', None): + if self.version in ("master", "dev", "dev-master", None): return True - if another == 'master': + if another == "master": return False return self._parse_version(self.version) >= self._parse_version(another) def __str__(self): - return (f"TestConfig(client_lang='{self.client_lang}', " - f"server_lang='{self.server_lang}', version={self.version!r})") + return ( + f"TestConfig(client_lang='{self.client_lang}', " + f"server_lang='{self.server_lang}', version={self.version!r})" + ) @staticmethod def _parse_version(version: str) -> pkg_version.Version: - if version.startswith('dev-'): + if version.startswith("dev-"): # Treat "dev-VERSION" as "VERSION". version = version[4:] - if version.endswith('.x'): + if version.endswith(".x"): version = version[:-2] return pkg_version.Version(version) def _get_lang(image_name: str) -> Lang: return Lang.from_string( - re.search(r'/(\w+)-(client|server):', image_name).group(1)) + re.search(r"/(\w+)-(client|server):", image_name).group(1) + ) def evaluate_test_config(check: Callable[[TestConfig], bool]) -> TestConfig: @@ -111,10 +115,11 @@ def evaluate_test_config(check: Callable[[TestConfig], bool]) -> TestConfig: test_config = TestConfig( client_lang=_get_lang(xds_k8s_flags.CLIENT_IMAGE.value), server_lang=_get_lang(xds_k8s_flags.SERVER_IMAGE.value), - version=xds_flags.TESTING_VERSION.value) + version=xds_flags.TESTING_VERSION.value, + ) if not check(test_config): - logger.info('Skipping %s', test_config) - raise unittest.SkipTest(f'Unsupported test config: {test_config}') + logger.info("Skipping %s", test_config) + raise unittest.SkipTest(f"Unsupported test config: {test_config}") - logger.info('Detected language and version: %s', test_config) + logger.info("Detected language and version: %s", test_config) return test_config diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/api.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/api.py index 77fdf07ac13ed..a9909a30dce5f 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/api.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/api.py @@ -37,21 +37,31 @@ PRIVATE_API_KEY_SECRET_NAME = flags.DEFINE_string( "private_api_key_secret_name", default=None, - help="Load Private API access key from the latest version of the secret " - "with the given name, in the format projects/*/secrets/*") -V1_DISCOVERY_URI = flags.DEFINE_string("v1_discovery_uri", - default=discovery.V1_DISCOVERY_URI, - help="Override v1 Discovery URI") -V2_DISCOVERY_URI = flags.DEFINE_string("v2_discovery_uri", - default=discovery.V2_DISCOVERY_URI, - help="Override v2 Discovery URI") + help=( + "Load Private API access key from the latest version of the secret " + "with the given name, in the format projects/*/secrets/*" + ), +) +V1_DISCOVERY_URI = flags.DEFINE_string( + "v1_discovery_uri", + default=discovery.V1_DISCOVERY_URI, + help="Override v1 Discovery URI", +) +V2_DISCOVERY_URI = flags.DEFINE_string( + "v2_discovery_uri", + default=discovery.V2_DISCOVERY_URI, + help="Override v2 Discovery URI", +) COMPUTE_V1_DISCOVERY_FILE = flags.DEFINE_string( "compute_v1_discovery_file", default=None, - help="Load compute v1 from discovery file") -GCP_UI_URL = flags.DEFINE_string("gcp_ui_url", - default="console.cloud.google.com", - help="Override GCP UI URL.") + help="Load compute v1 from discovery file", +) +GCP_UI_URL = flags.DEFINE_string( + "gcp_ui_url", + default="console.cloud.google.com", + help="Override GCP UI URL.", +) # Type aliases _HttpError = googleapiclient.errors.HttpError @@ -62,20 +72,23 @@ class GcpApiManager: - - def __init__(self, - *, - v1_discovery_uri=None, - v2_discovery_uri=None, - compute_v1_discovery_file=None, - private_api_key_secret_name=None, - gcp_ui_url=None): + def __init__( + self, + *, + v1_discovery_uri=None, + v2_discovery_uri=None, + compute_v1_discovery_file=None, + private_api_key_secret_name=None, + gcp_ui_url=None, + ): self.v1_discovery_uri = v1_discovery_uri or V1_DISCOVERY_URI.value self.v2_discovery_uri = v2_discovery_uri or V2_DISCOVERY_URI.value - self.compute_v1_discovery_file = (compute_v1_discovery_file or - COMPUTE_V1_DISCOVERY_FILE.value) - self.private_api_key_secret_name = (private_api_key_secret_name or - PRIVATE_API_KEY_SECRET_NAME.value) + self.compute_v1_discovery_file = ( + compute_v1_discovery_file or COMPUTE_V1_DISCOVERY_FILE.value + ) + self.private_api_key_secret_name = ( + private_api_key_secret_name or PRIVATE_API_KEY_SECRET_NAME.value + ) self.gcp_ui_url = gcp_ui_url or GCP_UI_URL.value # TODO(sergiitk): add options to pass google Credentials self._exit_stack = contextlib.ExitStack() @@ -97,65 +110,70 @@ def private_api_key(self): https://console.cloud.google.com/security/secret-manager """ if not self.private_api_key_secret_name: - raise ValueError('private_api_key_secret_name must be set to ' - 'access private_api_key.') + raise ValueError( + "private_api_key_secret_name must be set to " + "access private_api_key." + ) - secrets_api = self.secrets('v1') + secrets_api = self.secrets("v1") version_resource_path = secrets_api.secret_version_path( **secrets_api.parse_secret_path(self.private_api_key_secret_name), - secret_version='latest') + secret_version="latest", + ) secret: secretmanager_v1.AccessSecretVersionResponse secret = secrets_api.access_secret_version(name=version_resource_path) return secret.payload.data.decode() @functools.lru_cache(None) def compute(self, version): - api_name = 'compute' - if version == 'v1': + api_name = "compute" + if version == "v1": if self.compute_v1_discovery_file: return self._build_from_file(self.compute_v1_discovery_file) else: return self._build_from_discovery_v1(api_name, version) - elif version == 'v1alpha': - return self._build_from_discovery_v1(api_name, 'alpha') + elif version == "v1alpha": + return self._build_from_discovery_v1(api_name, "alpha") - raise NotImplementedError(f'Compute {version} not supported') + raise NotImplementedError(f"Compute {version} not supported") @functools.lru_cache(None) def networksecurity(self, version): - api_name = 'networksecurity' - if version == 'v1alpha1': + api_name = "networksecurity" + if version == "v1alpha1": return self._build_from_discovery_v2( api_name, version, api_key=self.private_api_key, - visibility_labels=['NETWORKSECURITY_ALPHA']) - elif version == 'v1beta1': + visibility_labels=["NETWORKSECURITY_ALPHA"], + ) + elif version == "v1beta1": return self._build_from_discovery_v2(api_name, version) - raise NotImplementedError(f'Network Security {version} not supported') + raise NotImplementedError(f"Network Security {version} not supported") @functools.lru_cache(None) def networkservices(self, version): - api_name = 'networkservices' - if version == 'v1alpha1': + api_name = "networkservices" + if version == "v1alpha1": return self._build_from_discovery_v2( api_name, version, api_key=self.private_api_key, - visibility_labels=['NETWORKSERVICES_ALPHA']) - elif version == 'v1beta1': + visibility_labels=["NETWORKSERVICES_ALPHA"], + ) + elif version == "v1beta1": return self._build_from_discovery_v2(api_name, version) - raise NotImplementedError(f'Network Services {version} not supported') + raise NotImplementedError(f"Network Services {version} not supported") @staticmethod @functools.lru_cache(None) def secrets(version: str): - if version == 'v1': + if version == "v1": return secretmanager_v1.SecretManagerServiceClient() - raise NotImplementedError(f'Secret Manager {version} not supported') + raise NotImplementedError(f"Secret Manager {version} not supported") @functools.lru_cache(None) def iam(self, version: str) -> discovery.Resource: @@ -164,48 +182,54 @@ def iam(self, version: str) -> discovery.Resource: https://cloud.google.com/iam/docs/reference/rest https://googleapis.github.io/google-api-python-client/docs/dyn/iam_v1.html """ - api_name = 'iam' - if version == 'v1': + api_name = "iam" + if version == "v1": return self._build_from_discovery_v1(api_name, version) raise NotImplementedError( - f'Identity and Access Management (IAM) {version} not supported') + f"Identity and Access Management (IAM) {version} not supported" + ) def _build_from_discovery_v1(self, api_name, version): - api = discovery.build(api_name, - version, - cache_discovery=False, - discoveryServiceUrl=self.v1_discovery_uri) + api = discovery.build( + api_name, + version, + cache_discovery=False, + discoveryServiceUrl=self.v1_discovery_uri, + ) self._exit_stack.enter_context(api) return api - def _build_from_discovery_v2(self, - api_name, - version, - *, - api_key: Optional[str] = None, - visibility_labels: Optional[List] = None): + def _build_from_discovery_v2( + self, + api_name, + version, + *, + api_key: Optional[str] = None, + visibility_labels: Optional[List] = None, + ): params = {} if api_key: - params['key'] = api_key + params["key"] = api_key if visibility_labels: # Dash-separated list of labels. - params['labels'] = '_'.join(visibility_labels) + params["labels"] = "_".join(visibility_labels) - params_str = '' + params_str = "" if params: - params_str = '&' + ('&'.join(f'{k}={v}' for k, v in params.items())) + params_str = "&" + "&".join(f"{k}={v}" for k, v in params.items()) api = discovery.build( api_name, version, cache_discovery=False, - discoveryServiceUrl=f'{self.v2_discovery_uri}{params_str}') + discoveryServiceUrl=f"{self.v2_discovery_uri}{params_str}", + ) self._exit_stack.enter_context(api) return api def _build_from_file(self, discovery_file): - with open(discovery_file, 'r') as f: + with open(discovery_file, "r") as f: api = discovery.build_from_document(f.read()) self._exit_stack.enter_context(api) return api @@ -217,6 +241,7 @@ class Error(Exception): class ResponseError(Error): """The response was not a 2xx.""" + reason: str uri: str error_details: Optional[str] @@ -238,12 +263,15 @@ def __init__(self, cause: _HttpError): super().__init__() def __repr__(self): - return (f'') + return ( + f"' + ) class TransportError(Error): """A transport error has occurred.""" + cause: _HttpLib2Error def __init__(self, cause: _HttpLib2Error): @@ -251,7 +279,7 @@ def __init__(self, cause: _HttpLib2Error): super().__init__() def __repr__(self): - return f'' + return f"" class OperationError(Error): @@ -262,6 +290,7 @@ class OperationError(Error): https://cloud.google.com/apis/design/design_patterns#long_running_operations https://github.com/googleapis/googleapis/blob/master/google/longrunning/operations.proto """ + api_name: str name: str metadata: Any @@ -274,11 +303,11 @@ def __init__(self, api_name: str, response: dict): # Operation.metadata field is Any specific to the API. It may not be # present in the default descriptor pool, and that's expected. # To avoid json_format.ParseError, handle it separately. - self.metadata = response.pop('metadata', {}) + self.metadata = response.pop("metadata", {}) # Must be after removing metadata field. operation: Operation = self._parse_operation_response(response) - self.name = operation.name or 'unknown' + self.name = operation.name or "unknown" self.code_name = code_pb2.Code.Name(operation.error.code) self.error = operation.error super().__init__() @@ -290,37 +319,45 @@ def _parse_operation_response(operation_response: dict) -> Operation: operation_response, Operation(), ignore_unknown_fields=True, - descriptor_pool=error_details_pb2.DESCRIPTOR.pool) + descriptor_pool=error_details_pb2.DESCRIPTOR.pool, + ) except (json_format.Error, TypeError) as e: # Swallow parsing errors if any. Building correct OperationError() # is more important than losing debug information. Details still # can be extracted from the warning. logger.warning( - ("Can't parse response while processing OperationError: '%r', " - "error %r"), operation_response, e) + ( + "Can't parse response while processing OperationError:" + " '%r', error %r" + ), + operation_response, + e, + ) return Operation() def __str__(self): - indent_l1 = ' ' * 2 + indent_l1 = " " * 2 indent_l2 = indent_l1 * 2 - result = (f'{self.api_name} operation "{self.name}" failed.\n' - f'{indent_l1}code: {self.error.code} ({self.code_name})\n' - f'{indent_l1}message: "{self.error.message}"') + result = ( + f'{self.api_name} operation "{self.name}" failed.\n' + f"{indent_l1}code: {self.error.code} ({self.code_name})\n" + f'{indent_l1}message: "{self.error.message}"' + ) if self.error.details: - result += f'\n{indent_l1}details: [\n' + result += f"\n{indent_l1}details: [\n" for any_error in self.error.details: error_str = json_format.MessageToJson(any_error) for line in error_str.splitlines(): - result += indent_l2 + line + '\n' - result += f'{indent_l1}]' + result += indent_l2 + line + "\n" + result += f"{indent_l1}]" if self.metadata: - result += f'\n metadata: \n' + result += f"\n metadata: \n" metadata_str = json.dumps(self.metadata, indent=2) for line in metadata_str.splitlines(): - result += indent_l2 + line + '\n' + result += indent_l2 + line + "\n" result = result.rstrip() return result @@ -340,10 +377,11 @@ def __init__(self, api: discovery.Resource, project: str): # TODO(sergiitk): in upcoming GCP refactoring, differentiate between # _execute for LRO (Long Running Operations), and immediate operations. def _execute( - self, - request: HttpRequest, - *, - num_retries: Optional[int] = _GCP_API_RETRIES) -> Dict[str, Any]: + self, + request: HttpRequest, + *, + num_retries: Optional[int] = _GCP_API_RETRIES, + ) -> Dict[str, Any]: """Execute the immediate request. Returns: @@ -368,38 +406,47 @@ def resource_pretty_format(self, body: dict) -> str: return self._highlighter.highlight(yaml_out) @staticmethod - def wait_for_operation(operation_request, - test_success_fn, - timeout_sec=_WAIT_FOR_OPERATION_SEC, - wait_sec=_WAIT_FIXED_SEC): + def wait_for_operation( + operation_request, + test_success_fn, + timeout_sec=_WAIT_FOR_OPERATION_SEC, + wait_sec=_WAIT_FIXED_SEC, + ): retryer = tenacity.Retrying( - retry=(tenacity.retry_if_not_result(test_success_fn) | - tenacity.retry_if_exception_type()), + retry=( + tenacity.retry_if_not_result(test_success_fn) + | tenacity.retry_if_exception_type() + ), wait=tenacity.wait_fixed(wait_sec), stop=tenacity.stop_after_delay(timeout_sec), after=tenacity.after_log(logger, logging.DEBUG), - reraise=True) + reraise=True, + ) return retryer(operation_request.execute) class GcpStandardCloudApiResource(GcpProjectApiResource, metaclass=abc.ABCMeta): - GLOBAL_LOCATION = 'global' + GLOBAL_LOCATION = "global" def parent(self, location: Optional[str] = GLOBAL_LOCATION): if location is None: location = self.GLOBAL_LOCATION - return f'projects/{self.project}/locations/{location}' + return f"projects/{self.project}/locations/{location}" def resource_full_name(self, name, collection_name): - return f'{self.parent()}/{collection_name}/{name}' - - def _create_resource(self, collection: discovery.Resource, body: dict, - **kwargs): - logger.info("Creating %s resource:\n%s", self.api_name, - self.resource_pretty_format(body)) - create_req = collection.create(parent=self.parent(), - body=body, - **kwargs) + return f"{self.parent()}/{collection_name}/{name}" + + def _create_resource( + self, collection: discovery.Resource, body: dict, **kwargs + ): + logger.info( + "Creating %s resource:\n%s", + self.api_name, + self.resource_pretty_format(body), + ) + create_req = collection.create( + parent=self.parent(), body=body, **kwargs + ) self._execute(create_req) @property @@ -414,45 +461,56 @@ def api_version(self) -> str: def _get_resource(self, collection: discovery.Resource, full_name): resource = collection.get(name=full_name).execute() - logger.info('Loaded %s:\n%s', full_name, - self.resource_pretty_format(resource)) + logger.info( + "Loaded %s:\n%s", full_name, self.resource_pretty_format(resource) + ) return resource - def _delete_resource(self, collection: discovery.Resource, - full_name: str) -> bool: + def _delete_resource( + self, collection: discovery.Resource, full_name: str + ) -> bool: logger.debug("Deleting %s", full_name) try: self._execute(collection.delete(name=full_name)) return True except _HttpError as error: if error.resp and error.resp.status == 404: - logger.info('%s not deleted since it does not exist', full_name) + logger.info("%s not deleted since it does not exist", full_name) else: - logger.warning('Failed to delete %s, %r', full_name, error) + logger.warning("Failed to delete %s, %r", full_name, error) return False # TODO(sergiitk): Use ResponseError and TransportError def _execute( # pylint: disable=arguments-differ - self, - request: HttpRequest, - timeout_sec: int = GcpProjectApiResource._WAIT_FOR_OPERATION_SEC): + self, + request: HttpRequest, + timeout_sec: int = GcpProjectApiResource._WAIT_FOR_OPERATION_SEC, + ): operation = request.execute(num_retries=self._GCP_API_RETRIES) - logger.debug('Operation %s', operation) - self._wait(operation['name'], timeout_sec) - - def _wait(self, - operation_id: str, - timeout_sec: int = GcpProjectApiResource._WAIT_FOR_OPERATION_SEC): - logger.info('Waiting %s sec for %s operation id: %s', timeout_sec, - self.api_name, operation_id) - - op_request = self.api.projects().locations().operations().get( - name=operation_id) + logger.debug("Operation %s", operation) + self._wait(operation["name"], timeout_sec) + + def _wait( + self, + operation_id: str, + timeout_sec: int = GcpProjectApiResource._WAIT_FOR_OPERATION_SEC, + ): + logger.info( + "Waiting %s sec for %s operation id: %s", + timeout_sec, + self.api_name, + operation_id, + ) + + op_request = ( + self.api.projects().locations().operations().get(name=operation_id) + ) operation = self.wait_for_operation( operation_request=op_request, - test_success_fn=lambda result: result['done'], - timeout_sec=timeout_sec) + test_success_fn=lambda result: result["done"], + timeout_sec=timeout_sec, + ) - logger.debug('Completed operation: %s', operation) - if 'error' in operation: + logger.debug("Completed operation: %s", operation) + if "error" in operation: raise OperationError(self.api_name, operation) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py index 472f2644d23d5..5492cb5198605 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py @@ -26,7 +26,9 @@ logger = logging.getLogger(__name__) -class ComputeV1(gcp.api.GcpProjectApiResource): # pylint: disable=too-many-public-methods +class ComputeV1( + gcp.api.GcpProjectApiResource +): # pylint: disable=too-many-public-methods # TODO(sergiitk): move someplace better _WAIT_FOR_BACKEND_SEC = 60 * 10 _WAIT_FOR_BACKEND_SLEEP_SEC = 4 @@ -41,10 +43,12 @@ class GcpResource: class ZonalGcpResource(GcpResource): zone: str - def __init__(self, - api_manager: gcp.api.GcpApiManager, - project: str, - version: str = 'v1'): + def __init__( + self, + api_manager: gcp.api.GcpApiManager, + project: str, + version: str = "v1", + ): super().__init__(api_manager.compute(version), project) class HealthCheckProtocol(enum.Enum): @@ -55,361 +59,445 @@ class BackendServiceProtocol(enum.Enum): HTTP2 = enum.auto() GRPC = enum.auto() - def create_health_check(self, - name: str, - protocol: HealthCheckProtocol, - *, - port: Optional[int] = None) -> 'GcpResource': + def create_health_check( + self, + name: str, + protocol: HealthCheckProtocol, + *, + port: Optional[int] = None, + ) -> "GcpResource": if protocol is self.HealthCheckProtocol.TCP: - health_check_field = 'tcpHealthCheck' + health_check_field = "tcpHealthCheck" elif protocol is self.HealthCheckProtocol.GRPC: - health_check_field = 'grpcHealthCheck' + health_check_field = "grpcHealthCheck" else: - raise TypeError(f'Unexpected Health Check protocol: {protocol}') + raise TypeError(f"Unexpected Health Check protocol: {protocol}") health_check_settings = {} if port is None: - health_check_settings['portSpecification'] = 'USE_SERVING_PORT' + health_check_settings["portSpecification"] = "USE_SERVING_PORT" else: - health_check_settings['portSpecification'] = 'USE_FIXED_PORT' - health_check_settings['port'] = port + health_check_settings["portSpecification"] = "USE_FIXED_PORT" + health_check_settings["port"] = port return self._insert_resource( - self.api.healthChecks(), { - 'name': name, - 'type': protocol.name, + self.api.healthChecks(), + { + "name": name, + "type": protocol.name, health_check_field: health_check_settings, - }) + }, + ) def list_health_check(self): return self._list_resource(self.api.healthChecks()) def delete_health_check(self, name): - self._delete_resource(self.api.healthChecks(), 'healthCheck', name) + self._delete_resource(self.api.healthChecks(), "healthCheck", name) - def create_firewall_rule(self, name: str, network_url: str, - source_ranges: List[str], - ports: List[str]) -> Optional['GcpResource']: + def create_firewall_rule( + self, + name: str, + network_url: str, + source_ranges: List[str], + ports: List[str], + ) -> Optional["GcpResource"]: try: return self._insert_resource( - self.api.firewalls(), { - "allowed": [{ - "IPProtocol": "tcp", - "ports": ports - }], + self.api.firewalls(), + { + "allowed": [{"IPProtocol": "tcp", "ports": ports}], "direction": "INGRESS", "name": name, "network": network_url, "priority": 1000, "sourceRanges": source_ranges, - "targetTags": ["allow-health-checks"] - }) + "targetTags": ["allow-health-checks"], + }, + ) except googleapiclient.errors.HttpError as http_error: # TODO(lidiz) use status_code() when we upgrade googleapiclient if http_error.resp.status == 409: - logger.debug('Firewall rule %s already existed', name) + logger.debug("Firewall rule %s already existed", name) return None else: raise def delete_firewall_rule(self, name): - self._delete_resource(self.api.firewalls(), 'firewall', name) + self._delete_resource(self.api.firewalls(), "firewall", name) def create_backend_service_traffic_director( - self, - name: str, - health_check: 'GcpResource', - affinity_header: Optional[str] = None, - protocol: Optional[BackendServiceProtocol] = None, - subset_size: Optional[int] = None, - locality_lb_policies: Optional[List[dict]] = None, - outlier_detection: Optional[dict] = None) -> 'GcpResource': + self, + name: str, + health_check: "GcpResource", + affinity_header: Optional[str] = None, + protocol: Optional[BackendServiceProtocol] = None, + subset_size: Optional[int] = None, + locality_lb_policies: Optional[List[dict]] = None, + outlier_detection: Optional[dict] = None, + ) -> "GcpResource": if not isinstance(protocol, self.BackendServiceProtocol): - raise TypeError(f'Unexpected Backend Service protocol: {protocol}') + raise TypeError(f"Unexpected Backend Service protocol: {protocol}") body = { - 'name': name, - 'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', # Traffic Director - 'healthChecks': [health_check.url], - 'protocol': protocol.name, + "name": name, + "loadBalancingScheme": "INTERNAL_SELF_MANAGED", # Traffic Director + "healthChecks": [health_check.url], + "protocol": protocol.name, } # If affinity header is specified, config the backend service to support # affinity, and set affinity header to the one given. if affinity_header: - body['sessionAffinity'] = 'HEADER_FIELD' - body['localityLbPolicy'] = 'RING_HASH' - body['consistentHash'] = { - 'httpHeaderName': affinity_header, + body["sessionAffinity"] = "HEADER_FIELD" + body["localityLbPolicy"] = "RING_HASH" + body["consistentHash"] = { + "httpHeaderName": affinity_header, } if subset_size: - body['subsetting'] = { - 'policy': 'CONSISTENT_HASH_SUBSETTING', - 'subsetSize': subset_size + body["subsetting"] = { + "policy": "CONSISTENT_HASH_SUBSETTING", + "subsetSize": subset_size, } if locality_lb_policies: - body['localityLbPolicies'] = locality_lb_policies + body["localityLbPolicies"] = locality_lb_policies if outlier_detection: - body['outlierDetection'] = outlier_detection + body["outlierDetection"] = outlier_detection return self._insert_resource(self.api.backendServices(), body) - def get_backend_service_traffic_director(self, name: str) -> 'GcpResource': - return self._get_resource(self.api.backendServices(), - backendService=name) + def get_backend_service_traffic_director(self, name: str) -> "GcpResource": + return self._get_resource( + self.api.backendServices(), backendService=name + ) def patch_backend_service(self, backend_service, body, **kwargs): - self._patch_resource(collection=self.api.backendServices(), - backendService=backend_service.name, - body=body, - **kwargs) + self._patch_resource( + collection=self.api.backendServices(), + backendService=backend_service.name, + body=body, + **kwargs, + ) def backend_service_patch_backends( - self, - backend_service, - backends, - max_rate_per_endpoint: Optional[int] = None): + self, + backend_service, + backends, + max_rate_per_endpoint: Optional[int] = None, + ): if max_rate_per_endpoint is None: max_rate_per_endpoint = 5 - backend_list = [{ - 'group': backend.url, - 'balancingMode': 'RATE', - 'maxRatePerEndpoint': max_rate_per_endpoint - } for backend in backends] + backend_list = [ + { + "group": backend.url, + "balancingMode": "RATE", + "maxRatePerEndpoint": max_rate_per_endpoint, + } + for backend in backends + ] - self._patch_resource(collection=self.api.backendServices(), - body={'backends': backend_list}, - backendService=backend_service.name) + self._patch_resource( + collection=self.api.backendServices(), + body={"backends": backend_list}, + backendService=backend_service.name, + ) def backend_service_remove_all_backends(self, backend_service): - self._patch_resource(collection=self.api.backendServices(), - body={'backends': []}, - backendService=backend_service.name) + self._patch_resource( + collection=self.api.backendServices(), + body={"backends": []}, + backendService=backend_service.name, + ) def delete_backend_service(self, name): - self._delete_resource(self.api.backendServices(), 'backendService', - name) + self._delete_resource( + self.api.backendServices(), "backendService", name + ) def create_url_map( self, name: str, matcher_name: str, src_hosts, - dst_default_backend_service: 'GcpResource', - dst_host_rule_match_backend_service: Optional['GcpResource'] = None, - ) -> 'GcpResource': + dst_default_backend_service: "GcpResource", + dst_host_rule_match_backend_service: Optional["GcpResource"] = None, + ) -> "GcpResource": if dst_host_rule_match_backend_service is None: dst_host_rule_match_backend_service = dst_default_backend_service return self._insert_resource( - self.api.urlMaps(), { - 'name': - name, - 'defaultService': - dst_default_backend_service.url, - 'hostRules': [{ - 'hosts': src_hosts, - 'pathMatcher': matcher_name, - }], - 'pathMatchers': [{ - 'name': matcher_name, - 'defaultService': dst_host_rule_match_backend_service.url, - }], - }) - - def create_url_map_with_content(self, url_map_body: Any) -> 'GcpResource': + self.api.urlMaps(), + { + "name": name, + "defaultService": dst_default_backend_service.url, + "hostRules": [ + { + "hosts": src_hosts, + "pathMatcher": matcher_name, + } + ], + "pathMatchers": [ + { + "name": matcher_name, + "defaultService": dst_host_rule_match_backend_service.url, + } + ], + }, + ) + + def create_url_map_with_content(self, url_map_body: Any) -> "GcpResource": return self._insert_resource(self.api.urlMaps(), url_map_body) - def patch_url_map(self, url_map: 'GcpResource', body, **kwargs): - self._patch_resource(collection=self.api.urlMaps(), - urlMap=url_map.name, - body=body, - **kwargs) + def patch_url_map(self, url_map: "GcpResource", body, **kwargs): + self._patch_resource( + collection=self.api.urlMaps(), + urlMap=url_map.name, + body=body, + **kwargs, + ) def delete_url_map(self, name): - self._delete_resource(self.api.urlMaps(), 'urlMap', name) + self._delete_resource(self.api.urlMaps(), "urlMap", name) def create_target_grpc_proxy( self, name: str, - url_map: 'GcpResource', + url_map: "GcpResource", validate_for_proxyless: bool = True, - ) -> 'GcpResource': + ) -> "GcpResource": return self._insert_resource( - self.api.targetGrpcProxies(), { - 'name': name, - 'url_map': url_map.url, - 'validate_for_proxyless': validate_for_proxyless, - }) + self.api.targetGrpcProxies(), + { + "name": name, + "url_map": url_map.url, + "validate_for_proxyless": validate_for_proxyless, + }, + ) def delete_target_grpc_proxy(self, name): - self._delete_resource(self.api.targetGrpcProxies(), 'targetGrpcProxy', - name) + self._delete_resource( + self.api.targetGrpcProxies(), "targetGrpcProxy", name + ) def create_target_http_proxy( self, name: str, - url_map: 'GcpResource', - ) -> 'GcpResource': - return self._insert_resource(self.api.targetHttpProxies(), { - 'name': name, - 'url_map': url_map.url, - }) + url_map: "GcpResource", + ) -> "GcpResource": + return self._insert_resource( + self.api.targetHttpProxies(), + { + "name": name, + "url_map": url_map.url, + }, + ) def delete_target_http_proxy(self, name): - self._delete_resource(self.api.targetHttpProxies(), 'targetHttpProxy', - name) - - def create_forwarding_rule(self, - name: str, - src_port: int, - target_proxy: 'GcpResource', - network_url: str, - *, - ip_address: str = '0.0.0.0') -> 'GcpResource': + self._delete_resource( + self.api.targetHttpProxies(), "targetHttpProxy", name + ) + + def create_forwarding_rule( + self, + name: str, + src_port: int, + target_proxy: "GcpResource", + network_url: str, + *, + ip_address: str = "0.0.0.0", + ) -> "GcpResource": return self._insert_resource( self.api.globalForwardingRules(), { - 'name': name, - 'loadBalancingScheme': - 'INTERNAL_SELF_MANAGED', # Traffic Director - 'portRange': src_port, - 'IPAddress': ip_address, - 'network': network_url, - 'target': target_proxy.url, - }) + "name": name, + "loadBalancingScheme": "INTERNAL_SELF_MANAGED", # Traffic Director + "portRange": src_port, + "IPAddress": ip_address, + "network": network_url, + "target": target_proxy.url, + }, + ) def exists_forwarding_rule(self, src_port) -> bool: # TODO(sergiitk): Better approach for confirming the port is available. # It's possible a rule allocates actual port range, e.g 8000-9000, # and this wouldn't catch it. For now, we assume there's no # port ranges used in the project. - filter_str = (f'(portRange eq "{src_port}-{src_port}") ' - f'(IPAddress eq "0.0.0.0")' - f'(loadBalancingScheme eq "INTERNAL_SELF_MANAGED")') - return self._exists_resource(self.api.globalForwardingRules(), - filter=filter_str) + filter_str = ( + f'(portRange eq "{src_port}-{src_port}") ' + '(IPAddress eq "0.0.0.0")' + '(loadBalancingScheme eq "INTERNAL_SELF_MANAGED")' + ) + return self._exists_resource( + self.api.globalForwardingRules(), resource_filter=filter_str + ) def delete_forwarding_rule(self, name): - self._delete_resource(self.api.globalForwardingRules(), - 'forwardingRule', name) - - def wait_for_network_endpoint_group(self, - name: str, - zone: str, - *, - timeout_sec=_WAIT_FOR_BACKEND_SEC, - wait_sec=_WAIT_FOR_BACKEND_SLEEP_SEC): + self._delete_resource( + self.api.globalForwardingRules(), "forwardingRule", name + ) + + def wait_for_network_endpoint_group( + self, + name: str, + zone: str, + *, + timeout_sec=_WAIT_FOR_BACKEND_SEC, + wait_sec=_WAIT_FOR_BACKEND_SLEEP_SEC, + ): retryer = retryers.constant_retryer( wait_fixed=datetime.timedelta(seconds=wait_sec), timeout=datetime.timedelta(seconds=timeout_sec), - check_result=lambda neg: neg and neg.get('size', 0) > 0) + check_result=lambda neg: neg and neg.get("size", 0) > 0, + ) network_endpoint_group = retryer( - self._retry_network_endpoint_group_ready, name, zone) + self._retry_network_endpoint_group_ready, name, zone + ) # TODO(sergiitk): dataclass - return self.ZonalGcpResource(network_endpoint_group['name'], - network_endpoint_group['selfLink'], zone) + return self.ZonalGcpResource( + network_endpoint_group["name"], + network_endpoint_group["selfLink"], + zone, + ) def _retry_network_endpoint_group_ready(self, name: str, zone: str): try: neg = self.get_network_endpoint_group(name, zone) logger.debug( - 'Waiting for endpoints: NEG %s in zone %s, ' - 'current count %s', neg['name'], zone, neg.get('size')) + "Waiting for endpoints: NEG %s in zone %s, current count %s", + neg["name"], + zone, + neg.get("size"), + ) except googleapiclient.errors.HttpError as error: # noinspection PyProtectedMember reason = error._get_reason() - logger.debug('Retrying NEG load, got %s, details %s', - error.resp.status, reason) + logger.debug( + "Retrying NEG load, got %s, details %s", + error.resp.status, + reason, + ) raise return neg def get_network_endpoint_group(self, name, zone): - neg = self.api.networkEndpointGroups().get(project=self.project, - networkEndpointGroup=name, - zone=zone).execute() + neg = ( + self.api.networkEndpointGroups() + .get(project=self.project, networkEndpointGroup=name, zone=zone) + .execute() + ) # TODO(sergiitk): dataclass return neg def wait_for_backends_healthy_status( - self, - backend_service: GcpResource, - backends: Set[ZonalGcpResource], - *, - timeout_sec: int = _WAIT_FOR_BACKEND_SEC, - wait_sec: int = _WAIT_FOR_BACKEND_SLEEP_SEC): + self, + backend_service: GcpResource, + backends: Set[ZonalGcpResource], + *, + timeout_sec: int = _WAIT_FOR_BACKEND_SEC, + wait_sec: int = _WAIT_FOR_BACKEND_SLEEP_SEC, + ): retryer = retryers.constant_retryer( wait_fixed=datetime.timedelta(seconds=wait_sec), timeout=datetime.timedelta(seconds=timeout_sec), - check_result=lambda result: result) + check_result=lambda result: result, + ) pending = set(backends) retryer(self._retry_backends_health, backend_service, pending) - def _retry_backends_health(self, backend_service: GcpResource, - pending: Set[ZonalGcpResource]): + def _retry_backends_health( + self, backend_service: GcpResource, pending: Set[ZonalGcpResource] + ): for backend in pending: result = self.get_backend_service_backend_health( - backend_service, backend) - if 'healthStatus' not in result: - logger.debug('Waiting for instances: backend %s, zone %s', - backend.name, backend.zone) + backend_service, backend + ) + if "healthStatus" not in result: + logger.debug( + "Waiting for instances: backend %s, zone %s", + backend.name, + backend.zone, + ) continue backend_healthy = True - for instance in result['healthStatus']: - logger.debug('Backend %s in zone %s: instance %s:%s health: %s', - backend.name, backend.zone, instance['ipAddress'], - instance['port'], instance['healthState']) - if instance['healthState'] != 'HEALTHY': + for instance in result["healthStatus"]: + logger.debug( + "Backend %s in zone %s: instance %s:%s health: %s", + backend.name, + backend.zone, + instance["ipAddress"], + instance["port"], + instance["healthState"], + ) + if instance["healthState"] != "HEALTHY": backend_healthy = False if backend_healthy: - logger.info('Backend %s in zone %s reported healthy', - backend.name, backend.zone) + logger.info( + "Backend %s in zone %s reported healthy", + backend.name, + backend.zone, + ) pending.remove(backend) return not pending def get_backend_service_backend_health(self, backend_service, backend): - return self.api.backendServices().getHealth( - project=self.project, - backendService=backend_service.name, - body={ - "group": backend.url - }).execute() - - def _get_resource(self, collection: discovery.Resource, - **kwargs) -> 'GcpResource': + return ( + self.api.backendServices() + .getHealth( + project=self.project, + backendService=backend_service.name, + body={"group": backend.url}, + ) + .execute() + ) + + def _get_resource( + self, collection: discovery.Resource, **kwargs + ) -> "GcpResource": resp = collection.get(project=self.project, **kwargs).execute() - logger.info('Loaded compute resource:\n%s', - self.resource_pretty_format(resp)) - return self.GcpResource(resp['name'], resp['selfLink']) + logger.info( + "Loaded compute resource:\n%s", self.resource_pretty_format(resp) + ) + return self.GcpResource(resp["name"], resp["selfLink"]) def _exists_resource( - self, collection: discovery.Resource, filter: str) -> bool: # pylint: disable=redefined-builtin + self, collection: discovery.Resource, resource_filter: str + ) -> bool: resp = collection.list( - project=self.project, filter=filter, - maxResults=1).execute(num_retries=self._GCP_API_RETRIES) - if 'kind' not in resp: + project=self.project, filter=resource_filter, maxResults=1 + ).execute(num_retries=self._GCP_API_RETRIES) + if "kind" not in resp: # TODO(sergiitk): better error raise ValueError('List response "kind" is missing') - return 'items' in resp and resp['items'] - - def _insert_resource(self, collection: discovery.Resource, - body: Dict[str, Any]) -> 'GcpResource': - logger.info('Creating compute resource:\n%s', - self.resource_pretty_format(body)) + return "items" in resp and resp["items"] + + def _insert_resource( + self, collection: discovery.Resource, body: Dict[str, Any] + ) -> "GcpResource": + logger.info( + "Creating compute resource:\n%s", self.resource_pretty_format(body) + ) resp = self._execute(collection.insert(project=self.project, body=body)) - return self.GcpResource(body['name'], resp['targetLink']) + return self.GcpResource(body["name"], resp["targetLink"]) def _patch_resource(self, collection, body, **kwargs): - logger.info('Patching compute resource:\n%s', - self.resource_pretty_format(body)) + logger.info( + "Patching compute resource:\n%s", self.resource_pretty_format(body) + ) self._execute( - collection.patch(project=self.project, body=body, **kwargs)) + collection.patch(project=self.project, body=body, **kwargs) + ) def _list_resource(self, collection: discovery.Resource): return collection.list(project=self.project).execute( - num_retries=self._GCP_API_RETRIES) + num_retries=self._GCP_API_RETRIES + ) - def _delete_resource(self, collection: discovery.Resource, - resource_type: str, resource_name: str) -> bool: + def _delete_resource( + self, + collection: discovery.Resource, + resource_type: str, + resource_name: str, + ) -> bool: try: params = {"project": self.project, resource_type: resource_name} self._execute(collection.delete(**params)) @@ -418,43 +506,53 @@ def _delete_resource(self, collection: discovery.Resource, if error.resp and error.resp.status == 404: logger.info( 'Resource %s "%s" not deleted since it does not exist', - resource_type, resource_name) + resource_type, + resource_name, + ) else: - logger.warning('Failed to delete %s "%s", %r', resource_type, - resource_name, error) + logger.warning( + 'Failed to delete %s "%s", %r', + resource_type, + resource_name, + error, + ) return False @staticmethod def _operation_status_done(operation): - return 'status' in operation and operation['status'] == 'DONE' + return "status" in operation and operation["status"] == "DONE" def _execute( # pylint: disable=arguments-differ - self, - request, - *, - timeout_sec=_WAIT_FOR_OPERATION_SEC): + self, request, *, timeout_sec=_WAIT_FOR_OPERATION_SEC + ): operation = request.execute(num_retries=self._GCP_API_RETRIES) - logger.debug('Operation %s', operation) - return self._wait(operation['name'], timeout_sec) - - def _wait(self, - operation_id: str, - timeout_sec: int = _WAIT_FOR_OPERATION_SEC) -> dict: - logger.info('Waiting %s sec for compute operation id: %s', timeout_sec, - operation_id) + logger.debug("Operation %s", operation) + return self._wait(operation["name"], timeout_sec) + + def _wait( + self, operation_id: str, timeout_sec: int = _WAIT_FOR_OPERATION_SEC + ) -> dict: + logger.info( + "Waiting %s sec for compute operation id: %s", + timeout_sec, + operation_id, + ) # TODO(sergiitk) try using wait() here # https://googleapis.github.io/google-api-python-client/docs/dyn/compute_v1.globalOperations.html#wait - op_request = self.api.globalOperations().get(project=self.project, - operation=operation_id) + op_request = self.api.globalOperations().get( + project=self.project, operation=operation_id + ) operation = self.wait_for_operation( operation_request=op_request, test_success_fn=self._operation_status_done, - timeout_sec=timeout_sec) + timeout_sec=timeout_sec, + ) - logger.debug('Completed operation: %s', operation) - if 'error' in operation: + logger.debug("Completed operation: %s", operation) + if "error" in operation: # This shouldn't normally happen: gcp library raises on errors. - raise Exception(f'Compute operation {operation_id} ' - f'failed: {operation}') + raise Exception( + f"Compute operation {operation_id} failed: {operation}" + ) return operation diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/iam.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/iam.py index acd319519d997..fc153da08d747 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/iam.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/iam.py @@ -36,24 +36,26 @@ class EtagConflict(gcp.api.Error): def handle_etag_conflict(func): - def wrap_retry_on_etag_conflict(*args, **kwargs): retryer = retryers.exponential_retryer_with_timeout( retry_on_exceptions=(EtagConflict, gcp.api.TransportError), wait_min=_timedelta(seconds=1), wait_max=_timedelta(seconds=10), - timeout=_timedelta(minutes=2)) + timeout=_timedelta(minutes=2), + ) return retryer(func, *args, **kwargs) return wrap_retry_on_etag_conflict -def _replace_binding(policy: 'Policy', binding: 'Policy.Binding', - new_binding: 'Policy.Binding') -> 'Policy': +def _replace_binding( + policy: "Policy", binding: "Policy.Binding", new_binding: "Policy.Binding" +) -> "Policy": new_bindings = set(policy.bindings) new_bindings.discard(binding) new_bindings.add(new_binding) - return dataclasses.replace(policy, bindings=frozenset(new_bindings)) # pylint: disable=too-many-function-args + # pylint: disable=too-many-function-args # No idea why pylint is like that. + return dataclasses.replace(policy, bindings=frozenset(new_bindings)) @dataclasses.dataclass(frozen=True) @@ -63,25 +65,28 @@ class ServiceAccount: https://cloud.google.com/iam/docs/reference/rest/v1/projects.serviceAccounts Note: "etag" field is skipped because it's deprecated """ + name: str projectId: str uniqueId: str email: str oauth2ClientId: str - displayName: str = '' - description: str = '' + displayName: str = "" + description: str = "" disabled: bool = False @classmethod - def from_response(cls, response: Dict[str, Any]) -> 'ServiceAccount': - return cls(name=response['name'], - projectId=response['projectId'], - uniqueId=response['uniqueId'], - email=response['email'], - oauth2ClientId=response['oauth2ClientId'], - description=response.get('description', ''), - displayName=response.get('displayName', ''), - disabled=response.get('disabled', False)) + def from_response(cls, response: Dict[str, Any]) -> "ServiceAccount": + return cls( + name=response["name"], + projectId=response["projectId"], + uniqueId=response["uniqueId"], + email=response["email"], + oauth2ClientId=response["oauth2ClientId"], + description=response.get("description", ""), + displayName=response.get("displayName", ""), + disabled=response.get("disabled", False), + ) def as_dict(self) -> Dict[str, Any]: return dataclasses.asdict(self) @@ -94,13 +99,14 @@ class Expr: https://cloud.google.com/iam/docs/reference/rest/v1/Expr """ + expression: str - title: str = '' - description: str = '' - location: str = '' + title: str = "" + description: str = "" + location: str = "" @classmethod - def from_response(cls, response: Dict[str, Any]) -> 'Expr': + def from_response(cls, response: Dict[str, Any]) -> "Expr": return cls(**response) def as_dict(self) -> Dict[str, Any]: @@ -122,28 +128,29 @@ class Binding: https://cloud.google.com/iam/docs/reference/rest/v1/Policy#binding """ + role: str members: FrozenSet[str] condition: Optional[Expr] = None @classmethod - def from_response(cls, response: Dict[str, Any]) -> 'Policy.Binding': + def from_response(cls, response: Dict[str, Any]) -> "Policy.Binding": fields = { - 'role': response['role'], - 'members': frozenset(response.get('members', [])), + "role": response["role"], + "members": frozenset(response.get("members", [])), } - if 'condition' in response: - fields['condition'] = Expr.from_response(response['condition']) + if "condition" in response: + fields["condition"] = Expr.from_response(response["condition"]) return cls(**fields) def as_dict(self) -> Dict[str, Any]: result = { - 'role': self.role, - 'members': list(self.members), + "role": self.role, + "members": list(self.members), } if self.condition is not None: - result['condition'] = self.condition.as_dict() + result["condition"] = self.condition.as_dict() return result bindings: FrozenSet[Binding] @@ -152,28 +159,33 @@ def as_dict(self) -> Dict[str, Any]: @functools.lru_cache(maxsize=128) def find_binding_for_role( - self, - role: str, - condition: Optional[Expr] = None) -> Optional['Policy.Binding']: - results = (binding for binding in self.bindings - if binding.role == role and binding.condition == condition) + self, role: str, condition: Optional[Expr] = None + ) -> Optional["Policy.Binding"]: + results = ( + binding + for binding in self.bindings + if binding.role == role and binding.condition == condition + ) return next(results, None) @classmethod - def from_response(cls, response: Dict[str, Any]) -> 'Policy': + def from_response(cls, response: Dict[str, Any]) -> "Policy": bindings = frozenset( - cls.Binding.from_response(b) for b in response.get('bindings', [])) - return cls(bindings=bindings, - etag=response['etag'], - version=response.get('version')) + cls.Binding.from_response(b) for b in response.get("bindings", []) + ) + return cls( + bindings=bindings, + etag=response["etag"], + version=response.get("version"), + ) def as_dict(self) -> Dict[str, Any]: result = { - 'bindings': [binding.as_dict() for binding in self.bindings], - 'etag': self.etag, + "bindings": [binding.as_dict() for binding in self.bindings], + "etag": self.etag, } if self.version is not None: - result['version'] = self.version + result["version"] = self.version return result @@ -183,6 +195,7 @@ class IamV1(gcp.api.GcpProjectApiResource): https://cloud.google.com/iam/docs/reference/rest """ + _service_accounts: gcp.api.discovery.Resource # Operations that affect conditional role bindings must specify version 3. @@ -192,7 +205,7 @@ class IamV1(gcp.api.GcpProjectApiResource): POLICY_VERSION: int = 3 def __init__(self, api_manager: gcp.api.GcpApiManager, project: str): - super().__init__(api_manager.iam('v1'), project) + super().__init__(api_manager.iam("v1"), project) # Shortcut to projects/*/serviceAccounts/ endpoints self._service_accounts = self.api.projects().serviceAccounts() @@ -209,39 +222,48 @@ def service_account_resource_name(self, account) -> str: Args: account: The ACCOUNT value """ - return f'projects/{self.project}/serviceAccounts/{account}' + return f"projects/{self.project}/serviceAccounts/{account}" def get_service_account(self, account: str) -> ServiceAccount: resource_name = self.service_account_resource_name(account) request: _HttpRequest = self._service_accounts.get(name=resource_name) response: Dict[str, Any] = self._execute(request) - logger.debug('Loaded Service Account:\n%s', - self.resource_pretty_format(response)) + logger.debug( + "Loaded Service Account:\n%s", self.resource_pretty_format(response) + ) return ServiceAccount.from_response(response) def get_service_account_iam_policy(self, account: str) -> Policy: resource_name = self.service_account_resource_name(account) request: _HttpRequest = self._service_accounts.getIamPolicy( resource=resource_name, - options_requestedPolicyVersion=self.POLICY_VERSION) + options_requestedPolicyVersion=self.POLICY_VERSION, + ) response: Dict[str, Any] = self._execute(request) - logger.debug('Loaded Service Account Policy:\n%s', - self.resource_pretty_format(response)) + logger.debug( + "Loaded Service Account Policy:\n%s", + self.resource_pretty_format(response), + ) return Policy.from_response(response) - def set_service_account_iam_policy(self, account: str, - policy: Policy) -> Policy: + def set_service_account_iam_policy( + self, account: str, policy: Policy + ) -> Policy: """Sets the IAM policy that is attached to a service account. https://cloud.google.com/iam/docs/reference/rest/v1/projects.serviceAccounts/setIamPolicy """ resource_name = self.service_account_resource_name(account) - body = {'policy': policy.as_dict()} - logger.debug('Updating Service Account %s policy:\n%s', account, - self.resource_pretty_format(body)) + body = {"policy": policy.as_dict()} + logger.debug( + "Updating Service Account %s policy:\n%s", + account, + self.resource_pretty_format(body), + ) try: request: _HttpRequest = self._service_accounts.setIamPolicy( - resource=resource_name, body=body) + resource=resource_name, body=body + ) response: Dict[str, Any] = self._execute(request) return Policy.from_response(response) except gcp.api.ResponseError as error: @@ -252,8 +274,9 @@ def set_service_account_iam_policy(self, account: str, raise @handle_etag_conflict - def add_service_account_iam_policy_binding(self, account: str, role: str, - member: str) -> None: + def add_service_account_iam_policy_binding( + self, account: str, role: str, member: str + ) -> None: """Add an IAM policy binding to an IAM service account. See for details on updating policy bindings: @@ -262,27 +285,39 @@ def add_service_account_iam_policy_binding(self, account: str, role: str, policy: Policy = self.get_service_account_iam_policy(account) binding: Optional[Policy.Binding] = policy.find_binding_for_role(role) if binding and member in binding.members: - logger.debug('Member %s already has role %s for Service Account %s', - member, role, account) + logger.debug( + "Member %s already has role %s for Service Account %s", + member, + role, + account, + ) return if binding is None: updated_binding = Policy.Binding(role, frozenset([member])) else: updated_members: FrozenSet[str] = binding.members.union({member}) - updated_binding: Policy.Binding = dataclasses.replace( # pylint: disable=too-many-function-args - binding, - members=updated_members) - - updated_policy: Policy = _replace_binding(policy, binding, - updated_binding) + updated_binding: Policy.Binding = ( + dataclasses.replace( # pylint: disable=too-many-function-args + binding, members=updated_members + ) + ) + + updated_policy: Policy = _replace_binding( + policy, binding, updated_binding + ) self.set_service_account_iam_policy(account, updated_policy) - logger.debug('Role %s granted to member %s for Service Account %s', - role, member, account) + logger.debug( + "Role %s granted to member %s for Service Account %s", + role, + member, + account, + ) @handle_etag_conflict - def remove_service_account_iam_policy_binding(self, account: str, role: str, - member: str) -> None: + def remove_service_account_iam_policy_binding( + self, account: str, role: str, member: str + ) -> None: """Remove an IAM policy binding from the IAM policy of a service account. @@ -293,21 +328,34 @@ def remove_service_account_iam_policy_binding(self, account: str, role: str, binding: Optional[Policy.Binding] = policy.find_binding_for_role(role) if binding is None: - logger.debug('Noop: Service Account %s has no bindings for role %s', - account, role) + logger.debug( + "Noop: Service Account %s has no bindings for role %s", + account, + role, + ) return if member not in binding.members: logger.debug( - 'Noop: Service Account %s binding for role %s has no member %s', - account, role, member) + "Noop: Service Account %s binding for role %s has no member %s", + account, + role, + member, + ) return updated_members: FrozenSet[str] = binding.members.difference({member}) - updated_binding: Policy.Binding = dataclasses.replace( # pylint: disable=too-many-function-args - binding, - members=updated_members) - updated_policy: Policy = _replace_binding(policy, binding, - updated_binding) + updated_binding: Policy.Binding = ( + dataclasses.replace( # pylint: disable=too-many-function-args + binding, members=updated_members + ) + ) + updated_policy: Policy = _replace_binding( + policy, binding, updated_binding + ) self.set_service_account_iam_policy(account, updated_policy) - logger.debug('Role %s revoked from member %s for Service Account %s', - role, member, account) + logger.debug( + "Role %s revoked from member %s for Service Account %s", + role, + member, + account, + ) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_security.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_security.py index 1c2d7dbc7e352..b7656bd45c96b 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_security.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_security.py @@ -38,14 +38,17 @@ class ServerTlsPolicy: create_time: str @classmethod - def from_response(cls, name: str, response: Dict[str, - Any]) -> 'ServerTlsPolicy': - return cls(name=name, - url=response['name'], - server_certificate=response.get('serverCertificate', {}), - mtls_policy=response.get('mtlsPolicy', {}), - create_time=response['createTime'], - update_time=response['updateTime']) + def from_response( + cls, name: str, response: Dict[str, Any] + ) -> "ServerTlsPolicy": + return cls( + name=name, + url=response["name"], + server_certificate=response.get("serverCertificate", {}), + mtls_policy=response.get("mtlsPolicy", {}), + create_time=response["createTime"], + update_time=response["updateTime"], + ) @dataclasses.dataclass(frozen=True) @@ -58,14 +61,17 @@ class ClientTlsPolicy: create_time: str @classmethod - def from_response(cls, name: str, response: Dict[str, - Any]) -> 'ClientTlsPolicy': - return cls(name=name, - url=response['name'], - client_certificate=response.get('clientCertificate', {}), - server_validation_ca=response.get('serverValidationCa', []), - create_time=response['createTime'], - update_time=response['updateTime']) + def from_response( + cls, name: str, response: Dict[str, Any] + ) -> "ClientTlsPolicy": + return cls( + name=name, + url=response["name"], + client_certificate=response.get("clientCertificate", {}), + server_validation_ca=response.get("serverValidationCa", []), + create_time=response["createTime"], + update_time=response["updateTime"], + ) @dataclasses.dataclass(frozen=True) @@ -78,18 +84,22 @@ class AuthorizationPolicy: rules: list @classmethod - def from_response(cls, name: str, - response: Dict[str, Any]) -> 'AuthorizationPolicy': - return cls(name=name, - url=response['name'], - create_time=response['createTime'], - update_time=response['updateTime'], - action=response['action'], - rules=response.get('rules', [])) - - -class _NetworkSecurityBase(gcp.api.GcpStandardCloudApiResource, - metaclass=abc.ABCMeta): + def from_response( + cls, name: str, response: Dict[str, Any] + ) -> "AuthorizationPolicy": + return cls( + name=name, + url=response["name"], + create_time=response["createTime"], + update_time=response["updateTime"], + action=response["action"], + rules=response.get("rules", []), + ) + + +class _NetworkSecurityBase( + gcp.api.GcpStandardCloudApiResource, metaclass=abc.ABCMeta +): """Base class for NetworkSecurity APIs.""" # TODO(https://github.com/grpc/grpc/issues/29532) remove pylint disable @@ -102,9 +112,11 @@ def __init__(self, api_manager: gcp.api.GcpApiManager, project: str): @property def api_name(self) -> str: - return 'networksecurity' + return "networksecurity" - def _execute(self, *args, **kwargs): # pylint: disable=signature-differs,arguments-differ + def _execute( + self, *args, **kwargs + ): # pylint: disable=signature-differs,arguments-differ # Workaround TD bug: throttled operations are reported as internal. # Ref b/175345578 retryer = tenacity.Retrying( @@ -112,76 +124,88 @@ def _execute(self, *args, **kwargs): # pylint: disable=signature-differs,argume wait=tenacity.wait_fixed(10), stop=tenacity.stop_after_delay(5 * 60), before_sleep=tenacity.before_sleep_log(logger, logging.DEBUG), - reraise=True) + reraise=True, + ) retryer(super()._execute, *args, **kwargs) @staticmethod def _operation_internal_error(exception): - return (isinstance(exception, gcp.api.OperationError) and - exception.error.code == code_pb2.INTERNAL) + return ( + isinstance(exception, gcp.api.OperationError) + and exception.error.code == code_pb2.INTERNAL + ) class NetworkSecurityV1Beta1(_NetworkSecurityBase): """NetworkSecurity API v1beta1.""" - SERVER_TLS_POLICIES = 'serverTlsPolicies' - CLIENT_TLS_POLICIES = 'clientTlsPolicies' - AUTHZ_POLICIES = 'authorizationPolicies' + SERVER_TLS_POLICIES = "serverTlsPolicies" + CLIENT_TLS_POLICIES = "clientTlsPolicies" + AUTHZ_POLICIES = "authorizationPolicies" @property def api_version(self) -> str: - return 'v1beta1' + return "v1beta1" def create_server_tls_policy(self, name: str, body: dict) -> GcpResource: return self._create_resource( collection=self._api_locations.serverTlsPolicies(), body=body, - serverTlsPolicyId=name) + serverTlsPolicyId=name, + ) def get_server_tls_policy(self, name: str) -> ServerTlsPolicy: response = self._get_resource( collection=self._api_locations.serverTlsPolicies(), - full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES)) + full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES), + ) return ServerTlsPolicy.from_response(name, response) def delete_server_tls_policy(self, name: str) -> bool: return self._delete_resource( collection=self._api_locations.serverTlsPolicies(), - full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES)) + full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES), + ) def create_client_tls_policy(self, name: str, body: dict) -> GcpResource: return self._create_resource( collection=self._api_locations.clientTlsPolicies(), body=body, - clientTlsPolicyId=name) + clientTlsPolicyId=name, + ) def get_client_tls_policy(self, name: str) -> ClientTlsPolicy: response = self._get_resource( collection=self._api_locations.clientTlsPolicies(), - full_name=self.resource_full_name(name, self.CLIENT_TLS_POLICIES)) + full_name=self.resource_full_name(name, self.CLIENT_TLS_POLICIES), + ) return ClientTlsPolicy.from_response(name, response) def delete_client_tls_policy(self, name: str) -> bool: return self._delete_resource( collection=self._api_locations.clientTlsPolicies(), - full_name=self.resource_full_name(name, self.CLIENT_TLS_POLICIES)) + full_name=self.resource_full_name(name, self.CLIENT_TLS_POLICIES), + ) def create_authz_policy(self, name: str, body: dict) -> GcpResource: return self._create_resource( collection=self._api_locations.authorizationPolicies(), body=body, - authorizationPolicyId=name) + authorizationPolicyId=name, + ) def get_authz_policy(self, name: str) -> ClientTlsPolicy: response = self._get_resource( collection=self._api_locations.authorizationPolicies(), - full_name=self.resource_full_name(name, self.AUTHZ_POLICIES)) + full_name=self.resource_full_name(name, self.AUTHZ_POLICIES), + ) return ClientTlsPolicy.from_response(name, response) def delete_authz_policy(self, name: str) -> bool: return self._delete_resource( collection=self._api_locations.authorizationPolicies(), - full_name=self.resource_full_name(name, self.AUTHZ_POLICIES)) + full_name=self.resource_full_name(name, self.AUTHZ_POLICIES), + ) class NetworkSecurityV1Alpha1(NetworkSecurityV1Beta1): @@ -194,4 +218,4 @@ class NetworkSecurityV1Alpha1(NetworkSecurityV1Beta1): @property def api_version(self) -> str: - return 'v1alpha1' + return "v1alpha1" diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_services.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_services.py index f080aaa62c32c..06602d58b0af9 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_services.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/network_services.py @@ -41,28 +41,30 @@ class EndpointPolicy: server_tls_policy: Optional[str] = None @classmethod - def from_response(cls, name: str, response: Dict[str, - Any]) -> 'EndpointPolicy': - return cls(name=name, - url=response['name'], - type=response['type'], - server_tls_policy=response.get('serverTlsPolicy', None), - traffic_port_selector=response['trafficPortSelector'], - endpoint_matcher=response['endpointMatcher'], - http_filters=response.get('httpFilters', None), - update_time=response['updateTime'], - create_time=response['createTime']) + def from_response( + cls, name: str, response: Dict[str, Any] + ) -> "EndpointPolicy": + return cls( + name=name, + url=response["name"], + type=response["type"], + server_tls_policy=response.get("serverTlsPolicy", None), + traffic_port_selector=response["trafficPortSelector"], + endpoint_matcher=response["endpointMatcher"], + http_filters=response.get("httpFilters", None), + update_time=response["updateTime"], + create_time=response["createTime"], + ) @dataclasses.dataclass(frozen=True) class Mesh: - name: str url: str routes: Optional[List[str]] @classmethod - def from_response(cls, name: str, d: Dict[str, Any]) -> 'Mesh': + def from_response(cls, name: str, d: Dict[str, Any]) -> "Mesh": return cls( name=name, url=d["name"], @@ -72,7 +74,6 @@ def from_response(cls, name: str, d: Dict[str, Any]) -> 'Mesh': @dataclasses.dataclass(frozen=True) class GrpcRoute: - @dataclasses.dataclass(frozen=True) class MethodMatch: type: Optional[str] @@ -81,7 +82,7 @@ class MethodMatch: case_sensitive: Optional[bool] @classmethod - def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.MethodMatch': + def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.MethodMatch": return cls( type=d.get("type"), grpc_service=d.get("grpcService"), @@ -96,7 +97,7 @@ class HeaderMatch: value: str @classmethod - def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.HeaderMatch': + def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.HeaderMatch": return cls( type=d.get("type"), key=d["key"], @@ -105,17 +106,20 @@ def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.HeaderMatch': @dataclasses.dataclass(frozen=True) class RouteMatch: - method: Optional['GrpcRoute.MethodMatch'] - headers: Tuple['GrpcRoute.HeaderMatch'] + method: Optional["GrpcRoute.MethodMatch"] + headers: Tuple["GrpcRoute.HeaderMatch"] @classmethod - def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.RouteMatch': + def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.RouteMatch": return cls( method=GrpcRoute.MethodMatch.from_response(d["method"]) - if "method" in d else None, + if "method" in d + else None, headers=tuple( - GrpcRoute.HeaderMatch.from_response(h) - for h in d["headers"]) if "headers" in d else (), + GrpcRoute.HeaderMatch.from_response(h) for h in d["headers"] + ) + if "headers" in d + else (), ) @dataclasses.dataclass(frozen=True) @@ -124,7 +128,7 @@ class Destination: weight: Optional[int] @classmethod - def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.Destination': + def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.Destination": return cls( service_name=d["serviceName"], weight=d.get("weight"), @@ -132,25 +136,30 @@ def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.Destination': @dataclasses.dataclass(frozen=True) class RouteAction: - @classmethod - def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.RouteAction': - destinations = [ - GrpcRoute.Destination.from_response(dest) - for dest in d["destinations"] - ] if "destinations" in d else [] + def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.RouteAction": + destinations = ( + [ + GrpcRoute.Destination.from_response(dest) + for dest in d["destinations"] + ] + if "destinations" in d + else [] + ) return cls(destinations=destinations) @dataclasses.dataclass(frozen=True) class RouteRule: - matches: List['GrpcRoute.RouteMatch'] - action: 'GrpcRoute.RouteAction' + matches: List["GrpcRoute.RouteMatch"] + action: "GrpcRoute.RouteAction" @classmethod - def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.RouteRule': - matches = [ - GrpcRoute.RouteMatch.from_response(m) for m in d["matches"] - ] if "matches" in d else [] + def from_response(cls, d: Dict[str, Any]) -> "GrpcRoute.RouteRule": + matches = ( + [GrpcRoute.RouteMatch.from_response(m) for m in d["matches"]] + if "matches" in d + else [] + ) return cls( matches=matches, action=GrpcRoute.RouteAction.from_response(d["action"]), @@ -159,12 +168,13 @@ def from_response(cls, d: Dict[str, Any]) -> 'GrpcRoute.RouteRule': name: str url: str hostnames: Tuple[str] - rules: Tuple['GrpcRoute.RouteRule'] + rules: Tuple["GrpcRoute.RouteRule"] meshes: Optional[Tuple[str]] @classmethod - def from_response(cls, name: str, d: Dict[str, - Any]) -> 'GrpcRoute.RouteRule': + def from_response( + cls, name: str, d: Dict[str, Any] + ) -> "GrpcRoute.RouteRule": return cls( name=name, url=d["name"], @@ -174,8 +184,9 @@ def from_response(cls, name: str, d: Dict[str, ) -class _NetworkServicesBase(gcp.api.GcpStandardCloudApiResource, - metaclass=abc.ABCMeta): +class _NetworkServicesBase( + gcp.api.GcpStandardCloudApiResource, metaclass=abc.ABCMeta +): """Base class for NetworkServices APIs.""" # TODO(https://github.com/grpc/grpc/issues/29532) remove pylint disable @@ -188,9 +199,11 @@ def __init__(self, api_manager: gcp.api.GcpApiManager, project: str): @property def api_name(self) -> str: - return 'networkservices' + return "networkservices" - def _execute(self, *args, **kwargs): # pylint: disable=signature-differs,arguments-differ + def _execute( + self, *args, **kwargs + ): # pylint: disable=signature-differs,arguments-differ # Workaround TD bug: throttled operations are reported as internal. # Ref b/175345578 retryer = tenacity.Retrying( @@ -198,39 +211,46 @@ def _execute(self, *args, **kwargs): # pylint: disable=signature-differs,argume wait=tenacity.wait_fixed(10), stop=tenacity.stop_after_delay(5 * 60), before_sleep=tenacity.before_sleep_log(logger, logging.DEBUG), - reraise=True) + reraise=True, + ) retryer(super()._execute, *args, **kwargs) @staticmethod def _operation_internal_error(exception): - return (isinstance(exception, gcp.api.OperationError) and - exception.error.code == code_pb2.INTERNAL) + return ( + isinstance(exception, gcp.api.OperationError) + and exception.error.code == code_pb2.INTERNAL + ) class NetworkServicesV1Beta1(_NetworkServicesBase): """NetworkServices API v1beta1.""" - ENDPOINT_POLICIES = 'endpointPolicies' + + ENDPOINT_POLICIES = "endpointPolicies" @property def api_version(self) -> str: - return 'v1beta1' + return "v1beta1" def create_endpoint_policy(self, name, body: dict) -> GcpResource: return self._create_resource( collection=self._api_locations.endpointPolicies(), body=body, - endpointPolicyId=name) + endpointPolicyId=name, + ) def get_endpoint_policy(self, name: str) -> EndpointPolicy: response = self._get_resource( collection=self._api_locations.endpointPolicies(), - full_name=self.resource_full_name(name, self.ENDPOINT_POLICIES)) + full_name=self.resource_full_name(name, self.ENDPOINT_POLICIES), + ) return EndpointPolicy.from_response(name, response) def delete_endpoint_policy(self, name: str) -> bool: return self._delete_resource( collection=self._api_locations.endpointPolicies(), - full_name=self.resource_full_name(name, self.ENDPOINT_POLICIES)) + full_name=self.resource_full_name(name, self.ENDPOINT_POLICIES), + ) class NetworkServicesV1Alpha1(NetworkServicesV1Beta1): @@ -241,42 +261,47 @@ class NetworkServicesV1Alpha1(NetworkServicesV1Beta1): v1alpha1 class can always override and reimplement incompatible methods. """ - GRPC_ROUTES = 'grpcRoutes' - MESHES = 'meshes' + GRPC_ROUTES = "grpcRoutes" + MESHES = "meshes" @property def api_version(self) -> str: - return 'v1alpha1' + return "v1alpha1" def create_mesh(self, name: str, body: dict) -> GcpResource: - return self._create_resource(collection=self._api_locations.meshes(), - body=body, - meshId=name) + return self._create_resource( + collection=self._api_locations.meshes(), body=body, meshId=name + ) def get_mesh(self, name: str) -> Mesh: full_name = self.resource_full_name(name, self.MESHES) - result = self._get_resource(collection=self._api_locations.meshes(), - full_name=full_name) + result = self._get_resource( + collection=self._api_locations.meshes(), full_name=full_name + ) return Mesh.from_response(name, result) def delete_mesh(self, name: str) -> bool: - return self._delete_resource(collection=self._api_locations.meshes(), - full_name=self.resource_full_name( - name, self.MESHES)) + return self._delete_resource( + collection=self._api_locations.meshes(), + full_name=self.resource_full_name(name, self.MESHES), + ) def create_grpc_route(self, name: str, body: dict) -> GcpResource: return self._create_resource( collection=self._api_locations.grpcRoutes(), body=body, - grpcRouteId=name) + grpcRouteId=name, + ) def get_grpc_route(self, name: str) -> GrpcRoute: full_name = self.resource_full_name(name, self.GRPC_ROUTES) - result = self._get_resource(collection=self._api_locations.grpcRoutes(), - full_name=full_name) + result = self._get_resource( + collection=self._api_locations.grpcRoutes(), full_name=full_name + ) return GrpcRoute.from_response(name, result) def delete_grpc_route(self, name: str) -> bool: return self._delete_resource( collection=self._api_locations.grpcRoutes(), - full_name=self.resource_full_name(name, self.GRPC_ROUTES)) + full_name=self.resource_full_name(name, self.GRPC_ROUTES), + ) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s.py index 02443a742a8b4..2b39c3e98c9ea 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s.py @@ -50,8 +50,11 @@ _ApiException = client.ApiException _FailToCreateError = utils.FailToCreateError -_RETRY_ON_EXCEPTIONS = (urllib3.exceptions.HTTPError, _ApiException, - _FailToCreateError) +_RETRY_ON_EXCEPTIONS = ( + urllib3.exceptions.HTTPError, + _ApiException, + _FailToCreateError, +) def _server_restart_retryer() -> retryers.Retrying: @@ -59,7 +62,8 @@ def _server_restart_retryer() -> retryers.Retrying: retry_on_exceptions=_RETRY_ON_EXCEPTIONS, wait_min=_timedelta(seconds=1), wait_max=_timedelta(seconds=10), - timeout=_timedelta(minutes=3)) + timeout=_timedelta(minutes=3), + ) def _too_many_requests_retryer() -> retryers.Retrying: @@ -67,17 +71,20 @@ def _too_many_requests_retryer() -> retryers.Retrying: retry_on_exceptions=_RETRY_ON_EXCEPTIONS, wait_min=_timedelta(seconds=10), wait_max=_timedelta(seconds=30), - timeout=_timedelta(minutes=3)) + timeout=_timedelta(minutes=3), + ) def _quick_recovery_retryer() -> retryers.Retrying: - return retryers.constant_retryer(wait_fixed=_timedelta(seconds=1), - attempts=3, - retry_on_exceptions=_RETRY_ON_EXCEPTIONS) + return retryers.constant_retryer( + wait_fixed=_timedelta(seconds=1), + attempts=3, + retry_on_exceptions=_RETRY_ON_EXCEPTIONS, + ) def label_dict_to_selector(labels: dict) -> str: - return ','.join(f'{k}=={v}' for k, v in labels.items()) + return ",".join(f"{k}=={v}" for k, v in labels.items()) class NotFound(Exception): @@ -117,9 +124,13 @@ def reload(self): @staticmethod def _new_client_from_context(context: str) -> ApiClient: client_instance = kubernetes.config.new_client_from_config( - context=context) - logger.info('Using kubernetes context "%s", active host: %s', context, - client_instance.configuration.host) + context=context + ) + logger.info( + 'Using kubernetes context "%s", active host: %s', + context, + client_instance.configuration.host, + ) # TODO(sergiitk): fine-tune if we see the total wait unreasonably long. client_instance.configuration.retries = 10 return client_instance @@ -130,7 +141,7 @@ class KubernetesNamespace: # pylint: disable=too-many-public-methods _api: KubernetesApiManager _name: str - NEG_STATUS_META = 'cloud.google.com/neg-status' + NEG_STATUS_META = "cloud.google.com/neg-status" DELETE_GRACE_PERIOD_SEC: int = 5 WAIT_SHORT_TIMEOUT_SEC: int = 60 WAIT_SHORT_SLEEP_SEC: int = 1 @@ -150,13 +161,13 @@ def name(self): return self._name def _refresh_auth(self): - logger.info('Reloading k8s api client to refresh the auth.') + logger.info("Reloading k8s api client to refresh the auth.") self._api.reload() def _apply_manifest(self, manifest): - return utils.create_from_dict(self._api.client, - manifest, - namespace=self.name) + return utils.create_from_dict( + self._api.client, manifest, namespace=self.name + ) def _get_resource(self, method: Callable[[Any], object], *args, **kwargs): try: @@ -199,7 +210,7 @@ def _handle_exception(self, err: Exception) -> Optional[retryers.Retrying]: # without response') # - ConnectionResetError(104, 'Connection reset by peer') if isinstance(err, urllib3.exceptions.ProtocolError): - if 'connection aborted' in str(err).lower(): + if "connection aborted" in str(err).lower(): return _server_restart_retryer() else: # To cover other cases we didn't account for, and haven't @@ -221,17 +232,22 @@ def _handle_exception(self, err: Exception) -> Optional[retryers.Retrying]: return None def _handle_api_exception( - self, err: _ApiException) -> Optional[retryers.Retrying]: + self, err: _ApiException + ) -> Optional[retryers.Retrying]: # TODO(sergiitk): replace returns with match/case when we use to py3.10. # pylint: disable=too-many-return-statements # TODO(sergiitk): can I chain the retryers? logger.debug( - 'Handling k8s.ApiException: status=%s reason=%s body=%s headers=%s', - err.status, err.reason, err.body, err.headers) + "Handling k8s.ApiException: status=%s reason=%s body=%s headers=%s", + err.status, + err.reason, + err.body, + err.headers, + ) code: int = err.status - body = err.body.lower() if err.body else '' + body = err.body.lower() if err.body else "" # 401 Unauthorized: token might be expired, attempt auth refresh. if code == 401: @@ -240,8 +256,10 @@ def _handle_api_exception( # 404 Not Found. Make it easier for the caller to handle 404s. if code == 404: - raise NotFound('Kubernetes API returned 404 Not Found: ' - f'{self._status_message_or_body(body)}') from err + raise NotFound( + "Kubernetes API returned 404 Not Found: " + f"{self._status_message_or_body(body)}" + ) from err # 409 Conflict # "Operation cannot be fulfilled on resourcequotas "foo": the object @@ -259,7 +277,7 @@ def _handle_api_exception( if code == 500: # Observed when using `kubectl proxy`. # "dial tcp 127.0.0.1:8080: connect: connection refused" - if 'connection refused' in body: + if "connection refused" in body: return _server_restart_retryer() # Known 500 errors that should be treated as 429: @@ -268,8 +286,10 @@ def _handle_api_exception( # to try again later # - Internal Server Error: "/api/v1/namespaces/foo/services": # the server is currently unable to handle the request - if ('too many requests' in body or - 'currently unable to handle the request' in body): + if ( + "too many requests" in body + or "currently unable to handle the request" in body + ): return _too_many_requests_retryer() # In other cases, just retry a few times in case the server @@ -286,7 +306,7 @@ def _handle_api_exception( @classmethod def _status_message_or_body(cls, body: str) -> str: try: - return str(json.loads(body)['message']) + return str(json.loads(body)["message"]) except (KeyError, ValueError): return body @@ -294,199 +314,259 @@ def create_single_resource(self, manifest): return self._execute(self._apply_manifest, manifest) def get_service(self, name) -> V1Service: - return self._get_resource(self._api.core.read_namespaced_service, name, - self.name) + return self._get_resource( + self._api.core.read_namespaced_service, name, self.name + ) def get_service_account(self, name) -> V1Service: return self._get_resource( - self._api.core.read_namespaced_service_account, name, self.name) - - def delete_service(self, - name, - grace_period_seconds=DELETE_GRACE_PERIOD_SEC): - self._execute(self._api.core.delete_namespaced_service, - name=name, - namespace=self.name, - body=client.V1DeleteOptions( - propagation_policy='Foreground', - grace_period_seconds=grace_period_seconds)) - - def delete_service_account(self, - name, - grace_period_seconds=DELETE_GRACE_PERIOD_SEC): - self._execute(self._api.core.delete_namespaced_service_account, - name=name, - namespace=self.name, - body=client.V1DeleteOptions( - propagation_policy='Foreground', - grace_period_seconds=grace_period_seconds)) + self._api.core.read_namespaced_service_account, name, self.name + ) + + def delete_service( + self, name, grace_period_seconds=DELETE_GRACE_PERIOD_SEC + ): + self._execute( + self._api.core.delete_namespaced_service, + name=name, + namespace=self.name, + body=client.V1DeleteOptions( + propagation_policy="Foreground", + grace_period_seconds=grace_period_seconds, + ), + ) + + def delete_service_account( + self, name, grace_period_seconds=DELETE_GRACE_PERIOD_SEC + ): + self._execute( + self._api.core.delete_namespaced_service_account, + name=name, + namespace=self.name, + body=client.V1DeleteOptions( + propagation_policy="Foreground", + grace_period_seconds=grace_period_seconds, + ), + ) def get(self) -> V1Namespace: return self._get_resource(self._api.core.read_namespace, self.name) def delete(self, grace_period_seconds=DELETE_GRACE_PERIOD_SEC): - self._execute(self._api.core.delete_namespace, - name=self.name, - body=client.V1DeleteOptions( - propagation_policy='Foreground', - grace_period_seconds=grace_period_seconds)) - - def wait_for_service_deleted(self, - name: str, - timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC, - wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None: + self._execute( + self._api.core.delete_namespace, + name=self.name, + body=client.V1DeleteOptions( + propagation_policy="Foreground", + grace_period_seconds=grace_period_seconds, + ), + ) + + def wait_for_service_deleted( + self, + name: str, + timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC, + wait_sec: int = WAIT_SHORT_SLEEP_SEC, + ) -> None: retryer = retryers.constant_retryer( wait_fixed=_timedelta(seconds=wait_sec), timeout=_timedelta(seconds=timeout_sec), - check_result=lambda service: service is None) + check_result=lambda service: service is None, + ) retryer(self.get_service, name) def wait_for_service_account_deleted( - self, - name: str, - timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC, - wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None: + self, + name: str, + timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC, + wait_sec: int = WAIT_SHORT_SLEEP_SEC, + ) -> None: retryer = retryers.constant_retryer( wait_fixed=_timedelta(seconds=wait_sec), timeout=_timedelta(seconds=timeout_sec), - check_result=lambda service_account: service_account is None) + check_result=lambda service_account: service_account is None, + ) retryer(self.get_service_account, name) - def wait_for_namespace_deleted(self, - timeout_sec: int = WAIT_LONG_TIMEOUT_SEC, - wait_sec: int = WAIT_LONG_SLEEP_SEC) -> None: + def wait_for_namespace_deleted( + self, + timeout_sec: int = WAIT_LONG_TIMEOUT_SEC, + wait_sec: int = WAIT_LONG_SLEEP_SEC, + ) -> None: retryer = retryers.constant_retryer( wait_fixed=_timedelta(seconds=wait_sec), timeout=_timedelta(seconds=timeout_sec), - check_result=lambda namespace: namespace is None) + check_result=lambda namespace: namespace is None, + ) retryer(self.get) - def wait_for_service_neg(self, - name: str, - timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC, - wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None: + def wait_for_service_neg( + self, + name: str, + timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC, + wait_sec: int = WAIT_SHORT_SLEEP_SEC, + ) -> None: timeout = _timedelta(seconds=timeout_sec) retryer = retryers.constant_retryer( wait_fixed=_timedelta(seconds=wait_sec), timeout=timeout, - check_result=self._check_service_neg_annotation) + check_result=self._check_service_neg_annotation, + ) try: retryer(self.get_service, name) except retryers.RetryError as e: logger.error( - 'Timeout %s (h:mm:ss) waiting for service %s to report NEG ' - 'status. Last service status:\n%s', timeout, name, - self._pretty_format_status(e.result())) + ( + "Timeout %s (h:mm:ss) waiting for service %s to report NEG " + "status. Last service status:\n%s" + ), + timeout, + name, + self._pretty_format_status(e.result()), + ) raise - def get_service_neg(self, service_name: str, - service_port: int) -> Tuple[str, List[str]]: + def get_service_neg( + self, service_name: str, service_port: int + ) -> Tuple[str, List[str]]: service = self.get_service(service_name) neg_info: dict = json.loads( - service.metadata.annotations[self.NEG_STATUS_META]) - neg_name: str = neg_info['network_endpoint_groups'][str(service_port)] - neg_zones: List[str] = neg_info['zones'] + service.metadata.annotations[self.NEG_STATUS_META] + ) + neg_name: str = neg_info["network_endpoint_groups"][str(service_port)] + neg_zones: List[str] = neg_info["zones"] return neg_name, neg_zones def get_deployment(self, name) -> V1Deployment: - return self._get_resource(self._api.apps.read_namespaced_deployment, - name, self.name) + return self._get_resource( + self._api.apps.read_namespaced_deployment, name, self.name + ) def delete_deployment( - self, - name: str, - grace_period_seconds: int = DELETE_GRACE_PERIOD_SEC) -> None: - self._execute(self._api.apps.delete_namespaced_deployment, - name=name, - namespace=self.name, - body=client.V1DeleteOptions( - propagation_policy='Foreground', - grace_period_seconds=grace_period_seconds)) + self, name: str, grace_period_seconds: int = DELETE_GRACE_PERIOD_SEC + ) -> None: + self._execute( + self._api.apps.delete_namespaced_deployment, + name=name, + namespace=self.name, + body=client.V1DeleteOptions( + propagation_policy="Foreground", + grace_period_seconds=grace_period_seconds, + ), + ) def list_deployment_pods(self, deployment: V1Deployment) -> List[V1Pod]: # V1LabelSelector.match_expressions not supported at the moment return self.list_pods_with_labels(deployment.spec.selector.match_labels) def wait_for_deployment_available_replicas( - self, - name: str, - count: int = 1, - timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC, - wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None: + self, + name: str, + count: int = 1, + timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC, + wait_sec: int = WAIT_SHORT_SLEEP_SEC, + ) -> None: timeout = _timedelta(seconds=timeout_sec) retryer = retryers.constant_retryer( wait_fixed=_timedelta(seconds=wait_sec), timeout=timeout, - check_result=lambda depl: self._replicas_available(depl, count)) + check_result=lambda depl: self._replicas_available(depl, count), + ) try: retryer(self.get_deployment, name) except retryers.RetryError as e: logger.error( - 'Timeout %s (h:mm:ss) waiting for deployment %s to report %i ' - 'replicas available. Last status:\n%s', timeout, name, count, - self._pretty_format_status(e.result())) + ( + "Timeout %s (h:mm:ss) waiting for deployment %s to report" + " %i replicas available. Last status:\n%s" + ), + timeout, + name, + count, + self._pretty_format_status(e.result()), + ) raise def wait_for_deployment_replica_count( - self, - deployment: V1Deployment, - count: int = 1, - *, - timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC, - wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None: + self, + deployment: V1Deployment, + count: int = 1, + *, + timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC, + wait_sec: int = WAIT_SHORT_SLEEP_SEC, + ) -> None: timeout = _timedelta(seconds=timeout_sec) retryer = retryers.constant_retryer( wait_fixed=_timedelta(seconds=wait_sec), timeout=timeout, - check_result=lambda pods: len(pods) == count) + check_result=lambda pods: len(pods) == count, + ) try: retryer(self.list_deployment_pods, deployment) except retryers.RetryError as e: result = e.result(default=[]) logger.error( - 'Timeout %s (h:mm:ss) waiting for pod count %i, got: %i. ' - 'Pod statuses:\n%s', timeout, count, len(result), - self._pretty_format_statuses(result)) + ( + "Timeout %s (h:mm:ss) waiting for pod count %i, got: %i. " + "Pod statuses:\n%s" + ), + timeout, + count, + len(result), + self._pretty_format_statuses(result), + ) raise def wait_for_deployment_deleted( - self, - deployment_name: str, - timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC, - wait_sec: int = WAIT_MEDIUM_SLEEP_SEC) -> None: + self, + deployment_name: str, + timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC, + wait_sec: int = WAIT_MEDIUM_SLEEP_SEC, + ) -> None: retryer = retryers.constant_retryer( wait_fixed=_timedelta(seconds=wait_sec), timeout=_timedelta(seconds=timeout_sec), - check_result=lambda deployment: deployment is None) + check_result=lambda deployment: deployment is None, + ) retryer(self.get_deployment, deployment_name) def list_pods_with_labels(self, labels: dict) -> List[V1Pod]: pod_list: V1PodList = self._execute( self._api.core.list_namespaced_pod, self.name, - label_selector=label_dict_to_selector(labels)) + label_selector=label_dict_to_selector(labels), + ) return pod_list.items def get_pod(self, name: str) -> V1Pod: - return self._get_resource(self._api.core.read_namespaced_pod, name, - self.name) + return self._get_resource( + self._api.core.read_namespaced_pod, name, self.name + ) - def wait_for_pod_started(self, - pod_name: str, - timeout_sec: int = WAIT_POD_START_TIMEOUT_SEC, - wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None: + def wait_for_pod_started( + self, + pod_name: str, + timeout_sec: int = WAIT_POD_START_TIMEOUT_SEC, + wait_sec: int = WAIT_SHORT_SLEEP_SEC, + ) -> None: timeout = _timedelta(seconds=timeout_sec) retryer = retryers.constant_retryer( wait_fixed=_timedelta(seconds=wait_sec), timeout=timeout, - check_result=self._pod_started) + check_result=self._pod_started, + ) try: retryer(self.get_pod, pod_name) except retryers.RetryError as e: logger.error( - 'Timeout %s (h:mm:ss) waiting for pod %s to start. ' - 'Pod status:\n%s', timeout, pod_name, - self._pretty_format_status(e.result())) + ( + "Timeout %s (h:mm:ss) waiting for pod %s to start. " + "Pod status:\n%s" + ), + timeout, + pod_name, + self._pretty_format_status(e.result()), + ) raise def port_forward_pod( @@ -496,20 +576,26 @@ def port_forward_pod( local_port: Optional[int] = None, local_address: Optional[str] = None, ) -> k8s_port_forwarder.PortForwarder: - pf = k8s_port_forwarder.PortForwarder(self._api.context, self.name, - f"pod/{pod.metadata.name}", - remote_port, local_port, - local_address) + pf = k8s_port_forwarder.PortForwarder( + self._api.context, + self.name, + f"pod/{pod.metadata.name}", + remote_port, + local_port, + local_address, + ) pf.connect() return pf - def pod_start_logging(self, - *, - pod_name: str, - log_path: pathlib.Path, - log_stop_event: threading.Event, - log_to_stdout: bool = False, - log_timestamps: bool = False) -> PodLogCollector: + def pod_start_logging( + self, + *, + pod_name: str, + log_path: pathlib.Path, + log_stop_event: threading.Event, + log_to_stdout: bool = False, + log_timestamps: bool = False, + ) -> PodLogCollector: pod_log_collector = PodLogCollector( pod_name=pod_name, namespace_name=self.name, @@ -517,40 +603,43 @@ def pod_start_logging(self, stop_event=log_stop_event, log_path=log_path, log_to_stdout=log_to_stdout, - log_timestamps=log_timestamps) + log_timestamps=log_timestamps, + ) pod_log_collector.start() return pod_log_collector - def _pretty_format_statuses(self, - k8s_objects: List[Optional[object]]) -> str: - return '\n'.join( - self._pretty_format_status(k8s_object) - for k8s_object in k8s_objects) + def _pretty_format_statuses( + self, k8s_objects: List[Optional[object]] + ) -> str: + return "\n".join( + self._pretty_format_status(k8s_object) for k8s_object in k8s_objects + ) def _pretty_format_status(self, k8s_object: Optional[object]) -> str: if k8s_object is None: - return 'No data' + return "No data" # Parse the name if present. - if hasattr(k8s_object, 'metadata') and hasattr(k8s_object.metadata, - 'name'): + if hasattr(k8s_object, "metadata") and hasattr( + k8s_object.metadata, "name" + ): name = k8s_object.metadata.name else: - name = 'Can\'t parse resource name' + name = "Can't parse resource name" # Pretty-print the status if present. - if hasattr(k8s_object, 'status'): + if hasattr(k8s_object, "status"): try: status = self._pretty_format(k8s_object.status.to_dict()) except Exception as e: # pylint: disable=broad-except # Catching all exceptions because not printing the status # isn't as important as the system under test. - status = f'Can\'t parse resource status: {e}' + status = f"Can't parse resource status: {e}" else: - status = 'Can\'t parse resource status' + status = "Can't parse resource status" # Return the name of k8s object, and its pretty-printed status. - return f'{name}:\n{status}\n' + return f"{name}:\n{status}\n" def _pretty_format(self, data: dict) -> str: """Return a string with pretty-printed yaml data from a python dict.""" @@ -558,18 +647,25 @@ def _pretty_format(self, data: dict) -> str: return self._highlighter.highlight(yaml_out) @classmethod - def _check_service_neg_annotation(cls, - service: Optional[V1Service]) -> bool: - return (isinstance(service, V1Service) and - cls.NEG_STATUS_META in service.metadata.annotations) + def _check_service_neg_annotation( + cls, service: Optional[V1Service] + ) -> bool: + return ( + isinstance(service, V1Service) + and cls.NEG_STATUS_META in service.metadata.annotations + ) @classmethod def _pod_started(cls, pod: V1Pod) -> bool: - return (isinstance(pod, V1Pod) and - pod.status.phase not in ('Pending', 'Unknown')) + return isinstance(pod, V1Pod) and pod.status.phase not in ( + "Pending", + "Unknown", + ) @classmethod def _replicas_available(cls, deployment: V1Deployment, count: int) -> bool: - return (isinstance(deployment, V1Deployment) and - deployment.status.available_replicas is not None and - deployment.status.available_replicas >= count) + return ( + isinstance(deployment, V1Deployment) + and deployment.status.available_replicas is not None + and deployment.status.available_replicas >= count + ) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_log_collector.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_log_collector.py index 70976ab37a2ad..4f0c287633f34 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_log_collector.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_log_collector.py @@ -25,6 +25,7 @@ class PodLogCollector(threading.Thread): """A thread that streams logs from the remote pod to a local file.""" + pod_name: str namespace_name: str stop_event: threading.Event @@ -37,16 +38,18 @@ class PodLogCollector(threading.Thread): _watcher: Optional[watch.Watch] _read_pod_log_fn: Callable[..., Any] - def __init__(self, - *, - pod_name: str, - namespace_name: str, - read_pod_log_fn: Callable[..., Any], - stop_event: threading.Event, - log_path: pathlib.Path, - log_to_stdout: bool = False, - log_timestamps: bool = False, - error_backoff_sec: int = 5): + def __init__( + self, + *, + pod_name: str, + namespace_name: str, + read_pod_log_fn: Callable[..., Any], + stop_event: threading.Event, + log_path: pathlib.Path, + log_to_stdout: bool = False, + log_timestamps: bool = False, + error_backoff_sec: int = 5, + ): self.pod_name = pod_name self.namespace_name = namespace_name self.stop_event = stop_event @@ -61,16 +64,18 @@ def __init__(self, self._read_pod_log_fn = read_pod_log_fn self._out_stream = None self._watcher = None - super().__init__(name=f'pod-log-{pod_name}', daemon=True) + super().__init__(name=f"pod-log-{pod_name}", daemon=True) def run(self): - logger.info('Starting log collection thread %i for %s', self.ident, - self.pod_name) + logger.info( + "Starting log collection thread %i for %s", + self.ident, + self.pod_name, + ) try: - self._out_stream = open(self.log_path, - 'w', - errors='ignore', - encoding="utf-8") + self._out_stream = open( + self.log_path, "w", errors="ignore", encoding="utf-8" + ) while not self.stop_event.is_set(): self._stream_log() finally: @@ -87,8 +92,10 @@ def _stop(self): self._watcher.stop() self._watcher = None if self._out_stream is not None: - self._write(f'Finished log collection for pod {self.pod_name}', - force_flush=True) + self._write( + f"Finished log collection for pod {self.pod_name}", + force_flush=True, + ) self._out_stream.close() self._out_stream = None self.drain_event.set() @@ -99,10 +106,13 @@ def _stream_log(self): except client.ApiException as e: self._write(f"Exception fetching logs: {e}") self._write( - f'Restarting log fetching in {self.error_backoff_sec} sec. ' - f'Will attempt to read from the beginning, but log ' - f'truncation may occur.', - force_flush=True) + ( + f"Restarting log fetching in {self.error_backoff_sec} sec. " + "Will attempt to read from the beginning, but log " + "truncation may occur." + ), + force_flush=True, + ) finally: # Instead of time.sleep(), we're waiting on the stop event # in case it gets set earlier. @@ -110,11 +120,13 @@ def _stream_log(self): def _restart_stream(self): self._watcher = watch.Watch() - for msg in self._watcher.stream(self._read_pod_log_fn, - name=self.pod_name, - namespace=self.namespace_name, - timestamps=self.log_timestamps, - follow=True): + for msg in self._watcher.stream( + self._read_pod_log_fn, + name=self.pod_name, + namespace=self.namespace_name, + timestamps=self.log_timestamps, + follow=True, + ): self._write(msg) # Every message check if a stop is requested. if self.stop_event.is_set(): diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_port_forwarder.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_port_forwarder.py index a2c9e68424601..d43c9a4f0af60 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_port_forwarder.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/k8s_internal/k8s_port_forwarder.py @@ -25,15 +25,17 @@ class PortForwardingError(Exception): class PortForwarder: - PORT_FORWARD_LOCAL_ADDRESS: str = '127.0.0.1' + PORT_FORWARD_LOCAL_ADDRESS: str = "127.0.0.1" - def __init__(self, - context: str, - namespace: str, - destination: str, - remote_port: int, - local_port: Optional[int] = None, - local_address: Optional[str] = None): + def __init__( + self, + context: str, + namespace: str, + destination: str, + remote_port: int, + local_port: Optional[int] = None, + local_address: Optional[str] = None, + ): self.context = context self.namespace = namespace self.destination = destination @@ -48,24 +50,36 @@ def connect(self) -> None: else: port_mapping = f":{self.remote_port}" cmd = [ - "kubectl", "--context", self.context, "--namespace", self.namespace, - "port-forward", "--address", self.local_address, self.destination, - port_mapping + "kubectl", + "--context", + self.context, + "--namespace", + self.namespace, + "port-forward", + "--address", + self.local_address, + self.destination, + port_mapping, ] - logger.debug('Executing port forwarding subprocess cmd: %s', - ' '.join(cmd)) - self.subprocess = subprocess.Popen(cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=True) + logger.debug( + "Executing port forwarding subprocess cmd: %s", " ".join(cmd) + ) + self.subprocess = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + ) # Wait for stdout line indicating successful start. if self.local_port: local_port_expected = ( f"Forwarding from {self.local_address}:{self.local_port}" - f" -> {self.remote_port}") + f" -> {self.remote_port}" + ) else: local_port_re = re.compile( - f"Forwarding from {self.local_address}:([0-9]+) -> {self.remote_port}" + f"Forwarding from {self.local_address}:([0-9]+) ->" + f" {self.remote_port}" ) try: while True: @@ -79,8 +93,9 @@ def connect(self) -> None: for error in self.subprocess.stdout.readlines() ] raise PortForwardingError( - 'Error forwarding port, kubectl return ' - f'code {return_code}, output {errors}') + "Error forwarding port, kubectl return " + f"code {return_code}, output {errors}" + ) # If there is no output, and the subprocess is not exiting, # continue waiting for the log line. continue @@ -89,13 +104,13 @@ def connect(self) -> None: if self.local_port: if output != local_port_expected: raise PortForwardingError( - f'Error forwarding port, unexpected output {output}' + f"Error forwarding port, unexpected output {output}" ) else: groups = local_port_re.search(output) if groups is None: raise PortForwardingError( - f'Error forwarding port, unexpected output {output}' + f"Error forwarding port, unexpected output {output}" ) # Update local port to the randomly picked one self.local_port = int(groups[1]) @@ -108,10 +123,11 @@ def connect(self) -> None: def close(self) -> None: if self.subprocess is not None: - logger.info('Shutting down port forwarding, pid %s', - self.subprocess.pid) + logger.info( + "Shutting down port forwarding, pid %s", self.subprocess.pid + ) self.subprocess.kill() stdout, _ = self.subprocess.communicate(timeout=5) - logger.info('Port forwarding stopped') - logger.debug('Port forwarding remaining stdout: %s', stdout) + logger.info("Port forwarding stopped") + logger.debug("Port forwarding remaining stdout: %s", stdout) self.subprocess = None diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py index fbb791aa15bbd..2bc9d11716b22 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py @@ -45,7 +45,7 @@ Mesh = gcp.network_services.Mesh # Testing metadata consts -TEST_AFFINITY_METADATA_KEY = 'xds_md' +TEST_AFFINITY_METADATA_KEY = "xds_md" class TrafficDirectorManager: # pylint: disable=too-many-public-methods @@ -73,13 +73,13 @@ def __init__( *, resource_prefix: str, resource_suffix: str, - network: str = 'default', - compute_api_version: str = 'v1', + network: str = "default", + compute_api_version: str = "v1", ): # API - self.compute = _ComputeV1(gcp_api_manager, - project, - version=compute_api_version) + self.compute = _ComputeV1( + gcp_api_manager, project, version=compute_api_version + ) # Settings self.project: str = project @@ -105,34 +105,39 @@ def __init__( self.alternative_backend_service: Optional[GcpResource] = None # TODO(sergiitk): remove this flag once backend service resource loaded self.alternative_backend_service_protocol: Optional[ - BackendServiceProtocol] = None + BackendServiceProtocol + ] = None self.alternative_backends: Set[ZonalGcpResource] = set() self.affinity_backend_service: Optional[GcpResource] = None # TODO(sergiitk): remove this flag once backend service resource loaded self.affinity_backend_service_protocol: Optional[ - BackendServiceProtocol] = None + BackendServiceProtocol + ] = None self.affinity_backends: Set[ZonalGcpResource] = set() @property def network_url(self): - return f'global/networks/{self.network}' + return f"global/networks/{self.network}" def setup_for_grpc( - self, - service_host, - service_port, - *, - backend_protocol: Optional[BackendServiceProtocol] = _BackendGRPC, - health_check_port: Optional[int] = None): - self.setup_backend_for_grpc(protocol=backend_protocol, - health_check_port=health_check_port) + self, + service_host, + service_port, + *, + backend_protocol: Optional[BackendServiceProtocol] = _BackendGRPC, + health_check_port: Optional[int] = None, + ): + self.setup_backend_for_grpc( + protocol=backend_protocol, health_check_port=health_check_port + ) self.setup_routing_rule_map_for_grpc(service_host, service_port) def setup_backend_for_grpc( - self, - *, - protocol: Optional[BackendServiceProtocol] = _BackendGRPC, - health_check_port: Optional[int] = None): + self, + *, + protocol: Optional[BackendServiceProtocol] = _BackendGRPC, + health_check_port: Optional[int] = None, + ): self.create_health_check(port=health_check_port) self.create_backend_service(protocol) @@ -162,16 +167,19 @@ def make_resource_name(self, name: str) -> str: # Avoid trailing dash when the suffix is empty. if self.resource_suffix: parts.append(self.resource_suffix) - return '-'.join(parts) + return "-".join(parts) def create_health_check( - self, - *, - protocol: Optional[HealthCheckProtocol] = _HealthCheckGRPC, - port: Optional[int] = None): + self, + *, + protocol: Optional[HealthCheckProtocol] = _HealthCheckGRPC, + port: Optional[int] = None, + ): if self.health_check: - raise ValueError(f'Health check {self.health_check.name} ' - 'already created, delete it first') + raise ValueError( + f"Health check {self.health_check.name} " + "already created, delete it first" + ) if protocol is None: protocol = _HealthCheckGRPC @@ -192,12 +200,13 @@ def delete_health_check(self, force=False): self.health_check = None def create_backend_service( - self, - protocol: Optional[BackendServiceProtocol] = _BackendGRPC, - subset_size: Optional[int] = None, - affinity_header: Optional[str] = None, - locality_lb_policies: Optional[List[dict]] = None, - outlier_detection: Optional[dict] = None): + self, + protocol: Optional[BackendServiceProtocol] = _BackendGRPC, + subset_size: Optional[int] = None, + affinity_header: Optional[str] = None, + locality_lb_policies: Optional[List[dict]] = None, + outlier_detection: Optional[dict] = None, + ): if protocol is None: protocol = _BackendGRPC @@ -210,7 +219,8 @@ def create_backend_service( subset_size=subset_size, affinity_header=affinity_header, locality_lb_policies=locality_lb_policies, - outlier_detection=outlier_detection) + outlier_detection=outlier_detection, + ) self.backend_service = resource self.backend_service_protocol = protocol @@ -230,57 +240,69 @@ def delete_backend_service(self, force=False): self.compute.delete_backend_service(name) self.backend_service = None - def backend_service_add_neg_backends(self, - name, - zones, - max_rate_per_endpoint: Optional[ - int] = None): - logger.info('Waiting for Network Endpoint Groups to load endpoints.') + def backend_service_add_neg_backends( + self, name, zones, max_rate_per_endpoint: Optional[int] = None + ): + logger.info("Waiting for Network Endpoint Groups to load endpoints.") for zone in zones: backend = self.compute.wait_for_network_endpoint_group(name, zone) - logger.info('Loaded NEG "%s" in zone %s', backend.name, - backend.zone) + logger.info( + 'Loaded NEG "%s" in zone %s', backend.name, backend.zone + ) self.backends.add(backend) self.backend_service_patch_backends(max_rate_per_endpoint) def backend_service_remove_neg_backends(self, name, zones): - logger.info('Waiting for Network Endpoint Groups to load endpoints.') + logger.info("Waiting for Network Endpoint Groups to load endpoints.") for zone in zones: backend = self.compute.wait_for_network_endpoint_group(name, zone) - logger.info('Loaded NEG "%s" in zone %s', backend.name, - backend.zone) + logger.info( + 'Loaded NEG "%s" in zone %s', backend.name, backend.zone + ) self.backends.remove(backend) self.backend_service_patch_backends() def backend_service_patch_backends( - self, max_rate_per_endpoint: Optional[int] = None): - logging.info('Adding backends to Backend Service %s: %r', - self.backend_service.name, self.backends) - self.compute.backend_service_patch_backends(self.backend_service, - self.backends, - max_rate_per_endpoint) + self, max_rate_per_endpoint: Optional[int] = None + ): + logging.info( + "Adding backends to Backend Service %s: %r", + self.backend_service.name, + self.backends, + ) + self.compute.backend_service_patch_backends( + self.backend_service, self.backends, max_rate_per_endpoint + ) def backend_service_remove_all_backends(self): - logging.info('Removing backends from Backend Service %s', - self.backend_service.name) + logging.info( + "Removing backends from Backend Service %s", + self.backend_service.name, + ) self.compute.backend_service_remove_all_backends(self.backend_service) def wait_for_backends_healthy_status(self): logger.debug( "Waiting for Backend Service %s to report all backends healthy %r", - self.backend_service, self.backends) - self.compute.wait_for_backends_healthy_status(self.backend_service, - self.backends) + self.backend_service, + self.backends, + ) + self.compute.wait_for_backends_healthy_status( + self.backend_service, self.backends + ) def create_alternative_backend_service( - self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC): + self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC + ): if protocol is None: protocol = _BackendGRPC name = self.make_resource_name(self.ALTERNATIVE_BACKEND_SERVICE_NAME) - logger.info('Creating %s Alternative Backend Service "%s"', - protocol.name, name) + logger.info( + 'Creating %s Alternative Backend Service "%s"', protocol.name, name + ) resource = self.compute.create_backend_service_traffic_director( - name, health_check=self.health_check, protocol=protocol) + name, health_check=self.health_check, protocol=protocol + ) self.alternative_backend_service = resource self.alternative_backend_service_protocol = protocol @@ -292,7 +314,8 @@ def load_alternative_backend_service(self): def delete_alternative_backend_service(self, force=False): if force: name = self.make_resource_name( - self.ALTERNATIVE_BACKEND_SERVICE_NAME) + self.ALTERNATIVE_BACKEND_SERVICE_NAME + ) elif self.alternative_backend_service: name = self.alternative_backend_service.name else: @@ -302,46 +325,59 @@ def delete_alternative_backend_service(self, force=False): self.alternative_backend_service = None def alternative_backend_service_add_neg_backends(self, name, zones): - logger.info('Waiting for Network Endpoint Groups to load endpoints.') + logger.info("Waiting for Network Endpoint Groups to load endpoints.") for zone in zones: backend = self.compute.wait_for_network_endpoint_group(name, zone) - logger.info('Loaded NEG "%s" in zone %s', backend.name, - backend.zone) + logger.info( + 'Loaded NEG "%s" in zone %s', backend.name, backend.zone + ) self.alternative_backends.add(backend) self.alternative_backend_service_patch_backends() def alternative_backend_service_patch_backends(self): - logging.info('Adding backends to Backend Service %s: %r', - self.alternative_backend_service.name, - self.alternative_backends) + logging.info( + "Adding backends to Backend Service %s: %r", + self.alternative_backend_service.name, + self.alternative_backends, + ) self.compute.backend_service_patch_backends( - self.alternative_backend_service, self.alternative_backends) + self.alternative_backend_service, self.alternative_backends + ) def alternative_backend_service_remove_all_backends(self): - logging.info('Removing backends from Backend Service %s', - self.alternative_backend_service.name) + logging.info( + "Removing backends from Backend Service %s", + self.alternative_backend_service.name, + ) self.compute.backend_service_remove_all_backends( - self.alternative_backend_service) + self.alternative_backend_service + ) def wait_for_alternative_backends_healthy_status(self): logger.debug( "Waiting for Backend Service %s to report all backends healthy %r", - self.alternative_backend_service, self.alternative_backends) + self.alternative_backend_service, + self.alternative_backends, + ) self.compute.wait_for_backends_healthy_status( - self.alternative_backend_service, self.alternative_backends) + self.alternative_backend_service, self.alternative_backends + ) def create_affinity_backend_service( - self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC): + self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC + ): if protocol is None: protocol = _BackendGRPC name = self.make_resource_name(self.AFFINITY_BACKEND_SERVICE_NAME) - logger.info('Creating %s Affinity Backend Service "%s"', protocol.name, - name) + logger.info( + 'Creating %s Affinity Backend Service "%s"', protocol.name, name + ) resource = self.compute.create_backend_service_traffic_director( name, health_check=self.health_check, protocol=protocol, - affinity_header=TEST_AFFINITY_METADATA_KEY) + affinity_header=TEST_AFFINITY_METADATA_KEY, + ) self.affinity_backend_service = resource self.affinity_backend_service_protocol = protocol @@ -362,32 +398,43 @@ def delete_affinity_backend_service(self, force=False): self.affinity_backend_service = None def affinity_backend_service_add_neg_backends(self, name, zones): - logger.info('Waiting for Network Endpoint Groups to load endpoints.') + logger.info("Waiting for Network Endpoint Groups to load endpoints.") for zone in zones: backend = self.compute.wait_for_network_endpoint_group(name, zone) - logger.info('Loaded NEG "%s" in zone %s', backend.name, - backend.zone) + logger.info( + 'Loaded NEG "%s" in zone %s', backend.name, backend.zone + ) self.affinity_backends.add(backend) self.affinity_backend_service_patch_backends() def affinity_backend_service_patch_backends(self): - logging.info('Adding backends to Backend Service %s: %r', - self.affinity_backend_service.name, self.affinity_backends) + logging.info( + "Adding backends to Backend Service %s: %r", + self.affinity_backend_service.name, + self.affinity_backends, + ) self.compute.backend_service_patch_backends( - self.affinity_backend_service, self.affinity_backends) + self.affinity_backend_service, self.affinity_backends + ) def affinity_backend_service_remove_all_backends(self): - logging.info('Removing backends from Backend Service %s', - self.affinity_backend_service.name) + logging.info( + "Removing backends from Backend Service %s", + self.affinity_backend_service.name, + ) self.compute.backend_service_remove_all_backends( - self.affinity_backend_service) + self.affinity_backend_service + ) def wait_for_affinity_backends_healthy_status(self): logger.debug( "Waiting for Backend Service %s to report all backends healthy %r", - self.affinity_backend_service, self.affinity_backends) + self.affinity_backend_service, + self.affinity_backends, + ) self.compute.wait_for_backends_healthy_status( - self.affinity_backend_service, self.affinity_backends) + self.affinity_backend_service, self.affinity_backends + ) @staticmethod def _generate_url_map_body( @@ -400,46 +447,61 @@ def _generate_url_map_body( if dst_host_rule_match_backend_service is None: dst_host_rule_match_backend_service = dst_default_backend_service return { - 'name': - name, - 'defaultService': - dst_default_backend_service.url, - 'hostRules': [{ - 'hosts': src_hosts, - 'pathMatcher': matcher_name, - }], - 'pathMatchers': [{ - 'name': matcher_name, - 'defaultService': dst_host_rule_match_backend_service.url, - }], + "name": name, + "defaultService": dst_default_backend_service.url, + "hostRules": [ + { + "hosts": src_hosts, + "pathMatcher": matcher_name, + } + ], + "pathMatchers": [ + { + "name": matcher_name, + "defaultService": dst_host_rule_match_backend_service.url, + } + ], } def create_url_map(self, src_host: str, src_port: int) -> GcpResource: - src_address = f'{src_host}:{src_port}' + src_address = f"{src_host}:{src_port}" name = self.make_resource_name(self.URL_MAP_NAME) matcher_name = self.make_resource_name(self.URL_MAP_PATH_MATCHER_NAME) - logger.info('Creating URL map "%s": %s -> %s', name, src_address, - self.backend_service.name) + logger.info( + 'Creating URL map "%s": %s -> %s', + name, + src_address, + self.backend_service.name, + ) resource = self.compute.create_url_map_with_content( - self._generate_url_map_body(name, matcher_name, [src_address], - self.backend_service)) + self._generate_url_map_body( + name, matcher_name, [src_address], self.backend_service + ) + ) self.url_map = resource return resource - def patch_url_map(self, src_host: str, src_port: int, - backend_service: GcpResource): - src_address = f'{src_host}:{src_port}' + def patch_url_map( + self, src_host: str, src_port: int, backend_service: GcpResource + ): + src_address = f"{src_host}:{src_port}" name = self.make_resource_name(self.URL_MAP_NAME) matcher_name = self.make_resource_name(self.URL_MAP_PATH_MATCHER_NAME) - logger.info('Patching URL map "%s": %s -> %s', name, src_address, - backend_service.name) + logger.info( + 'Patching URL map "%s": %s -> %s', + name, + src_address, + backend_service.name, + ) self.compute.patch_url_map( self.url_map, - self._generate_url_map_body(name, matcher_name, [src_address], - backend_service)) + self._generate_url_map_body( + name, matcher_name, [src_address], backend_service + ), + ) def create_url_map_with_content(self, url_map_body: Any) -> GcpResource: - logger.info('Creating URL map: %s', url_map_body) + logger.info("Creating URL map: %s", url_map_body) resource = self.compute.create_url_map_with_content(url_map_body) self.url_map = resource return resource @@ -456,20 +518,27 @@ def delete_url_map(self, force=False): self.url_map = None def create_alternative_url_map( - self, - src_host: str, - src_port: int, - backend_service: Optional[GcpResource] = None) -> GcpResource: + self, + src_host: str, + src_port: int, + backend_service: Optional[GcpResource] = None, + ) -> GcpResource: name = self.make_resource_name(self.ALTERNATIVE_URL_MAP_NAME) - src_address = f'{src_host}:{src_port}' + src_address = f"{src_host}:{src_port}" matcher_name = self.make_resource_name(self.URL_MAP_PATH_MATCHER_NAME) if backend_service is None: backend_service = self.alternative_backend_service - logger.info('Creating alternative URL map "%s": %s -> %s', name, - src_address, backend_service.name) + logger.info( + 'Creating alternative URL map "%s": %s -> %s', + name, + src_address, + backend_service.name, + ) resource = self.compute.create_url_map_with_content( - self._generate_url_map_body(name, matcher_name, [src_address], - backend_service)) + self._generate_url_map_body( + name, matcher_name, [src_address], backend_service + ) + ) self.alternative_url_map = resource return resource @@ -487,18 +556,22 @@ def delete_alternative_url_map(self, force=False): def create_target_proxy(self): name = self.make_resource_name(self.TARGET_PROXY_NAME) if self.backend_service_protocol is BackendServiceProtocol.GRPC: - target_proxy_type = 'GRPC' + target_proxy_type = "GRPC" create_proxy_fn = self.compute.create_target_grpc_proxy self.target_proxy_is_http = False elif self.backend_service_protocol is BackendServiceProtocol.HTTP2: - target_proxy_type = 'HTTP' + target_proxy_type = "HTTP" create_proxy_fn = self.compute.create_target_http_proxy self.target_proxy_is_http = True else: - raise TypeError('Unexpected backend service protocol') + raise TypeError("Unexpected backend service protocol") - logger.info('Creating target %s proxy "%s" to URL map %s', name, - target_proxy_type, self.url_map.name) + logger.info( + 'Creating target %s proxy "%s" to URL map %s', + name, + target_proxy_type, + self.url_map.name, + ) self.target_proxy = create_proxy_fn(name, self.url_map) def delete_target_grpc_proxy(self, force=False): @@ -530,11 +603,16 @@ def create_alternative_target_proxy(self): if self.backend_service_protocol is BackendServiceProtocol.GRPC: logger.info( 'Creating alternative target GRPC proxy "%s" to URL map %s', - name, self.alternative_url_map.name) - self.alternative_target_proxy = self.compute.create_target_grpc_proxy( - name, self.alternative_url_map, False) + name, + self.alternative_url_map.name, + ) + self.alternative_target_proxy = ( + self.compute.create_target_grpc_proxy( + name, self.alternative_url_map, False + ) + ) else: - raise TypeError('Unexpected backend service protocol') + raise TypeError("Unexpected backend service protocol") def delete_alternative_target_grpc_proxy(self, force=False): if force: @@ -548,11 +626,12 @@ def delete_alternative_target_grpc_proxy(self, force=False): self.alternative_target_proxy = None def find_unused_forwarding_rule_port( - self, - *, - lo: int = 1024, # To avoid confusion, skip well-known ports. - hi: int = 65535, - attempts: int = 25) -> int: + self, + *, + lo: int = 1024, # To avoid confusion, skip well-known ports. + hi: int = 65535, + attempts: int = 25, + ) -> int: for _ in range(attempts): src_port = random.randint(lo, hi) if not self.compute.exists_forwarding_rule(src_port): @@ -565,10 +644,14 @@ def create_forwarding_rule(self, src_port: int): src_port = int(src_port) logging.info( 'Creating forwarding rule "%s" in network "%s": 0.0.0.0:%s -> %s', - name, self.network, src_port, self.target_proxy.url) - resource = self.compute.create_forwarding_rule(name, src_port, - self.target_proxy, - self.network_url) + name, + self.network, + src_port, + self.target_proxy.url, + ) + resource = self.compute.create_forwarding_rule( + name, src_port, self.target_proxy, self.network_url + ) self.forwarding_rule = resource return resource @@ -583,28 +666,37 @@ def delete_forwarding_rule(self, force=False): self.compute.delete_forwarding_rule(name) self.forwarding_rule = None - def create_alternative_forwarding_rule(self, - src_port: int, - ip_address='0.0.0.0'): + def create_alternative_forwarding_rule( + self, src_port: int, ip_address="0.0.0.0" + ): name = self.make_resource_name(self.ALTERNATIVE_FORWARDING_RULE_NAME) src_port = int(src_port) logging.info( - 'Creating alternative forwarding rule "%s" in network "%s": %s:%s -> %s', - name, self.network, ip_address, src_port, - self.alternative_target_proxy.url) + ( + 'Creating alternative forwarding rule "%s" in network "%s":' + " %s:%s -> %s" + ), + name, + self.network, + ip_address, + src_port, + self.alternative_target_proxy.url, + ) resource = self.compute.create_forwarding_rule( name, src_port, self.alternative_target_proxy, self.network_url, - ip_address=ip_address) + ip_address=ip_address, + ) self.alternative_forwarding_rule = resource return resource def delete_alternative_forwarding_rule(self, force=False): if force: name = self.make_resource_name( - self.ALTERNATIVE_FORWARDING_RULE_NAME) + self.ALTERNATIVE_FORWARDING_RULE_NAME + ) elif self.alternative_forwarding_rule: name = self.alternative_forwarding_rule.name else: @@ -617,10 +709,16 @@ def create_firewall_rule(self, allowed_ports: List[str]): name = self.make_resource_name(self.FIREWALL_RULE_NAME) logging.info( 'Creating firewall rule "%s" in network "%s" with allowed ports %s', - name, self.network, allowed_ports) + name, + self.network, + allowed_ports, + ) resource = self.compute.create_firewall_rule( - name, self.network_url, xds_flags.FIREWALL_SOURCE_RANGE.value, - allowed_ports) + name, + self.network_url, + xds_flags.FIREWALL_SOURCE_RANGE.value, + allowed_ports, + ) self.firewall_rule = resource def delete_firewall_rule(self, force=False): @@ -637,26 +735,29 @@ def delete_firewall_rule(self, force=False): class TrafficDirectorAppNetManager(TrafficDirectorManager): - GRPC_ROUTE_NAME = "grpc-route" MESH_NAME = "mesh" netsvc: _NetworkServicesV1Alpha1 - def __init__(self, - gcp_api_manager: gcp.api.GcpApiManager, - project: str, - *, - resource_prefix: str, - resource_suffix: Optional[str] = None, - network: str = 'default', - compute_api_version: str = 'v1'): - super().__init__(gcp_api_manager, - project, - resource_prefix=resource_prefix, - resource_suffix=resource_suffix, - network=network, - compute_api_version=compute_api_version) + def __init__( + self, + gcp_api_manager: gcp.api.GcpApiManager, + project: str, + *, + resource_prefix: str, + resource_suffix: Optional[str] = None, + network: str = "default", + compute_api_version: str = "v1", + ): + super().__init__( + gcp_api_manager, + project, + resource_prefix=resource_prefix, + resource_suffix=resource_suffix, + network=network, + compute_api_version=compute_api_version, + ) # API self.netsvc = _NetworkServicesV1Alpha1(gcp_api_manager, project) @@ -682,25 +783,21 @@ def delete_mesh(self, force=False): name = self.mesh.name else: return - logger.info('Deleting Mesh %s', name) + logger.info("Deleting Mesh %s", name) self.netsvc.delete_mesh(name) self.mesh = None def create_grpc_route(self, src_host: str, src_port: int) -> GcpResource: - host = f'{src_host}:{src_port}' - service_name = self.netsvc.resource_full_name(self.backend_service.name, - "backendServices") + host = f"{src_host}:{src_port}" + service_name = self.netsvc.resource_full_name( + self.backend_service.name, "backendServices" + ) body = { "meshes": [self.mesh.url], - "hostnames": - host, - "rules": [{ - "action": { - "destinations": [{ - "serviceName": service_name - }] - } - }], + "hostnames": host, + "rules": [ + {"action": {"destinations": [{"serviceName": service_name}]}} + ], } name = self.make_resource_name(self.GRPC_ROUTE_NAME) logger.info("Creating GrpcRoute %s", name) @@ -724,7 +821,7 @@ def delete_grpc_route(self, force=False): name = self.grpc_route.name else: return - logger.info('Deleting GrpcRoute %s', name) + logger.info("Deleting GrpcRoute %s", name) self.netsvc.delete_grpc_route(name) self.grpc_route = None @@ -751,15 +848,17 @@ def __init__( *, resource_prefix: str, resource_suffix: Optional[str] = None, - network: str = 'default', - compute_api_version: str = 'v1', + network: str = "default", + compute_api_version: str = "v1", ): - super().__init__(gcp_api_manager, - project, - resource_prefix=resource_prefix, - resource_suffix=resource_suffix, - network=network, - compute_api_version=compute_api_version) + super().__init__( + gcp_api_manager, + project, + resource_prefix=resource_prefix, + resource_suffix=resource_suffix, + network=network, + compute_api_version=compute_api_version, + ) # API self.netsec = _NetworkSecurityV1Beta1(gcp_api_manager, project) @@ -771,27 +870,23 @@ def __init__( self.authz_policy: Optional[AuthorizationPolicy] = None self.endpoint_policy: Optional[EndpointPolicy] = None - def setup_server_security(self, - *, - server_namespace, - server_name, - server_port, - tls=True, - mtls=True): + def setup_server_security( + self, *, server_namespace, server_name, server_port, tls=True, mtls=True + ): self.create_server_tls_policy(tls=tls, mtls=mtls) - self.create_endpoint_policy(server_namespace=server_namespace, - server_name=server_name, - server_port=server_port) - - def setup_client_security(self, - *, - server_namespace, - server_name, - tls=True, - mtls=True): + self.create_endpoint_policy( + server_namespace=server_namespace, + server_name=server_name, + server_port=server_port, + ) + + def setup_client_security( + self, *, server_namespace, server_name, tls=True, mtls=True + ): self.create_client_tls_policy(tls=tls, mtls=mtls) - self.backend_service_apply_client_mtls_policy(server_namespace, - server_name) + self.backend_service_apply_client_mtls_policy( + server_namespace, server_name + ) def cleanup(self, *, force=False): # Cleanup in the reverse order of creation @@ -803,11 +898,15 @@ def cleanup(self, *, force=False): def create_server_tls_policy(self, *, tls, mtls): name = self.make_resource_name(self.SERVER_TLS_POLICY_NAME) - logger.info('Creating Server TLS Policy %s', name) + logger.info("Creating Server TLS Policy %s", name) if not tls and not mtls: logger.warning( - 'Server TLS Policy %s neither TLS, nor mTLS ' - 'policy. Skipping creation', name) + ( + "Server TLS Policy %s neither TLS, nor mTLS " + "policy. Skipping creation" + ), + name, + ) return certificate_provider = self._get_certificate_provider() @@ -821,7 +920,7 @@ def create_server_tls_policy(self, *, tls, mtls): self.netsec.create_server_tls_policy(name, policy) self.server_tls_policy = self.netsec.get_server_tls_policy(name) - logger.debug('Server TLS Policy loaded: %r', self.server_tls_policy) + logger.debug("Server TLS Policy loaded: %r", self.server_tls_policy) def delete_server_tls_policy(self, force=False): if force: @@ -830,13 +929,13 @@ def delete_server_tls_policy(self, force=False): name = self.server_tls_policy.name else: return - logger.info('Deleting Server TLS Policy %s', name) + logger.info("Deleting Server TLS Policy %s", name) self.netsec.delete_server_tls_policy(name) self.server_tls_policy = None def create_authz_policy(self, *, action: str, rules: list): name = self.make_resource_name(self.AUTHZ_POLICY_NAME) - logger.info('Creating Authz Policy %s', name) + logger.info("Creating Authz Policy %s", name) policy = { "action": action, "rules": rules, @@ -844,7 +943,7 @@ def create_authz_policy(self, *, action: str, rules: list): self.netsec.create_authz_policy(name, policy) self.authz_policy = self.netsec.get_authz_policy(name) - logger.debug('Authz Policy loaded: %r', self.authz_policy) + logger.debug("Authz Policy loaded: %r", self.authz_policy) def delete_authz_policy(self, force=False): if force: @@ -853,18 +952,21 @@ def delete_authz_policy(self, force=False): name = self.authz_policy.name else: return - logger.info('Deleting Authz Policy %s', name) + logger.info("Deleting Authz Policy %s", name) self.netsec.delete_authz_policy(name) self.authz_policy = None - def create_endpoint_policy(self, *, server_namespace: str, server_name: str, - server_port: int) -> None: + def create_endpoint_policy( + self, *, server_namespace: str, server_name: str, server_port: int + ) -> None: name = self.make_resource_name(self.ENDPOINT_POLICY) - logger.info('Creating Endpoint Policy %s', name) - endpoint_matcher_labels = [{ - "labelName": "app", - "labelValue": f"{server_namespace}-{server_name}" - }] + logger.info("Creating Endpoint Policy %s", name) + endpoint_matcher_labels = [ + { + "labelName": "app", + "labelValue": f"{server_namespace}-{server_name}", + } + ] port_selector = {"ports": [str(server_port)]} label_matcher_all = { "metadataLabelMatchCriteria": "MATCH_ALL", @@ -881,14 +983,18 @@ def create_endpoint_policy(self, *, server_namespace: str, server_name: str, config["serverTlsPolicy"] = self.server_tls_policy.name else: logger.warning( - 'Creating Endpoint Policy %s with ' - 'no Server TLS policy attached', name) + ( + "Creating Endpoint Policy %s with " + "no Server TLS policy attached" + ), + name, + ) if self.authz_policy: config["authorizationPolicy"] = self.authz_policy.name self.netsvc.create_endpoint_policy(name, config) self.endpoint_policy = self.netsvc.get_endpoint_policy(name) - logger.debug('Loaded Endpoint Policy: %r', self.endpoint_policy) + logger.debug("Loaded Endpoint Policy: %r", self.endpoint_policy) def delete_endpoint_policy(self, force: bool = False) -> None: if force: @@ -897,17 +1003,21 @@ def delete_endpoint_policy(self, force: bool = False) -> None: name = self.endpoint_policy.name else: return - logger.info('Deleting Endpoint Policy %s', name) + logger.info("Deleting Endpoint Policy %s", name) self.netsvc.delete_endpoint_policy(name) self.endpoint_policy = None def create_client_tls_policy(self, *, tls, mtls): name = self.make_resource_name(self.CLIENT_TLS_POLICY_NAME) - logger.info('Creating Client TLS Policy %s', name) + logger.info("Creating Client TLS Policy %s", name) if not tls and not mtls: logger.warning( - 'Client TLS Policy %s neither TLS, nor mTLS ' - 'policy. Skipping creation', name) + ( + "Client TLS Policy %s neither TLS, nor mTLS " + "policy. Skipping creation" + ), + name, + ) return certificate_provider = self._get_certificate_provider() @@ -919,7 +1029,7 @@ def create_client_tls_policy(self, *, tls, mtls): self.netsec.create_client_tls_policy(name, policy) self.client_tls_policy = self.netsec.get_client_tls_policy(name) - logger.debug('Client TLS Policy loaded: %r', self.client_tls_policy) + logger.debug("Client TLS Policy loaded: %r", self.client_tls_policy) def delete_client_tls_policy(self, force=False): if force: @@ -928,7 +1038,7 @@ def delete_client_tls_policy(self, force=False): name = self.client_tls_policy.name else: return - logger.info('Deleting Client TLS Policy %s', name) + logger.info("Deleting Client TLS Policy %s", name) self.netsec.delete_client_tls_policy(name) self.client_tls_policy = None @@ -939,25 +1049,34 @@ def backend_service_apply_client_mtls_policy( ): if not self.client_tls_policy: logger.warning( - 'Client TLS policy not created, ' - 'skipping attaching to Backend Service %s', - self.backend_service.name) + ( + "Client TLS policy not created, " + "skipping attaching to Backend Service %s" + ), + self.backend_service.name, + ) return - server_spiffe = (f'spiffe://{self.project}.svc.id.goog/' - f'ns/{server_namespace}/sa/{server_name}') + server_spiffe = ( + f"spiffe://{self.project}.svc.id.goog/" + f"ns/{server_namespace}/sa/{server_name}" + ) logging.info( - 'Adding Client TLS Policy to Backend Service %s: %s, ' - 'server %s', self.backend_service.name, self.client_tls_policy.url, - server_spiffe) + "Adding Client TLS Policy to Backend Service %s: %s, server %s", + self.backend_service.name, + self.client_tls_policy.url, + server_spiffe, + ) self.compute.patch_backend_service( - self.backend_service, { - 'securitySettings': { - 'clientTlsPolicy': self.client_tls_policy.url, - 'subjectAltNames': [server_spiffe] + self.backend_service, + { + "securitySettings": { + "clientTlsPolicy": self.client_tls_policy.url, + "subjectAltNames": [server_spiffe], } - }) + }, + ) @classmethod def _get_certificate_provider(cls): diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py index 65bd61430b652..d58ab5b8cd1b8 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py @@ -36,24 +36,28 @@ class GrpcClientHelper: # or port forwarding, this still is set to a useful name. log_target: str - def __init__(self, - channel: grpc.Channel, - stub_class: Any, - *, - log_target: Optional[str] = ''): + def __init__( + self, + channel: grpc.Channel, + stub_class: Any, + *, + log_target: Optional[str] = "", + ): self.channel = channel self.stub = stub_class(channel) - self.log_service_name = re.sub('Stub$', '', - self.stub.__class__.__name__) - self.log_target = log_target or '' + self.log_service_name = re.sub( + "Stub$", "", self.stub.__class__.__name__ + ) + self.log_target = log_target or "" def call_unary_with_deadline( - self, - *, - rpc: str, - req: Message, - deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC, - log_level: Optional[int] = logging.DEBUG) -> Message: + self, + *, + rpc: str, + req: Message, + deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC, + log_level: Optional[int] = logging.DEBUG, + ) -> Message: if deadline_sec is None: deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC @@ -65,11 +69,16 @@ def call_unary_with_deadline( return rpc_callable(req, **call_kwargs) def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG): - logger.log(logging.DEBUG if log_level is None else log_level, - '[%s] >> RPC %s.%s(request=%s(%r), %s)', self.log_target, - self.log_service_name, rpc, req.__class__.__name__, - json_format.MessageToDict(req), - ', '.join({f'{k}={v}' for k, v in call_kwargs.items()})) + logger.log( + logging.DEBUG if log_level is None else log_level, + "[%s] >> RPC %s.%s(request=%s(%r), %s)", + self.log_target, + self.log_service_name, + rpc, + req.__class__.__name__, + json_format.MessageToDict(req), + ", ".join({f"{k}={v}" for k, v in call_kwargs.items()}), + ) class GrpcApp: @@ -89,7 +98,7 @@ def __init__(self, rpc_host): def _make_channel(self, port) -> grpc.Channel: if port not in self.channels: - target = f'{self.rpc_host}:{port}' + target = f"{self.rpc_host}:{port}" self.channels[port] = grpc.insecure_channel(target) return self.channels[port] diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py index 92c8f79314fdd..29de0ca195052 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py @@ -57,17 +57,16 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): stub: channelz_pb2_grpc.ChannelzStub - def __init__(self, - channel: grpc.Channel, - *, - log_target: Optional[str] = ''): - super().__init__(channel, - channelz_pb2_grpc.ChannelzStub, - log_target=log_target) + def __init__( + self, channel: grpc.Channel, *, log_target: Optional[str] = "" + ): + super().__init__( + channel, channelz_pb2_grpc.ChannelzStub, log_target=log_target + ) @staticmethod def is_sock_tcpip_address(address: Address): - return address.WhichOneof('address') == 'tcpip_address' + return address.WhichOneof("address") == "tcpip_address" @staticmethod def is_ipv4(tcpip_address: Address.TcpIpAddress): @@ -83,18 +82,21 @@ def sock_address_to_str(cls, address: Address): ip = ipaddress.IPv4Address(tcpip_address.ip_address) else: ip = ipaddress.IPv6Address(tcpip_address.ip_address) - return f'{ip}:{tcpip_address.port}' + return f"{ip}:{tcpip_address.port}" else: - raise NotImplementedError('Only tcpip_address implemented') + raise NotImplementedError("Only tcpip_address implemented") @classmethod def sock_addresses_pretty(cls, socket: Socket): - return (f'local={cls.sock_address_to_str(socket.local)}, ' - f'remote={cls.sock_address_to_str(socket.remote)}') + return ( + f"local={cls.sock_address_to_str(socket.local)}, " + f"remote={cls.sock_address_to_str(socket.remote)}" + ) @staticmethod - def find_server_socket_matching_client(server_sockets: Iterator[Socket], - client_socket: Socket) -> Socket: + def find_server_socket_matching_client( + server_sockets: Iterator[Socket], client_socket: Socket + ) -> Socket: for server_socket in server_sockets: if server_socket.remote == client_socket.local: return server_socket @@ -102,35 +104,43 @@ def find_server_socket_matching_client(server_sockets: Iterator[Socket], @staticmethod def channel_repr(channel: Channel) -> str: - result = f'' + result += f" target={channel.data.target}" + result += f" state={ChannelState.Name(channel.data.state.state)}>" return result @staticmethod def subchannel_repr(subchannel: Subchannel) -> str: - result = f'' + result += f" target={subchannel.data.target}" + result += f" state={ChannelState.Name(subchannel.data.state.state)}>" return result - def find_channels_for_target(self, target: str, - **kwargs) -> Iterator[Channel]: - return (channel for channel in self.list_channels(**kwargs) - if channel.data.target == target) - - def find_server_listening_on_port(self, port: int, - **kwargs) -> Optional[Server]: + def find_channels_for_target( + self, target: str, **kwargs + ) -> Iterator[Channel]: + return ( + channel + for channel in self.list_channels(**kwargs) + if channel.data.target == target + ) + + def find_server_listening_on_port( + self, port: int, **kwargs + ) -> Optional[Server]: for server in self.list_servers(**kwargs): listen_socket_ref: SocketRef for listen_socket_ref in server.listen_socket: - listen_socket = self.get_socket(listen_socket_ref.socket_id, - **kwargs) + listen_socket = self.get_socket( + listen_socket_ref.socket_id, **kwargs + ) listen_address: Address = listen_socket.local - if (self.is_sock_tcpip_address(listen_address) and - listen_address.tcpip_address.port == port): + if ( + self.is_sock_tcpip_address(listen_address) + and listen_address.tcpip_address.port == port + ): return server return None @@ -148,9 +158,10 @@ def list_channels(self, **kwargs) -> Iterator[Channel]: # value by adding 1 to the highest seen result ID. start += 1 response = self.call_unary_with_deadline( - rpc='GetTopChannels', + rpc="GetTopChannels", req=_GetTopChannelsRequest(start_channel_id=start), - **kwargs) + **kwargs, + ) for channel in response.channel: start = max(start, channel.ref.channel_id) yield channel @@ -164,9 +175,10 @@ def list_servers(self, **kwargs) -> Iterator[Server]: # value by adding 1 to the highest seen result ID. start += 1 response = self.call_unary_with_deadline( - rpc='GetServers', + rpc="GetServers", req=_GetServersRequest(start_server_id=start), - **kwargs) + **kwargs, + ) for server in response.server: start = max(start, server.ref.server_id) yield server @@ -183,30 +195,35 @@ def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]: # value by adding 1 to the highest seen result ID. start += 1 response = self.call_unary_with_deadline( - rpc='GetServerSockets', - req=_GetServerSocketsRequest(server_id=server.ref.server_id, - start_socket_id=start), - **kwargs) + rpc="GetServerSockets", + req=_GetServerSocketsRequest( + server_id=server.ref.server_id, start_socket_id=start + ), + **kwargs, + ) socket_ref: SocketRef for socket_ref in response.socket_ref: start = max(start, socket_ref.socket_id) # Yield actual socket yield self.get_socket(socket_ref.socket_id, **kwargs) - def list_channel_sockets(self, channel: Channel, - **kwargs) -> Iterator[Socket]: + def list_channel_sockets( + self, channel: Channel, **kwargs + ) -> Iterator[Socket]: """List all sockets of all subchannels of a given channel.""" for subchannel in self.list_channel_subchannels(channel, **kwargs): yield from self.list_subchannels_sockets(subchannel, **kwargs) - def list_channel_subchannels(self, channel: Channel, - **kwargs) -> Iterator[Subchannel]: + def list_channel_subchannels( + self, channel: Channel, **kwargs + ) -> Iterator[Subchannel]: """List all subchannels of a given channel.""" for subchannel_ref in channel.subchannel_ref: yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs) - def list_subchannels_sockets(self, subchannel: Subchannel, - **kwargs) -> Iterator[Socket]: + def list_subchannels_sockets( + self, subchannel: Subchannel, **kwargs + ) -> Iterator[Socket]: """List all sockets of a given subchannel.""" for socket_ref in subchannel.socket_ref: yield self.get_socket(socket_ref.socket_id, **kwargs) @@ -214,15 +231,17 @@ def list_subchannels_sockets(self, subchannel: Subchannel, def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel: """Return a single Subchannel, otherwise raises RpcError.""" response: _GetSubchannelResponse = self.call_unary_with_deadline( - rpc='GetSubchannel', + rpc="GetSubchannel", req=_GetSubchannelRequest(subchannel_id=subchannel_id), - **kwargs) + **kwargs, + ) return response.subchannel def get_socket(self, socket_id, **kwargs) -> Socket: """Return a single Socket, otherwise raises RpcError.""" response: _GetSocketResponse = self.call_unary_with_deadline( - rpc='GetSocket', + rpc="GetSocket", req=_GetSocketRequest(socket_id=socket_id), - **kwargs) + **kwargs, + ) return response.socket diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_csds.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_csds.py index 947f88d8fb540..a3895fe89ef8e 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_csds.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_csds.py @@ -25,8 +25,10 @@ from envoy.extensions.filters.common.fault.v3 import fault_pb2 as _ from envoy.extensions.filters.http.fault.v3 import fault_pb2 as _ from envoy.extensions.filters.http.router.v3 import router_pb2 as _ -from envoy.extensions.filters.network.http_connection_manager.v3 import \ - http_connection_manager_pb2 as _ +from envoy.extensions.filters.network.http_connection_manager.v3 import ( + http_connection_manager_pb2 as _, +) + # pylint: enable=unused-import from envoy.service.status.v3 import csds_pb2 from envoy.service.status.v3 import csds_pb2_grpc @@ -44,21 +46,23 @@ class CsdsClient(framework.rpc.grpc.GrpcClientHelper): stub: csds_pb2_grpc.ClientStatusDiscoveryServiceStub - def __init__(self, - channel: grpc.Channel, - *, - log_target: Optional[str] = ''): - super().__init__(channel, - csds_pb2_grpc.ClientStatusDiscoveryServiceStub, - log_target=log_target) + def __init__( + self, channel: grpc.Channel, *, log_target: Optional[str] = "" + ): + super().__init__( + channel, + csds_pb2_grpc.ClientStatusDiscoveryServiceStub, + log_target=log_target, + ) def fetch_client_status(self, **kwargs) -> Optional[ClientConfig]: """Fetches the active xDS configurations.""" - response = self.call_unary_with_deadline(rpc='FetchClientStatus', - req=_ClientStatusRequest(), - **kwargs) + response = self.call_unary_with_deadline( + rpc="FetchClientStatus", req=_ClientStatusRequest(), **kwargs + ) if len(response.config) != 1: - logger.debug('Unexpected number of client configs: %s', - len(response.config)) + logger.debug( + "Unexpected number of client configs: %s", len(response.config) + ) return None return response.config[0] diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py index ead41443bfea4..be50a79be892f 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py @@ -30,8 +30,12 @@ # Type aliases _LoadBalancerStatsRequest = messages_pb2.LoadBalancerStatsRequest LoadBalancerStatsResponse = messages_pb2.LoadBalancerStatsResponse -_LoadBalancerAccumulatedStatsRequest = messages_pb2.LoadBalancerAccumulatedStatsRequest -LoadBalancerAccumulatedStatsResponse = messages_pb2.LoadBalancerAccumulatedStatsResponse +_LoadBalancerAccumulatedStatsRequest = ( + messages_pb2.LoadBalancerAccumulatedStatsRequest +) +LoadBalancerAccumulatedStatsResponse = ( + messages_pb2.LoadBalancerAccumulatedStatsResponse +) MethodStats = messages_pb2.LoadBalancerAccumulatedStatsResponse.MethodStats RpcsByPeer = messages_pb2.LoadBalancerStatsResponse.RpcsByPeer @@ -41,13 +45,14 @@ class LoadBalancerStatsServiceClient(framework.rpc.grpc.GrpcClientHelper): STATS_PARTIAL_RESULTS_TIMEOUT_SEC = 1200 STATS_ACCUMULATED_RESULTS_TIMEOUT_SEC = 600 - def __init__(self, - channel: grpc.Channel, - *, - log_target: Optional[str] = ''): - super().__init__(channel, - test_pb2_grpc.LoadBalancerStatsServiceStub, - log_target=log_target) + def __init__( + self, channel: grpc.Channel, *, log_target: Optional[str] = "" + ): + super().__init__( + channel, + test_pb2_grpc.LoadBalancerStatsServiceStub, + log_target=log_target, + ) def get_client_stats( self, @@ -58,40 +63,43 @@ def get_client_stats( if timeout_sec is None: timeout_sec = self.STATS_PARTIAL_RESULTS_TIMEOUT_SEC - return self.call_unary_with_deadline(rpc='GetClientStats', - req=_LoadBalancerStatsRequest( - num_rpcs=num_rpcs, - timeout_sec=timeout_sec), - deadline_sec=timeout_sec, - log_level=logging.INFO) + return self.call_unary_with_deadline( + rpc="GetClientStats", + req=_LoadBalancerStatsRequest( + num_rpcs=num_rpcs, timeout_sec=timeout_sec + ), + deadline_sec=timeout_sec, + log_level=logging.INFO, + ) def get_client_accumulated_stats( - self, - *, - timeout_sec: Optional[int] = None + self, *, timeout_sec: Optional[int] = None ) -> LoadBalancerAccumulatedStatsResponse: if timeout_sec is None: timeout_sec = self.STATS_ACCUMULATED_RESULTS_TIMEOUT_SEC return self.call_unary_with_deadline( - rpc='GetClientAccumulatedStats', + rpc="GetClientAccumulatedStats", req=_LoadBalancerAccumulatedStatsRequest(), deadline_sec=timeout_sec, - log_level=logging.INFO) + log_level=logging.INFO, + ) -class XdsUpdateClientConfigureServiceClient(framework.rpc.grpc.GrpcClientHelper - ): +class XdsUpdateClientConfigureServiceClient( + framework.rpc.grpc.GrpcClientHelper +): stub: test_pb2_grpc.XdsUpdateClientConfigureServiceStub CONFIGURE_TIMEOUT_SEC: int = 5 - def __init__(self, - channel: grpc.Channel, - *, - log_target: Optional[str] = ''): - super().__init__(channel, - test_pb2_grpc.XdsUpdateClientConfigureServiceStub, - log_target=log_target) + def __init__( + self, channel: grpc.Channel, *, log_target: Optional[str] = "" + ): + super().__init__( + channel, + test_pb2_grpc.XdsUpdateClientConfigureServiceStub, + log_target=log_target, + ) def configure( self, @@ -104,54 +112,62 @@ def configure( request = messages_pb2.ClientConfigureRequest() for rpc_type in rpc_types: request.types.append( - messages_pb2.ClientConfigureRequest.RpcType.Value(rpc_type)) + messages_pb2.ClientConfigureRequest.RpcType.Value(rpc_type) + ) if metadata: for entry in metadata: request.metadata.append( messages_pb2.ClientConfigureRequest.Metadata( type=messages_pb2.ClientConfigureRequest.RpcType.Value( - entry[0]), + entry[0] + ), key=entry[1], value=entry[2], - )) + ) + ) if app_timeout: request.timeout_sec = app_timeout # Configure's response is empty - self.call_unary_with_deadline(rpc='Configure', - req=request, - deadline_sec=timeout_sec, - log_level=logging.INFO) + self.call_unary_with_deadline( + rpc="Configure", + req=request, + deadline_sec=timeout_sec, + log_level=logging.INFO, + ) class XdsUpdateHealthServiceClient(framework.rpc.grpc.GrpcClientHelper): stub: test_pb2_grpc.XdsUpdateHealthServiceStub - def __init__(self, channel: grpc.Channel, log_target: Optional[str] = ''): - super().__init__(channel, - test_pb2_grpc.XdsUpdateHealthServiceStub, - log_target=log_target) + def __init__(self, channel: grpc.Channel, log_target: Optional[str] = ""): + super().__init__( + channel, + test_pb2_grpc.XdsUpdateHealthServiceStub, + log_target=log_target, + ) def set_serving(self): - self.call_unary_with_deadline(rpc='SetServing', - req=empty_pb2.Empty(), - log_level=logging.INFO) + self.call_unary_with_deadline( + rpc="SetServing", req=empty_pb2.Empty(), log_level=logging.INFO + ) def set_not_serving(self): - self.call_unary_with_deadline(rpc='SetNotServing', - req=empty_pb2.Empty(), - log_level=logging.INFO) + self.call_unary_with_deadline( + rpc="SetNotServing", req=empty_pb2.Empty(), log_level=logging.INFO + ) class HealthClient(framework.rpc.grpc.GrpcClientHelper): stub: health_pb2_grpc.HealthStub - def __init__(self, channel: grpc.Channel, log_target: Optional[str] = ''): - super().__init__(channel, - health_pb2_grpc.HealthStub, - log_target=log_target) + def __init__(self, channel: grpc.Channel, log_target: Optional[str] = ""): + super().__init__( + channel, health_pb2_grpc.HealthStub, log_target=log_target + ) def check_health(self): return self.call_unary_with_deadline( - rpc='Check', + rpc="Check", req=health_pb2.HealthCheckRequest(), - log_level=logging.INFO) + log_level=logging.INFO, + ) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py index 38c8d9a90b2e3..073f9de61be7d 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py @@ -30,7 +30,9 @@ # Type aliases _timedelta = datetime.timedelta _LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient -_XdsUpdateClientConfigureServiceClient = grpc_testing.XdsUpdateClientConfigureServiceClient +_XdsUpdateClientConfigureServiceClient = ( + grpc_testing.XdsUpdateClientConfigureServiceClient +) _ChannelzServiceClient = grpc_channelz.ChannelzServiceClient _ChannelzChannel = grpc_channelz.Channel _ChannelzChannelState = grpc_channelz.ChannelState @@ -44,17 +46,20 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp): Represents RPC services implemented in Client component of the xds test app. https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client """ + # A unique string identifying each client replica. Used in logging. hostname: str - def __init__(self, - *, - ip: str, - rpc_port: int, - server_target: str, - hostname: str, - rpc_host: Optional[str] = None, - maintenance_port: Optional[int] = None): + def __init__( + self, + *, + ip: str, + rpc_port: int, + server_target: str, + hostname: str, + rpc_host: Optional[str] = None, + maintenance_port: Optional[int] = None, + ): super().__init__(rpc_host=(rpc_host or ip)) self.ip = ip self.rpc_port = rpc_port @@ -67,28 +72,32 @@ def __init__(self, def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient: return _LoadBalancerStatsServiceClient( self._make_channel(self.rpc_port), - log_target=f'{self.hostname}:{self.rpc_port}') + log_target=f"{self.hostname}:{self.rpc_port}", + ) @property @functools.lru_cache(None) def update_config(self): return _XdsUpdateClientConfigureServiceClient( self._make_channel(self.rpc_port), - log_target=f'{self.hostname}:{self.rpc_port}') + log_target=f"{self.hostname}:{self.rpc_port}", + ) @property @functools.lru_cache(None) def channelz(self) -> _ChannelzServiceClient: return _ChannelzServiceClient( self._make_channel(self.maintenance_port), - log_target=f'{self.hostname}:{self.maintenance_port}') + log_target=f"{self.hostname}:{self.maintenance_port}", + ) @property @functools.lru_cache(None) def csds(self) -> _CsdsClient: return _CsdsClient( self._make_channel(self.maintenance_port), - log_target=f'{self.hostname}:{self.maintenance_port}') + log_target=f"{self.hostname}:{self.maintenance_port}", + ) def get_load_balancer_stats( self, @@ -100,7 +109,8 @@ def get_load_balancer_stats( Shortcut to LoadBalancerStatsServiceClient.get_client_stats() """ return self.load_balancer_stats.get_client_stats( - num_rpcs=num_rpcs, timeout_sec=timeout_sec) + num_rpcs=num_rpcs, timeout_sec=timeout_sec + ) def get_load_balancer_accumulated_stats( self, @@ -109,7 +119,8 @@ def get_load_balancer_accumulated_stats( ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse: """Shortcut to LoadBalancerStatsServiceClient.get_client_accumulated_stats()""" return self.load_balancer_stats.get_client_accumulated_stats( - timeout_sec=timeout_sec) + timeout_sec=timeout_sec + ) def wait_for_active_server_channel(self) -> _ChannelzChannel: """Wait for the channel to the server to transition to READY. @@ -121,33 +132,47 @@ def wait_for_active_server_channel(self) -> _ChannelzChannel: def get_active_server_channel_socket(self) -> _ChannelzSocket: channel = self.find_server_channel_with_state( - _ChannelzChannelState.READY) + _ChannelzChannelState.READY + ) # Get the first subchannel of the active channel to the server. logger.debug( - '[%s] Retrieving client -> server socket, ' - 'channel_id: %s, subchannel: %s', self.hostname, - channel.ref.channel_id, channel.subchannel_ref[0].name) + ( + "[%s] Retrieving client -> server socket, " + "channel_id: %s, subchannel: %s" + ), + self.hostname, + channel.ref.channel_id, + channel.subchannel_ref[0].name, + ) subchannel, *subchannels = list( - self.channelz.list_channel_subchannels(channel)) + self.channelz.list_channel_subchannels(channel) + ) if subchannels: - logger.warning('[%s] Unexpected subchannels: %r', self.hostname, - subchannels) + logger.warning( + "[%s] Unexpected subchannels: %r", self.hostname, subchannels + ) # Get the first socket of the subchannel socket, *sockets = list( - self.channelz.list_subchannels_sockets(subchannel)) + self.channelz.list_subchannels_sockets(subchannel) + ) if sockets: - logger.warning('[%s] Unexpected sockets: %r', self.hostname, - subchannels) - logger.debug('[%s] Found client -> server socket: %s', self.hostname, - socket.ref.name) + logger.warning( + "[%s] Unexpected sockets: %r", self.hostname, subchannels + ) + logger.debug( + "[%s] Found client -> server socket: %s", + self.hostname, + socket.ref.name, + ) return socket def wait_for_server_channel_state( - self, - state: _ChannelzChannelState, - *, - timeout: Optional[_timedelta] = None, - rpc_deadline: Optional[_timedelta] = None) -> _ChannelzChannel: + self, + state: _ChannelzChannelState, + *, + timeout: Optional[_timedelta] = None, + rpc_deadline: Optional[_timedelta] = None, + ) -> _ChannelzChannel: # When polling for a state, prefer smaller wait times to avoid # exhausting all allowed time on a single long RPC. if rpc_deadline is None: @@ -157,44 +182,61 @@ def wait_for_server_channel_state( retryer = retryers.exponential_retryer_with_timeout( wait_min=_timedelta(seconds=10), wait_max=_timedelta(seconds=25), - timeout=_timedelta(minutes=5) if timeout is None else timeout) - - logger.info('[%s] Waiting to report a %s channel to %s', self.hostname, - _ChannelzChannelState.Name(state), self.server_target) - channel = retryer(self.find_server_channel_with_state, - state, - rpc_deadline=rpc_deadline) - logger.info('[%s] Channel to %s transitioned to state %s: %s', - self.hostname, self.server_target, - _ChannelzChannelState.Name(state), - _ChannelzServiceClient.channel_repr(channel)) + timeout=_timedelta(minutes=5) if timeout is None else timeout, + ) + + logger.info( + "[%s] Waiting to report a %s channel to %s", + self.hostname, + _ChannelzChannelState.Name(state), + self.server_target, + ) + channel = retryer( + self.find_server_channel_with_state, + state, + rpc_deadline=rpc_deadline, + ) + logger.info( + "[%s] Channel to %s transitioned to state %s: %s", + self.hostname, + self.server_target, + _ChannelzChannelState.Name(state), + _ChannelzServiceClient.channel_repr(channel), + ) return channel def find_server_channel_with_state( - self, - state: _ChannelzChannelState, - *, - rpc_deadline: Optional[_timedelta] = None, - check_subchannel=True) -> _ChannelzChannel: + self, + state: _ChannelzChannelState, + *, + rpc_deadline: Optional[_timedelta] = None, + check_subchannel=True, + ) -> _ChannelzChannel: rpc_params = {} if rpc_deadline is not None: - rpc_params['deadline_sec'] = rpc_deadline.total_seconds() + rpc_params["deadline_sec"] = rpc_deadline.total_seconds() for channel in self.get_server_channels(**rpc_params): channel_state: _ChannelzChannelState = channel.data.state.state - logger.info('[%s] Server channel: %s', self.hostname, - _ChannelzServiceClient.channel_repr(channel)) + logger.info( + "[%s] Server channel: %s", + self.hostname, + _ChannelzServiceClient.channel_repr(channel), + ) if channel_state is state: if check_subchannel: # When requested, check if the channel has at least # one subchannel in the requested state. try: subchannel = self.find_subchannel_with_state( - channel, state, **rpc_params) + channel, state, **rpc_params + ) logger.info( - '[%s] Found subchannel in state %s: %s', - self.hostname, _ChannelzChannelState.Name(state), - _ChannelzServiceClient.subchannel_repr(subchannel)) + "[%s] Found subchannel in state %s: %s", + self.hostname, + _ChannelzChannelState.Name(state), + _ChannelzServiceClient.subchannel_repr(subchannel), + ) except self.NotFound as e: # Otherwise, keep searching. logger.info(e.message) @@ -202,32 +244,39 @@ def find_server_channel_with_state( return channel raise self.NotFound( - f'[{self.hostname}] Client has no ' - f'{_ChannelzChannelState.Name(state)} channel with the server') + f"[{self.hostname}] Client has no " + f"{_ChannelzChannelState.Name(state)} channel with the server" + ) def get_server_channels(self, **kwargs) -> Iterable[_ChannelzChannel]: - return self.channelz.find_channels_for_target(self.server_target, - **kwargs) + return self.channelz.find_channels_for_target( + self.server_target, **kwargs + ) - def find_subchannel_with_state(self, channel: _ChannelzChannel, - state: _ChannelzChannelState, - **kwargs) -> _ChannelzSubchannel: + def find_subchannel_with_state( + self, channel: _ChannelzChannel, state: _ChannelzChannelState, **kwargs + ) -> _ChannelzSubchannel: subchannels = self.channelz.list_channel_subchannels(channel, **kwargs) for subchannel in subchannels: if subchannel.data.state.state is state: return subchannel - raise self.NotFound(f'[{self.hostname}] Not found ' - f'a {_ChannelzChannelState.Name(state)} subchannel ' - f'for channel_id {channel.ref.channel_id}') + raise self.NotFound( + f"[{self.hostname}] Not found " + f"a {_ChannelzChannelState.Name(state)} subchannel " + f"for channel_id {channel.ref.channel_id}" + ) - def find_subchannels_with_state(self, state: _ChannelzChannelState, - **kwargs) -> List[_ChannelzSubchannel]: + def find_subchannels_with_state( + self, state: _ChannelzChannelState, **kwargs + ) -> List[_ChannelzSubchannel]: subchannels = [] for channel in self.channelz.find_channels_for_target( - self.server_target, **kwargs): + self.server_target, **kwargs + ): for subchannel in self.channelz.list_channel_subchannels( - channel, **kwargs): + channel, **kwargs + ): if subchannel.data.state.state is state: subchannels.append(subchannel) return subchannels diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/base_runner.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/base_runner.py index 02e581c9b3ff8..096712ccab75f 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/base_runner.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/base_runner.py @@ -28,7 +28,7 @@ from framework.helpers import logs flags.adopt_module_key_flags(logs) -_LOGS_SUBDIR = 'test_app_logs' +_LOGS_SUBDIR = "test_app_logs" class RunnerError(Exception): @@ -53,13 +53,13 @@ def should_collect_logs(self) -> bool: @functools.lru_cache(None) def logs_subdir(self) -> pathlib.Path: if not self.should_collect_logs: - raise FileNotFoundError('Log collection is not enabled.') + raise FileNotFoundError("Log collection is not enabled.") return self._logs_subdir @property def log_stop_event(self) -> threading.Event: if not self.should_collect_logs: - raise ValueError('Log collection is not enabled.') + raise ValueError("Log collection is not enabled.") return self._log_stop_event def maybe_stop_logging(self): @@ -76,28 +76,30 @@ def cleanup(self, *, force=False): @classmethod def _logs_explorer_link_from_params( - cls, - *, - gcp_ui_url: str, - gcp_project: str, - query: Dict[str, str], - request: Optional[Dict[str, str]] = None) -> str: - req_merged = {'query': cls._logs_explorer_query(query)} + cls, + *, + gcp_ui_url: str, + gcp_project: str, + query: Dict[str, str], + request: Optional[Dict[str, str]] = None, + ) -> str: + req_merged = {"query": cls._logs_explorer_query(query)} if request is not None: req_merged.update(request) req = cls._logs_explorer_request(req_merged) - return f'https://{gcp_ui_url}/logs/query;{req}?project={gcp_project}' + return f"https://{gcp_ui_url}/logs/query;{req}?project={gcp_project}" @classmethod def _logs_explorer_query(cls, query: Dict[str, str]) -> str: - return '\n'.join(f'{k}="{v}"' for k, v in query.items()) + return "\n".join(f'{k}="{v}"' for k, v in query.items()) @classmethod def _logs_explorer_request(cls, req: Dict[str, str]) -> str: - return ';'.join( - f'{k}={cls._logs_explorer_quote(v)}' for k, v in req.items()) + return ";".join( + f"{k}={cls._logs_explorer_quote(v)}" for k, v in req.items() + ) @classmethod def _logs_explorer_quote(cls, value: str) -> str: - return urllib.parse.quote_plus(value, safe=':') + return urllib.parse.quote_plus(value, safe=":") diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_base_runner.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_base_runner.py index f877656650fa6..6f302cbce05c5 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_base_runner.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_base_runner.py @@ -55,9 +55,9 @@ class KubernetesBaseRunner(base_runner.BaseRunner, metaclass=ABCMeta): # Pylint wants abstract classes to override abstract methods. # pylint: disable=abstract-method - TEMPLATE_DIR_NAME = 'kubernetes-manifests' - TEMPLATE_DIR_RELATIVE_PATH = f'../../../../{TEMPLATE_DIR_NAME}' - ROLE_WORKLOAD_IDENTITY_USER = 'roles/iam.workloadIdentityUser' + TEMPLATE_DIR_NAME = "kubernetes-manifests" + TEMPLATE_DIR_RELATIVE_PATH = f"../../../../{TEMPLATE_DIR_NAME}" + ROLE_WORKLOAD_IDENTITY_USER = "roles/iam.workloadIdentityUser" pod_port_forwarders: List[k8s.PortForwarder] pod_log_collectors: List[k8s.PodLogCollector] @@ -70,7 +70,7 @@ class KubernetesBaseRunner(base_runner.BaseRunner, metaclass=ABCMeta): gcp_ui_url: str # Fields with default values. - namespace_template: str = 'namespace.yaml' + namespace_template: str = "namespace.yaml" reuse_namespace: bool = False # Mutable state. Describes the current run. @@ -84,16 +84,18 @@ class KubernetesBaseRunner(base_runner.BaseRunner, metaclass=ABCMeta): # The history of all runs performed by this runner. run_history: List[RunHistory] - def __init__(self, - k8s_namespace: k8s.KubernetesNamespace, - *, - deployment_name: str, - image_name: str, - gcp_project: str, - gcp_service_account: str, - gcp_ui_url: str, - namespace_template: Optional[str] = 'namespace.yaml', - reuse_namespace: bool = False): + def __init__( + self, + k8s_namespace: k8s.KubernetesNamespace, + *, + deployment_name: str, + image_name: str, + gcp_project: str, + gcp_service_account: str, + gcp_ui_url: str, + namespace_template: Optional[str] = "namespace.yaml", + reuse_namespace: bool = False, + ): super().__init__() # Required fields. @@ -124,11 +126,13 @@ def run(self, **kwargs): if self.time_start_completed: raise RuntimeError( f"Deployment {self.deployment_name}: has already been" - f" started at {self.time_start_completed.isoformat()}") + f" started at {self.time_start_completed.isoformat()}" + ) else: raise RuntimeError( f"Deployment {self.deployment_name}: start has already been" - f" requested at {self.time_start_requested.isoformat()}") + f" requested at {self.time_start_requested.isoformat()}" + ) self._reset_state() self.time_start_requested = _datetime.now() @@ -138,7 +142,8 @@ def run(self, **kwargs): self.namespace = self._reuse_namespace() if not self.namespace: self.namespace = self._create_namespace( - self.namespace_template, namespace_name=self.k8s_namespace.name) + self.namespace_template, namespace_name=self.k8s_namespace.name + ) def _start_completed(self): self.time_start_completed = _datetime.now() @@ -159,12 +164,14 @@ def _reset_state(self): if self.pod_port_forwarders: logger.warning( "Port forwarders weren't cleaned up from the past run: %s", - len(self.pod_port_forwarders)) + len(self.pod_port_forwarders), + ) if self.pod_log_collectors: logger.warning( "Pod log collectors weren't cleaned up from the past run: %s", - len(self.pod_log_collectors)) + len(self.pod_log_collectors), + ) self.namespace = None self.deployment = None @@ -192,8 +199,11 @@ def stop_pod_dependencies(self, *, log_drain_sec: int = 0): for pod_log_collector in self.pod_log_collectors: if log_drain_sec > 0 and not pod_log_collector.drain_event.is_set(): - logger.info("Draining logs for %s, timeout %i sec", - pod_log_collector.pod_name, log_drain_sec) + logger.info( + "Draining logs for %s, timeout %i sec", + pod_log_collector.pod_name, + log_drain_sec, + ) # The close will happen normally at the next message. pod_log_collector.drain_event.wait(timeout=log_drain_sec) # Note this will be called from the main thread and may cause @@ -207,10 +217,12 @@ def get_pod_restarts(self, deployment: k8s.V1Deployment) -> int: return 0 total_restart: int = 0 pods: List[k8s.V1Pod] = self.k8s_namespace.list_deployment_pods( - deployment) + deployment + ) for pod in pods: - total_restart += sum(status.restart_count - for status in pod.status.container_statuses) + total_restart += sum( + status.restart_count for status in pod.status.container_statuses + ) return total_restart @classmethod @@ -233,8 +245,9 @@ def _manifests_from_str(cls, document): @classmethod def _template_file_from_name(cls, template_name): - templates_path = (pathlib.Path(__file__).parent / - cls.TEMPLATE_DIR_RELATIVE_PATH) + templates_path = ( + pathlib.Path(__file__).parent / cls.TEMPLATE_DIR_RELATIVE_PATH + ) return templates_path.joinpath(template_name).resolve() def _create_from_template(self, template_name, **kwargs) -> object: @@ -242,23 +255,31 @@ def _create_from_template(self, template_name, **kwargs) -> object: logger.debug("Loading k8s manifest template: %s", template_file) yaml_doc = self._render_template(template_file, **kwargs) - logger.info("Rendered template %s/%s:\n%s", self.TEMPLATE_DIR_NAME, - template_name, self._highlighter.highlight(yaml_doc)) + logger.info( + "Rendered template %s/%s:\n%s", + self.TEMPLATE_DIR_NAME, + template_name, + self._highlighter.highlight(yaml_doc), + ) manifests = self._manifests_from_str(yaml_doc) manifest = next(manifests) # Error out on multi-document yaml if next(manifests, False): - raise _RunnerError('Exactly one document expected in manifest ' - f'{template_file}') + raise _RunnerError( + f"Exactly one document expected in manifest {template_file}" + ) k8s_objects = self.k8s_namespace.create_single_resource(manifest) if len(k8s_objects) != 1: - raise _RunnerError('Expected exactly one object must created from ' - f'manifest {template_file}') + raise _RunnerError( + "Expected exactly one object must created from " + f"manifest {template_file}" + ) - logger.info('%s %s created', k8s_objects[0].kind, - k8s_objects[0].metadata.name) + logger.info( + "%s %s created", k8s_objects[0].kind, k8s_objects[0].metadata.name + ) return k8s_objects[0] def _reuse_deployment(self, deployment_name) -> k8s.V1Deployment: @@ -277,68 +298,101 @@ def _reuse_namespace(self) -> k8s.V1Namespace: def _create_namespace(self, template, **kwargs) -> k8s.V1Namespace: namespace = self._create_from_template(template, **kwargs) if not isinstance(namespace, k8s.V1Namespace): - raise _RunnerError('Expected V1Namespace to be created ' - f'from manifest {template}') - if namespace.metadata.name != kwargs['namespace_name']: - raise _RunnerError('V1Namespace created with unexpected name: ' - f'{namespace.metadata.name}') - logger.debug('V1Namespace %s created at %s', - namespace.metadata.self_link, - namespace.metadata.creation_timestamp) + raise _RunnerError( + f"Expected V1Namespace to be created from manifest {template}" + ) + if namespace.metadata.name != kwargs["namespace_name"]: + raise _RunnerError( + "V1Namespace created with unexpected name: " + f"{namespace.metadata.name}" + ) + logger.debug( + "V1Namespace %s created at %s", + namespace.metadata.self_link, + namespace.metadata.creation_timestamp, + ) return namespace @classmethod - def _get_workload_identity_member_name(cls, project, namespace_name, - service_account_name): + def _get_workload_identity_member_name( + cls, project, namespace_name, service_account_name + ): """ Returns workload identity member name used to authenticate Kubernetes service accounts. https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity """ - return (f'serviceAccount:{project}.svc.id.goog' - f'[{namespace_name}/{service_account_name}]') - - def _grant_workload_identity_user(self, *, gcp_iam, gcp_service_account, - service_account_name): + return ( + f"serviceAccount:{project}.svc.id.goog" + f"[{namespace_name}/{service_account_name}]" + ) + + def _grant_workload_identity_user( + self, *, gcp_iam, gcp_service_account, service_account_name + ): workload_identity_member = self._get_workload_identity_member_name( - gcp_iam.project, self.k8s_namespace.name, service_account_name) - logger.info('Granting %s to %s for GCP Service Account %s', - self.ROLE_WORKLOAD_IDENTITY_USER, workload_identity_member, - gcp_service_account) + gcp_iam.project, self.k8s_namespace.name, service_account_name + ) + logger.info( + "Granting %s to %s for GCP Service Account %s", + self.ROLE_WORKLOAD_IDENTITY_USER, + workload_identity_member, + gcp_service_account, + ) gcp_iam.add_service_account_iam_policy_binding( - gcp_service_account, self.ROLE_WORKLOAD_IDENTITY_USER, - workload_identity_member) - - def _revoke_workload_identity_user(self, *, gcp_iam, gcp_service_account, - service_account_name): + gcp_service_account, + self.ROLE_WORKLOAD_IDENTITY_USER, + workload_identity_member, + ) + + def _revoke_workload_identity_user( + self, *, gcp_iam, gcp_service_account, service_account_name + ): workload_identity_member = self._get_workload_identity_member_name( - gcp_iam.project, self.k8s_namespace.name, service_account_name) - logger.info('Revoking %s from %s for GCP Service Account %s', - self.ROLE_WORKLOAD_IDENTITY_USER, workload_identity_member, - gcp_service_account) + gcp_iam.project, self.k8s_namespace.name, service_account_name + ) + logger.info( + "Revoking %s from %s for GCP Service Account %s", + self.ROLE_WORKLOAD_IDENTITY_USER, + workload_identity_member, + gcp_service_account, + ) try: gcp_iam.remove_service_account_iam_policy_binding( - gcp_service_account, self.ROLE_WORKLOAD_IDENTITY_USER, - workload_identity_member) + gcp_service_account, + self.ROLE_WORKLOAD_IDENTITY_USER, + workload_identity_member, + ) except gcp.api.Error as error: - logger.warning('Failed %s from %s for Service Account %s: %r', - self.ROLE_WORKLOAD_IDENTITY_USER, - workload_identity_member, gcp_service_account, error) + logger.warning( + "Failed %s from %s for Service Account %s: %r", + self.ROLE_WORKLOAD_IDENTITY_USER, + workload_identity_member, + gcp_service_account, + error, + ) - def _create_service_account(self, template, - **kwargs) -> k8s.V1ServiceAccount: + def _create_service_account( + self, template, **kwargs + ) -> k8s.V1ServiceAccount: resource = self._create_from_template(template, **kwargs) if not isinstance(resource, k8s.V1ServiceAccount): - raise _RunnerError('Expected V1ServiceAccount to be created ' - f'from manifest {template}') - if resource.metadata.name != kwargs['service_account_name']: - raise _RunnerError('V1ServiceAccount created with unexpected name: ' - f'{resource.metadata.name}') - logger.debug('V1ServiceAccount %s created at %s', - resource.metadata.self_link, - resource.metadata.creation_timestamp) + raise _RunnerError( + "Expected V1ServiceAccount to be created " + f"from manifest {template}" + ) + if resource.metadata.name != kwargs["service_account_name"]: + raise _RunnerError( + "V1ServiceAccount created with unexpected name: " + f"{resource.metadata.name}" + ) + logger.debug( + "V1ServiceAccount %s created at %s", + resource.metadata.self_link, + resource.metadata.creation_timestamp, + ) return resource def _create_deployment(self, template, **kwargs) -> k8s.V1Deployment: @@ -346,14 +400,15 @@ def _create_deployment(self, template, **kwargs) -> k8s.V1Deployment: # the rest of the _create_* methods, which pass kwargs as-is # to _create_from_template(), so that the kwargs dict is unpacked into # template variables and their values. - if 'deployment_name' not in kwargs: - raise TypeError('Missing required keyword-only argument: ' - 'deployment_name') + if "deployment_name" not in kwargs: + raise TypeError( + "Missing required keyword-only argument: deployment_name" + ) # Automatically apply random deployment_id to use in the matchLabels # to prevent selecting pods in the same namespace belonging to # a different deployment. - if 'deployment_id' not in kwargs: + if "deployment_id" not in kwargs: rand_id: str = framework.helpers.rand.rand_string(lowercase=True) # Fun edge case: when rand_string() happen to generate numbers only, # yaml interprets deployment_id label value as an integer, @@ -363,136 +418,167 @@ def _create_deployment(self, template, **kwargs) -> k8s.V1Deployment: # Prepending deployment name forces deployment_id into a string, # as well as it's just a better description. self.deployment_id = f'{kwargs["deployment_name"]}-{rand_id}' - kwargs['deployment_id'] = self.deployment_id + kwargs["deployment_id"] = self.deployment_id else: - self.deployment_id = kwargs['deployment_id'] + self.deployment_id = kwargs["deployment_id"] deployment = self._create_from_template(template, **kwargs) if not isinstance(deployment, k8s.V1Deployment): - raise _RunnerError('Expected V1Deployment to be created ' - f'from manifest {template}') - if deployment.metadata.name != kwargs['deployment_name']: - raise _RunnerError('V1Deployment created with unexpected name: ' - f'{deployment.metadata.name}') - logger.debug('V1Deployment %s created at %s', - deployment.metadata.self_link, - deployment.metadata.creation_timestamp) + raise _RunnerError( + f"Expected V1Deployment to be created from manifest {template}" + ) + if deployment.metadata.name != kwargs["deployment_name"]: + raise _RunnerError( + "V1Deployment created with unexpected name: " + f"{deployment.metadata.name}" + ) + logger.debug( + "V1Deployment %s created at %s", + deployment.metadata.self_link, + deployment.metadata.creation_timestamp, + ) return deployment def _create_service(self, template, **kwargs) -> k8s.V1Service: service = self._create_from_template(template, **kwargs) if not isinstance(service, k8s.V1Service): - raise _RunnerError('Expected V1Service to be created ' - f'from manifest {template}') - if service.metadata.name != kwargs['service_name']: - raise _RunnerError('V1Service created with unexpected name: ' - f'{service.metadata.name}') - logger.debug('V1Service %s created at %s', service.metadata.self_link, - service.metadata.creation_timestamp) + raise _RunnerError( + f"Expected V1Service to be created from manifest {template}" + ) + if service.metadata.name != kwargs["service_name"]: + raise _RunnerError( + "V1Service created with unexpected name: " + f"{service.metadata.name}" + ) + logger.debug( + "V1Service %s created at %s", + service.metadata.self_link, + service.metadata.creation_timestamp, + ) return service def _delete_deployment(self, name, wait_for_deletion=True): self.stop_pod_dependencies() - logger.info('Deleting deployment %s', name) + logger.info("Deleting deployment %s", name) try: self.k8s_namespace.delete_deployment(name) except (retryers.RetryError, k8s.NotFound) as e: - logger.info('Deployment %s deletion failed: %s', name, e) + logger.info("Deployment %s deletion failed: %s", name, e) return if wait_for_deletion: self.k8s_namespace.wait_for_deployment_deleted(name) - logger.debug('Deployment %s deleted', name) + logger.debug("Deployment %s deleted", name) def _delete_service(self, name, wait_for_deletion=True): - logger.info('Deleting service %s', name) + logger.info("Deleting service %s", name) try: self.k8s_namespace.delete_service(name) except (retryers.RetryError, k8s.NotFound) as e: - logger.info('Service %s deletion failed: %s', name, e) + logger.info("Service %s deletion failed: %s", name, e) return if wait_for_deletion: self.k8s_namespace.wait_for_service_deleted(name) - logger.debug('Service %s deleted', name) + logger.debug("Service %s deleted", name) def _delete_service_account(self, name, wait_for_deletion=True): - logger.info('Deleting service account %s', name) + logger.info("Deleting service account %s", name) try: self.k8s_namespace.delete_service_account(name) except (retryers.RetryError, k8s.NotFound) as e: - logger.info('Service account %s deletion failed: %s', name, e) + logger.info("Service account %s deletion failed: %s", name, e) return if wait_for_deletion: self.k8s_namespace.wait_for_service_account_deleted(name) - logger.debug('Service account %s deleted', name) + logger.debug("Service account %s deleted", name) def delete_namespace(self, wait_for_deletion=True): - logger.info('Deleting namespace %s', self.k8s_namespace.name) + logger.info("Deleting namespace %s", self.k8s_namespace.name) try: self.k8s_namespace.delete() except (retryers.RetryError, k8s.NotFound) as e: - logger.info('Namespace %s deletion failed: %s', - self.k8s_namespace.name, e) + logger.info( + "Namespace %s deletion failed: %s", self.k8s_namespace.name, e + ) return if wait_for_deletion: self.k8s_namespace.wait_for_namespace_deleted() - logger.debug('Namespace %s deleted', self.k8s_namespace.name) + logger.debug("Namespace %s deleted", self.k8s_namespace.name) def _wait_deployment_with_available_replicas(self, name, count=1, **kwargs): logger.info( - 'Waiting for deployment %s to report %s ' - 'available replica(s)', name, count) + "Waiting for deployment %s to report %s available replica(s)", + name, + count, + ) self.k8s_namespace.wait_for_deployment_available_replicas( - name, count, **kwargs) + name, count, **kwargs + ) deployment = self.k8s_namespace.get_deployment(name) - logger.info('Deployment %s has %i replicas available', - deployment.metadata.name, - deployment.status.available_replicas) - - def _wait_deployment_pod_count(self, - deployment: k8s.V1Deployment, - count: int = 1, - **kwargs) -> List[str]: - logger.info('Waiting for deployment %s to initialize %s pod(s)', - deployment.metadata.name, count) + logger.info( + "Deployment %s has %i replicas available", + deployment.metadata.name, + deployment.status.available_replicas, + ) + + def _wait_deployment_pod_count( + self, deployment: k8s.V1Deployment, count: int = 1, **kwargs + ) -> List[str]: + logger.info( + "Waiting for deployment %s to initialize %s pod(s)", + deployment.metadata.name, + count, + ) self.k8s_namespace.wait_for_deployment_replica_count( - deployment, count, **kwargs) + deployment, count, **kwargs + ) pods = self.k8s_namespace.list_deployment_pods(deployment) pod_names = [pod.metadata.name for pod in pods] - logger.info('Deployment %s initialized %i pod(s): %s', - deployment.metadata.name, count, pod_names) + logger.info( + "Deployment %s initialized %i pod(s): %s", + deployment.metadata.name, + count, + pod_names, + ) # Pods may not be started yet, just return the names. return pod_names def _wait_pod_started(self, name, **kwargs) -> k8s.V1Pod: - logger.info('Waiting for pod %s to start', name) + logger.info("Waiting for pod %s to start", name) self.k8s_namespace.wait_for_pod_started(name, **kwargs) pod = self.k8s_namespace.get_pod(name) - logger.info('Pod %s ready, IP: %s', pod.metadata.name, - pod.status.pod_ip) + logger.info( + "Pod %s ready, IP: %s", pod.metadata.name, pod.status.pod_ip + ) return pod - def _start_port_forwarding_pod(self, pod: k8s.V1Pod, - remote_port: int) -> k8s.PortForwarder: - logger.info('LOCAL DEV MODE: Enabling port forwarding to %s:%s', - pod.status.pod_ip, remote_port) + def _start_port_forwarding_pod( + self, pod: k8s.V1Pod, remote_port: int + ) -> k8s.PortForwarder: + logger.info( + "LOCAL DEV MODE: Enabling port forwarding to %s:%s", + pod.status.pod_ip, + remote_port, + ) port_forwarder = self.k8s_namespace.port_forward_pod(pod, remote_port) self.pod_port_forwarders.append(port_forwarder) return port_forwarder - def _start_logging_pod(self, - pod: k8s.V1Pod, - *, - log_to_stdout: bool = False) -> k8s.PodLogCollector: + def _start_logging_pod( + self, pod: k8s.V1Pod, *, log_to_stdout: bool = False + ) -> k8s.PodLogCollector: pod_name = pod.metadata.name - logfile_name = f'{self.k8s_namespace.name}_{pod_name}.log' + logfile_name = f"{self.k8s_namespace.name}_{pod_name}.log" log_path = self.logs_subdir / logfile_name - logger.info('Enabling log collection from pod %s to %s', pod_name, - log_path.relative_to(self.logs_subdir.parent.parent)) + logger.info( + "Enabling log collection from pod %s to %s", + pod_name, + log_path.relative_to(self.logs_subdir.parent.parent), + ) pod_log_collector = self.k8s_namespace.pod_start_logging( pod_name=pod_name, log_path=log_path, @@ -501,24 +587,29 @@ def _start_logging_pod(self, # Timestamps are enabled because not all language implementations # include them. # TODO(sergiitk): Make this setting language-specific. - log_timestamps=True) + log_timestamps=True, + ) self.pod_log_collectors.append(pod_log_collector) return pod_log_collector def _wait_service_neg(self, name, service_port, **kwargs): - logger.info('Waiting for NEG for service %s', name) + logger.info("Waiting for NEG for service %s", name) self.k8s_namespace.wait_for_service_neg(name, **kwargs) neg_name, neg_zones = self.k8s_namespace.get_service_neg( - name, service_port) - logger.info("Service %s: detected NEG=%s in zones=%s", name, neg_name, - neg_zones) + name, service_port + ) + logger.info( + "Service %s: detected NEG=%s in zones=%s", name, neg_name, neg_zones + ) def logs_explorer_link(self): """Prints GCP Logs Explorer link to all runs of the deployment.""" - self._logs_explorer_link(deployment_name=self.deployment_name, - namespace_name=self.k8s_namespace.name, - gcp_project=self.gcp_project, - gcp_ui_url=self.gcp_ui_url) + self._logs_explorer_link( + deployment_name=self.deployment_name, + namespace_name=self.k8s_namespace.name, + gcp_project=self.gcp_project, + gcp_ui_url=self.gcp_ui_url, + ) def logs_explorer_run_history_links(self): """Prints a separate GCP Logs Explorer link for each run *completed* by @@ -527,27 +618,31 @@ def logs_explorer_run_history_links(self): This excludes the current run, if it hasn't been completed. """ if not self.run_history: - logger.info('No completed deployments of %s', self.deployment_name) + logger.info("No completed deployments of %s", self.deployment_name) return for run in self.run_history: - self._logs_explorer_link(deployment_name=self.deployment_name, - namespace_name=self.k8s_namespace.name, - gcp_project=self.gcp_project, - gcp_ui_url=self.gcp_ui_url, - deployment_id=run.deployment_id, - start_time=run.time_start_requested, - end_time=run.time_stopped) + self._logs_explorer_link( + deployment_name=self.deployment_name, + namespace_name=self.k8s_namespace.name, + gcp_project=self.gcp_project, + gcp_ui_url=self.gcp_ui_url, + deployment_id=run.deployment_id, + start_time=run.time_start_requested, + end_time=run.time_stopped, + ) @classmethod - def _logs_explorer_link(cls, - *, - deployment_name: str, - namespace_name: str, - gcp_project: str, - gcp_ui_url: str, - deployment_id: Optional[str] = None, - start_time: Optional[_datetime] = None, - end_time: Optional[_datetime] = None): + def _logs_explorer_link( + cls, + *, + deployment_name: str, + namespace_name: str, + gcp_project: str, + gcp_ui_url: str, + deployment_id: Optional[str] = None, + start_time: Optional[_datetime] = None, + end_time: Optional[_datetime] = None, + ): """Output the link to test server/client logs in GCP Logs Explorer.""" if not start_time: start_time = _datetime.now() @@ -556,31 +651,34 @@ def _logs_explorer_link(cls, logs_start = _helper_datetime.iso8601_utc_time(start_time) logs_end = _helper_datetime.iso8601_utc_time(end_time) - request = {'timeRange': f'{logs_start}/{logs_end}'} + request = {"timeRange": f"{logs_start}/{logs_end}"} query = { - 'resource.type': 'k8s_container', - 'resource.labels.project_id': gcp_project, - 'resource.labels.container_name': deployment_name, - 'resource.labels.namespace_name': namespace_name, + "resource.type": "k8s_container", + "resource.labels.project_id": gcp_project, + "resource.labels.container_name": deployment_name, + "resource.labels.namespace_name": namespace_name, } if deployment_id: query['labels."k8s-pod/deployment_id"'] = deployment_id - link = cls._logs_explorer_link_from_params(gcp_ui_url=gcp_ui_url, - gcp_project=gcp_project, - query=query, - request=request) + link = cls._logs_explorer_link_from_params( + gcp_ui_url=gcp_ui_url, + gcp_project=gcp_project, + query=query, + request=request, + ) link_to = deployment_id if deployment_id else deployment_name # A whitespace at the end to indicate the end of the url. logger.info("GCP Logs Explorer link to %s:\n%s ", link_to, link) @classmethod - def _make_namespace_name(cls, resource_prefix: str, resource_suffix: str, - name: str) -> str: + def _make_namespace_name( + cls, resource_prefix: str, resource_suffix: str, name: str + ) -> str: """A helper to make consistent test app kubernetes namespace name for given resource prefix and suffix.""" parts = [resource_prefix, name] # Avoid trailing dash when the suffix is empty. if resource_suffix: parts.append(resource_suffix) - return '-'.join(parts) + return "-".join(parts) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_xds_client_runner.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_xds_client_runner.py index cc97125d1d3ee..d938fd04b3ccc 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_xds_client_runner.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_xds_client_runner.py @@ -26,7 +26,6 @@ class KubernetesClientRunner(k8s_base_runner.KubernetesBaseRunner): - # Required fields. xds_server_uri: str stats_port: int @@ -42,33 +41,36 @@ class KubernetesClientRunner(k8s_base_runner.KubernetesBaseRunner): gcp_iam: Optional[gcp.iam.IamV1] = None def __init__( # pylint: disable=too-many-locals - self, - k8s_namespace: k8s.KubernetesNamespace, - *, - deployment_name: str, - image_name: str, - td_bootstrap_image: str, - network='default', - xds_server_uri: Optional[str] = None, - gcp_api_manager: gcp.api.GcpApiManager, - gcp_project: str, - gcp_service_account: str, - service_account_name: Optional[str] = None, - stats_port: int = 8079, - deployment_template: str = 'client.deployment.yaml', - service_account_template: str = 'service-account.yaml', - reuse_namespace: bool = False, - namespace_template: Optional[str] = None, - debug_use_port_forwarding: bool = False, - enable_workload_identity: bool = True): - super().__init__(k8s_namespace, - deployment_name=deployment_name, - image_name=image_name, - gcp_project=gcp_project, - gcp_service_account=gcp_service_account, - gcp_ui_url=gcp_api_manager.gcp_ui_url, - namespace_template=namespace_template, - reuse_namespace=reuse_namespace) + self, + k8s_namespace: k8s.KubernetesNamespace, + *, + deployment_name: str, + image_name: str, + td_bootstrap_image: str, + network="default", + xds_server_uri: Optional[str] = None, + gcp_api_manager: gcp.api.GcpApiManager, + gcp_project: str, + gcp_service_account: str, + service_account_name: Optional[str] = None, + stats_port: int = 8079, + deployment_template: str = "client.deployment.yaml", + service_account_template: str = "service-account.yaml", + reuse_namespace: bool = False, + namespace_template: Optional[str] = None, + debug_use_port_forwarding: bool = False, + enable_workload_identity: bool = True, + ): + super().__init__( + k8s_namespace, + deployment_name=deployment_name, + image_name=image_name, + gcp_project=gcp_project, + gcp_service_account=gcp_service_account, + gcp_ui_url=gcp_api_manager.gcp_ui_url, + namespace_template=namespace_template, + reuse_namespace=reuse_namespace, + ) # Settings self.stats_port = stats_port @@ -91,21 +93,32 @@ def __init__( # pylint: disable=too-many-locals self.gcp_iam = gcp.iam.IamV1(gcp_api_manager, gcp_project) def run( # pylint: disable=arguments-differ - self, - *, - server_target, - rpc='UnaryCall', - qps=25, - metadata='', - secure_mode=False, - config_mesh=None, - print_response=False, - log_to_stdout: bool = False) -> XdsTestClient: + self, + *, + server_target, + rpc="UnaryCall", + qps=25, + metadata="", + secure_mode=False, + config_mesh=None, + print_response=False, + log_to_stdout: bool = False, + ) -> XdsTestClient: logger.info( - 'Deploying xDS test client "%s" to k8s namespace %s: ' - 'server_target=%s rpc=%s qps=%s metadata=%r secure_mode=%s ' - 'print_response=%s', self.deployment_name, self.k8s_namespace.name, - server_target, rpc, qps, metadata, secure_mode, print_response) + ( + 'Deploying xDS test client "%s" to k8s namespace %s: ' + "server_target=%s rpc=%s qps=%s metadata=%r secure_mode=%s " + "print_response=%s" + ), + self.deployment_name, + self.k8s_namespace.name, + server_target, + rpc, + qps, + metadata, + secure_mode, + print_response, + ) super().run() if self.enable_workload_identity: @@ -114,14 +127,16 @@ def run( # pylint: disable=arguments-differ self._grant_workload_identity_user( gcp_iam=self.gcp_iam, gcp_service_account=self.gcp_service_account, - service_account_name=self.service_account_name) + service_account_name=self.service_account_name, + ) # Create service account self.service_account = self._create_service_account( self.service_account_template, service_account_name=self.service_account_name, namespace_name=self.k8s_namespace.name, - gcp_service_account=self.gcp_service_account) + gcp_service_account=self.gcp_service_account, + ) # Always create a new deployment self.deployment = self._create_deployment( @@ -140,7 +155,8 @@ def run( # pylint: disable=arguments-differ metadata=metadata, secure_mode=secure_mode, config_mesh=config_mesh, - print_response=print_response) + print_response=print_response, + ) # Load test client pod. We need only one client at the moment pod_name = self._wait_deployment_pod_count(self.deployment)[0] @@ -154,19 +170,22 @@ def run( # pylint: disable=arguments-differ return self._xds_test_client_for_pod(pod, server_target=server_target) - def _xds_test_client_for_pod(self, pod: k8s.V1Pod, *, - server_target: str) -> XdsTestClient: + def _xds_test_client_for_pod( + self, pod: k8s.V1Pod, *, server_target: str + ) -> XdsTestClient: if self.debug_use_port_forwarding: pf = self._start_port_forwarding_pod(pod, self.stats_port) rpc_port, rpc_host = pf.local_port, pf.local_address else: rpc_port, rpc_host = self.stats_port, None - return XdsTestClient(ip=pod.status.pod_ip, - rpc_port=rpc_port, - server_target=server_target, - hostname=pod.metadata.name, - rpc_host=rpc_host) + return XdsTestClient( + ip=pod.status.pod_ip, + rpc_port=rpc_port, + server_target=server_target, + hostname=pod.metadata.name, + rpc_host=rpc_host, + ) # pylint: disable=arguments-differ def cleanup(self, *, force=False, force_namespace=False): @@ -175,12 +194,14 @@ def cleanup(self, *, force=False, force_namespace=False): if self.deployment or force: self._delete_deployment(self.deployment_name) self.deployment = None - if (self.enable_workload_identity and - (self.service_account or force)): + if self.enable_workload_identity and ( + self.service_account or force + ): self._revoke_workload_identity_user( gcp_iam=self.gcp_iam, gcp_service_account=self.gcp_service_account, - service_account_name=self.service_account_name) + service_account_name=self.service_account_name, + ) self._delete_service_account(self.service_account_name) self.service_account = None self._cleanup_namespace(force=force_namespace and force) @@ -190,10 +211,9 @@ def cleanup(self, *, force=False, force_namespace=False): # pylint: enable=arguments-differ @classmethod - def make_namespace_name(cls, - resource_prefix: str, - resource_suffix: str, - name: str = 'client') -> str: + def make_namespace_name( + cls, resource_prefix: str, resource_suffix: str, name: str = "client" + ) -> str: """A helper to make consistent XdsTestClient kubernetes namespace name for given resource prefix and suffix. diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_xds_server_runner.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_xds_server_runner.py index b6de2f85e0b92..48f9799233e5f 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_xds_server_runner.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/runners/k8s/k8s_xds_server_runner.py @@ -51,36 +51,39 @@ class KubernetesServerRunner(k8s_base_runner.KubernetesBaseRunner): service: Optional[k8s.V1Service] = None def __init__( # pylint: disable=too-many-locals - self, - k8s_namespace: k8s.KubernetesNamespace, - *, - deployment_name: str, - image_name: str, - td_bootstrap_image: str, - network: str = 'default', - xds_server_uri: Optional[str] = None, - gcp_api_manager: gcp.api.GcpApiManager, - gcp_project: str, - gcp_service_account: str, - service_account_name: Optional[str] = None, - service_name: Optional[str] = None, - neg_name: Optional[str] = None, - deployment_template: str = 'server.deployment.yaml', - service_account_template: str = 'service-account.yaml', - service_template: str = 'server.service.yaml', - reuse_service: bool = False, - reuse_namespace: bool = False, - namespace_template: Optional[str] = None, - debug_use_port_forwarding: bool = False, - enable_workload_identity: bool = True): - super().__init__(k8s_namespace, - deployment_name=deployment_name, - image_name=image_name, - gcp_project=gcp_project, - gcp_service_account=gcp_service_account, - gcp_ui_url=gcp_api_manager.gcp_ui_url, - namespace_template=namespace_template, - reuse_namespace=reuse_namespace) + self, + k8s_namespace: k8s.KubernetesNamespace, + *, + deployment_name: str, + image_name: str, + td_bootstrap_image: str, + network: str = "default", + xds_server_uri: Optional[str] = None, + gcp_api_manager: gcp.api.GcpApiManager, + gcp_project: str, + gcp_service_account: str, + service_account_name: Optional[str] = None, + service_name: Optional[str] = None, + neg_name: Optional[str] = None, + deployment_template: str = "server.deployment.yaml", + service_account_template: str = "service-account.yaml", + service_template: str = "server.service.yaml", + reuse_service: bool = False, + reuse_namespace: bool = False, + namespace_template: Optional[str] = None, + debug_use_port_forwarding: bool = False, + enable_workload_identity: bool = True, + ): + super().__init__( + k8s_namespace, + deployment_name=deployment_name, + image_name=image_name, + gcp_project=gcp_project, + gcp_service_account=gcp_service_account, + gcp_ui_url=gcp_api_manager.gcp_ui_url, + namespace_template=namespace_template, + reuse_namespace=reuse_namespace, + ) # Settings self.deployment_template = deployment_template @@ -90,8 +93,9 @@ def __init__( # pylint: disable=too-many-locals self.enable_workload_identity = enable_workload_identity self.debug_use_port_forwarding = debug_use_port_forwarding # GCP Network Endpoint Group. - self.gcp_neg_name = neg_name or (f'{self.k8s_namespace.name}-' - f'{self.service_name}') + self.gcp_neg_name = neg_name or ( + f"{self.k8s_namespace.name}-{self.service_name}" + ) # Used by the TD bootstrap generator. self.td_bootstrap_image = td_bootstrap_image @@ -108,13 +112,14 @@ def __init__( # pylint: disable=too-many-locals self.gcp_iam = gcp.iam.IamV1(gcp_api_manager, gcp_project) def run( # pylint: disable=arguments-differ,too-many-branches - self, - *, - test_port: int = DEFAULT_TEST_PORT, - maintenance_port: Optional[int] = None, - secure_mode: bool = False, - replica_count: int = 1, - log_to_stdout: bool = False) -> List[XdsTestServer]: + self, + *, + test_port: int = DEFAULT_TEST_PORT, + maintenance_port: Optional[int] = None, + secure_mode: bool = False, + replica_count: int = 1, + log_to_stdout: bool = False, + ) -> List[XdsTestServer]: if not maintenance_port: maintenance_port = self._get_default_maintenance_port(secure_mode) @@ -123,21 +128,32 @@ def run( # pylint: disable=arguments-differ,too-many-branches # maintenance services can be reached independently of the security # configuration under test. if secure_mode and maintenance_port == test_port: - raise ValueError('port and maintenance_port must be different ' - 'when running test server in secure mode') + raise ValueError( + "port and maintenance_port must be different " + "when running test server in secure mode" + ) # To avoid bugs with comparing wrong types. - if not (isinstance(test_port, int) and - isinstance(maintenance_port, int)): - raise TypeError('Port numbers must be integer') + if not ( + isinstance(test_port, int) and isinstance(maintenance_port, int) + ): + raise TypeError("Port numbers must be integer") if secure_mode and not self.enable_workload_identity: - raise ValueError('Secure mode requires Workload Identity enabled.') + raise ValueError("Secure mode requires Workload Identity enabled.") logger.info( - 'Deploying xDS test server "%s" to k8s namespace %s: test_port=%s ' - 'maintenance_port=%s secure_mode=%s replica_count=%s', - self.deployment_name, self.k8s_namespace.name, test_port, - maintenance_port, secure_mode, replica_count) + ( + 'Deploying xDS test server "%s" to k8s namespace %s:' + " test_port=%s maintenance_port=%s secure_mode=%s" + " replica_count=%s" + ), + self.deployment_name, + self.k8s_namespace.name, + test_port, + maintenance_port, + secure_mode, + replica_count, + ) super().run() # Reuse existing if requested, create a new deployment when missing. @@ -151,7 +167,8 @@ def run( # pylint: disable=arguments-differ,too-many-branches namespace_name=self.k8s_namespace.name, deployment_name=self.deployment_name, neg_name=self.gcp_neg_name, - test_port=test_port) + test_port=test_port, + ) self._wait_service_neg(self.service_name, test_port) if self.enable_workload_identity: @@ -160,14 +177,16 @@ def run( # pylint: disable=arguments-differ,too-many-branches self._grant_workload_identity_user( gcp_iam=self.gcp_iam, gcp_service_account=self.gcp_service_account, - service_account_name=self.service_account_name) + service_account_name=self.service_account_name, + ) # Create service account self.service_account = self._create_service_account( self.service_account_template, service_account_name=self.service_account_name, namespace_name=self.k8s_namespace.name, - gcp_service_account=self.gcp_service_account) + gcp_service_account=self.gcp_service_account, + ) # Always create a new deployment self.deployment = self._create_deployment( @@ -182,10 +201,12 @@ def run( # pylint: disable=arguments-differ,too-many-branches replica_count=replica_count, test_port=test_port, maintenance_port=maintenance_port, - secure_mode=secure_mode) + secure_mode=secure_mode, + ) - pod_names = self._wait_deployment_pod_count(self.deployment, - replica_count) + pod_names = self._wait_deployment_pod_count( + self.deployment, replica_count + ) pods = [] for pod_name in pod_names: pod = self._wait_pod_started(pod_name) @@ -194,17 +215,21 @@ def run( # pylint: disable=arguments-differ,too-many-branches self._start_logging_pod(pod, log_to_stdout=log_to_stdout) # Verify the deployment reports all pods started as well. - self._wait_deployment_with_available_replicas(self.deployment_name, - replica_count) + self._wait_deployment_with_available_replicas( + self.deployment_name, replica_count + ) self._start_completed() servers: List[XdsTestServer] = [] for pod in pods: servers.append( - self._xds_test_server_for_pod(pod, - test_port=test_port, - maintenance_port=maintenance_port, - secure_mode=secure_mode)) + self._xds_test_server_for_pod( + pod, + test_port=test_port, + maintenance_port=maintenance_port, + secure_mode=secure_mode, + ) + ) return servers def _get_default_maintenance_port(self, secure_mode: bool) -> int: @@ -214,12 +239,14 @@ def _get_default_maintenance_port(self, secure_mode: bool) -> int: maintenance_port = self.DEFAULT_SECURE_MODE_MAINTENANCE_PORT return maintenance_port - def _xds_test_server_for_pod(self, - pod: k8s.V1Pod, - *, - test_port: int = DEFAULT_TEST_PORT, - maintenance_port: Optional[int] = None, - secure_mode: bool = False) -> XdsTestServer: + def _xds_test_server_for_pod( + self, + pod: k8s.V1Pod, + *, + test_port: int = DEFAULT_TEST_PORT, + maintenance_port: Optional[int] = None, + secure_mode: bool = False, + ) -> XdsTestServer: if maintenance_port is None: maintenance_port = self._get_default_maintenance_port(secure_mode) @@ -229,12 +256,14 @@ def _xds_test_server_for_pod(self, else: rpc_port, rpc_host = maintenance_port, None - return XdsTestServer(ip=pod.status.pod_ip, - rpc_port=test_port, - hostname=pod.metadata.name, - maintenance_port=rpc_port, - secure_mode=secure_mode, - rpc_host=rpc_host) + return XdsTestServer( + ip=pod.status.pod_ip, + rpc_port=test_port, + hostname=pod.metadata.name, + maintenance_port=rpc_port, + secure_mode=secure_mode, + rpc_host=rpc_host, + ) # pylint: disable=arguments-differ def cleanup(self, *, force=False, force_namespace=False): @@ -246,12 +275,14 @@ def cleanup(self, *, force=False, force_namespace=False): if (self.service and not self.reuse_service) or force: self._delete_service(self.service_name) self.service = None - if (self.enable_workload_identity and - (self.service_account or force)): + if self.enable_workload_identity and ( + self.service_account or force + ): self._revoke_workload_identity_user( gcp_iam=self.gcp_iam, gcp_service_account=self.gcp_service_account, - service_account_name=self.service_account_name) + service_account_name=self.service_account_name, + ) self._delete_service_account(self.service_account_name) self.service_account = None self._cleanup_namespace(force=(force_namespace and force)) @@ -261,10 +292,9 @@ def cleanup(self, *, force=False, force_namespace=False): # pylint: enable=arguments-differ @classmethod - def make_namespace_name(cls, - resource_prefix: str, - resource_suffix: str, - name: str = 'server') -> str: + def make_namespace_name( + cls, resource_prefix: str, resource_suffix: str, name: str = "server" + ) -> str: """A helper to make consistent XdsTestServer kubernetes namespace name for given resource prefix and suffix. diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/server_app.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/server_app.py index b13661b694031..57444b73c7db8 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/server_app.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/server_app.py @@ -35,21 +35,24 @@ class XdsTestServer(framework.rpc.grpc.GrpcApp): Represents RPC services implemented in Server component of the xDS test app. https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#server """ + # A unique host name identifying each server replica. # Server implementation must return this in the SimpleResponse.hostname, # which client uses as the key in rpcs_by_peer map. hostname: str - def __init__(self, - *, - ip: str, - rpc_port: int, - hostname: str, - maintenance_port: Optional[int] = None, - secure_mode: Optional[bool] = False, - xds_host: Optional[str] = None, - xds_port: Optional[int] = None, - rpc_host: Optional[str] = None): + def __init__( + self, + *, + ip: str, + rpc_port: int, + hostname: str, + maintenance_port: Optional[int] = None, + secure_mode: Optional[bool] = False, + xds_host: Optional[str] = None, + xds_port: Optional[int] = None, + rpc_host: Optional[str] = None, + ): super().__init__(rpc_host=(rpc_host or ip)) self.ip = ip self.rpc_port = rpc_port @@ -63,34 +66,44 @@ def __init__(self, def channelz(self) -> _ChannelzServiceClient: return _ChannelzServiceClient( self._make_channel(self.maintenance_port), - log_target=f'{self.hostname}:{self.maintenance_port}') + log_target=f"{self.hostname}:{self.maintenance_port}", + ) @property @functools.lru_cache(None) def update_health_service_client(self) -> _XdsUpdateHealthServiceClient: return _XdsUpdateHealthServiceClient( self._make_channel(self.maintenance_port), - log_target=f'{self.hostname}:{self.maintenance_port}') + log_target=f"{self.hostname}:{self.maintenance_port}", + ) @property @functools.lru_cache(None) def health_client(self) -> _HealthClient: return _HealthClient( self._make_channel(self.maintenance_port), - log_target=f'{self.hostname}:{self.maintenance_port}') + log_target=f"{self.hostname}:{self.maintenance_port}", + ) def set_serving(self): - logger.info('[%s] >> Setting health status to SERVING', self.hostname) + logger.info("[%s] >> Setting health status to SERVING", self.hostname) self.update_health_service_client.set_serving() - logger.info('[%s] << Health status %s', self.hostname, - self.health_client.check_health()) + logger.info( + "[%s] << Health status %s", + self.hostname, + self.health_client.check_health(), + ) def set_not_serving(self): - logger.info('[%s] >> Setting health status to NOT_SERVING', - self.hostname) + logger.info( + "[%s] >> Setting health status to NOT_SERVING", self.hostname + ) self.update_health_service_client.set_not_serving() - logger.info('[%s] << Health status %s', self.hostname, - self.health_client.check_health()) + logger.info( + "[%s] << Health status %s", + self.hostname, + self.health_client.check_health(), + ) def set_xds_address(self, xds_host, xds_port: Optional[int] = None): self.xds_host, self.xds_port = xds_host, xds_port @@ -98,16 +111,16 @@ def set_xds_address(self, xds_host, xds_port: Optional[int] = None): @property def xds_address(self) -> str: if not self.xds_host: - return '' + return "" if not self.xds_port: return self.xds_host - return f'{self.xds_host}:{self.xds_port}' + return f"{self.xds_host}:{self.xds_port}" @property def xds_uri(self) -> str: if not self.xds_host: - return '' - return f'xds:///{self.xds_address}' + return "" + return f"xds:///{self.xds_address}" def get_test_server(self) -> grpc_channelz.Server: """Return channelz representation of a server running TestService. @@ -117,8 +130,10 @@ def get_test_server(self) -> grpc_channelz.Server: """ server = self.channelz.find_server_listening_on_port(self.rpc_port) if not server: - raise self.NotFound(f'[{self.hostname}] Server' - f'listening on port {self.rpc_port} not found') + raise self.NotFound( + f"[{self.hostname}] Server" + f"listening on port {self.rpc_port} not found" + ) return server def get_test_server_sockets(self) -> Iterator[grpc_channelz.Socket]: @@ -130,8 +145,9 @@ def get_test_server_sockets(self) -> Iterator[grpc_channelz.Socket]: server = self.get_test_server() return self.channelz.list_server_sockets(server) - def get_server_socket_matching_client(self, - client_socket: grpc_channelz.Socket): + def get_server_socket_matching_client( + self, client_socket: grpc_channelz.Socket + ): """Find test server socket that matches given test client socket. Sockets are matched using TCP endpoints (ip:port), further on "address". @@ -139,21 +155,26 @@ def get_server_socket_matching_client(self, Raises: GrpcApp.NotFound: Server socket matching client socket not found. - """ + """ client_local = self.channelz.sock_address_to_str(client_socket.local) logger.debug( - '[%s] Looking for a server socket connected ' - 'to the client %s', self.hostname, client_local) + "[%s] Looking for a server socket connected to the client %s", + self.hostname, + client_local, + ) server_socket = self.channelz.find_server_socket_matching_client( - self.get_test_server_sockets(), client_socket) + self.get_test_server_sockets(), client_socket + ) if not server_socket: - raise self.NotFound(f'[{self.hostname}] Socket ' - f'to client {client_local} not found') + raise self.NotFound( + f"[{self.hostname}] Socket to client {client_local} not found" + ) logger.info( - '[%s] Found matching socket pair: ' - 'server(%s) <-> client(%s)', self.hostname, + "[%s] Found matching socket pair: server(%s) <-> client(%s)", + self.hostname, self.channelz.sock_addresses_pretty(server_socket), - self.channelz.sock_addresses_pretty(client_socket)) + self.channelz.sock_addresses_pretty(client_socket), + ) return server_socket diff --git a/tools/run_tests/xds_k8s_test_driver/framework/xds_flags.py b/tools/run_tests/xds_k8s_test_driver/framework/xds_flags.py index e9b720c17dd5d..fd33407d5e84a 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/xds_flags.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/xds_flags.py @@ -19,75 +19,92 @@ from framework.helpers import highlighter # GCP -PROJECT = flags.DEFINE_string("project", - default=None, - help="(required) GCP Project ID.") +PROJECT = flags.DEFINE_string( + "project", default=None, help="(required) GCP Project ID." +) RESOURCE_PREFIX = flags.DEFINE_string( "resource_prefix", default=None, - help=("(required) The prefix used to name GCP resources.\n" - "Together with `resource_suffix` used to create unique " - "resource names.")) + help=( + "(required) The prefix used to name GCP resources.\n" + "Together with `resource_suffix` used to create unique " + "resource names." + ), +) RESOURCE_SUFFIX = flags.DEFINE_string( "resource_suffix", default=None, - help=("The suffix used to name GCP resources.\n" - "Together with `resource_prefix` used to create unique " - "resource names.\n" - "(default: test suite will generate a random suffix, based on suite " - "resource management preferences)")) -NETWORK = flags.DEFINE_string("network", - default="default", - help="GCP Network ID") + help=( + "The suffix used to name GCP resources.\n" + "Together with `resource_prefix` used to create unique " + "resource names.\n" + "(default: test suite will generate a random suffix, based on suite " + "resource management preferences)" + ), +) +NETWORK = flags.DEFINE_string( + "network", default="default", help="GCP Network ID" +) COMPUTE_API_VERSION = flags.DEFINE_string( "compute_api_version", - default='v1', - help="The version of the GCP Compute API, e.g., v1, v1alpha") + default="v1", + help="The version of the GCP Compute API, e.g., v1, v1alpha", +) # Mirrors --xds-server-uri argument of Traffic Director gRPC Bootstrap XDS_SERVER_URI = flags.DEFINE_string( - "xds_server_uri", - default=None, - help="Override Traffic Director server URI.") + "xds_server_uri", default=None, help="Override Traffic Director server URI." +) ENSURE_FIREWALL = flags.DEFINE_bool( "ensure_firewall", default=False, - help="Ensure the allow-health-check firewall exists before each test case") + help="Ensure the allow-health-check firewall exists before each test case", +) FIREWALL_SOURCE_RANGE = flags.DEFINE_list( "firewall_source_range", - default=['35.191.0.0/16', '130.211.0.0/22'], - help="Update the source range of the firewall rule.") + default=["35.191.0.0/16", "130.211.0.0/22"], + help="Update the source range of the firewall rule.", +) FIREWALL_ALLOWED_PORTS = flags.DEFINE_list( "firewall_allowed_ports", - default=['8080-8100'], - help="Update the allowed ports of the firewall rule.") + default=["8080-8100"], + help="Update the allowed ports of the firewall rule.", +) # Test server SERVER_NAME = flags.DEFINE_string( "server_name", default="psm-grpc-server", - help="The name to use for test server deployments.") + help="The name to use for test server deployments.", +) SERVER_PORT = flags.DEFINE_integer( "server_port", default=8080, lower_bound=1, upper_bound=65535, - help="Server test port.\nMust be within --firewall_allowed_ports.") + help="Server test port.\nMust be within --firewall_allowed_ports.", +) SERVER_MAINTENANCE_PORT = flags.DEFINE_integer( "server_maintenance_port", default=None, lower_bound=1, upper_bound=65535, - help=("Server port running maintenance services: Channelz, CSDS, Health, " - "XdsUpdateHealth, and ProtoReflection (optional).\n" - "Must be within --firewall_allowed_ports.\n" - "(default: the port is chosen automatically based on " - "the security configuration)")) + help=( + "Server port running maintenance services: Channelz, CSDS, Health, " + "XdsUpdateHealth, and ProtoReflection (optional).\n" + "Must be within --firewall_allowed_ports.\n" + "(default: the port is chosen automatically based on " + "the security configuration)" + ), +) SERVER_XDS_HOST = flags.DEFINE_string( "server_xds_host", default="xds-test-server", - help=("The xDS hostname of the test server.\n" - "Together with `server_xds_port` makes test server target URI, " - "xds:///hostname:port")) + help=( + "The xDS hostname of the test server.\n" + "Together with `server_xds_port` makes test server target URI, " + "xds:///hostname:port" + ), +) # Note: port 0 known to represent a request for dynamically-allocated port # https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers#Well-known_ports SERVER_XDS_PORT = flags.DEFINE_integer( @@ -95,17 +112,21 @@ default=8080, lower_bound=0, upper_bound=65535, - help=("The xDS port of the test server.\n" - "Together with `server_xds_host` makes test server target URI, " - "xds:///hostname:port\n" - "Must be unique within a GCP project.\n" - "Set to 0 to select any unused port.")) + help=( + "The xDS port of the test server.\n" + "Together with `server_xds_host` makes test server target URI, " + "xds:///hostname:port\n" + "Must be unique within a GCP project.\n" + "Set to 0 to select any unused port." + ), +) # Test client CLIENT_NAME = flags.DEFINE_string( "client_name", default="psm-grpc-client", - help="The name to use for test client deployments") + help="The name to use for test client deployments", +) CLIENT_PORT = flags.DEFINE_integer( "client_port", default=8079, @@ -114,33 +135,43 @@ help=( "The port test client uses to run gRPC services: Channelz, CSDS, " "XdsStats, XdsUpdateClientConfigure, and ProtoReflection (optional).\n" - "Doesn't have to be within --firewall_allowed_ports.")) + "Doesn't have to be within --firewall_allowed_ports." + ), +) # Testing metadata TESTING_VERSION = flags.DEFINE_string( "testing_version", default=None, - help="The testing gRPC version branch name. Like master, dev, v1.55.x") + help="The testing gRPC version branch name. Like master, dev, v1.55.x", +) FORCE_CLEANUP = flags.DEFINE_bool( "force_cleanup", default=False, - help="Force resource cleanup, even if not created by this test run") + help="Force resource cleanup, even if not created by this test run", +) COLLECT_APP_LOGS = flags.DEFINE_bool( - 'collect_app_logs', + "collect_app_logs", default=False, - help=('Collect the logs of the xDS Test Client and Server\n' - f'into the test_app_logs/ directory under the log directory.\n' - 'See --log_dir description for configuring the log directory.')) + help=( + f"Collect the logs of the xDS Test Client and Server\n" + f"into the test_app_logs/ directory under the log directory.\n" + f"See --log_dir description for configuring the log directory." + ), +) # Needed to configure urllib3 socket timeout, which is infinity by default. SOCKET_DEFAULT_TIMEOUT = flags.DEFINE_float( "socket_default_timeout", default=60, lower_bound=0, - help=("Set the default timeout in seconds on blocking socket operations.\n" - "If zero is given, the new sockets have no timeout. ")) + help=( + "Set the default timeout in seconds on blocking socket operations.\n" + "If zero is given, the new sockets have no timeout. " + ), +) def set_socket_default_timeout_from_flag() -> None: @@ -162,7 +193,9 @@ def set_socket_default_timeout_from_flag() -> None: flags.adopt_module_key_flags(highlighter) -flags.mark_flags_as_required([ - "project", - "resource_prefix", -]) +flags.mark_flags_as_required( + [ + "project", + "resource_prefix", + ] +) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_flags.py b/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_flags.py index 5e50af1ade328..3bd2e10d4f0b4 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_flags.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_flags.py @@ -14,47 +14,57 @@ from absl import flags # GCP -KUBE_CONTEXT = flags.DEFINE_string("kube_context", - default=None, - help="Kubectl context to use") +KUBE_CONTEXT = flags.DEFINE_string( + "kube_context", default=None, help="Kubectl context to use" +) SECONDARY_KUBE_CONTEXT = flags.DEFINE_string( "secondary_kube_context", default=None, - help="Secondary kubectl context to use for cluster in another region") + help="Secondary kubectl context to use for cluster in another region", +) GCP_SERVICE_ACCOUNT = flags.DEFINE_string( "gcp_service_account", default=None, - help="GCP Service account for GKE workloads to impersonate") + help="GCP Service account for GKE workloads to impersonate", +) TD_BOOTSTRAP_IMAGE = flags.DEFINE_string( "td_bootstrap_image", default=None, - help="Traffic Director gRPC Bootstrap Docker image") + help="Traffic Director gRPC Bootstrap Docker image", +) # Test app -SERVER_IMAGE = flags.DEFINE_string("server_image", - default=None, - help="Server Docker image name") +SERVER_IMAGE = flags.DEFINE_string( + "server_image", default=None, help="Server Docker image name" +) SERVER_IMAGE_CANONICAL = flags.DEFINE_string( "server_image_canonical", default=None, - help=("The canonical implementation of the xDS test server.\n" - "Can be used in tests where language-specific xDS test server" - "does not exist, or missing a feature required for the test.")) -CLIENT_IMAGE = flags.DEFINE_string("client_image", - default=None, - help="Client Docker image name") + help=( + "The canonical implementation of the xDS test server.\n" + "Can be used in tests where language-specific xDS test server" + "does not exist, or missing a feature required for the test." + ), +) +CLIENT_IMAGE = flags.DEFINE_string( + "client_image", default=None, help="Client Docker image name" +) DEBUG_USE_PORT_FORWARDING = flags.DEFINE_bool( "debug_use_port_forwarding", default=False, - help="Development only: use kubectl port-forward to connect to test app") + help="Development only: use kubectl port-forward to connect to test app", +) ENABLE_WORKLOAD_IDENTITY = flags.DEFINE_bool( "enable_workload_identity", default=True, - help="Enable the WorkloadIdentity feature") + help="Enable the WorkloadIdentity feature", +) -flags.mark_flags_as_required([ - "kube_context", - "td_bootstrap_image", - "server_image", - "client_image", -]) +flags.mark_flags_as_required( + [ + "kube_context", + "td_bootstrap_image", + "server_image", + "client_image", + ] +) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py b/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py index 971ff7248a079..bf8e312ce5125 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py @@ -53,7 +53,8 @@ _CHECK_LOCAL_CERTS = flags.DEFINE_bool( "check_local_certs", default=True, - help="Security Tests also check the value of local certs") + help="Security Tests also check the value of local certs", +) flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_k8s_flags) @@ -66,7 +67,9 @@ KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner KubernetesClientRunner = k8s_xds_client_runner.KubernetesClientRunner _LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse -_LoadBalancerAccumulatedStatsResponse = grpc_testing.LoadBalancerAccumulatedStatsResponse +_LoadBalancerAccumulatedStatsResponse = ( + grpc_testing.LoadBalancerAccumulatedStatsResponse +) _ChannelState = grpc_channelz.ChannelState _timedelta = datetime.timedelta ClientConfig = grpc_csds.ClientConfig @@ -94,7 +97,7 @@ class XdsKubernetesBaseTestCase(absltest.TestCase): network: str project: str resource_prefix: str - resource_suffix: str = '' + resource_suffix: str = "" # Whether to randomize resources names for each test by appending a # unique suffix. resource_suffix_randomize: bool = True @@ -124,8 +127,8 @@ def setUpClass(cls): """Hook method for setting up class fixture before running tests in the class. """ - logger.info('----- Testing %s -----', cls.__name__) - logger.info('Logs timezone: %s', time.localtime().tm_zone) + logger.info("----- Testing %s -----", cls.__name__) + logger.info("Logs timezone: %s", time.localtime().tm_zone) # Raises unittest.SkipTest if given client/server/version does not # support current test case. @@ -165,17 +168,21 @@ def setUpClass(cls): # Test suite settings cls.force_cleanup = xds_flags.FORCE_CLEANUP.value - cls.debug_use_port_forwarding = \ + cls.debug_use_port_forwarding = ( xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value - cls.enable_workload_identity = \ + ) + cls.enable_workload_identity = ( xds_k8s_flags.ENABLE_WORKLOAD_IDENTITY.value + ) cls.check_local_certs = _CHECK_LOCAL_CERTS.value # Resource managers cls.k8s_api_manager = k8s.KubernetesApiManager( - xds_k8s_flags.KUBE_CONTEXT.value) + xds_k8s_flags.KUBE_CONTEXT.value + ) cls.secondary_k8s_api_manager = k8s.KubernetesApiManager( - xds_k8s_flags.SECONDARY_KUBE_CONTEXT.value) + xds_k8s_flags.SECONDARY_KUBE_CONTEXT.value + ) cls.gcp_api_manager = gcp.api.GcpApiManager() # Other @@ -183,13 +190,15 @@ def setUpClass(cls): @classmethod def _pretty_accumulated_stats( - cls, - accumulated_stats: _LoadBalancerAccumulatedStatsResponse, - *, - ignore_empty: bool = False, - highlight: bool = True) -> str: + cls, + accumulated_stats: _LoadBalancerAccumulatedStatsResponse, + *, + ignore_empty: bool = False, + highlight: bool = True, + ) -> str: stats_yaml = helpers_grpc.accumulated_stats_pretty( - accumulated_stats, ignore_empty=ignore_empty) + accumulated_stats, ignore_empty=ignore_empty + ) if not highlight: return stats_yaml return cls.yaml_highlighter.highlight(stats_yaml) @@ -206,12 +215,14 @@ def tearDownClass(cls): cls.gcp_api_manager.close() def setUp(self): - self._prev_sigint_handler = signal.signal(signal.SIGINT, - self.handle_sigint) - - def handle_sigint(self, signalnum: _SignalNum, - frame: Optional[FrameType]) -> None: - logger.info('Caught Ctrl+C, cleaning up...') + self._prev_sigint_handler = signal.signal( + signal.SIGINT, self.handle_sigint + ) + + def handle_sigint( + self, signalnum: _SignalNum, frame: Optional[FrameType] + ) -> None: + logger.info("Caught Ctrl+C, cleaning up...") self._handling_sigint = True # Force resource cleanup by their name. Addresses the case where ctrl-c # is pressed while waiting for the resource creation. @@ -225,32 +236,38 @@ def handle_sigint(self, signalnum: _SignalNum, @contextlib.contextmanager def subTest(self, msg, **params): # noqa pylint: disable=signature-differs - logger.info('--- Starting subTest %s.%s ---', self.id(), msg) + logger.info("--- Starting subTest %s.%s ---", self.id(), msg) try: yield super().subTest(msg, **params) finally: if not self._handling_sigint: - logger.info('--- Finished subTest %s.%s ---', self.id(), msg) + logger.info("--- Finished subTest %s.%s ---", self.id(), msg) def setupTrafficDirectorGrpc(self): - self.td.setup_for_grpc(self.server_xds_host, - self.server_xds_port, - health_check_port=self.server_maintenance_port) - - def setupServerBackends(self, - *, - wait_for_healthy_status=True, - server_runner=None, - max_rate_per_endpoint: Optional[int] = None): + self.td.setup_for_grpc( + self.server_xds_host, + self.server_xds_port, + health_check_port=self.server_maintenance_port, + ) + + def setupServerBackends( + self, + *, + wait_for_healthy_status=True, + server_runner=None, + max_rate_per_endpoint: Optional[int] = None, + ): if server_runner is None: server_runner = self.server_runner # Load Backends neg_name, neg_zones = server_runner.k8s_namespace.get_service_neg( - server_runner.service_name, self.server_port) + server_runner.service_name, self.server_port + ) # Add backends to the Backend Service self.td.backend_service_add_neg_backends( - neg_name, neg_zones, max_rate_per_endpoint=max_rate_per_endpoint) + neg_name, neg_zones, max_rate_per_endpoint=max_rate_per_endpoint + ) if wait_for_healthy_status: self.td.wait_for_backends_healthy_status() @@ -259,26 +276,28 @@ def removeServerBackends(self, *, server_runner=None): server_runner = self.server_runner # Load Backends neg_name, neg_zones = server_runner.k8s_namespace.get_service_neg( - server_runner.service_name, self.server_port) + server_runner.service_name, self.server_port + ) # Remove backends from the Backend Service self.td.backend_service_remove_neg_backends(neg_name, neg_zones) - def assertSuccessfulRpcs(self, - test_client: XdsTestClient, - num_rpcs: int = 100): + def assertSuccessfulRpcs( + self, test_client: XdsTestClient, num_rpcs: int = 100 + ): lb_stats = self.getClientRpcStats(test_client, num_rpcs) self.assertAllBackendsReceivedRpcs(lb_stats) failed = int(lb_stats.num_failures) self.assertLessEqual( failed, 0, - msg=f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed') + msg=f"Expected all RPCs to succeed: {failed} of {num_rpcs} failed", + ) @staticmethod def diffAccumulatedStatsPerMethod( before: _LoadBalancerAccumulatedStatsResponse, - after: _LoadBalancerAccumulatedStatsResponse + after: _LoadBalancerAccumulatedStatsResponse, ) -> _LoadBalancerAccumulatedStatsResponse: """Only diffs stats_per_method, as the other fields are deprecated.""" diff = _LoadBalancerAccumulatedStatsResponse() @@ -289,20 +308,24 @@ def diffAccumulatedStatsPerMethod( raise AssertionError("Diff of count shouldn't be negative") if count > 0: diff.stats_per_method[method].result[status] = count - rpcs_started = (method_stats.rpcs_started - - before.stats_per_method[method].rpcs_started) + rpcs_started = ( + method_stats.rpcs_started + - before.stats_per_method[method].rpcs_started + ) if rpcs_started < 0: raise AssertionError("Diff of count shouldn't be negative") diff.stats_per_method[method].rpcs_started = rpcs_started return diff - def assertRpcStatusCodes(self, - test_client: XdsTestClient, - *, - expected_status: grpc.StatusCode, - duration: _timedelta, - method: str, - stray_rpc_limit: int = 0) -> None: + def assertRpcStatusCodes( + self, + test_client: XdsTestClient, + *, + expected_status: grpc.StatusCode, + duration: _timedelta, + method: str, + stray_rpc_limit: int = 0, + ) -> None: """Assert all RPCs for a method are completing with a certain status.""" # pylint: disable=too-many-locals expected_status_int: int = expected_status.value[0] @@ -311,30 +334,45 @@ def assertRpcStatusCodes(self, # Sending with pre-set QPS for a period of time before_stats = test_client.get_load_balancer_accumulated_stats() logging.debug( - '[%s] << LoadBalancerAccumulatedStatsResponse initial measurement:' - '\n%s', test_client.hostname, - self._pretty_accumulated_stats(before_stats)) + ( + "[%s] << LoadBalancerAccumulatedStatsResponse initial" + " measurement:\n%s" + ), + test_client.hostname, + self._pretty_accumulated_stats(before_stats), + ) time.sleep(duration.total_seconds()) after_stats = test_client.get_load_balancer_accumulated_stats() logging.debug( - '[%s] << LoadBalancerAccumulatedStatsResponse after %s seconds:' - '\n%s', test_client.hostname, duration.total_seconds(), - self._pretty_accumulated_stats(after_stats)) - - diff_stats = self.diffAccumulatedStatsPerMethod(before_stats, - after_stats) + ( + "[%s] << LoadBalancerAccumulatedStatsResponse after %s seconds:" + "\n%s" + ), + test_client.hostname, + duration.total_seconds(), + self._pretty_accumulated_stats(after_stats), + ) + + diff_stats = self.diffAccumulatedStatsPerMethod( + before_stats, after_stats + ) logger.info( - '[%s] << Received accumulated stats difference.' - ' Expecting RPCs with status %s for method %s:\n%s', - test_client.hostname, expected_status_fmt, method, - self._pretty_accumulated_stats(diff_stats, ignore_empty=True)) + ( + "[%s] << Received accumulated stats difference." + " Expecting RPCs with status %s for method %s:\n%s" + ), + test_client.hostname, + expected_status_fmt, + method, + self._pretty_accumulated_stats(diff_stats, ignore_empty=True), + ) # Used in stack traces. Don't highlight for better compatibility. - diff_stats_fmt: str = self._pretty_accumulated_stats(diff_stats, - ignore_empty=True, - highlight=False) + diff_stats_fmt: str = self._pretty_accumulated_stats( + diff_stats, ignore_empty=True, highlight=False + ) # 1. Verify the completed RPCs of the given method has no statuses # other than the expected_status, @@ -342,136 +380,186 @@ def assertRpcStatusCodes(self, for found_status_int, count in stats.result.items(): found_status = helpers_grpc.status_from_int(found_status_int) if found_status != expected_status and count > stray_rpc_limit: - self.fail(f"Expected only status {expected_status_fmt}," - " but found status" - f" {helpers_grpc.status_pretty(found_status)}" - f" for method {method}." - f"\nDiff stats:\n{diff_stats_fmt}") + self.fail( + f"Expected only status {expected_status_fmt}," + " but found status" + f" {helpers_grpc.status_pretty(found_status)}" + f" for method {method}." + f"\nDiff stats:\n{diff_stats_fmt}" + ) # 2. Verify there are completed RPCs of the given method with # the expected_status. - self.assertGreater(stats.result[expected_status_int], - 0, - msg=("Expected non-zero completed RPCs with status" - f" {expected_status_fmt} for method {method}." - f"\nDiff stats:\n{diff_stats_fmt}")) - - def assertRpcsEventuallyGoToGivenServers(self, - test_client: XdsTestClient, - servers: List[XdsTestServer], - num_rpcs: int = 100): + self.assertGreater( + stats.result[expected_status_int], + 0, + msg=( + "Expected non-zero completed RPCs with status" + f" {expected_status_fmt} for method {method}." + f"\nDiff stats:\n{diff_stats_fmt}" + ), + ) + + def assertRpcsEventuallyGoToGivenServers( + self, + test_client: XdsTestClient, + servers: List[XdsTestServer], + num_rpcs: int = 100, + ): retryer = retryers.constant_retryer( wait_fixed=datetime.timedelta(seconds=1), timeout=datetime.timedelta(seconds=_TD_CONFIG_MAX_WAIT_SEC), - log_level=logging.INFO) + log_level=logging.INFO, + ) try: - retryer(self._assertRpcsEventuallyGoToGivenServers, test_client, - servers, num_rpcs) + retryer( + self._assertRpcsEventuallyGoToGivenServers, + test_client, + servers, + num_rpcs, + ) except retryers.RetryError as retry_error: logger.exception( - 'Rpcs did not go to expected servers before timeout %s', - _TD_CONFIG_MAX_WAIT_SEC) + "Rpcs did not go to expected servers before timeout %s", + _TD_CONFIG_MAX_WAIT_SEC, + ) raise retry_error - def _assertRpcsEventuallyGoToGivenServers(self, test_client: XdsTestClient, - servers: List[XdsTestServer], - num_rpcs: int): + def _assertRpcsEventuallyGoToGivenServers( + self, + test_client: XdsTestClient, + servers: List[XdsTestServer], + num_rpcs: int, + ): server_hostnames = [server.hostname for server in servers] - logger.info('Verifying RPCs go to servers %s', server_hostnames) + logger.info("Verifying RPCs go to servers %s", server_hostnames) lb_stats = self.getClientRpcStats(test_client, num_rpcs) failed = int(lb_stats.num_failures) self.assertLessEqual( failed, 0, - msg=f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed') + msg=f"Expected all RPCs to succeed: {failed} of {num_rpcs} failed", + ) for server_hostname in server_hostnames: - self.assertIn(server_hostname, lb_stats.rpcs_by_peer, - f'Server {server_hostname} did not receive RPCs') + self.assertIn( + server_hostname, + lb_stats.rpcs_by_peer, + f"Server {server_hostname} did not receive RPCs", + ) for server_hostname in lb_stats.rpcs_by_peer.keys(): - self.assertIn(server_hostname, server_hostnames, - f'Unexpected server {server_hostname} received RPCs') + self.assertIn( + server_hostname, + server_hostnames, + f"Unexpected server {server_hostname} received RPCs", + ) def assertXdsConfigExists(self, test_client: XdsTestClient): config = test_client.csds.fetch_client_status(log_level=logging.INFO) self.assertIsNotNone(config) seen = set() - want = frozenset([ - 'listener_config', - 'cluster_config', - 'route_config', - 'endpoint_config', - ]) + want = frozenset( + [ + "listener_config", + "cluster_config", + "route_config", + "endpoint_config", + ] + ) for xds_config in config.xds_config: - seen.add(xds_config.WhichOneof('per_xds_config')) + seen.add(xds_config.WhichOneof("per_xds_config")) for generic_xds_config in config.generic_xds_configs: - if re.search(r'\.Listener$', generic_xds_config.type_url): - seen.add('listener_config') - elif re.search(r'\.RouteConfiguration$', - generic_xds_config.type_url): - seen.add('route_config') - elif re.search(r'\.Cluster$', generic_xds_config.type_url): - seen.add('cluster_config') - elif re.search(r'\.ClusterLoadAssignment$', - generic_xds_config.type_url): - seen.add('endpoint_config') - logger.debug('Received xDS config dump: %s', - json_format.MessageToJson(config, indent=2)) + if re.search(r"\.Listener$", generic_xds_config.type_url): + seen.add("listener_config") + elif re.search( + r"\.RouteConfiguration$", generic_xds_config.type_url + ): + seen.add("route_config") + elif re.search(r"\.Cluster$", generic_xds_config.type_url): + seen.add("cluster_config") + elif re.search( + r"\.ClusterLoadAssignment$", generic_xds_config.type_url + ): + seen.add("endpoint_config") + logger.debug( + "Received xDS config dump: %s", + json_format.MessageToJson(config, indent=2), + ) self.assertSameElements(want, seen) def assertRouteConfigUpdateTrafficHandoff( - self, test_client: XdsTestClient, - previous_route_config_version: str, retry_wait_second: int, - timeout_second: int): + self, + test_client: XdsTestClient, + previous_route_config_version: str, + retry_wait_second: int, + timeout_second: int, + ): retryer = retryers.constant_retryer( wait_fixed=datetime.timedelta(seconds=retry_wait_second), timeout=datetime.timedelta(seconds=timeout_second), retry_on_exceptions=(TdPropagationRetryableError,), logger=logger, - log_level=logging.INFO) + log_level=logging.INFO, + ) try: for attempt in retryer: with attempt: self.assertSuccessfulRpcs(test_client) raw_config = test_client.csds.fetch_client_status( - log_level=logging.INFO) + log_level=logging.INFO + ) dumped_config = xds_url_map_testcase.DumpedXdsConfig( - json_format.MessageToDict(raw_config)) + json_format.MessageToDict(raw_config) + ) route_config_version = dumped_config.rds_version if previous_route_config_version == route_config_version: logger.info( - 'Routing config not propagated yet. Retrying.') + "Routing config not propagated yet. Retrying." + ) raise TdPropagationRetryableError( "CSDS not get updated routing config corresponding" - " to the second set of url maps") + " to the second set of url maps" + ) else: self.assertSuccessfulRpcs(test_client) logger.info( - ('[SUCCESS] Confirmed successful RPC with the ' - 'updated routing config, version=%s'), - route_config_version) + ( + "[SUCCESS] Confirmed successful RPC with the " + "updated routing config, version=%s" + ), + route_config_version, + ) except retryers.RetryError as retry_error: logger.info( - ('Retry exhausted. TD routing config propagation failed after ' - 'timeout %ds. Last seen client config dump: %s'), - timeout_second, dumped_config) + ( + "Retry exhausted. TD routing config propagation failed" + " after timeout %ds. Last seen client config dump: %s" + ), + timeout_second, + dumped_config, + ) raise retry_error - def assertFailedRpcs(self, - test_client: XdsTestClient, - num_rpcs: Optional[int] = 100): + def assertFailedRpcs( + self, test_client: XdsTestClient, num_rpcs: Optional[int] = 100 + ): lb_stats = self.getClientRpcStats(test_client, num_rpcs) failed = int(lb_stats.num_failures) self.assertEqual( failed, num_rpcs, - msg=f'Expected all RPCs to fail: {failed} of {num_rpcs} failed') + msg=f"Expected all RPCs to fail: {failed} of {num_rpcs} failed", + ) @classmethod - def getClientRpcStats(cls, test_client: XdsTestClient, - num_rpcs: int) -> _LoadBalancerStatsResponse: + def getClientRpcStats( + cls, test_client: XdsTestClient, num_rpcs: int + ) -> _LoadBalancerStatsResponse: lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs) - logger.info('[%s] << Received LoadBalancerStatsResponse:\n%s', - test_client.hostname, cls._pretty_lb_stats(lb_stats)) + logger.info( + "[%s] << Received LoadBalancerStatsResponse:\n%s", + test_client.hostname, + cls._pretty_lb_stats(lb_stats), + ) return lb_stats def assertAllBackendsReceivedRpcs(self, lb_stats): @@ -480,11 +568,13 @@ def assertAllBackendsReceivedRpcs(self, lb_stats): self.assertGreater( int(rpcs_count), 0, - msg=f'Backend {backend} did not receive a single RPC') + msg=f"Backend {backend} did not receive a single RPC", + ) -class IsolatedXdsKubernetesTestCase(XdsKubernetesBaseTestCase, - metaclass=abc.ABCMeta): +class IsolatedXdsKubernetesTestCase( + XdsKubernetesBaseTestCase, metaclass=abc.ABCMeta +): """Isolated test case. Base class for tests cases where infra resources are created before @@ -497,26 +587,32 @@ def setUp(self): if self.resource_suffix_randomize: self.resource_suffix = helpers_rand.random_resource_suffix() - logger.info('Test run resource prefix: %s, suffix: %s', - self.resource_prefix, self.resource_suffix) + logger.info( + "Test run resource prefix: %s, suffix: %s", + self.resource_prefix, + self.resource_suffix, + ) # TD Manager self.td = self.initTrafficDirectorManager() # Test Server runner self.server_namespace = KubernetesServerRunner.make_namespace_name( - self.resource_prefix, self.resource_suffix) + self.resource_prefix, self.resource_suffix + ) self.server_runner = self.initKubernetesServerRunner() # Test Client runner self.client_namespace = KubernetesClientRunner.make_namespace_name( - self.resource_prefix, self.resource_suffix) + self.resource_prefix, self.resource_suffix + ) self.client_runner = self.initKubernetesClientRunner() # Ensures the firewall exist if self.ensure_firewall: self.td.create_firewall_rule( - allowed_ports=self.firewall_allowed_ports) + allowed_ports=self.firewall_allowed_ports + ) # Randomize xds port, when it's set to 0 if self.server_xds_port == 0: @@ -526,7 +622,7 @@ def setUp(self): # forwarding rule. This check is better than nothing, # but we should find a better approach. self.server_xds_port = self.td.find_unused_forwarding_rule_port() - logger.info('Found unused xds port: %s', self.server_xds_port) + logger.info("Found unused xds port: %s", self.server_xds_port) @abc.abstractmethod def initTrafficDirectorManager(self) -> TrafficDirectorManager: @@ -541,27 +637,31 @@ def initKubernetesClientRunner(self) -> KubernetesClientRunner: raise NotImplementedError def tearDown(self): - logger.info('----- TestMethod %s teardown -----', self.id()) - logger.debug('Getting pods restart times') + logger.info("----- TestMethod %s teardown -----", self.id()) + logger.debug("Getting pods restart times") client_restarts: int = 0 server_restarts: int = 0 try: client_restarts = self.client_runner.get_pod_restarts( - self.client_runner.deployment) + self.client_runner.deployment + ) server_restarts = self.server_runner.get_pod_restarts( - self.server_runner.deployment) + self.server_runner.deployment + ) except (retryers.RetryError, k8s.NotFound) as e: logger.exception(e) - retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10), - attempts=3, - log_level=logging.INFO) + retryer = retryers.constant_retryer( + wait_fixed=_timedelta(seconds=10), + attempts=3, + log_level=logging.INFO, + ) try: retryer(self.cleanup) except retryers.RetryError: - logger.exception('Got error during teardown') + logger.exception("Got error during teardown") finally: - logger.info('----- Test client/server logs -----') + logger.info("----- Test client/server logs -----") self.client_runner.logs_explorer_run_history_links() self.server_runner.logs_explorer_run_history_links() @@ -569,25 +669,28 @@ def tearDown(self): self.assertEqual( client_restarts, 0, - msg= - ('Client pods unexpectedly restarted' - f' {client_restarts} times during test.' - ' In most cases, this is caused by the test client app crash.' - )) + msg=( + "Client pods unexpectedly restarted" + f" {client_restarts} times during test. In most cases, this" + " is caused by the test client app crash." + ), + ) self.assertEqual( server_restarts, 0, - msg= - ('Server pods unexpectedly restarted' - f' {server_restarts} times during test.' - ' In most cases, this is caused by the test client app crash.' - )) + msg=( + "Server pods unexpectedly restarted" + f" {server_restarts} times during test. In most cases, this" + " is caused by the test client app crash." + ), + ) def cleanup(self): self.td.cleanup(force=self.force_cleanup) self.client_runner.cleanup(force=self.force_cleanup) - self.server_runner.cleanup(force=self.force_cleanup, - force_namespace=self.force_cleanup) + self.server_runner.cleanup( + force=self.force_cleanup, force_namespace=self.force_cleanup + ) class RegularXdsKubernetesTestCase(IsolatedXdsKubernetesTestCase): @@ -600,8 +703,9 @@ def setUpClass(cls): """ super().setUpClass() if cls.server_maintenance_port is None: - cls.server_maintenance_port = \ + cls.server_maintenance_port = ( KubernetesServerRunner.DEFAULT_MAINTENANCE_PORT + ) def initTrafficDirectorManager(self) -> TrafficDirectorManager: return TrafficDirectorManager( @@ -610,12 +714,14 @@ def initTrafficDirectorManager(self) -> TrafficDirectorManager: resource_prefix=self.resource_prefix, resource_suffix=self.resource_suffix, network=self.network, - compute_api_version=self.compute_api_version) + compute_api_version=self.compute_api_version, + ) def initKubernetesServerRunner(self) -> KubernetesServerRunner: return KubernetesServerRunner( - k8s.KubernetesNamespace(self.k8s_api_manager, - self.server_namespace), + k8s.KubernetesNamespace( + self.k8s_api_manager, self.server_namespace + ), deployment_name=self.server_name, image_name=self.server_image, td_bootstrap_image=self.td_bootstrap_image, @@ -625,12 +731,14 @@ def initKubernetesServerRunner(self) -> KubernetesServerRunner: xds_server_uri=self.xds_server_uri, network=self.network, debug_use_port_forwarding=self.debug_use_port_forwarding, - enable_workload_identity=self.enable_workload_identity) + enable_workload_identity=self.enable_workload_identity, + ) def initKubernetesClientRunner(self) -> KubernetesClientRunner: return KubernetesClientRunner( - k8s.KubernetesNamespace(self.k8s_api_manager, - self.client_namespace), + k8s.KubernetesNamespace( + self.k8s_api_manager, self.client_namespace + ), deployment_name=self.client_name, image_name=self.client_image, td_bootstrap_image=self.td_bootstrap_image, @@ -642,28 +750,32 @@ def initKubernetesClientRunner(self) -> KubernetesClientRunner: debug_use_port_forwarding=self.debug_use_port_forwarding, enable_workload_identity=self.enable_workload_identity, stats_port=self.client_port, - reuse_namespace=self.server_namespace == self.client_namespace) + reuse_namespace=self.server_namespace == self.client_namespace, + ) - def startTestServers(self, - replica_count=1, - server_runner=None, - **kwargs) -> List[XdsTestServer]: + def startTestServers( + self, replica_count=1, server_runner=None, **kwargs + ) -> List[XdsTestServer]: if server_runner is None: server_runner = self.server_runner test_servers = server_runner.run( replica_count=replica_count, test_port=self.server_port, maintenance_port=self.server_maintenance_port, - **kwargs) + **kwargs, + ) for test_server in test_servers: - test_server.set_xds_address(self.server_xds_host, - self.server_xds_port) + test_server.set_xds_address( + self.server_xds_host, self.server_xds_port + ) return test_servers - def startTestClient(self, test_server: XdsTestServer, - **kwargs) -> XdsTestClient: - test_client = self.client_runner.run(server_target=test_server.xds_uri, - **kwargs) + def startTestClient( + self, test_server: XdsTestServer, **kwargs + ) -> XdsTestClient: + test_client = self.client_runner.run( + server_target=test_server.xds_uri, **kwargs + ) test_client.wait_for_active_server_channel() return test_client @@ -678,11 +790,13 @@ def initTrafficDirectorManager(self) -> TrafficDirectorAppNetManager: resource_prefix=self.resource_prefix, resource_suffix=self.resource_suffix, network=self.network, - compute_api_version=self.compute_api_version) + compute_api_version=self.compute_api_version, + ) class SecurityXdsKubernetesTestCase(IsolatedXdsKubernetesTestCase): """Test case base class for testing PSM security features in isolation.""" + td: TrafficDirectorSecureManager class SecurityMode(enum.Enum): @@ -702,8 +816,9 @@ def setUpClass(cls): # Health Checks and Channelz tests available. # When not provided, use explicit numeric port value, so # Backend Health Checks are created on a fixed port. - cls.server_maintenance_port = \ + cls.server_maintenance_port = ( KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT + ) def initTrafficDirectorManager(self) -> TrafficDirectorSecureManager: return TrafficDirectorSecureManager( @@ -712,12 +827,14 @@ def initTrafficDirectorManager(self) -> TrafficDirectorSecureManager: resource_prefix=self.resource_prefix, resource_suffix=self.resource_suffix, network=self.network, - compute_api_version=self.compute_api_version) + compute_api_version=self.compute_api_version, + ) def initKubernetesServerRunner(self) -> KubernetesServerRunner: return KubernetesServerRunner( - k8s.KubernetesNamespace(self.k8s_api_manager, - self.server_namespace), + k8s.KubernetesNamespace( + self.k8s_api_manager, self.server_namespace + ), deployment_name=self.server_name, image_name=self.server_image, td_bootstrap_image=self.td_bootstrap_image, @@ -726,13 +843,15 @@ def initKubernetesServerRunner(self) -> KubernetesServerRunner: gcp_service_account=self.gcp_service_account, network=self.network, xds_server_uri=self.xds_server_uri, - deployment_template='server-secure.deployment.yaml', - debug_use_port_forwarding=self.debug_use_port_forwarding) + deployment_template="server-secure.deployment.yaml", + debug_use_port_forwarding=self.debug_use_port_forwarding, + ) def initKubernetesClientRunner(self) -> KubernetesClientRunner: return KubernetesClientRunner( - k8s.KubernetesNamespace(self.k8s_api_manager, - self.client_namespace), + k8s.KubernetesNamespace( + self.k8s_api_manager, self.client_namespace + ), deployment_name=self.client_name, image_name=self.client_image, td_bootstrap_image=self.td_bootstrap_image, @@ -741,10 +860,11 @@ def initKubernetesClientRunner(self) -> KubernetesClientRunner: gcp_service_account=self.gcp_service_account, xds_server_uri=self.xds_server_uri, network=self.network, - deployment_template='client-secure.deployment.yaml', + deployment_template="client-secure.deployment.yaml", stats_port=self.client_port, reuse_namespace=self.server_namespace == self.client_namespace, - debug_use_port_forwarding=self.debug_use_port_forwarding) + debug_use_port_forwarding=self.debug_use_port_forwarding, + ) def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer: test_server = self.server_runner.run( @@ -752,43 +872,55 @@ def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer: test_port=self.server_port, maintenance_port=self.server_maintenance_port, secure_mode=True, - **kwargs)[0] + **kwargs, + )[0] test_server.set_xds_address(self.server_xds_host, self.server_xds_port) return test_server - def setupSecurityPolicies(self, *, server_tls, server_mtls, client_tls, - client_mtls): - self.td.setup_client_security(server_namespace=self.server_namespace, - server_name=self.server_name, - tls=client_tls, - mtls=client_mtls) - self.td.setup_server_security(server_namespace=self.server_namespace, - server_name=self.server_name, - server_port=self.server_port, - tls=server_tls, - mtls=server_mtls) - - def startSecureTestClient(self, - test_server: XdsTestServer, - *, - wait_for_active_server_channel=True, - **kwargs) -> XdsTestClient: - test_client = self.client_runner.run(server_target=test_server.xds_uri, - secure_mode=True, - **kwargs) + def setupSecurityPolicies( + self, *, server_tls, server_mtls, client_tls, client_mtls + ): + self.td.setup_client_security( + server_namespace=self.server_namespace, + server_name=self.server_name, + tls=client_tls, + mtls=client_mtls, + ) + self.td.setup_server_security( + server_namespace=self.server_namespace, + server_name=self.server_name, + server_port=self.server_port, + tls=server_tls, + mtls=server_mtls, + ) + + def startSecureTestClient( + self, + test_server: XdsTestServer, + *, + wait_for_active_server_channel=True, + **kwargs, + ) -> XdsTestClient: + test_client = self.client_runner.run( + server_target=test_server.xds_uri, secure_mode=True, **kwargs + ) if wait_for_active_server_channel: test_client.wait_for_active_server_channel() return test_client - def assertTestAppSecurity(self, mode: SecurityMode, - test_client: XdsTestClient, - test_server: XdsTestServer): + def assertTestAppSecurity( + self, + mode: SecurityMode, + test_client: XdsTestClient, + test_server: XdsTestServer, + ): client_socket, server_socket = self.getConnectedSockets( - test_client, test_server) + test_client, test_server + ) server_security: grpc_channelz.Security = server_socket.security client_security: grpc_channelz.Security = client_socket.security - logger.info('Server certs: %s', self.debug_sock_certs(server_security)) - logger.info('Client certs: %s', self.debug_sock_certs(client_security)) + logger.info("Server certs: %s", self.debug_sock_certs(server_security)) + logger.info("Client certs: %s", self.debug_sock_certs(client_security)) if mode is self.SecurityMode.MTLS: self.assertSecurityMtls(client_security, server_security) @@ -797,100 +929,144 @@ def assertTestAppSecurity(self, mode: SecurityMode, elif mode is self.SecurityMode.PLAINTEXT: self.assertSecurityPlaintext(client_security, server_security) else: - raise TypeError('Incorrect security mode') - - def assertSecurityMtls(self, client_security: grpc_channelz.Security, - server_security: grpc_channelz.Security): - self.assertEqual(client_security.WhichOneof('model'), - 'tls', - msg='(mTLS) Client socket security model must be TLS') - self.assertEqual(server_security.WhichOneof('model'), - 'tls', - msg='(mTLS) Server socket security model must be TLS') + raise TypeError("Incorrect security mode") + + def assertSecurityMtls( + self, + client_security: grpc_channelz.Security, + server_security: grpc_channelz.Security, + ): + self.assertEqual( + client_security.WhichOneof("model"), + "tls", + msg="(mTLS) Client socket security model must be TLS", + ) + self.assertEqual( + server_security.WhichOneof("model"), + "tls", + msg="(mTLS) Server socket security model must be TLS", + ) server_tls, client_tls = server_security.tls, client_security.tls # Confirm regular TLS: server local cert == client remote cert - self.assertNotEmpty(client_tls.remote_certificate, - msg="(mTLS) Client remote certificate is missing") + self.assertNotEmpty( + client_tls.remote_certificate, + msg="(mTLS) Client remote certificate is missing", + ) if self.check_local_certs: self.assertNotEmpty( server_tls.local_certificate, - msg="(mTLS) Server local certificate is missing") + msg="(mTLS) Server local certificate is missing", + ) self.assertEqual( server_tls.local_certificate, client_tls.remote_certificate, - msg="(mTLS) Server local certificate must match client's " - "remote certificate") + msg=( + "(mTLS) Server local certificate must match client's " + "remote certificate" + ), + ) # mTLS: server remote cert == client local cert - self.assertNotEmpty(server_tls.remote_certificate, - msg="(mTLS) Server remote certificate is missing") + self.assertNotEmpty( + server_tls.remote_certificate, + msg="(mTLS) Server remote certificate is missing", + ) if self.check_local_certs: self.assertNotEmpty( client_tls.local_certificate, - msg="(mTLS) Client local certificate is missing") + msg="(mTLS) Client local certificate is missing", + ) self.assertEqual( server_tls.remote_certificate, client_tls.local_certificate, - msg="(mTLS) Server remote certificate must match client's " - "local certificate") - - def assertSecurityTls(self, client_security: grpc_channelz.Security, - server_security: grpc_channelz.Security): - self.assertEqual(client_security.WhichOneof('model'), - 'tls', - msg='(TLS) Client socket security model must be TLS') - self.assertEqual(server_security.WhichOneof('model'), - 'tls', - msg='(TLS) Server socket security model must be TLS') + msg=( + "(mTLS) Server remote certificate must match client's " + "local certificate" + ), + ) + + def assertSecurityTls( + self, + client_security: grpc_channelz.Security, + server_security: grpc_channelz.Security, + ): + self.assertEqual( + client_security.WhichOneof("model"), + "tls", + msg="(TLS) Client socket security model must be TLS", + ) + self.assertEqual( + server_security.WhichOneof("model"), + "tls", + msg="(TLS) Server socket security model must be TLS", + ) server_tls, client_tls = server_security.tls, client_security.tls # Regular TLS: server local cert == client remote cert - self.assertNotEmpty(client_tls.remote_certificate, - msg="(TLS) Client remote certificate is missing") + self.assertNotEmpty( + client_tls.remote_certificate, + msg="(TLS) Client remote certificate is missing", + ) if self.check_local_certs: - self.assertNotEmpty(server_tls.local_certificate, - msg="(TLS) Server local certificate is missing") + self.assertNotEmpty( + server_tls.local_certificate, + msg="(TLS) Server local certificate is missing", + ) self.assertEqual( server_tls.local_certificate, client_tls.remote_certificate, - msg="(TLS) Server local certificate must match client " - "remote certificate") + msg=( + "(TLS) Server local certificate must match client " + "remote certificate" + ), + ) # mTLS must not be used self.assertEmpty( server_tls.remote_certificate, - msg="(TLS) Server remote certificate must be empty in TLS mode. " - "Is server security incorrectly configured for mTLS?") + msg=( + "(TLS) Server remote certificate must be empty in TLS mode. " + "Is server security incorrectly configured for mTLS?" + ), + ) self.assertEmpty( client_tls.local_certificate, - msg="(TLS) Client local certificate must be empty in TLS mode. " - "Is client security incorrectly configured for mTLS?") + msg=( + "(TLS) Client local certificate must be empty in TLS mode. " + "Is client security incorrectly configured for mTLS?" + ), + ) def assertSecurityPlaintext(self, client_security, server_security): server_tls, client_tls = server_security.tls, client_security.tls # Not TLS self.assertEmpty( server_tls.local_certificate, - msg="(Plaintext) Server local certificate must be empty.") + msg="(Plaintext) Server local certificate must be empty.", + ) self.assertEmpty( client_tls.local_certificate, - msg="(Plaintext) Client local certificate must be empty.") + msg="(Plaintext) Client local certificate must be empty.", + ) # Not mTLS self.assertEmpty( server_tls.remote_certificate, - msg="(Plaintext) Server remote certificate must be empty.") + msg="(Plaintext) Server remote certificate must be empty.", + ) self.assertEmpty( client_tls.local_certificate, - msg="(Plaintext) Client local certificate must be empty.") + msg="(Plaintext) Client local certificate must be empty.", + ) def assertClientCannotReachServerRepeatedly( - self, - test_client: XdsTestClient, - *, - times: Optional[int] = None, - delay: Optional[_timedelta] = None): + self, + test_client: XdsTestClient, + *, + times: Optional[int] = None, + delay: Optional[_timedelta] = None, + ): """ Asserts that the client repeatedly cannot reach the server. @@ -916,8 +1092,11 @@ def assertClientCannotReachServerRepeatedly( for i in range(1, times + 1): self.assertClientCannotReachServer(test_client) if i < times: - logger.info('Check %s passed, waiting %s before the next check', - i, delay) + logger.info( + "Check %s passed, waiting %s before the next check", + i, + delay, + ) time.sleep(delay.total_seconds()) def assertClientCannotReachServer(self, test_client: XdsTestClient): @@ -926,13 +1105,19 @@ def assertClientCannotReachServer(self, test_client: XdsTestClient): def assertClientChannelFailed(self, test_client: XdsTestClient): channel = test_client.wait_for_server_channel_state( - state=_ChannelState.TRANSIENT_FAILURE) + state=_ChannelState.TRANSIENT_FAILURE + ) subchannels = list( - test_client.channelz.list_channel_subchannels(channel)) - self.assertLen(subchannels, - 1, - msg="Client channel must have exactly one subchannel " - "in state TRANSIENT_FAILURE.") + test_client.channelz.list_channel_subchannels(channel) + ) + self.assertLen( + subchannels, + 1, + msg=( + "Client channel must have exactly one subchannel " + "in state TRANSIENT_FAILURE." + ), + ) @staticmethod def getConnectedSockets( @@ -944,15 +1129,17 @@ def getConnectedSockets( @classmethod def debug_sock_certs(cls, security: grpc_channelz.Security): - if security.WhichOneof('model') == 'other': - return f'other: <{security.other.name}={security.other.value}>' + if security.WhichOneof("model") == "other": + return f"other: <{security.other.name}={security.other.value}>" - return (f'local: <{cls.debug_cert(security.tls.local_certificate)}>, ' - f'remote: <{cls.debug_cert(security.tls.remote_certificate)}>') + return ( + f"local: <{cls.debug_cert(security.tls.local_certificate)}>, " + f"remote: <{cls.debug_cert(security.tls.remote_certificate)}>" + ) @staticmethod def debug_cert(cert): if not cert: - return 'missing' + return "missing" sha1 = hashlib.sha1(cert) - return f'sha1={sha1.hexdigest()}, len={len(cert)}' + return f"sha1={sha1.hexdigest()}, len={len(cert)}" diff --git a/tools/run_tests/xds_k8s_test_driver/framework/xds_url_map_test_resources.py b/tools/run_tests/xds_k8s_test_driver/framework/xds_url_map_test_resources.py index 180c7ca52e11b..9313981323396 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/xds_url_map_test_resources.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/xds_url_map_test_resources.py @@ -32,10 +32,12 @@ flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_k8s_flags) -STRATEGY = flags.DEFINE_enum('strategy', - default='reuse', - enum_values=['create', 'keep', 'reuse'], - help='Strategy of GCP resources management') +STRATEGY = flags.DEFINE_enum( + "strategy", + default="reuse", + enum_values=["create", "keep", "reuse"], + help="Strategy of GCP resources management", +) # Type alias _KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner @@ -59,16 +61,21 @@ def __init__(self, url_map_name: str): def get_map(self) -> UrlMapType: return self._map - def apply_change(self, test_case: 'XdsUrlMapTestCase') -> None: - logging.info('Apply urlMap change for test case: %s.%s', - test_case.short_module_name, test_case.__name__) + def apply_change(self, test_case: "XdsUrlMapTestCase") -> None: + logging.info( + "Apply urlMap change for test case: %s.%s", + test_case.short_module_name, + test_case.__name__, + ) url_map_parts = test_case.url_map_change( - *self._get_test_case_url_map(test_case)) + *self._get_test_case_url_map(test_case) + ) self._set_test_case_url_map(*url_map_parts) @staticmethod def _get_test_case_url_map( - test_case: 'XdsUrlMapTestCase') -> Tuple[HostRule, PathMatcher]: + test_case: "XdsUrlMapTestCase", + ) -> Tuple[HostRule, PathMatcher]: host_rule = { "hosts": [test_case.hostname()], "pathMatcher": test_case.path_matcher_name(), @@ -79,8 +86,9 @@ def _get_test_case_url_map( } return host_rule, path_matcher - def _set_test_case_url_map(self, host_rule: HostRule, - path_matcher: PathMatcher) -> None: + def _set_test_case_url_map( + self, host_rule: HostRule, path_matcher: PathMatcher + ) -> None: self._map["hostRules"].append(host_rule) self._map["pathMatchers"].append(path_matcher) @@ -99,7 +107,7 @@ class is about to be instantiated. for key, value in inspect.getmembers(flag_module): if isinstance(value, flags.FlagHolder): res[key.lower()] = value.value - res['strategy'] = STRATEGY.value + res["strategy"] = STRATEGY.value return res @@ -141,13 +149,17 @@ def __init__(self, absl_flags: Mapping[str, Any] = None): for key in absl_flags: setattr(self, key, absl_flags[key]) # Pick a client_namespace_suffix if not set - if getattr(self, 'resource_suffix', None) is None: + if getattr(self, "resource_suffix", None) is None: self.resource_suffix = "" else: raise NotImplementedError( - 'Predefined resource_suffix is not supported for UrlMap tests') - logging.info('GcpResourceManager: resource prefix=%s, suffix=%s', - self.resource_prefix, self.resource_suffix) + "Predefined resource_suffix is not supported for UrlMap tests" + ) + logging.info( + "GcpResourceManager: resource prefix=%s, suffix=%s", + self.resource_prefix, + self.resource_suffix, + ) # Must be called before KubernetesApiManager or GcpApiManager init. xds_flags.set_socket_default_timeout_from_flag() @@ -164,8 +176,9 @@ def __init__(self, absl_flags: Mapping[str, Any] = None): compute_api_version=self.compute_api_version, ) # Kubernetes namespace - self.k8s_namespace = k8s.KubernetesNamespace(self.k8s_api_manager, - self.resource_prefix) + self.k8s_namespace = k8s.KubernetesNamespace( + self.k8s_api_manager, self.resource_prefix + ) # Kubernetes Test Servers self.test_server_runner = _KubernetesServerRunner( self.k8s_namespace, @@ -177,10 +190,11 @@ def __init__(self, absl_flags: Mapping[str, Any] = None): td_bootstrap_image=self.td_bootstrap_image, xds_server_uri=self.xds_server_uri, network=self.network, - enable_workload_identity=self.enable_workload_identity) + enable_workload_identity=self.enable_workload_identity, + ) self.test_server_alternative_runner = _KubernetesServerRunner( self.k8s_namespace, - deployment_name=self.server_name + '-alternative', + deployment_name=self.server_name + "-alternative", image_name=self.server_image, gcp_project=self.project, gcp_api_manager=self.gcp_api_manager, @@ -189,10 +203,11 @@ def __init__(self, absl_flags: Mapping[str, Any] = None): xds_server_uri=self.xds_server_uri, network=self.network, enable_workload_identity=self.enable_workload_identity, - reuse_namespace=True) + reuse_namespace=True, + ) self.test_server_affinity_runner = _KubernetesServerRunner( self.k8s_namespace, - deployment_name=self.server_name + '-affinity', + deployment_name=self.server_name + "-affinity", image_name=self.server_image, gcp_project=self.project, gcp_api_manager=self.gcp_api_manager, @@ -201,20 +216,25 @@ def __init__(self, absl_flags: Mapping[str, Any] = None): xds_server_uri=self.xds_server_uri, network=self.network, enable_workload_identity=self.enable_workload_identity, - reuse_namespace=True) - logging.info('Strategy of GCP resources management: %s', self.strategy) + reuse_namespace=True, + ) + logging.info("Strategy of GCP resources management: %s", self.strategy) def create_test_client_runner(self): if self.resource_suffix: client_namespace_suffix = self.resource_suffix else: - client_namespace_suffix = framework.helpers.rand.random_resource_suffix( + client_namespace_suffix = ( + framework.helpers.rand.random_resource_suffix() ) - logging.info('GcpResourceManager: client_namespace_suffix=%s', - client_namespace_suffix) + logging.info( + "GcpResourceManager: client_namespace_suffix=%s", + client_namespace_suffix, + ) # Kubernetes Test Client namespace_name = _KubernetesClientRunner.make_namespace_name( - self.resource_prefix, client_namespace_suffix) + self.resource_prefix, client_namespace_suffix + ) return _KubernetesClientRunner( k8s.KubernetesNamespace(self.k8s_api_manager, namespace_name), deployment_name=self.client_name, @@ -227,27 +247,31 @@ def create_test_client_runner(self): network=self.network, debug_use_port_forwarding=self.debug_use_port_forwarding, enable_workload_identity=self.enable_workload_identity, - stats_port=self.client_port) + stats_port=self.client_port, + ) def _pre_cleanup(self): # Cleanup existing debris - logging.info('GcpResourceManager: pre clean-up') + logging.info("GcpResourceManager: pre clean-up") self.td.cleanup(force=True) self.test_server_runner.delete_namespace() - def setup(self, test_case_classes: Iterable['XdsUrlMapTestCase']) -> None: - if self.strategy not in ['create', 'keep']: - logging.info('GcpResourceManager: skipping setup for strategy [%s]', - self.strategy) + def setup(self, test_case_classes: Iterable["XdsUrlMapTestCase"]) -> None: + if self.strategy not in ["create", "keep"]: + logging.info( + "GcpResourceManager: skipping setup for strategy [%s]", + self.strategy, + ) return # Clean up debris from previous runs self._pre_cleanup() # Start creating GCP resources - logging.info('GcpResourceManager: start setup') + logging.info("GcpResourceManager: start setup") # Firewall if self.ensure_firewall: self.td.create_firewall_rule( - allowed_ports=self.firewall_allowed_ports) + allowed_ports=self.firewall_allowed_ports + ) # Health Checks self.td.create_health_check() # Backend Services @@ -256,7 +280,8 @@ def setup(self, test_case_classes: Iterable['XdsUrlMapTestCase']) -> None: self.td.create_affinity_backend_service() # Construct UrlMap from test classes aggregator = _UrlMapChangeAggregator( - url_map_name=self.td.make_resource_name(self.td.URL_MAP_NAME)) + url_map_name=self.td.make_resource_name(self.td.URL_MAP_NAME) + ) for test_case_class in test_case_classes: aggregator.apply_change(test_case_class) final_url_map = aggregator.get_map() @@ -269,53 +294,67 @@ def setup(self, test_case_classes: Iterable['XdsUrlMapTestCase']) -> None: # Kubernetes Test Server self.test_server_runner.run( test_port=self.server_port, - maintenance_port=self.server_maintenance_port) + maintenance_port=self.server_maintenance_port, + ) # Kubernetes Test Server Alternative self.test_server_alternative_runner.run( test_port=self.server_port, - maintenance_port=self.server_maintenance_port) + maintenance_port=self.server_maintenance_port, + ) # Kubernetes Test Server Affinity. 3 endpoints to test that only the # picked sub-channel is connected. self.test_server_affinity_runner.run( test_port=self.server_port, maintenance_port=self.server_maintenance_port, - replica_count=3) + replica_count=3, + ) # Add backend to default backend service neg_name, neg_zones = self.k8s_namespace.get_service_neg( - self.test_server_runner.service_name, self.server_port) + self.test_server_runner.service_name, self.server_port + ) self.td.backend_service_add_neg_backends(neg_name, neg_zones) # Add backend to alternative backend service neg_name_alt, neg_zones_alt = self.k8s_namespace.get_service_neg( - self.test_server_alternative_runner.service_name, self.server_port) + self.test_server_alternative_runner.service_name, self.server_port + ) self.td.alternative_backend_service_add_neg_backends( - neg_name_alt, neg_zones_alt) + neg_name_alt, neg_zones_alt + ) # Add backend to affinity backend service - neg_name_affinity, neg_zones_affinity = self.k8s_namespace.get_service_neg( - self.test_server_affinity_runner.service_name, self.server_port) + ( + neg_name_affinity, + neg_zones_affinity, + ) = self.k8s_namespace.get_service_neg( + self.test_server_affinity_runner.service_name, self.server_port + ) self.td.affinity_backend_service_add_neg_backends( - neg_name_affinity, neg_zones_affinity) + neg_name_affinity, neg_zones_affinity + ) # Wait for healthy backends self.td.wait_for_backends_healthy_status() self.td.wait_for_alternative_backends_healthy_status() self.td.wait_for_affinity_backends_healthy_status() def cleanup(self) -> None: - if self.strategy not in ['create']: + if self.strategy not in ["create"]: logging.info( - 'GcpResourceManager: skipping tear down for strategy [%s]', - self.strategy) + "GcpResourceManager: skipping tear down for strategy [%s]", + self.strategy, + ) return - logging.info('GcpResourceManager: start tear down') - if hasattr(self, 'td'): + logging.info("GcpResourceManager: start tear down") + if hasattr(self, "td"): self.td.cleanup(force=True) - if hasattr(self, 'test_server_runner'): + if hasattr(self, "test_server_runner"): self.test_server_runner.cleanup(force=True) - if hasattr(self, 'test_server_alternative_runner'): - self.test_server_alternative_runner.cleanup(force=True, - force_namespace=True) - if hasattr(self, 'test_server_affinity_runner'): - self.test_server_affinity_runner.cleanup(force=True, - force_namespace=True) + if hasattr(self, "test_server_alternative_runner"): + self.test_server_alternative_runner.cleanup( + force=True, force_namespace=True + ) + if hasattr(self, "test_server_affinity_runner"): + self.test_server_affinity_runner.cleanup( + force=True, force_namespace=True + ) @functools.lru_cache(None) def default_backend_service(self) -> str: diff --git a/tools/run_tests/xds_k8s_test_driver/framework/xds_url_map_testcase.py b/tools/run_tests/xds_k8s_test_driver/framework/xds_url_map_testcase.py index 0220d7e2a0438..efa513ce14b4c 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/xds_url_map_testcase.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/xds_url_map_testcase.py @@ -44,14 +44,14 @@ flags.adopt_module_key_flags(xds_url_map_test_resources) # Define urlMap specific flags -QPS = flags.DEFINE_integer('qps', default=25, help='The QPS client is sending') +QPS = flags.DEFINE_integer("qps", default=25, help="The QPS client is sending") # Test configs _URL_MAP_PROPAGATE_TIMEOUT_SEC = 600 # With the per-run IAM change, the first xDS response has a several minutes # delay. We want to increase the interval, reduce the log spam. _URL_MAP_PROPAGATE_CHECK_INTERVAL_SEC = 15 -URL_MAP_TESTCASE_FILE_SUFFIX = '_test.py' +URL_MAP_TESTCASE_FILE_SUFFIX = "_test.py" _CLIENT_CONFIGURE_WAIT_SEC = 2 # Type aliases @@ -64,14 +64,15 @@ _timedelta = datetime.timedelta # ProtoBuf translatable RpcType enums -RpcTypeUnaryCall = 'UNARY_CALL' -RpcTypeEmptyCall = 'EMPTY_CALL' +RpcTypeUnaryCall = "UNARY_CALL" +RpcTypeEmptyCall = "EMPTY_CALL" -def _split_camel(s: str, delimiter: str = '-') -> str: +def _split_camel(s: str, delimiter: str = "-") -> str: """Turn camel case name to snake-case-like name.""" - return ''.join(delimiter + c.lower() if c.isupper() else c - for c in s).lstrip(delimiter) + return "".join( + delimiter + c.lower() if c.isupper() else c for c in s + ).lstrip(delimiter) class DumpedXdsConfig(dict): @@ -89,59 +90,75 @@ def __init__(self, xds_json: JsonType): # pylint: disable=too-many-branches self.cds = [] self.eds = [] self.endpoints = [] - for xds_config in self.get('xdsConfig', []): + for xds_config in self.get("xdsConfig", []): try: - if 'listenerConfig' in xds_config: - self.lds = xds_config['listenerConfig']['dynamicListeners'][ - 0]['activeState']['listener'] - elif 'routeConfig' in xds_config: - self.rds = xds_config['routeConfig']['dynamicRouteConfigs'][ - 0]['routeConfig'] - self.rds_version = xds_config['routeConfig'][ - 'dynamicRouteConfigs'][0]['versionInfo'] - elif 'clusterConfig' in xds_config: - for cluster in xds_config['clusterConfig'][ - 'dynamicActiveClusters']: - self.cds.append(cluster['cluster']) - elif 'endpointConfig' in xds_config: - for endpoint in xds_config['endpointConfig'][ - 'dynamicEndpointConfigs']: - self.eds.append(endpoint['endpointConfig']) + if "listenerConfig" in xds_config: + self.lds = xds_config["listenerConfig"]["dynamicListeners"][ + 0 + ]["activeState"]["listener"] + elif "routeConfig" in xds_config: + self.rds = xds_config["routeConfig"]["dynamicRouteConfigs"][ + 0 + ]["routeConfig"] + self.rds_version = xds_config["routeConfig"][ + "dynamicRouteConfigs" + ][0]["versionInfo"] + elif "clusterConfig" in xds_config: + for cluster in xds_config["clusterConfig"][ + "dynamicActiveClusters" + ]: + self.cds.append(cluster["cluster"]) + elif "endpointConfig" in xds_config: + for endpoint in xds_config["endpointConfig"][ + "dynamicEndpointConfigs" + ]: + self.eds.append(endpoint["endpointConfig"]) # TODO(lidiz) reduce the catch to LookupError except Exception as e: # pylint: disable=broad-except - logging.debug('Parsing dumped xDS config failed with %s: %s', - type(e), e) - for generic_xds_config in self.get('genericXdsConfigs', []): + logging.debug( + "Parsing dumped xDS config failed with %s: %s", type(e), e + ) + for generic_xds_config in self.get("genericXdsConfigs", []): try: - if re.search(r'\.Listener$', generic_xds_config['typeUrl']): + if re.search(r"\.Listener$", generic_xds_config["typeUrl"]): self.lds = generic_xds_config["xdsConfig"] - elif re.search(r'\.RouteConfiguration$', - generic_xds_config['typeUrl']): + elif re.search( + r"\.RouteConfiguration$", generic_xds_config["typeUrl"] + ): self.rds = generic_xds_config["xdsConfig"] self.rds_version = generic_xds_config["versionInfo"] - elif re.search(r'\.Cluster$', generic_xds_config['typeUrl']): + elif re.search(r"\.Cluster$", generic_xds_config["typeUrl"]): self.cds.append(generic_xds_config["xdsConfig"]) - elif re.search(r'\.ClusterLoadAssignment$', - generic_xds_config['typeUrl']): + elif re.search( + r"\.ClusterLoadAssignment$", generic_xds_config["typeUrl"] + ): self.eds.append(generic_xds_config["xdsConfig"]) # TODO(lidiz) reduce the catch to LookupError except Exception as e: # pylint: disable=broad-except - logging.debug('Parsing dumped xDS config failed with %s: %s', - type(e), e) + logging.debug( + "Parsing dumped xDS config failed with %s: %s", type(e), e + ) for endpoint_config in self.eds: - for endpoint in endpoint_config.get('endpoints', {}): - for lb_endpoint in endpoint.get('lbEndpoints', {}): + for endpoint in endpoint_config.get("endpoints", {}): + for lb_endpoint in endpoint.get("lbEndpoints", {}): try: - if lb_endpoint['healthStatus'] == 'HEALTHY': + if lb_endpoint["healthStatus"] == "HEALTHY": self.endpoints.append( - '%s:%s' % (lb_endpoint['endpoint']['address'] - ['socketAddress']['address'], - lb_endpoint['endpoint']['address'] - ['socketAddress']['portValue'])) + "%s:%s" + % ( + lb_endpoint["endpoint"]["address"][ + "socketAddress" + ]["address"], + lb_endpoint["endpoint"]["address"][ + "socketAddress" + ]["portValue"], + ) + ) # TODO(lidiz) reduce the catch to LookupError except Exception as e: # pylint: disable=broad-except - logging.debug('Parse endpoint failed with %s: %s', - type(e), e) + logging.debug( + "Parse endpoint failed with %s: %s", type(e), e + ) def __str__(self) -> str: return json.dumps(self, indent=2) @@ -152,6 +169,7 @@ class RpcDistributionStats: Feel free to add more pre-compute fields. """ + num_failures: int num_oks: int default_service_rpc_count: int @@ -162,7 +180,7 @@ class RpcDistributionStats: empty_call_alternative_service_rpc_count: int def __init__(self, json_lb_stats: JsonType): - self.num_failures = json_lb_stats.get('numFailures', 0) + self.num_failures = json_lb_stats.get("numFailures", 0) self.num_peers = 0 self.num_oks = 0 @@ -174,25 +192,31 @@ def __init__(self, json_lb_stats: JsonType): self.empty_call_alternative_service_rpc_count = 0 self.raw = json_lb_stats - if 'rpcsByPeer' in json_lb_stats: - self.num_peers = len(json_lb_stats['rpcsByPeer']) - if 'rpcsByMethod' in json_lb_stats: - for rpc_type in json_lb_stats['rpcsByMethod']: - for peer in json_lb_stats['rpcsByMethod'][rpc_type][ - 'rpcsByPeer']: - count = json_lb_stats['rpcsByMethod'][rpc_type][ - 'rpcsByPeer'][peer] + if "rpcsByPeer" in json_lb_stats: + self.num_peers = len(json_lb_stats["rpcsByPeer"]) + if "rpcsByMethod" in json_lb_stats: + for rpc_type in json_lb_stats["rpcsByMethod"]: + for peer in json_lb_stats["rpcsByMethod"][rpc_type][ + "rpcsByPeer" + ]: + count = json_lb_stats["rpcsByMethod"][rpc_type][ + "rpcsByPeer" + ][peer] self.num_oks += count - if rpc_type == 'UnaryCall': - if 'alternative' in peer: - self.unary_call_alternative_service_rpc_count = count + if rpc_type == "UnaryCall": + if "alternative" in peer: + self.unary_call_alternative_service_rpc_count = ( + count + ) self.alternative_service_rpc_count += count else: self.unary_call_default_service_rpc_count = count self.default_service_rpc_count += count else: - if 'alternative' in peer: - self.empty_call_alternative_service_rpc_count = count + if "alternative" in peer: + self.empty_call_alternative_service_rpc_count = ( + count + ) self.alternative_service_rpc_count += count else: self.empty_call_default_service_rpc_count = count @@ -202,6 +226,7 @@ def __init__(self, json_lb_stats: JsonType): @dataclass class ExpectedResult: """Describes the expected result of assertRpcStatusCode method below.""" + rpc_type: str = RpcTypeUnaryCall status_code: grpc.StatusCode = grpc.StatusCode.OK ratio: float = 1 @@ -218,26 +243,28 @@ class _MetaXdsUrlMapTestCase(type): _started_test_cases = set() _finished_test_cases = set() - def __new__(cls, name: str, bases: Iterable[Any], - attrs: Mapping[str, Any]) -> Any: + def __new__( + cls, name: str, bases: Iterable[Any], attrs: Mapping[str, Any] + ) -> Any: # Hand over the tracking objects - attrs['test_case_classes'] = cls._test_case_classes - attrs['test_case_names'] = cls._test_case_names - attrs['started_test_cases'] = cls._started_test_cases - attrs['finished_test_cases'] = cls._finished_test_cases + attrs["test_case_classes"] = cls._test_case_classes + attrs["test_case_names"] = cls._test_case_names + attrs["started_test_cases"] = cls._started_test_cases + attrs["finished_test_cases"] = cls._finished_test_cases # Handle the test name reflection - module_name = os.path.split( - sys.modules[attrs['__module__']].__file__)[-1] + module_name = os.path.split(sys.modules[attrs["__module__"]].__file__)[ + -1 + ] if module_name.endswith(URL_MAP_TESTCASE_FILE_SUFFIX): - module_name = module_name.replace(URL_MAP_TESTCASE_FILE_SUFFIX, '') - attrs['short_module_name'] = module_name.replace('_', '-') + module_name = module_name.replace(URL_MAP_TESTCASE_FILE_SUFFIX, "") + attrs["short_module_name"] = module_name.replace("_", "-") # Create the class and track new_class = type.__new__(cls, name, bases, attrs) - if name.startswith('Test'): + if name.startswith("Test"): cls._test_case_names.add(name) cls._test_case_classes.append(new_class) else: - logging.debug('Skipping test case class: %s', name) + logging.debug("Skipping test case class: %s", name) return new_class @@ -288,8 +315,8 @@ def client_init_config(rpc: str, metadata: str) -> Tuple[str, str]: @staticmethod @abc.abstractmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: """Updates the dedicated urlMap components for this test case. Each test case will have a dedicated HostRule, where the hostname is @@ -327,8 +354,11 @@ def rpc_distribution_validate(self, test_client: XdsTestClient) -> None: @classmethod def hostname(cls): - return "%s.%s:%s" % (cls.short_module_name, _split_camel( - cls.__name__), GcpResourceManager().server_xds_port) + return "%s.%s:%s" % ( + cls.short_module_name, + _split_camel(cls.__name__), + GcpResourceManager().server_xds_port, + ) @classmethod def path_matcher_name(cls): @@ -337,8 +367,8 @@ def path_matcher_name(cls): @classmethod def setUpClass(cls): - logging.info('----- Testing %s -----', cls.__name__) - logging.info('Logs timezone: %s', time.localtime().tm_zone) + logging.info("----- Testing %s -----", cls.__name__) + logging.info("Logs timezone: %s", time.localtime().tm_zone) # Raises unittest.SkipTest if given client/server/version does not # support current test case. @@ -355,27 +385,31 @@ def setUpClass(cls): # Create the test case's own client runner with it's own namespace, # enables concurrent running with other test cases. - cls.test_client_runner = GcpResourceManager().create_test_client_runner( + cls.test_client_runner = ( + GcpResourceManager().create_test_client_runner() ) # Start the client, and allow the test to override the initial RPC config. - rpc, metadata = cls.client_init_config(rpc="UnaryCall,EmptyCall", - metadata="") + rpc, metadata = cls.client_init_config( + rpc="UnaryCall,EmptyCall", metadata="" + ) cls.test_client = cls.test_client_runner.run( - server_target=f'xds:///{cls.hostname()}', + server_target=f"xds:///{cls.hostname()}", rpc=rpc, metadata=metadata, qps=QPS.value, - print_response=True) + print_response=True, + ) @classmethod def cleanupAfterTests(cls): - logging.info('----- TestCase %s teardown -----', cls.__name__) + logging.info("----- TestCase %s teardown -----", cls.__name__) client_restarts: int = 0 if cls.test_client_runner: try: - logging.debug('Getting pods restart times') + logging.debug("Getting pods restart times") client_restarts = cls.test_client_runner.get_pod_restarts( - cls.test_client_runner.deployment) + cls.test_client_runner.deployment + ) except (retryers.RetryError, k8s.NotFound) as e: logging.exception(e) @@ -386,23 +420,26 @@ def cleanupAfterTests(cls): # Graceful cleanup: try three times, and don't fail the test on # a cleanup failure. - retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10), - attempts=3, - log_level=logging.INFO) + retryer = retryers.constant_retryer( + wait_fixed=_timedelta(seconds=10), + attempts=3, + log_level=logging.INFO, + ) try: retryer(cls._cleanup, cleanup_all) except retryers.RetryError: - logging.exception('Got error during teardown') + logging.exception("Got error during teardown") finally: - if hasattr(cls, 'test_client_runner') and cls.test_client_runner: - logging.info('----- Test client logs -----') + if hasattr(cls, "test_client_runner") and cls.test_client_runner: + logging.info("----- Test client logs -----") cls.test_client_runner.logs_explorer_run_history_links() # Fail if any of the pods restarted. error_msg = ( - 'Client pods unexpectedly restarted' - f' {client_restarts} times during test.' - ' In most cases, this is caused by the test client app crash.') + "Client pods unexpectedly restarted" + f" {client_restarts} times during test." + " In most cases, this is caused by the test client app crash." + ) assert client_restarts == 0, error_msg @classmethod @@ -415,13 +452,16 @@ def _cleanup(cls, cleanup_all: bool = False): def _fetch_and_check_xds_config(self): # TODO(lidiz) find another way to store last seen xDS config # Cleanup state for this attempt - self._xds_json_config = None # pylint: disable=attribute-defined-outside-init + # pylint: disable=attribute-defined-outside-init + self._xds_json_config = None # Fetch client config config = self.test_client.csds.fetch_client_status( - log_level=logging.INFO) + log_level=logging.INFO + ) self.assertIsNotNone(config) # Found client config, test it. - self._xds_json_config = json_format.MessageToDict(config) # pylint: disable=attribute-defined-outside-init + self._xds_json_config = json_format.MessageToDict(config) + # pylint: enable=attribute-defined-outside-init # Execute the child class provided validation logic self.xds_config_validate(DumpedXdsConfig(self._xds_json_config)) @@ -432,71 +472,94 @@ def run(self, result: unittest.TestResult = None) -> None: and yields clearer signal. """ if result.failures or result.errors: - logging.info('Aborting %s', self.__class__.__name__) + logging.info("Aborting %s", self.__class__.__name__) else: super().run(result) def test_client_config(self): retryer = retryers.constant_retryer( wait_fixed=datetime.timedelta( - seconds=_URL_MAP_PROPAGATE_CHECK_INTERVAL_SEC), + seconds=_URL_MAP_PROPAGATE_CHECK_INTERVAL_SEC + ), timeout=datetime.timedelta(seconds=_URL_MAP_PROPAGATE_TIMEOUT_SEC), logger=logging, - log_level=logging.INFO) + log_level=logging.INFO, + ) try: retryer(self._fetch_and_check_xds_config) finally: logging.info( - 'latest xDS config:\n%s', + "latest xDS config:\n%s", GcpResourceManager().td.compute.resource_pretty_format( - self._xds_json_config)) + self._xds_json_config + ), + ) def test_rpc_distribution(self): self.rpc_distribution_validate(self.test_client) @classmethod - def configure_and_send(cls, - test_client: XdsTestClient, - *, - rpc_types: Iterable[str], - metadata: Optional[Iterable[Tuple[str, str, - str]]] = None, - app_timeout: Optional[int] = None, - num_rpcs: int) -> RpcDistributionStats: - test_client.update_config.configure(rpc_types=rpc_types, - metadata=metadata, - app_timeout=app_timeout) + def configure_and_send( + cls, + test_client: XdsTestClient, + *, + rpc_types: Iterable[str], + metadata: Optional[Iterable[Tuple[str, str, str]]] = None, + app_timeout: Optional[int] = None, + num_rpcs: int, + ) -> RpcDistributionStats: + test_client.update_config.configure( + rpc_types=rpc_types, metadata=metadata, app_timeout=app_timeout + ) # Configure RPC might race with get stats RPC on slower machines. time.sleep(_CLIENT_CONFIGURE_WAIT_SEC) lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs) - logging.info('[%s] << Received LoadBalancerStatsResponse:\n%s', - test_client.hostname, - helpers_grpc.lb_stats_pretty(lb_stats)) + logging.info( + "[%s] << Received LoadBalancerStatsResponse:\n%s", + test_client.hostname, + helpers_grpc.lb_stats_pretty(lb_stats), + ) return RpcDistributionStats(json_format.MessageToDict(lb_stats)) def assertNumEndpoints(self, xds_config: DumpedXdsConfig, k: int) -> None: self.assertLen( - xds_config.endpoints, k, - f'insufficient endpoints in EDS: want={k} seen={xds_config.endpoints}' + xds_config.endpoints, + k, + ( + "insufficient endpoints in EDS:" + f" want={k} seen={xds_config.endpoints}" + ), ) def assertRpcStatusCode( # pylint: disable=too-many-locals - self, test_client: XdsTestClient, *, - expected: Iterable[ExpectedResult], length: int, - tolerance: float) -> None: + self, + test_client: XdsTestClient, + *, + expected: Iterable[ExpectedResult], + length: int, + tolerance: float, + ) -> None: """Assert the distribution of RPC statuses over a period of time.""" # Sending with pre-set QPS for a period of time before_stats = test_client.get_load_balancer_accumulated_stats() logging.info( - 'Received LoadBalancerAccumulatedStatsResponse from test client %s: before:\n%s', + ( + "Received LoadBalancerAccumulatedStatsResponse from test client" + " %s: before:\n%s" + ), test_client.hostname, - helpers_grpc.accumulated_stats_pretty(before_stats)) + helpers_grpc.accumulated_stats_pretty(before_stats), + ) time.sleep(length) after_stats = test_client.get_load_balancer_accumulated_stats() logging.info( - 'Received LoadBalancerAccumulatedStatsResponse from test client %s: after: \n%s', + ( + "Received LoadBalancerAccumulatedStatsResponse from test client" + " %s: after: \n%s" + ), test_client.hostname, - helpers_grpc.accumulated_stats_pretty(after_stats)) + helpers_grpc.accumulated_stats_pretty(after_stats), + ) # Validate the diff for expected_result in expected: @@ -511,21 +574,29 @@ def assertRpcStatusCode( # pylint: disable=too-many-locals seen = seen_after - seen_before # Compute total number of RPC started stats_per_method_after = after_stats.stats_per_method.get( - rpc, {}).result.items() + rpc, {} + ).result.items() total_after = sum( - x[1] for x in stats_per_method_after) # (status_code, count) + x[1] for x in stats_per_method_after + ) # (status_code, count) stats_per_method_before = before_stats.stats_per_method.get( - rpc, {}).result.items() + rpc, {} + ).result.items() total_before = sum( - x[1] for x in stats_per_method_before) # (status_code, count) + x[1] for x in stats_per_method_before + ) # (status_code, count) total = total_after - total_before # Compute and validate the number want = total * expected_result.ratio diff_ratio = abs(seen - want) / total self.assertLessEqual( - diff_ratio, tolerance, - (f'Expect rpc [{rpc}] to return ' - f'[{expected_result.status_code}] at ' - f'{expected_result.ratio:.2f} ratio: ' - f'seen={seen} want={want} total={total} ' - f'diff_ratio={diff_ratio:.4f} > {tolerance:.2f}')) + diff_ratio, + tolerance, + ( + f"Expect rpc [{rpc}] to return " + f"[{expected_result.status_code}] at " + f"{expected_result.ratio:.2f} ratio: " + f"seen={seen} want={want} total={total} " + f"diff_ratio={diff_ratio:.4f} > {tolerance:.2f}" + ), + ) diff --git a/tools/run_tests/xds_k8s_test_driver/requirements-dev.txt b/tools/run_tests/xds_k8s_test_driver/requirements-dev.txt index bf47c5a06a747..e8d336c38a3ab 100644 --- a/tools/run_tests/xds_k8s_test_driver/requirements-dev.txt +++ b/tools/run_tests/xds_k8s_test_driver/requirements-dev.txt @@ -1,4 +1,6 @@ -r requirements.lock -yapf==0.30.0 # Mirrors yapf version set in https://github.com/grpc/grpc/blob/master/tools/distrib/yapf_code.sh +# Mirrors black version set in +# https://github.com/grpc/grpc/blob/master/tools/distrib/black_code.sh +black==23.3.0 isort~=5.9 # TODO(https://github.com/grpc/grpc/pull/25872): mypy diff --git a/tools/run_tests/xds_k8s_test_driver/tests/affinity_test.py b/tools/run_tests/xds_k8s_test_driver/tests/affinity_test.py index 653bcddccac77..c09845d983e65 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/affinity_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/affinity_test.py @@ -35,7 +35,7 @@ _Lang = skips.Lang # Testing consts -_TEST_AFFINITY_METADATA_KEY = 'xds_md' +_TEST_AFFINITY_METADATA_KEY = "xds_md" _TD_PROPAGATE_CHECK_INTERVAL_SEC = 10 _TD_PROPAGATE_TIMEOUT = 600 _REPLICA_COUNT = 3 @@ -43,7 +43,6 @@ class AffinityTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): - @classmethod def setUpClass(cls): super().setUpClass() @@ -56,9 +55,9 @@ def setUpClass(cls): @staticmethod def is_supported(config: skips.TestConfig) -> bool: if config.client_lang in _Lang.CPP | _Lang.JAVA: - return config.version_gte('v1.40.x') + return config.version_gte("v1.40.x") elif config.client_lang == _Lang.GO: - return config.version_gte('v1.41.x') + return config.version_gte("v1.41.x") elif config.client_lang == _Lang.PYTHON: # TODO(https://github.com/grpc/grpc/issues/27430): supported after # the issue is fixed. @@ -68,126 +67,146 @@ def is_supported(config: skips.TestConfig) -> bool: return True def test_affinity(self) -> None: # pylint: disable=too-many-statements - - with self.subTest('00_create_health_check'): + with self.subTest("00_create_health_check"): self.td.create_health_check() - with self.subTest('01_create_backend_services'): + with self.subTest("01_create_backend_services"): self.td.create_backend_service( - affinity_header=_TEST_AFFINITY_METADATA_KEY) + affinity_header=_TEST_AFFINITY_METADATA_KEY + ) - with self.subTest('02_create_url_map'): + with self.subTest("02_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('03_create_target_proxy'): + with self.subTest("03_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('04_create_forwarding_rule'): + with self.subTest("04_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) test_servers: List[_XdsTestServer] - with self.subTest('05_start_test_servers'): + with self.subTest("05_start_test_servers"): test_servers = self.startTestServers(replica_count=_REPLICA_COUNT) - with self.subTest('06_add_server_backends_to_backend_services'): + with self.subTest("06_add_server_backends_to_backend_services"): self.setupServerBackends() test_client: _XdsTestClient - with self.subTest('07_start_test_client'): - test_client = self.startTestClient(test_servers[0], - rpc='EmptyCall', - metadata='EmptyCall:%s:123' % - _TEST_AFFINITY_METADATA_KEY) + with self.subTest("07_start_test_client"): + test_client = self.startTestClient( + test_servers[0], + rpc="EmptyCall", + metadata="EmptyCall:%s:123" % _TEST_AFFINITY_METADATA_KEY, + ) # Validate the number of received endpoints and affinity configs. config = test_client.csds.fetch_client_status( - log_level=logging.INFO) + log_level=logging.INFO + ) self.assertIsNotNone(config) json_config = json_format.MessageToDict(config) parsed = xds_url_map_testcase.DumpedXdsConfig(json_config) - logging.info('Client received CSDS response: %s', parsed) + logging.info("Client received CSDS response: %s", parsed) self.assertLen(parsed.endpoints, _REPLICA_COUNT) self.assertEqual( - parsed.rds['virtualHosts'][0]['routes'][0]['route'] - ['hashPolicy'][0]['header']['headerName'], - _TEST_AFFINITY_METADATA_KEY) - self.assertEqual(parsed.cds[0]['lbPolicy'], 'RING_HASH') + parsed.rds["virtualHosts"][0]["routes"][0]["route"][ + "hashPolicy" + ][0]["header"]["headerName"], + _TEST_AFFINITY_METADATA_KEY, + ) + self.assertEqual(parsed.cds[0]["lbPolicy"], "RING_HASH") - with self.subTest('08_test_client_xds_config_exists'): + with self.subTest("08_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('09_test_server_received_rpcs_from_test_client'): + with self.subTest("09_test_server_received_rpcs_from_test_client"): self.assertSuccessfulRpcs(test_client) - with self.subTest('10_first_100_affinity_rpcs_pick_same_backend'): + with self.subTest("10_first_100_affinity_rpcs_pick_same_backend"): rpc_stats = self.getClientRpcStats(test_client, _RPC_COUNT) json_lb_stats = json_format.MessageToDict(rpc_stats) rpc_distribution = xds_url_map_testcase.RpcDistributionStats( - json_lb_stats) + json_lb_stats + ) self.assertEqual(1, rpc_distribution.num_peers) # Check subchannel states. # One should be READY. ready_channels = test_client.find_subchannels_with_state( - _ChannelzChannelState.READY) + _ChannelzChannelState.READY + ) self.assertLen( ready_channels, 1, - msg=('(AffinityTest) The client expected to have one READY' - ' subchannel to one of the test servers. Found' - f' {len(ready_channels)} instead.'), + msg=( + "(AffinityTest) The client expected to have one READY" + " subchannel to one of the test servers. Found" + f" {len(ready_channels)} instead." + ), ) # The rest should be IDLE. expected_idle_channels = _REPLICA_COUNT - 1 idle_channels = test_client.find_subchannels_with_state( - _ChannelzChannelState.IDLE) + _ChannelzChannelState.IDLE + ) self.assertLen( idle_channels, expected_idle_channels, - msg=('(AffinityTest) The client expected to have IDLE' - f' subchannels to {expected_idle_channels} of the test' - f' servers. Found {len(idle_channels)} instead.'), + msg=( + "(AffinityTest) The client expected to have IDLE" + f" subchannels to {expected_idle_channels} of the test" + f" servers. Found {len(idle_channels)} instead." + ), ) # Remember the backend inuse, and turn it down later. first_backend_inuse = list( - rpc_distribution.raw['rpcsByPeer'].keys())[0] + rpc_distribution.raw["rpcsByPeer"].keys() + )[0] - with self.subTest('11_turn_down_server_in_use'): + with self.subTest("11_turn_down_server_in_use"): for server in test_servers: if server.hostname == first_backend_inuse: server.set_not_serving() - with self.subTest('12_wait_for_unhealth_status_propagation'): + with self.subTest("12_wait_for_unhealth_status_propagation"): deadline = time.time() + _TD_PROPAGATE_TIMEOUT parsed = None try: while time.time() < deadline: config = test_client.csds.fetch_client_status( - log_level=logging.INFO) + log_level=logging.INFO + ) self.assertIsNotNone(config) json_config = json_format.MessageToDict(config) parsed = xds_url_map_testcase.DumpedXdsConfig(json_config) if len(parsed.endpoints) == _REPLICA_COUNT - 1: break logging.info( - 'CSDS got unexpected endpoints, will retry after %d seconds', - _TD_PROPAGATE_CHECK_INTERVAL_SEC) + ( + "CSDS got unexpected endpoints, will retry after %d" + " seconds" + ), + _TD_PROPAGATE_CHECK_INTERVAL_SEC, + ) time.sleep(_TD_PROPAGATE_CHECK_INTERVAL_SEC) else: self.fail( - 'unhealthy status did not propagate after 600 seconds') + "unhealthy status did not propagate after 600 seconds" + ) finally: - logging.info('Client received CSDS response: %s', parsed) + logging.info("Client received CSDS response: %s", parsed) - with self.subTest('12_next_100_affinity_rpcs_pick_different_backend'): + with self.subTest("12_next_100_affinity_rpcs_pick_different_backend"): rpc_stats = self.getClientRpcStats(test_client, _RPC_COUNT) json_lb_stats = json_format.MessageToDict(rpc_stats) rpc_distribution = xds_url_map_testcase.RpcDistributionStats( - json_lb_stats) + json_lb_stats + ) self.assertEqual(1, rpc_distribution.num_peers) - new_backend_inuse = list( - rpc_distribution.raw['rpcsByPeer'].keys())[0] + new_backend_inuse = list(rpc_distribution.raw["rpcsByPeer"].keys())[ + 0 + ] self.assertNotEqual(new_backend_inuse, first_backend_inuse) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/api_listener_test.py b/tools/run_tests/xds_k8s_test_driver/tests/api_listener_test.py index 4e76b03fc2fed..039bf10d1e2a7 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/api_listener_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/api_listener_test.py @@ -34,90 +34,101 @@ class ApiListenerTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: if config.client_lang == _Lang.PYTHON: # gRPC Python versions prior to v1.43.x don't support handling empty # RDS update. - return config.version_gte('v1.43.x') + return config.version_gte("v1.43.x") return True def test_api_listener(self) -> None: - with self.subTest('00_create_health_check'): + with self.subTest("00_create_health_check"): self.td.create_health_check() - with self.subTest('01_create_backend_services'): + with self.subTest("01_create_backend_services"): self.td.create_backend_service() - with self.subTest('02_create_default_url_map'): + with self.subTest("02_create_default_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('03_create_default_target_proxy'): + with self.subTest("03_create_default_target_proxy"): self.td.create_target_proxy() - with self.subTest('04_create_default_forwarding_rule'): + with self.subTest("04_create_default_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) test_server: _XdsTestServer - with self.subTest('05_start_test_server'): + with self.subTest("05_start_test_server"): test_server = self.startTestServers()[0] - with self.subTest('06_add_server_backends_to_backend_services'): + with self.subTest("06_add_server_backends_to_backend_services"): self.setupServerBackends() test_client: _XdsTestClient - with self.subTest('07_start_test_client'): + with self.subTest("07_start_test_client"): test_client = self.startTestClient(test_server) - with self.subTest('08_test_client_xds_config_exists'): + with self.subTest("08_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('09_test_server_received_rpcs'): + with self.subTest("09_test_server_received_rpcs"): self.assertSuccessfulRpcs(test_client) - with self.subTest('10_create_alternate_url_map'): - self.td.create_alternative_url_map(self.server_xds_host, - self.server_xds_port, - self.td.backend_service) + with self.subTest("10_create_alternate_url_map"): + self.td.create_alternative_url_map( + self.server_xds_host, + self.server_xds_port, + self.td.backend_service, + ) # Create alternate target proxy pointing to alternate url_map with the same # host name in host rule. The port is fixed because they point to the same backend service. # Therefore we have to choose a non-`0.0.0.0` ip because ip:port needs to be unique. # We also have to set validate_for_proxyless=false because requires `0.0.0.0` ip. # See https://github.com/grpc/grpc-java/issues/8009 - with self.subTest('11_create_alternate_target_proxy'): + with self.subTest("11_create_alternate_target_proxy"): self.td.create_alternative_target_proxy() # Create a second suite of map+tp+fr with the same host name in host rule. # We set fr ip_address to be different from `0.0.0.0` and then set # validate_for_proxyless=false because ip:port needs to be unique. - with self.subTest('12_create_alternate_forwarding_rule'): - self.td.create_alternative_forwarding_rule(self.server_xds_port, - ip_address='10.10.10.10') + with self.subTest("12_create_alternate_forwarding_rule"): + self.td.create_alternative_forwarding_rule( + self.server_xds_port, ip_address="10.10.10.10" + ) - with self.subTest('13_test_server_received_rpcs_with_two_url_maps'): + with self.subTest("13_test_server_received_rpcs_with_two_url_maps"): self.assertSuccessfulRpcs(test_client) raw_config = test_client.csds.fetch_client_status( - log_level=logging.INFO) + log_level=logging.INFO + ) dumped_config = _DumpedXdsConfig( - json_format.MessageToDict(raw_config)) + json_format.MessageToDict(raw_config) + ) previous_route_config_version = dumped_config.rds_version - logger.info(('received client config from CSDS with two url maps, ' - 'dump config: %s, rds version: %s'), dumped_config, - previous_route_config_version) - - with self.subTest('14_delete_one_url_map_target_proxy_forwarding_rule'): + logger.info( + ( + "received client config from CSDS with two url maps, " + "dump config: %s, rds version: %s" + ), + dumped_config, + previous_route_config_version, + ) + + with self.subTest("14_delete_one_url_map_target_proxy_forwarding_rule"): self.td.delete_forwarding_rule() self.td.delete_target_grpc_proxy() self.td.delete_url_map() - with self.subTest('15_test_server_continues_to_receive_rpcs'): + with self.subTest("15_test_server_continues_to_receive_rpcs"): self.assertRouteConfigUpdateTrafficHandoff( - test_client, previous_route_config_version, + test_client, + previous_route_config_version, _TD_CONFIG_RETRY_WAIT_SEC, - xds_k8s_testcase._TD_CONFIG_MAX_WAIT_SEC) + xds_k8s_testcase._TD_CONFIG_MAX_WAIT_SEC, + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/app_net_test.py b/tools/run_tests/xds_k8s_test_driver/tests/app_net_test.py index 1b1e6ead2af1f..c399cbcb8d815 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/app_net_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/app_net_test.py @@ -26,39 +26,40 @@ class AppNetTest(xds_k8s_testcase.AppNetXdsKubernetesTestCase): - def test_ping_pong(self): - with self.subTest('0_create_health_check'): + with self.subTest("0_create_health_check"): self.td.create_health_check() - with self.subTest('1_create_backend_service'): + with self.subTest("1_create_backend_service"): self.td.create_backend_service() - with self.subTest('2_create_mesh'): + with self.subTest("2_create_mesh"): self.td.create_mesh() - with self.subTest('3_create_grpc_route'): - self.td.create_grpc_route(self.server_xds_host, - self.server_xds_port) + with self.subTest("3_create_grpc_route"): + self.td.create_grpc_route( + self.server_xds_host, self.server_xds_port + ) test_server: _XdsTestServer - with self.subTest('4_start_test_server'): + with self.subTest("4_start_test_server"): test_server = self.startTestServers(replica_count=1)[0] - with self.subTest('5_setup_server_backends'): + with self.subTest("5_setup_server_backends"): self.setupServerBackends() test_client: _XdsTestClient - with self.subTest('6_start_test_client'): - test_client = self.startTestClient(test_server, - config_mesh=self.td.mesh.name) + with self.subTest("6_start_test_client"): + test_client = self.startTestClient( + test_server, config_mesh=self.td.mesh.name + ) - with self.subTest('7_assert_xds_config_exists'): + with self.subTest("7_assert_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('8_assert_successful_rpcs'): + with self.subTest("8_assert_successful_rpcs"): self.assertSuccessfulRpcs(test_client) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/authz_test.py b/tools/run_tests/xds_k8s_test_driver/tests/authz_test.py index f39e1acda55b1..e3c3bff896bef 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/authz_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/authz_test.py @@ -43,8 +43,8 @@ class AuthzTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase): RPC_TYPE_CYCLE = { - 'UNARY_CALL': 'EMPTY_CALL', - 'EMPTY_CALL': 'UNARY_CALL', + "UNARY_CALL": "EMPTY_CALL", + "EMPTY_CALL": "UNARY_CALL", } @staticmethod @@ -52,9 +52,9 @@ def is_supported(config: skips.TestConfig) -> bool: # Per "Authorization (RBAC)" in # https://github.com/grpc/grpc/blob/master/doc/grpc_xds_features.md if config.client_lang in _Lang.CPP | _Lang.PYTHON: - return config.version_gte('v1.47.x') + return config.version_gte("v1.47.x") elif config.client_lang in _Lang.GO | _Lang.JAVA: - return config.version_gte('v1.42.x') + return config.version_gte("v1.42.x") elif config.client_lang == _Lang.NODE: return False return True @@ -86,25 +86,30 @@ def authz_rules(self): }, }, { - "destinations": [{ - "hosts": [f"{self.server_xds_host}:{self.server_xds_port}"], - "ports": [self.server_port], - "httpHeaderMatch": { - "headerName": "test", - "regexMatch": "host-match1", + "destinations": [ + { + "hosts": [ + f"{self.server_xds_host}:{self.server_xds_port}" + ], + "ports": [self.server_port], + "httpHeaderMatch": { + "headerName": "test", + "regexMatch": "host-match1", + }, }, - }, { - "hosts": [ - f"a-not-it.com:{self.server_xds_port}", - f"{self.server_xds_host}:{self.server_xds_port}", - "z-not-it.com:1", - ], - "ports": [1, self.server_port, 65535], - "httpHeaderMatch": { - "headerName": "test", - "regexMatch": "host-match2", + { + "hosts": [ + f"a-not-it.com:{self.server_xds_port}", + f"{self.server_xds_host}:{self.server_xds_port}", + "z-not-it.com:1", + ], + "ports": [1, self.server_port, 65535], + "httpHeaderMatch": { + "headerName": "test", + "regexMatch": "host-match2", + }, }, - }], + ], }, { "destinations": { @@ -144,17 +149,22 @@ def authz_rules(self): # }, # }, { - "sources": [{ - "principals": [ - f"spiffe://{self.project}.svc.id.goog/not/the/client", - ], - }, { - "principals": [ - f"spiffe://{self.project}.svc.id.goog/not/the/client", - f"spiffe://{self.project}.svc.id.goog/ns/" - f"{self.client_namespace}/sa/{self.client_name}", - ], - }], + "sources": [ + { + "principals": [ + f"spiffe://{self.project}.svc.id.goog/not/the/client", + ], + }, + { + "principals": [ + f"spiffe://{self.project}.svc.id.goog/not/the/client", + ( + f"spiffe://{self.project}.svc.id.goog/ns/" + f"{self.client_namespace}/sa/{self.client_name}" + ), + ], + }, + ], "destinations": { "hosts": [f"*:{self.server_xds_port}"], "ports": [self.server_port], @@ -181,9 +191,12 @@ def authz_rules(self): }, ] - def configure_and_assert(self, test_client: _XdsTestClient, - test_metadata_val: Optional[str], - status_code: grpc.StatusCode) -> None: + def configure_and_assert( + self, + test_client: _XdsTestClient, + test_metadata_val: Optional[str], + status_code: grpc.StatusCode, + ) -> None: # Swap method type every sub-test to avoid mixing results rpc_type = self.next_rpc_type if rpc_type is None: @@ -197,64 +210,89 @@ def configure_and_assert(self, test_client: _XdsTestClient, metadata = None if test_metadata_val is not None: metadata = ((rpc_type, "test", test_metadata_val),) - test_client.update_config.configure(rpc_types=[rpc_type], - metadata=metadata) + test_client.update_config.configure( + rpc_types=[rpc_type], metadata=metadata + ) # b/228743575 Python has as race. Give us time to fix it. stray_rpc_limit = 1 if self.lang_spec.client_lang == _Lang.PYTHON else 0 - self.assertRpcStatusCodes(test_client, - expected_status=status_code, - duration=_SAMPLE_DURATION, - method=rpc_type, - stray_rpc_limit=stray_rpc_limit) + self.assertRpcStatusCodes( + test_client, + expected_status=status_code, + duration=_SAMPLE_DURATION, + method=rpc_type, + stray_rpc_limit=stray_rpc_limit, + ) def test_plaintext_allow(self) -> None: self.setupTrafficDirectorGrpc() - self.td.create_authz_policy(action='ALLOW', rules=self.authz_rules()) - self.setupSecurityPolicies(server_tls=False, - server_mtls=False, - client_tls=False, - client_mtls=False) + self.td.create_authz_policy(action="ALLOW", rules=self.authz_rules()) + self.setupSecurityPolicies( + server_tls=False, + server_mtls=False, + client_tls=False, + client_mtls=False, + ) test_server: _XdsTestServer = self.startSecureTestServer() self.setupServerBackends() test_client: _XdsTestClient = self.startSecureTestClient(test_server) time.sleep(_SETTLE_DURATION.total_seconds()) - with self.subTest('01_host_wildcard'): - self.configure_and_assert(test_client, 'host-wildcard', - grpc.StatusCode.OK) - - with self.subTest('02_no_match'): - self.configure_and_assert(test_client, 'no-such-rule', - grpc.StatusCode.PERMISSION_DENIED) - self.configure_and_assert(test_client, None, - grpc.StatusCode.PERMISSION_DENIED) - - with self.subTest('03_header_regex'): - self.configure_and_assert(test_client, 'header-regex-a', - grpc.StatusCode.OK) - self.configure_and_assert(test_client, 'header-regex-aa', - grpc.StatusCode.OK) - self.configure_and_assert(test_client, 'header-regex-', - grpc.StatusCode.PERMISSION_DENIED) - self.configure_and_assert(test_client, 'header-regex-ab', - grpc.StatusCode.PERMISSION_DENIED) - self.configure_and_assert(test_client, 'aheader-regex-a', - grpc.StatusCode.PERMISSION_DENIED) - - with self.subTest('04_host_match'): - self.configure_and_assert(test_client, 'host-match1', - grpc.StatusCode.OK) - self.configure_and_assert(test_client, 'host-match2', - grpc.StatusCode.OK) - - with self.subTest('05_never_match_host'): - self.configure_and_assert(test_client, 'never-match-host', - grpc.StatusCode.PERMISSION_DENIED) - - with self.subTest('06_never_match_port'): - self.configure_and_assert(test_client, 'never-match-port', - grpc.StatusCode.PERMISSION_DENIED) + with self.subTest("01_host_wildcard"): + self.configure_and_assert( + test_client, "host-wildcard", grpc.StatusCode.OK + ) + + with self.subTest("02_no_match"): + self.configure_and_assert( + test_client, "no-such-rule", grpc.StatusCode.PERMISSION_DENIED + ) + self.configure_and_assert( + test_client, None, grpc.StatusCode.PERMISSION_DENIED + ) + + with self.subTest("03_header_regex"): + self.configure_and_assert( + test_client, "header-regex-a", grpc.StatusCode.OK + ) + self.configure_and_assert( + test_client, "header-regex-aa", grpc.StatusCode.OK + ) + self.configure_and_assert( + test_client, "header-regex-", grpc.StatusCode.PERMISSION_DENIED + ) + self.configure_and_assert( + test_client, + "header-regex-ab", + grpc.StatusCode.PERMISSION_DENIED, + ) + self.configure_and_assert( + test_client, + "aheader-regex-a", + grpc.StatusCode.PERMISSION_DENIED, + ) + + with self.subTest("04_host_match"): + self.configure_and_assert( + test_client, "host-match1", grpc.StatusCode.OK + ) + self.configure_and_assert( + test_client, "host-match2", grpc.StatusCode.OK + ) + + with self.subTest("05_never_match_host"): + self.configure_and_assert( + test_client, + "never-match-host", + grpc.StatusCode.PERMISSION_DENIED, + ) + + with self.subTest("06_never_match_port"): + self.configure_and_assert( + test_client, + "never-match-port", + grpc.StatusCode.PERMISSION_DENIED, + ) # b/202058316 # with self.subTest('07_principal_present'): @@ -263,24 +301,28 @@ def test_plaintext_allow(self) -> None: def test_tls_allow(self) -> None: self.setupTrafficDirectorGrpc() - self.td.create_authz_policy(action='ALLOW', rules=self.authz_rules()) - self.setupSecurityPolicies(server_tls=True, - server_mtls=False, - client_tls=True, - client_mtls=False) + self.td.create_authz_policy(action="ALLOW", rules=self.authz_rules()) + self.setupSecurityPolicies( + server_tls=True, + server_mtls=False, + client_tls=True, + client_mtls=False, + ) test_server: _XdsTestServer = self.startSecureTestServer() self.setupServerBackends() test_client: _XdsTestClient = self.startSecureTestClient(test_server) time.sleep(_SETTLE_DURATION.total_seconds()) - with self.subTest('01_host_wildcard'): - self.configure_and_assert(test_client, 'host-wildcard', - grpc.StatusCode.OK) + with self.subTest("01_host_wildcard"): + self.configure_and_assert( + test_client, "host-wildcard", grpc.StatusCode.OK + ) - with self.subTest('02_no_match'): - self.configure_and_assert(test_client, None, - grpc.StatusCode.PERMISSION_DENIED) + with self.subTest("02_no_match"): + self.configure_and_assert( + test_client, None, grpc.StatusCode.PERMISSION_DENIED + ) # b/202058316 # with self.subTest('03_principal_present'): @@ -289,58 +331,66 @@ def test_tls_allow(self) -> None: def test_mtls_allow(self) -> None: self.setupTrafficDirectorGrpc() - self.td.create_authz_policy(action='ALLOW', rules=self.authz_rules()) - self.setupSecurityPolicies(server_tls=True, - server_mtls=True, - client_tls=True, - client_mtls=True) + self.td.create_authz_policy(action="ALLOW", rules=self.authz_rules()) + self.setupSecurityPolicies( + server_tls=True, server_mtls=True, client_tls=True, client_mtls=True + ) test_server: _XdsTestServer = self.startSecureTestServer() self.setupServerBackends() test_client: _XdsTestClient = self.startSecureTestClient(test_server) time.sleep(_SETTLE_DURATION.total_seconds()) - with self.subTest('01_host_wildcard'): - self.configure_and_assert(test_client, 'host-wildcard', - grpc.StatusCode.OK) + with self.subTest("01_host_wildcard"): + self.configure_and_assert( + test_client, "host-wildcard", grpc.StatusCode.OK + ) - with self.subTest('02_no_match'): - self.configure_and_assert(test_client, None, - grpc.StatusCode.PERMISSION_DENIED) + with self.subTest("02_no_match"): + self.configure_and_assert( + test_client, None, grpc.StatusCode.PERMISSION_DENIED + ) # b/202058316 # with self.subTest('03_principal_present'): # self.configure_and_assert(test_client, 'principal-present', # grpc.StatusCode.OK) - with self.subTest('04_match_principal'): - self.configure_and_assert(test_client, 'match-principal', - grpc.StatusCode.OK) + with self.subTest("04_match_principal"): + self.configure_and_assert( + test_client, "match-principal", grpc.StatusCode.OK + ) - with self.subTest('05_never_match_principal'): - self.configure_and_assert(test_client, 'never-match-principal', - grpc.StatusCode.PERMISSION_DENIED) + with self.subTest("05_never_match_principal"): + self.configure_and_assert( + test_client, + "never-match-principal", + grpc.StatusCode.PERMISSION_DENIED, + ) def test_plaintext_deny(self) -> None: self.setupTrafficDirectorGrpc() - self.td.create_authz_policy(action='DENY', rules=self.authz_rules()) - self.setupSecurityPolicies(server_tls=False, - server_mtls=False, - client_tls=False, - client_mtls=False) + self.td.create_authz_policy(action="DENY", rules=self.authz_rules()) + self.setupSecurityPolicies( + server_tls=False, + server_mtls=False, + client_tls=False, + client_mtls=False, + ) test_server: _XdsTestServer = self.startSecureTestServer() self.setupServerBackends() test_client: _XdsTestClient = self.startSecureTestClient(test_server) time.sleep(_SETTLE_DURATION.total_seconds()) - with self.subTest('01_host_wildcard'): - self.configure_and_assert(test_client, 'host-wildcard', - grpc.StatusCode.PERMISSION_DENIED) + with self.subTest("01_host_wildcard"): + self.configure_and_assert( + test_client, "host-wildcard", grpc.StatusCode.PERMISSION_DENIED + ) - with self.subTest('02_no_match'): + with self.subTest("02_no_match"): self.configure_and_assert(test_client, None, grpc.StatusCode.OK) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/baseline_test.py b/tools/run_tests/xds_k8s_test_driver/tests/baseline_test.py index 9e533e732bad3..bf1bcf31acc44 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/baseline_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/baseline_test.py @@ -27,38 +27,37 @@ class BaselineTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): - def test_traffic_director_grpc_setup(self): - with self.subTest('0_create_health_check'): + with self.subTest("0_create_health_check"): self.td.create_health_check() - with self.subTest('1_create_backend_service'): + with self.subTest("1_create_backend_service"): self.td.create_backend_service() - with self.subTest('2_create_url_map'): + with self.subTest("2_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('3_create_target_proxy'): + with self.subTest("3_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('4_create_forwarding_rule'): + with self.subTest("4_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) - with self.subTest('5_start_test_server'): + with self.subTest("5_start_test_server"): test_server: _XdsTestServer = self.startTestServers()[0] - with self.subTest('6_add_server_backends_to_backend_service'): + with self.subTest("6_add_server_backends_to_backend_service"): self.setupServerBackends() - with self.subTest('7_start_test_client'): + with self.subTest("7_start_test_client"): test_client: _XdsTestClient = self.startTestClient(test_server) - with self.subTest('8_test_client_xds_config_exists'): + with self.subTest("8_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('9_test_server_received_rpcs_from_test_client'): + with self.subTest("9_test_server_received_rpcs_from_test_client"): self.assertSuccessfulRpcs(test_client) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/bootstrap_generator_test.py b/tools/run_tests/xds_k8s_test_driver/tests/bootstrap_generator_test.py index 3fe80085b79b8..0c290a829002a 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/bootstrap_generator_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/bootstrap_generator_test.py @@ -47,24 +47,20 @@ def bootstrap_version_testcases() -> List: return ( dict( - version='v0.14.0', - image= - 'gcr.io/grpc-testing/td-grpc-bootstrap:d6baaf7b0e0c63054ac4d9bedc09021ff261d599' + version="v0.14.0", + image="gcr.io/grpc-testing/td-grpc-bootstrap:d6baaf7b0e0c63054ac4d9bedc09021ff261d599", ), dict( - version='v0.13.0', - image= - 'gcr.io/grpc-testing/td-grpc-bootstrap:203db6ce70452996f4183c30dd4c5ecaada168b0' + version="v0.13.0", + image="gcr.io/grpc-testing/td-grpc-bootstrap:203db6ce70452996f4183c30dd4c5ecaada168b0", ), dict( - version='v0.12.0', - image= - 'gcr.io/grpc-testing/td-grpc-bootstrap:8765051ef3b742bc5cd20f16de078ae7547f2ba2' + version="v0.12.0", + image="gcr.io/grpc-testing/td-grpc-bootstrap:8765051ef3b742bc5cd20f16de078ae7547f2ba2", ), dict( - version='v0.11.0', - image= - 'gcr.io/grpc-testing/td-grpc-bootstrap:b96f7a73314668aee83cbf86ab1e40135a0542fc' + version="v0.11.0", + image="gcr.io/grpc-testing/td-grpc-bootstrap:b96f7a73314668aee83cbf86ab1e40135a0542fc", ), # v0.10.0 uses v2 xDS transport protocol by default. TD only supports v3 # and we can force the bootstrap generator to emit config with v3 @@ -80,8 +76,9 @@ def bootstrap_version_testcases() -> List: # TODO: Reuse service account and namespaces for significant improvements in # running time. class BootstrapGeneratorClientTest( - bootstrap_generator_testcase.BootstrapGeneratorBaseTest, - parameterized.TestCase): + bootstrap_generator_testcase.BootstrapGeneratorBaseTest, + parameterized.TestCase, +): client_runner: KubernetesClientRunner server_runner: KubernetesServerRunner test_server: XdsTestServer @@ -103,11 +100,13 @@ def setUpClass(cls): port=cls.server_port, maintenance_port=cls.server_maintenance_port, xds_host=cls.server_xds_host, - xds_port=cls.server_xds_port) + xds_port=cls.server_xds_port, + ) # Load backends. neg_name, neg_zones = cls.server_runner.k8s_namespace.get_service_neg( - cls.server_runner.service_name, cls.server_port) + cls.server_runner.service_name, cls.server_port + ) # Add backends to the Backend Service. cls.td.backend_service_add_neg_backends(neg_name, neg_zones) @@ -118,35 +117,41 @@ def tearDownClass(cls): # Remove backends from the Backend Service before closing the server # runner. neg_name, neg_zones = cls.server_runner.k8s_namespace.get_service_neg( - cls.server_runner.service_name, cls.server_port) + cls.server_runner.service_name, cls.server_port + ) cls.td.backend_service_remove_neg_backends(neg_name, neg_zones) cls.server_runner.cleanup(force=cls.force_cleanup) super().tearDownClass() def tearDown(self): - logger.info('----- TestMethod %s teardown -----', self.id()) - retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10), - attempts=3, - log_level=logging.INFO) + logger.info("----- TestMethod %s teardown -----", self.id()) + retryer = retryers.constant_retryer( + wait_fixed=_timedelta(seconds=10), + attempts=3, + log_level=logging.INFO, + ) try: retryer(self._cleanup) except retryers.RetryError: - logger.exception('Got error during teardown') + logger.exception("Got error during teardown") super().tearDown() def _cleanup(self): self.client_runner.cleanup(force=self.force_cleanup) @parameterized.parameters( - (t["version"], t["image"]) for t in bootstrap_version_testcases()) + (t["version"], t["image"]) for t in bootstrap_version_testcases() + ) def test_baseline_in_client_with_bootstrap_version(self, version, image): """Runs the baseline test for multiple versions of the bootstrap generator on the client. """ - logger.info('----- testing bootstrap generator version %s -----', - version) + logger.info( + "----- testing bootstrap generator version %s -----", version + ) self.client_runner = self.initKubernetesClientRunner( - td_bootstrap_image=image) + td_bootstrap_image=image + ) test_client: XdsTestClient = self.startTestClient(self.test_server) self.assertXdsConfigExists(test_client) self.assertSuccessfulRpcs(test_client) @@ -156,21 +161,24 @@ def test_baseline_in_client_with_bootstrap_version(self, version, image): # corresponding runners, by suffixing the version of the bootstrap generator # being tested. Then, run these in parallel. class BootstrapGeneratorServerTest( - bootstrap_generator_testcase.BootstrapGeneratorBaseTest, - parameterized.TestCase): + bootstrap_generator_testcase.BootstrapGeneratorBaseTest, + parameterized.TestCase, +): client_runner: KubernetesClientRunner server_runner: KubernetesServerRunner test_server: XdsTestServer def tearDown(self): - logger.info('----- TestMethod %s teardown -----', self.id()) - retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10), - attempts=3, - log_level=logging.INFO) + logger.info("----- TestMethod %s teardown -----", self.id()) + retryer = retryers.constant_retryer( + wait_fixed=_timedelta(seconds=10), + attempts=3, + log_level=logging.INFO, + ) try: retryer(self._cleanup) except retryers.RetryError: - logger.exception('Got error during teardown') + logger.exception("Got error during teardown") super().tearDown() def _cleanup(self): @@ -179,25 +187,30 @@ def _cleanup(self): self.server_runner.cleanup(force=self.force_cleanup) @parameterized.parameters( - (t["version"], t["image"]) for t in bootstrap_version_testcases()) + (t["version"], t["image"]) for t in bootstrap_version_testcases() + ) def test_baseline_in_server_with_bootstrap_version(self, version, image): """Runs the baseline test for multiple versions of the bootstrap generator on the server. """ - logger.info('----- Testing bootstrap generator version %s -----', - version) + logger.info( + "----- Testing bootstrap generator version %s -----", version + ) self.server_runner = self.initKubernetesServerRunner( - td_bootstrap_image=image) + td_bootstrap_image=image + ) self.test_server = self.startTestServer( server_runner=self.server_runner, port=self.server_port, maintenance_port=self.server_maintenance_port, xds_host=self.server_xds_host, - xds_port=self.server_xds_port) + xds_port=self.server_xds_port, + ) # Load backends. neg_name, neg_zones = self.server_runner.k8s_namespace.get_service_neg( - self.server_runner.service_name, self.server_port) + self.server_runner.service_name, self.server_port + ) # Add backends to the Backend Service. self.td.backend_service_add_neg_backends(neg_name, neg_zones) @@ -209,5 +222,5 @@ def test_baseline_in_server_with_bootstrap_version(self, version, image): self.assertSuccessfulRpcs(test_client) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/change_backend_service_test.py b/tools/run_tests/xds_k8s_test_driver/tests/change_backend_service_test.py index d7814c9136b6e..f7af2326bbc9c 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/change_backend_service_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/change_backend_service_test.py @@ -31,14 +31,14 @@ class ChangeBackendServiceTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): - def setUp(self): super().setUp() self.alternate_k8s_namespace = k8s.KubernetesNamespace( - self.k8s_api_manager, self.server_namespace) + self.k8s_api_manager, self.server_namespace + ) self.alternate_server_runner = _KubernetesServerRunner( self.alternate_k8s_namespace, - deployment_name=self.server_name + '-alt', + deployment_name=self.server_name + "-alt", image_name=self.server_image, gcp_service_account=self.gcp_service_account, td_bootstrap_image=self.td_bootstrap_image, @@ -47,62 +47,74 @@ def setUp(self): xds_server_uri=self.xds_server_uri, network=self.network, debug_use_port_forwarding=self.debug_use_port_forwarding, - reuse_namespace=True) + reuse_namespace=True, + ) def cleanup(self): super().cleanup() - if hasattr(self, 'alternate_server_runner'): + if hasattr(self, "alternate_server_runner"): self.alternate_server_runner.cleanup( - force=self.force_cleanup, force_namespace=self.force_cleanup) + force=self.force_cleanup, force_namespace=self.force_cleanup + ) def test_change_backend_service(self) -> None: - with self.subTest('00_create_health_check'): + with self.subTest("00_create_health_check"): self.td.create_health_check() - with self.subTest('01_create_backend_services'): + with self.subTest("01_create_backend_services"): self.td.create_backend_service() self.td.create_alternative_backend_service() - with self.subTest('02_create_url_map'): + with self.subTest("02_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('03_create_target_proxy'): + with self.subTest("03_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('04_create_forwarding_rule'): + with self.subTest("04_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) default_test_servers: List[_XdsTestServer] same_zone_test_servers: List[_XdsTestServer] - with self.subTest('05_start_test_servers'): + with self.subTest("05_start_test_servers"): default_test_servers = self.startTestServers() same_zone_test_servers = self.startTestServers( - server_runner=self.alternate_server_runner) + server_runner=self.alternate_server_runner + ) - with self.subTest('06_add_server_backends_to_backend_services'): + with self.subTest("06_add_server_backends_to_backend_services"): self.setupServerBackends() # Add backend to alternative backend service - neg_name_alt, neg_zones_alt = self.alternate_k8s_namespace.get_service_neg( - self.alternate_server_runner.service_name, self.server_port) + ( + neg_name_alt, + neg_zones_alt, + ) = self.alternate_k8s_namespace.get_service_neg( + self.alternate_server_runner.service_name, self.server_port + ) self.td.alternative_backend_service_add_neg_backends( - neg_name_alt, neg_zones_alt) + neg_name_alt, neg_zones_alt + ) test_client: _XdsTestClient - with self.subTest('07_start_test_client'): + with self.subTest("07_start_test_client"): test_client = self.startTestClient(default_test_servers[0]) - with self.subTest('08_test_client_xds_config_exists'): + with self.subTest("08_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('09_test_server_received_rpcs_from_test_client'): + with self.subTest("09_test_server_received_rpcs_from_test_client"): self.assertSuccessfulRpcs(test_client) - with self.subTest('10_change_backend_service'): - self.td.patch_url_map(self.server_xds_host, self.server_xds_port, - self.td.alternative_backend_service) - self.assertRpcsEventuallyGoToGivenServers(test_client, - same_zone_test_servers) + with self.subTest("10_change_backend_service"): + self.td.patch_url_map( + self.server_xds_host, + self.server_xds_port, + self.td.alternative_backend_service, + ) + self.assertRpcsEventuallyGoToGivenServers( + test_client, same_zone_test_servers + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/custom_lb_test.py b/tools/run_tests/xds_k8s_test_driver/tests/custom_lb_test.py index d19f306ef63a4..6287d9c0e6b1b 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/custom_lb_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/custom_lb_test.py @@ -34,7 +34,6 @@ class CustomLbTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): - @classmethod def setUpClass(cls): """Force the java test server for languages not yet supporting @@ -55,15 +54,15 @@ def setUpClass(cls): @staticmethod def is_supported(config: skips.TestConfig) -> bool: if config.client_lang == _Lang.JAVA: - return config.version_gte('v1.47.x') + return config.version_gte("v1.47.x") if config.client_lang == _Lang.CPP: - return config.version_gte('v1.55.x') + return config.version_gte("v1.55.x") if config.client_lang == _Lang.GO: - return config.version_gte('v1.56.x') + return config.version_gte("v1.56.x") return False def test_custom_lb_config(self): - with self.subTest('0_create_health_check'): + with self.subTest("0_create_health_check"): self.td.create_health_check() # Configures a custom, test LB on the client to instruct the servers @@ -72,49 +71,57 @@ def test_custom_lb_config(self): # The first policy in the list is a non-existent one to verify that # the gRPC client can gracefully move down the list to the valid one # once it determines the first one is not available. - with self.subTest('1_create_backend_service'): - self.td.create_backend_service(locality_lb_policies=[{ - 'customPolicy': { - 'name': 'test.ThisLoadBalancerDoesNotExist', - 'data': '{ "foo": "bar" }' - }, - }, { - 'customPolicy': { - 'name': - 'test.RpcBehaviorLoadBalancer', - 'data': - f'{{ "rpcBehavior": "error-code-{_EXPECTED_STATUS.value[0]}" }}' - } - }]) - - with self.subTest('2_create_url_map'): + with self.subTest("1_create_backend_service"): + self.td.create_backend_service( + locality_lb_policies=[ + { + "customPolicy": { + "name": "test.ThisLoadBalancerDoesNotExist", + "data": '{ "foo": "bar" }', + }, + }, + { + "customPolicy": { + "name": "test.RpcBehaviorLoadBalancer", + "data": ( + '{ "rpcBehavior":' + f' "error-code-{_EXPECTED_STATUS.value[0]}" }}' + ), + } + }, + ] + ) + + with self.subTest("2_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('3_create_target_proxy'): + with self.subTest("3_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('4_create_forwarding_rule'): + with self.subTest("4_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) - with self.subTest('5_start_test_server'): + with self.subTest("5_start_test_server"): test_server: _XdsTestServer = self.startTestServers()[0] - with self.subTest('6_add_server_backends_to_backend_service'): + with self.subTest("6_add_server_backends_to_backend_service"): self.setupServerBackends() - with self.subTest('7_start_test_client'): + with self.subTest("7_start_test_client"): test_client: _XdsTestClient = self.startTestClient(test_server) - with self.subTest('8_test_client_xds_config_exists'): + with self.subTest("8_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) # Verify status codes from the servers have the configured one. - with self.subTest('9_test_server_returned_configured_status_code'): - self.assertRpcStatusCodes(test_client, - expected_status=_EXPECTED_STATUS, - duration=datetime.timedelta(seconds=10), - method='UNARY_CALL') + with self.subTest("9_test_server_returned_configured_status_code"): + self.assertRpcStatusCodes( + test_client, + expected_status=_EXPECTED_STATUS, + duration=datetime.timedelta(seconds=10), + method="UNARY_CALL", + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/failover_test.py b/tools/run_tests/xds_k8s_test_driver/tests/failover_test.py index 1d4a99534bd90..5f41707ceeed8 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/failover_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/failover_test.py @@ -49,9 +49,10 @@ def setUpClass(cls): def setUp(self): super().setUp() self.secondary_server_runner = _KubernetesServerRunner( - k8s.KubernetesNamespace(self.secondary_k8s_api_manager, - self.server_namespace), - deployment_name=self.server_name + '-alt', + k8s.KubernetesNamespace( + self.secondary_k8s_api_manager, self.server_namespace + ), + deployment_name=self.server_name + "-alt", image_name=self.server_image, gcp_service_account=self.gcp_service_account, td_bootstrap_image=self.td_bootstrap_image, @@ -62,81 +63,93 @@ def setUp(self): debug_use_port_forwarding=self.debug_use_port_forwarding, # This runner's namespace created in the secondary cluster, # so it's not reused and must be cleaned up. - reuse_namespace=False) + reuse_namespace=False, + ) def cleanup(self): super().cleanup() - if hasattr(self, 'secondary_server_runner'): + if hasattr(self, "secondary_server_runner"): self.secondary_server_runner.cleanup( - force=self.force_cleanup, force_namespace=self.force_cleanup) + force=self.force_cleanup, force_namespace=self.force_cleanup + ) def test_failover(self) -> None: - with self.subTest('00_create_health_check'): + with self.subTest("00_create_health_check"): self.td.create_health_check() - with self.subTest('01_create_backend_services'): + with self.subTest("01_create_backend_services"): self.td.create_backend_service() - with self.subTest('02_create_url_map'): + with self.subTest("02_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('03_create_target_proxy'): + with self.subTest("03_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('04_create_forwarding_rule'): + with self.subTest("04_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) default_test_servers: List[_XdsTestServer] alternate_test_servers: List[_XdsTestServer] - with self.subTest('05_start_test_servers'): + with self.subTest("05_start_test_servers"): default_test_servers = self.startTestServers( - replica_count=self.REPLICA_COUNT) + replica_count=self.REPLICA_COUNT + ) alternate_test_servers = self.startTestServers( - server_runner=self.secondary_server_runner) + server_runner=self.secondary_server_runner + ) - with self.subTest('06_add_server_backends_to_backend_services'): + with self.subTest("06_add_server_backends_to_backend_services"): self.setupServerBackends( - max_rate_per_endpoint=self.MAX_RATE_PER_ENDPOINT) + max_rate_per_endpoint=self.MAX_RATE_PER_ENDPOINT + ) self.setupServerBackends( server_runner=self.secondary_server_runner, - max_rate_per_endpoint=self.MAX_RATE_PER_ENDPOINT) + max_rate_per_endpoint=self.MAX_RATE_PER_ENDPOINT, + ) test_client: _XdsTestClient - with self.subTest('07_start_test_client'): + with self.subTest("07_start_test_client"): test_client = self.startTestClient(default_test_servers[0]) - with self.subTest('08_test_client_xds_config_exists'): + with self.subTest("08_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('09_primary_locality_receives_requests'): - self.assertRpcsEventuallyGoToGivenServers(test_client, - default_test_servers) + with self.subTest("09_primary_locality_receives_requests"): + self.assertRpcsEventuallyGoToGivenServers( + test_client, default_test_servers + ) with self.subTest( - '10_secondary_locality_receives_no_requests_on_partial_primary_failure' + "10_secondary_locality_receives_no_requests_on_partial_primary_failure" ): default_test_servers[0].set_not_serving() - self.assertRpcsEventuallyGoToGivenServers(test_client, - default_test_servers[1:]) + self.assertRpcsEventuallyGoToGivenServers( + test_client, default_test_servers[1:] + ) - with self.subTest('11_gentle_failover'): + with self.subTest("11_gentle_failover"): default_test_servers[1].set_not_serving() self.assertRpcsEventuallyGoToGivenServers( - test_client, default_test_servers[2:] + alternate_test_servers) + test_client, default_test_servers[2:] + alternate_test_servers + ) with self.subTest( - '12_secondary_locality_receives_requests_on_primary_failure'): + "12_secondary_locality_receives_requests_on_primary_failure" + ): default_test_servers[2].set_not_serving() - self.assertRpcsEventuallyGoToGivenServers(test_client, - alternate_test_servers) + self.assertRpcsEventuallyGoToGivenServers( + test_client, alternate_test_servers + ) - with self.subTest('13_traffic_resumes_to_healthy_backends'): + with self.subTest("13_traffic_resumes_to_healthy_backends"): for i in range(self.REPLICA_COUNT): default_test_servers[i].set_serving() - self.assertRpcsEventuallyGoToGivenServers(test_client, - default_test_servers) + self.assertRpcsEventuallyGoToGivenServers( + test_client, default_test_servers + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/outlier_detection_test.py b/tools/run_tests/xds_k8s_test_driver/tests/outlier_detection_test.py index d69e3412b6d2d..48f7e0ace0e10 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/outlier_detection_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/outlier_detection_test.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) flags.adopt_module_key_flags(xds_k8s_testcase) -flags.mark_flag_as_required('server_image_canonical') +flags.mark_flag_as_required("server_image_canonical") # Type aliases RpcTypeUnaryCall = xds_url_map_testcase.RpcTypeUnaryCall @@ -63,71 +63,74 @@ def setUpClass(cls): @staticmethod def is_supported(config: skips.TestConfig) -> bool: if config.client_lang in _Lang.CPP | _Lang.PYTHON: - return config.version_gte('v1.48.x') + return config.version_gte("v1.48.x") if config.client_lang == _Lang.JAVA: - return config.version_gte('v1.49.x') + return config.version_gte("v1.49.x") if config.client_lang == _Lang.NODE: - return config.version_gte('v1.6.x') + return config.version_gte("v1.6.x") if config.client_lang == _Lang.GO: # TODO(zasweq): Update when the feature makes in a version branch. - return config.version_gte('master') + return config.version_gte("master") return False def test_outlier_detection(self) -> None: - - with self.subTest('00_create_health_check'): + with self.subTest("00_create_health_check"): self.td.create_health_check() - with self.subTest('01_create_backend_service'): + with self.subTest("01_create_backend_service"): self.td.create_backend_service( outlier_detection={ - 'interval': { - 'seconds': 2, - 'nanos': 0 - }, - 'successRateRequestVolume': 20 - }) - - with self.subTest('02_create_url_map'): + "interval": {"seconds": 2, "nanos": 0}, + "successRateRequestVolume": 20, + } + ) + + with self.subTest("02_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('03_create_target_proxy'): + with self.subTest("03_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('04_create_forwarding_rule'): + with self.subTest("04_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) test_servers: List[_XdsTestServer] - with self.subTest('05_start_test_servers'): + with self.subTest("05_start_test_servers"): test_servers = self.startTestServers(replica_count=_REPLICA_COUNT) - with self.subTest('06_add_server_backends_to_backend_services'): + with self.subTest("06_add_server_backends_to_backend_services"): self.setupServerBackends() test_client: _XdsTestClient - with self.subTest('07_start_test_client'): + with self.subTest("07_start_test_client"): test_client = self.startTestClient(test_servers[0], qps=_QPS) - with self.subTest('08_test_client_xds_config_exists'): + with self.subTest("08_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('09_test_servers_received_rpcs_from_test_client'): + with self.subTest("09_test_servers_received_rpcs_from_test_client"): self.assertRpcsEventuallyGoToGivenServers(test_client, test_servers) rpc_types = (RpcTypeUnaryCall,) - with self.subTest('10_chosen_server_removed_by_outlier_detection'): + with self.subTest("10_chosen_server_removed_by_outlier_detection"): test_client.update_config.configure( rpc_types=rpc_types, metadata=( - (RpcTypeUnaryCall, 'rpc-behavior', - f'hostname={test_servers[0].hostname} error-code-2'),)) - self.assertRpcsEventuallyGoToGivenServers(test_client, - test_servers[1:]) - - with self.subTest('11_ejected_server_returned_after_failures_stopped'): + ( + RpcTypeUnaryCall, + "rpc-behavior", + f"hostname={test_servers[0].hostname} error-code-2", + ), + ), + ) + self.assertRpcsEventuallyGoToGivenServers( + test_client, test_servers[1:] + ) + + with self.subTest("11_ejected_server_returned_after_failures_stopped"): test_client.update_config.configure(rpc_types=rpc_types) self.assertRpcsEventuallyGoToGivenServers(test_client, test_servers) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/remove_neg_test.py b/tools/run_tests/xds_k8s_test_driver/tests/remove_neg_test.py index 8b1a5a4b4d49a..cd690b187a7ce 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/remove_neg_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/remove_neg_test.py @@ -31,13 +31,13 @@ class RemoveNegTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): - def setUp(self): super().setUp() self.alternate_server_runner = _KubernetesServerRunner( - k8s.KubernetesNamespace(self.k8s_api_manager, - self.server_namespace), - deployment_name=self.server_name + '-alt', + k8s.KubernetesNamespace( + self.k8s_api_manager, self.server_namespace + ), + deployment_name=self.server_name + "-alt", image_name=self.server_image, gcp_service_account=self.gcp_service_account, td_bootstrap_image=self.td_bootstrap_image, @@ -46,59 +46,65 @@ def setUp(self): xds_server_uri=self.xds_server_uri, network=self.network, debug_use_port_forwarding=self.debug_use_port_forwarding, - reuse_namespace=True) + reuse_namespace=True, + ) def cleanup(self): super().cleanup() - if hasattr(self, 'alternate_server_runner'): + if hasattr(self, "alternate_server_runner"): self.alternate_server_runner.cleanup( - force=self.force_cleanup, force_namespace=self.force_cleanup) + force=self.force_cleanup, force_namespace=self.force_cleanup + ) def test_remove_neg(self) -> None: - with self.subTest('00_create_health_check'): + with self.subTest("00_create_health_check"): self.td.create_health_check() - with self.subTest('01_create_backend_services'): + with self.subTest("01_create_backend_services"): self.td.create_backend_service() - with self.subTest('02_create_url_map'): + with self.subTest("02_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('03_create_target_proxy'): + with self.subTest("03_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('04_create_forwarding_rule'): + with self.subTest("04_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) default_test_servers: List[_XdsTestServer] same_zone_test_servers: List[_XdsTestServer] - with self.subTest('05_start_test_servers'): + with self.subTest("05_start_test_servers"): default_test_servers = self.startTestServers() same_zone_test_servers = self.startTestServers( - server_runner=self.alternate_server_runner) + server_runner=self.alternate_server_runner + ) - with self.subTest('06_add_server_backends_to_backend_services'): + with self.subTest("06_add_server_backends_to_backend_services"): self.setupServerBackends() self.setupServerBackends(server_runner=self.alternate_server_runner) test_client: _XdsTestClient - with self.subTest('07_start_test_client'): + with self.subTest("07_start_test_client"): test_client = self.startTestClient(default_test_servers[0]) - with self.subTest('08_test_client_xds_config_exists'): + with self.subTest("08_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('09_test_server_received_rpcs_from_test_client'): + with self.subTest("09_test_server_received_rpcs_from_test_client"): self.assertSuccessfulRpcs(test_client) - with self.subTest('10_remove_neg'): + with self.subTest("10_remove_neg"): self.assertRpcsEventuallyGoToGivenServers( - test_client, default_test_servers + same_zone_test_servers) + test_client, default_test_servers + same_zone_test_servers + ) self.removeServerBackends( - server_runner=self.alternate_server_runner) - self.assertRpcsEventuallyGoToGivenServers(test_client, - default_test_servers) + server_runner=self.alternate_server_runner + ) + self.assertRpcsEventuallyGoToGivenServers( + test_client, default_test_servers + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/round_robin_test.py b/tools/run_tests/xds_k8s_test_driver/tests/round_robin_test.py index c8449eb91495c..6d3144f065bf2 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/round_robin_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/round_robin_test.py @@ -28,59 +28,65 @@ class RoundRobinTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): - def test_round_robin(self) -> None: REPLICA_COUNT = 2 - with self.subTest('00_create_health_check'): + with self.subTest("00_create_health_check"): self.td.create_health_check() - with self.subTest('01_create_backend_services'): + with self.subTest("01_create_backend_services"): self.td.create_backend_service() - with self.subTest('02_create_url_map'): + with self.subTest("02_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('03_create_target_proxy'): + with self.subTest("03_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('04_create_forwarding_rule'): + with self.subTest("04_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) test_servers: List[_XdsTestServer] - with self.subTest('05_start_test_servers'): + with self.subTest("05_start_test_servers"): test_servers = self.startTestServers(replica_count=REPLICA_COUNT) - with self.subTest('06_add_server_backends_to_backend_services'): + with self.subTest("06_add_server_backends_to_backend_services"): self.setupServerBackends() test_client: _XdsTestClient - with self.subTest('07_start_test_client'): + with self.subTest("07_start_test_client"): test_client = self.startTestClient(test_servers[0]) - with self.subTest('08_test_client_xds_config_exists'): + with self.subTest("08_test_client_xds_config_exists"): self.assertXdsConfigExists(test_client) - with self.subTest('09_test_server_received_rpcs_from_test_client'): + with self.subTest("09_test_server_received_rpcs_from_test_client"): self.assertSuccessfulRpcs(test_client) - with self.subTest('10_round_robin'): + with self.subTest("10_round_robin"): num_rpcs = 100 expected_rpcs_per_replica = num_rpcs / REPLICA_COUNT - rpcs_by_peer = self.getClientRpcStats(test_client, - num_rpcs).rpcs_by_peer + rpcs_by_peer = self.getClientRpcStats( + test_client, num_rpcs + ).rpcs_by_peer total_requests_received = sum(rpcs_by_peer[x] for x in rpcs_by_peer) - self.assertEqual(total_requests_received, num_rpcs, - 'Wrong number of RPCS') + self.assertEqual( + total_requests_received, num_rpcs, "Wrong number of RPCS" + ) for server in test_servers: hostname = server.hostname - self.assertIn(hostname, rpcs_by_peer, - f'Server {hostname} did not receive RPCs') + self.assertIn( + hostname, + rpcs_by_peer, + f"Server {hostname} did not receive RPCs", + ) self.assertLessEqual( - abs(rpcs_by_peer[hostname] - expected_rpcs_per_replica), 1, - f'Wrong number of RPCs for server {hostname}') + abs(rpcs_by_peer[hostname] - expected_rpcs_per_replica), + 1, + f"Wrong number of RPCs for server {hostname}", + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/security_test.py b/tools/run_tests/xds_k8s_test_driver/tests/security_test.py index 4192587e2727c..f087a9595f9ca 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/security_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/security_test.py @@ -31,14 +31,14 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: - if config.client_lang in (_Lang.CPP | _Lang.GO | _Lang.JAVA | - _Lang.PYTHON): + if config.client_lang in ( + _Lang.CPP | _Lang.GO | _Lang.JAVA | _Lang.PYTHON + ): # Versions prior to v1.41.x don't support PSM Security. # https://github.com/grpc/grpc/blob/master/doc/grpc_xds_features.md - return config.version_gte('v1.41.x') + return config.version_gte("v1.41.x") elif config.client_lang == _Lang.NODE: return False return True @@ -49,10 +49,9 @@ def test_mtls(self): Both client and server configured to use TLS and mTLS. """ self.setupTrafficDirectorGrpc() - self.setupSecurityPolicies(server_tls=True, - server_mtls=True, - client_tls=True, - client_mtls=True) + self.setupSecurityPolicies( + server_tls=True, server_mtls=True, client_tls=True, client_mtls=True + ) test_server: _XdsTestServer = self.startSecureTestServer() self.setupServerBackends() @@ -60,7 +59,7 @@ def test_mtls(self): self.assertTestAppSecurity(_SecurityMode.MTLS, test_client, test_server) self.assertSuccessfulRpcs(test_client) - logger.info('[SUCCESS] mTLS security mode confirmed.') + logger.info("[SUCCESS] mTLS security mode confirmed.") def test_tls(self): """TLS test. @@ -68,10 +67,12 @@ def test_tls(self): Both client and server configured to use TLS and not use mTLS. """ self.setupTrafficDirectorGrpc() - self.setupSecurityPolicies(server_tls=True, - server_mtls=False, - client_tls=True, - client_mtls=False) + self.setupSecurityPolicies( + server_tls=True, + server_mtls=False, + client_tls=True, + client_mtls=False, + ) test_server: _XdsTestServer = self.startSecureTestServer() self.setupServerBackends() @@ -79,7 +80,7 @@ def test_tls(self): self.assertTestAppSecurity(_SecurityMode.TLS, test_client, test_server) self.assertSuccessfulRpcs(test_client) - logger.info('[SUCCESS] TLS security mode confirmed.') + logger.info("[SUCCESS] TLS security mode confirmed.") def test_plaintext_fallback(self): """Plain-text fallback test. @@ -88,19 +89,22 @@ def test_plaintext_fallback(self): fallback to plaintext based on fallback-credentials. """ self.setupTrafficDirectorGrpc() - self.setupSecurityPolicies(server_tls=False, - server_mtls=False, - client_tls=False, - client_mtls=False) + self.setupSecurityPolicies( + server_tls=False, + server_mtls=False, + client_tls=False, + client_mtls=False, + ) test_server: _XdsTestServer = self.startSecureTestServer() self.setupServerBackends() test_client: _XdsTestClient = self.startSecureTestClient(test_server) - self.assertTestAppSecurity(_SecurityMode.PLAINTEXT, test_client, - test_server) + self.assertTestAppSecurity( + _SecurityMode.PLAINTEXT, test_client, test_server + ) self.assertSuccessfulRpcs(test_client) - logger.info('[SUCCESS] Plaintext security mode confirmed.') + logger.info("[SUCCESS] Plaintext security mode confirmed.") def test_mtls_error(self): """Negative test: mTLS Error. @@ -123,7 +127,8 @@ def test_mtls_error(self): """ # Create backend service self.td.setup_backend_for_grpc( - health_check_port=self.server_maintenance_port) + health_check_port=self.server_maintenance_port + ) # Start server and attach its NEGs to the backend service, but # until they become healthy. @@ -131,26 +136,31 @@ def test_mtls_error(self): self.setupServerBackends(wait_for_healthy_status=False) # Setup policies and attach them. - self.setupSecurityPolicies(server_tls=True, - server_mtls=True, - client_tls=True, - client_mtls=False) + self.setupSecurityPolicies( + server_tls=True, + server_mtls=True, + client_tls=True, + client_mtls=False, + ) # Create the routing rule map. - self.td.setup_routing_rule_map_for_grpc(self.server_xds_host, - self.server_xds_port) + self.td.setup_routing_rule_map_for_grpc( + self.server_xds_host, self.server_xds_port + ) # Now that TD setup is complete, Backend Service can be populated # with healthy backends (NEGs). self.td.wait_for_backends_healthy_status() # Start the client, but don't wait for it to report a healthy channel. test_client: _XdsTestClient = self.startSecureTestClient( - test_server, wait_for_active_server_channel=False) + test_server, wait_for_active_server_channel=False + ) self.assertClientCannotReachServerRepeatedly(test_client) logger.info( "[SUCCESS] Client's connectivity state is consistent with a mTLS " - "error caused by not presenting mTLS certificate to the server.") + "error caused by not presenting mTLS certificate to the server." + ) def test_server_authz_error(self): """Negative test: AuthZ error. @@ -160,7 +170,8 @@ def test_server_authz_error(self): """ # Create backend service self.td.setup_backend_for_grpc( - health_check_port=self.server_maintenance_port) + health_check_port=self.server_maintenance_port + ) # Start server and attach its NEGs to the backend service, but # until they become healthy. @@ -169,32 +180,40 @@ def test_server_authz_error(self): # Regular TLS setup, but with client policy configured using # intentionality incorrect server_namespace. - self.td.setup_server_security(server_namespace=self.server_namespace, - server_name=self.server_name, - server_port=self.server_port, - tls=True, - mtls=False) - incorrect_namespace = f'incorrect-namespace-{rand.rand_string()}' - self.td.setup_client_security(server_namespace=incorrect_namespace, - server_name=self.server_name, - tls=True, - mtls=False) + self.td.setup_server_security( + server_namespace=self.server_namespace, + server_name=self.server_name, + server_port=self.server_port, + tls=True, + mtls=False, + ) + incorrect_namespace = f"incorrect-namespace-{rand.rand_string()}" + self.td.setup_client_security( + server_namespace=incorrect_namespace, + server_name=self.server_name, + tls=True, + mtls=False, + ) # Create the routing rule map. - self.td.setup_routing_rule_map_for_grpc(self.server_xds_host, - self.server_xds_port) + self.td.setup_routing_rule_map_for_grpc( + self.server_xds_host, self.server_xds_port + ) # Now that TD setup is complete, Backend Service can be populated # with healthy backends (NEGs). self.td.wait_for_backends_healthy_status() # Start the client, but don't wait for it to report a healthy channel. test_client: _XdsTestClient = self.startSecureTestClient( - test_server, wait_for_active_server_channel=False) + test_server, wait_for_active_server_channel=False + ) self.assertClientCannotReachServerRepeatedly(test_client) - logger.info("[SUCCESS] Client's connectivity state is consistent with " - "AuthZ error caused by server presenting incorrect SAN.") + logger.info( + "[SUCCESS] Client's connectivity state is consistent with " + "AuthZ error caused by server presenting incorrect SAN." + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/subsetting_test.py b/tools/run_tests/xds_k8s_test_driver/tests/subsetting_test.py index 51379175b51f1..e3dcae0a3aeab 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/subsetting_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/subsetting_test.py @@ -36,39 +36,38 @@ class SubsettingTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: # Subsetting is an experimental feature where most work is done on the # server-side. We limit it to only run on master branch to save # resources. - return config.version_gte('master') + return config.version_gte("master") def test_subsetting_basic(self) -> None: - with self.subTest('00_create_health_check'): + with self.subTest("00_create_health_check"): self.td.create_health_check() - with self.subTest('01_create_backend_services'): + with self.subTest("01_create_backend_services"): self.td.create_backend_service(subset_size=_SUBSET_SIZE) - with self.subTest('02_create_url_map'): + with self.subTest("02_create_url_map"): self.td.create_url_map(self.server_xds_host, self.server_xds_port) - with self.subTest('03_create_target_proxy'): + with self.subTest("03_create_target_proxy"): self.td.create_target_proxy() - with self.subTest('04_create_forwarding_rule'): + with self.subTest("04_create_forwarding_rule"): self.td.create_forwarding_rule(self.server_xds_port) test_servers: List[_XdsTestServer] - with self.subTest('05_start_test_servers'): + with self.subTest("05_start_test_servers"): test_servers = self.startTestServers(replica_count=_NUM_BACKENDS) - with self.subTest('06_add_server_backends_to_backend_services'): + with self.subTest("06_add_server_backends_to_backend_services"): self.setupServerBackends() rpc_distribution = collections.defaultdict(int) - with self.subTest('07_start_test_client'): + with self.subTest("07_start_test_client"): for i in range(_NUM_CLIENTS): # Clean created client pods if there is any. if self.client_runner.time_start_requested: @@ -77,37 +76,53 @@ def test_subsetting_basic(self) -> None: # Create a test client test_client: _XdsTestClient = self.startTestClient( - test_servers[0]) + test_servers[0] + ) # Validate the number of received endpoints config = test_client.csds.fetch_client_status( - log_level=logging.INFO) + log_level=logging.INFO + ) self.assertIsNotNone(config) json_config = json_format.MessageToDict(config) parsed = xds_url_map_testcase.DumpedXdsConfig(json_config) - logging.info('Client %d received endpoints (len=%s): %s', i, - len(parsed.endpoints), parsed.endpoints) + logging.info( + "Client %d received endpoints (len=%s): %s", + i, + len(parsed.endpoints), + parsed.endpoints, + ) self.assertLen(parsed.endpoints, _SUBSET_SIZE) # Record RPC stats - lb_stats = self.getClientRpcStats(test_client, - _NUM_BACKENDS * 25) + lb_stats = self.getClientRpcStats( + test_client, _NUM_BACKENDS * 25 + ) for key, value in lb_stats.rpcs_by_peer.items(): rpc_distribution[key] += value - with self.subTest('08_log_rpc_distribution'): - server_entries = sorted(rpc_distribution.items(), - key=lambda x: -x[1]) + with self.subTest("08_log_rpc_distribution"): + server_entries = sorted( + rpc_distribution.items(), key=lambda x: -x[1] + ) # Validate if clients are receiving different sets of backends (3 # client received a total of 4 unique backends == FAIL, a total of 5 # unique backends == PASS) self.assertGreater(len(server_entries), _SUBSET_SIZE) - logging.info('RPC distribution (len=%s): %s', len(server_entries), - server_entries) + logging.info( + "RPC distribution (len=%s): %s", + len(server_entries), + server_entries, + ) peak = server_entries[0][1] - mean = sum(map(lambda x: x[1], - server_entries)) / len(server_entries) - logging.info('Peak=%d Mean=%.1f Peak-to-Mean-Ratio=%.2f', peak, - mean, peak / mean) - - -if __name__ == '__main__': + mean = sum(map(lambda x: x[1], server_entries)) / len( + server_entries + ) + logging.info( + "Peak=%d Mean=%.1f Peak-to-Mean-Ratio=%.2f", + peak, + mean, + peak / mean, + ) + + +if __name__ == "__main__": absltest.main(failfast=True) diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/__main__.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/__main__.py index c89f6ec3b48ee..ff8c4855c2c57 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/__main__.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/__main__.py @@ -22,10 +22,11 @@ def load_tests(loader: absltest.TestLoader, unused_tests, unused_pattern): - return loader.discover(_TEST_CASE_FOLDER, - pattern='*' + - xds_url_map_testcase.URL_MAP_TESTCASE_FILE_SUFFIX) + return loader.discover( + _TEST_CASE_FOLDER, + pattern="*" + xds_url_map_testcase.URL_MAP_TESTCASE_FILE_SUFFIX, + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/affinity_test.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/affinity_test.py index 1bc9928bfcd2f..4719e8a68bda4 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/affinity_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/affinity_test.py @@ -38,16 +38,19 @@ _NUM_RPCS = 150 _TEST_METADATA_KEY = traffic_director.TEST_AFFINITY_METADATA_KEY -_TEST_METADATA_VALUE_UNARY = 'unary_yranu' -_TEST_METADATA_VALUE_EMPTY = 'empty_ytpme' -_TEST_METADATA_NUMERIC_KEY = 'xds_md_numeric' -_TEST_METADATA_NUMERIC_VALUE = '159' +_TEST_METADATA_VALUE_UNARY = "unary_yranu" +_TEST_METADATA_VALUE_EMPTY = "empty_ytpme" +_TEST_METADATA_NUMERIC_KEY = "xds_md_numeric" +_TEST_METADATA_NUMERIC_VALUE = "159" _TEST_METADATA = ( (RpcTypeUnaryCall, _TEST_METADATA_KEY, _TEST_METADATA_VALUE_UNARY), (RpcTypeEmptyCall, _TEST_METADATA_KEY, _TEST_METADATA_VALUE_EMPTY), - (RpcTypeUnaryCall, _TEST_METADATA_NUMERIC_KEY, - _TEST_METADATA_NUMERIC_VALUE), + ( + RpcTypeUnaryCall, + _TEST_METADATA_NUMERIC_KEY, + _TEST_METADATA_NUMERIC_VALUE, + ), ) _ChannelzChannelState = grpc_channelz.ChannelState @@ -57,9 +60,9 @@ def _is_supported(config: skips.TestConfig) -> bool: # Per "Ring hash" in # https://github.com/grpc/grpc/blob/master/doc/grpc_xds_features.md if config.client_lang in _Lang.CPP | _Lang.JAVA: - return config.version_gte('v1.40.x') + return config.version_gte("v1.40.x") elif config.client_lang == _Lang.GO: - return config.version_gte('v1.41.x') + return config.version_gte("v1.41.x") elif config.client_lang == _Lang.PYTHON: # TODO(https://github.com/grpc/grpc/issues/27430): supported after # the issue is fixed. @@ -70,7 +73,6 @@ def _is_supported(config: skips.TestConfig) -> bool: class TestHeaderBasedAffinity(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @@ -82,37 +84,46 @@ def client_init_config(rpc: str, metadata: str): # backends (behavior of RING_HASH). This is necessary to only one # sub-channel is picked and used from the beginning, thus the channel # will only create this one sub-channel. - return 'EmptyCall', 'EmptyCall:%s:%s' % (_TEST_METADATA_KEY, - _TEST_METADATA_VALUE_EMPTY) + return "EmptyCall", "EmptyCall:%s:%s" % ( + _TEST_METADATA_KEY, + _TEST_METADATA_VALUE_EMPTY, + ) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: # Update default service to the affinity service. - path_matcher["defaultService"] = GcpResourceManager( - ).affinity_backend_service() + path_matcher[ + "defaultService" + ] = GcpResourceManager().affinity_backend_service() return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): # 3 endpoints in the affinity backend service. self.assertNumEndpoints(xds_config, 3) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['route'] - ['hashPolicy'][0]['header']['headerName'], _TEST_METADATA_KEY) - self.assertEqual(xds_config.cds[0]['lbPolicy'], 'RING_HASH') + xds_config.rds["virtualHosts"][0]["routes"][0]["route"][ + "hashPolicy" + ][0]["header"]["headerName"], + _TEST_METADATA_KEY, + ) + self.assertEqual(xds_config.cds[0]["lbPolicy"], "RING_HASH") def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, + rpc_types=[RpcTypeEmptyCall], + metadata=_TEST_METADATA, + num_rpcs=_NUM_RPCS, + ) # Only one backend should receive traffic, even though there are 3 # backends. self.assertEqual(1, rpc_distribution.num_peers) self.assertLen( test_client.find_subchannels_with_state( - _ChannelzChannelState.READY), + _ChannelzChannelState.READY + ), 1, ) self.assertLen( @@ -124,18 +135,20 @@ def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( test_client, rpc_types=[RpcTypeEmptyCall, RpcTypeUnaryCall], - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertEqual(3, rpc_distribution.num_peers) self.assertLen( test_client.find_subchannels_with_state( - _ChannelzChannelState.READY), + _ChannelzChannelState.READY + ), 3, ) class TestHeaderBasedAffinityMultipleHeaders( - xds_url_map_testcase.XdsUrlMapTestCase): - + xds_url_map_testcase.XdsUrlMapTestCase +): @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @@ -147,45 +160,57 @@ def client_init_config(rpc: str, metadata: str): # backends (behavior of RING_HASH). This is necessary to only one # sub-channel is picked and used from the beginning, thus the channel # will only create this one sub-channel. - return 'EmptyCall', 'EmptyCall:%s:%s' % (_TEST_METADATA_KEY, - _TEST_METADATA_VALUE_EMPTY) + return "EmptyCall", "EmptyCall:%s:%s" % ( + _TEST_METADATA_KEY, + _TEST_METADATA_VALUE_EMPTY, + ) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: # Update default service to the affinity service. - path_matcher["defaultService"] = GcpResourceManager( - ).affinity_backend_service() + path_matcher[ + "defaultService" + ] = GcpResourceManager().affinity_backend_service() return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): # 3 endpoints in the affinity backend service. self.assertNumEndpoints(xds_config, 3) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['route'] - ['hashPolicy'][0]['header']['headerName'], _TEST_METADATA_KEY) - self.assertEqual(xds_config.cds[0]['lbPolicy'], 'RING_HASH') + xds_config.rds["virtualHosts"][0]["routes"][0]["route"][ + "hashPolicy" + ][0]["header"]["headerName"], + _TEST_METADATA_KEY, + ) + self.assertEqual(xds_config.cds[0]["lbPolicy"], "RING_HASH") def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, + rpc_types=[RpcTypeEmptyCall], + metadata=_TEST_METADATA, + num_rpcs=_NUM_RPCS, + ) # Only one backend should receive traffic, even though there are 3 # backends. self.assertEqual(1, rpc_distribution.num_peers) self.assertLen( test_client.find_subchannels_with_state( - _ChannelzChannelState.READY), + _ChannelzChannelState.READY + ), 1, ) self.assertLen( test_client.find_subchannels_with_state(_ChannelzChannelState.IDLE), 2, ) - empty_call_peer = list(rpc_distribution.raw['rpcsByMethod']['EmptyCall'] - ['rpcsByPeer'].keys())[0] + empty_call_peer = list( + rpc_distribution.raw["rpcsByMethod"]["EmptyCall"][ + "rpcsByPeer" + ].keys() + )[0] # Send RPCs with a different metadata value, try different values to # verify that the client will pick a different backend. # @@ -201,27 +226,38 @@ def rpc_distribution_validate(self, test_client: XdsTestClient): different_peer_picked = False for i in range(30): new_metadata = ( - (RpcTypeEmptyCall, _TEST_METADATA_KEY, - _TEST_METADATA_VALUE_EMPTY), + ( + RpcTypeEmptyCall, + _TEST_METADATA_KEY, + _TEST_METADATA_VALUE_EMPTY, + ), (RpcTypeUnaryCall, _TEST_METADATA_KEY, str(i)), ) rpc_distribution = self.configure_and_send( test_client, rpc_types=[RpcTypeEmptyCall, RpcTypeUnaryCall], metadata=new_metadata, - num_rpcs=_NUM_RPCS) - unary_call_peer = list(rpc_distribution.raw['rpcsByMethod'] - ['UnaryCall']['rpcsByPeer'].keys())[0] + num_rpcs=_NUM_RPCS, + ) + unary_call_peer = list( + rpc_distribution.raw["rpcsByMethod"]["UnaryCall"][ + "rpcsByPeer" + ].keys() + )[0] if unary_call_peer != empty_call_peer: different_peer_picked = True break self.assertTrue( different_peer_picked, - ("the same endpoint was picked for all the headers, expect a " - "different endpoint to be picked")) + ( + "the same endpoint was picked for all the headers, expect a " + "different endpoint to be picked" + ), + ) self.assertLen( test_client.find_subchannels_with_state( - _ChannelzChannelState.READY), + _ChannelzChannelState.READY + ), 2, ) self.assertLen( @@ -234,5 +270,5 @@ def rpc_distribution_validate(self, test_client: XdsTestClient): # 1. based on the basic test, turn down the backend in use, then verify that all # RPCs are sent to another backend -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/csds_test.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/csds_test.py index 63f60b3c39858..98bfac57c3b56 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/csds_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/csds_test.py @@ -38,41 +38,42 @@ class TestBasicCsds(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: if config.client_lang == _Lang.NODE: - return config.version_gte('v1.5.x') + return config.version_gte("v1.5.x") return True @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): # Validate Endpoint Configs self.assertNumEndpoints(xds_config, 1) # Validate Node - self.assertEqual(self.test_client.ip, - xds_config['node']['metadata']['INSTANCE_IP']) + self.assertEqual( + self.test_client.ip, xds_config["node"]["metadata"]["INSTANCE_IP"] + ) # Validate Listeners self.assertIsNotNone(xds_config.lds) - self.assertEqual(self.hostname(), xds_config.lds['name']) + self.assertEqual(self.hostname(), xds_config.lds["name"]) # Validate Route Configs - self.assertTrue(xds_config.rds['virtualHosts']) + self.assertTrue(xds_config.rds["virtualHosts"]) # Validate Clusters self.assertEqual(1, len(xds_config.cds)) - self.assertEqual('EDS', xds_config.cds[0]['type']) + self.assertEqual("EDS", xds_config.cds[0]["type"]) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( test_client, rpc_types=[RpcTypeUnaryCall, RpcTypeEmptyCall], - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertEqual(_NUM_RPCS, rpc_distribution.num_oks) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/fault_injection_test.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/fault_injection_test.py index 4a43332efb2f6..11d0a36476703 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/fault_injection_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/fault_injection_test.py @@ -55,34 +55,34 @@ _BACKLOG_WAIT_TIME_SEC = 20 -def _build_fault_injection_route_rule(abort_percentage: int = 0, - delay_percentage: int = 0): +def _build_fault_injection_route_rule( + abort_percentage: int = 0, delay_percentage: int = 0 +): return { - 'priority': 0, - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/UnaryCall' - }], - 'service': GcpResourceManager().default_backend_service(), - 'routeAction': { - 'faultInjectionPolicy': { - 'abort': { - 'httpStatus': 401, - 'percentage': abort_percentage, + "priority": 0, + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/UnaryCall"} + ], + "service": GcpResourceManager().default_backend_service(), + "routeAction": { + "faultInjectionPolicy": { + "abort": { + "httpStatus": 401, + "percentage": abort_percentage, + }, + "delay": { + "fixedDelay": {"seconds": "20"}, + "percentage": delay_percentage, }, - 'delay': { - 'fixedDelay': { - 'seconds': '20' - }, - 'percentage': delay_percentage, - } } }, } -def _wait_until_backlog_cleared(test_client: XdsTestClient, - timeout: int = _BACKLOG_WAIT_TIME_SEC): - """ Wait until the completed RPC is close to started RPC. +def _wait_until_backlog_cleared( + test_client: XdsTestClient, timeout: int = _BACKLOG_WAIT_TIME_SEC +): + """Wait until the completed RPC is close to started RPC. For delay injected test cases, there might be a backlog of RPCs due to slow initialization of the client. E.g., if initialization took 20s and qps is @@ -90,7 +90,7 @@ def _wait_until_backlog_cleared(test_client: XdsTestClient, fine, because RPCs will fail immediately. But for delay injected test cases, the RPC might linger much longer and affect the stability of test results. """ - logger.info('Waiting for RPC backlog to clear for %d seconds', timeout) + logger.info("Waiting for RPC backlog to clear for %d seconds", timeout) deadline = time.time() + timeout while time.time() < deadline: stats = test_client.get_load_balancer_accumulated_stats() @@ -98,74 +98,91 @@ def _wait_until_backlog_cleared(test_client: XdsTestClient, for rpc_type in [RpcTypeUnaryCall, RpcTypeEmptyCall]: started = stats.num_rpcs_started_by_method.get(rpc_type, 0) completed = stats.num_rpcs_succeeded_by_method.get( - rpc_type, 0) + stats.num_rpcs_failed_by_method.get(rpc_type, 0) + rpc_type, 0 + ) + stats.num_rpcs_failed_by_method.get(rpc_type, 0) # We consider the backlog is healthy, if the diff between started # RPCs and completed RPCs is less than 1.5 QPS. if abs(started - completed) > xds_url_map_testcase.QPS.value * 1.1: logger.info( - 'RPC backlog exist: rpc_type=%s started=%s completed=%s', - rpc_type, started, completed) + "RPC backlog exist: rpc_type=%s started=%s completed=%s", + rpc_type, + started, + completed, + ) time.sleep(_DELAY_CASE_APPLICATION_TIMEOUT_SEC) ok = False else: logger.info( - 'RPC backlog clear: rpc_type=%s started=%s completed=%s', - rpc_type, started, completed) + "RPC backlog clear: rpc_type=%s started=%s completed=%s", + rpc_type, + started, + completed, + ) if ok: # Both backlog of both types of RPCs is clear, success, return. return - raise RuntimeError('failed to clear RPC backlog in %s seconds' % timeout) + raise RuntimeError("failed to clear RPC backlog in %s seconds" % timeout) def _is_supported(config: skips.TestConfig) -> bool: if config.client_lang == _Lang.NODE: - return config.version_gte('v1.4.x') + return config.version_gte("v1.4.x") return True class TestZeroPercentFaultInjection(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ - _build_fault_injection_route_rule(abort_percentage=0, - delay_percentage=0) + _build_fault_injection_route_rule( + abort_percentage=0, delay_percentage=0 + ) ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 1) - filter_config = xds_config.rds['virtualHosts'][0]['routes'][0][ - 'typedPerFilterConfig']['envoy.filters.http.fault'] - self.assertEqual('20s', filter_config['delay']['fixedDelay']) + filter_config = xds_config.rds["virtualHosts"][0]["routes"][0][ + "typedPerFilterConfig" + ]["envoy.filters.http.fault"] + self.assertEqual("20s", filter_config["delay"]["fixedDelay"]) + self.assertEqual( + 0, filter_config["delay"]["percentage"].get("numerator", 0) + ) + self.assertEqual( + "MILLION", filter_config["delay"]["percentage"]["denominator"] + ) + self.assertEqual(401, filter_config["abort"]["httpStatus"]) self.assertEqual( - 0, filter_config['delay']['percentage'].get('numerator', 0)) - self.assertEqual('MILLION', - filter_config['delay']['percentage']['denominator']) - self.assertEqual(401, filter_config['abort']['httpStatus']) + 0, filter_config["abort"]["percentage"].get("numerator", 0) + ) self.assertEqual( - 0, filter_config['abort']['percentage'].get('numerator', 0)) - self.assertEqual('MILLION', - filter_config['abort']['percentage']['denominator']) + "MILLION", filter_config["abort"]["percentage"]["denominator"] + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - self.configure_and_send(test_client, - rpc_types=(RpcTypeUnaryCall,), - num_rpcs=_NUM_RPCS) - self.assertRpcStatusCode(test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.OK, - ratio=1),), - length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_NON_RANDOM_ERROR_TOLERANCE) + self.configure_and_send( + test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS + ) + self.assertRpcStatusCode( + test_client, + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.OK, + ratio=1, + ), + ), + length=_LENGTH_OF_RPC_SENDING_SEC, + tolerance=_NON_RANDOM_ERROR_TOLERANCE, + ) class TestNonMatchingFaultInjection(xds_url_map_testcase.XdsUrlMapTestCase): @@ -181,15 +198,16 @@ def client_init_config(rpc: str, metadata: str): # 20s injected). The purpose of this test is examining the un-injected # traffic is not impacted, so it's fine to just send un-injected # traffic. - return 'EmptyCall', metadata + return "EmptyCall", metadata @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ - _build_fault_injection_route_rule(abort_percentage=100, - delay_percentage=100) + _build_fault_injection_route_rule( + abort_percentage=100, delay_percentage=100 + ) ] return host_rule, path_matcher @@ -198,199 +216,244 @@ def xds_config_validate(self, xds_config: DumpedXdsConfig): # The first route rule for UNARY_CALL is fault injected self.assertEqual( "/grpc.testing.TestService/UnaryCall", - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['path']) - filter_config = xds_config.rds['virtualHosts'][0]['routes'][0][ - 'typedPerFilterConfig']['envoy.filters.http.fault'] - self.assertEqual('20s', filter_config['delay']['fixedDelay']) - self.assertEqual(1000000, - filter_config['delay']['percentage']['numerator']) - self.assertEqual('MILLION', - filter_config['delay']['percentage']['denominator']) - self.assertEqual(401, filter_config['abort']['httpStatus']) - self.assertEqual(1000000, - filter_config['abort']['percentage']['numerator']) - self.assertEqual('MILLION', - filter_config['abort']['percentage']['denominator']) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["path"], + ) + filter_config = xds_config.rds["virtualHosts"][0]["routes"][0][ + "typedPerFilterConfig" + ]["envoy.filters.http.fault"] + self.assertEqual("20s", filter_config["delay"]["fixedDelay"]) + self.assertEqual( + 1000000, filter_config["delay"]["percentage"]["numerator"] + ) + self.assertEqual( + "MILLION", filter_config["delay"]["percentage"]["denominator"] + ) + self.assertEqual(401, filter_config["abort"]["httpStatus"]) + self.assertEqual( + 1000000, filter_config["abort"]["percentage"]["numerator"] + ) + self.assertEqual( + "MILLION", filter_config["abort"]["percentage"]["denominator"] + ) # The second route rule for all other RPCs is untouched self.assertNotIn( - 'envoy.filters.http.fault', - xds_config.rds['virtualHosts'][0]['routes'][1].get( - 'typedPerFilterConfig', {})) + "envoy.filters.http.fault", + xds_config.rds["virtualHosts"][0]["routes"][1].get( + "typedPerFilterConfig", {} + ), + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - self.assertRpcStatusCode(test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeEmptyCall, - status_code=grpc.StatusCode.OK, - ratio=1),), - length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_NON_RANDOM_ERROR_TOLERANCE) + self.assertRpcStatusCode( + test_client, + expected=( + ExpectedResult( + rpc_type=RpcTypeEmptyCall, + status_code=grpc.StatusCode.OK, + ratio=1, + ), + ), + length=_LENGTH_OF_RPC_SENDING_SEC, + tolerance=_NON_RANDOM_ERROR_TOLERANCE, + ) -@absltest.skip('20% RPC might pass immediately, reason unknown') +@absltest.skip("20% RPC might pass immediately, reason unknown") class TestAlwaysDelay(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ - _build_fault_injection_route_rule(abort_percentage=0, - delay_percentage=100) + _build_fault_injection_route_rule( + abort_percentage=0, delay_percentage=100 + ) ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 1) - filter_config = xds_config.rds['virtualHosts'][0]['routes'][0][ - 'typedPerFilterConfig']['envoy.filters.http.fault'] - self.assertEqual('20s', filter_config['delay']['fixedDelay']) - self.assertEqual(1000000, - filter_config['delay']['percentage']['numerator']) - self.assertEqual('MILLION', - filter_config['delay']['percentage']['denominator']) + filter_config = xds_config.rds["virtualHosts"][0]["routes"][0][ + "typedPerFilterConfig" + ]["envoy.filters.http.fault"] + self.assertEqual("20s", filter_config["delay"]["fixedDelay"]) + self.assertEqual( + 1000000, filter_config["delay"]["percentage"]["numerator"] + ) + self.assertEqual( + "MILLION", filter_config["delay"]["percentage"]["denominator"] + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - self.configure_and_send(test_client, - rpc_types=(RpcTypeUnaryCall,), - num_rpcs=_NUM_RPCS, - app_timeout=_DELAY_CASE_APPLICATION_TIMEOUT_SEC) + self.configure_and_send( + test_client, + rpc_types=(RpcTypeUnaryCall,), + num_rpcs=_NUM_RPCS, + app_timeout=_DELAY_CASE_APPLICATION_TIMEOUT_SEC, + ) _wait_until_backlog_cleared(test_client) self.assertRpcStatusCode( test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.DEADLINE_EXCEEDED, - ratio=1),), + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.DEADLINE_EXCEEDED, + ratio=1, + ), + ), length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_NON_RANDOM_ERROR_TOLERANCE) + tolerance=_NON_RANDOM_ERROR_TOLERANCE, + ) class TestAlwaysAbort(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ - _build_fault_injection_route_rule(abort_percentage=100, - delay_percentage=0) + _build_fault_injection_route_rule( + abort_percentage=100, delay_percentage=0 + ) ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 1) - filter_config = xds_config.rds['virtualHosts'][0]['routes'][0][ - 'typedPerFilterConfig']['envoy.filters.http.fault'] - self.assertEqual(401, filter_config['abort']['httpStatus']) - self.assertEqual(1000000, - filter_config['abort']['percentage']['numerator']) - self.assertEqual('MILLION', - filter_config['abort']['percentage']['denominator']) + filter_config = xds_config.rds["virtualHosts"][0]["routes"][0][ + "typedPerFilterConfig" + ]["envoy.filters.http.fault"] + self.assertEqual(401, filter_config["abort"]["httpStatus"]) + self.assertEqual( + 1000000, filter_config["abort"]["percentage"]["numerator"] + ) + self.assertEqual( + "MILLION", filter_config["abort"]["percentage"]["denominator"] + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - self.configure_and_send(test_client, - rpc_types=(RpcTypeUnaryCall,), - num_rpcs=_NUM_RPCS) + self.configure_and_send( + test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS + ) self.assertRpcStatusCode( test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.UNAUTHENTICATED, - ratio=1),), + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.UNAUTHENTICATED, + ratio=1, + ), + ), length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_NON_RANDOM_ERROR_TOLERANCE) + tolerance=_NON_RANDOM_ERROR_TOLERANCE, + ) class TestDelayHalf(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ - _build_fault_injection_route_rule(abort_percentage=0, - delay_percentage=50) + _build_fault_injection_route_rule( + abort_percentage=0, delay_percentage=50 + ) ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 1) - filter_config = xds_config.rds['virtualHosts'][0]['routes'][0][ - 'typedPerFilterConfig']['envoy.filters.http.fault'] - self.assertEqual('20s', filter_config['delay']['fixedDelay']) - self.assertEqual(500000, - filter_config['delay']['percentage']['numerator']) - self.assertEqual('MILLION', - filter_config['delay']['percentage']['denominator']) + filter_config = xds_config.rds["virtualHosts"][0]["routes"][0][ + "typedPerFilterConfig" + ]["envoy.filters.http.fault"] + self.assertEqual("20s", filter_config["delay"]["fixedDelay"]) + self.assertEqual( + 500000, filter_config["delay"]["percentage"]["numerator"] + ) + self.assertEqual( + "MILLION", filter_config["delay"]["percentage"]["denominator"] + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - self.configure_and_send(test_client, - rpc_types=(RpcTypeUnaryCall,), - num_rpcs=_NUM_RPCS, - app_timeout=_DELAY_CASE_APPLICATION_TIMEOUT_SEC) + self.configure_and_send( + test_client, + rpc_types=(RpcTypeUnaryCall,), + num_rpcs=_NUM_RPCS, + app_timeout=_DELAY_CASE_APPLICATION_TIMEOUT_SEC, + ) _wait_until_backlog_cleared(test_client) self.assertRpcStatusCode( test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.DEADLINE_EXCEEDED, - ratio=0.5),), + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.DEADLINE_EXCEEDED, + ratio=0.5, + ), + ), length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_ERROR_TOLERANCE) + tolerance=_ERROR_TOLERANCE, + ) class TestAbortHalf(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ - _build_fault_injection_route_rule(abort_percentage=50, - delay_percentage=0) + _build_fault_injection_route_rule( + abort_percentage=50, delay_percentage=0 + ) ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 1) - filter_config = xds_config.rds['virtualHosts'][0]['routes'][0][ - 'typedPerFilterConfig']['envoy.filters.http.fault'] - self.assertEqual(401, filter_config['abort']['httpStatus']) - self.assertEqual(500000, - filter_config['abort']['percentage']['numerator']) - self.assertEqual('MILLION', - filter_config['abort']['percentage']['denominator']) + filter_config = xds_config.rds["virtualHosts"][0]["routes"][0][ + "typedPerFilterConfig" + ]["envoy.filters.http.fault"] + self.assertEqual(401, filter_config["abort"]["httpStatus"]) + self.assertEqual( + 500000, filter_config["abort"]["percentage"]["numerator"] + ) + self.assertEqual( + "MILLION", filter_config["abort"]["percentage"]["denominator"] + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - self.configure_and_send(test_client, - rpc_types=(RpcTypeUnaryCall,), - num_rpcs=_NUM_RPCS) + self.configure_and_send( + test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS + ) self.assertRpcStatusCode( test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.UNAUTHENTICATED, - ratio=0.5),), + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.UNAUTHENTICATED, + ratio=0.5, + ), + ), length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_ERROR_TOLERANCE) + tolerance=_ERROR_TOLERANCE, + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/header_matching_test.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/header_matching_test.py index dada10667508c..1f559d735e1ef 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/header_matching_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/header_matching_test.py @@ -35,361 +35,454 @@ flags.adopt_module_key_flags(xds_url_map_testcase) _NUM_RPCS = 150 -_TEST_METADATA_KEY = 'xds_md' -_TEST_METADATA_VALUE_UNARY = 'unary_yranu' -_TEST_METADATA_VALUE_EMPTY = 'empty_ytpme' -_TEST_METADATA_NUMERIC_KEY = 'xds_md_numeric' -_TEST_METADATA_NUMERIC_VALUE = '159' +_TEST_METADATA_KEY = "xds_md" +_TEST_METADATA_VALUE_UNARY = "unary_yranu" +_TEST_METADATA_VALUE_EMPTY = "empty_ytpme" +_TEST_METADATA_NUMERIC_KEY = "xds_md_numeric" +_TEST_METADATA_NUMERIC_VALUE = "159" _TEST_METADATA = ( (RpcTypeUnaryCall, _TEST_METADATA_KEY, _TEST_METADATA_VALUE_UNARY), (RpcTypeEmptyCall, _TEST_METADATA_KEY, _TEST_METADATA_VALUE_EMPTY), - (RpcTypeUnaryCall, _TEST_METADATA_NUMERIC_KEY, - _TEST_METADATA_NUMERIC_VALUE), + ( + RpcTypeUnaryCall, + _TEST_METADATA_NUMERIC_KEY, + _TEST_METADATA_NUMERIC_VALUE, + ), ) def _is_supported(config: skips.TestConfig) -> bool: if config.client_lang == _Lang.NODE: - return config.version_gte('v1.3.x') + return config.version_gte("v1.3.x") return True class TestExactMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # Header ExactMatch -> alternate_backend_service. - # EmptyCall is sent with the metadata. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'exactMatch': _TEST_METADATA_VALUE_EMPTY - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # Header ExactMatch -> alternate_backend_service. + # EmptyCall is sent with the metadata. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "exactMatch": _TEST_METADATA_VALUE_EMPTY, + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['exactMatch'], _TEST_METADATA_VALUE_EMPTY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["exactMatch"], + _TEST_METADATA_VALUE_EMPTY, + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, + rpc_types=[RpcTypeEmptyCall], + metadata=_TEST_METADATA, + num_rpcs=_NUM_RPCS, + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.empty_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.empty_call_alternative_service_rpc_count + ) -@absltest.skip('the xDS config is good, but distribution is wrong.') +@absltest.skip("the xDS config is good, but distribution is wrong.") class TestPrefixMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # Header PrefixMatch -> alternate_backend_service. - # UnaryCall is sent with the metadata. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'prefixMatch': _TEST_METADATA_VALUE_UNARY[:2] - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # Header PrefixMatch -> alternate_backend_service. + # UnaryCall is sent with the metadata. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "prefixMatch": _TEST_METADATA_VALUE_UNARY[:2], + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['prefixMatch'], _TEST_METADATA_VALUE_UNARY[:2]) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["prefixMatch"], + _TEST_METADATA_VALUE_UNARY[:2], + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( test_client, rpc_types=(RpcTypeUnaryCall,), metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.unary_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.unary_call_alternative_service_rpc_count + ) class TestSuffixMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # Header SuffixMatch -> alternate_backend_service. - # EmptyCall is sent with the metadata. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'suffixMatch': _TEST_METADATA_VALUE_EMPTY[-2:] - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # Header SuffixMatch -> alternate_backend_service. + # EmptyCall is sent with the metadata. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "suffixMatch": _TEST_METADATA_VALUE_EMPTY[-2:], + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['suffixMatch'], _TEST_METADATA_VALUE_EMPTY[-2:]) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["suffixMatch"], + _TEST_METADATA_VALUE_EMPTY[-2:], + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, + rpc_types=[RpcTypeEmptyCall], + metadata=_TEST_METADATA, + num_rpcs=_NUM_RPCS, + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.empty_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.empty_call_alternative_service_rpc_count + ) class TestPresentMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # Header 'xds_md_numeric' present -> alternate_backend_service. - # UnaryCall is sent with the metadata, so will be sent to alternative. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_NUMERIC_KEY, - 'presentMatch': True - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # Header 'xds_md_numeric' present -> alternate_backend_service. + # UnaryCall is sent with the metadata, so will be sent to alternative. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_NUMERIC_KEY, + "presentMatch": True, + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_NUMERIC_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_NUMERIC_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['presentMatch'], True) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["presentMatch"], + True, + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( test_client, rpc_types=(RpcTypeUnaryCall,), metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.unary_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.unary_call_alternative_service_rpc_count + ) class TestInvertMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # Header invert ExactMatch -> alternate_backend_service. - # UnaryCall is sent with the metadata, so will be sent to - # default. EmptyCall will be sent to alternative. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'exactMatch': _TEST_METADATA_VALUE_UNARY, - 'invertMatch': True - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # Header invert ExactMatch -> alternate_backend_service. + # UnaryCall is sent with the metadata, so will be sent to + # default. EmptyCall will be sent to alternative. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "exactMatch": _TEST_METADATA_VALUE_UNARY, + "invertMatch": True, + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['invertMatch'], True) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["invertMatch"], + True, + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( test_client, rpc_types=[RpcTypeUnaryCall, RpcTypeEmptyCall], metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertEqual(_NUM_RPCS, rpc_distribution.num_oks) self.assertEqual( - 0, rpc_distribution.unary_call_alternative_service_rpc_count) - self.assertEqual(0, - rpc_distribution.empty_call_default_service_rpc_count) + 0, rpc_distribution.unary_call_alternative_service_rpc_count + ) + self.assertEqual( + 0, rpc_distribution.empty_call_default_service_rpc_count + ) class TestRangeMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # Header 'xds_md_numeric' range [100,200] -> alternate_backend_service. - # UnaryCall is sent with the metadata in range. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_NUMERIC_KEY, - 'rangeMatch': { - 'rangeStart': '100', - 'rangeEnd': '200' + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # Header 'xds_md_numeric' range [100,200] -> alternate_backend_service. + # UnaryCall is sent with the metadata in range. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_NUMERIC_KEY, + "rangeMatch": { + "rangeStart": "100", + "rangeEnd": "200", + }, + } + ], } - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_NUMERIC_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_NUMERIC_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['rangeMatch']['start'], '100') + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["rangeMatch"]["start"], + "100", + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['rangeMatch']['end'], '200') + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["rangeMatch"]["end"], + "200", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( test_client, rpc_types=[RpcTypeUnaryCall, RpcTypeEmptyCall], metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertEqual(_NUM_RPCS, rpc_distribution.num_oks) - self.assertEqual(0, - rpc_distribution.unary_call_default_service_rpc_count) self.assertEqual( - 0, rpc_distribution.empty_call_alternative_service_rpc_count) + 0, rpc_distribution.unary_call_default_service_rpc_count + ) + self.assertEqual( + 0, rpc_distribution.empty_call_alternative_service_rpc_count + ) class TestRegexMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # Header RegexMatch -> alternate_backend_service. - # EmptyCall is sent with the metadata. - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': - _TEST_METADATA_KEY, - 'regexMatch': - "^%s.*%s$" % (_TEST_METADATA_VALUE_EMPTY[:2], - _TEST_METADATA_VALUE_EMPTY[-2:]) - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }], + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = ( + [ + { + "priority": 0, + # Header RegexMatch -> alternate_backend_service. + # EmptyCall is sent with the metadata. + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "regexMatch": "^%s.*%s$" + % ( + _TEST_METADATA_VALUE_EMPTY[:2], + _TEST_METADATA_VALUE_EMPTY[-2:], + ), + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ], + ) return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['safeRegexMatch']['regex'], "^%s.*%s$" % - (_TEST_METADATA_VALUE_EMPTY[:2], _TEST_METADATA_VALUE_EMPTY[-2:])) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["safeRegexMatch"]["regex"], + "^%s.*%s$" + % (_TEST_METADATA_VALUE_EMPTY[:2], _TEST_METADATA_VALUE_EMPTY[-2:]), + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, + rpc_types=[RpcTypeEmptyCall], + metadata=_TEST_METADATA, + num_rpcs=_NUM_RPCS, + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.empty_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.empty_call_alternative_service_rpc_count + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/metadata_filter_test.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/metadata_filter_test.py index c9e054e09e9ea..eb131d626b036 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/metadata_filter_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/metadata_filter_test.py @@ -33,19 +33,19 @@ flags.adopt_module_key_flags(xds_url_map_testcase) _NUM_RPCS = 150 -_TEST_METADATA_KEY = 'xds_md' -_TEST_METADATA_VALUE_EMPTY = 'empty_ytpme' -_TEST_METADATA = ((RpcTypeEmptyCall, _TEST_METADATA_KEY, - _TEST_METADATA_VALUE_EMPTY),) -match_labels = [{ - 'name': 'TRAFFICDIRECTOR_NETWORK_NAME', - 'value': 'default-vpc' -}] -not_match_labels = [{'name': 'fake', 'value': 'fail'}] +_TEST_METADATA_KEY = "xds_md" +_TEST_METADATA_VALUE_EMPTY = "empty_ytpme" +_TEST_METADATA = ( + (RpcTypeEmptyCall, _TEST_METADATA_KEY, _TEST_METADATA_VALUE_EMPTY), +) +match_labels = [ + {"name": "TRAFFICDIRECTOR_NETWORK_NAME", "value": "default-vpc"} +] +not_match_labels = [{"name": "fake", "value": "fail"}] class TestMetadataFilterMatchAll(xds_url_map_testcase.XdsUrlMapTestCase): - """" The test url-map has two routeRules: the higher priority routes to + """ " The test url-map has two routeRules: the higher priority routes to the default backends, but is supposed to be filtered out by TD because of non-matching metadata filters. The lower priority routes to alternative backends and metadata filter matches. Thus, it verifies that TD evaluates @@ -53,222 +53,286 @@ class TestMetadataFilterMatchAll(xds_url_map_testcase.XdsUrlMapTestCase): @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ALL', - 'filterLabels': not_match_labels - }] - }], - 'service': GcpResourceManager().default_backend_service() - }, { - 'priority': 1, - 'matchRules': [{ - 'prefixMatch': - '/grpc.testing.TestService/Empty', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'exactMatch': _TEST_METADATA_VALUE_EMPTY - }], - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ALL', - 'filterLabels': match_labels - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ALL", + "filterLabels": not_match_labels, + } + ], + } + ], + "service": GcpResourceManager().default_backend_service(), + }, + { + "priority": 1, + "matchRules": [ + { + "prefixMatch": "/grpc.testing.TestService/Empty", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "exactMatch": _TEST_METADATA_VALUE_EMPTY, + } + ], + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ALL", + "filterLabels": match_labels, + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + }, + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) - self.assertEqual(len(xds_config.rds['virtualHosts'][0]['routes']), 2) + self.assertEqual(len(xds_config.rds["virtualHosts"][0]["routes"]), 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['prefix'], - "/grpc.testing.TestService/Empty") + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["prefix"], + "/grpc.testing.TestService/Empty", + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['exactMatch'], _TEST_METADATA_VALUE_EMPTY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["exactMatch"], + _TEST_METADATA_VALUE_EMPTY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][1]['match']['prefix'], - "") + xds_config.rds["virtualHosts"][0]["routes"][1]["match"]["prefix"], + "", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, + rpc_types=[RpcTypeEmptyCall], + metadata=_TEST_METADATA, + num_rpcs=_NUM_RPCS, + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.empty_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.empty_call_alternative_service_rpc_count + ) class TestMetadataFilterMatchAny(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ANY', - 'filterLabels': not_match_labels - }] - }], - 'service': GcpResourceManager().default_backend_service() - }, { - 'priority': 1, - 'matchRules': [{ - 'prefixMatch': - '/grpc.testing.TestService/Unary', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ANY', - 'filterLabels': not_match_labels + match_labels - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ANY", + "filterLabels": not_match_labels, + } + ], + } + ], + "service": GcpResourceManager().default_backend_service(), + }, + { + "priority": 1, + "matchRules": [ + { + "prefixMatch": "/grpc.testing.TestService/Unary", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ANY", + "filterLabels": not_match_labels + match_labels, + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + }, + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) - self.assertEqual(len(xds_config.rds['virtualHosts'][0]['routes']), 2) + self.assertEqual(len(xds_config.rds["virtualHosts"][0]["routes"]), 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['prefix'], - "/grpc.testing.TestService/Unary") + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["prefix"], + "/grpc.testing.TestService/Unary", + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][1]['match']['prefix'], - "") + xds_config.rds["virtualHosts"][0]["routes"][1]["match"]["prefix"], + "", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( - test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS) + test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.unary_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.unary_call_alternative_service_rpc_count + ) class TestMetadataFilterMatchAnyAndAll(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - 'matchRules': [{ - 'prefixMatch': - '/', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ALL', - 'filterLabels': not_match_labels + match_labels - }] - }], - 'service': GcpResourceManager().default_backend_service() - }, { - 'priority': 1, - 'matchRules': [{ - 'prefixMatch': - '/grpc.testing.TestService/Unary', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ANY', - 'filterLabels': not_match_labels + match_labels - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + "matchRules": [ + { + "prefixMatch": "/", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ALL", + "filterLabels": not_match_labels + match_labels, + } + ], + } + ], + "service": GcpResourceManager().default_backend_service(), + }, + { + "priority": 1, + "matchRules": [ + { + "prefixMatch": "/grpc.testing.TestService/Unary", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ANY", + "filterLabels": not_match_labels + match_labels, + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + }, + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) - self.assertEqual(len(xds_config.rds['virtualHosts'][0]['routes']), 2) + self.assertEqual(len(xds_config.rds["virtualHosts"][0]["routes"]), 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['prefix'], - "/grpc.testing.TestService/Unary") + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["prefix"], + "/grpc.testing.TestService/Unary", + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][1]['match']['prefix'], - "") + xds_config.rds["virtualHosts"][0]["routes"][1]["match"]["prefix"], + "", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( - test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS) + test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.unary_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.unary_call_alternative_service_rpc_count + ) class TestMetadataFilterMatchMultipleRules( - xds_url_map_testcase.XdsUrlMapTestCase): - + xds_url_map_testcase.XdsUrlMapTestCase +): @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - 'matchRules': [{ - 'prefixMatch': - '/', - 'headerMatches': [{ - 'headerName': _TEST_METADATA_KEY, - 'exactMatch': _TEST_METADATA_VALUE_EMPTY - }], - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ANY', - 'filterLabels': match_labels - }] - }], - 'service': GcpResourceManager().alternative_backend_service() - }, { - 'priority': 1, - 'matchRules': [{ - 'prefixMatch': - '/grpc.testing.TestService/Unary', - 'metadataFilters': [{ - 'filterMatchCriteria': 'MATCH_ALL', - 'filterLabels': match_labels - }] - }], - 'service': GcpResourceManager().default_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + "matchRules": [ + { + "prefixMatch": "/", + "headerMatches": [ + { + "headerName": _TEST_METADATA_KEY, + "exactMatch": _TEST_METADATA_VALUE_EMPTY, + } + ], + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ANY", + "filterLabels": match_labels, + } + ], + } + ], + "service": GcpResourceManager().alternative_backend_service(), + }, + { + "priority": 1, + "matchRules": [ + { + "prefixMatch": "/grpc.testing.TestService/Unary", + "metadataFilters": [ + { + "filterMatchCriteria": "MATCH_ALL", + "filterLabels": match_labels, + } + ], + } + ], + "service": GcpResourceManager().default_backend_service(), + }, + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) - self.assertEqual(len(xds_config.rds['virtualHosts'][0]['routes']), 3) + self.assertEqual(len(xds_config.rds["virtualHosts"][0]["routes"]), 3) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['name'], _TEST_METADATA_KEY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["name"], + _TEST_METADATA_KEY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['headers'] - [0]['exactMatch'], _TEST_METADATA_VALUE_EMPTY) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["headers"][ + 0 + ]["exactMatch"], + _TEST_METADATA_VALUE_EMPTY, + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][1]['match']['prefix'], - "/grpc.testing.TestService/Unary") + xds_config.rds["virtualHosts"][0]["routes"][1]["match"]["prefix"], + "/grpc.testing.TestService/Unary", + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][2]['match']['prefix'], - "") + xds_config.rds["virtualHosts"][0]["routes"][2]["match"]["prefix"], + "", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - metadata=_TEST_METADATA, - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, + rpc_types=[RpcTypeEmptyCall], + metadata=_TEST_METADATA, + num_rpcs=_NUM_RPCS, + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.empty_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.empty_call_alternative_service_rpc_count + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/path_matching_test.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/path_matching_test.py index c33064c0a9cb1..3d92859a4de88 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/path_matching_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/path_matching_test.py @@ -39,77 +39,82 @@ def _is_supported(config: skips.TestConfig) -> bool: if config.client_lang == _Lang.NODE: - return config.version_gte('v1.3.x') + return config.version_gte("v1.3.x") return True class TestFullPathMatchEmptyCall(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # FullPath EmptyCall -> alternate_backend_service. - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/EmptyCall' - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # FullPath EmptyCall -> alternate_backend_service. + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/EmptyCall"} + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['path'], - "/grpc.testing.TestService/EmptyCall") + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["path"], + "/grpc.testing.TestService/EmptyCall", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, rpc_types=[RpcTypeEmptyCall], num_rpcs=_NUM_RPCS + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.empty_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.empty_call_alternative_service_rpc_count + ) class TestFullPathMatchUnaryCall(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # FullPath EmptyCall -> alternate_backend_service. - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/UnaryCall' - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # FullPath EmptyCall -> alternate_backend_service. + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/UnaryCall"} + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['path'], - "/grpc.testing.TestService/UnaryCall") + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["path"], + "/grpc.testing.TestService/UnaryCall", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( - test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS) + test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.unary_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.unary_call_alternative_service_rpc_count + ) class TestTwoRoutesAndPrefixMatch(xds_url_map_testcase.XdsUrlMapTestCase): @@ -124,123 +129,143 @@ def is_supported(config: skips.TestConfig) -> bool: @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ { - 'priority': 0, + "priority": 0, # Prefix UnaryCall -> default_backend_service. - 'matchRules': [{ - 'prefixMatch': '/grpc.testing.TestService/Unary' - }], - 'service': GcpResourceManager().default_backend_service() + "matchRules": [ + {"prefixMatch": "/grpc.testing.TestService/Unary"} + ], + "service": GcpResourceManager().default_backend_service(), }, { - 'priority': 1, + "priority": 1, # FullPath EmptyCall -> alternate_backend_service. - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/EmptyCall' - }], - 'service': GcpResourceManager().alternative_backend_service() - } + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/EmptyCall"} + ], + "service": GcpResourceManager().alternative_backend_service(), + }, ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['prefix'], - "/grpc.testing.TestService/Unary") + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["prefix"], + "/grpc.testing.TestService/Unary", + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][1]['match']['path'], - "/grpc.testing.TestService/EmptyCall") + xds_config.rds["virtualHosts"][0]["routes"][1]["match"]["path"], + "/grpc.testing.TestService/EmptyCall", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( test_client, rpc_types=[RpcTypeUnaryCall, RpcTypeEmptyCall], - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertEqual(0, rpc_distribution.num_failures) self.assertEqual( - 0, rpc_distribution.unary_call_alternative_service_rpc_count) - self.assertEqual(0, - rpc_distribution.empty_call_default_service_rpc_count) + 0, rpc_distribution.unary_call_alternative_service_rpc_count + ) + self.assertEqual( + 0, rpc_distribution.empty_call_default_service_rpc_count + ) class TestRegexMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # Regex UnaryCall -> alternate_backend_service. - 'matchRules': [{ - 'regexMatch': - r'^\/.*\/UnaryCall$' # Unary methods with any services. - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # Regex UnaryCall -> alternate_backend_service. + "matchRules": [ + { + "regexMatch": ( # Unary methods with any services. + r"^\/.*\/UnaryCall$" + ) + } + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['safeRegex'] - ['regex'], r'^\/.*\/UnaryCall$') + xds_config.rds["virtualHosts"][0]["routes"][0]["match"][ + "safeRegex" + ]["regex"], + r"^\/.*\/UnaryCall$", + ) def rpc_distribution_validate(self, test_client: XdsTestClient): rpc_distribution = self.configure_and_send( - test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS) + test_client, rpc_types=(RpcTypeUnaryCall,), num_rpcs=_NUM_RPCS + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.unary_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.unary_call_alternative_service_rpc_count + ) class TestCaseInsensitiveMatch(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher["routeRules"] = [{ - 'priority': 0, - # ignoreCase EmptyCall -> alternate_backend_service. - 'matchRules': [{ - # Case insensitive matching. - 'fullPathMatch': '/gRpC.tEsTinG.tEstseRvice/empTycaLl', - 'ignoreCase': True, - }], - 'service': GcpResourceManager().alternative_backend_service() - }] + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + # ignoreCase EmptyCall -> alternate_backend_service. + "matchRules": [ + { + # Case insensitive matching. + "fullPathMatch": "/gRpC.tEsTinG.tEstseRvice/empTycaLl", + "ignoreCase": True, + } + ], + "service": GcpResourceManager().alternative_backend_service(), + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 2) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match']['path'], - '/gRpC.tEsTinG.tEstseRvice/empTycaLl') + xds_config.rds["virtualHosts"][0]["routes"][0]["match"]["path"], + "/gRpC.tEsTinG.tEstseRvice/empTycaLl", + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['match'] - ['caseSensitive'], False) + xds_config.rds["virtualHosts"][0]["routes"][0]["match"][ + "caseSensitive" + ], + False, + ) def rpc_distribution_validate(self, test_client: XdsTestClient): - rpc_distribution = self.configure_and_send(test_client, - rpc_types=[RpcTypeEmptyCall], - num_rpcs=_NUM_RPCS) + rpc_distribution = self.configure_and_send( + test_client, rpc_types=[RpcTypeEmptyCall], num_rpcs=_NUM_RPCS + ) self.assertEqual( - _NUM_RPCS, - rpc_distribution.empty_call_alternative_service_rpc_count) + _NUM_RPCS, rpc_distribution.empty_call_alternative_service_rpc_count + ) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/retry_test.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/retry_test.py index eda729e60c72c..74574bde6c085 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/retry_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/retry_test.py @@ -44,20 +44,20 @@ # SLEEP_DURATION number of RPC is finished. The final completed RPC might be # slightly more or less. _NON_RANDOM_ERROR_TOLERANCE = 0.01 -_RPC_BEHAVIOR_HEADER_NAME = 'rpc-behavior' +_RPC_BEHAVIOR_HEADER_NAME = "rpc-behavior" def _build_retry_route_rule(retryConditions, num_retries): return { - 'priority': 0, - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/UnaryCall' - }], - 'service': GcpResourceManager().default_backend_service(), - 'routeAction': { - 'retryPolicy': { - 'retryConditions': retryConditions, - 'numRetries': num_retries, + "priority": 0, + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/UnaryCall"} + ], + "service": GcpResourceManager().default_backend_service(), + "routeAction": { + "retryPolicy": { + "retryConditions": retryConditions, + "numRetries": num_retries, } }, } @@ -67,95 +67,115 @@ def _is_supported(config: skips.TestConfig) -> bool: # Per "Retry" in # https://github.com/grpc/grpc/blob/master/doc/grpc_xds_features.md if config.client_lang in _Lang.CPP | _Lang.JAVA | _Lang.PYTHON: - return config.version_gte('v1.40.x') + return config.version_gte("v1.40.x") elif config.client_lang == _Lang.GO: - return config.version_gte('v1.41.x') + return config.version_gte("v1.41.x") elif config.client_lang == _Lang.NODE: - return config.version_gte('v1.8.x') + return config.version_gte("v1.8.x") return True class TestRetryUpTo3AttemptsAndFail(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ - _build_retry_route_rule(retryConditions=["unavailable"], - num_retries=3) + _build_retry_route_rule( + retryConditions=["unavailable"], num_retries=3 + ) ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 1) - retry_config = xds_config.rds['virtualHosts'][0]['routes'][0]['route'][ - 'retryPolicy'] - self.assertEqual(3, retry_config['numRetries']) - self.assertEqual('unavailable', retry_config['retryOn']) + retry_config = xds_config.rds["virtualHosts"][0]["routes"][0]["route"][ + "retryPolicy" + ] + self.assertEqual(3, retry_config["numRetries"]) + self.assertEqual("unavailable", retry_config["retryOn"]) def rpc_distribution_validate(self, test_client: XdsTestClient): - self.configure_and_send(test_client, - rpc_types=(RpcTypeUnaryCall,), - metadata=[ - (RpcTypeUnaryCall, - _RPC_BEHAVIOR_HEADER_NAME, - 'succeed-on-retry-attempt-4,error-code-14') - ], - num_rpcs=_NUM_RPCS) - self.assertRpcStatusCode(test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.UNAVAILABLE, - ratio=1),), - length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_NON_RANDOM_ERROR_TOLERANCE) + self.configure_and_send( + test_client, + rpc_types=(RpcTypeUnaryCall,), + metadata=[ + ( + RpcTypeUnaryCall, + _RPC_BEHAVIOR_HEADER_NAME, + "succeed-on-retry-attempt-4,error-code-14", + ) + ], + num_rpcs=_NUM_RPCS, + ) + self.assertRpcStatusCode( + test_client, + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.UNAVAILABLE, + ratio=1, + ), + ), + length=_LENGTH_OF_RPC_SENDING_SEC, + tolerance=_NON_RANDOM_ERROR_TOLERANCE, + ) class TestRetryUpTo4AttemptsAndSucceed(xds_url_map_testcase.XdsUrlMapTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: return _is_supported(config) @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: path_matcher["routeRules"] = [ - _build_retry_route_rule(retryConditions=["unavailable"], - num_retries=4) + _build_retry_route_rule( + retryConditions=["unavailable"], num_retries=4 + ) ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 1) - retry_config = xds_config.rds['virtualHosts'][0]['routes'][0]['route'][ - 'retryPolicy'] - self.assertEqual(4, retry_config['numRetries']) - self.assertEqual('unavailable', retry_config['retryOn']) + retry_config = xds_config.rds["virtualHosts"][0]["routes"][0]["route"][ + "retryPolicy" + ] + self.assertEqual(4, retry_config["numRetries"]) + self.assertEqual("unavailable", retry_config["retryOn"]) def rpc_distribution_validate(self, test_client: XdsTestClient): - self.configure_and_send(test_client, - rpc_types=(RpcTypeUnaryCall,), - metadata=[ - (RpcTypeUnaryCall, - _RPC_BEHAVIOR_HEADER_NAME, - 'succeed-on-retry-attempt-4,error-code-14') - ], - num_rpcs=_NUM_RPCS) - self.assertRpcStatusCode(test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.OK, - ratio=1),), - length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_NON_RANDOM_ERROR_TOLERANCE) - - -if __name__ == '__main__': + self.configure_and_send( + test_client, + rpc_types=(RpcTypeUnaryCall,), + metadata=[ + ( + RpcTypeUnaryCall, + _RPC_BEHAVIOR_HEADER_NAME, + "succeed-on-retry-attempt-4,error-code-14", + ) + ], + num_rpcs=_NUM_RPCS, + ) + self.assertRpcStatusCode( + test_client, + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.OK, + ratio=1, + ), + ), + length=_LENGTH_OF_RPC_SENDING_SEC, + tolerance=_NON_RANDOM_ERROR_TOLERANCE, + ) + + +if __name__ == "__main__": absltest.main() diff --git a/tools/run_tests/xds_k8s_test_driver/tests/url_map/timeout_test.py b/tools/run_tests/xds_k8s_test_driver/tests/url_map/timeout_test.py index 8e2c42682c401..f076e85464f18 100644 --- a/tools/run_tests/xds_k8s_test_driver/tests/url_map/timeout_test.py +++ b/tools/run_tests/xds_k8s_test_driver/tests/url_map/timeout_test.py @@ -45,48 +45,54 @@ class _BaseXdsTimeOutTestCase(XdsUrlMapTestCase): - @staticmethod def url_map_change( - host_rule: HostRule, - path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]: - path_matcher['routeRules'] = [{ - 'priority': 0, - 'matchRules': [{ - 'fullPathMatch': '/grpc.testing.TestService/UnaryCall' - }], - 'service': GcpResourceManager().default_backend_service(), - 'routeAction': { - 'maxStreamDuration': { - 'seconds': 3, + host_rule: HostRule, path_matcher: PathMatcher + ) -> Tuple[HostRule, PathMatcher]: + path_matcher["routeRules"] = [ + { + "priority": 0, + "matchRules": [ + {"fullPathMatch": "/grpc.testing.TestService/UnaryCall"} + ], + "service": GcpResourceManager().default_backend_service(), + "routeAction": { + "maxStreamDuration": { + "seconds": 3, + }, }, - }, - }] + } + ] return host_rule, path_matcher def xds_config_validate(self, xds_config: DumpedXdsConfig): self.assertNumEndpoints(xds_config, 1) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['route'] - ['maxStreamDuration']['maxStreamDuration'], '3s') + xds_config.rds["virtualHosts"][0]["routes"][0]["route"][ + "maxStreamDuration" + ]["maxStreamDuration"], + "3s", + ) self.assertEqual( - xds_config.rds['virtualHosts'][0]['routes'][0]['route'] - ['maxStreamDuration']['grpcTimeoutHeaderMax'], '3s') + xds_config.rds["virtualHosts"][0]["routes"][0]["route"][ + "maxStreamDuration" + ]["grpcTimeoutHeaderMax"], + "3s", + ) def rpc_distribution_validate(self, unused_test_client): raise NotImplementedError() class TestTimeoutInRouteRule(_BaseXdsTimeOutTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: # TODO(lidiz) either add support for rpc-behavior to other languages, or we # should always use Java server as backend. - if config.server_lang != 'java': + if config.server_lang != "java": return False if config.client_lang == skips.Lang.NODE: - return config.version_gte('v1.4.x') + return config.version_gte("v1.4.x") return True def rpc_distribution_validate(self, test_client: XdsTestClient): @@ -96,32 +102,36 @@ def rpc_distribution_validate(self, test_client: XdsTestClient): # UnaryCall and EmptyCall both sleep-4. # UnaryCall timeouts, EmptyCall succeeds. metadata=( - (RpcTypeUnaryCall, 'rpc-behavior', 'sleep-4'), - (RpcTypeEmptyCall, 'rpc-behavior', 'sleep-4'), + (RpcTypeUnaryCall, "rpc-behavior", "sleep-4"), + (RpcTypeEmptyCall, "rpc-behavior", "sleep-4"), ), - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertRpcStatusCode( test_client, expected=( - ExpectedResult(rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.DEADLINE_EXCEEDED), - ExpectedResult(rpc_type=RpcTypeEmptyCall, - status_code=grpc.StatusCode.OK), + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.DEADLINE_EXCEEDED, + ), + ExpectedResult( + rpc_type=RpcTypeEmptyCall, status_code=grpc.StatusCode.OK + ), ), length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_ERROR_TOLERANCE) + tolerance=_ERROR_TOLERANCE, + ) class TestTimeoutInApplication(_BaseXdsTimeOutTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: # TODO(lidiz) either add support for rpc-behavior to other languages, or we # should always use Java server as backend. - if config.server_lang != 'java': + if config.server_lang != "java": return False if config.client_lang == skips.Lang.NODE: - return config.version_gte('v1.4.x') + return config.version_gte("v1.4.x") return True def rpc_distribution_validate(self, test_client: XdsTestClient): @@ -129,24 +139,28 @@ def rpc_distribution_validate(self, test_client: XdsTestClient): test_client, rpc_types=(RpcTypeUnaryCall,), # UnaryCall only with sleep-2; timeout=1s; calls timeout. - metadata=((RpcTypeUnaryCall, 'rpc-behavior', 'sleep-2'),), + metadata=((RpcTypeUnaryCall, "rpc-behavior", "sleep-2"),), app_timeout=1, - num_rpcs=_NUM_RPCS) + num_rpcs=_NUM_RPCS, + ) self.assertRpcStatusCode( test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.DEADLINE_EXCEEDED),), + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, + status_code=grpc.StatusCode.DEADLINE_EXCEEDED, + ), + ), length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_ERROR_TOLERANCE) + tolerance=_ERROR_TOLERANCE, + ) class TestTimeoutNotExceeded(_BaseXdsTimeOutTestCase): - @staticmethod def is_supported(config: skips.TestConfig) -> bool: if config.client_lang == skips.Lang.NODE: - return config.version_gte('v1.4.x') + return config.version_gte("v1.4.x") return True def rpc_distribution_validate(self, test_client: XdsTestClient): @@ -154,19 +168,26 @@ def rpc_distribution_validate(self, test_client: XdsTestClient): test_client, # UnaryCall only with no sleep; calls succeed. rpc_types=(RpcTypeUnaryCall,), - num_rpcs=_NUM_RPCS) - self.assertRpcStatusCode(test_client, - expected=(ExpectedResult( - rpc_type=RpcTypeUnaryCall, - status_code=grpc.StatusCode.OK),), - length=_LENGTH_OF_RPC_SENDING_SEC, - tolerance=_ERROR_TOLERANCE) + num_rpcs=_NUM_RPCS, + ) + self.assertRpcStatusCode( + test_client, + expected=( + ExpectedResult( + rpc_type=RpcTypeUnaryCall, status_code=grpc.StatusCode.OK + ), + ), + length=_LENGTH_OF_RPC_SENDING_SEC, + tolerance=_ERROR_TOLERANCE, + ) def load_tests(loader: absltest.TestLoader, unused_tests, unused_pattern): suite = unittest.TestSuite() test_cases = [ - TestTimeoutInRouteRule, TestTimeoutInApplication, TestTimeoutNotExceeded + TestTimeoutInRouteRule, + TestTimeoutInApplication, + TestTimeoutNotExceeded, ] for test_class in test_cases: tests = loader.loadTestsFromTestCase(test_class) @@ -174,5 +195,5 @@ def load_tests(loader: absltest.TestLoader, unused_tests, unused_pattern): return suite -if __name__ == '__main__': +if __name__ == "__main__": absltest.main()