From f43b7966e9c14f83fba203f532bc9d93e504d900 Mon Sep 17 00:00:00 2001 From: Joongi Kim Date: Tue, 31 May 2022 23:53:18 +0900 Subject: [PATCH] Monorepo for server-side components (#417) * We use Pants (www.pantsbuild.org) as our build toolchain. - Details will follow in subsequent document updates. * Add a custom setup generator plugin for Pants - It single-sources the version number from VERSION. - It takes the description, license, name from kwargs. - It validates the package name to start with "backend.ai-". - It takes the long-description from package-specific README. * Add a custom platform-specific dependency selector for Pants - It provides `platform_resources()` target with a per-platform dependency map. (Thanks to Andreas Stenius) * Move mypy/pytest configs to root `pyproject.toml` - flake8 does not support pyproject.toml yet (PyCQA/flake8#234), so keep it at `.flake8` - Explicitly add `setuptools` as requirement of flake8 because flake8 uses `pkg_resources` to detect its own plugin entrypoints. * Implement entrypoint scan with BUILD files for CLI * Our new plugin subsystem (`ai.backend.plugin`) uses `importlib`-based entrypoints. - This removes dependency to setuptools at runtime. * Update gitignore * Add `./py`, `./backend.ai` shortcuts to run commands in exported venv - Finally implement truly unified CLI via `./backend.ai`! * Rewrite GitHub Actions workflows * Notable non-trivial Pants configs - Set `[GLOBAL].local_execution_root_dir` to a non-tmp directory because Snap-based Docker cannot access it! - `aiosqlite` is not explicitly imported but specified as a SQLAlchemy server URL scheme, so declare a manual dependency in manager tests. - Add `[pytest].execution_slot_var` config for test parallelization * Mark storage-proxy tests as "integration" that require external dependency - Integration tests are skipped! * Update test fixtures - Spawn a single-node etcd container with OS-assigned port numbers - Self-bootstrap db containers for isolated and parallel testing * Reorganize packages - ai.backend.helpers -> backend.ai-kernel-helper - ai.backend.kernel -> backend.ai-kernel - ai.backend.runner -> backend.ai-kernel-binary * Let it build packages for only Python 3.10 - Without the upper bound, it tries to use Python 3.11-dev if available but aiohttp fails to build there. (aio-libs/aiohttp#6600) * Rewrite scripts/install-dev and scripts/delete-dev - Change the container volume path for halfstack containers to "./volumes" - When Docker is installed via Snap, it must be 20.10.15 or later to have a working `docker compose` (v2) plugin with `sudo`. - Remove the auto-install routine but just show the guides - Now we support and use docker-compose v2 only * Import `backend.ai-common` source (c864ccbe1) * Import `backend.ai-agent` source (98aeeb98) * Import `backend.ai-manager` source (85d16f0) * Import `backend.ai-client-py` source (b6d03cc) * Import `backend.ai-webserver` source (81506cc) * Import `backend.ai-storage-proxy` source (8019533) * Import `backend.ai-tester` source (ab85fab5c) Co-authored-by: Andreas Stenius --- .editorconfig | 9 +- .flake8 | 8 + .gitattributes | 66 + .github/workflows/default.yml | 245 +- .github/workflows/timeline-check.yml | 29 + .gitignore | 31 +- .gitmodules | 3 + BUILD | 26 + BUILD_ROOT | 1 + MIGRATION.md | 51 + README.md | 216 +- VERSION | 1 + backend.ai | 2 + changes/417.misc | 1 + changes/template.md | 44 + configs/agent/ci.toml | 65 + configs/agent/halfstack.toml | 83 + configs/agent/sample.toml | 270 ++ configs/manager/ci.toml | 67 + configs/manager/halfstack.alembic.ini | 74 + configs/manager/halfstack.toml | 67 + configs/manager/sample.etcd.config.json | 50 + .../manager/sample.etcd.redis-sentinel.json | 4 + configs/manager/sample.etcd.redis-single.json | 3 + configs/manager/sample.etcd.volumes.json | 15 + configs/manager/sample.toml | 163 + configs/storage-proxy/sample.toml | 144 + .../ssl/manager-api-selfsigned.cert.pem | 30 + .../ssl/manager-api-selfsigned.key.pem | 52 + configs/webserver/sample.conf | 106 + docker-compose.halfstack-2203.yml | 6 +- docker/linuxkit-nsenter/Dockerfile | 4 + docker/socket-relay/Dockerfile | 8 + fixtures/manager/example-keypairs.json | 172 + .../manager/example-resource-presets.json | 33 + .../manager/example-session-templates.json | 71 + pants | 395 ++ pants.ci.toml | 6 + pants.toml | 68 + plugins/.gitignore | 4 + plugins/README.md | 11 + py | 14 + pyproject.toml | 55 + python-kernel.lock | 563 +++ python.lock | 3915 +++++++++++++++++ requirements.txt | 79 + requirements/build.txt | 1 - requirements/dev.txt | 1 - requirements/docs.txt | 2 - requirements/lint.txt | 1 - requirements/main.txt | 1 - requirements/test.txt | 1 - requirements/typecheck.txt | 1 - scripts/agent/build-dropbear.sh | 79 + scripts/agent/build-krunner-extractor.sh | 6 + scripts/agent/build-sftpserver.sh | 136 + scripts/agent/build-socket-relay.sh | 5 + scripts/agent/build-suexec.sh | 83 + scripts/agent/build-tmux.sh | 92 + scripts/agent/ci/deploy.sh | 6 + scripts/agent/ci/install-manager.sh | 17 + scripts/agent/ci/overwrite-python.sh | 15 + scripts/agent/ci/run-manager.sh | 21 + .../deploy-static/deploy_static_files.py | 34 + scripts/agent/eks.yml | 33 + scripts/agent/run-with-halfstack.sh | 5 + scripts/agent/update-metadata-iptables.sh | 20 + scripts/check-docker.py | 109 + scripts/delete-dev.sh | 51 +- scripts/install-dev.sh | 470 +- scripts/install-plugin.sh | 8 + scripts/reinstall-plugins.sh | 8 + scripts/specialize.py | 36 + setup.cfg | 86 - setup.py | 3 - src/ai/backend/agent/BUILD | 71 + src/ai/backend/agent/README.k8s.md | 208 + src/ai/backend/agent/README.md | 291 ++ src/ai/backend/agent/VERSION | 1 + src/ai/backend/agent/__init__.py | 3 + src/ai/backend/agent/agent.py | 1807 ++++++++ src/ai/backend/agent/cli.py | 7 + src/ai/backend/agent/config.py | 102 + src/ai/backend/agent/docker/__init__.py | 8 + src/ai/backend/agent/docker/agent.py | 1367 ++++++ .../backendai-socket-relay.img.aarch64.tar.gz | 3 + .../backendai-socket-relay.img.x86_64.tar.gz | 3 + .../docker/backendai-socket-relay.version.txt | 1 + src/ai/backend/agent/docker/files.py | 64 + src/ai/backend/agent/docker/intrinsic.py | 606 +++ src/ai/backend/agent/docker/kernel.py | 413 ++ ...linuxkit-metadata-proxy-worker.aarch64.bin | 3 + .../linuxkit-metadata-proxy-worker.x86_64.bin | 3 + .../linuxkit-nsenter.img.aarch64.tar.gz | 3 + .../docker/linuxkit-nsenter.img.x86_64.tar.gz | 3 + .../agent/docker/linuxkit-nsenter.version.txt | 1 + .../backend/agent/docker/metadata/server.py | 64 + src/ai/backend/agent/docker/resources.py | 98 + src/ai/backend/agent/docker/utils.py | 162 + src/ai/backend/agent/exception.py | 53 + src/ai/backend/agent/fs.py | 42 + src/ai/backend/agent/kernel.py | 920 ++++ src/ai/backend/agent/kubernetes/__init__.py | 8 + src/ai/backend/agent/kubernetes/agent.py | 1000 +++++ .../backendai-socket-relay.img.tar.gz | 3 + .../backendai-socket-relay.version.txt | 1 + src/ai/backend/agent/kubernetes/files.py | 64 + src/ai/backend/agent/kubernetes/intrinsic.py | 336 ++ src/ai/backend/agent/kubernetes/kernel.py | 448 ++ .../backend/agent/kubernetes/kube_object.py | 201 + src/ai/backend/agent/kubernetes/resources.py | 98 + src/ai/backend/agent/kubernetes/utils.py | 140 + .../agent/linuxkit-metadata-proxy/main.go | 84 + src/ai/backend/agent/proxy.py | 73 + src/ai/backend/agent/py.typed | 1 + src/ai/backend/agent/resources.py | 1001 +++++ src/ai/backend/agent/server.py | 795 ++++ src/ai/backend/agent/stats.py | 428 ++ src/ai/backend/agent/types.py | 83 + src/ai/backend/agent/utils.py | 336 ++ src/ai/backend/agent/vendor/__init__.py | 0 src/ai/backend/agent/vendor/linux.py | 57 + src/ai/backend/agent/watcher.py | 384 ++ src/ai/backend/cli/BUILD | 48 + src/ai/backend/cli/README.md | 23 + src/ai/backend/cli/VERSION | 1 + src/ai/backend/cli/__init__.py | 1 + src/ai/backend/cli/__main__.py | 9 + src/ai/backend/cli/extensions.py | 144 + src/ai/backend/cli/interaction.py | 150 + src/ai/backend/cli/loader.py | 23 + src/ai/backend/cli/main.py | 14 + src/ai/backend/cli/py.typed | 1 + src/ai/backend/client/BUILD | 48 + src/ai/backend/client/README.rst | 218 + src/ai/backend/client/VERSION | 1 + src/ai/backend/client/__init__.py | 15 + src/ai/backend/client/auth.py | 73 + src/ai/backend/client/cli/__init__.py | 9 + src/ai/backend/client/cli/__main__.py | 5 + src/ai/backend/client/cli/admin/__init__.py | 27 + src/ai/backend/client/cli/admin/agent.py | 194 + src/ai/backend/client/cli/admin/domain.py | 228 + src/ai/backend/client/cli/admin/etcd.py | 85 + src/ai/backend/client/cli/admin/group.py | 322 ++ src/ai/backend/client/cli/admin/image.py | 120 + src/ai/backend/client/cli/admin/keypair.py | 293 ++ src/ai/backend/client/cli/admin/license.py | 38 + src/ai/backend/client/cli/admin/manager.py | 207 + src/ai/backend/client/cli/admin/resource.py | 144 + .../client/cli/admin/resource_policy.py | 220 + .../backend/client/cli/admin/scaling_group.py | 285 ++ src/ai/backend/client/cli/admin/session.py | 264 ++ src/ai/backend/client/cli/admin/storage.py | 85 + src/ai/backend/client/cli/admin/user.py | 284 ++ src/ai/backend/client/cli/admin/vfolder.py | 205 + src/ai/backend/client/cli/announcement.py | 45 + src/ai/backend/client/cli/app.py | 344 ++ src/ai/backend/client/cli/config.py | 171 + src/ai/backend/client/cli/dotfile.py | 190 + src/ai/backend/client/cli/logs.py | 29 + src/ai/backend/client/cli/main.py | 45 + src/ai/backend/client/cli/pagination.py | 115 + src/ai/backend/client/cli/params.py | 168 + src/ai/backend/client/cli/pretty.py | 181 + src/ai/backend/client/cli/proxy.py | 233 + src/ai/backend/client/cli/run.py | 726 +++ src/ai/backend/client/cli/server_log.py | 43 + src/ai/backend/client/cli/session.py | 863 ++++ src/ai/backend/client/cli/session_template.py | 156 + src/ai/backend/client/cli/ssh.py | 60 + src/ai/backend/client/cli/types.py | 22 + src/ai/backend/client/cli/vfolder.py | 681 +++ src/ai/backend/client/compat.py | 98 + src/ai/backend/client/config.py | 388 ++ src/ai/backend/client/exceptions.py | 68 + src/ai/backend/client/func/__init__.py | 0 src/ai/backend/client/func/admin.py | 75 + src/ai/backend/client/func/agent.py | 174 + src/ai/backend/client/func/auth.py | 60 + src/ai/backend/client/func/base.py | 65 + src/ai/backend/client/func/bgtask.py | 35 + src/ai/backend/client/func/domain.py | 193 + src/ai/backend/client/func/dotfile.py | 143 + src/ai/backend/client/func/etcd.py | 71 + src/ai/backend/client/func/group.py | 276 ++ src/ai/backend/client/func/image.py | 118 + src/ai/backend/client/func/keypair.py | 265 ++ .../client/func/keypair_resource_policy.py | 184 + src/ai/backend/client/func/manager.py | 105 + src/ai/backend/client/func/resource.py | 120 + src/ai/backend/client/func/scaling_group.py | 312 ++ src/ai/backend/client/func/server_log.py | 43 + src/ai/backend/client/func/session.py | 993 +++++ .../backend/client/func/session_template.py | 82 + src/ai/backend/client/func/storage.py | 86 + src/ai/backend/client/func/system.py | 37 + src/ai/backend/client/func/user.py | 359 ++ src/ai/backend/client/func/vfolder.py | 473 ++ src/ai/backend/client/helper.py | 5 + src/ai/backend/client/load_balancing.py | 70 + src/ai/backend/client/output/__init__.py | 18 + src/ai/backend/client/output/console.py | 213 + src/ai/backend/client/output/fields.py | 262 ++ src/ai/backend/client/output/formatters.py | 305 ++ src/ai/backend/client/output/json.py | 192 + src/ai/backend/client/output/types.py | 204 + src/ai/backend/client/pagination.py | 90 + src/ai/backend/client/py.typed | 1 + src/ai/backend/client/request.py | 885 ++++ src/ai/backend/client/session.py | 490 +++ src/ai/backend/client/test_utils.py | 91 + src/ai/backend/client/types.py | 28 + src/ai/backend/client/utils.py | 51 + src/ai/backend/client/versioning.py | 52 + src/ai/backend/common/BUILD | 35 + src/ai/backend/common/README.md | 39 + src/ai/backend/common/VERSION | 1 + src/ai/backend/common/__init__.py | 3 + src/ai/backend/common/argparse.py | 101 + src/ai/backend/common/asyncio.py | 124 + src/ai/backend/common/bgtask.py | 284 ++ src/ai/backend/common/cli.py | 101 + src/ai/backend/common/config.py | 163 + src/ai/backend/common/distributed.py | 75 + src/ai/backend/common/docker.py | 399 ++ src/ai/backend/common/enum_extension.py | 57 + src/ai/backend/common/enum_extension.pyi | 22 + src/ai/backend/common/etcd.py | 482 ++ src/ai/backend/common/events.py | 880 ++++ src/ai/backend/common/exception.py | 50 + src/ai/backend/common/files.py | 69 + src/ai/backend/common/identity.py | 245 ++ src/ai/backend/common/json.py | 12 + src/ai/backend/common/lock.py | 157 + src/ai/backend/common/logging.py | 509 +++ src/ai/backend/common/logging_utils.py | 24 + src/ai/backend/common/msgpack.py | 17 + src/ai/backend/common/networking.py | 59 + src/ai/backend/common/plugin/__init__.py | 177 + src/ai/backend/common/plugin/hook.py | 188 + src/ai/backend/common/plugin/monitor.py | 96 + src/ai/backend/common/plugin/py.typed | 1 + src/ai/backend/common/py.typed | 1 + src/ai/backend/common/redis.py | 500 +++ src/ai/backend/common/sd_notify.py | 119 + src/ai/backend/common/service_ports.py | 60 + src/ai/backend/common/testutils.py | 49 + src/ai/backend/common/types.py | 848 ++++ src/ai/backend/common/utils.py | 295 ++ src/ai/backend/common/validators.py | 616 +++ src/ai/backend/helpers/BUILD | 31 + src/ai/backend/helpers/README.md | 1 + src/ai/backend/helpers/VERSION | 1 + src/ai/backend/helpers/__init__.py | 9 + src/ai/backend/helpers/package.py | 44 + src/ai/backend/kernel/BUILD | 42 + src/ai/backend/kernel/README.md | 1 + src/ai/backend/kernel/VERSION | 1 + src/ai/backend/kernel/__init__.py | 30 + src/ai/backend/kernel/__main__.py | 50 + src/ai/backend/kernel/app/__init__.py | 57 + src/ai/backend/kernel/base.py | 841 ++++ src/ai/backend/kernel/c/__init__.py | 90 + src/ai/backend/kernel/compat.py | 103 + src/ai/backend/kernel/cpp/__init__.py | 87 + src/ai/backend/kernel/exception.py | 16 + src/ai/backend/kernel/git/__init__.py | 135 + src/ai/backend/kernel/golang/__init__.py | 72 + src/ai/backend/kernel/haskell/__init__.py | 64 + src/ai/backend/kernel/intrinsic.py | 144 + src/ai/backend/kernel/java/LablupPatches.java | 57 + src/ai/backend/kernel/java/__init__.py | 110 + src/ai/backend/kernel/julia/__init__.py | 75 + src/ai/backend/kernel/jupyter_client.py | 63 + src/ai/backend/kernel/logging.py | 52 + src/ai/backend/kernel/lua/__init__.py | 59 + src/ai/backend/kernel/nodejs/__init__.py | 63 + src/ai/backend/kernel/octave/__init__.py | 60 + src/ai/backend/kernel/php/__init__.py | 61 + src/ai/backend/kernel/python/__init__.py | 142 + .../backend/kernel/python/drawing/__init__.py | 7 + .../backend/kernel/python/drawing/canvas.py | 212 + src/ai/backend/kernel/python/drawing/color.py | 63 + .../backend/kernel/python/drawing/encoding.py | 12 + .../backend/kernel/python/drawing/turtle.py | 115 + src/ai/backend/kernel/python/sitecustomize.py | 46 + src/ai/backend/kernel/python/types.py | 28 + src/ai/backend/kernel/r/__init__.py | 69 + src/ai/backend/kernel/r_server_ms/__init__.py | 118 + src/ai/backend/kernel/requirements.txt | 8 + src/ai/backend/kernel/rust/__init__.py | 76 + src/ai/backend/kernel/scheme/__init__.py | 52 + src/ai/backend/kernel/service.py | 154 + src/ai/backend/kernel/service_actions.py | 72 + src/ai/backend/kernel/terminal.py | 206 + src/ai/backend/kernel/test_utils.py | 33 + src/ai/backend/kernel/utils.py | 57 + src/ai/backend/kernel/vendor/__init__.py | 0 .../kernel/vendor/aws_polly/__init__.py | 97 + .../backend/kernel/vendor/aws_polly/inproc.py | 70 + src/ai/backend/kernel/vendor/h2o/__init__.py | 98 + src/ai/backend/manager/BUILD | 75 + src/ai/backend/manager/README.md | 263 ++ src/ai/backend/manager/VERSION | 1 + src/ai/backend/manager/__init__.py | 3 + src/ai/backend/manager/api/__init__.py | 8 + src/ai/backend/manager/api/admin.py | 173 + src/ai/backend/manager/api/auth.py | 981 +++++ .../backend/manager/api/cluster_template.py | 399 ++ src/ai/backend/manager/api/context.py | 52 + src/ai/backend/manager/api/domainconfig.py | 196 + src/ai/backend/manager/api/etcd.py | 154 + src/ai/backend/manager/api/events.py | 417 ++ src/ai/backend/manager/api/exceptions.py | 418 ++ src/ai/backend/manager/api/groupconfig.py | 283 ++ src/ai/backend/manager/api/image.py | 472 ++ src/ai/backend/manager/api/logs.py | 287 ++ src/ai/backend/manager/api/manager.py | 272 ++ src/ai/backend/manager/api/py.typed | 1 + src/ai/backend/manager/api/ratelimit.py | 116 + src/ai/backend/manager/api/resource.py | 796 ++++ src/ai/backend/manager/api/scaling_group.py | 125 + src/ai/backend/manager/api/session.py | 2278 ++++++++++ .../backend/manager/api/session_template.py | 342 ++ src/ai/backend/manager/api/stream.py | 743 ++++ src/ai/backend/manager/api/types.py | 36 + src/ai/backend/manager/api/userconfig.py | 241 + src/ai/backend/manager/api/utils.py | 372 ++ src/ai/backend/manager/api/vfolder.py | 2380 ++++++++++ src/ai/backend/manager/api/wsproxy.py | 257 ++ src/ai/backend/manager/cli/__init__.py | 0 src/ai/backend/manager/cli/__main__.py | 273 ++ src/ai/backend/manager/cli/context.py | 86 + src/ai/backend/manager/cli/dbschema.py | 94 + src/ai/backend/manager/cli/etcd.py | 317 ++ src/ai/backend/manager/cli/fixture.py | 67 + src/ai/backend/manager/cli/gql.py | 33 + src/ai/backend/manager/cli/image.py | 112 + src/ai/backend/manager/cli/image_impl.py | 152 + src/ai/backend/manager/config.py | 594 +++ .../manager/container_registry/__init__.py | 29 + .../manager/container_registry/base.py | 315 ++ .../manager/container_registry/docker.py | 92 + .../manager/container_registry/harbor.py | 109 + src/ai/backend/manager/defs.py | 51 + src/ai/backend/manager/exceptions.py | 104 + src/ai/backend/manager/idle.py | 637 +++ src/ai/backend/manager/models/__init__.py | 52 + src/ai/backend/manager/models/agent.py | 398 ++ src/ai/backend/manager/models/alembic/README | 1 + src/ai/backend/manager/models/alembic/env.py | 79 + .../manager/models/alembic/script.py.mako | 24 + ...4_add_idle_timeout_to_keypair_resource_.py | 28 + .../versions/015d84d5a5ef_add_image_table.py | 59 + ...62e50e90e0_add_ssh_keypair_into_keypair.py | 69 + .../02950808ca3d_add_agent_version.py | 38 + .../06184d82a211_add_session_creation_id.py | 24 + .../0c5733f80e4d_index_kernel_timestamps.py | 46 + ...eplace_is_active_to_status_and_its_info.py | 66 + .../0e558d06e0e3_add_service_ports.py | 24 + .../versions/0f3bc98edaa0_more_status.py | 71 + .../models/alembic/versions/0f7a4b643940_.py | 24 + ...eed5_enlarge_kernels_lang_column_length.py | 32 + .../11146ba02235_change_char_col_to_str.py | 29 + ...852ff9872_add_vfolder_permissions_table.py | 39 + ..._add_clusterized_column_to_agents_table.py | 27 + ...31583e20_add_dotfile_column_to_keypairs.py | 26 + ...a31ea8e3_add_inviter_field_for_vfolder_.py | 28 + ...2b6dcbc159_add_internal_data_to_kernels.py | 28 + ...5c12b_add_total_resource_slots_to_group.py | 49 + ...dd_allowed_docker_registries_in_domains.py | 46 + .../versions/250e8656cf45_add_status_data.py | 24 + ...0fa1_add_dotfiles_to_domains_and_groups.py | 36 + ...87e764_create_vfolder_invitations_table.py | 39 + ...82340fa30e_add_mounts_info_in_kernel_db.py | 28 + ...a059_convert_lang_to_image_and_registry.py | 31 + ...2fa4f88f61_add_tpu_slot_on_kernel_model.py | 32 + .../3bb80d1887d6_add_preopen_ports.py | 28 + .../models/alembic/versions/3cf19d906e71_.py | 24 + .../alembic/versions/3f1dafab60b2_merge.py | 24 + .../versions/405aa2c39458_job_queue.py | 192 + .../4545f5c948b3_add_io_scratch_size_stats.py | 25 + ...ab2dfefba9_reindex_kernel_updated_order.py | 30 + .../4b7b650bc30e_add_creator_in_vfolders.py | 28 + .../versions/4b8a66fb8d82_revamp_keypairs.py | 30 + .../versions/4cc87e7fbfdf_stats_refactor.py | 216 + ...164749de4_add_cancelled_to_kernelstatus.py | 92 + ...518ecf41f567_add_index_for_cluster_role.py | 28 + ...d79aa21_add_logs_column_on_kernel_table.py | 24 + .../529113b08c2c_add_vfolder_type_column.py | 128 + ...a49c8_update_cluster_columns_in_kernels.py | 55 + .../versions/57b523dec0e8_add_tpu_slots.py | 34 + ...03287_rename_clone_allowed_to_cloneable.py | 22 + ...45f28d2cac_add_resource_opts_in_kernels.py | 30 + ...e6043455e_add_user_group_ids_in_vfolder.py | 162 + .../alembic/versions/5de06da3c2b5_init.py | 76 + ...0_add_unmanaged_path_column_to_vfolders.py | 26 + ...77d2_add_coordinator_address_column_on_.py | 24 + .../models/alembic/versions/65c4a109bbc7_.py | 22 + ...erge_user_s_first__last_name_into_full_.py | 34 + ...7_vfolder_invitation_state_to_enum_type.py | 45 + .../versions/7a82e0c70122_add_group_model.py | 54 + ...1d81c3204_add_vfolder_mounts_to_kernels.py | 29 + .../7ea324d0535b_vfolder_and_kernel.py | 114 + .../80176413d8aa_keypairs_get_is_admin.py | 30 + .../versions/819c2b3830a9_add_user_model.py | 136 + .../81c264528f20_add_max_session_lifetime.py | 24 + ...4bd902b1bc_change_kernel_identification.py | 67 + ...9d0a7e22b_add_scheduled_to_kernelstatus.py | 96 + .../8e660aa31fe3_add_resource_presets.py | 98 + ...80bc9_add_architecture_column_on_agents.py | 34 + .../versions/93e9d31d40bf_agent_add_region.py | 41 + .../alembic/versions/97f6c80c8aa5_merge.py | 24 + .../9a91532c8534_add_scaling_group.py | 131 + ...2a_allow_kernels_scaling_group_nullable.py | 34 + ...2_add_attached_devices_field_in_kernels.py | 30 + ...1b1ae70d_add_scheduable_field_to_agents.py | 34 + .../a1fd4e7b7782_enumerate_vfolder_perms.py | 49 + .../alembic/versions/a7ca9f175d5f_merge.py | 24 + ...bc74594aa6_add_partial_index_to_kernels.py | 31 + .../versions/bae1a7326e8a_add_domain_model.py | 79 + .../versions/bf4bae8f942e_add_kernel_host.py | 24 + .../c092dabf3ee5_add_batch_session.py | 43 + .../models/alembic/versions/c1409ad0e8da_.py | 24 + .../c3e74dcf1808_add_environ_to_kernels.py | 23 + ...dd_allowed_vfolder_hosts_to_domain_and_.py | 42 + ...d_add_shared_memory_to_resource_presets.py | 28 + ...add_domain_group_user_fields_to_kernels.py | 125 + ...ce209920f654_create_task_template_table.py | 49 + .../d2aafa234374_create_error_logs_table.py | 44 + ...bacd085c_add_mount_map_column_to_kernel.py | 26 + ...3fc5d6109_add_clone_allowed_to_vfolders.py | 30 + ...f5ec9ef3_convert_cpu_gpu_slots_to_float.py | 75 + .../d582942886ad_add_tag_to_kernels.py | 28 + ...89e7514_remove_keypair_concurrency_used.py | 26 + ...36b5_update_for_multicontainer_sessions.py | 104 + .../models/alembic/versions/d643752544de_.py | 24 + ...727b5da20e6_add_callback_url_to_kernels.py | 26 + ...520049_add_starts_at_field_into_kernels.py | 28 + ...bc1e053b880_add_keypair_resource_policy.py | 97 + .../dc9b66466e43_remove_clusterized.py | 24 + ...8ed5fcfedf_add_superadmin_role_for_user.py | 78 + ...3d_add_modified_at_to_users_and_kernels.py | 78 + ..._rename_kernel_dependencies_to_session_.py | 51 + .../versions/e7371ca5797a_rename_mem_stats.py | 24 + ...476f39_add_bootstrap_script_to_keypairs.py | 31 + .../eec98e65902a_merge_with_vfolder_clone.py | 24 + .../f0f4ee907155_dynamic_resource_slots.py | 210 + ...30eccf202_add_kernels_uuid_prefix_index.py | 31 + .../versions/f8a71c3bffa2_stringify_userid.py | 79 + ...add_state_column_to_vfolder_invitations.py | 30 + .../models/alembic/versions/ff4bfca66bf8_.py | 22 + src/ai/backend/manager/models/base.py | 845 ++++ src/ai/backend/manager/models/domain.py | 381 ++ src/ai/backend/manager/models/dotfile.py | 84 + src/ai/backend/manager/models/error_logs.py | 26 + src/ai/backend/manager/models/gql.py | 1294 ++++++ src/ai/backend/manager/models/group.py | 665 +++ src/ai/backend/manager/models/image.py | 910 ++++ src/ai/backend/manager/models/kernel.py | 1542 +++++++ src/ai/backend/manager/models/keypair.py | 616 +++ .../manager/models/minilang/__init__.py | 8 + .../manager/models/minilang/ordering.py | 82 + .../manager/models/minilang/queryfilter.py | 196 + .../backend/manager/models/resource_policy.py | 334 ++ .../backend/manager/models/resource_preset.py | 195 + .../backend/manager/models/scaling_group.py | 707 +++ .../manager/models/session_template.py | 246 ++ src/ai/backend/manager/models/storage.py | 303 ++ src/ai/backend/manager/models/user.py | 1151 +++++ src/ai/backend/manager/models/utils.py | 306 ++ src/ai/backend/manager/models/vfolder.py | 823 ++++ src/ai/backend/manager/pglock.py | 29 + src/ai/backend/manager/plugin/__init__.py | 0 .../backend/manager/plugin/error_monitor.py | 120 + src/ai/backend/manager/plugin/exceptions.py | 15 + src/ai/backend/manager/plugin/webapp.py | 31 + src/ai/backend/manager/py.typed | 1 + src/ai/backend/manager/registry.py | 2827 ++++++++++++ src/ai/backend/manager/scheduler/__init__.py | 0 .../backend/manager/scheduler/dispatcher.py | 1019 +++++ src/ai/backend/manager/scheduler/drf.py | 146 + src/ai/backend/manager/scheduler/fifo.py | 166 + src/ai/backend/manager/scheduler/mof.py | 71 + .../backend/manager/scheduler/predicates.py | 299 ++ src/ai/backend/manager/scheduler/types.py | 447 ++ src/ai/backend/manager/server.py | 759 ++++ src/ai/backend/manager/types.py | 48 + src/ai/backend/meta/BUILD | 3 + src/ai/backend/meta/__init__.py | 5 - src/ai/backend/plugin/BUILD | 34 + src/ai/backend/plugin/README.md | 7 + src/ai/backend/plugin/VERSION | 1 + src/ai/backend/plugin/__init__.py | 3 + src/ai/backend/plugin/entrypoint.py | 99 + src/ai/backend/plugin/py.typed | 1 + src/ai/backend/runner/.bash_profile | 6 + src/ai/backend/runner/.bashrc | 7 + src/ai/backend/runner/.dockerignore | 3 + src/ai/backend/runner/.tmux.conf | 82 + src/ai/backend/runner/.vimrc | 10 + src/ai/backend/runner/BUILD | 77 + .../DO_NOT_STORE_PERSISTENT_FILES_HERE.md | 9 + src/ai/backend/runner/README.md | 2 + src/ai/backend/runner/VERSION | 1 + src/ai/backend/runner/__init__.py | 0 .../backend/runner/dropbear.glibc.aarch64.bin | 3 + .../backend/runner/dropbear.glibc.x86_64.bin | 3 + .../backend/runner/dropbear.musl.aarch64.bin | 3 + .../backend/runner/dropbear.musl.x86_64.bin | 3 + .../runner/dropbearconvert.glibc.aarch64.bin | 3 + .../runner/dropbearconvert.glibc.x86_64.bin | 3 + .../runner/dropbearconvert.musl.aarch64.bin | 3 + .../runner/dropbearconvert.musl.x86_64.bin | 3 + .../runner/dropbearkey.glibc.aarch64.bin | 3 + .../runner/dropbearkey.glibc.x86_64.bin | 3 + .../runner/dropbearkey.musl.aarch64.bin | 3 + .../runner/dropbearkey.musl.x86_64.bin | 3 + src/ai/backend/runner/entrypoint.sh | 89 + src/ai/backend/runner/extract_dotfiles.py | 30 + src/ai/backend/runner/jail.alpine3.8.bin | 3 + src/ai/backend/runner/jail.ubuntu16.04.bin | 3 + src/ai/backend/runner/jupyter-custom.css | 275 ++ .../runner/krunner-extractor.dockerfile | 5 + .../krunner-extractor.img.aarch64.tar.xz | 3 + .../krunner-extractor.img.x86_64.tar.xz | 3 + src/ai/backend/runner/krunner-extractor.sh | 4 + .../runner/libbaihook.alpine3.8.aarch64.so | 3 + .../runner/libbaihook.alpine3.8.x86_64.so | 3 + .../runner/libbaihook.centos7.6.aarch64.so | 3 + .../runner/libbaihook.centos7.6.x86_64.so | 3 + .../runner/libbaihook.ubuntu18.04.x86_64.so | 3 + .../runner/libbaihook.ubuntu20.04.aarch64.so | 3 + .../runner/libbaihook.ubuntu20.04.x86_64.so | 3 + src/ai/backend/runner/logo.svg | 1 + src/ai/backend/runner/requirements.txt | 1 + src/ai/backend/runner/roboto-italic.ttf | 3 + src/ai/backend/runner/roboto.ttf | 3 + .../backend/runner/scp.alpine3.8.aarch64.bin | 3 + .../backend/runner/scp.alpine3.8.x86_64.bin | 3 + .../backend/runner/scp.centos7.6.aarch64.bin | 3 + .../backend/runner/scp.centos7.6.x86_64.bin | 3 + .../runner/scp.ubuntu16.04.aarch64.bin | 3 + .../backend/runner/scp.ubuntu16.04.x86_64.bin | 3 + .../runner/scp.ubuntu18.04.aarch64.bin | 3 + .../backend/runner/scp.ubuntu18.04.x86_64.bin | 3 + .../runner/scp.ubuntu20.04.aarch64.bin | 3 + .../backend/runner/scp.ubuntu20.04.x86_64.bin | 3 + .../runner/sftp-server.alpine3.8.aarch64.bin | 3 + .../runner/sftp-server.alpine3.8.x86_64.bin | 3 + .../runner/sftp-server.centos7.6.aarch64.bin | 3 + .../runner/sftp-server.centos7.6.x86_64.bin | 3 + .../sftp-server.ubuntu16.04.aarch64.bin | 3 + .../runner/sftp-server.ubuntu16.04.x86_64.bin | 3 + .../sftp-server.ubuntu18.04.aarch64.bin | 3 + .../runner/sftp-server.ubuntu18.04.x86_64.bin | 3 + .../sftp-server.ubuntu20.04.aarch64.bin | 3 + .../runner/sftp-server.ubuntu20.04.x86_64.bin | 3 + .../runner/su-exec.alpine3.8.aarch64.bin | 3 + .../runner/su-exec.alpine3.8.x86_64.bin | 3 + .../runner/su-exec.centos7.6.aarch64.bin | 3 + .../runner/su-exec.centos7.6.x86_64.bin | 3 + .../runner/su-exec.ubuntu16.04.aarch64.bin | 3 + .../runner/su-exec.ubuntu16.04.x86_64.bin | 3 + .../runner/su-exec.ubuntu18.04.aarch64.bin | 3 + .../runner/su-exec.ubuntu18.04.x86_64.bin | 3 + .../runner/su-exec.ubuntu20.04.aarch64.bin | 3 + .../runner/su-exec.ubuntu20.04.x86_64.bin | 3 + .../runner/terminfo.alpine3.8/s/screen | Bin 0 -> 1660 bytes .../terminfo.alpine3.8/s/screen-256color | Bin 0 -> 1995 bytes .../s/screen.xterm-256color | Bin 0 -> 3459 bytes .../backend/runner/terminfo.alpine3.8/x/xterm | Bin 0 -> 3617 bytes .../terminfo.alpine3.8/x/xterm+256color | Bin 0 -> 1090 bytes .../terminfo.alpine3.8/x/xterm-256color | Bin 0 -> 3713 bytes src/ai/backend/runner/tmux.glibc.aarch64.bin | 3 + src/ai/backend/runner/tmux.glibc.x86_64.bin | 3 + src/ai/backend/runner/tmux.musl.aarch64.bin | 3 + src/ai/backend/runner/tmux.musl.x86_64.bin | 3 + src/ai/backend/storage/BUILD | 44 + src/ai/backend/storage/README.md | 162 + src/ai/backend/storage/VERSION | 1 + src/ai/backend/storage/__init__.py | 3 + src/ai/backend/storage/abc.py | 251 ++ src/ai/backend/storage/api/__init__.py | 0 src/ai/backend/storage/api/client.py | 348 ++ src/ai/backend/storage/api/manager.py | 658 +++ src/ai/backend/storage/config.py | 84 + src/ai/backend/storage/context.py | 66 + src/ai/backend/storage/exception.py | 50 + src/ai/backend/storage/filelock.py | 50 + src/ai/backend/storage/netapp/__init__.py | 394 ++ src/ai/backend/storage/netapp/netappclient.py | 222 + src/ai/backend/storage/netapp/quotamanager.py | 165 + .../backend/storage/purestorage/__init__.py | 267 ++ src/ai/backend/storage/purestorage/purity.py | 151 + src/ai/backend/storage/py.typed | 1 + src/ai/backend/storage/server.py | 222 + src/ai/backend/storage/types.py | 104 + src/ai/backend/storage/utils.py | 129 + src/ai/backend/storage/vfs/__init__.py | 482 ++ src/ai/backend/storage/xfs/__init__.py | 255 ++ src/ai/backend/test/BUILD | 49 + src/ai/backend/test/VERSION | 1 + src/ai/backend/test/__init__.py | 3 + src/ai/backend/test/cli/__init__.py | 0 src/ai/backend/test/cli/__main__.py | 56 + src/ai/backend/test/cli/context.py | 6 + src/ai/backend/test/cli/utils.py | 30 + .../backend/test/cli_integration/__init__.py | 12 + .../test/cli_integration/admin/__init__.py | 9 + .../test/cli_integration/admin/test_domain.py | 102 + .../test/cli_integration/admin/test_group.py | 26 + .../test/cli_integration/admin/test_image.py | 22 + .../cli_integration/admin/test_keypair.py | 179 + .../admin/test_keypair_resource_policy.py | 116 + .../admin/test_scaling_group.py | 101 + .../cli_integration/admin/test_storage.py | 30 + .../test/cli_integration/admin/test_user.py | 206 + .../backend/test/cli_integration/conftest.py | 72 + .../test/cli_integration/user/__init__.py | 0 .../test/cli_integration/user/test_vfolder.py | 102 + src/ai/backend/test/py.typed | 1 + src/ai/backend/test/utils/__init__.py | 0 src/ai/backend/test/utils/cli.py | 39 + src/ai/backend/testutils/BUILD | 6 + src/ai/backend/testutils/__init__.py | 0 src/ai/backend/testutils/bootstrap.py | 113 + src/ai/backend/testutils/pants.py | 9 + src/ai/backend/web/BUILD | 45 + src/ai/backend/web/README.md | 64 + src/ai/backend/web/VERSION | 1 + src/ai/backend/web/__init__.py | 5 + src/ai/backend/web/auth.py | 66 + src/ai/backend/web/logging.py | 24 + src/ai/backend/web/proxy.py | 316 ++ src/ai/backend/web/py.typed | 1 + src/ai/backend/web/server.py | 694 +++ src/ai/backend/web/static | 1 + stubs/trafaret/BUILD | 3 + stubs/trafaret/__init__.pyi | 76 + stubs/trafaret/base.pyi | 94 + stubs/trafaret/constructor.pyi | 12 + stubs/trafaret/dataerror.pyi | 11 + stubs/trafaret/internet.pyi | 13 + stubs/trafaret/keys.pyi | 15 + stubs/trafaret/lib.pyi | 1 + stubs/trafaret/numeric.pyi | 15 + stubs/trafaret/regexp.pyi | 9 + tests/agent/BUILD | 11 + tests/agent/__init__.py | 0 tests/agent/conftest.py | 164 + tests/agent/docker/BUILD | 6 + tests/agent/docker/test_agent.py | 211 + tests/agent/test_agent.py | 64 + tests/agent/test_alloc_map.py | 874 ++++ tests/agent/test_files.py | 105 + tests/agent/test_kernel.py | 145 + tests/agent/test_resources.py | 91 + tests/agent/test_server.py | 370 ++ tests/agent/test_stats.py | 5 + tests/agent/test_utils.py | 90 + tests/common/BUILD | 44 + tests/common/__init__.py | 0 tests/common/conftest.py | 150 + tests/common/redis/.gitignore | 1 + tests/common/redis/__init__.py | 0 tests/common/redis/conftest.py | 37 + tests/common/redis/docker.py | 217 + tests/common/redis/native.py | 185 + tests/common/redis/redis-cluster.yml | 90 + tests/common/redis/redis-sentinel.dockerfile | 9 + tests/common/redis/sentinel.conf | 7 + tests/common/redis/test_connect.py | 114 + tests/common/redis/test_list.py | 265 ++ tests/common/redis/test_pipeline.py | 134 + tests/common/redis/test_pubsub.py | 256 ++ tests/common/redis/test_stream.py | 395 ++ tests/common/redis/types.py | 57 + tests/common/redis/utils.py | 113 + tests/common/test_argparse.py | 145 + tests/common/test_config.py | 89 + tests/common/test_distributed.py | 303 ++ tests/common/test_docker.py | 292 ++ tests/common/test_etcd.py | 297 ++ tests/common/test_events.py | 179 + tests/common/test_identity.py | 199 + tests/common/test_json.py | 19 + tests/common/test_logging.py | 81 + tests/common/test_msgpack.py | 29 + tests/common/test_plugin.py | 294 ++ tests/common/test_service_ports.py | 82 + tests/common/test_types.py | 288 ++ tests/common/test_utils.py | 329 ++ tests/common/test_validators.py | 469 ++ tests/manager/BUILD | 24 + tests/manager/__init__.py | 0 tests/manager/api/BUILD | 6 + tests/manager/api/test_auth.py | 126 + tests/manager/api/test_bgtask.py | 124 + tests/manager/api/test_config.py | 21 + tests/manager/api/test_exceptions.py | 52 + tests/manager/api/test_middlewares.py | 134 + tests/manager/api/test_ratelimit.py | 67 + tests/manager/api/test_utils.py | 85 + tests/manager/conftest.py | 714 +++ tests/manager/fixtures/example-keypairs.json | 172 + .../fixtures/example-resource-presets.json | 33 + .../fixtures/example-session-templates.json | 71 + tests/manager/model_factory.py | 194 + tests/manager/models/BUILD | 6 + tests/manager/models/test_dbutils.py | 47 + tests/manager/sample-ssl-cert/sample.crt | 15 + tests/manager/sample-ssl-cert/sample.csr | 11 + tests/manager/sample-ssl-cert/sample.key | 15 + tests/manager/test_advisory_lock.py | 68 + tests/manager/test_image.py | 95 + tests/manager/test_predicates.py | 162 + tests/manager/test_queryfilter.py | 301 ++ tests/manager/test_queryorder.py | 115 + tests/manager/test_registry.py | 157 + tests/manager/test_scheduler.py | 1110 +++++ tests/plugin/BUILD | 10 + tests/plugin/test_entrypoint.py | 63 + tests/storage-proxy/BUILD | 10 + tests/storage-proxy/conftest.py | 17 + tests/storage-proxy/test_netapp.py | 73 + tests/storage-proxy/test_purestorage.py | 89 + tests/storage-proxy/test_vfs.py | 123 + tests/storage-proxy/test_xfs.py | 267 ++ tests/test_dummy.py | 8 - tests/webserver/BUILD | 10 + tests/webserver/conftest.py | 0 tests/webserver/test_auth.py | 99 + tools/flake8.lock | 208 + tools/mypy.lock | 154 + tools/pants-linux-aarch64.patch | 68 + tools/pants-local | 28 + tools/pants-plugins/platform_resources/BUILD | 1 + .../platform_resources/register.py | 102 + tools/pants-plugins/setupgen/BUILD | 1 + tools/pants-plugins/setupgen/register.py | 130 + tools/pytest.lock | 955 ++++ 743 files changed, 105502 insertions(+), 599 deletions(-) create mode 100644 .flake8 create mode 100644 .gitattributes create mode 100644 .github/workflows/timeline-check.yml create mode 100644 .gitmodules create mode 100644 BUILD create mode 100644 BUILD_ROOT create mode 100644 MIGRATION.md create mode 100644 VERSION create mode 100755 backend.ai create mode 100644 changes/417.misc create mode 100644 changes/template.md create mode 100644 configs/agent/ci.toml create mode 100644 configs/agent/halfstack.toml create mode 100644 configs/agent/sample.toml create mode 100644 configs/manager/ci.toml create mode 100644 configs/manager/halfstack.alembic.ini create mode 100644 configs/manager/halfstack.toml create mode 100644 configs/manager/sample.etcd.config.json create mode 100644 configs/manager/sample.etcd.redis-sentinel.json create mode 100644 configs/manager/sample.etcd.redis-single.json create mode 100644 configs/manager/sample.etcd.volumes.json create mode 100644 configs/manager/sample.toml create mode 100644 configs/storage-proxy/sample.toml create mode 100644 configs/storage-proxy/ssl/manager-api-selfsigned.cert.pem create mode 100644 configs/storage-proxy/ssl/manager-api-selfsigned.key.pem create mode 100644 configs/webserver/sample.conf create mode 100644 docker/linuxkit-nsenter/Dockerfile create mode 100644 docker/socket-relay/Dockerfile create mode 100644 fixtures/manager/example-keypairs.json create mode 100644 fixtures/manager/example-resource-presets.json create mode 100755 fixtures/manager/example-session-templates.json create mode 100755 pants create mode 100644 pants.ci.toml create mode 100644 pants.toml create mode 100644 plugins/.gitignore create mode 100644 plugins/README.md create mode 100755 py create mode 100644 pyproject.toml create mode 100644 python-kernel.lock create mode 100644 python.lock create mode 100644 requirements.txt delete mode 100644 requirements/build.txt delete mode 100644 requirements/dev.txt delete mode 100644 requirements/docs.txt delete mode 100644 requirements/lint.txt delete mode 100644 requirements/main.txt delete mode 100644 requirements/test.txt delete mode 100644 requirements/typecheck.txt create mode 100755 scripts/agent/build-dropbear.sh create mode 100755 scripts/agent/build-krunner-extractor.sh create mode 100755 scripts/agent/build-sftpserver.sh create mode 100755 scripts/agent/build-socket-relay.sh create mode 100755 scripts/agent/build-suexec.sh create mode 100755 scripts/agent/build-tmux.sh create mode 100755 scripts/agent/ci/deploy.sh create mode 100755 scripts/agent/ci/install-manager.sh create mode 100755 scripts/agent/ci/overwrite-python.sh create mode 100755 scripts/agent/ci/run-manager.sh create mode 100644 scripts/agent/deploy-static/deploy_static_files.py create mode 100644 scripts/agent/eks.yml create mode 100755 scripts/agent/run-with-halfstack.sh create mode 100755 scripts/agent/update-metadata-iptables.sh create mode 100644 scripts/check-docker.py create mode 100755 scripts/install-plugin.sh create mode 100755 scripts/reinstall-plugins.sh create mode 100644 scripts/specialize.py delete mode 100644 setup.cfg delete mode 100644 setup.py create mode 100644 src/ai/backend/agent/BUILD create mode 100644 src/ai/backend/agent/README.k8s.md create mode 100644 src/ai/backend/agent/README.md create mode 120000 src/ai/backend/agent/VERSION create mode 100644 src/ai/backend/agent/__init__.py create mode 100644 src/ai/backend/agent/agent.py create mode 100644 src/ai/backend/agent/cli.py create mode 100644 src/ai/backend/agent/config.py create mode 100644 src/ai/backend/agent/docker/__init__.py create mode 100644 src/ai/backend/agent/docker/agent.py create mode 100644 src/ai/backend/agent/docker/backendai-socket-relay.img.aarch64.tar.gz create mode 100644 src/ai/backend/agent/docker/backendai-socket-relay.img.x86_64.tar.gz create mode 100644 src/ai/backend/agent/docker/backendai-socket-relay.version.txt create mode 100644 src/ai/backend/agent/docker/files.py create mode 100644 src/ai/backend/agent/docker/intrinsic.py create mode 100644 src/ai/backend/agent/docker/kernel.py create mode 100755 src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.aarch64.bin create mode 100755 src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.x86_64.bin create mode 100644 src/ai/backend/agent/docker/linuxkit-nsenter.img.aarch64.tar.gz create mode 100644 src/ai/backend/agent/docker/linuxkit-nsenter.img.x86_64.tar.gz create mode 100644 src/ai/backend/agent/docker/linuxkit-nsenter.version.txt create mode 100644 src/ai/backend/agent/docker/metadata/server.py create mode 100644 src/ai/backend/agent/docker/resources.py create mode 100644 src/ai/backend/agent/docker/utils.py create mode 100644 src/ai/backend/agent/exception.py create mode 100644 src/ai/backend/agent/fs.py create mode 100644 src/ai/backend/agent/kernel.py create mode 100644 src/ai/backend/agent/kubernetes/__init__.py create mode 100644 src/ai/backend/agent/kubernetes/agent.py create mode 100644 src/ai/backend/agent/kubernetes/backendai-socket-relay.img.tar.gz create mode 100644 src/ai/backend/agent/kubernetes/backendai-socket-relay.version.txt create mode 100644 src/ai/backend/agent/kubernetes/files.py create mode 100644 src/ai/backend/agent/kubernetes/intrinsic.py create mode 100644 src/ai/backend/agent/kubernetes/kernel.py create mode 100644 src/ai/backend/agent/kubernetes/kube_object.py create mode 100644 src/ai/backend/agent/kubernetes/resources.py create mode 100644 src/ai/backend/agent/kubernetes/utils.py create mode 100644 src/ai/backend/agent/linuxkit-metadata-proxy/main.go create mode 100644 src/ai/backend/agent/proxy.py create mode 100644 src/ai/backend/agent/py.typed create mode 100644 src/ai/backend/agent/resources.py create mode 100644 src/ai/backend/agent/server.py create mode 100644 src/ai/backend/agent/stats.py create mode 100644 src/ai/backend/agent/types.py create mode 100644 src/ai/backend/agent/utils.py create mode 100644 src/ai/backend/agent/vendor/__init__.py create mode 100644 src/ai/backend/agent/vendor/linux.py create mode 100644 src/ai/backend/agent/watcher.py create mode 100644 src/ai/backend/cli/BUILD create mode 100644 src/ai/backend/cli/README.md create mode 120000 src/ai/backend/cli/VERSION create mode 100644 src/ai/backend/cli/__init__.py create mode 100644 src/ai/backend/cli/__main__.py create mode 100644 src/ai/backend/cli/extensions.py create mode 100644 src/ai/backend/cli/interaction.py create mode 100644 src/ai/backend/cli/loader.py create mode 100644 src/ai/backend/cli/main.py create mode 100644 src/ai/backend/cli/py.typed create mode 100644 src/ai/backend/client/BUILD create mode 100644 src/ai/backend/client/README.rst create mode 120000 src/ai/backend/client/VERSION create mode 100644 src/ai/backend/client/__init__.py create mode 100644 src/ai/backend/client/auth.py create mode 100644 src/ai/backend/client/cli/__init__.py create mode 100644 src/ai/backend/client/cli/__main__.py create mode 100644 src/ai/backend/client/cli/admin/__init__.py create mode 100644 src/ai/backend/client/cli/admin/agent.py create mode 100644 src/ai/backend/client/cli/admin/domain.py create mode 100644 src/ai/backend/client/cli/admin/etcd.py create mode 100644 src/ai/backend/client/cli/admin/group.py create mode 100644 src/ai/backend/client/cli/admin/image.py create mode 100644 src/ai/backend/client/cli/admin/keypair.py create mode 100644 src/ai/backend/client/cli/admin/license.py create mode 100644 src/ai/backend/client/cli/admin/manager.py create mode 100644 src/ai/backend/client/cli/admin/resource.py create mode 100644 src/ai/backend/client/cli/admin/resource_policy.py create mode 100644 src/ai/backend/client/cli/admin/scaling_group.py create mode 100644 src/ai/backend/client/cli/admin/session.py create mode 100644 src/ai/backend/client/cli/admin/storage.py create mode 100644 src/ai/backend/client/cli/admin/user.py create mode 100644 src/ai/backend/client/cli/admin/vfolder.py create mode 100644 src/ai/backend/client/cli/announcement.py create mode 100644 src/ai/backend/client/cli/app.py create mode 100644 src/ai/backend/client/cli/config.py create mode 100644 src/ai/backend/client/cli/dotfile.py create mode 100644 src/ai/backend/client/cli/logs.py create mode 100644 src/ai/backend/client/cli/main.py create mode 100644 src/ai/backend/client/cli/pagination.py create mode 100644 src/ai/backend/client/cli/params.py create mode 100644 src/ai/backend/client/cli/pretty.py create mode 100644 src/ai/backend/client/cli/proxy.py create mode 100644 src/ai/backend/client/cli/run.py create mode 100644 src/ai/backend/client/cli/server_log.py create mode 100644 src/ai/backend/client/cli/session.py create mode 100644 src/ai/backend/client/cli/session_template.py create mode 100644 src/ai/backend/client/cli/ssh.py create mode 100644 src/ai/backend/client/cli/types.py create mode 100644 src/ai/backend/client/cli/vfolder.py create mode 100644 src/ai/backend/client/compat.py create mode 100644 src/ai/backend/client/config.py create mode 100644 src/ai/backend/client/exceptions.py create mode 100644 src/ai/backend/client/func/__init__.py create mode 100644 src/ai/backend/client/func/admin.py create mode 100644 src/ai/backend/client/func/agent.py create mode 100644 src/ai/backend/client/func/auth.py create mode 100644 src/ai/backend/client/func/base.py create mode 100644 src/ai/backend/client/func/bgtask.py create mode 100644 src/ai/backend/client/func/domain.py create mode 100644 src/ai/backend/client/func/dotfile.py create mode 100644 src/ai/backend/client/func/etcd.py create mode 100644 src/ai/backend/client/func/group.py create mode 100644 src/ai/backend/client/func/image.py create mode 100644 src/ai/backend/client/func/keypair.py create mode 100644 src/ai/backend/client/func/keypair_resource_policy.py create mode 100644 src/ai/backend/client/func/manager.py create mode 100644 src/ai/backend/client/func/resource.py create mode 100644 src/ai/backend/client/func/scaling_group.py create mode 100644 src/ai/backend/client/func/server_log.py create mode 100644 src/ai/backend/client/func/session.py create mode 100644 src/ai/backend/client/func/session_template.py create mode 100644 src/ai/backend/client/func/storage.py create mode 100644 src/ai/backend/client/func/system.py create mode 100644 src/ai/backend/client/func/user.py create mode 100644 src/ai/backend/client/func/vfolder.py create mode 100644 src/ai/backend/client/helper.py create mode 100644 src/ai/backend/client/load_balancing.py create mode 100644 src/ai/backend/client/output/__init__.py create mode 100644 src/ai/backend/client/output/console.py create mode 100644 src/ai/backend/client/output/fields.py create mode 100644 src/ai/backend/client/output/formatters.py create mode 100644 src/ai/backend/client/output/json.py create mode 100644 src/ai/backend/client/output/types.py create mode 100644 src/ai/backend/client/pagination.py create mode 100644 src/ai/backend/client/py.typed create mode 100644 src/ai/backend/client/request.py create mode 100644 src/ai/backend/client/session.py create mode 100644 src/ai/backend/client/test_utils.py create mode 100644 src/ai/backend/client/types.py create mode 100644 src/ai/backend/client/utils.py create mode 100644 src/ai/backend/client/versioning.py create mode 100644 src/ai/backend/common/BUILD create mode 100644 src/ai/backend/common/README.md create mode 120000 src/ai/backend/common/VERSION create mode 100644 src/ai/backend/common/__init__.py create mode 100644 src/ai/backend/common/argparse.py create mode 100644 src/ai/backend/common/asyncio.py create mode 100644 src/ai/backend/common/bgtask.py create mode 100644 src/ai/backend/common/cli.py create mode 100644 src/ai/backend/common/config.py create mode 100644 src/ai/backend/common/distributed.py create mode 100644 src/ai/backend/common/docker.py create mode 100644 src/ai/backend/common/enum_extension.py create mode 100644 src/ai/backend/common/enum_extension.pyi create mode 100644 src/ai/backend/common/etcd.py create mode 100644 src/ai/backend/common/events.py create mode 100644 src/ai/backend/common/exception.py create mode 100644 src/ai/backend/common/files.py create mode 100644 src/ai/backend/common/identity.py create mode 100644 src/ai/backend/common/json.py create mode 100644 src/ai/backend/common/lock.py create mode 100644 src/ai/backend/common/logging.py create mode 100644 src/ai/backend/common/logging_utils.py create mode 100644 src/ai/backend/common/msgpack.py create mode 100644 src/ai/backend/common/networking.py create mode 100644 src/ai/backend/common/plugin/__init__.py create mode 100644 src/ai/backend/common/plugin/hook.py create mode 100644 src/ai/backend/common/plugin/monitor.py create mode 100644 src/ai/backend/common/plugin/py.typed create mode 100644 src/ai/backend/common/py.typed create mode 100644 src/ai/backend/common/redis.py create mode 100644 src/ai/backend/common/sd_notify.py create mode 100644 src/ai/backend/common/service_ports.py create mode 100644 src/ai/backend/common/testutils.py create mode 100644 src/ai/backend/common/types.py create mode 100644 src/ai/backend/common/utils.py create mode 100644 src/ai/backend/common/validators.py create mode 100644 src/ai/backend/helpers/BUILD create mode 100644 src/ai/backend/helpers/README.md create mode 120000 src/ai/backend/helpers/VERSION create mode 100644 src/ai/backend/helpers/__init__.py create mode 100644 src/ai/backend/helpers/package.py create mode 100644 src/ai/backend/kernel/BUILD create mode 100644 src/ai/backend/kernel/README.md create mode 120000 src/ai/backend/kernel/VERSION create mode 100644 src/ai/backend/kernel/__init__.py create mode 100644 src/ai/backend/kernel/__main__.py create mode 100644 src/ai/backend/kernel/app/__init__.py create mode 100644 src/ai/backend/kernel/base.py create mode 100644 src/ai/backend/kernel/c/__init__.py create mode 100644 src/ai/backend/kernel/compat.py create mode 100644 src/ai/backend/kernel/cpp/__init__.py create mode 100644 src/ai/backend/kernel/exception.py create mode 100644 src/ai/backend/kernel/git/__init__.py create mode 100644 src/ai/backend/kernel/golang/__init__.py create mode 100644 src/ai/backend/kernel/haskell/__init__.py create mode 100644 src/ai/backend/kernel/intrinsic.py create mode 100644 src/ai/backend/kernel/java/LablupPatches.java create mode 100644 src/ai/backend/kernel/java/__init__.py create mode 100644 src/ai/backend/kernel/julia/__init__.py create mode 100644 src/ai/backend/kernel/jupyter_client.py create mode 100644 src/ai/backend/kernel/logging.py create mode 100644 src/ai/backend/kernel/lua/__init__.py create mode 100644 src/ai/backend/kernel/nodejs/__init__.py create mode 100644 src/ai/backend/kernel/octave/__init__.py create mode 100644 src/ai/backend/kernel/php/__init__.py create mode 100644 src/ai/backend/kernel/python/__init__.py create mode 100644 src/ai/backend/kernel/python/drawing/__init__.py create mode 100644 src/ai/backend/kernel/python/drawing/canvas.py create mode 100644 src/ai/backend/kernel/python/drawing/color.py create mode 100644 src/ai/backend/kernel/python/drawing/encoding.py create mode 100644 src/ai/backend/kernel/python/drawing/turtle.py create mode 100644 src/ai/backend/kernel/python/sitecustomize.py create mode 100644 src/ai/backend/kernel/python/types.py create mode 100644 src/ai/backend/kernel/r/__init__.py create mode 100644 src/ai/backend/kernel/r_server_ms/__init__.py create mode 100644 src/ai/backend/kernel/requirements.txt create mode 100644 src/ai/backend/kernel/rust/__init__.py create mode 100644 src/ai/backend/kernel/scheme/__init__.py create mode 100644 src/ai/backend/kernel/service.py create mode 100644 src/ai/backend/kernel/service_actions.py create mode 100644 src/ai/backend/kernel/terminal.py create mode 100644 src/ai/backend/kernel/test_utils.py create mode 100644 src/ai/backend/kernel/utils.py create mode 100644 src/ai/backend/kernel/vendor/__init__.py create mode 100644 src/ai/backend/kernel/vendor/aws_polly/__init__.py create mode 100644 src/ai/backend/kernel/vendor/aws_polly/inproc.py create mode 100644 src/ai/backend/kernel/vendor/h2o/__init__.py create mode 100644 src/ai/backend/manager/BUILD create mode 100644 src/ai/backend/manager/README.md create mode 120000 src/ai/backend/manager/VERSION create mode 100644 src/ai/backend/manager/__init__.py create mode 100644 src/ai/backend/manager/api/__init__.py create mode 100644 src/ai/backend/manager/api/admin.py create mode 100644 src/ai/backend/manager/api/auth.py create mode 100644 src/ai/backend/manager/api/cluster_template.py create mode 100644 src/ai/backend/manager/api/context.py create mode 100644 src/ai/backend/manager/api/domainconfig.py create mode 100644 src/ai/backend/manager/api/etcd.py create mode 100644 src/ai/backend/manager/api/events.py create mode 100644 src/ai/backend/manager/api/exceptions.py create mode 100644 src/ai/backend/manager/api/groupconfig.py create mode 100644 src/ai/backend/manager/api/image.py create mode 100644 src/ai/backend/manager/api/logs.py create mode 100644 src/ai/backend/manager/api/manager.py create mode 100644 src/ai/backend/manager/api/py.typed create mode 100644 src/ai/backend/manager/api/ratelimit.py create mode 100644 src/ai/backend/manager/api/resource.py create mode 100644 src/ai/backend/manager/api/scaling_group.py create mode 100644 src/ai/backend/manager/api/session.py create mode 100644 src/ai/backend/manager/api/session_template.py create mode 100644 src/ai/backend/manager/api/stream.py create mode 100644 src/ai/backend/manager/api/types.py create mode 100644 src/ai/backend/manager/api/userconfig.py create mode 100644 src/ai/backend/manager/api/utils.py create mode 100644 src/ai/backend/manager/api/vfolder.py create mode 100644 src/ai/backend/manager/api/wsproxy.py create mode 100644 src/ai/backend/manager/cli/__init__.py create mode 100644 src/ai/backend/manager/cli/__main__.py create mode 100644 src/ai/backend/manager/cli/context.py create mode 100644 src/ai/backend/manager/cli/dbschema.py create mode 100644 src/ai/backend/manager/cli/etcd.py create mode 100644 src/ai/backend/manager/cli/fixture.py create mode 100644 src/ai/backend/manager/cli/gql.py create mode 100644 src/ai/backend/manager/cli/image.py create mode 100644 src/ai/backend/manager/cli/image_impl.py create mode 100644 src/ai/backend/manager/config.py create mode 100644 src/ai/backend/manager/container_registry/__init__.py create mode 100644 src/ai/backend/manager/container_registry/base.py create mode 100644 src/ai/backend/manager/container_registry/docker.py create mode 100644 src/ai/backend/manager/container_registry/harbor.py create mode 100644 src/ai/backend/manager/defs.py create mode 100644 src/ai/backend/manager/exceptions.py create mode 100644 src/ai/backend/manager/idle.py create mode 100644 src/ai/backend/manager/models/__init__.py create mode 100644 src/ai/backend/manager/models/agent.py create mode 100644 src/ai/backend/manager/models/alembic/README create mode 100644 src/ai/backend/manager/models/alembic/env.py create mode 100644 src/ai/backend/manager/models/alembic/script.py.mako create mode 100644 src/ai/backend/manager/models/alembic/versions/01456c812164_add_idle_timeout_to_keypair_resource_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/015d84d5a5ef_add_image_table.py create mode 100644 src/ai/backend/manager/models/alembic/versions/0262e50e90e0_add_ssh_keypair_into_keypair.py create mode 100644 src/ai/backend/manager/models/alembic/versions/02950808ca3d_add_agent_version.py create mode 100644 src/ai/backend/manager/models/alembic/versions/06184d82a211_add_session_creation_id.py create mode 100644 src/ai/backend/manager/models/alembic/versions/0c5733f80e4d_index_kernel_timestamps.py create mode 100644 src/ai/backend/manager/models/alembic/versions/0d553d59f369_users_replace_is_active_to_status_and_its_info.py create mode 100644 src/ai/backend/manager/models/alembic/versions/0e558d06e0e3_add_service_ports.py create mode 100644 src/ai/backend/manager/models/alembic/versions/0f3bc98edaa0_more_status.py create mode 100644 src/ai/backend/manager/models/alembic/versions/0f7a4b643940_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/10e39a34eed5_enlarge_kernels_lang_column_length.py create mode 100644 src/ai/backend/manager/models/alembic/versions/11146ba02235_change_char_col_to_str.py create mode 100644 src/ai/backend/manager/models/alembic/versions/185852ff9872_add_vfolder_permissions_table.py create mode 100644 src/ai/backend/manager/models/alembic/versions/1e673659b283_add_clusterized_column_to_agents_table.py create mode 100644 src/ai/backend/manager/models/alembic/versions/1e8531583e20_add_dotfile_column_to_keypairs.py create mode 100644 src/ai/backend/manager/models/alembic/versions/1fa6a31ea8e3_add_inviter_field_for_vfolder_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/202b6dcbc159_add_internal_data_to_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/22964745c12b_add_total_resource_slots_to_group.py create mode 100644 src/ai/backend/manager/models/alembic/versions/22e52d03fc61_add_allowed_docker_registries_in_domains.py create mode 100644 src/ai/backend/manager/models/alembic/versions/250e8656cf45_add_status_data.py create mode 100644 src/ai/backend/manager/models/alembic/versions/25e903510fa1_add_dotfiles_to_domains_and_groups.py create mode 100644 src/ai/backend/manager/models/alembic/versions/26d0c387e764_create_vfolder_invitations_table.py create mode 100644 src/ai/backend/manager/models/alembic/versions/2a82340fa30e_add_mounts_info_in_kernel_db.py create mode 100644 src/ai/backend/manager/models/alembic/versions/2b0931e4a059_convert_lang_to_image_and_registry.py create mode 100644 src/ai/backend/manager/models/alembic/versions/352fa4f88f61_add_tpu_slot_on_kernel_model.py create mode 100644 src/ai/backend/manager/models/alembic/versions/3bb80d1887d6_add_preopen_ports.py create mode 100644 src/ai/backend/manager/models/alembic/versions/3cf19d906e71_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/3f1dafab60b2_merge.py create mode 100644 src/ai/backend/manager/models/alembic/versions/405aa2c39458_job_queue.py create mode 100644 src/ai/backend/manager/models/alembic/versions/4545f5c948b3_add_io_scratch_size_stats.py create mode 100644 src/ai/backend/manager/models/alembic/versions/48ab2dfefba9_reindex_kernel_updated_order.py create mode 100644 src/ai/backend/manager/models/alembic/versions/4b7b650bc30e_add_creator_in_vfolders.py create mode 100644 src/ai/backend/manager/models/alembic/versions/4b8a66fb8d82_revamp_keypairs.py create mode 100644 src/ai/backend/manager/models/alembic/versions/4cc87e7fbfdf_stats_refactor.py create mode 100644 src/ai/backend/manager/models/alembic/versions/513164749de4_add_cancelled_to_kernelstatus.py create mode 100644 src/ai/backend/manager/models/alembic/versions/518ecf41f567_add_index_for_cluster_role.py create mode 100644 src/ai/backend/manager/models/alembic/versions/51dddd79aa21_add_logs_column_on_kernel_table.py create mode 100644 src/ai/backend/manager/models/alembic/versions/529113b08c2c_add_vfolder_type_column.py create mode 100644 src/ai/backend/manager/models/alembic/versions/548cc8aa49c8_update_cluster_columns_in_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/57b523dec0e8_add_tpu_slots.py create mode 100644 src/ai/backend/manager/models/alembic/versions/57e717103287_rename_clone_allowed_to_cloneable.py create mode 100644 src/ai/backend/manager/models/alembic/versions/5b45f28d2cac_add_resource_opts_in_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/5d8e6043455e_add_user_group_ids_in_vfolder.py create mode 100644 src/ai/backend/manager/models/alembic/versions/5de06da3c2b5_init.py create mode 100644 src/ai/backend/manager/models/alembic/versions/5e88398bc340_add_unmanaged_path_column_to_vfolders.py create mode 100644 src/ai/backend/manager/models/alembic/versions/60a1effa77d2_add_coordinator_address_column_on_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/65c4a109bbc7_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/6f1c1b83870a_merge_user_s_first__last_name_into_full_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/6f5fe19894b7_vfolder_invitation_state_to_enum_type.py create mode 100644 src/ai/backend/manager/models/alembic/versions/7a82e0c70122_add_group_model.py create mode 100644 src/ai/backend/manager/models/alembic/versions/7dd1d81c3204_add_vfolder_mounts_to_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/7ea324d0535b_vfolder_and_kernel.py create mode 100644 src/ai/backend/manager/models/alembic/versions/80176413d8aa_keypairs_get_is_admin.py create mode 100644 src/ai/backend/manager/models/alembic/versions/819c2b3830a9_add_user_model.py create mode 100644 src/ai/backend/manager/models/alembic/versions/81c264528f20_add_max_session_lifetime.py create mode 100644 src/ai/backend/manager/models/alembic/versions/854bd902b1bc_change_kernel_identification.py create mode 100644 src/ai/backend/manager/models/alembic/versions/8679d0a7e22b_add_scheduled_to_kernelstatus.py create mode 100644 src/ai/backend/manager/models/alembic/versions/8e660aa31fe3_add_resource_presets.py create mode 100644 src/ai/backend/manager/models/alembic/versions/911023380bc9_add_architecture_column_on_agents.py create mode 100644 src/ai/backend/manager/models/alembic/versions/93e9d31d40bf_agent_add_region.py create mode 100644 src/ai/backend/manager/models/alembic/versions/97f6c80c8aa5_merge.py create mode 100644 src/ai/backend/manager/models/alembic/versions/9a91532c8534_add_scaling_group.py create mode 100644 src/ai/backend/manager/models/alembic/versions/9bd986a75a2a_allow_kernels_scaling_group_nullable.py create mode 100644 src/ai/backend/manager/models/alembic/versions/9c89b9011872_add_attached_devices_field_in_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/9cd61b1ae70d_add_scheduable_field_to_agents.py create mode 100644 src/ai/backend/manager/models/alembic/versions/a1fd4e7b7782_enumerate_vfolder_perms.py create mode 100644 src/ai/backend/manager/models/alembic/versions/a7ca9f175d5f_merge.py create mode 100644 src/ai/backend/manager/models/alembic/versions/babc74594aa6_add_partial_index_to_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/bae1a7326e8a_add_domain_model.py create mode 100644 src/ai/backend/manager/models/alembic/versions/bf4bae8f942e_add_kernel_host.py create mode 100644 src/ai/backend/manager/models/alembic/versions/c092dabf3ee5_add_batch_session.py create mode 100644 src/ai/backend/manager/models/alembic/versions/c1409ad0e8da_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/c3e74dcf1808_add_environ_to_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/c401d78cc7b9_add_allowed_vfolder_hosts_to_domain_and_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/c481d3dc6c7d_add_shared_memory_to_resource_presets.py create mode 100644 src/ai/backend/manager/models/alembic/versions/c5e4e764f9e3_add_domain_group_user_fields_to_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/ce209920f654_create_task_template_table.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d2aafa234374_create_error_logs_table.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d452bacd085c_add_mount_map_column_to_kernel.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d463fc5d6109_add_clone_allowed_to_vfolders.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d52bf5ec9ef3_convert_cpu_gpu_slots_to_float.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d582942886ad_add_tag_to_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d59ff89e7514_remove_keypair_concurrency_used.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d5cc54fd36b5_update_for_multicontainer_sessions.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d643752544de_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/d727b5da20e6_add_callback_url_to_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/da24ff520049_add_starts_at_field_into_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/dbc1e053b880_add_keypair_resource_policy.py create mode 100644 src/ai/backend/manager/models/alembic/versions/dc9b66466e43_remove_clusterized.py create mode 100644 src/ai/backend/manager/models/alembic/versions/e18ed5fcfedf_add_superadmin_role_for_user.py create mode 100644 src/ai/backend/manager/models/alembic/versions/e35332f8d23d_add_modified_at_to_users_and_kernels.py create mode 100644 src/ai/backend/manager/models/alembic/versions/e421c02cf9e4_rename_kernel_dependencies_to_session_.py create mode 100644 src/ai/backend/manager/models/alembic/versions/e7371ca5797a_rename_mem_stats.py create mode 100644 src/ai/backend/manager/models/alembic/versions/ed666f476f39_add_bootstrap_script_to_keypairs.py create mode 100644 src/ai/backend/manager/models/alembic/versions/eec98e65902a_merge_with_vfolder_clone.py create mode 100644 src/ai/backend/manager/models/alembic/versions/f0f4ee907155_dynamic_resource_slots.py create mode 100644 src/ai/backend/manager/models/alembic/versions/f5530eccf202_add_kernels_uuid_prefix_index.py create mode 100644 src/ai/backend/manager/models/alembic/versions/f8a71c3bffa2_stringify_userid.py create mode 100644 src/ai/backend/manager/models/alembic/versions/f9971fbb34d9_add_state_column_to_vfolder_invitations.py create mode 100644 src/ai/backend/manager/models/alembic/versions/ff4bfca66bf8_.py create mode 100644 src/ai/backend/manager/models/base.py create mode 100644 src/ai/backend/manager/models/domain.py create mode 100644 src/ai/backend/manager/models/dotfile.py create mode 100644 src/ai/backend/manager/models/error_logs.py create mode 100644 src/ai/backend/manager/models/gql.py create mode 100644 src/ai/backend/manager/models/group.py create mode 100644 src/ai/backend/manager/models/image.py create mode 100644 src/ai/backend/manager/models/kernel.py create mode 100644 src/ai/backend/manager/models/keypair.py create mode 100644 src/ai/backend/manager/models/minilang/__init__.py create mode 100644 src/ai/backend/manager/models/minilang/ordering.py create mode 100644 src/ai/backend/manager/models/minilang/queryfilter.py create mode 100644 src/ai/backend/manager/models/resource_policy.py create mode 100644 src/ai/backend/manager/models/resource_preset.py create mode 100644 src/ai/backend/manager/models/scaling_group.py create mode 100644 src/ai/backend/manager/models/session_template.py create mode 100644 src/ai/backend/manager/models/storage.py create mode 100644 src/ai/backend/manager/models/user.py create mode 100644 src/ai/backend/manager/models/utils.py create mode 100644 src/ai/backend/manager/models/vfolder.py create mode 100644 src/ai/backend/manager/pglock.py create mode 100644 src/ai/backend/manager/plugin/__init__.py create mode 100644 src/ai/backend/manager/plugin/error_monitor.py create mode 100644 src/ai/backend/manager/plugin/exceptions.py create mode 100644 src/ai/backend/manager/plugin/webapp.py create mode 100644 src/ai/backend/manager/py.typed create mode 100644 src/ai/backend/manager/registry.py create mode 100644 src/ai/backend/manager/scheduler/__init__.py create mode 100644 src/ai/backend/manager/scheduler/dispatcher.py create mode 100644 src/ai/backend/manager/scheduler/drf.py create mode 100644 src/ai/backend/manager/scheduler/fifo.py create mode 100644 src/ai/backend/manager/scheduler/mof.py create mode 100644 src/ai/backend/manager/scheduler/predicates.py create mode 100644 src/ai/backend/manager/scheduler/types.py create mode 100644 src/ai/backend/manager/server.py create mode 100644 src/ai/backend/manager/types.py create mode 100644 src/ai/backend/meta/BUILD create mode 100644 src/ai/backend/plugin/BUILD create mode 100644 src/ai/backend/plugin/README.md create mode 120000 src/ai/backend/plugin/VERSION create mode 100644 src/ai/backend/plugin/__init__.py create mode 100644 src/ai/backend/plugin/entrypoint.py create mode 100644 src/ai/backend/plugin/py.typed create mode 100644 src/ai/backend/runner/.bash_profile create mode 100644 src/ai/backend/runner/.bashrc create mode 100644 src/ai/backend/runner/.dockerignore create mode 100644 src/ai/backend/runner/.tmux.conf create mode 100644 src/ai/backend/runner/.vimrc create mode 100644 src/ai/backend/runner/BUILD create mode 100644 src/ai/backend/runner/DO_NOT_STORE_PERSISTENT_FILES_HERE.md create mode 100644 src/ai/backend/runner/README.md create mode 120000 src/ai/backend/runner/VERSION create mode 100644 src/ai/backend/runner/__init__.py create mode 100755 src/ai/backend/runner/dropbear.glibc.aarch64.bin create mode 100755 src/ai/backend/runner/dropbear.glibc.x86_64.bin create mode 100755 src/ai/backend/runner/dropbear.musl.aarch64.bin create mode 100755 src/ai/backend/runner/dropbear.musl.x86_64.bin create mode 100755 src/ai/backend/runner/dropbearconvert.glibc.aarch64.bin create mode 100755 src/ai/backend/runner/dropbearconvert.glibc.x86_64.bin create mode 100755 src/ai/backend/runner/dropbearconvert.musl.aarch64.bin create mode 100755 src/ai/backend/runner/dropbearconvert.musl.x86_64.bin create mode 100755 src/ai/backend/runner/dropbearkey.glibc.aarch64.bin create mode 100755 src/ai/backend/runner/dropbearkey.glibc.x86_64.bin create mode 100755 src/ai/backend/runner/dropbearkey.musl.aarch64.bin create mode 100755 src/ai/backend/runner/dropbearkey.musl.x86_64.bin create mode 100755 src/ai/backend/runner/entrypoint.sh create mode 100644 src/ai/backend/runner/extract_dotfiles.py create mode 100755 src/ai/backend/runner/jail.alpine3.8.bin create mode 100755 src/ai/backend/runner/jail.ubuntu16.04.bin create mode 100644 src/ai/backend/runner/jupyter-custom.css create mode 100644 src/ai/backend/runner/krunner-extractor.dockerfile create mode 100644 src/ai/backend/runner/krunner-extractor.img.aarch64.tar.xz create mode 100644 src/ai/backend/runner/krunner-extractor.img.x86_64.tar.xz create mode 100755 src/ai/backend/runner/krunner-extractor.sh create mode 100755 src/ai/backend/runner/libbaihook.alpine3.8.aarch64.so create mode 100755 src/ai/backend/runner/libbaihook.alpine3.8.x86_64.so create mode 100755 src/ai/backend/runner/libbaihook.centos7.6.aarch64.so create mode 100755 src/ai/backend/runner/libbaihook.centos7.6.x86_64.so create mode 100755 src/ai/backend/runner/libbaihook.ubuntu18.04.x86_64.so create mode 100755 src/ai/backend/runner/libbaihook.ubuntu20.04.aarch64.so create mode 100755 src/ai/backend/runner/libbaihook.ubuntu20.04.x86_64.so create mode 100644 src/ai/backend/runner/logo.svg create mode 100644 src/ai/backend/runner/requirements.txt create mode 100644 src/ai/backend/runner/roboto-italic.ttf create mode 100644 src/ai/backend/runner/roboto.ttf create mode 100755 src/ai/backend/runner/scp.alpine3.8.aarch64.bin create mode 100755 src/ai/backend/runner/scp.alpine3.8.x86_64.bin create mode 100755 src/ai/backend/runner/scp.centos7.6.aarch64.bin create mode 100755 src/ai/backend/runner/scp.centos7.6.x86_64.bin create mode 100755 src/ai/backend/runner/scp.ubuntu16.04.aarch64.bin create mode 100755 src/ai/backend/runner/scp.ubuntu16.04.x86_64.bin create mode 100755 src/ai/backend/runner/scp.ubuntu18.04.aarch64.bin create mode 100755 src/ai/backend/runner/scp.ubuntu18.04.x86_64.bin create mode 100755 src/ai/backend/runner/scp.ubuntu20.04.aarch64.bin create mode 100755 src/ai/backend/runner/scp.ubuntu20.04.x86_64.bin create mode 100755 src/ai/backend/runner/sftp-server.alpine3.8.aarch64.bin create mode 100755 src/ai/backend/runner/sftp-server.alpine3.8.x86_64.bin create mode 100755 src/ai/backend/runner/sftp-server.centos7.6.aarch64.bin create mode 100755 src/ai/backend/runner/sftp-server.centos7.6.x86_64.bin create mode 100755 src/ai/backend/runner/sftp-server.ubuntu16.04.aarch64.bin create mode 100755 src/ai/backend/runner/sftp-server.ubuntu16.04.x86_64.bin create mode 100755 src/ai/backend/runner/sftp-server.ubuntu18.04.aarch64.bin create mode 100755 src/ai/backend/runner/sftp-server.ubuntu18.04.x86_64.bin create mode 100755 src/ai/backend/runner/sftp-server.ubuntu20.04.aarch64.bin create mode 100755 src/ai/backend/runner/sftp-server.ubuntu20.04.x86_64.bin create mode 100755 src/ai/backend/runner/su-exec.alpine3.8.aarch64.bin create mode 100755 src/ai/backend/runner/su-exec.alpine3.8.x86_64.bin create mode 100755 src/ai/backend/runner/su-exec.centos7.6.aarch64.bin create mode 100755 src/ai/backend/runner/su-exec.centos7.6.x86_64.bin create mode 100755 src/ai/backend/runner/su-exec.ubuntu16.04.aarch64.bin create mode 100755 src/ai/backend/runner/su-exec.ubuntu16.04.x86_64.bin create mode 100755 src/ai/backend/runner/su-exec.ubuntu18.04.aarch64.bin create mode 100755 src/ai/backend/runner/su-exec.ubuntu18.04.x86_64.bin create mode 100755 src/ai/backend/runner/su-exec.ubuntu20.04.aarch64.bin create mode 100755 src/ai/backend/runner/su-exec.ubuntu20.04.x86_64.bin create mode 100644 src/ai/backend/runner/terminfo.alpine3.8/s/screen create mode 100644 src/ai/backend/runner/terminfo.alpine3.8/s/screen-256color create mode 100644 src/ai/backend/runner/terminfo.alpine3.8/s/screen.xterm-256color create mode 100644 src/ai/backend/runner/terminfo.alpine3.8/x/xterm create mode 100644 src/ai/backend/runner/terminfo.alpine3.8/x/xterm+256color create mode 100644 src/ai/backend/runner/terminfo.alpine3.8/x/xterm-256color create mode 100755 src/ai/backend/runner/tmux.glibc.aarch64.bin create mode 100755 src/ai/backend/runner/tmux.glibc.x86_64.bin create mode 100755 src/ai/backend/runner/tmux.musl.aarch64.bin create mode 100755 src/ai/backend/runner/tmux.musl.x86_64.bin create mode 100644 src/ai/backend/storage/BUILD create mode 100644 src/ai/backend/storage/README.md create mode 120000 src/ai/backend/storage/VERSION create mode 100644 src/ai/backend/storage/__init__.py create mode 100644 src/ai/backend/storage/abc.py create mode 100644 src/ai/backend/storage/api/__init__.py create mode 100644 src/ai/backend/storage/api/client.py create mode 100644 src/ai/backend/storage/api/manager.py create mode 100644 src/ai/backend/storage/config.py create mode 100644 src/ai/backend/storage/context.py create mode 100644 src/ai/backend/storage/exception.py create mode 100644 src/ai/backend/storage/filelock.py create mode 100644 src/ai/backend/storage/netapp/__init__.py create mode 100644 src/ai/backend/storage/netapp/netappclient.py create mode 100644 src/ai/backend/storage/netapp/quotamanager.py create mode 100644 src/ai/backend/storage/purestorage/__init__.py create mode 100644 src/ai/backend/storage/purestorage/purity.py create mode 100644 src/ai/backend/storage/py.typed create mode 100644 src/ai/backend/storage/server.py create mode 100644 src/ai/backend/storage/types.py create mode 100644 src/ai/backend/storage/utils.py create mode 100644 src/ai/backend/storage/vfs/__init__.py create mode 100644 src/ai/backend/storage/xfs/__init__.py create mode 100644 src/ai/backend/test/BUILD create mode 120000 src/ai/backend/test/VERSION create mode 100644 src/ai/backend/test/__init__.py create mode 100644 src/ai/backend/test/cli/__init__.py create mode 100644 src/ai/backend/test/cli/__main__.py create mode 100644 src/ai/backend/test/cli/context.py create mode 100644 src/ai/backend/test/cli/utils.py create mode 100644 src/ai/backend/test/cli_integration/__init__.py create mode 100644 src/ai/backend/test/cli_integration/admin/__init__.py create mode 100644 src/ai/backend/test/cli_integration/admin/test_domain.py create mode 100644 src/ai/backend/test/cli_integration/admin/test_group.py create mode 100644 src/ai/backend/test/cli_integration/admin/test_image.py create mode 100644 src/ai/backend/test/cli_integration/admin/test_keypair.py create mode 100644 src/ai/backend/test/cli_integration/admin/test_keypair_resource_policy.py create mode 100644 src/ai/backend/test/cli_integration/admin/test_scaling_group.py create mode 100644 src/ai/backend/test/cli_integration/admin/test_storage.py create mode 100644 src/ai/backend/test/cli_integration/admin/test_user.py create mode 100644 src/ai/backend/test/cli_integration/conftest.py create mode 100644 src/ai/backend/test/cli_integration/user/__init__.py create mode 100644 src/ai/backend/test/cli_integration/user/test_vfolder.py create mode 100644 src/ai/backend/test/py.typed create mode 100644 src/ai/backend/test/utils/__init__.py create mode 100644 src/ai/backend/test/utils/cli.py create mode 100644 src/ai/backend/testutils/BUILD create mode 100644 src/ai/backend/testutils/__init__.py create mode 100644 src/ai/backend/testutils/bootstrap.py create mode 100644 src/ai/backend/testutils/pants.py create mode 100644 src/ai/backend/web/BUILD create mode 100644 src/ai/backend/web/README.md create mode 120000 src/ai/backend/web/VERSION create mode 100644 src/ai/backend/web/__init__.py create mode 100644 src/ai/backend/web/auth.py create mode 100644 src/ai/backend/web/logging.py create mode 100644 src/ai/backend/web/proxy.py create mode 100644 src/ai/backend/web/py.typed create mode 100644 src/ai/backend/web/server.py create mode 160000 src/ai/backend/web/static create mode 100644 stubs/trafaret/BUILD create mode 100644 stubs/trafaret/__init__.pyi create mode 100644 stubs/trafaret/base.pyi create mode 100644 stubs/trafaret/constructor.pyi create mode 100644 stubs/trafaret/dataerror.pyi create mode 100644 stubs/trafaret/internet.pyi create mode 100644 stubs/trafaret/keys.pyi create mode 100644 stubs/trafaret/lib.pyi create mode 100644 stubs/trafaret/numeric.pyi create mode 100644 stubs/trafaret/regexp.pyi create mode 100644 tests/agent/BUILD create mode 100644 tests/agent/__init__.py create mode 100644 tests/agent/conftest.py create mode 100644 tests/agent/docker/BUILD create mode 100644 tests/agent/docker/test_agent.py create mode 100644 tests/agent/test_agent.py create mode 100644 tests/agent/test_alloc_map.py create mode 100644 tests/agent/test_files.py create mode 100644 tests/agent/test_kernel.py create mode 100644 tests/agent/test_resources.py create mode 100644 tests/agent/test_server.py create mode 100644 tests/agent/test_stats.py create mode 100644 tests/agent/test_utils.py create mode 100644 tests/common/BUILD create mode 100644 tests/common/__init__.py create mode 100644 tests/common/conftest.py create mode 100644 tests/common/redis/.gitignore create mode 100644 tests/common/redis/__init__.py create mode 100644 tests/common/redis/conftest.py create mode 100644 tests/common/redis/docker.py create mode 100644 tests/common/redis/native.py create mode 100644 tests/common/redis/redis-cluster.yml create mode 100644 tests/common/redis/redis-sentinel.dockerfile create mode 100644 tests/common/redis/sentinel.conf create mode 100644 tests/common/redis/test_connect.py create mode 100644 tests/common/redis/test_list.py create mode 100644 tests/common/redis/test_pipeline.py create mode 100644 tests/common/redis/test_pubsub.py create mode 100644 tests/common/redis/test_stream.py create mode 100644 tests/common/redis/types.py create mode 100644 tests/common/redis/utils.py create mode 100644 tests/common/test_argparse.py create mode 100644 tests/common/test_config.py create mode 100644 tests/common/test_distributed.py create mode 100644 tests/common/test_docker.py create mode 100644 tests/common/test_etcd.py create mode 100644 tests/common/test_events.py create mode 100644 tests/common/test_identity.py create mode 100644 tests/common/test_json.py create mode 100644 tests/common/test_logging.py create mode 100644 tests/common/test_msgpack.py create mode 100644 tests/common/test_plugin.py create mode 100644 tests/common/test_service_ports.py create mode 100644 tests/common/test_types.py create mode 100644 tests/common/test_utils.py create mode 100644 tests/common/test_validators.py create mode 100644 tests/manager/BUILD create mode 100644 tests/manager/__init__.py create mode 100644 tests/manager/api/BUILD create mode 100644 tests/manager/api/test_auth.py create mode 100644 tests/manager/api/test_bgtask.py create mode 100644 tests/manager/api/test_config.py create mode 100644 tests/manager/api/test_exceptions.py create mode 100644 tests/manager/api/test_middlewares.py create mode 100644 tests/manager/api/test_ratelimit.py create mode 100644 tests/manager/api/test_utils.py create mode 100644 tests/manager/conftest.py create mode 100644 tests/manager/fixtures/example-keypairs.json create mode 100644 tests/manager/fixtures/example-resource-presets.json create mode 100755 tests/manager/fixtures/example-session-templates.json create mode 100644 tests/manager/model_factory.py create mode 100644 tests/manager/models/BUILD create mode 100644 tests/manager/models/test_dbutils.py create mode 100644 tests/manager/sample-ssl-cert/sample.crt create mode 100644 tests/manager/sample-ssl-cert/sample.csr create mode 100644 tests/manager/sample-ssl-cert/sample.key create mode 100644 tests/manager/test_advisory_lock.py create mode 100644 tests/manager/test_image.py create mode 100644 tests/manager/test_predicates.py create mode 100644 tests/manager/test_queryfilter.py create mode 100644 tests/manager/test_queryorder.py create mode 100644 tests/manager/test_registry.py create mode 100644 tests/manager/test_scheduler.py create mode 100644 tests/plugin/BUILD create mode 100644 tests/plugin/test_entrypoint.py create mode 100644 tests/storage-proxy/BUILD create mode 100644 tests/storage-proxy/conftest.py create mode 100644 tests/storage-proxy/test_netapp.py create mode 100644 tests/storage-proxy/test_purestorage.py create mode 100644 tests/storage-proxy/test_vfs.py create mode 100644 tests/storage-proxy/test_xfs.py delete mode 100644 tests/test_dummy.py create mode 100644 tests/webserver/BUILD create mode 100644 tests/webserver/conftest.py create mode 100644 tests/webserver/test_auth.py create mode 100644 tools/flake8.lock create mode 100644 tools/mypy.lock create mode 100644 tools/pants-linux-aarch64.patch create mode 100755 tools/pants-local create mode 100644 tools/pants-plugins/platform_resources/BUILD create mode 100644 tools/pants-plugins/platform_resources/register.py create mode 100644 tools/pants-plugins/setupgen/BUILD create mode 100644 tools/pants-plugins/setupgen/register.py create mode 100644 tools/pytest.lock diff --git a/.editorconfig b/.editorconfig index 1b797cce6e..2fa85e787f 100644 --- a/.editorconfig +++ b/.editorconfig @@ -7,7 +7,7 @@ trim_trailing_whitespace = true charset = utf-8 [*.{py,md}] -max_line_length = 85 +max_line_length = 105 indent_style = space indent_size = 4 @@ -19,3 +19,10 @@ indent_size = 3 [.travis.yml] indent_style = space indent_size = 2 + +[BUILD] +indent_style = space +indent_size = 4 + +[VERSION] +insert_final_newline = false diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..cd80cb9fd1 --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +[flake8] +# ref: http://pep8.readthedocs.io/en/latest/intro.html#error-codes +ignore = E126,E127,E128,E129,E722,E731,E221,E241,E401,W503,W504,N801,N802 +max-line-length = 125 +builtins = _ +exclude = .git,.cache,.idea,.egg,__pycache__,venv,dist,build,docs,src/ai/backend/manager/models/alembic/**,*.pyi + +# vim: ft=dosini diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..ede2bec008 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,66 @@ +src/ai/backend/runner/scp.alpine3.8.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.ubuntu20.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.alpine3.8.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.centos7.6.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.centos7.6.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.ubuntu16.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/tmux.glibc.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbear.musl.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/jail.alpine3.8.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.alpine3.8.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbearconvert.musl.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbear.glibc.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbear.musl.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.centos7.6.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.centos7.6.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.ubuntu18.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.ubuntu16.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/tmux.musl.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbearkey.glibc.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbearkey.glibc.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbearkey.musl.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.ubuntu16.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.ubuntu16.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.alpine3.8.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.ubuntu20.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.ubuntu20.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbearconvert.glibc.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbear.glibc.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.ubuntu18.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.ubuntu18.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.alpine3.8.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.ubuntu20.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.ubuntu18.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/jail.ubuntu16.04.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.ubuntu18.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.centos7.6.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.centos7.6.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.ubuntu18.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/tmux.glibc.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbearconvert.glibc.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbearkey.musl.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/tmux.musl.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/dropbearconvert.musl.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.alpine3.8.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.ubuntu16.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/scp.ubuntu20.04.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/sftp-server.ubuntu20.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/su-exec.ubuntu16.04.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/krunner-extractor.img.aarch64.tar.xz filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/krunner-extractor.img.x86_64.tar.xz filter=lfs diff=lfs merge=lfs -text +src/ai/backend/agent/docker/backendai-socket-relay.img.x86_64.tar.gz filter=lfs diff=lfs merge=lfs -text +src/ai/backend/agent/docker/linuxkit-nsenter.img.aarch64.tar.gz filter=lfs diff=lfs merge=lfs -text +src/ai/backend/agent/docker/linuxkit-nsenter.img.x86_64.tar.gz filter=lfs diff=lfs merge=lfs -text +src/ai/backend/agent/docker/backendai-socket-relay.img.aarch64.tar.gz filter=lfs diff=lfs merge=lfs -text +src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.aarch64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.x86_64.bin filter=lfs diff=lfs merge=lfs -text +src/ai/backend/agent/kubernetes/backendai-socket-relay.img.tar.gz filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/libbaihook.alpine3.8.aarch64.so filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/libbaihook.alpine3.8.x86_64.so filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/libbaihook.centos7.6.aarch64.so filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/libbaihook.centos7.6.x86_64.so filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/libbaihook.ubuntu18.04.x86_64.so filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/libbaihook.ubuntu20.04.aarch64.so filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/libbaihook.ubuntu20.04.x86_64.so filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/roboto-italic.ttf filter=lfs diff=lfs merge=lfs -text +src/ai/backend/runner/roboto.ttf filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/default.yml b/.github/workflows/default.yml index 6396544bc5..a0fa043e98 100644 --- a/.github/workflows/default.yml +++ b/.github/workflows/default.yml @@ -6,87 +6,150 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + - uses: actions/cache@v2 + id: cache + with: + # pants-specific cache + path: | + ~/.cache/pants/setup + ~/.cache/pants/lmdb_store + ~/.cache/pants/named_caches + key: ${{ runner.os }}- - name: Set up Python uses: actions/setup-python@v2 with: - python-version: "3.9" - - name: Cache pip packages - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: lint-flake8-${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }} - restore-keys: | - lint-flake8-${{ runner.os }}-pip-${{ matrix.python-version }} - lint-flake8-${{ runner.os }}-pip- - - name: Install dependencies - env: - REQUIREMENTS_FILE: lint + python-version: "3.10" + cache: pip + - name: Bootstrap Pants run: | - python -m pip install -U pip setuptools wheel - python -m pip install -U -r requirements/${REQUIREMENTS_FILE}.txt - - name: Lint with flake8 + mkdir .tmp + ./pants --no-verify-config version + - name: Check BUILD files + run: ./pants tailor --check update-build-files --check + - name: Lint run: | if [ "$GITHUB_EVENT_NAME" == "pull_request" -a -n "$GITHUB_HEAD_REF" ]; then echo "(skipping matchers for pull request from local branches)" else echo "::add-matcher::.github/workflows/flake8-matcher.json" fi - python -m flake8 src/ai/backend tests - + if [ -n "$GITHUB_BASE_REF" ]; then + BASE_REF="origin/${GITHUB_BASE_REF}" + git fetch --no-tags --depth=1 origin "$GITHUB_BASE_REF" + else + BASE_REF="HEAD~1" + fi + ./pants lint --changed-since=$BASE_REF --changed-dependees=transitive + - name: Upload pants log + uses: actions/upload-artifact@v2 + with: + name: pants.lint.log + path: .pants.d/pants.log + if: always() # We want the log even on failures. + + typecheck: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + - uses: actions/cache@v2 + id: cache + with: + # pants-specific cache + path: | + ~/.cache/pants/setup + ~/.cache/pants/lmdb_store + ~/.cache/pants/named_caches + key: ${{ runner.os }}- - name: Set up Python uses: actions/setup-python@v2 with: - python-version: "3.9" - - name: Cache pip packages - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: typecheck-mypy-${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }} - restore-keys: | - typecheck-mypy-${{ runner.os }}-pip-${{ matrix.python-version }} - typecheck-mypy-${{ runner.os }}-pip- - - name: Install dependencies - env: - REQUIREMENTS_FILE: typecheck + python-version: "3.10" + cache: pip + - name: Bootstrap Pants run: | - python -m pip install -U pip setuptools wheel - python -m pip install -U -r requirements/${REQUIREMENTS_FILE}.txt - - name: Type check with mypy + mkdir .tmp + ./pants --no-verify-config version + - name: Check BUILD files + run: ./pants tailor --check update-build-files --check + - name: Typecheck run: | if [ "$GITHUB_EVENT_NAME" == "pull_request" -a -n "$GITHUB_HEAD_REF" ]; then echo "(skipping matchers for pull request from local branches)" else echo "::add-matcher::.github/workflows/mypy-matcher.json" fi - python -m mypy --no-color-output src/ai/backend + if [ -n "$GITHUB_BASE_REF" ]; then + BASE_REF="origin/${GITHUB_BASE_REF}" + git fetch --no-tags --depth=1 origin "$GITHUB_BASE_REF" + else + BASE_REF="HEAD~1" + fi + ./pants check --changed-since=$BASE_REF --changed-dependees=transitive + - name: Upload pants log + uses: actions/upload-artifact@v2 + with: + name: pants.check.log + path: .pants.d/pants.log + if: always() # We want the log even on failures. test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + submodules: true + - uses: actions/cache@v2 + id: cache with: - python-version: "3.9" - - name: Cache pip packages + # pants-specific cache + path: | + ~/.cache/pants/setup + ~/.cache/pants/lmdb_store + ~/.cache/pants/named_caches + key: ${{ runner.os }}- + - name: Create LFS file hash list + run: git lfs ls-files -l | cut -d ' ' -f1 | sort > .lfs-assets-id + - name: Restore LFS cache uses: actions/cache@v2 + id: lfs-cache with: - path: ~/.cache/pip - key: typecheck-mypy-${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }} - - name: Install dependencies - env: - REQUIREMENTS_FILE: test + path: .git/lfs + key: lfs-${{ hashFiles('.lfs-assets-id') }} + - name: Git LFS Pull + run: git lfs pull + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.10" + cache: pip + - name: Bootstrap Pants run: | - python -m pip install -U pip setuptools wheel - python -m pip install -U -r requirements/${REQUIREMENTS_FILE}.txt - - name: Test with pytest + mkdir .tmp + ./pants --no-verify-config version + - name: Check BUILD files + run: ./pants tailor --check update-build-files --check + - name: Test run: | - python -m pytest tests + if [ -n "$GITHUB_BASE_REF" ]; then + BASE_REF="origin/${GITHUB_BASE_REF}" + git fetch --no-tags --depth=1 origin "$GITHUB_BASE_REF" + else + BASE_REF="HEAD~1" + fi + ./pants test --changed-since=$BASE_REF --changed-dependees=transitive + - name: Upload pants log + uses: actions/upload-artifact@v2 + with: + name: pants.test.log + path: .pants.d/pants.log + if: always() # We want the log even on failures. deploy-to-pypi: needs: [lint, typecheck, test] @@ -94,28 +157,78 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 with: - python-version: "3.9" - - name: Cache pip packages + submodules: true + - uses: actions/cache@v2 + id: cache + with: + # pants-specific cache + path: | + ~/.cache/pants/setup + ~/.cache/pants/lmdb_store + ~/.cache/pants/named_caches + key: ${{ runner.os }}- + - name: Create LFS file hash list + run: git lfs ls-files -l | cut -d ' ' -f1 | sort > .lfs-assets-id + - name: Restore LFS cache uses: actions/cache@v2 + id: lfs-cache with: - path: ~/.cache/pip - key: test-pytest-${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.py') }} - restore-keys: | - test-pytest-${{ runner.os }}-pip-${{ matrix.python-version }} - test-pytest-${{ runner.os }}-pip- - - name: Install dependencies - env: - REQUIREMENTS_FILE: build + path: .git/lfs + key: lfs-${{ hashFiles('.lfs-assets-id') }} + - name: Git LFS Pull + run: git lfs pull + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: "3.10" + cache: pip + - name: Bootstrap Pants + run: | + mkdir .tmp + ./pants --no-verify-config version + - name: Install local dependencies for packaging and publishing + run: | + pip install -U 'twine~=4.0' 'packaging>=21.3' + - name: Build packages run: | - python -m pip install -U pip setuptools wheel - python -m pip install -U -r requirements/${REQUIREMENTS_FILE}.txt - - name: Build and publish + # Normalize the package version + PKGVER=$(python -c "import packaging.version,pathlib; print(str(packaging.version.Version(pathlib.Path('VERSION').read_text())))") + # Build non-platform-specific wheels + ./pants --platform-specific-resources-target=linux_x86_64 --tag="wheel" --tag="-platform-specific" package '::' + # Build x86_64 wheels + MANYLINUX_PTAG=manylinux2014_x86_64 + MACOS_PTAG=macosx_11_0_x86_64 + ./pants --platform-specific-resources-target=linux_x86_64 --tag="wheel" --tag="+platform-specific" package '::' + for pkgname in "kernel_binary"; do + mv "dist/backend.ai_${pkgname}-${PKGVER}-py3-none-any.whl" \ + "dist/backend.ai_${pkgname}-${PKGVER}-py3-none-${MANYLINUX_PTAG}.${MACOS_PTAG}.whl" + done + # Build arm64 wheels + MANYLINUX_PTAG=manylinux2014_aarch64 + MACOS_PTAG=macosx_11_0_arm64 + ./pants --platform-specific-resources-target=linux_arm64 --tag="wheel" --tag="+platform-specific" package '::' + for pkgname in "kernel_binary"; do + mv "dist/backend.ai_${pkgname}-${PKGVER}-py3-none-any.whl" \ + "dist/backend.ai_${pkgname}-${PKGVER}-py3-none-${MANYLINUX_PTAG}.${MACOS_PTAG}.whl" + done + ls -lh dist + - name: Publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + # We don't use `./pants publish ::` because we manually rename the + # wheels after buildling them to add arch-specific tags. run: | - python setup.py sdist bdist_wheel - twine upload dist/* \ No newline at end of file + twine upload dist/*.whl + - name: Upload artifacts + uses: actions/upload-artifact@v2 + with: + name: dist + path: dist + - name: Upload pants log + uses: actions/upload-artifact@v2 + with: + name: pants.deploy.log + path: .pants.d/pants.log + if: always() # We want the log even on failures. diff --git a/.github/workflows/timeline-check.yml b/.github/workflows/timeline-check.yml new file mode 100644 index 0000000000..b1e9e96d56 --- /dev/null +++ b/.github/workflows/timeline-check.yml @@ -0,0 +1,29 @@ +name: timeline-check + +on: [pull_request] + +jobs: + towncrier: + runs-on: ubuntu-latest + steps: + - name: Sparse-checkout + uses: lablup/sparse-checkout@v1 + with: + patterns: | + changes + scripts + packages + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install -U pip setuptools + python -m pip install -U towncrier~=21.9 + - name: Check existence of news fragment + run: | + git fetch --no-tags origin +refs/heads/${BASE_BRANCH}:refs/remotes/origin/${BASE_BRANCH} + python -m towncrier.check --compare-with=origin/${BASE_BRANCH} + env: + BASE_BRANCH: ${{ github.base_ref }} diff --git a/.gitignore b/.gitignore index d3de4ad297..4cbdf5fe38 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,6 @@ __pycache__/ *.py[cod] *$py.class -# C extensions -*.so - # Distribution / packaging .Python env/ @@ -95,5 +92,29 @@ ENV/ # IDE/vim .idea .*.swp - -tmp/ +.vscode + +# Pants +/.pants.d/ +/.pids +/.pants.workdir.file_lock* +/.pants.rc +/pants-local +/tools/pants-src + +# Local configurations +/manager.toml +/agent.toml +/storage-proxy.toml +/webserver.conf +/.pants.env +/docker-compose.halfstack.current.yml +/alembic.ini +/dev.etcd.volumes.json +/env-*.sh + +# Local temp +/.tmp/ +/tmp/ +/volumes/ +/vfroot/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..2bb02b82a6 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/ai/backend/web/static"] + path = src/ai/backend/web/static + url = https://github.com/lablup/backend.ai-app diff --git a/BUILD b/BUILD new file mode 100644 index 0000000000..90e7d2f93a --- /dev/null +++ b/BUILD @@ -0,0 +1,26 @@ +python_requirements( + name="reqs", + source="requirements.txt", + module_mapping={ + "aiodataloader-ng": ["aiodataloader"], + "attrs": ["attr", "attrs"], + "python-dateutil": ["dateutil"], + "python-json-logger": ["pythonjsonlogger"], + "python-snappy": ["snappy"], + "pyzmq": ["zmq"], + "PyYAML": ["yaml"], + "typing-extensions": ["typing_extensions"], + "more-itertools": ["more_itertools"], + "zipstream-new": ["zipstream"], + }, + type_stubs_module_mapping={ + "types-aiofiles": ["aiofiles"], + "types-click": ["click"], + "types-cachetools": ["cachetools"], + "types-Jinja2": ["Jinja2"], + "types-PyYAML": ["yaml"], + "types-python-dateutil": ["dateutil"], + "types-six": ["six"], + "types-tabulate": ["tabulate"], + }, +) diff --git a/BUILD_ROOT b/BUILD_ROOT new file mode 100644 index 0000000000..dae01c95cc --- /dev/null +++ b/BUILD_ROOT @@ -0,0 +1 @@ +# a placeholder to designate the repository root path diff --git a/MIGRATION.md b/MIGRATION.md new file mode 100644 index 0000000000..94d0730d48 --- /dev/null +++ b/MIGRATION.md @@ -0,0 +1,51 @@ +Backend.AI Migration Guide +========================== + +## General + +* The migration should be done while the managers and agents are shut down. +* This guide only describes additional steps to follow other than the code/package upgrades. + +## 21.09 to 22.03 + +* `alembic upgrade head` is required to migrate the PostgreSQL database schema. + - The `keypairs.concurrency_used` column is dropped and it will use Redis to keep track of it. + - The `kernels.last_stat` column is still there but it will get updated only when the kernels terminate. + There is a backup option to restore prior behavior of periodic sync: `debug.periodic-sync-stats` in + `manager.toml`, though. + +* The Redis container used with the manager should be reconfigured to use a persistent database. + In HA setup, it is recommended to enable AOF by `appendonly yes` in the Redis configuration to make it + recoverable after hardware failures. + + Consult [the official doc](https://redis.io/docs/manual/persistence/) for more details. + + - FYI: The Docker official image uses `/data` as the directory to store RDB/AOF files. It may be + configured to use an explicit bind-mount of a host directory. If not configured, by default it will + create an anonymous volume and mount it. + +* The image metadata database is migrated from etcd to PostgreSQL while the registry configuration is + still inside the etcd. + + Run `backend.ai mgr image rescan` in the manager venv or `backend.ai admin image rescan` from clients + with the superadmin privilege to resync the image database. The old etcd image database will no longer + be used. + +* The manager now has replacible distributed lock backend, configured by the key `manager.distributed-lock` in + `manager.toml`. **The new default is "etcd".** "filelock" is suitable for single-node manager deployments + as it relies on POSIX file-level advisory locks. Change this value to "pg_advisory" to restore the behavior + of previous versions. "redlock" is not currently supported as aioredis v2 has a limited implementation. + +* (TODO) storage-proxy related stuffs + +* Configure an explicit cron job to execute `backend.ai mgr clear-history -r {retention}` which trims old + sessions' execution records from the PostgreSQL and Redis databases to avoid indefinite grow of disk + and memory usage of the manager. + + The retention argument may be given as human-readable duration expressions, such as `30m`, `6h`, `3d`, + `2w`, `3mo`, and `1yr`. If there is no unit suffix, the value is interpreted as seconds. + It is recommended to schedule this command once a day. + +## 21.03 to 21.09 + +* `alembic upgrade head` is required to migrate the PostgreSQL database schema. diff --git a/README.md b/README.md index 4671093993..bb1e957c28 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,62 @@ computation sessions on-demand or in batches with customizable job schedulers. All its functions are exposed as REST/GraphQL/WebSocket APIs. -Server-side Components ----------------------- - -If you want to run a Backend.AI cluster on your own, you need to install and -configure the following server-side components. -All server-side components are licensed under LGPLv3 to promote non-proprietary open -innovation in the open-source community. +Contents in This Repository +--------------------------- + +This repository contains all open-source server-side components and the client SDK for Python +as a reference implementation of API clients. + +## Directory Structure + +* `src/ai/backend/`: Source codes + - `manager/`: Manager + - `manager/api`: Manager API handlers + - `agent/`: Agent + - `agent/docker/`: Agent's Docker backend + - `agent/k8s/`: Agent's Kubernetes backend + - `kernel/`: Agent's kernel runner counterpart + - `runner/`: Agent's in-kernel prebuilt binaries + - `helpers/`: Agent's in-kernel helper package + - `common/`: Shared utilities + - `client/`: Client SDK + - `cli/`: Unified CLI for all components + - `storage/`: Storage proxy + - `storage/api`: Storage proxy's manager-facing and client-facing APIs + - `web/`: Web UI server + - `plugin/`: Plugin subsystem + - `test/`: Integration test suite + - `testutils/`: Shared utilities used by unit tests + - `meta/`: Legacy meta package +* `docs/`: Unified documentation +* `tests/` + - `manager/`, `agent/`, ...: Per-component unit tests +* `configs` + - `manager/`, `agent/`, ...: Per-component sample configurations +* `fixtures/` + - `manager/`, ...: Per-component fixtures for development setup and tests +* `plugins/`: A directory to place plugins such as accelerators, monitors, etc. +* `scripts/`: Scripts to assist development workflows + - `install-dev.sh`: The single-node development setup script from the working copy +* `stubs/`: Type annotation stub packages written by us +* `tools/`: A directory to host Pants-related tooling +* `dist/`: A directory to put build artifacts (.whl files) and Pants-exported virtualenvs +* `changes/`: News fragments for towncrier +* `pants.toml`: The Pants configuration +* `pyproject.toml`: Tooling configuration (towncrier, pytest, mypy) +* `BUILD`: The root build config file +* `**/BUILD`: Per-directory build config files +* `BUILD_ROOT`: An indicator to mark the build root directory for Pants +* `requirements.txt`: The unified requirements file +* `*.lock`, `tools/*.lock`: The dependency lock files +* `docker-compose.*.yml`: Per-version recommended halfstack container configs +* `README.md`: This file +* `MIGRATION.md`: The migration guide for updating between major releases +* `VERSION`: The unified version declaration + +Server-side components are licensed under LGPLv3 to promote non-proprietary open +innovation in the open-source community while other shared libraries and client SDKs +are distributed under the MIT license. There is no obligation to open your service/system codes if you just run the server-side components as-is (e.g., just run as daemons or import the components @@ -27,17 +76,59 @@ without modification in your codes). Please contact us (contact-at-lablup-com) for commercial consulting and more licensing details/options about individual use-cases. -For details about server installation and configuration, please visit [our -documentation](http://docs.backend.ai). -### Manager with API Gateway +Getting Started +--------------- + +### Installation for Single-node Development + +Run `scripts/install-dev.sh` after cloning this repository. + +This script checks availability of all required dependencies such as Docker and bootstrap a development +setup. Note that it requires `sudo` and a modern Python installed in the host system based on Linux +(Debian/RHEL-likes) or macOS. + +### Installation for Multi-node Tests & Production + +Please consult [our documentation](http://docs.backend.ai) for community-supported materials. +Contact the sales team (contact@lablup.com) for professional paid support and deployment options. + +### Accessing Compute Sessions (aka Kernels) + +Backend.AI provides websocket tunneling into individual computation sessions (containers), +so that users can use their browsers and client CLI to access in-container applications directly +in a secure way. + +* Jupyter: data scientists' favorite tool + * Most container images have intrinsic Jupyter and JupyterLab support. +* Web-based terminal + * All container sessions have intrinsic ttyd support. +* SSH + * All container sessions have intrinsic SSH/SFTP/SCP support with auto-generated per-user SSH keypair. + PyCharm and other IDEs can use on-demand sessions using SSH remote interpreters. +* VSCode (coming soon) + * Most container sessions have intrinsic web-based VSCode support. + +### Working with Storage + +Backend.AI provides an abstraction layer on top of existing network-based storages +(e.g., NFS/SMB), called vfolders (virtual folders). +Each vfolder works like a cloud storage that can be mounted into any computation +sessions and shared between users and user groups with differentiated privileges. + + +Major Components +---------------- + +### Manager It routes external API requests from front-end services to individual agents. It also monitors and scales the cluster of multiple agents (a few tens to hundreds). -* https://github.com/lablup/backend.ai-manager - * Package namespace: `ai.backend.gateway` and `ai.backend.manager` - * Plugin interfaces +* `src/ai/backend/manager` + * [README](https://github.com/lablup/backend.ai/blob/main/src/ai/backend/manager/README.md) + * Legacy per-pkg repo: https://github.com/lablup/backend.ai-manager + * Availble plugin interfaces - `backendai_scheduler_v10` - `backendai_hook_v10` - `backendai_webapp_v10` @@ -51,46 +142,23 @@ REPL daemons (kernels) run. Each agent on a new EC2 instance self-registers itself to the instance registry via heartbeats. -* https://github.com/lablup/backend.ai-agent - * Package namespace: `ai.backend.agent` - * Plugin interfaces +* `src/ai/backend/agent` + * [README](https://github.com/lablup/backend.ai/blob/main/src/ai/backend/agent/README.md) + * Legacy per-pkg repo: https://github.com/lablup/backend.ai-agent + * Available plugin interfaces - `backendai_accelerator_v12` - `backendai_monitor_stats_v10` - `backendai_monitor_error_v10` - - `backendai_krunner_v10` -* https://github.com/lablup/backend.ai-accelerator-cuda (CUDA accelerator plugin) - * Package namespace: `ai.backend.acceelrator.cuda` -* https://github.com/lablup/backend.ai-accelerator-cuda-mock (CUDA mockup plugin) - * Package namespace: `ai.backend.acceelrator.cuda` - * This emulates the presence of CUDA devices without actual CUDA devices, - so that developers can work on CUDA integration without real GPUs. -* https://github.com/lablup/backend.ai-accelerator-rocm (ROCM accelerator plugin) - * Package namespace: `ai.backend.acceelrator.rocm` - -### Server-side common plugins (for both manager and agents) - -* https://github.com/lablup/backend.ai-stats-monitor - - Statistics collector based on the Datadog API - - Package namespace: `ai.backend.monitor.stats` -* https://github.com/lablup/backend.ai-error-monitor - - Exception collector based on the Sentry API - - Package namespace: `ai.backend.monitor.error` ### Kernels -A set of small ZeroMQ-based REPL daemons in various programming languages and -configurations. - -* https://github.com/lablup/backend.ai-kernel-runner - * Package namespace: `ai.backend.kernel` - * A common interface for the agent to deal with various language runtimes * https://github.com/lablup/backend.ai-kernels - * Runtime-specific recipes to build the Docker images (Dockerfile) + - Computing environment recipes (Dockerfile) to build the container images to execute + on top of the Backend.AI platform ### Jail -A programmable sandbox implemented using ptrace-based sytem call filtering written in -Go. +A programmable sandbox implemented using ptrace-based sytem call filtering written in Go. * https://github.com/lablup/backend.ai-jail @@ -101,17 +169,6 @@ with agents). * https://github.com/lablup/backend.ai-hook -### Commons - -A collection of utility modules commonly shared throughout Backend.AI projects. - -* Package namespaces: `ai.backend.common` -* https://github.com/lablup/backend.ai-common - - -Client-side Components ----------------------- - ### Client SDK Libraries We offer client SDKs in popular programming languages. @@ -131,6 +188,30 @@ commercial and non-commercial software products and services. * `composer require lablup/backend.ai-client` * https://github.com/lablup/backend.ai-client-php + +Plugins +------- + +* `backendai_accelerator_v12` + - [`ai.backend.accelerator.cuda`](https://github.com/lablup/backend.ai-accelerator-cuda): CUDA accelerator plugin + - [`ai.backend.accelerator.cuda` (mock)](https://github.com/lablup/backend.ai-accelerator-cuda-mock): CUDA mockup plugin + - This emulates the presence of CUDA devices without actual CUDA devices, + so that developers can work on CUDA integration without real GPUs. + - [`ai.backend.accelerator.rocm`](https://github.com/lablup/backend.ai-accelerator-rocm): ROCM accelerator plugin + - More available in the enterprise edition! +* `backendai_monitor_stats_v10` + - [`ai.backend.monitor.stats`](https://github.com/lablup/backend.ai-stats-monitor) + - Statistics collector based on the Datadog API +* `backendai_monitor_error_v10` + - [`ai.backend.monitor.error`](https://github.com/lablup/backend.ai-error-monitor) + - Exception collector based on the Sentry API + + +Legacy Components +----------------- + +These components still exist but are no longer actively maintained. + ### Media The front-end support libraries to handle multi-media outputs (e.g., SVG plots, @@ -141,25 +222,8 @@ animated vector graphics) the Javascript part in the front-end. * https://github.com/lablup/backend.ai-media -Interacting with computation sessions -------------------------------------- - -Backend.AI provides websocket tunneling into individual computation sessions (containers), -so that users can use their browsers and client CLI to access in-container applications directly -in a secure way. - -* Jupyter Kernel: data scientists' favorite tool - * Most container sessions have intrinsic Jupyter and JupyterLab support. -* Web-based terminal - * All container sessions have intrinsic ttyd support. -* SSH - * All container sessions have intrinsic SSH/SFTP/SCP support with auto-generated per-user SSH keypair. - PyCharm and other IDEs can use on-demand sessions using SSH remote interpreters. -* VSCode (coming soon) - * Most container sessions have intrinsic web-based VSCode support. -Integrations with IDEs and Editors ----------------------------------- +### IDE and Editor Extensions * Visual Studio Code Extension * Search “Live Code Runner” among VSCode extensions. @@ -168,13 +232,9 @@ Integrations with IDEs and Editors * Search “Live Code Runner” among Atom plugins. * https://github.com/lablup/atom-live-code-runner -Storage management ------------------- +We now recommend using in-kernel applications such as Jupyter Lab, Visual Studio Code Server, +or native SSH connection to kernels via our client SDK or desktop apps. -Backend.AI provides an abstraction layer on top of existing network-based storages -(e.g., NFS/SMB), called vfolders (virtual folders). -Each vfolder works like a cloud storage that can be mounted into any computation -sessions and shared between users and user groups with differentiated privileges. License ------- diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000..37d67c63f6 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +22.06.0.dev0 \ No newline at end of file diff --git a/backend.ai b/backend.ai new file mode 100755 index 0000000000..b9f274f25d --- /dev/null +++ b/backend.ai @@ -0,0 +1,2 @@ +#! /bin/bash +./py -m ai.backend.cli "$@" diff --git a/changes/417.misc b/changes/417.misc new file mode 100644 index 0000000000..38fbe6eb0e --- /dev/null +++ b/changes/417.misc @@ -0,0 +1 @@ +Migrate to a semi-mono repository that contains all first-party server-side components with automated dependency management via Pants diff --git a/changes/template.md b/changes/template.md new file mode 100644 index 0000000000..db193a0e03 --- /dev/null +++ b/changes/template.md @@ -0,0 +1,44 @@ +{%- if top_line -%} +{{ top_line }} +{%- elif versiondata.name -%} +{{ versiondata.name }} {{ versiondata.version }} ({{ versiondata.date }}) +{%- else -%} +{{ versiondata.version }} ({{ versiondata.date }}) +{%- endif -%} +{%- for section, _ in sections.items() -%} + {%- if section -%} +### {{ section }}{%- endif -%} + {%- if sections[section] -%} + {%- for category, val in definitions.items() if category in sections[section] %} + + +### {{ definitions[category]['name'] }} + + {%- if definitions[category]['showcontent'] %} + {%- for text, values in sections[section][category].items() %} + {%- if values[0].endswith("/0)") %} + +* {{ definitions[category]['name'] }} without explicit PR/issue numbers + {{ text }} + {%- else %} + +* {{ text }} {{ values|join(',\n ') }} + {%- endif %} + + {%- endfor %} + {%- else %} + +* {{ sections[section][category]['']|join(', ') }} + {%- endif %} + {%- if sections[section][category]|length == 0 %} + +No significant changes. + {%- else %} + {%- endif %} + + {%- endfor %} + {%- else %} + +No significant changes. + {%- endif %} +{%- endfor %} diff --git a/configs/agent/ci.toml b/configs/agent/ci.toml new file mode 100644 index 0000000000..4250835a4a --- /dev/null +++ b/configs/agent/ci.toml @@ -0,0 +1,65 @@ +[etcd] +namespace = "local" +addr = { host = "127.0.0.1", port = 2379 } +user = "" +password = "" + + +[agent] +rpc-listen-addr = { host = "127.0.0.1", port = 6001 } +agent-sock-port = 6007 +id = "i-travis" +scaling-group = "default" +pid-file = "./agent.pid" +event-loop = "asyncio" + + +[container] +port-range = [30000, 31000] +kernel-uid = -1 +bind-host = "127.0.0.1" +sandbox-type = "docker" +jail-args = [] +scratch-type = "hostdir" +scratch-root = "/tmp/scratches" +scratch-size = "1G" + + +[watcher] +service-addr = { host = "127.0.0.1", port = 6009 } +ssl-enabled = false +#ssl-cert = "" +#ssl-key = "" +target-service = "backendai-agent.service" +soft-reset-available = false + + +[logging] +level = "INFO" +drivers = ["console"] + +[logging.pkg-ns] +"" = "WARNING" +"aiodocker" = "INFO" +"aiotools" = "INFO" +"aiohttp" = "INFO" +"ai.backend" = "INFO" + +[logging.console] +colored = true +format = "simple" + + +[resource] +reserved-cpu = 1 +reserved-mem = "1G" +reserved-disk = "8G" + + +[debug] +debug = false +skip-container-deletion = false + + +[license] +addr = { host = "127.0.0.1", port = 6099 } diff --git a/configs/agent/halfstack.toml b/configs/agent/halfstack.toml new file mode 100644 index 0000000000..8b3cef1a82 --- /dev/null +++ b/configs/agent/halfstack.toml @@ -0,0 +1,83 @@ +[etcd] +namespace = "local" +addr = { host = "127.0.0.1", port = 8120 } +user = "" +password = "" + + +[agent] +mode = "docker" +rpc-listen-addr = { host = "127.0.0.1", port = 6001 } +agent-sock-port = 6007 +# id = "i-something-special" +scaling-group = "default" +pid-file = "./agent.pid" +event-loop = "uvloop" + + +[container] +port-range = [30000, 31000] +kernel-uid = -1 +kernel-gid = -1 +bind-host = "127.0.0.1" +sandbox-type = "docker" +scratch-type = "hostdir" +scratch-root = "./scratches" +scratch-size = "1G" + + +[watcher] +service-addr = { host = "127.0.0.1", port = 6009 } +ssl-enabled = false +#ssl-cert = "" +#ssl-key = "" +target-service = "backendai-agent.service" +soft-reset-available = false + + +[logging] +level = "INFO" +drivers = ["console"] + +[logging.pkg-ns] +"" = "WARNING" +"aiodocker" = "INFO" +"aiotools" = "INFO" +"aiohttp" = "INFO" +"ai.backend" = "INFO" + +[logging.console] +colored = true +format = "verbose" + +[logging.file] +path = "./logs" +filename = "agent.log" +rotation-size = "10M" + +[logging.logstash] +endpoint = { host = "localhost", port = 9300 } +protocol = "tcp" +ssl-enabled = true +ssl-verify = false + + +[resource] +reserved-cpu = 1 +reserved-mem = "1G" +reserved-disk = "8G" + + +[debug] +enabled = true +skip-container-deletion = false + +[debug.coredump] +enabled = false +path = "./coredumps" +backup-count = 10 +size-limit = "64M" + + +[license] +addr = { host = "127.0.0.1", port = 6099 } diff --git a/configs/agent/sample.toml b/configs/agent/sample.toml new file mode 100644 index 0000000000..3d387667c8 --- /dev/null +++ b/configs/agent/sample.toml @@ -0,0 +1,270 @@ +[etcd] +namespace = "local" # env: BACKEND_NAMESPACE +addr = { host = "127.0.0.1", port = 2379 } # env: BACKEND_ETCD_ADDR (host:port) +user = "" # env: BACKEND_ETCD_USER +password = "" # env: BACKEND_ETCD_PASSWORD + + +[agent] +# Agent mode; required +# One of: "kubernetes", "docker" +mode = "docker" + +# Change the reported host/address of the agent. +# The manager will use this value to connect to the agent. +# If host is an empty string, the agent tries to auto-detect it with a fallback to "127.0.0.1". +# For mobile environments such as developer laptops which roam around different networks, +# it is HIGHLY RECOMMENDED to set this to "127.0.0.1" manually. +rpc-listen-addr = { host = "", port = 6001 } +# env: BACKEND_AGENT_HOST_OVERRIDE +# env: BACKEND_AGENT_PORT + +# The port number of agent socket which provides access to host-side information +# to the containers (such as PID conversion). +agent-sock-port = 6007 + +# The base directory to put domain sockets for IPC. +# Normally you don't have to change it. +# NOTE: If Docker is installed via Snap (https://snapcraft.io/docker), +# you must change this to a directory under your *home* directory. +# ipc-base-path = "/tmp/backend.ai/ipc" + +# Override the name of this agent. +# If empty or unspecified, the agent builds this from the hostname by prefixing it with "i-", +# like "i-hostname". The "i-" prefix is not mandatory, though. +# This affects the per-node configuration scope. +# id = "i-something-special" + +# Set the scaling group of this agent. +# This affects the per-sgroup configuration scope. +scaling-group = "default" + +# Create a PID file so that daemon managers could keep track of us. +# If set to an empty string, it does NOT create the PID file. +# pid-file = "./agent.pid" # env: BACKEND_PID_FILE + +# One of: "asyncio", "uvloop" +# This changes the event loop backend. +# uvloop is a fast libuv-based implementation but sometimes has +# compatibility issues. +# event-loop = "uvloop" + +# A boolean flag to check and wait until at least one manager instances are running +# when the agent starts up. [default: false] +# skip-manager-detection = false + + +[container] +# The port range to expose public service ports. +# If too small, this may limit the maximum number of containers +# allowed to run simultaneously. +port-range = [30000, 31000] # env: BACKEND_CONTAINER_PORT_RANGE + +# The UID/GID to be set to the container's main process. +# If not specified, it uses the same user which the agent runs as. +# This configurations could be replaced with the configurations in etcd +# (config/container/kernel-uid, config/container/kernel-gid). +kernel-uid = -1 +kernel-gid = -1 + +# Change the reported host/address of the containers. +# The manager will use this value to connect to containers. +# If empty or unspecified, the agent tries to auto-detect it with a fallback to +# "agent.rpc-listen-addr.host" value. When auto-detcting, it uses the etcd's +# "config/network/subnet/container" key to limit the candidate IP addresses bound +# to the current host. +# For mobile environments such as developer laptops which roam around different networks, +# it is HIGHLY RECOMMENDED to set this to "127.0.0.1" manually. +# bind-host = "127.0.0.1" # env: BACKEND_KERNEL_HOST_OVERRIDE + +# Alias string to tell manager as a "kernel-host" value. +# Useful when wsproxy can't access kernel with bind-host IP. +# Optional, defaults to "bind-host" value when not specified. +# advertised-host "" + +# One of: "docker", "cgroup" +# "docker" uses the Docker API to retrieve container statistics. +# "cgroup" makes the agent to control the creation/destruction of container cgroups so +# that it can safely retrieve the last-moment statistics even when containers die +# unexpectedley. But this requires the agent to be run as root. +stats-type = "docker" + +# One of: "docker", "jail". +# "docker" uses the Docker's default apparmor and seccomp profiles. +# "jail" uses Backend.AI Jail to programmatically filter syscalls. +sandbox-type = "docker" + +# Only meaningful when sandbox-type = "jail" +# Additional arguments passed to the jail executable in containers. +jail-args = [] + +# One of: "hostdir", "memory", "k8s-nfs" +# "hostdir": creates an empty host directory and mount it as /home/work of containers. +# "memory": creates an in-memory tmpfs and mount it as /home/work of containers. (only supported in Linux) +# "k8s-nfs": creates Kubernetes PV/PVC and mounts it when creating Pods. (only supported in Kubernetes mode) +scratch-type = "hostdir" # env: BACKEND_SANDBOX_TYPE + +# Only meaningful when scratch-type is "hostdir" or "k8s-nfs". +# "hostdir": If not exists, it is auto-created. +# "k8s-nfs": Source NFS device should be mounted to this location. +scratch-root = "./scratches" # env: BACKEND_SCRATCH_ROOT + +# Limit the maximum size of the scratch space. +# If set zero, it is unlimited. +scratch-size = "1G" + +# Enable legacy swarm mode. +# This should be true to let this agent handles multi-container session. +swarm-enabled = false + +# Only meaningful when scratch-type is "k8s-nfs". +# Mount point of source NFS disk should match with scratch-root folder's mount point. +scratch-nfs-address = "" + +# Only meaningful when scratch-type is "k8s-nfs". +scratch-nfs-options = "" + + +[watcher] +# The address to accept the watcher API requests +service-addr = { host = "127.0.0.1", port = 6009 } +# env: BACKEND_WATCHER_SERVICE_IP +# env: BACKEND_WATCHER_SERVICE_PORT + +# SSL configuration for the watcher's HTTP endpoint +ssl-enabled = false +ssl-cert = "" +ssl-key = "" + +# The target systemd service name to watch and control. +target-service = "backendai-agent.service" + +# If "reload" is supported, set true. +soft-reset-available = false + + +[logging] +# One of: "NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" +# Set the global logging level. +level = "INFO" + +# Multi-choice of: "console", "logstash", "file" +# For each choice, there must be a "logging." section +# in this config file as exemplified below. +drivers = ["console", "file"] + + +[logging.console] +# If set true, use ANSI colors if the console is a terminal. +# If set false, always disable the colored output in console logs. +colored = true + +# One of: "simple", "verbose" +format = "simple" + + +[logging.file] +# The log file path and filename pattern. +# All messages are wrapped in single-line JSON objects. +# Rotated logs may have additional suffixes. +# For production, "/var/log/backend.ai" is recommended. +path = "./logs" +filename = "agent.log" + +# The maximum number of rotated logs. +backup-count = 5 + +# The log file size to begin rotation. +rotation-size = "10M" + + +[logging.logstash] +# The endpoint to publish logstash records. +endpoint = { host = "localhost", port = 9300 } + +# One of: "zmq.push", "zmq.pub", "tcp", "udp" +protocol = "tcp" + +# SSL configs when protocol = "tcp" +ssl-enabled = true +ssl-verify = true + + +# Specify additional package namespaces to include in the logs +# and their individual log levels. +# Note that the actual logging level applied is the conjunction of the global logging level and the +# logging levels specified here for each namespace. +[logging.pkg-ns] +"" = "WARNING" +"aiodocker" = "INFO" +"aiotools" = "INFO" +"aiohttp" = "INFO" +"ai.backend" = "INFO" + + +[resource] +# The amount of CPU cores reserved for the agent and the OS. +# This will be subtracted from the resource capacity reported to the manager. +reserved-cpu = 1 + +# The amount of memory reserved for the agent and the OS. +# This will be subtracted from the resource capacity reported to the manager. +reserved-mem = "1G" + +# The amount of disk space reserved for the agent and the OS. +# This will be subtracted from the resource capacity reported to the manager. +reserved-disk = "8G" + + +[debug] +# Enable or disable the debug-level logging. +enabled = false + +# If set true, it does not actually delete the containers after they terminate or are terminated +# so that developers can inspect the container logs. +# This is useful for debugging errors that make containers to terminate immediately after kernel +# launches, due to bugs in initialization steps such as jail. +skip-container-deletion = false + +# Include debug-level logs for internal events. +log-events = false + +# Include debug-level logs for detailed kernel creation configs and their resource spec. +log-kernel-config = false + +# Include debug-level logs for allocation maps. +log-alloc-map = false + +# Include debug-level logs for statistics. +log-stats = false + +# Include debug-level logs for heartbeats +log-heartbeats = false + +# Include debug-level logs for docker event stream. +log-docker-events = false + +[debug.coredump] +# If set true, enable coredumps in containers. Only supported in Linux. +# (This option is not related to the agent itself.) +# IMPORTANT: You must set /proc/sys/kernel/core_pattern to an absolute path which is available +# inside both the host and containers. +# If the system's core_pattern is set to a pipe (e.g., appport) or a relative path, +# the agent will report a configuration error upon startup. +enabled = false + +# Set a host directory to store coredumps, which will be auto-created. +# Inside the directory, coredumps are saved at container-ID-prefixed directories. +# It will be mounted as the parent directory of /proc/sys/kernel/core_pattern +path = "/var/crash/backend.ai" + +# Set the maximum number of recent container coredumps in the coredump directory. +# Oldest coredumps are deleted if there is more than this number of coredumps. +backup-count = 10 + +# Set the maximum size of coredumps from containers. +size-limit = "64M" + + +[license] +addr = { host = "127.0.0.1", port = 6099 } diff --git a/configs/manager/ci.toml b/configs/manager/ci.toml new file mode 100644 index 0000000000..48f91e30f0 --- /dev/null +++ b/configs/manager/ci.toml @@ -0,0 +1,67 @@ +[etcd] +namespace = "local" +addr = { host = "127.0.0.1", port = 2379 } +user = "" +password = "" + + +[db] +type = "postgresql" +addr = { host = "localhost", port = 5432 } +name = "testing_db_XXX" # auto-generated for every test run +user = "lablup" +password = "develove" + + +[manager] +num-proc = 2 +service-addr = { host = "0.0.0.0", port = 8080 } +ssl-enabled = false +#ssl-cert = "/etc/backend.ai/ssl/apiserver-fullchain.pem" +#ssl-privkey = "/etc/backend.ai/ssl/apiserver-privkey.pem" + +heartbeat-timeout = 5.0 +id = "i-travis" +pid-file = "./manager.pid" +disabled-plugins = [] + +importer-image = "lablup/importer:manylinux2010" + +event-loop = "asyncio" + + +[docker-registry] +ssl-verify = true + + +[logging] +level = "INFO" +drivers = ["console"] + +[logging.pkg-ns] +"" = "WARNING" +"aiotools" = "INFO" +"aiohttp" = "INFO" +"ai.backend" = "INFO" +"alembic" = "INFO" +"sqlalchemy" = "WARNING" + +[logging.console] +colored = true +format = "verbose" + +[logging.file] +path = "./logs" +filename = "manager.log" +backup-count = 5 +rotation-size = "10M" + +[logging.logstash] +endpoint = { host = "localhost", port = 9300 } +protocol = "tcp" +ssl-enabled = true +ssl-verify = true + + +[debug] +enabled = false diff --git a/configs/manager/halfstack.alembic.ini b/configs/manager/halfstack.alembic.ini new file mode 100644 index 0000000000..6c73f25c29 --- /dev/null +++ b/configs/manager/halfstack.alembic.ini @@ -0,0 +1,74 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = ai.backend.manager.models:alembic + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# timezone to use when rendering the date +# within the migration file as well as the filename. +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +#truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; this defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path +# version_locations = %(here)s/bar %(here)s/bat alembic/versions + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = postgresql://postgres:develove@localhost:8100/backend + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/configs/manager/halfstack.toml b/configs/manager/halfstack.toml new file mode 100644 index 0000000000..676d452372 --- /dev/null +++ b/configs/manager/halfstack.toml @@ -0,0 +1,67 @@ +[etcd] +namespace = "local" +addr = { host = "127.0.0.1", port = 8120 } +user = "" +password = "" + + +[db] +type = "postgresql" +addr = { host = "localhost", port = 8100 } +name = "backend" +user = "postgres" +password = "develove" + + +[manager] +num-proc = 4 +service-addr = { host = "127.0.0.1", port = 8081 } +#user = "nobody" +#group = "nobody" +ssl-enabled = false +#ssl-cert = "/etc/backend.ai/ssl/apiserver-fullchain.pem" # env: BACKNED_SSL_CERT +#ssl-privkey = "/etc/backend.ai/ssl/apiserver-privkey.pem" # env: BACKNED_SSL_KEY + +heartbeat-timeout = 10.0 +#id = "" +pid-file = "./manager.pid" # env: BACKEND_PID_FILE +disabled-plugins = [] + +hide-agents = true + + +[docker-registry] +ssl-verify = false + + +[logging] +level = "INFO" +drivers = ["console"] + +[logging.pkg-ns] +"" = "WARNING" +"aiotools" = "INFO" +"aiohttp" = "INFO" +"ai.backend" = "INFO" +"alembic" = "INFO" +"sqlalchemy" = "WARNING" + +[logging.console] +colored = true +format = "verbose" + +[logging.file] +path = "./logs" +filename = "manager.log" +backup-count = 5 +rotation-size = "10M" + +[logging.logstash] +endpoint = { host = "localhost", port = 9300 } +protocol = "tcp" +ssl-enabled = true +ssl-verify = true + + +[debug] +enabled = false diff --git a/configs/manager/sample.etcd.config.json b/configs/manager/sample.etcd.config.json new file mode 100644 index 0000000000..a07cb72be6 --- /dev/null +++ b/configs/manager/sample.etcd.config.json @@ -0,0 +1,50 @@ +{ + "system": { + "timezone": "UTC" + }, + "redis": { + "addr": "127.0.0.1:6379", + "password": "REDIS_PASSWORD" + }, + "docker": { + "registry": { + "cr.backend.ai": { + "": "https://cr.backend.ai", + "type": "harbor2", + "project": "stable,ngc" + } + } + }, + "idle": { + "enabled": "timeout", + "app-streaming-packet-timeout": "5m", + "checkers": { + "timeout": { + "threshold": "10m" + } + } + }, + "network": { + "subnet": { + "agent": "0.0.0.0/0", + "container": "0.0.0.0/0" + }, + "overlay": { + "mtu": 1500 + } + }, + "watcher": { + "token": "some-random-long-string" + }, + "plugins": { + "accelerator": { + "cuda": { + } + }, + "scheduler": { + "fifo": { + "num_retries_to_skip": 3 + } + } + } +} diff --git a/configs/manager/sample.etcd.redis-sentinel.json b/configs/manager/sample.etcd.redis-sentinel.json new file mode 100644 index 0000000000..bb27afc897 --- /dev/null +++ b/configs/manager/sample.etcd.redis-sentinel.json @@ -0,0 +1,4 @@ +{ + "sentinel": "127.0.0.1:8217,127.0.0.1:8218,127.0.0.1:8219", + "service_name": "manager" +} diff --git a/configs/manager/sample.etcd.redis-single.json b/configs/manager/sample.etcd.redis-single.json new file mode 100644 index 0000000000..7f9bb77a77 --- /dev/null +++ b/configs/manager/sample.etcd.redis-single.json @@ -0,0 +1,3 @@ +{ + "addr": "127.0.0.1:8111" +} diff --git a/configs/manager/sample.etcd.volumes.json b/configs/manager/sample.etcd.volumes.json new file mode 100644 index 0000000000..48ca9a86e3 --- /dev/null +++ b/configs/manager/sample.etcd.volumes.json @@ -0,0 +1,15 @@ +{ + "_types": { + "group": "", + "user": "" + }, + "default_host": "local:volume1", + "proxies": { + "local": { + "client_api": "http://client-accessible-hostname:6021", + "manager_api": "https://127.0.0.1:6022", + "secret": "some-secret-shared-with-storage-proxy", + "ssl_verify": "false" + } + } +} diff --git a/configs/manager/sample.toml b/configs/manager/sample.toml new file mode 100644 index 0000000000..3a24dd4517 --- /dev/null +++ b/configs/manager/sample.toml @@ -0,0 +1,163 @@ +[etcd] +namespace = "local" # env: BACKEND_NAMESPACE +addr = { host = "127.0.0.1", port = 2379 } # env: BACKEND_ETCD_ADDR (host:port) +user = "manager" # env: BACKEND_ETCD_USER +password = "ETCD_PASSWORD" # env: BACKEND_ETCD_PASSWORD + + +[db] +# One of: "postgresql" +# Currently we only support PostgreSQL. +type = "postgresql" + +# Address to access the database. +# NOTE: It is RECOMMENDED to use domain names in cloud setups because the IP addresses +# may change upon automatic upgrade or migration. +addr = { host = "localhost", port = 5432 } # env: BACKEND_DB_ADDR + +# Database name. +name = "backend" # env: BACKEND_DB_NAME + +# Database account credentials. +user = "postgres" # env: BACKEND_DB_USER +password = "DB_PASSWORD" # env: BACKEND_DB_PASSWORD + + +# NOTE: Redis settings are configured in etcd as it is shared by both the manager and agents. + + +[manager] +# The number of worker processes to handle API requests and event messages. +# If set zero, it uses the number of CPU cores in the system. +num-proc = 4 # env: BACKEND_MANAGER_NPROC + +# An arbitrary string used to salt when generating secret tokens (e.g., JWT) +# Currently only some plugins use this configuration in the manager. +secret = "XXXXXXXXXXXXXX" + +# Specify the user/group used for the manager daemon, +# to which the manager changes after reading the daemon configuration and SSL certifiactes. +# If not specified, it uses the file permission of ai/backend/manager/server.py +# This config is effective only when the manager daemon is started as the root user. +# Note that the vfolder (storage) permission must match with this. +user = "nobody" +group = "nobody" + +# Set the service hostname/port to accept API requests. +service-addr = { host = "0.0.0.0", port = 8080 } +# env: BACKEND_SERVICE_IP, BACKEND_SERVICE_PORT + +# Set the SSL certificate chain and the private keys used for serving the API requests. +ssl-enabled = false +#ssl-cert = "/etc/backend.ai/ssl/apiserver-fullchain.pem" # env: BACKNED_SSL_CERT +#ssl-privkey = "/etc/backend.ai/ssl/apiserver-privkey.pem" # env: BACKNED_SSL_KEY + +# Set the timeout for agent heartbeats in seconds. +heartbeat-timeout = 30.0 + +# Override the name of this manager node. +# If empty or unspecified, the agent builds this from the hostname by prefixing it with "i-", +# like "i-hostname". The "i-" prefix is not mandatory, though. +# Explicit configuration may be required if the hostname changes frequently, +# to handle the event bus messages consistently. +# This affects the per-node configuration scope. +# id = "" + +# Create a PID file so that daemon managers could keep track of us. +# If set to an empty string, it does NOT create the PID file. +# pid-file = "./manager.pid" # env: BACKEND_PID_FILE + +# The list of black-listed manager plugins. +disabled-plugins = [] + +# Hide agent and container IDs from GraphQL results unless the API requester is super-admin. +hide-agents = false + +# One of: "asyncio", "uvloop" +# This changes the event loop backend. +# uvloop is a fast libuv-based implementation but sometimes has +# compatibility issues. +# event-loop = "asyncio" + +# One of: "filelock", "pg_advisory", "redlock", "etcd" +# Choose the implementation of distributed lock. +# "filelock" is the simplest one when the manager is deployed on only one node. +# "pg_advisory" uses PostgreSQL's session-level advisory lock. +# "redlock" uses Redis-based distributed lock (Redlock) -- currently not supported. +# "etcd" uses etcd-based distributed lock via etcetra. +# distributed-lock = "pg_advisory" + +# The Docker image name that is used for importing external Docker images. +# You need to change this if your are at offline environments so that the manager +# uses an importer image from a private registry. +importer-image = "lablup/importer:manylinux2010" + + +[docker-registry] +# Enable or disable SSL certificate verification when accessing Docker registries. +ssl-verify = false # env: BACKEND_SKIP_SSLCERT_VALIDATION + + +[logging] +# One of: "NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" +# Set the global logging level. +level = "INFO" + +# Multi-choice of: "console", "logstash", "file" +# For each choice, there must be a "logging." section +# in this config file as exemplified below. +drivers = ["console", "file"] + + +[logging.console] +# If set true, use ANSI colors if the console is a terminal. +# If set false, always disable the colored output in console logs. +colored = true + +# One of: "simple", "verbose" +format = "verbose" + + +[logging.file] +# The log file path and filename pattern. +# All messages are wrapped in single-line JSON objects. +# Rotated logs may have additional suffixes. +# For production, "/var/log/backend.ai" is recommended. +path = "./logs" +filename = "manager.log" + +# The maximum number of rotated logs. +backup-count = 5 + +# The log file size to begin rotation. +rotation-size = "10M" + + +[logging.logstash] +# The endpoint to publish logstash records. +endpoint = { host = "localhost", port = 9300 } + +# One of: "zmq.push", "zmq.pub", "tcp", "udp" +protocol = "tcp" + +# SSL configs when protocol = "tcp" +ssl-enabled = true +ssl-verify = true + + +# Specify additional package namespaces to include in the logs +# and their individual log levels. +# Note that the actual logging level applied is the conjunction of the global logging level and the +# logging levels specified here for each namespace. +[logging.pkg-ns] +"" = "WARNING" +"aiotools" = "INFO" +"aiohttp" = "INFO" +"ai.backend" = "INFO" +"alembic" = "INFO" +"sqlalchemy" = "WARNING" + + +[debug] +enabled = false +periodic-sync-stats = false # periodically sync container stat from Redis to the kernels.last_stat column. diff --git a/configs/storage-proxy/sample.toml b/configs/storage-proxy/sample.toml new file mode 100644 index 0000000000..82bee8807b --- /dev/null +++ b/configs/storage-proxy/sample.toml @@ -0,0 +1,144 @@ +[etcd] +namespace = "local" # env: BACKEND_NAMESPACE +addr = { host = "127.0.0.1", port = 2379 } # env: BACKEND_ETCD_ADDR (host:port) +user = "" # env: BACKEND_ETCD_USER +password = "" # env: BACKEND_ETCD_PASSWORD + + +[storage-proxy] +# An identifier of this storage proxy, which must be unique in a cluster. +node-id = "i-storage-proxy-01" +num-proc = 4 +# The PID file for systemd or other daemon managers to keep track of. +# pid-file = "./storage-proxy.pid" +event-loop = "uvloop" + +# The maximum number of directory entries to return upon +# a scandir operation to prevent excessive loads/delays on filesystems. +# Settings it zero means no limit. +scandir-limit = 1000 + +# The maximum allowed size of a single upload session. +max-upload-size = "100g" + +# Used to generate JWT tokens for download/upload sessions +secret = "some-secret-private-for-storage-proxy" + +# The download/upload session tokens are valid for: +session-expire = "1d" + +# When executed as root (e.g., to bind under-1023 ports) +# it is recommended to set UID/GID to lower the privilege after port binding. +# If not specified, it defaults to the owner UID/GID of the "server.py" file +# of the installed package. +# user = 1000 +# group = 1000 + + +[api.client] +# Client-facing API +service-addr = { host = "0.0.0.0", port = 6021 } +ssl-enabled = false + + +[api.manager] +# Manager-facing API +# Recommended to have SSL and bind on a private IP only accessible by managers +service-addr = { host = "127.0.0.1", port = 6022 } +ssl-enabled = true +ssl-cert = "configs/storage-proxy/ssl/manager-api-selfsigned.cert.pem" +ssl-privkey = "configs/storage-proxy/ssl/manager-api-selfsigned.key.pem" + +# Used to authenticate managers +secret = "some-secret-shared-with-manager" + + +[debug] +# Enable the debug mode by overriding the global loglevel and "ai.backend" loglevel. +enabled = false + + +[logging] +# One of: "NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" +# Set the global logging level. +level = "INFO" + +# Multi-choice of: "console", "logstash", "file" +# For each choice, there must be a "logging." section +# in this config file as exemplified below. +drivers = ["console"] + +[logging.pkg-ns] +"" = "WARNING" +"aiotools" = "INFO" +"aiohttp" = "INFO" +"ai.backend" = "INFO" + +[logging.console] +# If set true, use ANSI colors if the console is a terminal. +# If set false, always disable the colored output in console logs. +colored = true + +# One of: "simple", "verbose" +format = "simple" + + +[volume] +# volume section may define one or more named subsections with +# backend-specific configurations. +# It is your job to prepare the "/vfroot" directory and its inner mount +# points from actual storage devices or using local partitions. + + +[volume.local] +# The default, generic filesystem. +# It uses the standard syscalls to perform filesystem operations +# such as scanning the file/directory metadata. +# This does *NOT* support per-directory quota. +backend = "vfs" +path = "/vfroot/vfs" + + +[volume.fastlocal] +# An extended version for XFS, which supports per-directory quota +# based on xfs projects. +backend = "xfs" +path = "/vfroot/xfs" + + +[volume.mypure] +# An extended version for PureStorage FlashBlade nodes, which uses +# RapidFile Tools to perform filesystem metadata queries and the +# FB REST APIs to configure per-directory quota. +backend = "purestorage" +path = "/vfroot/fb1" + +[volume.mypure.options] +purity_endpoint = "https://pure01.example.com" +purity_ssl_verify = false +purity_api_token = "T-11111111-2222-3333-4444-000000000000" +purity_api_version = "1.8" +purity_fs_name = "FB-NFS1" # the name of filesystem used by the filesystem API + + +[volume.myceph] +# An extended version for CephFS, which supports extended inode attributes +# for per-directory quota and fast metadata queries. +backend = "cephfs" +path = "/vfroot/ceph-fuse" + +[volume.netapp] +backend = "netapp" +path = "/vfroot/netapp" + +[volume.netapp.options] +netapp_endpoint = "https://netapp.example.com" +netapp_admin = "signed-in-id" +netapp_password = "signed-in-pw" +netapp_svm = "svm-name" +netapp_volume_name = "netapp-volume-name" +netapp_qtree_name = "bai_qtree" # default qtree name for backend.ai exclusive +netapp_xcp_container_name = "netapp-xcp" # only required when xcp activated by container +netapp_xcp_hostname = "xcp-hostname" +# default xcp catalog path goes to the directory named "catalog" of the first NetApp volume +netapp_xcp_catalog_path = "path for xcp-catalog" # Hint: execute command cat /opt/NetApp/xFiles/xcp/xcp.ini and see nfs mount path diff --git a/configs/storage-proxy/ssl/manager-api-selfsigned.cert.pem b/configs/storage-proxy/ssl/manager-api-selfsigned.cert.pem new file mode 100644 index 0000000000..8c5cd2aad0 --- /dev/null +++ b/configs/storage-proxy/ssl/manager-api-selfsigned.cert.pem @@ -0,0 +1,30 @@ +-----BEGIN CERTIFICATE----- +MIIFHTCCAwWgAwIBAgIUORKCw1L6FMO+MktzKPnH35L8bQkwDQYJKoZIhvcNAQEL +BQAwHjEcMBoGA1UEAwwTc3RvcmFnZS1wcm94eS5sb2NhbDAeFw0yMDA4MTUwNjEx +MjlaFw0yMTA4MTUwNjExMjlaMB4xHDAaBgNVBAMME3N0b3JhZ2UtcHJveHkubG9j +YWwwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDVAyQu+0afCEaGfPFt +MM7t8MDrpI/FiWPHp1LfQ6zFS/Y9QqSdkP+nKW6yuuBLVCJDbgnMfiyXkebzLB8z +UL6tgVWH3+erNKwb0obX7z6A09pqK+LtWH238ekJx+qr7UgrO9fmfz7Y6w8NpBDf +hall8KDJk1FJqCXX0DOLenUbZPyV2dW1cMoa2AOfYNG7gvW3fVQjU/MGmTA1Wl7D +aHiW9VvSVf7C5tjiHQuXTgCSjBi5/+s7Azs5cRbqgn82ssLZNdllYHJjnaWa85vT +8Ww0kHSY3bCkBFveQs/CiEElv88IkpBgL5MR54O7dxLUouF2excw8gXOo3EdACWw +k9pk0lR5aOe4exUEIr8JCRTMvHqEfzXG511T9NrOmxTFCEVzKDGMacYqKHugkIjZ +V6DZidC1XOYk5bAl6jJSD+vnKksxUzbrhBdza31iNW4B7l8UzHrv+9/o7KEFUe+j +ypLSB3OTmJB5+wr5NhzGGiJy2DQ3yfTWI10qBGRtKaYm0d2dajNmpEyyO+mB3S0E +/WhRYuFhDS6j51sayWUYWn3/GCBrxgV31fIO6WHn1O4Eanevd6dPf/qCQnKlyzyU +uohJSGspi31yBxUwR6hzQE+5E6LbtiO+uzWFnhD6lbNahsgpwLayM9hFsL5o78Dq +AzojCteiukr6c8Vyva04F0YRgwIDAQABo1MwUTAdBgNVHQ4EFgQU0OP5rGBBbWCC +FAIbDegWUBahLlcwHwYDVR0jBBgwFoAU0OP5rGBBbWCCFAIbDegWUBahLlcwDwYD +VR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAgEAEmNtpv/RICBTN4MtyhVO +yPHHMXr9SrpfaeflClEW4oILg71WuyzWkAwQuBHEQCCp2ybDGXAlmLtvGlIiojD0 +UPiQbfTDTxh1trMKpLpYanlBsV8k4RupVLe1sB00e2hTq4Igxzfcm+4lnclTGImq +fVLOaJs5JmHmEzkix3c9mUj6lfIibHKDZQ/62mf6W1/BUk2u7mLY62RBtl5Dlnft +0FcLcz4NG17tMAi0D1KSATstliqTO6uc7ZOxUGBJ/ovhvVFlYzwbL6uve1dsdYnr +zFvIhXlsKHZeGgrb+mE/XDQ5PmjIyWy4WR38+OB2vDZ0+OaQyXqG7CXwl7Rubfi6 +m8F+gtREy/FZLtBkoqMVQ3lNivbxCY0iCSQdL2npF/ZqfmMcTtMiUC8zBsQ9U7VG +MPD9B0jzZ7v+gRIQhDrO99/6hpLpqthGHKAyhn0PqFFuBip9jC7JaFS6Vul8eio/ +BkAkIAC13cisrBPbjlxZJMZ/J8QyBoIwAyC7ulohADmpQxMczeyNN17J8N9O/xJ6 +U9tO6bY6elOWQWTN1WhTB+K0+UByNLdiRjwkBKK5KZc897l59BvIsa8y//ucwdiz +OW+RmiO2aVvND44fJ4zRQWBlciQKxmbF3Mm9AlrOK5qpPswMMVVvfw26dsPRguVR +psbTWD/iXTIaFf/manP0YpA= +-----END CERTIFICATE----- diff --git a/configs/storage-proxy/ssl/manager-api-selfsigned.key.pem b/configs/storage-proxy/ssl/manager-api-selfsigned.key.pem new file mode 100644 index 0000000000..5b222d9fa0 --- /dev/null +++ b/configs/storage-proxy/ssl/manager-api-selfsigned.key.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJRAIBADANBgkqhkiG9w0BAQEFAASCCS4wggkqAgEAAoICAQDVAyQu+0afCEaG +fPFtMM7t8MDrpI/FiWPHp1LfQ6zFS/Y9QqSdkP+nKW6yuuBLVCJDbgnMfiyXkebz +LB8zUL6tgVWH3+erNKwb0obX7z6A09pqK+LtWH238ekJx+qr7UgrO9fmfz7Y6w8N +pBDfhall8KDJk1FJqCXX0DOLenUbZPyV2dW1cMoa2AOfYNG7gvW3fVQjU/MGmTA1 +Wl7DaHiW9VvSVf7C5tjiHQuXTgCSjBi5/+s7Azs5cRbqgn82ssLZNdllYHJjnaWa +85vT8Ww0kHSY3bCkBFveQs/CiEElv88IkpBgL5MR54O7dxLUouF2excw8gXOo3Ed +ACWwk9pk0lR5aOe4exUEIr8JCRTMvHqEfzXG511T9NrOmxTFCEVzKDGMacYqKHug +kIjZV6DZidC1XOYk5bAl6jJSD+vnKksxUzbrhBdza31iNW4B7l8UzHrv+9/o7KEF +Ue+jypLSB3OTmJB5+wr5NhzGGiJy2DQ3yfTWI10qBGRtKaYm0d2dajNmpEyyO+mB +3S0E/WhRYuFhDS6j51sayWUYWn3/GCBrxgV31fIO6WHn1O4Eanevd6dPf/qCQnKl +yzyUuohJSGspi31yBxUwR6hzQE+5E6LbtiO+uzWFnhD6lbNahsgpwLayM9hFsL5o +78DqAzojCteiukr6c8Vyva04F0YRgwIDAQABAoICAE2c6mcv6Shy9Hd2OOVnHk62 +JST2/eekyrVpcxmkZ+QvbFYf3SkINw5qW0pGKwlna2CUTH+1DXxgjfzWe7vP1NoV +QNvUKm5IL9mnWLh/FJAOIQwEV5fRYWVPgHCu5gOk3mHaWS1D+dtBsmdu+zLmWbWp ++nvX8Itc+ATteIp+oQLNRfc3utV0dj9Xq+I7fc/LxMoJoUmKAgfKczVNay/Z9e5T +EhTOfpf3Utj4akvEARNkdnH3HHeREtg4K3hg+gctFS/dnguBG8zOGRQfZQzRb3Kr +m8BDLmRkrjCKuXaQ/OPLQp1GAdL6IqUobOg2V6cgffQUn3uXA6SqCsHdVt5C0RS2 +ybrAsW+3jUsBAEf2YhA6vNJbRmMiSuTuWPEjyfrCnyPA5NsVKMus9NWAHFCla7Wh +00gB5oGt5fPAFzSSxb6ETD/8VVioCNZtApcMcrGLvv71FxgyweEgHk+veCgjvDUI +puw43N2sTP5kZp71rQEsZpxBJ+LqkfBapBp/HLBevk7F/bPsO1GPdB+lz0mHDR4w +JHq7zhj7vm5nfhu7+VYiLISjlo6PRsgElDsJYIHPhP7FfE7SamLMwtZkniYx9+0h +gsri2SV9Z/f+TbYwGWJYVyWVsdUpAhB6wEE8YtY2csIl+v0bRhcnEVhnFjGWM+Qk +UlrihEki8eniUfBFWVHxAoIBAQD4hjWM7Ht65a7skawFfknfjX4aSIVPRw7QwUgY +/2DTG7iVud11flNl9HHdgMafj9tMJcQsLt7fes5xfR21m6TysMOFnhT53uOrSSkW +SsqtVZmghZ47/5A194/ey7FtR8pcfnuNm+3x6xykK4+ULo3LwBCMw8oRf+JECJ5A +pUqXWgD+6Ovs9qCb5oK8PWE6v92EQD7jP/tUPPDxZ7MtO5DoZx44AdfojI+EsAQJ +pXKodURjNPB0Ap2jaFLYlFaaRsGvu6Zw1vA/hNKlJQlc5wUcSybJ3y7aXXQwEz55 +psQ4LvCSG9BOETBArR73dox07ouXL+Esrs/Nb7FRfbl4KZupAoIBAQDba3fGc8GV +fgGY/KuUO0gzjHkjFTf7JWfELs0CrWJFgKB8/4Pg0uYrD3AV6SK/0wmcHkGHhtpB +WQjjYSLg1Wrm4PUdHqd2D4oEUAK2rXsiBuPQ0ib8dVKoBgNJhjeOnYoARfjsQAwa +gZ3LyoSe2Lrj9AMiWGxu9DO1SjmL/QncrdsW1rfNsnDWkqR0G61BW/fstZRFpanU +qOuL44CT/BG2+upuCvpAS87BQ/7xcYKAggjZ/GAyY5+LzO9/ZJjSbthXFhrpznWd +x8bgjqpErAN8FMyuhdIR2JPT5PhXp9PRxwxMFeQXrjUE8NfPoZI+cTe2p4oYpSo9 +aSlxdkb03B9LAoIBAQDuRTY4S05D0Mv50dwcVC0dYPnC1z7AeD8TFAw8szOwkwZQ +vqc9i4UH3eoBVQKZXoIBKsA/IBzcJFCjbDI8uOHXMHP0ulAgFHsw8G6tCb3xm30z +8od3vJkVtKlCEQal4Et6jGWGqjXiV+jY2U7J4ixeaWE0pE4qovJbCew3zIGMRGVq +AASZ5warEcDMGwkKG4OU1Ue19tTOubLDsAmQV2ih+KN3TQUk1waOT3c/jFk2e1+6 +wbm7b/qU/WNYdDfnp+jwxDdaPiiOrZiLbsvgPzmeG8svhoPUJf9MTFb8qU+9Efvi +CTqYblBV4eUrmoQlY8N/iw8XGO7wZqKZn7FfLH0pAoIBAQDVsB7QzZEdFr9Vj9VF +okpZsZeT3ClN1IrvG6kaz2KiT3d85Jc50ArKqtk87XSjHlMAkNK+u005UQ1/6+y8 +y/u1aCYuutjZ+J4wPv+1itQdOlqJ1vCS82uRKXHwG99la/Wm+H5JDsL62XqkqtFn +pai1McIPo8/OatMk3mmW9NKy8ToqTuhoUjzkK6IvVXjw6zFTfS8ueP0hl3T2IpTp +ChbyfgDrNJOtJGbx/1d9Kk+u+XTHXqsMx8rsqqQgDAdtAPh+L4/8Xc5b3+DGdwpJ +oMAwCk2gNcF1Edg+B4L4UwDODyzhuHwPt9/4tNloY3D6kOZyY03xXID7l7v9vPOz +qeBrAoIBAQCsEUGimITSLt2XyB892vcIjMPWBsKdu4U9TUjWmijwod+sRlutDR/O +mIe6nVAFG2Z6qTIlHbCQ8KDzEwp34ZzPS+Abjs3iA0K8g5V6giybP5avblyWZN79 +MbRclrfL78kCLM3dwVqwQtLMhuy1ivABk4lzMLh1RCZqHH7hcokkDbwcOXtiug7F +4SD8imq+zagAV14MZKQm9pP6oMK//mYJNqv91CmftjHaE13qxSABsQB5jBkBglyq +tYuPqS6laAC4gbGtum4UGIJd7IzeOjr+hSYbLPTd/iysBITrEERg9QFVk7tdhD0k +DSjFozn22Pk72SaIi7w0Bf5CLjSkrNkV +-----END PRIVATE KEY----- diff --git a/configs/webserver/sample.conf b/configs/webserver/sample.conf new file mode 100644 index 0000000000..d63fc8a2ef --- /dev/null +++ b/configs/webserver/sample.conf @@ -0,0 +1,106 @@ +[service] +ip = "0.0.0.0" +port = 8080 +wsproxy.url = "" + +# If you need ssl-enabled server, fill these configurations. +#ssl-enabled = true +#ssl-cert = +#ssl-privkey = + +# Set or enable it when using nginx proxy +#force-endpoint-protocol = "https" + +# "webui" mode is for serving "backend.ai-webui" PWA, +# where non-existent URLs are fall-back to "index.html" +# and "config.ini" is generated on the fly. +# "static" mode is for serving the static directory as-is, +# without any fallback or hooks. +mode = "webui" +# Enable signup feature support. +enable_signup = false +# Let anonymous user can request an email with a link to change password. +allow_anonymous_change_password = false +# Allow users to see user's current project resource monitor +allow_project_resource_monitor = false +# Allow users to change signin mode between ID/Password and IAM mode +allow_change_signin_mode = false +# Allow users to use the specific environment image by typing the exact name. +allow_manual_image_name_for_session = false +# Allow users can sign up without confirmation such as token or email confirmation. +allow_signup_without_confirmation = false +# Debug mode for Web-UI (not Webserver's debug flag) +webui_debug = false +# Enable masking user information +mask_user_info = false + +[resources] +# Display "Open port to public" checkbox in the app launcher. +# If checked, the app will be accessible by anyone who has network to the URL. +open_port_to_public = false +# Maximum CPU cores allowed per container (int) +max_cpu_cores_per_container = 64 +# Maximum memory allowed per container (int) +max_memory_per_container = 64 +# Maximum CUDA devices allowed per container (int) +max_cuda_devices_per_container = 16 +# Maximum CUDA fGPUs allowed per container (int) +max_cuda_shares_per_container = 16 +# Maximum shared memory allowed per container (float) +max_shm_per_container = 2 +# Maximum per-file upload size (bytes) +max_file_upload_size = 4294967296 + +[environments] +# Comma-separated string +# Image name should contain the repository (registry path and image name) part of the full image URL, excluding the protocol and tag +# e.g. cr.backend.ai/stable/python +# You should pick default_environment in ui section too. +#allowlist = "" + +[plugin] +# Comma-separated string +# Should be same as plugin file in web UI plugin directory. +#page = "" + +[ui] +brand = "Lablup Cloud" +# Default environment to show on session launcher +# default_environment = 'index.docker.io/lablup/python-tensorflow' +# Default environment to import GitHub repositories / notebooks +# default_import_environment = 'index.docker.io/lablup/python:3.6-ubuntu18.04' +# Comma-separated sidebar menu pages +#menu_blocklist = "statistics" + +[api] +domain = "default" +endpoint = "https://api.backend.ai" +# endpoint = "https://api.backend.ai,https://alt.backend.ai" # for HA manager endpoints +text = "Backend.AI Cloud" +ssl-verify = true +# Cookie key to be used for token-based login +auth_token_name = 'sToken' + +[session] +redis.host = "localhost" +redis.port = 6379 +# redis.db = 0 +# redis.password = "mysecret" +max_age = 604800 # 1 week +flush_on_startup = false +# Time to block login when an email consecutively fails to login +login_block_time = 1200 # 20 min (in sec) +# Number of allowed consecutive failed logins. If this user fails +# to login consecutively over this number, login with the account +# is blocked for ``block_time``. +login_allowed_fail_count = 10 +# Auto logout when user closes all Backend.AI web UI tab / window +#auto_logout = false + +# Add a manually configured license information shown in the UI. +#[license] +#edition = "Open Source" +#valid_since = "" +#valid_until = "" + +# vim: ft=toml diff --git a/docker-compose.halfstack-2203.yml b/docker-compose.halfstack-2203.yml index ea79b2cc66..14124fd1a6 100644 --- a/docker-compose.halfstack-2203.yml +++ b/docker-compose.halfstack-2203.yml @@ -14,7 +14,7 @@ services: - POSTGRES_PASSWORD=develove - POSTGRES_DB=backend volumes: - - "./tmp/backend.ai-halfstack/${DATADIR_PREFIX:-.}/postgres-data:/var/lib/postgresql/data:rw" + - "./volumes/${DATADIR_PREFIX:-.}/postgres-data:/var/lib/postgresql/data:rw" healthcheck: test: ["CMD", "pg_isready", "-U", "postgres"] interval: 5s @@ -29,7 +29,7 @@ services: ports: - "8110:6379" volumes: - - "./tmp/backend.ai-halfstack/${DATADIR_PREFIX:-.}/redis-data:/data:rw" + - "./volumes/${DATADIR_PREFIX:-.}/redis-data:/data:rw" command: > redis-server --appendonly yes @@ -43,7 +43,7 @@ services: image: quay.io/coreos/etcd:v3.5.4 restart: unless-stopped volumes: - - "./tmp/backend.ai-halfstack/${DATADIR_PREFIX:-.}/etcd-data:/etcd-data:rw" + - "./volumes/${DATADIR_PREFIX:-.}/etcd-data:/etcd-data:rw" networks: - half ports: diff --git a/docker/linuxkit-nsenter/Dockerfile b/docker/linuxkit-nsenter/Dockerfile new file mode 100644 index 0000000000..dab9d82836 --- /dev/null +++ b/docker/linuxkit-nsenter/Dockerfile @@ -0,0 +1,4 @@ +FROM justincormack/nsenter1 +LABEL ai.backend.system=1 \ + ai.backend.version=1 +ENTRYPOINT ["/usr/bin/nsenter1"] diff --git a/docker/socket-relay/Dockerfile b/docker/socket-relay/Dockerfile new file mode 100644 index 0000000000..6bb13b7d63 --- /dev/null +++ b/docker/socket-relay/Dockerfile @@ -0,0 +1,8 @@ +FROM alpine:3.12 + +RUN apk add --no-cache socat # 1.7.3.4-r1 + +ENTRYPOINT ["/usr/bin/socat"] + +LABEL ai.backend.system=1 \ + ai.backend.version=1 diff --git a/fixtures/manager/example-keypairs.json b/fixtures/manager/example-keypairs.json new file mode 100644 index 0000000000..5d76bc821c --- /dev/null +++ b/fixtures/manager/example-keypairs.json @@ -0,0 +1,172 @@ +{ + "domains": [ + { + "name": "default", + "description": "The default domain", + "is_active": true, + "total_resource_slots": {}, + "allowed_vfolder_hosts": {}, + "allowed_docker_registries": ["cr.backend.ai"] + } + ], + "groups": [ + { + "id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "name": "default", + "description": "The default user group", + "is_active": true, + "domain_name": "default", + "total_resource_slots": {}, + "allowed_vfolder_hosts": {} + } + ], + "scaling_groups": [ + { + "name": "default", + "description": "The default agent scaling group", + "is_active": true, + "driver": "static", + "driver_opts": {}, + "scheduler": "fifo", + "scheduler_opts": {} + } + ], + "sgroups_for_domains": [ + { + "scaling_group": "default", + "domain": "default" + } + ], + "users": [ + { + "uuid": "f38dea23-50fa-42a0-b5ae-338f5f4693f4", + "username": "admin", + "email": "admin@lablup.com", + "password": "wJalrXUt", + "need_password_change": false, + "full_name": "Admin Lablup", + "description": "Lablup's Admin Account", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "role": "superadmin" + }, + { + "uuid": "4f13d193-f646-425a-a340-270c4d2b9860", + "username": "domain-admin", + "email": "domain-admin@lablup.com", + "password": "cWbsM_vB", + "need_password_change": false, + "full_name": "Default Domain Admin Lablup", + "description": "Lablup's Default Domain Admin Account", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "role": "admin" + }, + { + "uuid": "dfa9da54-4b28-432f-be29-c0d680c7a412", + "username": "user", + "email": "user@lablup.com", + "password": "C8qnIo29", + "need_password_change": false, + "full_name": "User Lablup", + "description": "Lablup's User Account", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "role": "user" + }, + { + "uuid": "2e10157d-20ca-4bd0-9806-3f909cbcd0e6", + "username": "monitor", + "email": "monitor@lablup.com", + "password": "7tuEwF1J", + "need_password_change": false, + "full_name": "Monitor Lablup", + "description": "Lablup's Monitor Account", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "role": "monitor" + } + ], + "association_groups_users": [ + { + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_id": "f38dea23-50fa-42a0-b5ae-338f5f4693f4" + }, + { + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_id": "4f13d193-f646-425a-a340-270c4d2b9860" + }, + { + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_id": "dfa9da54-4b28-432f-be29-c0d680c7a412" + }, + { + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_id": "2e10157d-20ca-4bd0-9806-3f909cbcd0e6" + } + ], + "keypair_resource_policies": [ + { + "name": "default", + "default_for_unspecified": "UNLIMITED", + "total_resource_slots": {}, + "max_session_lifetime": 0, + "max_concurrent_sessions": 5, + "max_containers_per_session": 1, + "max_vfolder_count": 10, + "max_vfolder_size": 0, + "idle_timeout": 3600, + "allowed_vfolder_hosts": ["local:volume1"] + } + ], + "keypairs": [ + { + "user_id": "admin@lablup.com", + "access_key": "AKIAIOSFODNN7EXAMPLE", + "secret_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "is_active": true, + "resource_policy": "default", + "rate_limit": 30000, + "num_queries": 0, + "is_admin": true, + "user": "f38dea23-50fa-42a0-b5ae-338f5f4693f4" + }, + { + "user_id": "domain-admin@lablup.com", + "access_key": "AKIAHUKCHDEZGEXAMPLE", + "secret_key": "cWbsM_vBB4CzTW7JdORRMx8SjGI3-wEXAMPLEKEY", + "is_active": true, + "resource_policy": "default", + "rate_limit": 30000, + "num_queries": 0, + "is_admin": true, + "user": "4f13d193-f646-425a-a340-270c4d2b9860" + }, + { + "user_id": "user@lablup.com", + "access_key": "AKIANABBDUSEREXAMPLE", + "secret_key": "C8qnIo29EZvXkPK_MXcuAakYTy4NYrxwmCEyNPlf", + "is_active": true, + "resource_policy": "default", + "rate_limit": 30000, + "num_queries": 0, + "is_admin": false, + "user": "dfa9da54-4b28-432f-be29-c0d680c7a412" + }, + { + "user_id": "monitor@lablup.com", + "access_key": "AKIANAMONITOREXAMPLE", + "secret_key": "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", + "is_active": true, + "resource_policy": "default", + "rate_limit": 30000, + "num_queries": 0, + "is_admin": false, + "user": "2e10157d-20ca-4bd0-9806-3f909cbcd0e6" + } + ] +} diff --git a/fixtures/manager/example-resource-presets.json b/fixtures/manager/example-resource-presets.json new file mode 100644 index 0000000000..63c4169013 --- /dev/null +++ b/fixtures/manager/example-resource-presets.json @@ -0,0 +1,33 @@ +{ + "resource_presets": [ + { + "name": "01-small", + "resource_slots": { + "cpu": "8", + "mem": "34359738368", + "cuda.device": "1", + "cuda.shares": "0.5" + } + }, + { + "name": "02-medium", + "resource_slots": { + "cpu": "24", + "mem": "171798691840", + "cuda.device": "2", + "cuda.shares": "2.0" + }, + "shared_memory": "1073741824" + }, + { + "name": "03-large", + "resource_slots": { + "cpu": "64", + "mem": "343597383680", + "cuda.device": "4", + "cuda.shares": "4.0" + }, + "shared_memory": "2147483648" + } + ] +} diff --git a/fixtures/manager/example-session-templates.json b/fixtures/manager/example-session-templates.json new file mode 100755 index 0000000000..4e69cca076 --- /dev/null +++ b/fixtures/manager/example-session-templates.json @@ -0,0 +1,71 @@ +{ + "session_templates": [ + { + "id": "c1b8441a-ba46-4a83-8727-de6645f521b4", + "is_active": true, + "domain_name": "default", + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_uuid": "f38dea23-50fa-42a0-b5ae-338f5f4693f4", + "type": "TASK", + "name": "jupyter", + "template": { + "api_version": "6", + "kind": "task_template", + "metadata": { + "name": "cr.backend.ai/testing/ngc-pytorch", + "tag": "20.11-py3" + }, + "spec": { + "session_type": "interactive", + "kernel": { + "image": "cr.backend.ai/testing/ngc-pytorch:20.11-py3", + "environ": {}, + "run": null, + "git": null + }, + "scaling_group": "default", + "mounts": { + }, + "resources": { + "cpu": "2", + "mem": "4g", + "cuda.shares": "0.2" + } + } + } + }, + { + "id": "59062449-4f57-4434-975d-add2a593438c", + "is_active": true, + "domain_name": "default", + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_uuid": "f38dea23-50fa-42a0-b5ae-338f5f4693f4", + "type": "TASK", + "name": "rstudio", + "template": { + "api_version": "6", + "kind": "task_template", + "metadata": { + "name": "cr.backend.ai/cloud/r-base", + "tag": "4.0" + }, + "spec": { + "session_type": "interactive", + "kernel": { + "image": "cr.backend.ai/cloud/r-base:4.0", + "environ": {}, + "run": null, + "git": null + }, + "scaling_group": "default", + "mounts": { + }, + "resources": { + "cpu": "1", + "mem": "2g" + } + } + } + } + ] +} diff --git a/pants b/pants new file mode 100755 index 0000000000..5c84ca4e97 --- /dev/null +++ b/pants @@ -0,0 +1,395 @@ +#!/usr/bin/env bash +# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +# =============================== NOTE =============================== +# This ./pants bootstrap script comes from the pantsbuild/setup +# project. It is intended to be checked into your code repository so +# that other developers have the same setup. +# +# Learn more here: https://www.pantsbuild.org/docs/installation +# ==================================================================== + +set -eou pipefail + +# NOTE: To use an unreleased version of Pants from the pantsbuild/pants main branch, +# locate the main branch SHA, set PANTS_SHA= in the environment, and run this script as usual. +# +# E.g., PANTS_SHA=725fdaf504237190f6787dda3d72c39010a4c574 ./pants --version + +PYTHON_BIN_NAME="${PYTHON:-unspecified}" + +# Set this to specify a non-standard location for this script to read the Pants version from. +# NB: This will *not* cause Pants itself to use this location as a config file. +# You can use PANTS_CONFIG_FILES or --pants-config-files to do so. +PANTS_TOML=${PANTS_TOML:-pants.toml} + +PANTS_BIN_NAME="${PANTS_BIN_NAME:-$0}" + +PANTS_SETUP_CACHE="${PANTS_SETUP_CACHE:-${XDG_CACHE_HOME:-$HOME/.cache}/pants/setup}" +# If given a relative path, we fix it to be absolute. +if [[ "$PANTS_SETUP_CACHE" != /* ]]; then + PANTS_SETUP_CACHE="${PWD}/${PANTS_SETUP_CACHE}" +fi + +PANTS_BOOTSTRAP="${PANTS_SETUP_CACHE}/bootstrap-$(uname -s)-$(uname -m)" + +_PEX_VERSION=2.1.62 +_PEX_URL="https://github.com/pantsbuild/pex/releases/download/v${_PEX_VERSION}/pex" +_PEX_EXPECTED_SHA256="56668b1ca147bd63141e586ffee97c7cc51ce8e6eac6c9b7a4bf1215b94396e5" + +VIRTUALENV_VERSION=20.4.7 +VIRTUALENV_REQUIREMENTS=$( +cat << EOF +virtualenv==${VIRTUALENV_VERSION} --hash sha256:2b0126166ea7c9c3661f5b8e06773d28f83322de7a3ff7d06f0aed18c9de6a76 +filelock==3.0.12 --hash sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836 +six==1.16.0 --hash sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 +distlib==0.3.2 --hash sha256:23e223426b28491b1ced97dc3bbe183027419dfc7982b4fa2f05d5f3ff10711c +appdirs==1.4.4 --hash sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128 +importlib-resources==5.1.4; python_version < "3.7" --hash sha256:e962bff7440364183203d179d7ae9ad90cb1f2b74dcb84300e88ecc42dca3351 +importlib-metadata==4.5.0; python_version < "3.8" --hash sha256:833b26fb89d5de469b24a390e9df088d4e52e4ba33b01dc5e0e4f41b81a16c00 +zipp==3.4.1; python_version < "3.10" --hash sha256:51cb66cc54621609dd593d1787f286ee42a5c0adbb4b29abea5a63edc3e03098 +typing-extensions==3.10.0.0; python_version < "3.8" --hash sha256:779383f6086d90c99ae41cf0ff39aac8a7937a9283ce0a414e5dd782f4c94a84 +EOF +) + +COLOR_RED="\x1b[31m" +COLOR_GREEN="\x1b[32m" +COLOR_YELLOW="\x1b[33m" +COLOR_RESET="\x1b[0m" + +function log() { + echo -e "$@" 1>&2 +} + +function die() { + (($# > 0)) && log "${COLOR_RED}$*${COLOR_RESET}" + exit 1 +} + +function green() { + (($# > 0)) && log "${COLOR_GREEN}$*${COLOR_RESET}" +} + +function warn() { + (($# > 0)) && log "${COLOR_YELLOW}$*${COLOR_RESET}" +} + +function tempdir { + mkdir -p "$1" + mktemp -d "$1"/pants.XXXXXX +} + +function get_exe_path_or_die { + local exe="$1" + if ! command -v "${exe}"; then + die "Could not find ${exe}. Please ensure ${exe} is on your PATH." + fi +} + +function get_pants_config_string_value { + local config_key="$1" + local optional_space="[[:space:]]*" + local prefix="^${config_key}${optional_space}=${optional_space}" + local raw_value + raw_value="$(sed -ne "/${prefix}/ s|${prefix}||p" "${PANTS_TOML}")" + local optional_suffix="${optional_space}(#.*)?$" + echo "${raw_value}" \ + | sed -E \ + -e "s|^'([^']*)'${optional_suffix}|\1|" \ + -e 's|^"([^"]*)"'"${optional_suffix}"'$|\1|' \ + && return 0 + return 0 +} + +function get_python_major_minor_version { + local python_exe="$1" + "$python_exe" </dev/null 2>&1; then + continue + fi + if [[ -n "$(check_python_exe_compatible_version "${interpreter_path}")" ]]; then + echo "${interpreter_path}" && return 0 + fi + done +} + +function determine_python_exe { + local pants_version="$1" + set_supported_python_versions "${pants_version}" + local requirement_str="For \`pants_version = \"${pants_version}\"\`, Pants requires Python ${supported_message} to run." + + local python_exe + if [[ "${PYTHON_BIN_NAME}" != 'unspecified' ]]; then + python_exe="$(get_exe_path_or_die "${PYTHON_BIN_NAME}")" || exit 1 + if [[ -z "$(check_python_exe_compatible_version "${python_exe}")" ]]; then + die "Invalid Python interpreter version for ${python_exe}. ${requirement_str}" + fi + else + python_exe="$(determine_default_python_exe)" + if [[ -z "${python_exe}" ]]; then + die "No valid Python interpreter found. ${requirement_str} Please check that a valid interpreter is installed and on your \$PATH." + fi + fi + echo "${python_exe}" +} + +function compute_sha256 { + local python="$1" + local path="$2" + + "$python" <&2 || exit 1 + fi + echo "${bootstrapped}" +} + +function scrub_env_vars { + # Ensure the virtualenv PEX runs as shrink-wrapped. + # See: https://github.com/pantsbuild/setup/issues/105 + if [[ -n "${!PEX_@}" ]]; then + warn "Scrubbing ${!PEX_@}" + unset "${!PEX_@}" + fi + # Also ensure pip doesn't think packages on PYTHONPATH + # are already installed. + if [ -n "${PYTHONPATH:-}" ]; then + warn "Scrubbing PYTHONPATH" + unset PYTHONPATH + fi +} + +function bootstrap_virtualenv { + local python="$1" + local bootstrapped="${PANTS_BOOTSTRAP}/virtualenv-${VIRTUALENV_VERSION}/virtualenv.pex" + if [[ ! -f "${bootstrapped}" ]]; then + ( + green "Creating the virtualenv PEX." + pex_path="$(bootstrap_pex "${python}")" || exit 1 + mkdir -p "${PANTS_BOOTSTRAP}" + local staging_dir + staging_dir=$(tempdir "${PANTS_BOOTSTRAP}") + cd "${staging_dir}" + echo "${VIRTUALENV_REQUIREMENTS}" > requirements.txt + ( + scrub_env_vars + "${python}" "${pex_path}" -r requirements.txt -c virtualenv -o virtualenv.pex + ) + mkdir -p "$(dirname "${bootstrapped}")" + mv -f "${staging_dir}/virtualenv.pex" "${bootstrapped}" + rm -rf "${staging_dir}" + ) 1>&2 || exit 1 + fi + echo "${bootstrapped}" +} + +function find_links_url { + local pants_version="$1" + local pants_sha="$2" + echo -n "https://binaries.pantsbuild.org/wheels/pantsbuild.pants/${pants_sha}/${pants_version/+/%2B}/index.html" +} + +function get_version_for_sha { + local sha="$1" + + # Retrieve the Pants version associated with this commit. + local pants_version + pants_version="$(curl --proto "=https" \ + --tlsv1.2 \ + --fail \ + --silent \ + --location \ + "https://raw.githubusercontent.com/pantsbuild/pants/${sha}/src/python/pants/VERSION")" + + # Construct the version as the release version from src/python/pants/VERSION, plus the string `+gitXXXXXXXX`, + # where the XXXXXXXX is the first 8 characters of the SHA. + echo "${pants_version}+git${sha:0:8}" +} + +function bootstrap_pants { + local pants_version="$1" + local python="$2" + local pants_sha="${3:-}" + + local pants_requirement="pantsbuild.pants==${pants_version}" + local maybe_find_links + if [[ -z "${pants_sha}" ]]; then + maybe_find_links="" + else + maybe_find_links="--find-links=$(find_links_url "${pants_version}" "${pants_sha}")" + fi + local python_major_minor_version + python_major_minor_version="$(get_python_major_minor_version "${python}")" + local target_folder_name="${pants_version}_py${python_major_minor_version}" + local bootstrapped="${PANTS_BOOTSTRAP}/${target_folder_name}" + + if [[ ! -d "${bootstrapped}" ]]; then + ( + green "Bootstrapping Pants using ${python}" + local staging_dir + staging_dir=$(tempdir "${PANTS_BOOTSTRAP}") + local virtualenv_path + virtualenv_path="$(bootstrap_virtualenv "${python}")" || exit 1 + green "Installing ${pants_requirement} into a virtual environment at ${bootstrapped}" + ( + scrub_env_vars + # shellcheck disable=SC2086 + "${python}" "${virtualenv_path}" --quiet --no-download "${staging_dir}/install" && \ + # Grab the latest pip, but don't advance setuptools past 58 which drops support for the + # `setup` kwarg `use_2to3` which Pants 1.x sdist dependencies (pystache) use. + "${staging_dir}/install/bin/pip" install --quiet -U pip "setuptools<58" && \ + "${staging_dir}/install/bin/pip" install ${maybe_find_links} --quiet --progress-bar off "${pants_requirement}" + ) && \ + ln -s "${staging_dir}/install" "${staging_dir}/${target_folder_name}" && \ + mv "${staging_dir}/${target_folder_name}" "${bootstrapped}" && \ + green "New virtual environment successfully created at ${bootstrapped}." + ) 1>&2 || exit 1 + fi + echo "${bootstrapped}" +} + +# Ensure we operate from the context of the ./pants buildroot. +cd "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)" +pants_version="$(determine_pants_version)" +python="$(determine_python_exe "${pants_version}")" +pants_dir="$(bootstrap_pants "${pants_version}" "${python}" "${PANTS_SHA:-}")" || exit 1 + +pants_python="${pants_dir}/bin/python" +pants_binary="${pants_dir}/bin/pants" +pants_extra_args="" +if [[ -n "${PANTS_SHA:-}" ]]; then + pants_extra_args="${pants_extra_args} --python-repos-repos=$(find_links_url "$pants_version" "$PANTS_SHA")" +fi + +# shellcheck disable=SC2086 +exec "${pants_python}" "${pants_binary}" ${pants_extra_args} \ + --pants-bin-name="${PANTS_BIN_NAME}" --pants-version=${pants_version} "$@" diff --git a/pants.ci.toml b/pants.ci.toml new file mode 100644 index 0000000000..cf98e4ab5a --- /dev/null +++ b/pants.ci.toml @@ -0,0 +1,6 @@ +[GLOBAL] +dynamic_ui = false +colors = true + +[pytest] +args = ["-vv", "--no-header"] diff --git a/pants.toml b/pants.toml new file mode 100644 index 0000000000..780a6e0837 --- /dev/null +++ b/pants.toml @@ -0,0 +1,68 @@ +[GLOBAL] +pants_version = "2.11.0" +pythonpath = ["%(buildroot)s/tools/pants-plugins"] +local_execution_root_dir="%(buildroot)s/.tmp" +backend_packages = [ + "pants.backend.python", + "pants.backend.python.lint.flake8", + "pants.backend.python.typecheck.mypy", + "pants.backend.experimental.python", + "setupgen", + "platform_resources", +] +pants_ignore = [ + "scripts", + "plugins", + "docs", # TODO: docs build config +] + +[anonymous-telemetry] +enabled = false + +[source] +root_patterns = [ + "/", + "/src", + "/stubs", + "/tests", + "/tools/pants-plugins", +] + +[python] +enable_resolves = true +interpreter_constraints = ["CPython==3.10.4"] +lockfile_generator = "pex" + +[python.resolves] +python-default = "python.lock" +python-kernel = "python-kernel.lock" + +# [setup-py-generation] +# first_party_depenency_version_scheme = "exact" + +[flake8] +version = "flake8>=4.0" +extra_requirements.add = [ + "flake8-commas>=2.1", + "setuptools>=60.0", +] +lockfile = "tools/flake8.lock" + +[pytest] +version = "pytest>=7.0" +extra_requirements.add = [ + "pytest-asyncio>=0.18", + "pytest-aiohttp>=1.0.4", + "pytest-dependency>=0.5.1", + "pytest-mock>=3.5.0", + "aioresponses>=0.7.3", +] +args = ["-v", "-m", "'not integration'"] +lockfile = "tools/pytest.lock" +execution_slot_var = "BACKEND_TEST_EXEC_SLOT" + +[mypy] +version = "mypy>=0.950" +extra_requirements.add = [ +] +lockfile = "tools/mypy.lock" diff --git a/plugins/.gitignore b/plugins/.gitignore new file mode 100644 index 0000000000..1345720ad0 --- /dev/null +++ b/plugins/.gitignore @@ -0,0 +1,4 @@ +# This directory is reserved for cloning plugin repositories. +/* +!.gitignore +!README.md diff --git a/plugins/README.md b/plugins/README.md new file mode 100644 index 0000000000..606b4d8091 --- /dev/null +++ b/plugins/README.md @@ -0,0 +1,11 @@ +Plugin Development Flow +----------------------- + +Run `./scripts/install-plugin.sh {github-owner}/{repo-name}`. +(Example: `./scripts/install-plugin.sh lablup/backend.ai-accelerator-cuda-mock`) + +The plugin code will be cloned into `./plugins/{repo-name}` and it will be installed +as an editable package inside the Pants exported unified virtualenv. + +Note that whenever you run `./pants export ::` again, you need to run +`./scripts/reinstall-plugins.sh` again to redo editable installation. diff --git a/py b/py new file mode 100755 index 0000000000..7c084d73cd --- /dev/null +++ b/py @@ -0,0 +1,14 @@ +#! /bin/bash +if [ ! -d dist/export/python/virtualenvs/python-default ]; then + >&2 echo "The exported virtualenv does not exist." + >&2 echo "Please run './pants export ::' first and try again." + exit 1 +fi +PYTHON_VERSION=$(cat pants.toml | python -c 'import sys,re;m=re.search("CPython==([^\"]+)", sys.stdin.read());print(m.group(1) if m else sys.exit(1))') +if [ $? -ne 0 ]; then + >&2 echo "Could not read the target CPython interpreter version from pants.toml" + exit 1 +fi +LOCKSET=${LOCKSET:-python-default/$PYTHON_VERSION} +source dist/export/python/virtualenvs/$LOCKSET/bin/activate +PYTHONPATH=src python "$@" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..a765a9cd63 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,55 @@ +[tool.towncrier] +package = "ai.backend.manager" +filename = "CHANGELOG.md" +directory = "changes/" +title_format = "## {version} ({project_date})" +template = "changes/template.md" +start_string = "\n" +issue_format = "([#{issue}](https://github.com/lablup/backend.ai-manager/issues/{issue}))" +underlines = ["", "", ""] + +[[tool.towncrier.type]] + directory = "breaking" + name = "Breaking Changes" + showcontent = true + +[[tool.towncrier.type]] + directory = "feature" + name = "Features" + showcontent = true + +[[tool.towncrier.type]] + directory = "deprecation" + name = "Deprecations" + showcontent = true + +[[tool.towncrier.type]] + directory = "fix" + name = "Fixes" + showcontent = true + +[[tool.towncrier.type]] + directory = "doc" + name = "Documentation Changes" + showcontent = true + +[[tool.towncrier.type]] + directory = "misc" + name = "Miscellaneous" + showcontent = true + +[tool.pytest.ini_options] +testpaths = "tests" +markers = [ + "integration: Test cases that spawn Dockerized kernel sessions", +] +filterwarnings = [ + "ignore::DeprecationWarning:etcd3.*:", +] +asyncio_mode = "auto" + +[tool.mypy] +ignore_missing_imports = true +mypy_path = "stubs:src" +namespace_packages = true +explicit_package_bases = true diff --git a/python-kernel.lock b/python-kernel.lock new file mode 100644 index 0000000000..6cff217b17 --- /dev/null +++ b/python-kernel.lock @@ -0,0 +1,563 @@ +// This lockfile was autogenerated by Pants. To regenerate, run: +// +// ./pants generate-lockfiles --resolve=python-kernel +// +// --- BEGIN PANTS LOCKFILE METADATA: DO NOT EDIT OR REMOVE --- +// { +// "version": 2, +// "valid_for_interpreter_constraints": [ +// "CPython==3.10.4" +// ], +// "generated_with_requirements": [ +// "async_timeout~=3.0", +// "attrs~=21.2", +// "janus~=0.6.1", +// "jupyter-client~=6.1", +// "msgpack~=1.0", +// "pyzmq~=22.2", +// "uvloop~=0.16" +// ] +// } +// --- END PANTS LOCKFILE METADATA --- + +{ + "allow_builds": true, + "allow_prereleases": false, + "allow_wheels": true, + "build_isolation": true, + "constraints": [], + "locked_resolves": [ + { + "locked_requirements": [ + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "4291ca197d287d274d0b6cb5d6f8f8f82d434ed288f962539ff18cc9012f9ea3", + "url": "https://files.pythonhosted.org/packages/e1/1e/5a4441be21b0726c4464f3f23c8b19628372f606755a9d2e46c187e65ec4/async_timeout-3.0.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "0c3c816a028d47f659d6ff5c745cb2acf1f966da1fe5c19c77a70282b25f4c5f", + "url": "https://files.pythonhosted.org/packages/a1/78/aae1545aba6e87e23ecab8d212b58bb70e72164b67eb090b81bb17ad38e3/async-timeout-3.0.1.tar.gz" + } + ], + "project_name": "async-timeout", + "requires_dists": [], + "requires_python": ">=3.5.3", + "version": "3.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4", + "url": "https://files.pythonhosted.org/packages/be/be/7abce643bfdf8ca01c48afa2ddf8308c2308b0c3b239a44e57d020afa0ef/attrs-21.4.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd", + "url": "https://files.pythonhosted.org/packages/d7/77/ebb15fc26d0f815839ecd897b919ed6d85c050feeb83e100e020df9153d2/attrs-21.4.0.tar.gz" + } + ], + "project_name": "attrs", + "requires_dists": [ + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"dev\"", + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"tests\"", + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"tests_no_zope\"", + "coverage[toml]>=5.0.2; extra == \"dev\"", + "coverage[toml]>=5.0.2; extra == \"tests\"", + "coverage[toml]>=5.0.2; extra == \"tests_no_zope\"", + "furo; extra == \"dev\"", + "furo; extra == \"docs\"", + "hypothesis; extra == \"dev\"", + "hypothesis; extra == \"tests\"", + "hypothesis; extra == \"tests_no_zope\"", + "mypy; extra == \"dev\"", + "mypy; extra == \"tests\"", + "mypy; extra == \"tests_no_zope\"", + "pre-commit; extra == \"dev\"", + "pympler; extra == \"dev\"", + "pympler; extra == \"tests\"", + "pympler; extra == \"tests_no_zope\"", + "pytest-mypy-plugins; extra == \"dev\"", + "pytest-mypy-plugins; extra == \"tests\"", + "pytest-mypy-plugins; extra == \"tests_no_zope\"", + "pytest>=4.3.0; extra == \"dev\"", + "pytest>=4.3.0; extra == \"tests\"", + "pytest>=4.3.0; extra == \"tests_no_zope\"", + "six; extra == \"dev\"", + "six; extra == \"tests\"", + "six; extra == \"tests_no_zope\"", + "sphinx-notfound-page; extra == \"dev\"", + "sphinx-notfound-page; extra == \"docs\"", + "sphinx; extra == \"dev\"", + "sphinx; extra == \"docs\"", + "zope.interface; extra == \"dev\"", + "zope.interface; extra == \"docs\"", + "zope.interface; extra == \"tests\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "21.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "5e069f72d497312b24fcc02073d70cb989045d1c91cbd53979366077959933e0", + "url": "https://files.pythonhosted.org/packages/c9/06/3dc78a8537fba6d442d45a2d9c0d71679d2bfc88e82452780874256cf883/cffi-1.15.0-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "920f0d66a896c2d99f0adbb391f990a84091179542c205fa53ce5787aff87954", + "url": "https://files.pythonhosted.org/packages/00/9e/92de7e1217ccc3d5f352ba21e52398372525765b2e0c4530e6eb2ba9282a/cffi-1.15.0.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "c21c9e3896c23007803a875460fb786118f0cdd4434359577ea25eb556e34c55", + "url": "https://files.pythonhosted.org/packages/2a/fb/7f52b10940eb31b32410fe016cad5b379961be0eac1d40ba4afa1e7819ad/cffi-1.15.0-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "abb9a20a72ac4e0fdb50dae135ba5e77880518e742077ced47eb1499e29a443c", + "url": "https://files.pythonhosted.org/packages/6d/cc/e45ad6277cd0675c4fbfc25cacc29f5760b9a4c150fd902b18abf85e4672/cffi-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "91ec59c33514b7c7559a6acda53bbfe1b283949c34fe7440bcf917f96ac0723e", + "url": "https://files.pythonhosted.org/packages/7f/96/126f39cbb6d7c87cb7de2e5f74a2a707c282c357055be641d3ae6075d0a5/cffi-1.15.0-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "f54a64f8b0c8ff0b64d18aa76675262e1700f3995182267998c31ae974fbc382", + "url": "https://files.pythonhosted.org/packages/ac/40/9cf45d01320987075d3156e96741b5de2395005070b571c0c9498093b905/cffi-1.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "00c878c90cb53ccfaae6b8bc18ad05d2036553e6d9d1d9dbcf323bbe83854ca3", + "url": "https://files.pythonhosted.org/packages/ae/27/a99335833b6c4d356bdeaadd87d0e9e83969761513dba6dc2a8123d95ca1/cffi-1.15.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "a5263e363c27b653a90078143adb3d076c1a748ec9ecc78ea2fb916f9b861962", + "url": "https://files.pythonhosted.org/packages/bb/7d/8e2ef3d009d801e02e18fb995c06ad788b5ed42c534c9c737260d54ddec7/cffi-1.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "f5c7150ad32ba43a07c4479f40241756145a1f03b43480e058cfd862bf5041c7", + "url": "https://files.pythonhosted.org/packages/c3/54/4587212a3a2340d41a40a903c92ce3590f78ca75a56fd608e6889cba98d1/cffi-1.15.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "0104fb5ae2391d46a4cb082abdd5c69ea4eab79d8d44eaaf79f1b1fd806ee4c2", + "url": "https://files.pythonhosted.org/packages/f0/00/3003a6f8c20bc349cc7c307432dcbd3711135a1ce12b763e7b09726674f0/cffi-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl" + } + ], + "project_name": "cffi", + "requires_dists": [ + "pycparser" + ], + "requires_python": null, + "version": "1.15" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "a719e6b0bed81a5d64fef0c186ee44b3077812b03e6f86deccab13e37ade33ec", + "url": "https://files.pythonhosted.org/packages/04/51/fc88d660980c7087310c65cbe3005a3eb0907e64eb9b217c2d44e91faf47/janus-0.6.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "127edc891f9e13420dd12f230d5113fa3de7f93662b81acfaf845989edf5eebf", + "url": "https://files.pythonhosted.org/packages/fc/8e/92fcb2ff18797959cde050b7c96a713999f73feefde809bfdf18b5901174/janus-0.6.2.tar.gz" + } + ], + "project_name": "janus", + "requires_dists": [], + "requires_python": ">=3.6", + "version": "0.6.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "e053a2c44b6fa597feebe2b3ecb5eea3e03d1d91cc94351a52931ee1426aecfc", + "url": "https://files.pythonhosted.org/packages/77/e8/c3cf72a32a697256608d5fa96360c431adec6e1c6709ba7f13f99ff5ee04/jupyter_client-6.1.12-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "c4bca1d0846186ca8be97f4d2fa6d2bae889cce4892a167ffa1ba6bd1f73e782", + "url": "https://files.pythonhosted.org/packages/de/05/6b1809dbe46e21c4018721c14a989a150ff73b4ecf631fe6e22d02cac579/jupyter_client-6.1.12.tar.gz" + } + ], + "project_name": "jupyter-client", + "requires_dists": [ + "async-generator; extra == \"test\"", + "ipykernel; extra == \"test\"", + "ipython; extra == \"test\"", + "jedi<0.18; python_version <= \"3.6\" and extra == \"test\"", + "jupyter-core>=4.6.0", + "mock; extra == \"test\"", + "pytest-asyncio; extra == \"test\"", + "pytest-timeout; extra == \"test\"", + "pytest; extra == \"test\"", + "python-dateutil>=2.1", + "pyzmq>=13", + "sphinx-rtd-theme; extra == \"doc\"", + "sphinx>=1.3.6; extra == \"doc\"", + "sphinxcontrib-github-alt; extra == \"doc\"", + "tornado>=4.1", + "traitlets" + ], + "requires_python": ">=3.5", + "version": "6.1.12" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "e7f5212177af7ab34179690140f188aa9bf3d322d8155ed972cbded19f55b6f3", + "url": "https://files.pythonhosted.org/packages/34/7d/8e442c0637a648c0136f686e015dc2f547f1a19f2690b183aa340a6762bc/jupyter_core-4.10.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "a6de44b16b7b31d7271130c71a6792c4040f077011961138afed5e5e73181aec", + "url": "https://files.pythonhosted.org/packages/91/5d/746dd5b904854043f99e72a22c69a2e9b3eb0ade2adc2b288e666ffa816f/jupyter_core-4.10.0.tar.gz" + } + ], + "project_name": "jupyter-core", + "requires_dists": [ + "ipykernel; extra == \"test\"", + "pre-commit; extra == \"test\"", + "pytest-cov; extra == \"test\"", + "pytest-timeout; extra == \"test\"", + "pytest; extra == \"test\"", + "pywin32>=1.0; sys_platform == \"win32\" and platform_python_implementation != \"PyPy\"", + "traitlets" + ], + "requires_python": ">=3.7", + "version": "4.10" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "c1ba333b4024c17c7591f0f372e2daa3c31db495a9b2af3cf664aef3c14354f7", + "url": "https://files.pythonhosted.org/packages/26/71/5fbd40e87fabaf6f60c2fa8934d93ec1df542b7f978a080ce99f6734934d/msgpack-1.0.3-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "36a64a10b16c2ab31dcd5f32d9787ed41fe68ab23dd66957ca2826c7f10d0b85", + "url": "https://files.pythonhosted.org/packages/06/e5/da31b9be6bed416c29906e0f9eff66af3e08f0b6e11caa7858d649e8ca1f/msgpack-1.0.3-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "1c58cdec1cb5fcea8c2f1771d7b5fec79307d056874f746690bd2bdd609ab147", + "url": "https://files.pythonhosted.org/packages/1b/18/61b7462849c31fafd7c7d05a2ae896d495a1c1bf7f25788a4a8af9439153/msgpack-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "96acc674bb9c9be63fa8b6dabc3248fdc575c4adc005c440ad02f87ca7edd079", + "url": "https://files.pythonhosted.org/packages/4f/e9/837b5c2209d41ddaf99cc7247598191d6f9f776c017b95abb5ada761ef93/msgpack-1.0.3-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e", + "url": "https://files.pythonhosted.org/packages/61/3c/2206f39880d38ca7ad8ac1b28d2d5ca81632d163b2d68ef90e46409ca057/msgpack-1.0.3.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "2f97c0f35b3b096a330bb4a1a9247d0bd7e1f3a2eba7ab69795501504b1c2c39", + "url": "https://files.pythonhosted.org/packages/a5/36/3734c798885a93c6e8fe4422184ad089c6e2e44c18d2b18f09cc029c02b8/msgpack-1.0.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "2c3ca57c96c8e69c1a0d2926a6acf2d9a522b41dc4253a8945c4c6cd4981a4e3", + "url": "https://files.pythonhosted.org/packages/b9/f4/4d2ee26409739c1a4b1dc3b8e4c50dedd9d5054d1ab5fec9830c42ebb3b6/msgpack-1.0.3-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "b0a792c091bac433dfe0a70ac17fc2087d4595ab835b47b89defc8bbabcf5c73", + "url": "https://files.pythonhosted.org/packages/bd/c5/e69b0e5f216191b09261957a75a78078aa2bc90a7138e5186eb641e45d9f/msgpack-1.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + } + ], + "project_name": "msgpack", + "requires_dists": [], + "requires_python": null, + "version": "1.0.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", + "url": "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", + "url": "https://files.pythonhosted.org/packages/98/ff/fec109ceb715d2a6b4c4a85a61af3b40c723a961e8828319fbcb15b868dc/py-1.11.0.tar.gz" + } + ], + "project_name": "py", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "1.11" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9", + "url": "https://files.pythonhosted.org/packages/62/d5/5f610ebe421e85889f2e55e33b7f9a6795bd982198517d912eb1c76e1a53/pycparser-2.21-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206", + "url": "https://files.pythonhosted.org/packages/5e/0b/95d387f5f4433cb0f53ff7ad859bd2c6051051cebbb564f139a999ab46de/pycparser-2.21.tar.gz" + } + ], + "project_name": "pycparser", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7", + "version": "2.21" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9", + "url": "https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86", + "url": "https://files.pythonhosted.org/packages/4c/c4/13b4776ea2d76c115c1d1b84579f3764ee6d57204f6be27119f13a61d0a9/python-dateutil-2.8.2.tar.gz" + } + ], + "project_name": "python-dateutil", + "requires_dists": [ + "six>=1.5" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7", + "version": "2.8.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "d3ee45adff48e0551d1aa60d2ec066fec006083b791f5c3527c40cd8aefac71f", + "url": "https://files.pythonhosted.org/packages/75/91/fa2a9d3861184df4c2dc57c9a29e6e856f6bbe3702acccf169329f9b6eae/pywin32-304-cp310-cp310-win_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "3c7bacf5e24298c86314f03fa20e16558a4e4138fc34615d7de4070c23e65af3", + "url": "https://files.pythonhosted.org/packages/05/6b/9f8421a9a2ab5f33cbb9fd2f282ac971e584f6a83d44f6672bd17f1d68b2/pywin32-304-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "4f32145913a2447736dad62495199a8e280a77a0ca662daa2332acf849f0be48", + "url": "https://files.pythonhosted.org/packages/14/07/9a2bd2cdcdeecd013ed83173209f1c984662ef05922ef6fe5f0fb9cc120e/pywin32-304-cp310-cp310-win_amd64.whl" + } + ], + "project_name": "pywin32", + "requires_dists": [], + "requires_python": null, + "version": "304" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "7661fc1d5cb73481cf710a1418a4e1e301ed7d5d924f91c67ba84b2a1b89defd", + "url": "https://files.pythonhosted.org/packages/c5/e1/76c0e9fb596f613ec1e52b00720da310b5c422b67db80647973ca603b415/pyzmq-22.3.0-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "2841997a0d85b998cbafecb4183caf51fd19c4357075dfd33eb7efea57e4c149", + "url": "https://files.pythonhosted.org/packages/2a/d6/e76c2740943bc2c12b4989f8bea2c036298b0255005e11e2da34a15df459/pyzmq-22.3.0-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f89468059ebc519a7acde1ee50b779019535db8dcf9b8c162ef669257fef7a93", + "url": "https://files.pythonhosted.org/packages/2d/20/71b3770dc9165a40cc2cad54d6a698593f942c8d41463c5382949f4f8bc0/pyzmq-22.3.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "76c532fd68b93998aab92356be280deec5de8f8fe59cd28763d2cc8a58747b7f", + "url": "https://files.pythonhosted.org/packages/35/84/1cd4efede2c34d0d36642c071eee50c2baf5574ab245c4fcd805a3a006f1/pyzmq-22.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "6b217b8f9dfb6628f74b94bdaf9f7408708cb02167d644edca33f38746ca12dd", + "url": "https://files.pythonhosted.org/packages/56/29/4ebcdd3956ba019b84e3671c5d8fcadc4d8e33651801acd1f43583077060/pyzmq-22.3.0-cp310-cp310-macosx_10_15_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "902319cfe23366595d3fa769b5b751e6ee6750a0a64c5d9f757d624b2ac3519e", + "url": "https://files.pythonhosted.org/packages/5b/5b/efd46c447d2cf3abadf1ee0bf891ac7677bbb4231166b809eff7ab7d8bdc/pyzmq-22.3.0-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "8eddc033e716f8c91c6a2112f0a8ebc5e00532b4a6ae1eb0ccc48e027f9c671c", + "url": "https://files.pythonhosted.org/packages/6c/95/d37e7db364d7f569e71068882b1848800f221c58026670e93a4c6d50efe7/pyzmq-22.3.0.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "67db33bea0a29d03e6eeec55a8190e033318cee3cbc732ba8fd939617cbf762d", + "url": "https://files.pythonhosted.org/packages/bb/e5/86e2e22f513c1fa388cd9d2fc2efead5b4677f934e66a5fbc95c34348ba3/pyzmq-22.3.0-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "ea12133df25e3a6918718fbb9a510c6ee5d3fdd5a346320421aac3882f4feeea", + "url": "https://files.pythonhosted.org/packages/e6/3d/f0a39a4e94b4db31561a3431d035ddf53b52fc8c0dc09cfc704b24d58f0f/pyzmq-22.3.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f907c7359ce8bf7f7e63c82f75ad0223384105f5126f313400b7e8004d9b33c3", + "url": "https://files.pythonhosted.org/packages/f4/a4/877dd3866c1b76f30da51116e45f3702eb2ab45f8b2638919e8d9237fdb6/pyzmq-22.3.0-cp310-cp310-musllinux_1_1_i686.whl" + } + ], + "project_name": "pyzmq", + "requires_dists": [ + "cffi; implementation_name == \"pypy\"", + "py; implementation_name == \"pypy\"" + ], + "requires_python": ">=3.6", + "version": "22.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254", + "url": "https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", + "url": "https://files.pythonhosted.org/packages/71/39/171f1c67cd00715f190ba0b100d606d440a28c93c7714febeca8b79af85e/six-1.16.0.tar.gz" + } + ], + "project_name": "six", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7", + "version": "1.16" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "33c6e81d7bd55b468d2e793517c909b139960b6c790a60b7991b9b6b76fb9791", + "url": "https://files.pythonhosted.org/packages/cf/44/cc9590db23758ee7906d40cacff06c02a21c2a6166602e095a56cbf2f6f6/tornado-6.1.tar.gz" + } + ], + "project_name": "tornado", + "requires_dists": [], + "requires_python": ">=3.5", + "version": "6.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "f44b708d33d98b0addb40c29d148a761f44af740603a8fd0e2f8b5b27cf0f087", + "url": "https://files.pythonhosted.org/packages/84/c5/6a23a8f6acc43150fdc6cfb3bda1cad1f0dbaec0e4f75df77d9c8a62320d/traitlets-5.2.1.post0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "70815ecb20ec619d1af28910ade523383be13754283aef90528eb3d47b77c5db", + "url": "https://files.pythonhosted.org/packages/9b/26/5e5f9002f939d54663d244a260d0453b2baf4f767697da5968aa474f04e7/traitlets-5.2.1.post0.tar.gz" + } + ], + "project_name": "traitlets", + "requires_dists": [ + "pre-commit; extra == \"test\"", + "pytest; extra == \"test\"" + ], + "requires_python": ">=3.7", + "version": "5.2.1.post0" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "772206116b9b57cd625c8a88f2413df2fcfd0b496eb188b82a43bed7af2c2ec9", + "url": "https://files.pythonhosted.org/packages/12/1c/4c270b22f68a75bedf795aadc40370c4ff9e910a5e1aff327c24aaae6a99/uvloop-0.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "6224f1401025b748ffecb7a6e2652b17768f30b1a6a3f7b44660e5b5b690b12d", + "url": "https://files.pythonhosted.org/packages/2a/07/75074f9789d5f8811bc77230a84ddbb7586e555e84f59d75d2968ef5c4a0/uvloop-0.16.0-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "f74bc20c7b67d1c27c72601c78cf95be99d5c2cdd4514502b4f3eb0933ff1228", + "url": "https://files.pythonhosted.org/packages/ab/d9/22bbffa8f8d7e075ccdb29e8134107adfb4710feb10039f9d357db8b589c/uvloop-0.16.0.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "bd53f7f5db562f37cd64a3af5012df8cac2c464c97e732ed556800129505bd64", + "url": "https://files.pythonhosted.org/packages/b9/00/14dffb56943092c2b5821d288dc23ff36dff9ad3b8aad3547c71b171cf3b/uvloop-0.16.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "30ba9dcbd0965f5c812b7c2112a1ddf60cf904c1c160f398e7eed3a6b82dcd9c", + "url": "https://files.pythonhosted.org/packages/da/ea/56fecce56844e308b6ff3c5b55a372a6a4da2c6fe8ee35a272459f534d9c/uvloop-0.16.0-cp310-cp310-macosx_10_9_x86_64.whl" + } + ], + "project_name": "uvloop", + "requires_dists": [ + "Cython<0.30.0,>=0.29.24; extra == \"dev\"", + "Sphinx~=4.1.2; extra == \"dev\"", + "Sphinx~=4.1.2; extra == \"docs\"", + "aiohttp; extra == \"dev\"", + "aiohttp; extra == \"test\"", + "flake8~=3.9.2; extra == \"dev\"", + "flake8~=3.9.2; extra == \"test\"", + "mypy>=0.800; extra == \"dev\"", + "mypy>=0.800; extra == \"test\"", + "psutil; extra == \"dev\"", + "psutil; extra == \"test\"", + "pyOpenSSL~=19.0.0; extra == \"dev\"", + "pyOpenSSL~=19.0.0; extra == \"test\"", + "pycodestyle~=2.7.0; extra == \"dev\"", + "pycodestyle~=2.7.0; extra == \"test\"", + "pytest>=3.6.0; extra == \"dev\"", + "sphinx-rtd-theme~=0.5.2; extra == \"dev\"", + "sphinx-rtd-theme~=0.5.2; extra == \"docs\"", + "sphinxcontrib-asyncio~=0.3.0; extra == \"dev\"", + "sphinxcontrib-asyncio~=0.3.0; extra == \"docs\"" + ], + "requires_python": ">=3.7", + "version": "0.16" + } + ], + "platform_tag": [ + "cp310", + "cp310", + "manylinux_2_31_aarch64" + ] + } + ], + "path_mappings": {}, + "pex_version": "2.1.84", + "prefer_older_binary": false, + "requirements": [ + "async_timeout~=3.0", + "attrs~=21.2", + "janus~=0.6.1", + "jupyter-client~=6.1", + "msgpack~=1.0", + "pyzmq~=22.2", + "uvloop~=0.16" + ], + "requires_python": [ + "==3.10.4" + ], + "resolver_version": "pip-2020-resolver", + "style": "universal", + "transitive": true, + "use_pep517": null +} \ No newline at end of file diff --git a/python.lock b/python.lock new file mode 100644 index 0000000000..aa7dbea32d --- /dev/null +++ b/python.lock @@ -0,0 +1,3915 @@ +// This lockfile was autogenerated by Pants. To regenerate, run: +// +// ./pants generate-lockfiles --resolve=python-default +// +// --- BEGIN PANTS LOCKFILE METADATA: DO NOT EDIT OR REMOVE --- +// { +// "version": 2, +// "valid_for_interpreter_constraints": [ +// "CPython==3.10.4" +// ], +// "generated_with_requirements": [ +// "Jinja2~=3.0.1", +// "PyJWT~=2.0", +// "PyYAML~=5.4.1", +// "SQLAlchemy[postgresql_asyncpg]~=1.4.29", +// "aiodataloader-ng~=0.2.1", +// "aiodns>=3.0", +// "aiodocker~=0.21.0", +// "aiofiles~=0.8.0", +// "aiohttp_cors~=0.7", +// "aiohttp_session[aioredis]~=2.11", +// "aiohttp_sse>=2.0", +// "aiohttp~=3.8.1", +// "aiomonitor~=0.4.5", +// "aioredis[hiredis]~=2.0.1", +// "aiosqlite~=0.17.0", +// "aiotools~=1.5.9", +// "aiotusclient~=0.1.4", +// "alembic~=1.7.7", +// "appdirs~=1.4.4", +// "async_timeout~=4.0", +// "asyncudp>=0.4", +// "attrs>=20.3", +// "backend.ai-krunner-alpine~=3.3", +// "backend.ai-krunner-static-gnu~=2.0", +// "cachetools~=4.1.1", +// "callosum~=0.9.10", +// "click>=7.1.2", +// "colorama>=0.4.4", +// "coloredlogs~=15.0", +// "cryptography>=2.8", +// "etcetra~=0.1.6", +// "graphene~=2.1.9", +// "humanize>=3.1.0", +// "janus>=0.6.1", +// "kubernetes-asyncio~=9.1.0", +// "kubernetes~=10.0.0", +// "lark-parser~=0.11.3", +// "more-itertools~=8.12.0", +// "msgpack>=1.0.0", +// "netifaces~=0.11.0", +// "packaging>=21.3", +// "passlib[bcrypt]>=1.7.4", +// "pexpect~=4.8", +// "psutil~=5.8.0", +// "psycopg2-binary>=2.8.4", +// "pytest-dependency>=0.5.1", +// "pytest~=7.1", +// "python-dateutil>=2.8", +// "python-json-logger>=2.0.1", +// "python-snappy~=0.6.0", +// "pyzmq~=22.1.0", +// "redis[hiredis]~=4.3.1", +// "rich~=12.2", +// "setproctitle~=1.2.2", +// "tabulate~=0.8.9", +// "tblib~=1.7", +// "tenacity>=8.0", +// "toml>=0.10.2; python_version <= \"3.11\"", +// "tomlkit~=0.8.0", +// "tqdm>=4.61", +// "trafaret~=2.1", +// "typeguard~=2.10", +// "types-Jinja2", +// "types-PyYAML", +// "types-aiofiles", +// "types-cachetools", +// "types-click", +// "types-python-dateutil", +// "types-setuptools", +// "types-six", +// "types-tabulate", +// "types-toml", +// "typing_extensions~=4.1.1", +// "uvloop>=0.16", +// "yarl>=1.7", +// "zipstream-new~=1.1.8" +// ] +// } +// --- END PANTS LOCKFILE METADATA --- + +{ + "allow_builds": true, + "allow_prereleases": false, + "allow_wheels": true, + "build_isolation": true, + "constraints": [], + "locked_resolves": [ + { + "locked_requirements": [ + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "0c0acb66b4e72d6606c9fa14ac9eb001a222c37c885b8fbdf65f41824cfa855f", + "url": "https://files.pythonhosted.org/packages/c5/16/a9de89d859eeaa0f9fdf5675a97ea1416265e1ebc039c7b9ea2c9e5433ac/aioconsole-0.4.1.tar.gz" + } + ], + "project_name": "aioconsole", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "0.4.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "c705656ae8cab12f8d313ddac4d68d6036bafe403f172c9687ece027cde21acc", + "url": "https://files.pythonhosted.org/packages/d4/f7/9b837b7893d2db59c21df28434f31637a37884ce6b9e29b92d4dd1e68435/aiodataloader_ng-0.2.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "f655efa53d72f9887617443161e7920eaa10b79ba6710326223bf9de395515ec", + "url": "https://files.pythonhosted.org/packages/e0/09/9d11e527dcd81e770526a32c95e207fa6ef0aeb92de6b31b418573ebfd0b/aiodataloader-ng-0.2.1.tar.gz" + } + ], + "project_name": "aiodataloader-ng", + "requires_dists": [ + "coveralls; extra == \"test\"", + "flake8>=3.9.0; extra == \"lint\"", + "mock; extra == \"test\"", + "mypy>=0.930; extra == \"typecheck\"", + "pytest-asyncio; extra == \"test\"", + "pytest-cov; extra == \"test\"", + "pytest>=6.2.5; extra == \"test\"", + "twine>=3.7.1; extra == \"build\"", + "wheel>=0.37.1; extra == \"build\"" + ], + "requires_python": null, + "version": "0.2.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "2b19bc5f97e5c936638d28e665923c093d8af2bf3aa88d35c43417fa25d136a2", + "url": "https://files.pythonhosted.org/packages/ab/72/991ee33a517df69c6cd6f3486cfe9b6329557cb55acaa8cefac33c2aa4d2/aiodns-3.0.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "946bdfabe743fceeeb093c8a010f5d1645f708a241be849e17edfb0e49e08cd6", + "url": "https://files.pythonhosted.org/packages/27/79/df72e25df0fdd9bf5a5ab068539731d27c5f2ae5654621ae0c92ceca94cf/aiodns-3.0.0.tar.gz" + } + ], + "project_name": "aiodns", + "requires_dists": [ + "pycares>=4.0.0" + ], + "requires_python": null, + "version": "3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6fe00135bb7dc40a407669d3157ecdfd856f3737d939df54f40a479d40cf7bdc", + "url": "https://files.pythonhosted.org/packages/7e/86/97638ef9d0e54a86d389ded8ccf27cc1ecabf7ce27ae873636a5c1e46d89/aiodocker-0.21.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "1f2e6db6377195962bb676d4822f6e3a0c525e1b5d60b8ebbab68230bff3d227", + "url": "https://files.pythonhosted.org/packages/6f/f5/5fb3a17fcdd31d3cce9afa82c306da869e2b36c5ca1477224396e5e1f31b/aiodocker-0.21.0.tar.gz" + } + ], + "project_name": "aiodocker", + "requires_dists": [ + "aiohttp>=3.6", + "typing-extensions>=3.6.5" + ], + "requires_python": ">=3.6", + "version": "0.21" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "7a973fc22b29e9962d0897805ace5856e6a566ab1f0c8e5c91ff6c866519c937", + "url": "https://files.pythonhosted.org/packages/ca/e4/b78d049f7cc7ed053ddbfdd59b2dcc7bd387458e2c2869b602975685d65e/aiofiles-0.8.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "8334f23235248a3b2e83b2c3a78a22674f39969b96397126cc93664d9a901e59", + "url": "https://files.pythonhosted.org/packages/10/ca/c416cfacf6a47e1400dad56eab85aa86c92c6fbe58447d12035e434f0d5c/aiofiles-0.8.0.tar.gz" + } + ], + "project_name": "aiofiles", + "requires_dists": [], + "requires_python": "<4.0,>=3.6", + "version": "0.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "713ac174a629d39b7c6a3aa757b337599798da4c1157114a314e4e391cd28e32", + "url": "https://files.pythonhosted.org/packages/2e/4f/119a8efad036d1f766ad736864a6dbfc8db9596e74ce9820f8c1282a240b/aiohttp-3.8.1-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "12de6add4038df8f72fac606dff775791a60f113a725c960f2bab01d8b8e6b15", + "url": "https://files.pythonhosted.org/packages/48/08/c3efb449dea5f38292804e4fbf8eaef1b3f168535a4163cc3fce3f9b4915/aiohttp-3.8.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "97ef77eb6b044134c0b3a96e16abcb05ecce892965a2124c566af0fd60f717e2", + "url": "https://files.pythonhosted.org/packages/4f/c6/a8ce9fc6bbf9c0dbdaa631bcb8f9da5b532fd22ead50ef7390976fc9bf0d/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "fc5471e1a54de15ef71c1bc6ebe80d4dc681ea600e68bfd1cbce40427f0b7578", + "url": "https://files.pythonhosted.org/packages/5a/86/5f63de7a202550269a617a5d57859a2961f3396ecd1739a70b92224766bc/aiohttp-3.8.1.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "c2aef4703f1f2ddc6df17519885dbfa3514929149d3ff900b73f45998f2532fa", + "url": "https://files.pythonhosted.org/packages/75/86/c55c7b6b9d0d9e25b1d721e204424f154bd72bb172d2056f0f9f06c50254/aiohttp-3.8.1-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "61bfc23df345d8c9716d03717c2ed5e27374e0fe6f659ea64edcd27b4b044cf7", + "url": "https://files.pythonhosted.org/packages/76/3d/8f64ed6d429f9feeefc52b551f4ba5554d2f7a6f46d92c080f4ae48e0478/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "7dadf3c307b31e0e61689cbf9e06be7a867c563d5a63ce9dca578f956609abf8", + "url": "https://files.pythonhosted.org/packages/7e/9f/3cd2502f3cab61eccd7c20f5ab67447cf891ad8613282141955df1b7fb98/aiohttp-3.8.1-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "099ebd2c37ac74cce10a3527d2b49af80243e2a4fa39e7bce41617fbc35fa3c1", + "url": "https://files.pythonhosted.org/packages/80/a3/9403173d3a6ba5893a4e0a1816b211da7ba0cb7c00c9ac0279ec2dbbf576/aiohttp-3.8.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "eaba923151d9deea315be1f3e2b31cc39a6d1d2f682f942905951f4e40200922", + "url": "https://files.pythonhosted.org/packages/85/e6/d52a342bf22b5b5c759a94af340836490bcbffd288d4a65494234d8298f7/aiohttp-3.8.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "a79004bb58748f31ae1cbe9fa891054baaa46fb106c2dc7af9f8e3304dc30316", + "url": "https://files.pythonhosted.org/packages/a6/7f/4c202b0fd3c33029e45bb0d06eaac2886be4427763cc9589774fb39b5da7/aiohttp-3.8.1-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "6f0d5f33feb5f69ddd57a4a4bd3d56c719a141080b445cbf18f238973c5c9923", + "url": "https://files.pythonhosted.org/packages/b1/bd/e412cb6cd12b7a86966239a97ed0391e1ad5ac6f8a749caddc49e18264ec/aiohttp-3.8.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "fa0ffcace9b3aa34d205d8130f7873fcfefcb6a4dd3dd705b0dab69af6712642", + "url": "https://files.pythonhosted.org/packages/c0/6d/f5423a7c899c538e2cff2e713f9eb2c51b02fad909ec8e8b1c3ed713049a/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "01d7bdb774a9acc838e6b8f1d114f45303841b89b95984cbb7d80ea41172a9e3", + "url": "https://files.pythonhosted.org/packages/cc/28/c95a0694da3082cb76808799017b02db6c10ec8687ee1ac5edad091ab070/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "1ed0b6477896559f17b9eaeb6d38e07f7f9ffe40b9f0f9627ae8b9926ae260a8", + "url": "https://files.pythonhosted.org/packages/e3/3a/720635a98bb0eef9179d12ee3ccca659d1fcccfbafaacdf42ed5536a0861/aiohttp-3.8.1-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "31560d268ff62143e92423ef183680b9829b1b482c011713ae941997921eebc8", + "url": "https://files.pythonhosted.org/packages/f3/0d/a035862f8a11b6cba4220b0c1201443fa6f5151137889e2dfe1cc983e58e/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "2e5d962cf7e1d426aa0e528a7e198658cdc8aa4fe87f781d039ad75dcd52c516", + "url": "https://files.pythonhosted.org/packages/f4/2d/07e3ba718571e79509f88a791611a3e156e8915ed9a19116547806bce8fa/aiohttp-3.8.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + } + ], + "project_name": "aiohttp", + "requires_dists": [ + "Brotli; extra == \"speedups\"", + "aiodns; extra == \"speedups\"", + "aiosignal>=1.1.2", + "async-timeout<5.0,>=4.0.0a3", + "asynctest==0.13.0; python_version < \"3.8\"", + "attrs>=17.3.0", + "cchardet; extra == \"speedups\"", + "charset-normalizer<3.0,>=2.0", + "frozenlist>=1.1.1", + "idna-ssl>=1.0; python_version < \"3.7\"", + "multidict<7.0,>=4.5", + "typing-extensions>=3.7.4; python_version < \"3.8\"", + "yarl<2.0,>=1.0" + ], + "requires_python": ">=3.6", + "version": "3.8.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "0451ba59fdf6909d0e2cd21e4c0a43752bc0703d33fc78ae94d9d9321710193e", + "url": "https://files.pythonhosted.org/packages/13/e7/e436a0c0eb5127d8b491a9b83ecd2391c6ff7dcd5548dfaec2080a2340fd/aiohttp_cors-0.7.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "4d39c6d7100fd9764ed1caf8cebf0eb01bf5e3f24e2e073fda6234bc48b19f5d", + "url": "https://files.pythonhosted.org/packages/44/9e/6cdce7c3f346d8fd487adf68761728ad8cd5fbc296a7b07b92518350d31f/aiohttp-cors-0.7.0.tar.gz" + } + ], + "project_name": "aiohttp-cors", + "requires_dists": [ + "aiohttp>=1.1", + "typing; python_version < \"3.5\"" + ], + "requires_python": null, + "version": "0.7" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "d5451b72ba748d72ef01f55c3af45b63be2d40b2bdadf47407007e1270343384", + "url": "https://files.pythonhosted.org/packages/13/e3/affe80cf6d906f5a8ca7d084b780c340d90f7649c0683ed1e65dd88c5bdf/aiohttp_session-2.11.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "0fdf39600f6a05c4943ef6c7b099071ca9854413111a26761626244be5015dc4", + "url": "https://files.pythonhosted.org/packages/3b/a7/0b97b9a2e3a553a86a6703f86b0e9b1afb2b262849700e8f80015c0f643f/aiohttp-session-2.11.0.tar.gz" + } + ], + "project_name": "aiohttp-session", + "requires_dists": [ + "aiohttp>=3.8", + "aiomcache>=0.5.2; extra == \"aiomcache\"", + "aioredis>=2.0.0; extra == \"aioredis\"", + "cryptography; extra == \"pycrypto\"", + "cryptography; extra == \"secure\"", + "pynacl; extra == \"pynacl\"", + "typing-extensions>=3.7.4; python_version < \"3.8\"" + ], + "requires_python": ">=3.7", + "version": "2.11" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "b31cbf99142ca0c93cfc180e3f32fcdca0bcab35367e1380bc70dfaaf88cdc60", + "url": "https://files.pythonhosted.org/packages/fe/0c/e1db340ac8d75af3e51236696077e399cc3dfe8ea25d2a524066bff196a9/aiohttp_sse-2.1.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "dfe8f7271ab4470891fa1bfa1913d6889b3d19015dd3d3a4cab949e66971bbca", + "url": "https://files.pythonhosted.org/packages/2f/3f/cc4f5a3fe6cb50ad5b9d26bb7738c5da1f61645b517d4230df2fc32d89f0/aiohttp-sse-2.1.0.tar.gz" + } + ], + "project_name": "aiohttp-sse", + "requires_dists": [ + "aiohttp>=3.0" + ], + "requires_python": ">=3.7", + "version": "2.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "5c7ac38b2ee59cbad87162ef5c45d72c7b57b94c8a93e7f462184bf10f9ebccd", + "url": "https://files.pythonhosted.org/packages/94/6d/95eb718f545c6d09aa19fa08b28b616203a77a68755b09c6f6b1202773b2/aiomonitor-0.4.5-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "6232c1ab14bf06cd7217845801c27340032f74e283bdaf32d01cdd3b7c673d0e", + "url": "https://files.pythonhosted.org/packages/98/76/b62e9fbe267287527fb6f4b6774394d4f00650195774173bb0055a99ab3d/aiomonitor-0.4.5.tar.gz" + } + ], + "project_name": "aiomonitor", + "requires_dists": [ + "aioconsole", + "terminaltables" + ], + "requires_python": null, + "version": "0.4.5" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "9ac0d0b3b485d293b8ca1987e6de8658d7dafcca1cddfcd1d506cae8cdebfdd6", + "url": "https://files.pythonhosted.org/packages/9b/a9/0da089c3ae7a31cbcd2dcf0214f6f571e1295d292b6139e2bac68ec081d0/aioredis-2.0.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "eaa51aaf993f2d71f54b70527c440437ba65340588afeb786cd87c55c89cd98e", + "url": "https://files.pythonhosted.org/packages/2e/cf/9eb144a0b05809ffc5d29045c4b51039000ea275bc1268d0351c9e7dfc06/aioredis-2.0.1.tar.gz" + } + ], + "project_name": "aioredis", + "requires_dists": [ + "async-timeout", + "hiredis>=1.0; implementation_name == \"cpython\" and extra == \"hiredis\"", + "typing-extensions" + ], + "requires_python": ">=3.6", + "version": "2.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "26e62109036cd181df6e6ad646f91f0dcfd05fe16d0cb924138ff2ab75d64e3a", + "url": "https://files.pythonhosted.org/packages/3b/87/fe94898f2d44a93a35d5aa74671ed28094d80753a1113d68b799fab6dc22/aiosignal-1.2.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "78ed67db6c7b7ced4f98e495e572106d5c432a93e1ddd1bf475e1dc05f5b7df2", + "url": "https://files.pythonhosted.org/packages/27/6b/a89fbcfae70cf53f066ec22591938296889d3cc58fec1e1c393b10e8d71d/aiosignal-1.2.0.tar.gz" + } + ], + "project_name": "aiosignal", + "requires_dists": [ + "frozenlist>=1.1.0" + ], + "requires_python": ">=3.6", + "version": "1.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6c49dc6d3405929b1d08eeccc72306d3677503cc5e5e43771efc1e00232e8231", + "url": "https://files.pythonhosted.org/packages/a0/48/77c0092f716c4bf9460dca44f5120f70b8f71f14a12f40d22551a7152719/aiosqlite-0.17.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "f0e6acc24bc4864149267ac82fb46dfb3be4455f99fe21df82609cc6e6baee51", + "url": "https://files.pythonhosted.org/packages/40/e0/ad1edd74311831ca71b32a5b83352b490d78d11a90a1cde04e1b6830e018/aiosqlite-0.17.0.tar.gz" + } + ], + "project_name": "aiosqlite", + "requires_dists": [ + "typing_extensions>=3.7.2" + ], + "requires_python": ">=3.6", + "version": "0.17" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "3a9946f047cfb33e1fca0d39845c3d6d10aeda69a6602c06a8f7dcb67c12171d", + "url": "https://files.pythonhosted.org/packages/e0/6b/8b7b1c5dfcb1ab49ac0baa0e59f1edb437a05b3fe618b5c5ec931e15ec56/aiotools-1.5.9-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "44e7b5f50cef692e177ad37231f0d6fa058465ea93c8759da7ee06a2b5cfac97", + "url": "https://files.pythonhosted.org/packages/83/b9/915ad8fa49f8834841bd7ca387db3c374ba8d953e949ef2f185e86ffd652/aiotools-1.5.9.tar.gz" + } + ], + "project_name": "aiotools", + "requires_dists": [ + "codecov; extra == \"test\"", + "flake8>=4.0.1; extra == \"lint\"", + "mypy>=0.920; extra == \"typecheck\"", + "pytest-asyncio~=0.16.0; extra == \"test\"", + "pytest-cov; extra == \"test\"", + "pytest-mock; extra == \"test\"", + "pytest~=6.2.5; extra == \"test\"", + "setuptools>=51.2.0; extra == \"build\"", + "sphinx-rtd-theme~=1.0; extra == \"docs\"", + "sphinx~=4.3; extra == \"docs\"", + "towncrier~=21.3; extra == \"build\"", + "twine>=3.8.0; extra == \"build\"", + "typing-extensions~=3.7; python_version < \"3.8\"", + "wheel>=0.37.0; extra == \"build\"" + ], + "requires_python": ">=3.6", + "version": "1.5.9" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "69dad8a1bfb66c811a537014bda44e06188621df36e75566b97c4aaf373a389b", + "url": "https://files.pythonhosted.org/packages/7b/d2/004c9f673df74ff0455f381e898ad522b5fbc05a2b7b99420f1998702a58/aiotusclient-0.1.4-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "4e23704420b5273b0c417c92a8bae605f90fc8f9364caa91a80f2586c7d8f952", + "url": "https://files.pythonhosted.org/packages/7c/9e/1a2fa86e00168a035ce532b1e595298b397eed6de9e757470f1e2bf53b8d/aiotusclient-0.1.4.tar.gz" + } + ], + "project_name": "aiotusclient", + "requires_dists": [ + "aiohttp>=3.6.2", + "codecov>=2.1.8; extra == \"test\"", + "flake8>=3.8.3; extra == \"lint\"", + "mypy>=0.782; extra == \"typecheck\"", + "pytest-asyncio>=0.14.0; extra == \"test\"", + "pytest-cov>=2.10.0; extra == \"test\"", + "pytest-mock>=3.2.0; extra == \"test\"", + "pytest>=6.0.1; extra == \"test\"", + "tqdm>=4.42", + "twine>=3.2.0; extra == \"build\"", + "wheel>=0.34.2; extra == \"build\"" + ], + "requires_python": ">=3.7", + "version": "0.1.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "29be0856ec7591c39f4e1cb10f198045d890e6e2274cf8da80cb5e721a09642b", + "url": "https://files.pythonhosted.org/packages/b3/e2/8d48220731b7279911c43e95cd182961a703b939de6822b00de3ea0d3159/alembic-1.7.7-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "4961248173ead7ce8a21efb3de378f13b8398e6630fab0eb258dc74a8af24c58", + "url": "https://files.pythonhosted.org/packages/30/b9/5526b43a4c54d177ab14af0af4b5c31d73db33d1ad3e30976d3b023e0594/alembic-1.7.7.tar.gz" + } + ], + "project_name": "alembic", + "requires_dists": [ + "Mako", + "SQLAlchemy>=1.3.0", + "importlib-metadata; python_version < \"3.9\"", + "importlib-resources; python_version < \"3.9\"", + "python-dateutil; extra == \"tz\"" + ], + "requires_python": ">=3.6", + "version": "1.7.7" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "d10a4bf949f619f719b227ef5386e31f49a2b6d453004b21f02661ccc8670c7b", + "url": "https://files.pythonhosted.org/packages/45/a4/b4fcadbdab46c2ec2d2f6f8b4ab3f64fd0040789ac7f065eba82119cd602/aniso8601-7.0.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "513d2b6637b7853806ae79ffaca6f3e8754bdd547048f5ccc1420aec4b714f1e", + "url": "https://files.pythonhosted.org/packages/7f/39/0da0982a3a42fd896beaa07425692fb3100a9d0e40723783efc20f1dec7c/aniso8601-7.0.0.tar.gz" + } + ], + "project_name": "aniso8601", + "requires_dists": [], + "requires_python": null, + "version": "7" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", + "url": "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", + "url": "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz" + } + ], + "project_name": "appdirs", + "requires_dists": [], + "requires_python": null, + "version": "1.4.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c", + "url": "https://files.pythonhosted.org/packages/d6/c1/8991e7c5385b897b8c020cdaad718c5b087a6626d1d11a23e1ea87e325a7/async_timeout-4.0.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15", + "url": "https://files.pythonhosted.org/packages/54/6e/9678f7b2993537452710ffb1750c62d2c26df438aa621ad5fa9d1507a43a/async-timeout-4.0.2.tar.gz" + } + ], + "project_name": "async-timeout", + "requires_dists": [ + "typing-extensions>=3.6.5; python_version < \"3.8\"" + ], + "requires_python": ">=3.6", + "version": "4.0.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "56d88d7ef4341412cd9c68efba323a4519c916979ba91b95d4c08799d2ff0c09", + "url": "https://files.pythonhosted.org/packages/a2/5e/403d7e4e206b9ca031d47875b6d88ae8442ce1c90bf6a55c6b0cf3f4ddfc/asyncpg-0.25.0-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "43cde84e996a3afe75f325a68300093425c2f47d340c0fc8912765cf24a1c095", + "url": "https://files.pythonhosted.org/packages/00/b4/1f4699643c090db6ade5991272d0c6c75c570b5e46632bad71e8cbb0bd5d/asyncpg-0.25.0-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "bf5e3408a14a17d480f36ebaf0401a12ff6ae5457fdf45e4e2775c51cc9517d3", + "url": "https://files.pythonhosted.org/packages/2e/70/8305c9891742a4e767b6257708195ed9a72dec90e0169a85d930bc845f0d/asyncpg-0.25.0-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "63f8e6a69733b285497c2855464a34de657f2cccd25aeaeeb5071872e9382540", + "url": "https://files.pythonhosted.org/packages/38/80/4c03e190c86c78a5a85b440a5f719dd42c388c5976ab327c6358f5b86514/asyncpg-0.25.0.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "1a70783f6ffa34cc7dd2de20a873181414a34fd35a4a208a1f1a7f9f695e4ec4", + "url": "https://files.pythonhosted.org/packages/48/ba/572ade0c41bc4bfd4804b7d35ebd49a94293f37aded6acef528d25a11a0f/asyncpg-0.25.0-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "2bc197fc4aca2fd24f60241057998124012469d2e414aed3f992579db0c88e3a", + "url": "https://files.pythonhosted.org/packages/b4/00/10f2add0c7ca3963c7b273abd2299b9c2190967ee56758a1a233049deca7/asyncpg-0.25.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + } + ], + "project_name": "asyncpg", + "requires_dists": [ + "Cython<0.30.0,>=0.29.24; extra == \"dev\"", + "Sphinx~=4.1.2; extra == \"dev\"", + "Sphinx~=4.1.2; extra == \"docs\"", + "flake8~=3.9.2; extra == \"dev\"", + "flake8~=3.9.2; extra == \"test\"", + "pycodestyle~=2.7.0; extra == \"dev\"", + "pycodestyle~=2.7.0; extra == \"test\"", + "pytest>=6.0; extra == \"dev\"", + "sphinx-rtd-theme~=0.5.2; extra == \"dev\"", + "sphinx-rtd-theme~=0.5.2; extra == \"docs\"", + "sphinxcontrib-asyncio~=0.3.0; extra == \"dev\"", + "sphinxcontrib-asyncio~=0.3.0; extra == \"docs\"", + "typing-extensions>=3.7.4.3; python_version < \"3.8\"", + "uvloop>=0.15.3; (platform_system != \"Windows\" and python_version >= \"3.7\") and extra == \"dev\"", + "uvloop>=0.15.3; (platform_system != \"Windows\" and python_version >= \"3.7\") and extra == \"test\"" + ], + "requires_python": ">=3.6.0", + "version": "0.25" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "c0fd6129bed9cd95eeb0815235ed86920f9a24d60d87eae181123588ddd4cd2b", + "url": "https://files.pythonhosted.org/packages/f8/d6/cdb2c5477b106cf49c7c1ab26c61b072ee3022355074fbfdfbfb642437ac/asyncudp-0.6.0.tar.gz" + } + ], + "project_name": "asyncudp", + "requires_dists": [], + "requires_python": null, + "version": "0.6" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197", + "url": "https://files.pythonhosted.org/packages/2c/a0/da5f49008ec6e9a658dbf5d7310a4debd397bce0b4db03cf8a410066bb87/atomicwrites-1.4.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a", + "url": "https://files.pythonhosted.org/packages/55/8d/74a75635f2c3c914ab5b3850112fd4b0c8039975ecb320e4449aa363ba54/atomicwrites-1.4.0.tar.gz" + } + ], + "project_name": "atomicwrites", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7", + "version": "1.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4", + "url": "https://files.pythonhosted.org/packages/be/be/7abce643bfdf8ca01c48afa2ddf8308c2308b0c3b239a44e57d020afa0ef/attrs-21.4.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd", + "url": "https://files.pythonhosted.org/packages/d7/77/ebb15fc26d0f815839ecd897b919ed6d85c050feeb83e100e020df9153d2/attrs-21.4.0.tar.gz" + } + ], + "project_name": "attrs", + "requires_dists": [ + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"dev\"", + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"tests\"", + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"tests_no_zope\"", + "coverage[toml]>=5.0.2; extra == \"dev\"", + "coverage[toml]>=5.0.2; extra == \"tests\"", + "coverage[toml]>=5.0.2; extra == \"tests_no_zope\"", + "furo; extra == \"dev\"", + "furo; extra == \"docs\"", + "hypothesis; extra == \"dev\"", + "hypothesis; extra == \"tests\"", + "hypothesis; extra == \"tests_no_zope\"", + "mypy; extra == \"dev\"", + "mypy; extra == \"tests\"", + "mypy; extra == \"tests_no_zope\"", + "pre-commit; extra == \"dev\"", + "pympler; extra == \"dev\"", + "pympler; extra == \"tests\"", + "pympler; extra == \"tests_no_zope\"", + "pytest-mypy-plugins; extra == \"dev\"", + "pytest-mypy-plugins; extra == \"tests\"", + "pytest-mypy-plugins; extra == \"tests_no_zope\"", + "pytest>=4.3.0; extra == \"dev\"", + "pytest>=4.3.0; extra == \"tests\"", + "pytest>=4.3.0; extra == \"tests_no_zope\"", + "six; extra == \"dev\"", + "six; extra == \"tests\"", + "six; extra == \"tests_no_zope\"", + "sphinx-notfound-page; extra == \"dev\"", + "sphinx-notfound-page; extra == \"docs\"", + "sphinx; extra == \"dev\"", + "sphinx; extra == \"docs\"", + "zope.interface; extra == \"dev\"", + "zope.interface; extra == \"docs\"", + "zope.interface; extra == \"tests\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "21.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "b41b340bc3eabd0413c5116e5b55c808d5d85273a5f066402649bbd1aacd5e30", + "url": "https://files.pythonhosted.org/packages/73/b3/ad66998041f559e24d8c64d90d6d4f67543bb5d89c97646701bb66e48b1f/backend.ai_krunner_alpine-3.3.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "9fc7bd39c6075a9c172539152b3712d203fcb57eaf22262ac97f03688a9af492", + "url": "https://files.pythonhosted.org/packages/88/f1/c1125679d2690c933c47d57b2dc35c1e75eeddb1e8ead161aea4677571de/backend.ai-krunner-alpine-3.3.1.tar.gz" + } + ], + "project_name": "backend-ai-krunner-alpine", + "requires_dists": [ + "Click>=7.1", + "codecov; extra == \"test\"", + "flake8>=3.7.9; extra == \"test\"", + "pytest~=5.4.1; extra == \"test\"", + "twine~=3.0; extra == \"build\"", + "wheel>=0.34.2; extra == \"build\"" + ], + "requires_python": ">=3.6", + "version": "3.3.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "2508b4f0bae09f2d7b1bd68d3100418551e3f2f0d4d8ed0c089d84afc31d7f4c", + "url": "https://files.pythonhosted.org/packages/8d/af/182be6fd62adbacd8f8ca0e8763241247f632a18d486f885ab8f8dc985be/backend.ai_krunner_static_gnu-2.0.1-py3-none-manylinux2014_x86_64.macosx_11_0_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "90f5f7346cd8d6c904b190f125aed2bb7c09cf52b91a5ad1b57b3faf31faa9a4", + "url": "https://files.pythonhosted.org/packages/16/c9/912e963a29854920a1018d7b37efd5bb31a3e66b1d672e156ca6148a026e/backend.ai-krunner-static-gnu-2.0.1.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "8e00553ab9e09e5b2b4cd045783e9d05be19eb94e1d53f268c2d11e910a3b319", + "url": "https://files.pythonhosted.org/packages/27/24/a931130c36e7c79fa78772e6264d6fa13f9b11a1213b86c166497ffc3b77/backend.ai_krunner_static_gnu-2.0.1-py3-none-manylinux2014_aarch64.macosx_11_0_arm64.whl" + } + ], + "project_name": "backend-ai-krunner-static-gnu", + "requires_dists": [ + "codecov; extra == \"test\"", + "flake8>=3.7.9; extra == \"test\"", + "pytest~=6.2.1; extra == \"test\"", + "twine~=3.0; extra == \"build\"", + "wheel>=0.34.2; extra == \"build\"" + ], + "requires_python": ">=3.6", + "version": "2.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "7ff2069240c6bbe49109fe84ca80508773a904f5a8cb960e02a977f7f519b129", + "url": "https://files.pythonhosted.org/packages/f5/37/7cd297ff571c4d86371ff024c0e008b37b59e895b28f69444a9b6f94ca1a/bcrypt-3.2.2-cp36-abi3-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "cd43303d6b8a165c29ec6756afd169faba9396a9472cdff753fe9f19b96ce2fa", + "url": "https://files.pythonhosted.org/packages/18/76/057b0637c880e6cb0abdc8a867d080376ddca6ed7d05b7738f589cc5c1a8/bcrypt-3.2.2-cp36-abi3-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "2b02d6bfc6336d1094276f3f588aa1225a598e27f8e3388f4db9948cb707b521", + "url": "https://files.pythonhosted.org/packages/3e/df/289db4f31b303de6addb0897c8b5c01b23bd4b8c511ac80a32b08658847c/bcrypt-3.2.2-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "a2c46100e315c3a5b90fdc53e429c006c5f962529bc27e1dfd656292c20ccc40", + "url": "https://files.pythonhosted.org/packages/40/8f/b67b42faa2e4d944b145b1a402fc08db0af8fe2dfa92418c674b5a302496/bcrypt-3.2.2-cp36-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "88273d806ab3a50d06bc6a2fc7c87d737dd669b76ad955f449c43095389bc8fb", + "url": "https://files.pythonhosted.org/packages/61/3d/dce83194830183aa700cab07c89822471d21663a86a0b305d1e5c7b02810/bcrypt-3.2.2-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "6d2cb9d969bfca5bc08e45864137276e4c3d3d7de2b162171def3d188bf9d34a", + "url": "https://files.pythonhosted.org/packages/86/1b/f4d7425dfc6cd0e405b48ee484df6d80fb39e05f25963dbfcc2c511e8341/bcrypt-3.2.2-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "61bae49580dce88095d669226d5076d0b9d927754cedbdf76c6c9f5099ad6f26", + "url": "https://files.pythonhosted.org/packages/8c/b3/1257f7d64ee0aa0eb4fb1de5da8c2647a57db7b737da1f2342ac1889d3b8/bcrypt-3.2.2-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "7180d98a96f00b1050e93f5b0f556e658605dd9f524d0b0e68ae7944673f525e", + "url": "https://files.pythonhosted.org/packages/a0/c2/05354b1d4351d2e686a32296cc9dd1e63f9909a580636df0f7b06d774600/bcrypt-3.2.2-cp36-abi3-macosx_10_10_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "433c410c2177057705da2a9f2cd01dd157493b2a7ac14c8593a16b3dab6b6bfb", + "url": "https://files.pythonhosted.org/packages/e8/36/edc85ab295ceff724506252b774155eff8a238f13730c8b13badd33ef866/bcrypt-3.2.2.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "4e029cef560967fb0cf4a802bcf4d562d3d6b4b1bf81de5ec1abbe0f1adb027e", + "url": "https://files.pythonhosted.org/packages/f1/64/cd93e2c3e28a5fa8bcf6753d5cc5e858e4da08bf51404a0adb6a412532de/bcrypt-3.2.2-cp36-abi3-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "7d9ba2e41e330d2af4af6b1b6ec9e6128e91343d0b4afb9282e54e5508f31baa", + "url": "https://files.pythonhosted.org/packages/fc/9a/e1867f0b27a3f4ce90e21dd7f322f0e15d4aac2434d3b938dcf765e47c6b/bcrypt-3.2.2-cp36-abi3-musllinux_1_1_aarch64.whl" + } + ], + "project_name": "bcrypt", + "requires_dists": [ + "cffi>=1.1", + "mypy; extra == \"typecheck\"", + "pytest!=3.3.0,>=3.2.1; extra == \"tests\"" + ], + "requires_python": ">=3.6", + "version": "3.2.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "513d4ff98dd27f85743a8dc0e92f55ddb1b49e060c2d5961512855cda2c01a98", + "url": "https://files.pythonhosted.org/packages/cd/5c/f3aa86b6d5482f3051b433c7616668a9b96fbe49a622210e2c9781938a5c/cachetools-4.1.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "bbaa39c3dede00175df2dc2b03d0cf18dd2d32a7de7beb68072d13043c9edb20", + "url": "https://files.pythonhosted.org/packages/fc/c8/0b52cf3132b4b85c9e83faa3e4d375575afeb3a1710c40b2b2cd2a3e5635/cachetools-4.1.1.tar.gz" + } + ], + "project_name": "cachetools", + "requires_dists": [], + "requires_python": "~=3.5", + "version": "4.1.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "3f17e34abc01ae71fef9d30bbea86a1e8eb155e24f23ae8121196a59964c6f04", + "url": "https://files.pythonhosted.org/packages/3b/bf/54362d591b09948f707e4066baade96e7014b8f7af1d1672dcc7e05d6fdd/callosum-0.9.10-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "f8b48397a768511cdf8c290b727e4adc6300831d3f82679b8b1b55278fce634c", + "url": "https://files.pythonhosted.org/packages/18/6e/a5fdcd76e6aaa52f13721046470139ca1b8beab8f3a5feed8800b1181f17/callosum-0.9.10.tar.gz" + } + ], + "project_name": "callosum", + "requires_dists": [ + "Click>=8.0; extra == \"test\"", + "aioredis<2.0,>=1.3.0; extra == \"redis\"", + "aiotools>=0.9.1", + "async-timeout>=3.0.1", + "attrs>=21.2.0", + "codecov; extra == \"test\"", + "flake8>=3.9; extra == \"lint\"", + "msgpack>=1.0.0", + "mypy>=0.910; extra == \"typecheck\"", + "pytest-asyncio>=0.15.0; extra == \"test\"", + "pytest-cov; extra == \"test\"", + "pytest-mock; extra == \"test\"", + "pytest~=6.2; extra == \"test\"", + "python-dateutil>=2.8.1", + "python-snappy>=0.5.4; extra == \"snappy\"", + "pyzmq>=19.0.0; extra == \"zeromq\"", + "sphinx-autodoc-typehints; extra == \"docs\"", + "sphinx; extra == \"docs\"", + "temporenc>=0.1", + "thriftpy2>=0.4.9; extra == \"thrift\"", + "towncrier>=19.2.0; extra == \"build\"", + "twine>=3.1.0; extra == \"build\"", + "types-python-dateutil; extra == \"typecheck\"", + "wheel>=0.33.6; extra == \"build\"" + ], + "requires_python": ">=3.8", + "version": "0.9.10" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "f1d53542ee8cbedbe2118b5686372fb33c297fcd6379b050cca0ef13a597382a", + "url": "https://files.pythonhosted.org/packages/11/dd/e015f3780f42dd9af62cf0107b44ea1298926627ecd70c17b0e484e95bcd/certifi-2022.5.18.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "9c5705e395cd70084351dd8ad5c41e65655e08ce46f2ec9cf6c2c08390f71eb7", + "url": "https://files.pythonhosted.org/packages/07/10/75277f313d13a2b74fc56e29239d5c840c2bf09f17bf25c02b35558812c6/certifi-2022.5.18.1.tar.gz" + } + ], + "project_name": "certifi", + "requires_dists": [], + "requires_python": ">=3.6", + "version": "2022.5.18.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "5e069f72d497312b24fcc02073d70cb989045d1c91cbd53979366077959933e0", + "url": "https://files.pythonhosted.org/packages/c9/06/3dc78a8537fba6d442d45a2d9c0d71679d2bfc88e82452780874256cf883/cffi-1.15.0-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "920f0d66a896c2d99f0adbb391f990a84091179542c205fa53ce5787aff87954", + "url": "https://files.pythonhosted.org/packages/00/9e/92de7e1217ccc3d5f352ba21e52398372525765b2e0c4530e6eb2ba9282a/cffi-1.15.0.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "c21c9e3896c23007803a875460fb786118f0cdd4434359577ea25eb556e34c55", + "url": "https://files.pythonhosted.org/packages/2a/fb/7f52b10940eb31b32410fe016cad5b379961be0eac1d40ba4afa1e7819ad/cffi-1.15.0-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "abb9a20a72ac4e0fdb50dae135ba5e77880518e742077ced47eb1499e29a443c", + "url": "https://files.pythonhosted.org/packages/6d/cc/e45ad6277cd0675c4fbfc25cacc29f5760b9a4c150fd902b18abf85e4672/cffi-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "91ec59c33514b7c7559a6acda53bbfe1b283949c34fe7440bcf917f96ac0723e", + "url": "https://files.pythonhosted.org/packages/7f/96/126f39cbb6d7c87cb7de2e5f74a2a707c282c357055be641d3ae6075d0a5/cffi-1.15.0-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "f54a64f8b0c8ff0b64d18aa76675262e1700f3995182267998c31ae974fbc382", + "url": "https://files.pythonhosted.org/packages/ac/40/9cf45d01320987075d3156e96741b5de2395005070b571c0c9498093b905/cffi-1.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "00c878c90cb53ccfaae6b8bc18ad05d2036553e6d9d1d9dbcf323bbe83854ca3", + "url": "https://files.pythonhosted.org/packages/ae/27/a99335833b6c4d356bdeaadd87d0e9e83969761513dba6dc2a8123d95ca1/cffi-1.15.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "a5263e363c27b653a90078143adb3d076c1a748ec9ecc78ea2fb916f9b861962", + "url": "https://files.pythonhosted.org/packages/bb/7d/8e2ef3d009d801e02e18fb995c06ad788b5ed42c534c9c737260d54ddec7/cffi-1.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "f5c7150ad32ba43a07c4479f40241756145a1f03b43480e058cfd862bf5041c7", + "url": "https://files.pythonhosted.org/packages/c3/54/4587212a3a2340d41a40a903c92ce3590f78ca75a56fd608e6889cba98d1/cffi-1.15.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "0104fb5ae2391d46a4cb082abdd5c69ea4eab79d8d44eaaf79f1b1fd806ee4c2", + "url": "https://files.pythonhosted.org/packages/f0/00/3003a6f8c20bc349cc7c307432dcbd3711135a1ce12b763e7b09726674f0/cffi-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl" + } + ], + "project_name": "cffi", + "requires_dists": [ + "pycparser" + ], + "requires_python": null, + "version": "1.15" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6881edbebdb17b39b4eaaa821b438bf6eddffb4468cf344f09f89def34a8b1df", + "url": "https://files.pythonhosted.org/packages/06/b3/24afc8868eba069a7f03650ac750a778862dc34941a4bebeb58706715726/charset_normalizer-2.0.12-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2857e29ff0d34db842cd7ca3230549d1a697f96ee6d3fb071cfa6c7393832597", + "url": "https://files.pythonhosted.org/packages/56/31/7bcaf657fafb3c6db8c787a865434290b726653c912085fbd371e9b92e1c/charset-normalizer-2.0.12.tar.gz" + } + ], + "project_name": "charset-normalizer", + "requires_dists": [ + "unicodedata2; extra == \"unicode_backport\"" + ], + "requires_python": ">=3.5.0", + "version": "2.0.12" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48", + "url": "https://files.pythonhosted.org/packages/c2/f1/df59e28c642d583f7dacffb1e0965d0e00b218e0186d7858ac5233dce840/click-8.1.3-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e", + "url": "https://files.pythonhosted.org/packages/59/87/84326af34517fca8c58418d148f2403df25303e02736832403587318e9e8/click-8.1.3.tar.gz" + } + ], + "project_name": "click", + "requires_dists": [ + "colorama; platform_system == \"Windows\"", + "importlib-metadata; python_version < \"3.8\"" + ], + "requires_python": ">=3.7", + "version": "8.1.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2", + "url": "https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b", + "url": "https://files.pythonhosted.org/packages/1f/bb/5d3246097ab77fa083a61bd8d3d527b7ae063c7d8e8671b1cf8c4ec10cbe/colorama-0.4.4.tar.gz" + } + ], + "project_name": "colorama", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "0.4.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934", + "url": "https://files.pythonhosted.org/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", + "url": "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz" + } + ], + "project_name": "coloredlogs", + "requires_dists": [ + "capturer>=2.4; extra == \"cron\"", + "humanfriendly>=9.1" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "15.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9", + "url": "https://files.pythonhosted.org/packages/b1/92/dfd892312d822f36c55366118b95d914e5f16de11044a27cf10a7d71bbbf/commonmark-0.9.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60", + "url": "https://files.pythonhosted.org/packages/60/48/a60f593447e8f0894ebb7f6e6c1f25dafc5e89c5879fdc9360ae93ff83f0/commonmark-0.9.1.tar.gz" + } + ], + "project_name": "commonmark", + "requires_dists": [ + "flake8==3.7.8; extra == \"test\"", + "future>=0.14.0; python_version < \"3\"", + "hypothesis==3.55.3; extra == \"test\"" + ], + "requires_python": null, + "version": "0.9.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "471e0d70201c069f74c837983189949aa0d24bb2d751b57e26e3761f2f782b8d", + "url": "https://files.pythonhosted.org/packages/f6/51/640fe2a25b774aefcd49b101c850f36e8e4ac164dc5c281b3dfa50c01da7/cryptography-37.0.2-cp36-abi3-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "59b281eab51e1b6b6afa525af2bd93c16d49358404f814fe2c2410058623928c", + "url": "https://files.pythonhosted.org/packages/06/01/2a237fae9ea9a7aecc182cd09348c4eb4c5d8a9ef3a50d1f2a60a1004603/cryptography-37.0.2-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "31fe38d14d2e5f787e0aecef831457da6cec68e0bb09a35835b0b44ae8b988fe", + "url": "https://files.pythonhosted.org/packages/1d/63/eb9ee3c63cebf6bac454617085376b7e2cdc1ae022e55fbc1d0194d4eae4/cryptography-37.0.2-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "2bd1096476aaac820426239ab534b636c77d71af66c547b9ddcd76eb9c79e004", + "url": "https://files.pythonhosted.org/packages/45/10/de0bdaaf4410dd046404e38d57bfe8a567aa94c8b7b6cf858d759112a947/cryptography-37.0.2-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f224ad253cc9cea7568f49077007d2263efa57396a2f2f78114066fd54b5c68e", + "url": "https://files.pythonhosted.org/packages/51/05/bb2b681f6a77276fc423d04187c39dafdb65b799c8d87b62ca82659f9ead/cryptography-37.0.2.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "0cc20f655157d4cfc7bada909dc5cc228211b075ba8407c46467f63597c78178", + "url": "https://files.pythonhosted.org/packages/55/ba/2268399be15f1542a3bacf6e60fdaf4fea0b18e5190e87b97075e03cb155/cryptography-37.0.2-cp36-abi3-manylinux_2_24_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "ef15c2df7656763b4ff20a9bc4381d8352e6640cfeb95c2972c38ef508e75181", + "url": "https://files.pythonhosted.org/packages/80/e2/89a180c6dc1c3fe33f7f8965da6401cf0b31f440f4e59e9b024b6f54eb0c/cryptography-37.0.2-cp36-abi3-macosx_10_10_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "093cb351031656d3ee2f4fa1be579a8c69c754cf874206be1d4cf3b542042804", + "url": "https://files.pythonhosted.org/packages/8e/38/055c75d4f6180aa3525eabaa5a0eabadd174594c7d5eeac6741db663dcd5/cryptography-37.0.2-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "731c8abd27693323b348518ed0e0705713a36d79fdbd969ad968fbef0979a7e0", + "url": "https://files.pythonhosted.org/packages/a1/09/51b3b56ec18f1eb395aa12c65e154f8582a08f4af458d4890b80a9f40acd/cryptography-37.0.2-cp36-abi3-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "46f4c544f6557a2fefa7ac8ac7d1b17bf9b647bd20b16decc8fbcab7117fbc15", + "url": "https://files.pythonhosted.org/packages/ac/96/358a0b767bdd40ee51f0843ee87e614f9f3c1754a2247a26eb0d40e80ded/cryptography-37.0.2-cp36-abi3-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "3c81599befb4d4f3d7648ed3217e00d21a9341a9a688ecdd615ff72ffbed7336", + "url": "https://files.pythonhosted.org/packages/b4/b7/b39f5812f3fc787be8a1bad7fd9bcf39cfa9b058bb3f3c0bc1b7659e9d77/cryptography-37.0.2-cp36-abi3-macosx_10_10_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f8ec91983e638a9bcd75b39f1396e5c0dc2330cbd9ce4accefe68717e6779e0a", + "url": "https://files.pythonhosted.org/packages/c9/d2/aac40c7a55192c15f2845565ee769f1627f6cfb73fc73b0a250f8b787f41/cryptography-37.0.2-cp36-abi3-musllinux_1_1_aarch64.whl" + } + ], + "project_name": "cryptography", + "requires_dists": [ + "bcrypt>=3.1.5; extra == \"ssh\"", + "black; extra == \"pep8test\"", + "cffi>=1.12", + "flake8-import-order; extra == \"pep8test\"", + "flake8; extra == \"pep8test\"", + "hypothesis!=3.79.2,>=1.11.4; extra == \"test\"", + "iso8601; extra == \"test\"", + "pep8-naming; extra == \"pep8test\"", + "pretend; extra == \"test\"", + "pyenchant>=1.6.11; extra == \"docstest\"", + "pytest-benchmark; extra == \"test\"", + "pytest-cov; extra == \"test\"", + "pytest-subtests; extra == \"test\"", + "pytest-xdist; extra == \"test\"", + "pytest>=6.2.0; extra == \"test\"", + "pytz; extra == \"test\"", + "setuptools-rust>=0.11.4; extra == \"sdist\"", + "sphinx!=1.8.0,!=3.1.0,!=3.1.1,>=1.6.5; extra == \"docs\"", + "sphinx-rtd-theme; extra == \"docs\"", + "sphinxcontrib-spelling>=4.0.1; extra == \"docstest\"", + "twine>=1.12.0; extra == \"docstest\"" + ], + "requires_python": ">=3.6", + "version": "37.0.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "64756e3e14c8c5eea9795d93c524551432a0be75629f8f29e67ab8caf076c76d", + "url": "https://files.pythonhosted.org/packages/51/6a/c3a0408646408f7283b7bc550c30a32cc791181ec4618592eec13e066ce3/Deprecated-1.2.13-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "43ac5335da90c31c24ba028af536a91d41d53f9e6901ddb021bcc572ce44e38d", + "url": "https://files.pythonhosted.org/packages/c8/d1/e412abc2a358a6b9334250629565fe12697ca1cdee4826239eddf944ddd0/Deprecated-1.2.13.tar.gz" + } + ], + "project_name": "deprecated", + "requires_dists": [ + "PyTest-Cov; python_version >= \"3.6\" and extra == \"dev\"", + "PyTest-Cov<2.6; python_version < \"3.6\" and extra == \"dev\"", + "PyTest; python_version >= \"3.6\" and extra == \"dev\"", + "PyTest<5; python_version < \"3.6\" and extra == \"dev\"", + "bump2version<1; extra == \"dev\"", + "configparser<5; python_version < \"3\" and extra == \"dev\"", + "importlib-metadata<3; python_version < \"3\" and extra == \"dev\"", + "importlib-resources<4; python_version < \"3\" and extra == \"dev\"", + "sphinx<2; extra == \"dev\"", + "sphinxcontrib-websupport<2; python_version < \"3\" and extra == \"dev\"", + "tox; extra == \"dev\"", + "wrapt<2,>=1.10", + "zipp<2; python_version < \"3\" and extra == \"dev\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7", + "version": "1.2.13" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "78f81768dbbd498091ecf09291b8bd8ef1a9a68791f121d1fbab7b1f3220d862", + "url": "https://files.pythonhosted.org/packages/94/69/425848749851fb35b1aded3e89a7a0a5adb08d576ea2ca79b9678ed4546e/etcetra-0.1.6-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "ceaaf1ce8fa72b943e1fb69704c83cc631085187b6dc19be6eae02099e9155bb", + "url": "https://files.pythonhosted.org/packages/7e/f8/2dc3d6796461f27f06ac3a10adfeca52a7565d0cdbf28093fb398715b9e9/etcetra-0.1.6.tar.gz" + } + ], + "project_name": "etcetra", + "requires_dists": [ + "async-timeout~=4.0.0", + "codecov; extra == \"test\"", + "flake8-commas>=2.1; extra == \"lint\"", + "flake8>=4.0.1; extra == \"lint\"", + "grpcio-tools~=1.44.0", + "grpcio~=1.44.0", + "mypy>=0.930; extra == \"typecheck\"", + "pytest-asyncio~=0.16.0; extra == \"test\"", + "pytest-cov>=2.11; extra == \"test\"", + "pytest-mock>=3.5.0; extra == \"test\"", + "pytest~=6.2.5; extra == \"test\"", + "twine>=3.4.1; extra == \"build\"", + "types-protobuf; extra == \"typecheck\"", + "types-setuptools; extra == \"typecheck\"", + "wheel>=0.36.2; extra == \"build\"" + ], + "requires_python": ">=3.9", + "version": "0.1.6" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "f7353ba3367473d1d616ee727945f439e027f0bb16ac1a750219a8344d1d5d3c", + "url": "https://files.pythonhosted.org/packages/a0/fa/7e6e4cbd0911966ca52846deee74b6ef9b138c45765bdb0f7242f14688e4/frozenlist-1.3.0-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "6eb275c6385dd72594758cbe96c07cdb9bd6becf84235f4a594bdf21e3596c9d", + "url": "https://files.pythonhosted.org/packages/0e/36/c4659bee33cab5ed22b7df23bafc3841a269793ca8e5527822f3fe41b568/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "754728d65f1acc61e0f4df784456106e35afb7bf39cfe37227ab00436fb38676", + "url": "https://files.pythonhosted.org/packages/14/36/9a396760b7d1a48efe3520e994064401b36dfa9286e5b5e5bfb5bde16db7/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "e30b2f9683812eb30cf3f0a8e9f79f8d590a7999f731cf39f9105a7c4a39489d", + "url": "https://files.pythonhosted.org/packages/24/1c/076b1a5a0b8b4af0bae5f999eaf0e3deaa25eb08fe195cdc3e628e41c279/frozenlist-1.3.0-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "4a44ebbf601d7bac77976d429e9bdb5a4614f9f4027777f9e54fd765196e9d3b", + "url": "https://files.pythonhosted.org/packages/29/03/a300b151ecb1cf78c4fe404978ffbdb719eed810a1606e6afc8ae8f16837/frozenlist-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "45334234ec30fc4ea677f43171b18a27505bfb2dba9aca4398a62692c0ea8868", + "url": "https://files.pythonhosted.org/packages/32/61/b322998b806633b7df19d614916600d00439099dbb030a623eeb0694304e/frozenlist-1.3.0-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "04cb491c4b1c051734d41ea2552fde292f5f3a9c911363f74f39c23659c4af78", + "url": "https://files.pythonhosted.org/packages/3f/9e/991076d645ddfff334ace95b9386daef81cc144676c7f0057938f29ffa48/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "436496321dad302b8b27ca955364a439ed1f0999311c393dccb243e451ff66aa", + "url": "https://files.pythonhosted.org/packages/49/22/cb44c4c4671c55fc2ecf0727496d466390315f705ec3f0b0c7aeb5658a50/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "b9e3e9e365991f8cc5f5edc1fd65b58b41d0514a6a7ad95ef5c7f34eb49b3d3e", + "url": "https://files.pythonhosted.org/packages/4c/4e/0a153040dc966105dc99ccb597358d30a9bbda4a13aa753d0f382eced4fb/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "6a202458d1298ced3768f5a7d44301e7c86defac162ace0ab7434c2e961166e8", + "url": "https://files.pythonhosted.org/packages/5d/98/10edca86eb789469648049d0f8ea0b5bd74f5a3e11064ae620095db8595e/frozenlist-1.3.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "691ddf6dc50480ce49f68441f1d16a4c3325887453837036e0fb94736eae1e58", + "url": "https://files.pythonhosted.org/packages/71/46/d96b08a7f84bf77a7e4a5238bfabd7a1c34b2c1617476c69445668de7923/frozenlist-1.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "47be22dc27ed933d55ee55845d34a3e4e9f6fee93039e7f8ebadb0c2f60d403f", + "url": "https://files.pythonhosted.org/packages/79/58/3a0a77a6be2c368f8e52f4aeba0016bb3a040c9a43553b901bc0e969f54f/frozenlist-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "bde99812f237f79eaf3f04ebffd74f6718bbd216101b35ac7955c2d47c17da02", + "url": "https://files.pythonhosted.org/packages/b3/ac/ac631cdb022ddcf199305c03e45b3234aaab79e00663c4d96dacc39013d9/frozenlist-1.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "03a7dd1bfce30216a3f51a84e6dd0e4a573d23ca50f0346634916ff105ba6e6b", + "url": "https://files.pythonhosted.org/packages/cd/e5/c813ed0b4efa409ba74eb001f552243d4cb8d180723745f04a92340cc3fe/frozenlist-1.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "d2257aaba9660f78c7b1d8fea963b68f3feffb1a9d5d05a18401ca9eb3e8d0a3", + "url": "https://files.pythonhosted.org/packages/e8/28/da4e60e30dad3638570db89f9d6be26ae1f3e183607629b48cd5e35b1c81/frozenlist-1.3.0-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "ce6f2ba0edb7b0c1d8976565298ad2deba6f8064d2bebb6ffce2ca896eb35b0b", + "url": "https://files.pythonhosted.org/packages/f4/f7/8dfeb76d2a52bcea2b0718427af954ffec98be1d34cd8f282034b3e36829/frozenlist-1.3.0.tar.gz" + } + ], + "project_name": "frozenlist", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "1.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "349ac49b18b01019453cc99c11c92ed772739778c92f184002b7ab3a5b7ac77d", + "url": "https://files.pythonhosted.org/packages/fc/04/ea9a945a6fbd7f7f977fcf7e300a715f1635939e5daf141b95068abaa5ec/google_auth-2.6.6-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "1ba4938e032b73deb51e59c4656a00e0939cf0b1112575099f136babb4563312", + "url": "https://files.pythonhosted.org/packages/fd/8c/3c24a436775d6582effe4ecaf33b2562e6a7f0cbc647a293c764c5eac9ee/google-auth-2.6.6.tar.gz" + } + ], + "project_name": "google-auth", + "requires_dists": [ + "aiohttp<4.0.0dev,>=3.6.2; python_version >= \"3.6\" and extra == \"aiohttp\"", + "cachetools<6.0,>=2.0.0", + "enum34>=1.1.10; python_version < \"3.4\"", + "pyasn1-modules>=0.2.1", + "pyopenssl>=20.0.0; extra == \"pyopenssl\"", + "pyu2f>=0.1.5; extra == \"reauth\"", + "requests<3.0.0dev,>=2.20.0; extra == \"aiohttp\"", + "rsa<4.6; python_version < \"3.6\"", + "rsa<5,>=3.1.4; python_version >= \"3.6\"", + "six>=1.9.0" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7", + "version": "2.6.6" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "3d446eb1237c551052bc31155cf1a3a607053e4f58c9172b83a1b597beaa0868", + "url": "https://files.pythonhosted.org/packages/ef/a2/b3e68706bf45abc2f9d70f099a4b4ca6305779577f4a03458d78fb39cd42/graphene-2.1.9-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "b9f2850e064eebfee9a3ef4a1f8aa0742848d97652173ab44c82cc8a62b9ed93", + "url": "https://files.pythonhosted.org/packages/0a/9d/5a8890c7d14adbeda55e2d5f28120b4be2a7bfa0131674c340db1c162072/graphene-2.1.9.tar.gz" + } + ], + "project_name": "graphene", + "requires_dists": [ + "aniso8601<=7,>=3", + "coveralls; extra == \"test\"", + "fastdiff==0.2.0; extra == \"test\"", + "graphene-django; extra == \"django\"", + "graphene-sqlalchemy; extra == \"sqlalchemy\"", + "graphql-core<3,>=2.1", + "graphql-relay<3,>=2", + "iso8601; extra == \"test\"", + "mock; extra == \"test\"", + "promise; extra == \"test\"", + "pytest-benchmark; extra == \"test\"", + "pytest-cov; extra == \"test\"", + "pytest-mock; extra == \"test\"", + "pytest; extra == \"test\"", + "pytz; extra == \"test\"", + "six; extra == \"test\"", + "six<2,>=1.10.0", + "snapshottest; extra == \"test\"" + ], + "requires_python": null, + "version": "2.1.9" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "44c9bac4514e5e30c5a595fac8e3c76c1975cae14db215e8174c7fe995825bad", + "url": "https://files.pythonhosted.org/packages/11/71/d51beba3d8986fa6d8670ec7bcba989ad6e852d5ae99d95633e5dacc53e7/graphql_core-2.3.2-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "aac46a9ac524c9855910c14c48fc5d60474def7f99fd10245e76608eba7af746", + "url": "https://files.pythonhosted.org/packages/88/a2/dd91d55a6f6dd88c4d3c284d387c94f1f933fedec43a86a4422940b9de18/graphql-core-2.3.2.tar.gz" + } + ], + "project_name": "graphql-core", + "requires_dists": [ + "coveralls==1.11.1; extra == \"test\"", + "cython==0.29.17; extra == \"test\"", + "gevent==1.5.0; extra == \"test\"", + "gevent>=1.1; extra == \"gevent\"", + "promise<3,>=2.3", + "pyannotate==1.2.0; extra == \"test\"", + "pytest-benchmark==3.2.3; extra == \"test\"", + "pytest-cov==2.8.1; extra == \"test\"", + "pytest-django==3.9.0; extra == \"test\"", + "pytest-mock==2.0.0; extra == \"test\"", + "pytest==4.6.10; extra == \"test\"", + "rx<2,>=1.6", + "six==1.14.0; extra == \"test\"", + "six>=1.10.0" + ], + "requires_python": null, + "version": "2.3.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "ac514cb86db9a43014d7e73511d521137ac12cf0101b2eaa5f0a3da2e10d913d", + "url": "https://files.pythonhosted.org/packages/94/48/6022ea2e89cb936c3b933a0409c6e29bf8a68c050fe87d97f98aff6e5e9e/graphql_relay-2.0.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "870b6b5304123a38a0b215a79eace021acce5a466bf40cd39fa18cb8528afabb", + "url": "https://files.pythonhosted.org/packages/16/59/afbf1ce02631910ff0be06e5e057cc9e2806192d9b9c8d6671ff39e4abe2/graphql-relay-2.0.1.tar.gz" + } + ], + "project_name": "graphql-relay", + "requires_dists": [ + "graphql-core<3,>=2.2", + "promise<3,>=2.2", + "six>=1.12" + ], + "requires_python": null, + "version": "2.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "14d4f3cd4e8b524ae9b8aa567858beed70c392fdec26dbdb0a8a418392e71708", + "url": "https://files.pythonhosted.org/packages/07/59/f4656193ac084b7134dcf5f5d4d5c4d3a154d222202eecf00b367d367e90/greenlet-1.1.2-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "e30f5ea4ae2346e62cedde8794a56858a67b878dd79f7df76a0767e356b1744a", + "url": "https://files.pythonhosted.org/packages/0c/10/754e21b5bea89d0e73f99d60c83754df7cc64db74f47d98ab187669ce341/greenlet-1.1.2.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "8639cadfda96737427330a094476d4c7a56ac03de7265622fcf4cfe57c8ae18d", + "url": "https://files.pythonhosted.org/packages/47/20/433693ac90ae70c8577bf4896951859e8d293e59df5818073033f181ee7b/greenlet-1.1.2-cp310-cp310-macosx_10_14_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "abb7a75ed8b968f3061327c433a0fbd17b729947b400747c334a9c29a9af6c58", + "url": "https://files.pythonhosted.org/packages/4e/39/71870a73ac498c5af423c566bb9ac6c3a3b2147acfa35ed44a9ccc41fab1/greenlet-1.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "b336501a05e13b616ef81ce329c0e09ac5ed8c732d9ba7e3e983fcc1a9e86965", + "url": "https://files.pythonhosted.org/packages/b2/c2/141916d37869e817cf18a590de62e8bc732e602ed6d2786c09f4b365f2cf/greenlet-1.1.2-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "97e5306482182170ade15c4b0d8386ded995a07d7cc2ca8f27958d34d6736497", + "url": "https://files.pythonhosted.org/packages/b7/55/0e1a2c02f043b9fc698e70c6ef247b233c82791a78466929352fefcae5e8/greenlet-1.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "e6a36bb9474218c7a5b27ae476035497a6990e21d04c279884eb10d9b290f1b1", + "url": "https://files.pythonhosted.org/packages/d0/bf/e6d86812a6d81536afbb4ad0a7a9793fd2d31268a6dabbc9d82534eab5fa/greenlet-1.1.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + } + ], + "project_name": "greenlet", + "requires_dists": [ + "Sphinx; extra == \"docs\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "1.1.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "4ee51964edfd0a1293a95bb0d72d134ecf889379d90d2612cbf663623ce832b4", + "url": "https://files.pythonhosted.org/packages/61/c5/abf1e430561246dc0517e8ad92ad60e42a17c0eab7a62bb9d3f6334a8ab6/grpcio-1.44.0-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "11f811c0fffd84fca747fbc742464575e5eb130fd4fb4d6012ccc34febd001db", + "url": "https://files.pythonhosted.org/packages/42/98/0e5c94596eb48c096b91d24dd0cffb8106bbb018294394cd1fe016e18f91/grpcio-1.44.0-cp310-cp310-linux_armv7l.whl" + }, + { + "algorithm": "sha256", + "hash": "9a86a91201f8345502ea81dee0a55ae13add5fafadf109b17acd858fe8239651", + "url": "https://files.pythonhosted.org/packages/53/73/972df49733bc60f77452efcbb7ba087fcfe89b432f7d0c9fad68544d56ef/grpcio-1.44.0-cp310-cp310-macosx_10_10_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "4bae1c99896045d3062ab95478411c8d5a52cb84b91a1517312629fa6cfeb50e", + "url": "https://files.pythonhosted.org/packages/65/75/8b706e1170e2c7b6242b1675259e47986bb4fc490f29387989a965972e6e/grpcio-1.44.0.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "d1e22d3a510438b7f3365c0071b810672d09febac6e8ca8a47eab657ae5f347b", + "url": "https://files.pythonhosted.org/packages/7c/d4/5c55043df5f05fc780359fc5fe9e77f0b0709d01839ec4bd37438ec68e56/grpcio-1.44.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "41036a574cab3468f24d41d6ed2b52588fb85ed60f8feaa925d7e424a250740b", + "url": "https://files.pythonhosted.org/packages/83/55/f1754f8bd4024d078ae87ebb381bea7be7fceb306d23536c62b3a66cd23e/grpcio-1.44.0-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "5f3c54ebb5d9633a557335c01d88d3d4928e9b1b131692283b6184da1edbec0b", + "url": "https://files.pythonhosted.org/packages/b5/d8/3f8c773c1a26caa50a441b49324fcbda007c4d0e839b1037266217fb6526/grpcio-1.44.0-cp310-cp310-manylinux_2_17_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "3d47553b8e86ab1e59b0185ba6491a187f94a0239f414c8fc867a22b0405b798", + "url": "https://files.pythonhosted.org/packages/e3/f8/9d335b4cd261771752332cfdb9452b31fa0c81dc9a7b3915611c7290032d/grpcio-1.44.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl" + } + ], + "project_name": "grpcio", + "requires_dists": [ + "enum34>=1.0.4; python_version < \"3.4\"", + "futures>=2.2.0; python_version < \"3.2\"", + "grpcio-tools>=1.44.0; extra == \"protobuf\"", + "six>=1.5.2" + ], + "requires_python": ">=3.6", + "version": "1.44" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "3e16260dfe6e997330473863e01466b0992369ae2337a0249b390b4651cff424", + "url": "https://files.pythonhosted.org/packages/ff/92/1d3071b330af8405b0df7945eb4162546a0ef9a7c8fd1a68baba53d994b3/grpcio_tools-1.44.0-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "65c2fe3cdc5425180f01dd303e28d4f363d38f4c2e3a7e1a87caedd5417e23bb", + "url": "https://files.pythonhosted.org/packages/38/45/e0d068847b0f2a0b2921ef557c055d9c24a08860701e80d07a677a2b1221/grpcio_tools-1.44.0-cp310-cp310-manylinux_2_17_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "121c9765cee8636201cf0d4e80bc7b509813194919bccdb66e9671c4ece6dac3", + "url": "https://files.pythonhosted.org/packages/3c/08/0461a640d172067e4e502458a22168cf67ebf492b5ca98d0fa4a5c23fc35/grpcio_tools-1.44.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "be37f458ea510c9a8f1caabbc2b258d12e55d189a567f5edcace90f27dc0efbf", + "url": "https://files.pythonhosted.org/packages/53/3e/bdb69af20f03ce1ad54a65625302b137d3a040958f214cee5efa3ca0b0c4/grpcio-tools-1.44.0.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "9f58529e24f613019a85c258a274d441d89e0cad8cf7fca21ef3807ba5840c5d", + "url": "https://files.pythonhosted.org/packages/a7/fe/f5505fba18ac163e102f4e27fafe7303adedd2a193e618d39392d404cf9a/grpcio_tools-1.44.0-cp310-cp310-linux_armv7l.whl" + }, + { + "algorithm": "sha256", + "hash": "5caef118deb8cdee1978fd3d8e388a9b256cd8d34e4a8895731ac0e86fa5e47c", + "url": "https://files.pythonhosted.org/packages/c7/41/da1269cc98dfac2eb7c5f3ac7438ea5a603d49f5a0f99b2ecf159ca25f8b/grpcio_tools-1.44.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "90d1fac188bac838c4169eb3b67197887fa0572ea8a90519a20cddb080800549", + "url": "https://files.pythonhosted.org/packages/d6/42/40375915cf7396b1d68fb87357521984537459902de340a0bfb0c8bfbd73/grpcio_tools-1.44.0-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "1d120082236f8d2877f8a19366476b82c3562423b877b7c471a142432e31c2c4", + "url": "https://files.pythonhosted.org/packages/ec/eb/d962c9aa400ef53448152db006f3cf5f5db9d528fbe1ecafd55b7c37ead7/grpcio_tools-1.44.0-cp310-cp310-macosx_10_10_universal2.whl" + } + ], + "project_name": "grpcio-tools", + "requires_dists": [ + "grpcio>=1.44.0", + "protobuf<4.0dev,>=3.5.0.post1", + "setuptools" + ], + "requires_python": ">=3.6", + "version": "1.44" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "81d6d8e39695f2c37954d1011c0480ef7cf444d4e3ae24bc5e89ee5de360139a", + "url": "https://files.pythonhosted.org/packages/0c/39/eae11344d69ba435ec13d6bcc1a9eea3d2278324506fcd0e52d1ed8958c8/hiredis-2.0.0.tar.gz" + } + ], + "project_name": "hiredis", + "requires_dists": [], + "requires_python": ">=3.6", + "version": "2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", + "url": "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", + "url": "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz" + } + ], + "project_name": "humanfriendly", + "requires_dists": [ + "monotonic; python_version == \"2.7\"", + "pyreadline3; sys_platform == \"win32\" and python_version >= \"3.8\"", + "pyreadline; sys_platform == \"win32\" and python_version < \"3.8\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "10" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "953b393f5bd67e19d47a4c0fd20c3a3537853967b307e49729c4755d3551753c", + "url": "https://files.pythonhosted.org/packages/00/76/b31b3c65463c990af4e63d3dafb149e34bbb4902b01670d4cb153bf1cbc2/humanize-4.1.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "3a119b242ec872c029d8b7bf8435a61a5798f124b244a08013aec5617302f80e", + "url": "https://files.pythonhosted.org/packages/bb/68/c8be852a42c3b0364ad256a8cb41ab619d445b812aa16f94c9d16b042d74/humanize-4.1.0.tar.gz" + } + ], + "project_name": "humanize", + "requires_dists": [ + "freezegun; extra == \"tests\"", + "importlib-metadata; python_version < \"3.8\"", + "pytest-cov; extra == \"tests\"", + "pytest; extra == \"tests\"" + ], + "requires_python": ">=3.7", + "version": "4.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff", + "url": "https://files.pythonhosted.org/packages/04/a2/d918dcd22354d8958fe113e1a3630137e0fc8b44859ade3063982eacd2a4/idna-3.3-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d", + "url": "https://files.pythonhosted.org/packages/62/08/e3fc7c8161090f742f504f40b1bccbfc544d4a4e09eb774bf40aafce5436/idna-3.3.tar.gz" + } + ], + "project_name": "idna", + "requires_dists": [], + "requires_python": ">=3.5", + "version": "3.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3", + "url": "https://files.pythonhosted.org/packages/9b/dd/b3c12c6d707058fa947864b67f0c4e0c39ef8610988d7baea9578f3c48f3/iniconfig-1.1.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32", + "url": "https://files.pythonhosted.org/packages/23/a2/97899f6bd0e873fed3a7e67ae8d3a08b21799430fb4da15cfedf10d6e2c2/iniconfig-1.1.1.tar.gz" + } + ], + "project_name": "iniconfig", + "requires_dists": [], + "requires_python": null, + "version": "1.1.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "2596ea5482711c1ee3ef2df6c290aaf370a13c55a007826e8f7c32d696d1d00a", + "url": "https://files.pythonhosted.org/packages/c1/84/7bfe436fa6a4943eecb17c2cca9c84215299684575376d664ea6bf294439/janus-1.0.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "df976f2cdcfb034b147a2d51edfc34ff6bfb12d4e2643d3ad0e10de058cb1612", + "url": "https://files.pythonhosted.org/packages/b8/a8/facab7275d7d3d2032f375843fe46fad1cfa604a108b5a238638d4615bdc/janus-1.0.0.tar.gz" + } + ], + "project_name": "janus", + "requires_dists": [ + "typing-extensions>=3.7.4.3" + ], + "requires_python": ">=3.7", + "version": "1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "077ce6014f7b40d03b47d1f1ca4b0fc8328a692bd284016f806ed0eaca390ad8", + "url": "https://files.pythonhosted.org/packages/20/9a/e5d9ec41927401e41aea8af6d16e78b5e612bca4699d417f646a9610a076/Jinja2-3.0.3-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "611bb273cd68f3b993fabdc4064fc858c5b47a973cb5aa7999ec1ba405c87cd7", + "url": "https://files.pythonhosted.org/packages/91/a5/429efc6246119e1e3fbf562c00187d04e83e54619249eb732bb423efa6c6/Jinja2-3.0.3.tar.gz" + } + ], + "project_name": "jinja2", + "requires_dists": [ + "Babel>=2.7; extra == \"i18n\"", + "MarkupSafe>=2.0" + ], + "requires_python": ">=3.6", + "version": "3.0.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "a6dee02a1b39ea4bb9c4c2cc415ea0ada33d8ea0a920f7d4fb6d166989dcac01", + "url": "https://files.pythonhosted.org/packages/6e/fc/2cab119f679648b348b8940de0dd744a1f0ee99c690aa2ef6072f050816c/kubernetes-10.0.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "3770a496663396ad1def665eeadb947b3f45217a08b64b10c01a57e981ac8592", + "url": "https://files.pythonhosted.org/packages/db/4e/af5af9e1cf3d6c9d001f0fcf1a0efc29a02c078da97a5fc9d7b0d17e631e/kubernetes-10.0.1.tar.gz" + } + ], + "project_name": "kubernetes", + "requires_dists": [ + "adal>=1.0.2; extra == \"adal\"", + "certifi>=14.05.14", + "google-auth>=1.0.1", + "ipaddress>=1.0.17; python_version == \"2.7\"", + "python-dateutil>=2.5.3", + "pyyaml>=3.12", + "requests", + "requests-oauthlib", + "setuptools>=21.0.0", + "six>=1.9.0", + "urllib3>=1.24.2", + "websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0" + ], + "requires_python": null, + "version": "10.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "18fc53caec93156cdf761ccf1df51394ed0cf28002ff44238db82eac1f5519a6", + "url": "https://files.pythonhosted.org/packages/8b/3a/16ab267c7d4439e0b37ff66d246058500258bf93c00c0573d864fccb03b5/kubernetes_asyncio-9.1.0.tar.gz" + } + ], + "project_name": "kubernetes-asyncio", + "requires_dists": [ + "aiohttp>=2.3.10", + "certifi>=14.05.14", + "python-dateutil>=2.5.3", + "pyyaml>=3.12", + "setuptools>=21.0.0", + "six>=1.9.0", + "urllib3>=1.23" + ], + "requires_python": null, + "version": "9.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "e29ca814a98bb0f81674617d878e5f611cb993c19ea47f22c80da3569425f9bd", + "url": "https://files.pythonhosted.org/packages/cd/55/8951788003c9f65dfcf9c51d7ba7952c6c5ae72685aa3cf9c66e925bd538/lark-parser-0.11.3.tar.gz" + } + ], + "project_name": "lark-parser", + "requires_dists": [ + "atomicwrites; extra == \"atomic_cache\"", + "js2py; extra == \"nearley\"", + "regex; extra == \"regex\"" + ], + "requires_python": null, + "version": "0.11.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "23aab11fdbbb0f1051b93793a58323ff937e98e34aece1c4219675122e57e4ba", + "url": "https://files.pythonhosted.org/packages/6e/01/45ab9f723a93e0ca75fba4d2c266bb041120cb4215eab94f7c78743ac7ed/Mako-1.2.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "9a7c7e922b87db3686210cf49d5d767033a41d4010b284e747682c92bddd8b39", + "url": "https://files.pythonhosted.org/packages/50/ec/1d687348f0954bda388bfd1330c158ba8d7dea4044fc160e74e080babdb9/Mako-1.2.0.tar.gz" + } + ], + "project_name": "mako", + "requires_dists": [ + "Babel; extra == \"babel\"", + "MarkupSafe>=0.9.2", + "importlib-metadata; python_version < \"3.8\"", + "lingua; extra == \"lingua\"", + "pytest; extra == \"testing\"" + ], + "requires_python": ">=3.7", + "version": "1.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "dda30ba7e87fbbb7eab1ec9f58678558fd9a6b8b853530e176eabd064da81417", + "url": "https://files.pythonhosted.org/packages/3d/4b/15e5b9d40c4b58e97ebcb8ed5845a215fa5b7cf49a7f1cc7908f8db9cf46/MarkupSafe-2.1.1-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b", + "url": "https://files.pythonhosted.org/packages/1d/97/2288fe498044284f39ab8950703e88abbac2abbdf65524d576157af70556/MarkupSafe-2.1.1.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "694deca8d702d5db21ec83983ce0bb4b26a578e71fbdbd4fdcd387daa90e4d5e", + "url": "https://files.pythonhosted.org/packages/5c/1a/ac3a2b2a4ef1196c15dd8a143fc28eddeb6e6871d6d1de64dc44ef7f59b6/MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "4a33dea2b688b3190ee12bd7cfa29d39c9ed176bda40bfa11099a3ce5d3a7ac6", + "url": "https://files.pythonhosted.org/packages/5e/3d/0a7df21deca52e20de81f8a895ac29df68944588c0030be9aa1e6c07877c/MarkupSafe-2.1.1-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a", + "url": "https://files.pythonhosted.org/packages/8c/96/7e608e1a942232cb8c81ca24093e71e07e2bacbeb2dad62a0f82da28ed54/MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "10c1bfff05d95783da83491be968e8fe789263689c02724e0c691933c52994f5", + "url": "https://files.pythonhosted.org/packages/9e/82/2e089c6f34e77c073aa5a67040d368aac0dfb9b8ccbb46d381452c26fc33/MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "b7bd98b796e2b6553da7225aeb61f447f80a1ca64f41d83612e6139ca5213aa4", + "url": "https://files.pythonhosted.org/packages/a3/47/9dcc08eff8ab94f1e50f59f9cd322b710ef5db7e8590fdd8df924406fc9c/MarkupSafe-2.1.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "b09bf97215625a311f669476f44b8b318b075847b49316d3e28c08e41a7a573f", + "url": "https://files.pythonhosted.org/packages/ad/fa/292a72cddad41e3c06227b446a0af53ff642a40755fc5bd695f439c35ba8/MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "86b1f75c4e7c2ac2ccdaec2b9022845dbb81880ca318bb7a0a01fbf7813e3812", + "url": "https://files.pythonhosted.org/packages/d9/60/94e9de017674f88a514804e2924bdede9a642aba179d2045214719d6ec76/MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "efc1913fd2ca4f334418481c7e595c00aad186563bbc1ec76067848c7ca0a933", + "url": "https://files.pythonhosted.org/packages/fc/e4/78c7607352dd574d524daad079f855757d406d36b919b1864a5a07978390/MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "a49907dd8420c5685cfa064a1335b6754b74541bbb3706c259c02ed65b644b3e", + "url": "https://files.pythonhosted.org/packages/ff/3a/42262a3aa6415befee33b275b31afbcef4f7f8d2f4380061b226c692ee2a/MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + } + ], + "project_name": "markupsafe", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "2.1.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "43e6dd9942dffd72661a2c4ef383ad7da1e6a3e968a927ad7a6083ab410a688b", + "url": "https://files.pythonhosted.org/packages/e5/c3/48e2c81038f57e8caab9a6e6fb6c2fc23536c59b092abefc447e6b5d1903/more_itertools-8.12.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "7dc6ad46f05f545f900dd59e8dfb4e84a4827b97b3cfecb175ea0c7d247f6064", + "url": "https://files.pythonhosted.org/packages/dc/b5/c216ffeace7b89b7387fe08e1b39a07c6da38ea82c60e2e630dd5883813b/more-itertools-8.12.0.tar.gz" + } + ], + "project_name": "more-itertools", + "requires_dists": [], + "requires_python": ">=3.5", + "version": "8.12" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "c1ba333b4024c17c7591f0f372e2daa3c31db495a9b2af3cf664aef3c14354f7", + "url": "https://files.pythonhosted.org/packages/26/71/5fbd40e87fabaf6f60c2fa8934d93ec1df542b7f978a080ce99f6734934d/msgpack-1.0.3-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "36a64a10b16c2ab31dcd5f32d9787ed41fe68ab23dd66957ca2826c7f10d0b85", + "url": "https://files.pythonhosted.org/packages/06/e5/da31b9be6bed416c29906e0f9eff66af3e08f0b6e11caa7858d649e8ca1f/msgpack-1.0.3-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "1c58cdec1cb5fcea8c2f1771d7b5fec79307d056874f746690bd2bdd609ab147", + "url": "https://files.pythonhosted.org/packages/1b/18/61b7462849c31fafd7c7d05a2ae896d495a1c1bf7f25788a4a8af9439153/msgpack-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "96acc674bb9c9be63fa8b6dabc3248fdc575c4adc005c440ad02f87ca7edd079", + "url": "https://files.pythonhosted.org/packages/4f/e9/837b5c2209d41ddaf99cc7247598191d6f9f776c017b95abb5ada761ef93/msgpack-1.0.3-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "51fdc7fb93615286428ee7758cecc2f374d5ff363bdd884c7ea622a7a327a81e", + "url": "https://files.pythonhosted.org/packages/61/3c/2206f39880d38ca7ad8ac1b28d2d5ca81632d163b2d68ef90e46409ca057/msgpack-1.0.3.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "2f97c0f35b3b096a330bb4a1a9247d0bd7e1f3a2eba7ab69795501504b1c2c39", + "url": "https://files.pythonhosted.org/packages/a5/36/3734c798885a93c6e8fe4422184ad089c6e2e44c18d2b18f09cc029c02b8/msgpack-1.0.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "2c3ca57c96c8e69c1a0d2926a6acf2d9a522b41dc4253a8945c4c6cd4981a4e3", + "url": "https://files.pythonhosted.org/packages/b9/f4/4d2ee26409739c1a4b1dc3b8e4c50dedd9d5054d1ab5fec9830c42ebb3b6/msgpack-1.0.3-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "b0a792c091bac433dfe0a70ac17fc2087d4595ab835b47b89defc8bbabcf5c73", + "url": "https://files.pythonhosted.org/packages/bd/c5/e69b0e5f216191b09261957a75a78078aa2bc90a7138e5186eb641e45d9f/msgpack-1.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + } + ], + "project_name": "msgpack", + "requires_dists": [], + "requires_python": null, + "version": "1.0.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8cbf0132f3de7cc6c6ce00147cc78e6439ea736cee6bca4f068bcf892b0fd658", + "url": "https://files.pythonhosted.org/packages/72/e4/9ea1c573503ddf11ea56c48e9af49660fbd45a13ceb394a48e437c32eba9/multidict-6.0.2-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "041b81a5f6b38244b34dc18c7b6aba91f9cdaf854d9a39e5ff0b58e2b5773b9c", + "url": "https://files.pythonhosted.org/packages/14/7b/d11a6dec8996ca054e727f7d3b1578753b44ba9e378c9449404aef076b47/multidict-6.0.2-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2", + "url": "https://files.pythonhosted.org/packages/1d/35/0ea9ce0cc0aeb3b4c898595d807ac80ebbd295efefabc80c4f6c6bee8106/multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "fcb91630817aa8b9bc4a74023e4198480587269c272c58b3279875ed7235c293", + "url": "https://files.pythonhosted.org/packages/23/31/c8736506ae534e20c8f0b1b090bc2ad89349d96e5e7c5928464c6c876599/multidict-6.0.2-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "50bd442726e288e884f7be9071016c15a8742eb689a593a0cac49ea093eef0a7", + "url": "https://files.pythonhosted.org/packages/2a/c2/0f63e839b93a68dd2bcfbf30cc35dbdb4b172ad0078e32176628ec7d91d5/multidict-6.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "0556a1d4ea2d949efe5fd76a09b4a82e3a4a30700553a6725535098d8d9fb672", + "url": "https://files.pythonhosted.org/packages/31/b1/eb1a8cdb3bb177929dfee9543c0fd8074768c9e4431c7b3da7d01a3c66d8/multidict-6.0.2-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "8064b7c6f0af936a741ea1efd18690bacfbae4078c0c385d7c3f611d11f0cf87", + "url": "https://files.pythonhosted.org/packages/3f/44/83e4bd573cc80c41896394129f162b69fe1ed9fd7a99ca4153740e20349c/multidict-6.0.2-cp310-cp310-musllinux_1_1_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "2d36e929d7f6a16d4eb11b250719c39560dd70545356365b494249e2186bc389", + "url": "https://files.pythonhosted.org/packages/69/d7/c49e9ca438846658191905f5df53a895738b478cdca98580f092b557802c/multidict-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f4f052ee022928d34fe1f4d2bc743f32609fb79ed9c49a1710a5ad6b2198db20", + "url": "https://files.pythonhosted.org/packages/7e/21/73f8a51219fd9b4b04badcc7933ce5f5344ab33308492755220524bc4faf/multidict-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "47e6a7e923e9cada7c139531feac59448f1f47727a79076c0b1ee80274cd8eee", + "url": "https://files.pythonhosted.org/packages/9b/a4/a8d3c6bb884d97fd1e9d37c5c9a8c46de799d7465e455b617f33dfbb52ba/multidict-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "3368bf2398b0e0fcbf46d85795adc4c259299fec50c1416d0f77c0a843a3eed9", + "url": "https://files.pythonhosted.org/packages/bf/b9/b8c9845853b7086476201ff18bcff5a169e945c5d8397e234ba4453a38d4/multidict-6.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "626fe10ac87851f4cffecee161fc6f8f9853f0f6f1035b59337a51d29ff3b4f9", + "url": "https://files.pythonhosted.org/packages/ce/b3/7b2ed0a1fca198da0e6354ccd0358757c12b56f204c179271cf81a7372ae/multidict-6.0.2-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3", + "url": "https://files.pythonhosted.org/packages/d2/67/ef1ef8f3539642d90c77bc7c86cc7283297cd2ab100b45d7541476ef641e/multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "225383a6603c086e6cef0f2f05564acb4f4d5f019a4e3e983f572b8530f70c88", + "url": "https://files.pythonhosted.org/packages/df/93/34efbfa7aa778b04b365960f52f7071d7942ce386572aac8940ae032dd48/multidict-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "5fdda29a3c7e76a064f2477c9aab1ba96fd94e02e386f1e665bca1807fc5386f", + "url": "https://files.pythonhosted.org/packages/ee/a1/a7cc44b7ed84e430c2c176420ffa432a74a2432f7df4f71988365fa8772a/multidict-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "5ff3bd75f38e4c43f1f470f2df7a4d430b821c4ce22be384e1459cb57d6bb013", + "url": "https://files.pythonhosted.org/packages/fa/a7/71c253cdb8a1528802bac7503bf82fe674367e4055b09c28846fdfa4ab90/multidict-6.0.2.tar.gz" + } + ], + "project_name": "multidict", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "6.0.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "043a79146eb2907edf439899f262b3dfe41717d34124298ed281139a8b93ca32", + "url": "https://files.pythonhosted.org/packages/a6/91/86a6eac449ddfae239e93ffc1918cf33fd9bab35c04d1e963b311e347a73/netifaces-0.11.0.tar.gz" + } + ], + "project_name": "netifaces", + "requires_dists": [], + "requires_python": null, + "version": "0.11" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe", + "url": "https://files.pythonhosted.org/packages/1d/46/5ee2475e1b46a26ca0fa10d3c1d479577fde6ee289f8c6aa6d7ec33e31fd/oauthlib-3.2.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2", + "url": "https://files.pythonhosted.org/packages/6e/7e/a43cec8b2df28b6494a865324f0ac4be213cb2edcf1e2a717547a93279b0/oauthlib-3.2.0.tar.gz" + } + ], + "project_name": "oauthlib", + "requires_dists": [ + "blinker>=1.4.0; extra == \"signals\"", + "cryptography>=3.0.0; extra == \"rsa\"", + "cryptography>=3.0.0; extra == \"signedtoken\"", + "pyjwt<3,>=2.0.0; extra == \"signedtoken\"" + ], + "requires_python": ">=3.6", + "version": "3.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522", + "url": "https://files.pythonhosted.org/packages/05/8e/8de486cbd03baba4deef4142bd643a3e7bbe954a784dc1bb17142572d127/packaging-21.3-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb", + "url": "https://files.pythonhosted.org/packages/df/9e/d1a7217f69310c1db8fdf8ab396229f55a699ce34a203691794c5d1cad0c/packaging-21.3.tar.gz" + } + ], + "project_name": "packaging", + "requires_dists": [ + "pyparsing!=3.0.5,>=2.0.2" + ], + "requires_python": ">=3.6", + "version": "21.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1", + "url": "https://files.pythonhosted.org/packages/3b/a4/ab6b7589382ca3df236e03faa71deac88cae040af60c071a78d254a62172/passlib-1.7.4-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04", + "url": "https://files.pythonhosted.org/packages/b6/06/9da9ee59a67fae7761aab3ccc84fa4f3f33f125b370f1ccdb915bf967c11/passlib-1.7.4.tar.gz" + } + ], + "project_name": "passlib", + "requires_dists": [ + "argon2-cffi>=18.2.0; extra == \"argon2\"", + "bcrypt>=3.1.0; extra == \"bcrypt\"", + "cloud-sptheme>=1.10.1; extra == \"build_docs\"", + "cryptography; extra == \"totp\"", + "sphinx>=1.6; extra == \"build_docs\"", + "sphinxcontrib-fulltoc>=1.2.0; extra == \"build_docs\"" + ], + "requires_python": null, + "version": "1.7.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937", + "url": "https://files.pythonhosted.org/packages/39/7b/88dbb785881c28a102619d46423cb853b46dbccc70d3ac362d99773a78ce/pexpect-4.8.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c", + "url": "https://files.pythonhosted.org/packages/e5/9b/ff402e0e930e70467a7178abb7c128709a30dfb22d8777c043e501bc1b10/pexpect-4.8.0.tar.gz" + } + ], + "project_name": "pexpect", + "requires_dists": [ + "ptyprocess>=0.5" + ], + "requires_python": null, + "version": "4.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3", + "url": "https://files.pythonhosted.org/packages/9e/01/f38e2ff29715251cf25532b9082a1589ab7e4f571ced434f98d0139336dc/pluggy-1.0.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159", + "url": "https://files.pythonhosted.org/packages/a1/16/db2d7de3474b6e37cbb9c008965ee63835bba517e22cdb8c35b5116b5ce1/pluggy-1.0.0.tar.gz" + } + ], + "project_name": "pluggy", + "requires_dists": [ + "importlib-metadata>=0.12; python_version < \"3.8\"", + "pre-commit; extra == \"dev\"", + "pytest-benchmark; extra == \"testing\"", + "pytest; extra == \"testing\"", + "tox; extra == \"dev\"" + ], + "requires_python": ">=3.6", + "version": "1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0", + "url": "https://files.pythonhosted.org/packages/cf/9c/fb5d48abfe5d791cd496e4242ebcf87a4bb2e0c3dcd6e0ae68c11426a528/promise-2.3.tar.gz" + } + ], + "project_name": "promise", + "requires_dists": [ + "coveralls; extra == \"test\"", + "futures; extra == \"test\"", + "mock; extra == \"test\"", + "pytest-asyncio; extra == \"test\"", + "pytest-benchmark; extra == \"test\"", + "pytest-cov; extra == \"test\"", + "pytest>=2.7.3; extra == \"test\"", + "six", + "typing>=3.6.4; python_version < \"3.5\"" + ], + "requires_python": null, + "version": "2.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "adfc6cf69c7f8c50fd24c793964eef18f0ac321315439d94945820612849c388", + "url": "https://files.pythonhosted.org/packages/ef/c8/2e7f7feaf804b7206e6cc8fa3f0f49834a78f7cb127813d2c45e42d5f7bf/protobuf-3.20.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "adc31566d027f45efe3f44eeb5b1f329da43891634d61c75a5944e9be6dd42c9", + "url": "https://files.pythonhosted.org/packages/19/96/1283259c25bc48a6df98fa096f66fc568b40137b93806ef5ff66a2d166b1/protobuf-3.20.1.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "cd68be2559e2a3b84f517fb029ee611546f7812b1fdd0aa2ecc9bc6ec0e4fdde", + "url": "https://files.pythonhosted.org/packages/4c/be/bdd22d86d24e5b8b08673d80be70d1a72c255f85152ff09b28490904092a/protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "ff8d8fa42675249bb456f5db06c00de6c2f4c27a065955917b28c4f15978b9c3", + "url": "https://files.pythonhosted.org/packages/70/75/df318e565cf126a9464b9220ef6adfecb44fb7c68df140bc5680d0ed05c3/protobuf-3.20.1-cp310-cp310-manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "9016d01c91e8e625141d24ec1b20fed584703e527d28512aa8c8707f105a683c", + "url": "https://files.pythonhosted.org/packages/89/1a/b4d72e1d7134ffac2156d1dfc3b9ddb21d1664ff392e1e5fe2882a117f81/protobuf-3.20.1-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "3cc797c9d15d7689ed507b165cd05913acb992d78b379f6014e013f9ecb20996", + "url": "https://files.pythonhosted.org/packages/bd/ca/0d522203bedd17a8c53cb869e1dfd7ac9140c66b76b3cbca25bf601448b2/protobuf-3.20.1-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "32ca378605b41fd180dfe4e14d3226386d8d1b002ab31c969c366549e66a2bb7", + "url": "https://files.pythonhosted.org/packages/c0/9c/bb88091287418ae1cf8af2bb9ed9710748a562b9abc227e4884d687a8650/protobuf-3.20.1-cp310-cp310-win_amd64.whl" + } + ], + "project_name": "protobuf", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "3.20.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "0c9ccb99ab76025f2f0bbecf341d4656e9c1351db8cc8a03ccd62e318ab4b5c6", + "url": "https://files.pythonhosted.org/packages/e1/b0/7276de53321c12981717490516b7e612364f2cb372ee8901bd4a66a000d7/psutil-5.8.0.tar.gz" + } + ], + "project_name": "psutil", + "requires_dists": [ + "enum34; python_version <= \"3.4\" and extra == \"test\"", + "ipaddress; python_version < \"3.0\" and extra == \"test\"", + "mock; python_version < \"3.0\" and extra == \"test\"", + "pywin32; sys_platform == \"win32\" and extra == \"test\"", + "unittest2; python_version < \"3.0\" and extra == \"test\"", + "wmi; sys_platform == \"win32\" and extra == \"test\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.6", + "version": "5.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8b344adbb9a862de0c635f4f0425b7958bf5a4b927c8594e6e8d261775796d53", + "url": "https://files.pythonhosted.org/packages/44/4a/6b17a2907d1fd0e891f61784d916c39945a85d855badd609c87bc2d9021e/psycopg2_binary-2.9.3-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "a9e1f75f96ea388fbcef36c70640c4efbe4650658f3d6a2967b4cc70e907352e", + "url": "https://files.pythonhosted.org/packages/23/db/2383e85ceff06a2279001c027bb75406baf53d94a75c7648cb5d3b2a23d1/psycopg2_binary-2.9.3-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "6e82d38390a03da28c7985b394ec3f56873174e2c88130e6966cb1c946508e65", + "url": "https://files.pythonhosted.org/packages/2b/20/2f1fc936f8ee4828b348aba3efacab2731995b21da0a955a25a398c0b57b/psycopg2_binary-2.9.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "090f3348c0ab2cceb6dfbe6bf721ef61262ddf518cd6cc6ecc7d334996d64efa", + "url": "https://files.pythonhosted.org/packages/46/9f/536f052c80e71d37edcb8902bef319c1f8d6e7f4ba49d4a999b6cea87589/psycopg2_binary-2.9.3-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "3a79d622f5206d695d7824cbf609a4f5b88ea6d6dab5f7c147fc6d333a8787e4", + "url": "https://files.pythonhosted.org/packages/62/cf/9e510ea668a22be01e3ba25deff03b9fa9a34598e6eabc714e67e868e0fb/psycopg2_binary-2.9.3-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "083a55275f09a62b8ca4902dd11f4b33075b743cf0d360419e2051a8a5d5ff76", + "url": "https://files.pythonhosted.org/packages/7f/07/71dd915057d7ce28bc0167d3dff17166a821913085556df3fedf7968897d/psycopg2_binary-2.9.3-cp310-cp310-manylinux_2_24_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "539b28661b71da7c0e428692438efbcd048ca21ea81af618d845e06ebfd29478", + "url": "https://files.pythonhosted.org/packages/83/8f/748aa34614899181c5e420850281c18efec93f260a013d568020b38320e3/psycopg2_binary-2.9.3-cp310-cp310-macosx_10_14_x86_64.macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "0a29729145aaaf1ad8bafe663131890e2111f13416b60e460dae0a96af5905c9", + "url": "https://files.pythonhosted.org/packages/97/87/a73b2f93009bf66fc9b5a9aa1b8bdf94e462657bcb0a99c259d88683d217/psycopg2_binary-2.9.3-cp310-cp310-manylinux_2_24_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "57804fc02ca3ce0dbfbef35c4b3a4a774da66d66ea20f4bda601294ad2ea6092", + "url": "https://files.pythonhosted.org/packages/ac/84/d01b8a9aebeae783b84f8ee09d07ee861da2f8e260772ef7f3878549bf17/psycopg2_binary-2.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "c3ae8e75eb7160851e59adc77b3a19a976e50622e44fd4fd47b8b18208189d42", + "url": "https://files.pythonhosted.org/packages/cd/7f/05c6036e6482b7cddf3a12904344655defd7fc16b008d7f17f28d25d2f2a/psycopg2_binary-2.9.3-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "761df5313dc15da1502b21453642d7599d26be88bff659382f8f9747c7ebea4e", + "url": "https://files.pythonhosted.org/packages/d7/1c/8d042630c5ff3c3e6d93c992bd7ecf516d577803b96781c6caa649bbf6e5/psycopg2-binary-2.9.3.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "7b1e9b80afca7b7a386ef087db614faebbf8839b7f4db5eb107d0f1a53225029", + "url": "https://files.pythonhosted.org/packages/f6/56/4c1186774f1dd75b1492e3fabc8b5c57d213ebc412e7d38de2813918bad4/psycopg2_binary-2.9.3-cp310-cp310-win32.whl" + } + ], + "project_name": "psycopg2-binary", + "requires_dists": [], + "requires_python": ">=3.6", + "version": "2.9.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", + "url": "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", + "url": "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz" + } + ], + "project_name": "ptyprocess", + "requires_dists": [], + "requires_python": null, + "version": "0.7" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", + "url": "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", + "url": "https://files.pythonhosted.org/packages/98/ff/fec109ceb715d2a6b4c4a85a61af3b40c723a961e8828319fbcb15b868dc/py-1.11.0.tar.gz" + } + ], + "project_name": "py", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "1.11" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d", + "url": "https://files.pythonhosted.org/packages/62/1e/a94a8d635fa3ce4cfc7f506003548d0a2447ae76fd5ca53932970fe3053f/pyasn1-0.4.8-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "aef77c9fb94a3ac588e87841208bdec464471d9871bd5050a287cc9a475cd0ba", + "url": "https://files.pythonhosted.org/packages/a4/db/fffec68299e6d7bad3d504147f9094830b704527a7fc098b721d38cc7fa7/pyasn1-0.4.8.tar.gz" + } + ], + "project_name": "pyasn1", + "requires_dists": [], + "requires_python": null, + "version": "0.4.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74", + "url": "https://files.pythonhosted.org/packages/95/de/214830a981892a3e286c3794f41ae67a4495df1108c3da8a9f62159b9a9d/pyasn1_modules-0.2.8-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e", + "url": "https://files.pythonhosted.org/packages/88/87/72eb9ccf8a58021c542de2588a867dbefc7556e14b2866d1e40e9e2b587e/pyasn1-modules-0.2.8.tar.gz" + } + ], + "project_name": "pyasn1-modules", + "requires_dists": [ + "pyasn1<0.5.0,>=0.4.6" + ], + "requires_python": null, + "version": "0.2.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "27a6f09dbfb69bb79609724c0f90dfaa7c215876a7cd9f12d585574d1f922112", + "url": "https://files.pythonhosted.org/packages/b7/ec/9f7f76cd5897643d84064bdf4ae117d2157ab902ab1f404860558df577a5/pycares-4.1.2-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "ad7b28e1b6bc68edd3d678373fa3af84e39d287090434f25055d21b4716b2fc6", + "url": "https://files.pythonhosted.org/packages/0c/4e/8b68387b0ab0285d3bce1aa95f23e53d4d26d49743994f3bb4668f82c556/pycares-4.1.2-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "6831e963a910b0a8cbdd2750ffcdf5f2bb0edb3f53ca69ff18484de2cc3807c4", + "url": "https://files.pythonhosted.org/packages/2a/96/1060c9aa3b6e88760c84806be30ae7edcf981c9325fd6998eea3e48045d8/pycares-4.1.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "71b99b9e041ae3356b859822c511f286f84c8889ec9ed1fbf6ac30fb4da13e4c", + "url": "https://files.pythonhosted.org/packages/30/cd/4b0288a95d733ac2ff099fa2a9b27213e0eb9f102d983f63436681df5093/pycares-4.1.2-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "03490be0e7b51a0c8073f877bec347eff31003f64f57d9518d419d9369452837", + "url": "https://files.pythonhosted.org/packages/83/61/17bd0cfb9c4dc8c3738484d604b50d47c78fe4fcfe0ca2c58a61a106f578/pycares-4.1.2.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "b0e50ddc78252f2e2b6b5f2c73e5b2449dfb6bea7a5a0e21dfd1e2bcc9e17382", + "url": "https://files.pythonhosted.org/packages/ae/b0/36c14737c61825279991700eb9b6c51a443868ac304766ca129c31b4817e/pycares-4.1.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "c000942f5fc64e6e046aa61aa53b629b576ba11607d108909727c3c8f211a157", + "url": "https://files.pythonhosted.org/packages/db/bb/01d52575901ba0fb2b2af6990cab48aa10a6298cbcb5400d19398b42e11b/pycares-4.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + } + ], + "project_name": "pycares", + "requires_dists": [ + "cffi>=1.5.0", + "idna>=2.1; extra == \"idna\"" + ], + "requires_python": null, + "version": "4.1.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9", + "url": "https://files.pythonhosted.org/packages/62/d5/5f610ebe421e85889f2e55e33b7f9a6795bd982198517d912eb1c76e1a53/pycparser-2.21-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206", + "url": "https://files.pythonhosted.org/packages/5e/0b/95d387f5f4433cb0f53ff7ad859bd2c6051051cebbb564f139a999ab46de/pycparser-2.21.tar.gz" + } + ], + "project_name": "pycparser", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7", + "version": "2.21" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "dc9c10fb40944260f6ed4c688ece0cd2048414940f1cea51b8b226318411c519", + "url": "https://files.pythonhosted.org/packages/5c/8e/1d9017950034297fffa336c72e693a5b51bbf85141b24a763882cf1977b5/Pygments-2.12.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "5eb116118f9612ff1ee89ac96437bb6b49e8f04d8a13b514ba26f620208e26eb", + "url": "https://files.pythonhosted.org/packages/59/0f/eb10576eb73b5857bc22610cdfc59e424ced4004fe7132c8f2af2cc168d3/Pygments-2.12.0.tar.gz" + } + ], + "project_name": "pygments", + "requires_dists": [], + "requires_python": ">=3.6", + "version": "2.12" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "72d1d253f32dbd4f5c88eaf1fdc62f3a19f676ccbadb9dbc5d07e951b2b26daf", + "url": "https://files.pythonhosted.org/packages/1c/fb/b82e9601b00d88cf8bbee1f39b855ae773f9d5bcbcedb3801b2f72460696/PyJWT-2.4.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "d42908208c699b3b973cbeb01a969ba6a96c821eefb1c5bfe4c390c01d67abba", + "url": "https://files.pythonhosted.org/packages/d8/6b/6287745054dbcccf75903630346be77d4715c594402cec7c2518032416c2/PyJWT-2.4.0.tar.gz" + } + ], + "project_name": "pyjwt", + "requires_dists": [ + "coverage[toml]==5.0.4; extra == \"dev\"", + "coverage[toml]==5.0.4; extra == \"tests\"", + "cryptography>=3.3.1; extra == \"crypto\"", + "cryptography>=3.3.1; extra == \"dev\"", + "mypy; extra == \"dev\"", + "pre-commit; extra == \"dev\"", + "pytest<7.0.0,>=6.0.0; extra == \"dev\"", + "pytest<7.0.0,>=6.0.0; extra == \"tests\"", + "sphinx-rtd-theme; extra == \"dev\"", + "sphinx-rtd-theme; extra == \"docs\"", + "sphinx; extra == \"dev\"", + "sphinx; extra == \"docs\"", + "zope.interface; extra == \"dev\"", + "zope.interface; extra == \"docs\"" + ], + "requires_python": ">=3.6", + "version": "2.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc", + "url": "https://files.pythonhosted.org/packages/6c/10/a7d0fa5baea8fe7b50f448ab742f26f52b80bfca85ac2be9d35cdd9a3246/pyparsing-3.0.9-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb", + "url": "https://files.pythonhosted.org/packages/71/22/207523d16464c40a0310d2d4d8926daffa00ac1f5b1576170a32db749636/pyparsing-3.0.9.tar.gz" + } + ], + "project_name": "pyparsing", + "requires_dists": [ + "jinja2; extra == \"diagrams\"", + "railroad-diagrams; extra == \"diagrams\"" + ], + "requires_python": ">=3.6.8", + "version": "3.0.9" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb", + "url": "https://files.pythonhosted.org/packages/56/fc/a3c13ded7b3057680c8ae95a9b6cc83e63657c38e0005c400a5d018a33a7/pyreadline3-3.4.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae", + "url": "https://files.pythonhosted.org/packages/d7/86/3d61a61f36a0067874a00cb4dceb9028d34b6060e47828f7fc86fb9f7ee9/pyreadline3-3.4.1.tar.gz" + } + ], + "project_name": "pyreadline3", + "requires_dists": [], + "requires_python": null, + "version": "3.4.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "13d0e3ccfc2b6e26be000cb6568c832ba67ba32e719443bfe725814d3c42433c", + "url": "https://files.pythonhosted.org/packages/fb/d0/bae533985f2338c5d02184b4a7083b819f6b3fc101da792e0d96e6e5299d/pytest-7.1.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45", + "url": "https://files.pythonhosted.org/packages/4e/1f/34657c6ac56f3c58df650ba41f8ffb2620281ead8e11bcdc7db63cf72a78/pytest-7.1.2.tar.gz" + } + ], + "project_name": "pytest", + "requires_dists": [ + "argcomplete; extra == \"testing\"", + "atomicwrites>=1.0; sys_platform == \"win32\"", + "attrs>=19.2.0", + "colorama; sys_platform == \"win32\"", + "hypothesis>=3.56; extra == \"testing\"", + "importlib-metadata>=0.12; python_version < \"3.8\"", + "iniconfig", + "mock; extra == \"testing\"", + "nose; extra == \"testing\"", + "packaging", + "pluggy<2.0,>=0.12", + "py>=1.8.2", + "pygments>=2.7.2; extra == \"testing\"", + "requests; extra == \"testing\"", + "tomli>=1.0.0", + "xmlschema; extra == \"testing\"" + ], + "requires_python": ">=3.7", + "version": "7.1.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "c2a892906192663f85030a6ab91304e508e546cddfe557d692d61ec57a1d946b", + "url": "https://files.pythonhosted.org/packages/69/6d/cfd6d654877f75e0368e4040f1cf0350dd9f427b578bf7b685af629f8167/pytest-dependency-0.5.1.tar.gz" + } + ], + "project_name": "pytest-dependency", + "requires_dists": [ + "pytest>=3.6.0" + ], + "requires_python": null, + "version": "0.5.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9", + "url": "https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86", + "url": "https://files.pythonhosted.org/packages/4c/c4/13b4776ea2d76c115c1d1b84579f3764ee6d57204f6be27119f13a61d0a9/python-dateutil-2.8.2.tar.gz" + } + ], + "project_name": "python-dateutil", + "requires_dists": [ + "six>=1.5" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7", + "version": "2.8.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "99310d148f054e858cd5f4258794ed6777e7ad2c3fd7e1c1b527f1cba4d08420", + "url": "https://files.pythonhosted.org/packages/4e/68/3eebcb5becdc90e43525164f1a1951ebf1fd9df1295418a077c168e18410/python_json_logger-2.0.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "202a4f29901a4b8002a6d1b958407eeb2dd1d83c18b18b816f5b64476dde9096", + "url": "https://files.pythonhosted.org/packages/6f/a6/b47b6c0211e858c711d8dfdf34557a9da17579892efb71df9dbf983ba724/python-json-logger-2.0.2.tar.gz" + } + ], + "project_name": "python-json-logger", + "requires_dists": [], + "requires_python": ">=3.5", + "version": "2.0.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "2aaaf618c68d8c9daebc23a20436bd01b09ee70d7fbf7072b7f38b06d2fab539", + "url": "https://files.pythonhosted.org/packages/0a/ec/473cb25378f07db4c051343cc1561cf1dd1fa8ecbcb995dcc461cc083529/python_snappy-0.6.1-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "b7f920eaf46ebf41bd26f9df51c160d40f9e00b7b48471c3438cb8d027f7fb9b", + "url": "https://files.pythonhosted.org/packages/32/d2/52ca2c822787425e31f742c038bba7f05fe6ca98d605c584ec285204215a/python_snappy-0.6.1-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "cb18d9cd7b3f35a2f5af47bb8ed6a5bdbf4f3ddee37f3daade4ab7864c292f5b", + "url": "https://files.pythonhosted.org/packages/4b/57/02864670fc97cc6691846a4a56a69d9e4b0e30951547c98d825ba68bdb22/python_snappy-0.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "b6a107ab06206acc5359d4c5632bd9b22d448702a79b3169b0c62e0fb808bb2a", + "url": "https://files.pythonhosted.org/packages/98/7a/44a24bad98335b2c72e4cadcdecf79f50197d1bab9f22f863a274f104b96/python-snappy-0.6.1.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "4ec533a8c1f8df797bded662ec3e494d225b37855bb63eb0d75464a07947477c", + "url": "https://files.pythonhosted.org/packages/9e/48/94d6a98d9fdecd4b747225376aa1e5fbc2d1e0ab8ca141242753185e2d26/python_snappy-0.6.1-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "8277d1f6282463c40761f802b742f833f9f2449fcdbb20a96579aa05c8feb614", + "url": "https://files.pythonhosted.org/packages/ae/30/37bfbad510e0a323fc159b88cf8d43fb25b9719187d796224d33f9782963/python_snappy-0.6.1-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "b265cde49774752aec9ca7f5d272e3f98718164afc85521622a8a5394158a2b5", + "url": "https://files.pythonhosted.org/packages/b3/9f/81050aafd77fd1337e0da83fc5e300633618c2f50874e8b2162a2e9923b0/python_snappy-0.6.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "d017775851a778ec9cc32651c4464079d06d927303c2dde9ae9830ccf6fe94e1", + "url": "https://files.pythonhosted.org/packages/b3/db/2e1f3d55aec9b4ae83ed9781dafcae08ce2829b72afc541d9235778a9f1f/python_snappy-0.6.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "8d0c019ee7dcf2c60e240877107cddbd95a5b1081787579bf179938392d66480", + "url": "https://files.pythonhosted.org/packages/f4/fc/fd274c00c7776a17c8740281ec0790bf83dec503ab5d42ef2e02aa93bdec/python_snappy-0.6.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "6f8bf4708a11b47517baf962f9a02196478bbb10fdb9582add4aa1459fa82380", + "url": "https://files.pythonhosted.org/packages/f5/6d/1d5c0de6fe8ef83c063ba9841ddfa77e777aafa1ce8a74b6aa7a49780f2a/python_snappy-0.6.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl" + } + ], + "project_name": "python-snappy", + "requires_dists": [], + "requires_python": null, + "version": "0.6.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e", + "url": "https://files.pythonhosted.org/packages/a0/a4/d63f2d7597e1a4b55aa3b4d6c5b029991d3b824b5bd331af8d4ab1ed687d/PyYAML-5.4.1.tar.gz" + } + ], + "project_name": "pyyaml", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7", + "version": "5.4.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "7040d6dd85ea65703904d023d7f57fab793d7ffee9ba9e14f3b897f34ff2415d", + "url": "https://files.pythonhosted.org/packages/99/3b/69360102db726741053d1446cbe9f7f06df7e2a6d5b805ee71841abf1cdc/pyzmq-22.1.0.tar.gz" + } + ], + "project_name": "pyzmq", + "requires_dists": [ + "cffi; implementation_name == \"pypy\"", + "py; implementation_name == \"pypy\"" + ], + "requires_python": ">=3.6", + "version": "22.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "84316970995a7adb907a56754d2b92d88fc2d252963dc5ac34c88f0f1a22c25d", + "url": "https://files.pythonhosted.org/packages/c3/d8/46f3c0dadb5499031282b43d9c280a3c3d70107b6b0217122a92476fb5b0/redis-4.3.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "94b617b4cd296e94991146f66fc5559756fbefe9493604f0312e4d3298ac63e9", + "url": "https://files.pythonhosted.org/packages/56/78/c8c819080f6fbdd72e314756e1cb259199814b64d821dd396810dedce247/redis-4.3.1.tar.gz" + } + ], + "project_name": "redis", + "requires_dists": [ + "async-timeout>=4.0.2", + "cryptography>=36.0.1; extra == \"ocsp\"", + "deprecated>=1.2.3", + "hiredis>=1.0.0; extra == \"hiredis\"", + "importlib-metadata>=1.0; python_version < \"3.8\"", + "packaging>=20.4", + "pyopenssl==20.0.1; extra == \"ocsp\"", + "requests>=2.26.0; extra == \"ocsp\"", + "typing-extensions; python_version < \"3.8\"" + ], + "requires_python": ">=3.6", + "version": "4.3.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d", + "url": "https://files.pythonhosted.org/packages/2d/61/08076519c80041bc0ffa1a8af0cbd3bf3e2b62af10435d269a9d0f40564d/requests-2.27.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61", + "url": "https://files.pythonhosted.org/packages/60/f3/26ff3767f099b73e0efa138a9998da67890793bfa475d8278f84a30fec77/requests-2.27.1.tar.gz" + } + ], + "project_name": "requests", + "requires_dists": [ + "PySocks!=1.5.7,>=1.5.6; extra == \"socks\"", + "certifi>=2017.4.17", + "chardet<5,>=3.0.2; extra == \"use_chardet_on_py3\"", + "chardet<5,>=3.0.2; python_version < \"3\"", + "charset-normalizer~=2.0.0; python_version >= \"3\"", + "idna<3,>=2.5; python_version < \"3\"", + "idna<4,>=2.5; python_version >= \"3\"", + "urllib3<1.27,>=1.21.1", + "win-inet-pton; (sys_platform == \"win32\" and python_version == \"2.7\") and extra == \"socks\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7", + "version": "2.27.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5", + "url": "https://files.pythonhosted.org/packages/6f/bb/5deac77a9af870143c684ab46a7934038a53eb4aa975bc0687ed6ca2c610/requests_oauthlib-1.3.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a", + "url": "https://files.pythonhosted.org/packages/95/52/531ef197b426646f26b53815a7d2a67cb7a331ef098bb276db26a68ac49f/requests-oauthlib-1.3.1.tar.gz" + } + ], + "project_name": "requests-oauthlib", + "requires_dists": [ + "oauthlib>=3.0.0", + "oauthlib[signedtoken]>=3.0.0; extra == \"rsa\"", + "requests>=2.0.0" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7", + "version": "1.3.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "d2bbd99c320a2532ac71ff6a3164867884357da3e3301f0240090c5d2fdac7ec", + "url": "https://files.pythonhosted.org/packages/13/3f/1996db12d23733e2834b9c2b094cc59c0d1ab943fedafcdb34b5c0da9ebf/rich-12.4.4-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "4c586de507202505346f3e32d1363eb9ed6932f0c2f63184dea88983ff4971e2", + "url": "https://files.pythonhosted.org/packages/f5/f3/f87be42279b5cfba09f7f29e2f4a77063ccf5d9075042981e2cf48752d51/rich-12.4.4.tar.gz" + } + ], + "project_name": "rich", + "requires_dists": [ + "commonmark<0.10.0,>=0.9.0", + "dataclasses<0.9,>=0.7; python_version < \"3.7\"", + "ipywidgets<8.0.0,>=7.5.1; extra == \"jupyter\"", + "pygments<3.0.0,>=2.6.0", + "typing-extensions<5.0,>=4.0.0; python_version < \"3.9\"" + ], + "requires_python": "<4.0.0,>=3.6.3", + "version": "12.4.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb", + "url": "https://files.pythonhosted.org/packages/30/ab/8fd9e88e6fa5ec41afca995938bbefb72195278e0cfc5bd76a4f29b23fb2/rsa-4.8-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17", + "url": "https://files.pythonhosted.org/packages/8c/ee/4022542e0fed77dd6ddade38e1e4dea3299f873b7fd4e6d78319953b0f83/rsa-4.8.tar.gz" + } + ], + "project_name": "rsa", + "requires_dists": [ + "pyasn1>=0.1.3" + ], + "requires_python": "<4,>=3.6", + "version": "4.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "7357592bc7e881a95e0c2013b73326f704953301ab551fbc8133a6fadab84105", + "url": "https://files.pythonhosted.org/packages/33/0f/5ef4ac78e2a538cc1b054eb86285fe0bf7a5dbaeaac2c584757c300515e2/Rx-1.6.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "13a1d8d9e252625c173dc795471e614eadfe1cf40ffc684e08b8fff0d9748c23", + "url": "https://files.pythonhosted.org/packages/25/d7/9bc30242d9af6a9e9bf65b007c56e17b7dc9c13f86e440b885969b3bbdcf/Rx-1.6.1.tar.gz" + } + ], + "project_name": "rx", + "requires_dists": [], + "requires_python": null, + "version": "1.6.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "3b1883ccdbee624386dc046cfbcd80c4e75e24c478f35627984a79892e088b88", + "url": "https://files.pythonhosted.org/packages/c1/fb/2481ea59a072b51f322460d753f846917ce7f40ecd4574db08c5ad263165/setproctitle-1.2.3-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "791bed39e4ecbdd008b64999a60c9cc560d17b3836ca0c27cd4708e8e1bcf495", + "url": "https://files.pythonhosted.org/packages/06/26/a80da97eafde871482b8cf9a32c35d95cdf5c62109a36641b46f01ee8700/setproctitle-1.2.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "eb82a49aaf440232c762539ab3737b5174d31aba0141fd4bf4d8739c28d18624", + "url": "https://files.pythonhosted.org/packages/22/65/7420368ac4b5264dc2c74e2cfb14d392a726dae673e0730296849d0212d3/setproctitle-1.2.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "501c084cf3df7d848e91c97d4f8c44d799ba545858a79c6960326ce6f285b4e4", + "url": "https://files.pythonhosted.org/packages/2c/6e/c5dd9b87ac71e19d5176091bd21f59df63be4bb478be743a67ea2934cc6e/setproctitle-1.2.3-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "47f97f591ea2335b7d35f5e9ad7d806385338182dc6de5732d091e9c70ed1cc0", + "url": "https://files.pythonhosted.org/packages/44/07/0efb26a71e63c82d3a96abb3be4812cdf032106d5f28076a6fbeca2fc440/setproctitle-1.2.3-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "d8e4da68d4d4ba46d4c5db6ae5eb61b11de9c520f25ae8334570f4d0018a8611", + "url": "https://files.pythonhosted.org/packages/4a/9c/a5408822809cae73486cc6a9048178355e584beeb8d41bf59e53abccf462/setproctitle-1.2.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "71d00ef63a1f78e13c236895badac77b6c8503377467b9c1a4f81fe729d16e03", + "url": "https://files.pythonhosted.org/packages/63/7e/e7f272e104d9d76caf71de17f3051452335ebe442b4bf9049d19e84faada/setproctitle-1.2.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "ecf28b1c07a799d76f4326e508157b71aeda07b84b90368ea451c0710dbd32c0", + "url": "https://files.pythonhosted.org/packages/78/9a/cf6bf4c472b59aef3f3c0184233eeea8938d3366bcdd93d525261b1b9e0a/setproctitle-1.2.3.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "a39b30d7400c0d50941fe19e1fe0b7d35676186fec4d9c010129ac91b883fd26", + "url": "https://files.pythonhosted.org/packages/9e/13/7569553060fe4573bda5ba35f7973cd8b39f7bb9f1828bd7112638e51494/setproctitle-1.2.3-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "b213376fc779c0e1a4b60008f3fd03f74e9baa9665db37fa6646e98d31baa6d8", + "url": "https://files.pythonhosted.org/packages/b2/39/43dc8b0bc9b721964f3ca726e13905278c4e4d776aa1b81b070f7ac26aa1/setproctitle-1.2.3-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "0a668acec8b61a971de54bc4c733869ea7b0eb1348eae5a32b9477f788908e5c", + "url": "https://files.pythonhosted.org/packages/e7/88/23cfd59e08d8197af89e2e2f13204462221a30404c8f37c6329c28f3b681/setproctitle-1.2.3-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "e24fa9251cc22ddb88ef183070063fdca826c9636381f1c4fb9d2a1dccb7c2a4", + "url": "https://files.pythonhosted.org/packages/ec/76/a357f5f2acc4efb96c48031457b35c073917bcde2e39ff0b09ed74d8ee7e/setproctitle-1.2.3-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "52265182fe5ac237d179d8e949248d307882a2e6ec7f189c8dac1c9d1b3631fa", + "url": "https://files.pythonhosted.org/packages/f9/6e/0b3693721e3a81a19e395e903b0877e0dd8cd3f94cec5646eb80d4ac833e/setproctitle-1.2.3-cp310-cp310-macosx_10_9_x86_64.whl" + } + ], + "project_name": "setproctitle", + "requires_dists": [ + "pytest; extra == \"test\"" + ], + "requires_python": ">=3.6", + "version": "1.2.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "68e45d17c9281ba25dc0104eadd2647172b3472d9e01f911efa57965e8d51a36", + "url": "https://files.pythonhosted.org/packages/e9/1c/ec080fde54ab30a738c92f794eab7f5d2f354f2b619ee95b2efe353e0766/setuptools-62.3.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "a43bdedf853c670e5fed28e5623403bad2f73cf02f9a2774e91def6bda8265a7", + "url": "https://files.pythonhosted.org/packages/4a/25/ec29a23ef38b9456f9965c57a9e1221e6c246d87abbf2a31158799bca201/setuptools-62.3.2.tar.gz" + } + ], + "project_name": "setuptools", + "requires_dists": [ + "build[virtualenv]; extra == \"testing\"", + "build[virtualenv]; extra == \"testing-integration\"", + "filelock>=3.4.0; extra == \"testing\"", + "filelock>=3.4.0; extra == \"testing-integration\"", + "flake8-2020; extra == \"testing\"", + "furo; extra == \"docs\"", + "ini2toml[lite]>=0.9; extra == \"testing\"", + "jaraco.envs>=2.2; extra == \"testing\"", + "jaraco.envs>=2.2; extra == \"testing-integration\"", + "jaraco.packaging>=9; extra == \"docs\"", + "jaraco.path>=3.2.0; extra == \"testing\"", + "jaraco.path>=3.2.0; extra == \"testing-integration\"", + "jaraco.tidelift>=1.4; extra == \"docs\"", + "mock; extra == \"testing\"", + "pip-run>=8.8; extra == \"testing\"", + "pip>=19.1; extra == \"testing\"", + "pygments-github-lexers==0.0.5; extra == \"docs\"", + "pytest-black>=0.3.7; platform_python_implementation != \"PyPy\" and extra == \"testing\"", + "pytest-checkdocs>=2.4; extra == \"testing\"", + "pytest-cov; platform_python_implementation != \"PyPy\" and extra == \"testing\"", + "pytest-enabler; extra == \"testing-integration\"", + "pytest-enabler>=1.0.1; extra == \"testing\"", + "pytest-flake8; extra == \"testing\"", + "pytest-mypy>=0.9.1; platform_python_implementation != \"PyPy\" and extra == \"testing\"", + "pytest-perf; extra == \"testing\"", + "pytest-xdist; extra == \"testing\"", + "pytest-xdist; extra == \"testing-integration\"", + "pytest; extra == \"testing-integration\"", + "pytest>=6; extra == \"testing\"", + "rst.linker>=1.9; extra == \"docs\"", + "sphinx-favicon; extra == \"docs\"", + "sphinx-inline-tabs; extra == \"docs\"", + "sphinx-reredirects; extra == \"docs\"", + "sphinx; extra == \"docs\"", + "sphinxcontrib-towncrier; extra == \"docs\"", + "tomli-w>=1.0.0; extra == \"testing\"", + "tomli; extra == \"testing-integration\"", + "virtualenv>=13.0.0; extra == \"testing\"", + "virtualenv>=13.0.0; extra == \"testing-integration\"", + "wheel; extra == \"testing\"", + "wheel; extra == \"testing-integration\"" + ], + "requires_python": ">=3.7", + "version": "62.3.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254", + "url": "https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926", + "url": "https://files.pythonhosted.org/packages/71/39/171f1c67cd00715f190ba0b100d606d440a28c93c7714febeca8b79af85e/six-1.16.0.tar.gz" + } + ], + "project_name": "six", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7", + "version": "1.16" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "64d796e9af522162f7f2bf7a3c5531a0a550764c426782797bbeed809d0646c5", + "url": "https://files.pythonhosted.org/packages/46/be/1fe89630d6bcd239c702117a5c7be7f1403137b8dd5fb451533995d73b58/SQLAlchemy-1.4.36-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "8d07fe2de0325d06e7e73281e9a9b5e259fbd7cbfbe398a0433cbb0082ad8fa7", + "url": "https://files.pythonhosted.org/packages/3e/2c/fcb7508e5e40c42eb00516c7c1a936afae7af95b2de0e4680a60924fff7f/SQLAlchemy-1.4.36-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f0394a3acfb8925db178f7728adb38c027ed7e303665b225906bfa8099dc1ce8", + "url": "https://files.pythonhosted.org/packages/82/39/cab0562a7e580004b513856bf73af789a8b9aa810b3e2bb25b9ab74f720f/SQLAlchemy-1.4.36-cp310-cp310-macosx_10_15_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "09c606d8238feae2f360b8742ffbe67741937eb0a05b57f536948d198a3def96", + "url": "https://files.pythonhosted.org/packages/9c/6b/81d2d3e3020f9105570a7e8730815134223802f13d1fd122ee5c813cf1d2/SQLAlchemy-1.4.36-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "be094460930087e50fd08297db9d7aadaed8408ad896baf758e9190c335632da", + "url": "https://files.pythonhosted.org/packages/b5/b1/3ea004b6e6a30f5098c20bfb3153343acabf5a6e0f1a77e1ab941165f3db/SQLAlchemy-1.4.36-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "5041474dcab7973baa91ec1f3112049a9dd4652898d6a95a6a895ff5c58beb6b", + "url": "https://files.pythonhosted.org/packages/d4/c8/1496a0fb6b853eeb2fbd82f0cdd3b5c1ca0a7d46146dc91ce3b65a1616bf/SQLAlchemy-1.4.36-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "64678ac321d64a45901ef2e24725ec5e783f1f4a588305e196431447e7ace243", + "url": "https://files.pythonhosted.org/packages/fb/b0/53e540c9fad14ac2da8a15ae95d707b167f64f62d85d4f506b0335dfd66d/SQLAlchemy-1.4.36.tar.gz" + } + ], + "project_name": "sqlalchemy", + "requires_dists": [ + "aiomysql; python_version >= \"3\" and extra == \"aiomysql\"", + "aiosqlite; python_version >= \"3\" and extra == \"aiosqlite\"", + "asyncmy!=0.2.4,>=0.2.3; python_version >= \"3\" and extra == \"asyncmy\"", + "asyncpg; python_version >= \"3\" and extra == \"postgresql_asyncpg\"", + "cx-oracle<8,>=7; python_version < \"3\" and extra == \"oracle\"", + "cx-oracle>=7; python_version >= \"3\" and extra == \"oracle\"", + "greenlet!=0.4.17; python_version >= \"3\" and (platform_machine == \"aarch64\" or (platform_machine == \"ppc64le\" or (platform_machine == \"x86_64\" or (platform_machine == \"amd64\" or (platform_machine == \"AMD64\" or (platform_machine == \"win32\" or platform_machine == \"WIN32\"))))))", + "greenlet!=0.4.17; python_version >= \"3\" and extra == \"aiomysql\"", + "greenlet!=0.4.17; python_version >= \"3\" and extra == \"aiosqlite\"", + "greenlet!=0.4.17; python_version >= \"3\" and extra == \"asyncio\"", + "greenlet!=0.4.17; python_version >= \"3\" and extra == \"asyncmy\"", + "greenlet!=0.4.17; python_version >= \"3\" and extra == \"postgresql_asyncpg\"", + "importlib-metadata; python_version < \"3.8\"", + "mariadb>=1.0.1; python_version >= \"3\" and extra == \"mariadb_connector\"", + "mypy>=0.910; python_version >= \"3\" and extra == \"mypy\"", + "mysql-connector-python; extra == \"mysql_connector\"", + "mysqlclient<2,>=1.4.0; python_version < \"3\" and extra == \"mysql\"", + "mysqlclient>=1.4.0; python_version >= \"3\" and extra == \"mysql\"", + "pg8000>=1.16.6; extra == \"postgresql_pg8000\"", + "psycopg2-binary; extra == \"postgresql_psycopg2binary\"", + "psycopg2>=2.7; extra == \"postgresql\"", + "psycopg2cffi; extra == \"postgresql_psycopg2cffi\"", + "pymssql; extra == \"mssql_pymssql\"", + "pymysql; python_version >= \"3\" and extra == \"pymysql\"", + "pymysql<1; python_version < \"3\" and extra == \"pymysql\"", + "pyodbc; extra == \"mssql\"", + "pyodbc; extra == \"mssql_pyodbc\"", + "sqlalchemy2-stubs; extra == \"mypy\"", + "sqlcipher3-binary; python_version >= \"3\" and extra == \"sqlcipher\"", + "typing-extensions!=3.10.0.1; extra == \"aiosqlite\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7", + "version": "1.4.36" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "d7c013fe7abbc5e491394e10fa845f8f32fe54f8dc60c6622c6cf482d25d47e4", + "url": "https://files.pythonhosted.org/packages/ca/80/7c0cad11bd99985cfe7c09427ee0b4f9bd6b048bd13d4ffb32c6db237dfb/tabulate-0.8.9-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "eb1d13f25760052e8931f2ef80aaf6045a6cceb47514db8beab24cded16f13a7", + "url": "https://files.pythonhosted.org/packages/ae/3d/9d7576d94007eaf3bb685acbaaec66ff4cdeb0b18f1bf1f17edbeebffb0a/tabulate-0.8.9.tar.gz" + } + ], + "project_name": "tabulate", + "requires_dists": [ + "wcwidth; extra == \"widechars\"" + ], + "requires_python": null, + "version": "0.8.9" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "289fa7359e580950e7d9743eab36b0691f0310fce64dee7d9c31065b8f723e23", + "url": "https://files.pythonhosted.org/packages/f8/cd/2fad4add11c8837e72f50a30e2bda30e67a10d70462f826b291443a55c7d/tblib-1.7.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "059bd77306ea7b419d4f76016aef6d7027cc8a0785579b5aad198803435f882c", + "url": "https://files.pythonhosted.org/packages/d3/41/901ef2e81d7b1e834b9870d416cb09479e175a2be1c4aa1a9dcd0a555293/tblib-1.7.0.tar.gz" + } + ], + "project_name": "tblib", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "1.7" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "b07e8ea0684e73ded0cd4fa3622ca00477ee85cf32ea686f38db06c2e8e17bda", + "url": "https://files.pythonhosted.org/packages/1b/cc/8cc2406d9b022cb0379f983560aee13e874f426477d5cbdcfcf46423eb08/temporenc-0.1.0.tar.gz" + } + ], + "project_name": "temporenc", + "requires_dists": [], + "requires_python": null, + "version": "0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "f78f4ea81b0fabc06728c11dc2a8c01277bfc5181b321a4770471902e3eb844a", + "url": "https://files.pythonhosted.org/packages/f2/a5/f86bc8d67c979020438c8559cc70cfe3a1643fd160d35e09c9cca6a09189/tenacity-8.0.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "43242a20e3e73291a28bcbcacfd6e000b02d3857a9a9fff56b297a27afdc932f", + "url": "https://files.pythonhosted.org/packages/2c/f5/04748914f5c78f7418b803226bd56cdddd70ac369b936b3e24f5158017f1/tenacity-8.0.1.tar.gz" + } + ], + "project_name": "tenacity", + "requires_dists": [ + "reno; extra == \"doc\"", + "sphinx; extra == \"doc\"", + "tornado>=4.5; extra == \"doc\"" + ], + "requires_python": ">=3.6", + "version": "8.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "e4fdc4179c9e4aab5f674d80f09d76fa436b96fdc698a8505e0a36bf0804a874", + "url": "https://files.pythonhosted.org/packages/c4/fb/ea621e0a19733e01fe4005d46087d383693c0f4a8f824b47d8d4122c87e0/terminaltables-3.1.10-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "ba6eca5cb5ba02bba4c9f4f985af80c54ec3dccf94cfcd190154386255e47543", + "url": "https://files.pythonhosted.org/packages/f5/fc/0b73d782f5ab7feba8d007573a3773c58255f223c5940a7b7085f02153c3/terminaltables-3.1.10.tar.gz" + } + ], + "project_name": "terminaltables", + "requires_dists": [], + "requires_python": ">=2.6", + "version": "3.1.10" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", + "url": "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", + "url": "https://files.pythonhosted.org/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz" + } + ], + "project_name": "toml", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,>=2.6", + "version": "0.10.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc", + "url": "https://files.pythonhosted.org/packages/97/75/10a9ebee3fd790d20926a90a2547f0bf78f371b2f13aa822c759680ca7b9/tomli-2.0.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f", + "url": "https://files.pythonhosted.org/packages/c0/3f/d7af728f075fb08564c5949a9c95e44352e23dee646869fa104a3b2060a3/tomli-2.0.1.tar.gz" + } + ], + "project_name": "tomli", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "2.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "b824e3466f1d475b2b5f1c392954c6cb7ea04d64354ff7300dc7c14257dc85db", + "url": "https://files.pythonhosted.org/packages/2d/36/b17811aa7c17609eaa68a91e15e6b2e56bf4d5d5a3c43d53c2b46728e6b2/tomlkit-0.8.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "29e84a855712dfe0e88a48f6d05c21118dbafb283bb2eed614d46f80deb8e9a1", + "url": "https://files.pythonhosted.org/packages/0f/96/ee6ba35c61186fbf084cb3077374d50eef36ab59cb8c6513317caa190935/tomlkit-0.8.0.tar.gz" + } + ], + "project_name": "tomlkit", + "requires_dists": [], + "requires_python": "<4.0,>=3.6", + "version": "0.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "74a2cdefe14d11442cedf3ba4e21a3b84ff9a2dbdc6cfae2c34addb2a14a5ea6", + "url": "https://files.pythonhosted.org/packages/8a/c4/d15f1e627fff25443ded77ea70a7b5532d6371498f9285d44d62587e209c/tqdm-4.64.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "40be55d30e200777a307a7585aee69e4eabb46b4ec6a4b4a5f2d9f11e7d5408d", + "url": "https://files.pythonhosted.org/packages/98/2a/838de32e09bd511cf69fe4ae13ffc748ac143449bfc24bb3fd172d53a84f/tqdm-4.64.0.tar.gz" + } + ], + "project_name": "tqdm", + "requires_dists": [ + "colorama; platform_system == \"Windows\"", + "importlib-resources; python_version < \"3.7\"", + "ipywidgets>=6; extra == \"notebook\"", + "py-make>=0.1.0; extra == \"dev\"", + "requests; extra == \"telegram\"", + "slack-sdk; extra == \"slack\"", + "twine; extra == \"dev\"", + "wheel; extra == \"dev\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7", + "version": "4.64" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "1966f432586797aed663edd54cbc201fd7ba59eed1638f1a7a33f17977b3a569", + "url": "https://files.pythonhosted.org/packages/64/32/17b47745df926eff2e5b89d79838337de88258f65a936f70416a15190142/trafaret-2.1.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "d9d00800318fbd343fdfb3353e947b2ebb5557159c844696c5ac24846f76d41c", + "url": "https://files.pythonhosted.org/packages/c9/ed/aac034e566f8846aee6472dcc90da6011a0b1829e3ffc768407df519a3b0/trafaret-2.1.1.tar.gz" + } + ], + "project_name": "trafaret", + "requires_dists": [ + "pymongo>=2.4.1; extra == \"objectid\"", + "python-dateutil>=1.5; extra == \"rfc3339\"" + ], + "requires_python": null, + "version": "2.1.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1", + "url": "https://files.pythonhosted.org/packages/9a/bb/d43e5c75054e53efce310e79d63df0ac3f25e34c926be5dffb7d283fb2a8/typeguard-2.13.3-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4", + "url": "https://files.pythonhosted.org/packages/3a/38/c61bfcf62a7b572b5e9363a802ff92559cb427ee963048e1442e3aef7490/typeguard-2.13.3.tar.gz" + } + ], + "project_name": "typeguard", + "requires_dists": [ + "mypy; platform_python_implementation != \"PyPy\" and extra == \"test\"", + "pytest; extra == \"test\"", + "sphinx-autodoc-typehints>=1.2.0; extra == \"doc\"", + "sphinx-rtd-theme; extra == \"doc\"", + "typing-extensions; extra == \"test\"" + ], + "requires_python": ">=3.5.3", + "version": "2.13.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "64846425bc267beec45ddf50b14fb532e6612b9d74c976761cc5e1c47a88a06b", + "url": "https://files.pythonhosted.org/packages/9b/67/f0c15faead0f0232da14c24b03700cf7d896f4d138591e2b3fcfa8df1d3c/types_aiofiles-0.8.8-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "62942901b44dd8fb14b5581e8389fa0fe5c03b8aef7a1f86ac08e7b550d7dc11", + "url": "https://files.pythonhosted.org/packages/2f/35/bdbf0cf795eea02b95e1b7d10aa2f58f6b72762162884c0ba32b2f323dc6/types-aiofiles-0.8.8.tar.gz" + } + ], + "project_name": "types-aiofiles", + "requires_dists": [], + "requires_python": null, + "version": "0.8.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "4915e40fce8df222aa9ef118cdbc3e8eef55caefe32b6ded2cd2dbe21a2e2398", + "url": "https://files.pythonhosted.org/packages/fd/47/42a38cf1b57075061b32da942d1401a0e8da07f67f0227c5d235e9287b62/types_cachetools-5.0.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2b9e0b1ace1a0de67bc3ced9c4aa69fff882bac8bfeb0f2aefaf0f99e4571c6b", + "url": "https://files.pythonhosted.org/packages/ae/b0/6874d82294165cec3ac31767b5b4005325e3f9f2f699316ce6a41a52b1bc/types-cachetools-5.0.1.tar.gz" + } + ], + "project_name": "types-cachetools", + "requires_dists": [], + "requires_python": null, + "version": "5.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8cb030a669e2e927461be9827375f83c16b8178c365852c060a34e24871e7e81", + "url": "https://files.pythonhosted.org/packages/ee/ad/607454a5f991c5b3e14693a7113926758f889138371058a5f72f567fa131/types_click-7.1.8-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "b6604968be6401dc516311ca50708a0a28baa7a0cb840efd7412f0dbbff4e092", + "url": "https://files.pythonhosted.org/packages/00/ff/0e6a56108d45c80c61cdd4743312d0304d8192482aea4cce96c554aaa90d/types-click-7.1.8.tar.gz" + } + ], + "project_name": "types-click", + "requires_dists": [], + "requires_python": null, + "version": "7.1.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "60a1e21e8296979db32f9374d8a239af4cb541ff66447bb915d8ad398f9c63b2", + "url": "https://files.pythonhosted.org/packages/b7/b0/e79d84748f1d34304f13191424348a719c3febaa3493835370fe9528e1e6/types_Jinja2-2.11.9-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "dbdc74a40aba7aed520b7e4d89e8f0fe4286518494208b35123bcf084d4b8c81", + "url": "https://files.pythonhosted.org/packages/46/c4/b82309bfed8195de7997672deac301bd6f5bd5cbb6a3e392b7fe780d7852/types-Jinja2-2.11.9.tar.gz" + } + ], + "project_name": "types-jinja2", + "requires_dists": [ + "types-MarkupSafe" + ], + "requires_python": null, + "version": "2.11.9" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "ca2bee0f4faafc45250602567ef38d533e877d2ddca13003b319c551ff5b3cc5", + "url": "https://files.pythonhosted.org/packages/bc/d6/b8effb1c48539260a5eb4196afc55efac4ea1684a4991977555eb266b2ef/types_MarkupSafe-1.1.10-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "85b3a872683d02aea3a5ac2a8ef590193c344092032f58457287fbf8e06711b1", + "url": "https://files.pythonhosted.org/packages/39/31/b5f059142d058aec41e913d8e0eff0a967e7bc46f9a2ba2f31bc11cff059/types-MarkupSafe-1.1.10.tar.gz" + } + ], + "project_name": "types-markupsafe", + "requires_dists": [], + "requires_python": null, + "version": "1.1.10" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "0be7435b4d382d1cd00b8c55a8a90f4e515aaad8a96f8f0bc20c22df046792e5", + "url": "https://files.pythonhosted.org/packages/ae/f4/4fdaa07d4ec2a722b9820d11722afb1b289e5a697ee92d2d4766b1310e97/types_python_dateutil-2.8.17-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "6c54265a221681dd87f61df6743bd5eab060cf1b4086ff65c1a8fd763ed6370e", + "url": "https://files.pythonhosted.org/packages/ea/43/f675eb9c13cb784ee18f44b99a9ec5ee73aa4d1a6aff230699c7fc76ad03/types-python-dateutil-2.8.17.tar.gz" + } + ], + "project_name": "types-python-dateutil", + "requires_dists": [], + "requires_python": null, + "version": "2.8.17" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "7b273a34f32af9910cf9405728c9d2ad3afc4be63e4048091a1a73d76681fe67", + "url": "https://files.pythonhosted.org/packages/c6/e0/6e3e8e3af769206cb4f3f9e90e9e72e9df3bf921602d63ac117dcd7f1c30/types_PyYAML-6.0.7-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "59480cf44595d836aaae050f35e3c39f197f3a833679ef3978d97aa9f2fb7def", + "url": "https://files.pythonhosted.org/packages/42/6d/3d9bcc1ca2634492fa92bd311d2e1fede17ce9377e54bf11560ccf5305ca/types-PyYAML-6.0.7.tar.gz" + } + ], + "project_name": "types-pyyaml", + "requires_dists": [], + "requires_python": null, + "version": "6.0.7" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "9c7cdaf0d55113e24ac17103bde2d434472abf1dbf444238e989fe4e798ffa26", + "url": "https://files.pythonhosted.org/packages/58/68/dda470233bc56db72b51fb20de9ead5faa138595debbc861256ec20740b6/types_setuptools-57.4.17-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "9d556fcaf6808a1cead4aaa41e5c07a61f0152a875811e1239738eba4e0b7b16", + "url": "https://files.pythonhosted.org/packages/ba/97/1e6f2106412d038b2ca4b18bd652a15a50096c6829345d9d0076ad9555a5/types-setuptools-57.4.17.tar.gz" + } + ], + "project_name": "types-setuptools", + "requires_dists": [], + "requires_python": null, + "version": "57.4.17" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "18f6856a7df44fc7a292c2d73093908333e5f7cb858667b8cbefc8ed1e91942e", + "url": "https://files.pythonhosted.org/packages/13/d4/ee6b9044aac19e9a3e0463acc0b6c062db54d70294d1885de492c0c399ad/types_six-1.16.15-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "d244f0537dab0d0570a5bc6f8a60c4da7f0546d960a8677520e6bff214a64fb8", + "url": "https://files.pythonhosted.org/packages/1c/69/1cb7e70583221b1f42147d5d59e43a230e84cebf25e774021a12dfb713fa/types-six-1.16.15.tar.gz" + } + ], + "project_name": "types-six", + "requires_dists": [], + "requires_python": null, + "version": "1.16.15" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "7971ed0cd40454eb18d82c01e2f18bcd09ca23cc9eb901c62d2b04e5d1f57f84", + "url": "https://files.pythonhosted.org/packages/ca/2b/54cbdc5c90689cfd7523082dbe88d535b1ab9e35ca23221ce839dcc51a94/types_tabulate-0.8.9-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2fc3fa4fe1853ac987cf50e8d4599e3fe446dd53064fe86a46a407a98e9fc04f", + "url": "https://files.pythonhosted.org/packages/26/4d/f5e4faeaa4f042f4afde95f2b4025a3edc863924410ef76d673110fcbeac/types-tabulate-0.8.9.tar.gz" + } + ], + "project_name": "types-tabulate", + "requires_dists": [], + "requires_python": null, + "version": "0.8.9" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "05a8da4bfde2f1ee60e90c7071c063b461f74c63a9c3c1099470c08d6fa58615", + "url": "https://files.pythonhosted.org/packages/77/02/8bb35dc27ea84f9ac5f4bf70a51ae5626734843ac73a117fd88888551069/types_toml-0.10.7-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "a567fe2614b177d537ad99a661adc9bfc8c55a46f95e66370a4ed2dd171335f9", + "url": "https://files.pythonhosted.org/packages/c5/da/a5fb5c4eb663a1cd2d0c8ef619c42d51e6b8f55e155341e7b39b8c6c67b4/types-toml-0.10.7.tar.gz" + } + ], + "project_name": "types-toml", + "requires_dists": [], + "requires_python": null, + "version": "0.10.7" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "21c85e0fe4b9a155d0799430b0ad741cdce7e359660ccbd8b530613e8df88ce2", + "url": "https://files.pythonhosted.org/packages/45/6b/44f7f8f1e110027cf88956b59f2fad776cca7e1704396d043f89effd3a0e/typing_extensions-4.1.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "1a9462dcc3347a79b1f1c0271fbe79e844580bb598bafa1ed208b94da3cdcd42", + "url": "https://files.pythonhosted.org/packages/b1/5a/8b5fbb891ef3f81fc923bf3cb4a578c0abf9471eb50ce0f51c74212182ab/typing_extensions-4.1.1.tar.gz" + } + ], + "project_name": "typing-extensions", + "requires_dists": [], + "requires_python": ">=3.6", + "version": "4.1.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "44ece4d53fb1706f667c9bd1c648f5469a2ec925fcf3a776667042d645472c14", + "url": "https://files.pythonhosted.org/packages/ec/03/062e6444ce4baf1eac17a6a0ebfe36bb1ad05e1df0e20b110de59c278498/urllib3-1.26.9-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "aabaf16477806a5e1dd19aa41f8c2b7950dd3c746362d7e3223dbe6de6ac448e", + "url": "https://files.pythonhosted.org/packages/1b/a5/4eab74853625505725cefdf168f48661b2cd04e7843ab836f3f63abf81da/urllib3-1.26.9.tar.gz" + } + ], + "project_name": "urllib3", + "requires_dists": [ + "PySocks!=1.5.7,<2.0,>=1.5.6; extra == \"socks\"", + "brotli>=1.0.9; ((os_name != \"nt\" or python_version >= \"3\") and platform_python_implementation == \"CPython\") and extra == \"brotli\"", + "brotlicffi>=0.8.0; ((os_name != \"nt\" or python_version >= \"3\") and platform_python_implementation != \"CPython\") and extra == \"brotli\"", + "brotlipy>=0.6.0; (os_name == \"nt\" and python_version < \"3\") and extra == \"brotli\"", + "certifi; extra == \"secure\"", + "cryptography>=1.3.4; extra == \"secure\"", + "idna>=2.0.0; extra == \"secure\"", + "ipaddress; python_version == \"2.7\" and extra == \"secure\"", + "pyOpenSSL>=0.14; extra == \"secure\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,<4,>=2.7", + "version": "1.26.9" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "772206116b9b57cd625c8a88f2413df2fcfd0b496eb188b82a43bed7af2c2ec9", + "url": "https://files.pythonhosted.org/packages/12/1c/4c270b22f68a75bedf795aadc40370c4ff9e910a5e1aff327c24aaae6a99/uvloop-0.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "6224f1401025b748ffecb7a6e2652b17768f30b1a6a3f7b44660e5b5b690b12d", + "url": "https://files.pythonhosted.org/packages/2a/07/75074f9789d5f8811bc77230a84ddbb7586e555e84f59d75d2968ef5c4a0/uvloop-0.16.0-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "f74bc20c7b67d1c27c72601c78cf95be99d5c2cdd4514502b4f3eb0933ff1228", + "url": "https://files.pythonhosted.org/packages/ab/d9/22bbffa8f8d7e075ccdb29e8134107adfb4710feb10039f9d357db8b589c/uvloop-0.16.0.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "bd53f7f5db562f37cd64a3af5012df8cac2c464c97e732ed556800129505bd64", + "url": "https://files.pythonhosted.org/packages/b9/00/14dffb56943092c2b5821d288dc23ff36dff9ad3b8aad3547c71b171cf3b/uvloop-0.16.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "30ba9dcbd0965f5c812b7c2112a1ddf60cf904c1c160f398e7eed3a6b82dcd9c", + "url": "https://files.pythonhosted.org/packages/da/ea/56fecce56844e308b6ff3c5b55a372a6a4da2c6fe8ee35a272459f534d9c/uvloop-0.16.0-cp310-cp310-macosx_10_9_x86_64.whl" + } + ], + "project_name": "uvloop", + "requires_dists": [ + "Cython<0.30.0,>=0.29.24; extra == \"dev\"", + "Sphinx~=4.1.2; extra == \"dev\"", + "Sphinx~=4.1.2; extra == \"docs\"", + "aiohttp; extra == \"dev\"", + "aiohttp; extra == \"test\"", + "flake8~=3.9.2; extra == \"dev\"", + "flake8~=3.9.2; extra == \"test\"", + "mypy>=0.800; extra == \"dev\"", + "mypy>=0.800; extra == \"test\"", + "psutil; extra == \"dev\"", + "psutil; extra == \"test\"", + "pyOpenSSL~=19.0.0; extra == \"dev\"", + "pyOpenSSL~=19.0.0; extra == \"test\"", + "pycodestyle~=2.7.0; extra == \"dev\"", + "pycodestyle~=2.7.0; extra == \"test\"", + "pytest>=3.6.0; extra == \"dev\"", + "sphinx-rtd-theme~=0.5.2; extra == \"dev\"", + "sphinx-rtd-theme~=0.5.2; extra == \"docs\"", + "sphinxcontrib-asyncio~=0.3.0; extra == \"dev\"", + "sphinxcontrib-asyncio~=0.3.0; extra == \"docs\"" + ], + "requires_python": ">=3.7", + "version": "0.16" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "722b171be00f2b90e1d4fb2f2b53146a536ca38db1da8ff49c972a4e1365d0ef", + "url": "https://files.pythonhosted.org/packages/a1/9e/8ddb04ef21ea3dfe3924b884dc11fa785df662af23e049ec2d62eaba707d/websocket_client-1.3.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "50b21db0058f7a953d67cc0445be4b948d7fc196ecbeb8083d68d94628e4abf6", + "url": "https://files.pythonhosted.org/packages/7c/de/9f5354b4b37df453b7d664f587124c70a75c81805095d491d39f5b591818/websocket-client-1.3.2.tar.gz" + } + ], + "project_name": "websocket-client", + "requires_dists": [ + "Sphinx>=3.4; extra == \"docs\"", + "python-socks; extra == \"optional\"", + "sphinx-rtd-theme>=0.5; extra == \"docs\"", + "websockets; extra == \"test\"", + "wsaccel; extra == \"optional\"" + ], + "requires_python": ">=3.7", + "version": "1.3.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164", + "url": "https://files.pythonhosted.org/packages/c0/1e/e5a5ac09e92fd112d50e1793e5b9982dc9e510311ed89dacd2e801f82967/wrapt-1.14.1-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f", + "url": "https://files.pythonhosted.org/packages/07/06/2b4aaaa4403f766c938f9780c700d7399726bce3dfd94f5a57c4e5b9dc68/wrapt-1.14.1-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d", + "url": "https://files.pythonhosted.org/packages/11/eb/e06e77394d6cf09977d92bff310cb0392930c08a338f99af6066a5a98f92/wrapt-1.14.1.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2", + "url": "https://files.pythonhosted.org/packages/39/4d/34599a47c8a41b3ea4986e14f728c293a8a96cd6c23663fe33657c607d34/wrapt-1.14.1-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c", + "url": "https://files.pythonhosted.org/packages/40/f4/7be7124a06c14b92be53912f93c8dc84247f1cb93b4003bed460a430d1de/wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8", + "url": "https://files.pythonhosted.org/packages/4f/83/2669bf2cb4cc2b346c40799478d29749ccd17078cb4f69b4a9f95921ff6d/wrapt-1.14.1-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4", + "url": "https://files.pythonhosted.org/packages/50/d5/bf619c4d204fe8888460f65222b465c7ecfa43590fdb31864fe0e266da29/wrapt-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069", + "url": "https://files.pythonhosted.org/packages/94/56/fd707fb8e1ea86e72503d823549fb002a0f16cb4909619748996daeb3a82/wrapt-1.14.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656", + "url": "https://files.pythonhosted.org/packages/cd/ec/383d9552df0641e9915454b03139571e0c6e055f5d414d8f3d04f3892f38/wrapt-1.14.1-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320", + "url": "https://files.pythonhosted.org/packages/f7/92/121147bb2f9ed1aa35a8780c636d5da9c167545f97737f0860b4c6c92086/wrapt-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310", + "url": "https://files.pythonhosted.org/packages/fd/70/8a133c88a394394dd57159083b86a564247399440b63f2da0ad727593570/wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + } + ], + "project_name": "wrapt", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "1.14.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "c9c6d927e098c2d360695f2e9d38870b2e92e0919be07dbe339aefa32a090265", + "url": "https://files.pythonhosted.org/packages/7c/ad/bf6dfc6521394aa7d0b3ecbdf5e2b272fd1e79d585107869e75f0e283245/yarl-1.7.2-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "cff3ba513db55cc6a35076f32c4cdc27032bd075c9faef31fec749e64b45d26c", + "url": "https://files.pythonhosted.org/packages/1a/09/a9b4fc484f562297158ad03f6db123f9e1f39424a969599ca0b6cbe5367f/yarl-1.7.2-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "167ab7f64e409e9bdd99333fe8c67b5574a1f0495dcfd905bc7454e766729b9e", + "url": "https://files.pythonhosted.org/packages/48/2d/3992de6e80cacc12b51f3cb690590a5a834f9ac2022c88e9ac0d3b293c77/yarl-1.7.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f2a8508f7350512434e41065684076f640ecce176d262a7d54f0da41d99c5a95", + "url": "https://files.pythonhosted.org/packages/4e/a5/edfa475dc2138da03cc7561b4fbfb26c2bb18c1f41a99333adb28a9a90e5/yarl-1.7.2-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "6152224d0a1eb254f97df3997d79dadd8bb2c1a02ef283dbb34b97d4f8492d23", + "url": "https://files.pythonhosted.org/packages/69/4d/a64f3371ff9e599aa738699a539d6391cea226299b28a922900b3e5a2bd1/yarl-1.7.2-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "95a1873b6c0dd1c437fb3bb4a4aaa699a48c218ac7ca1e74b0bee0ab16c7d60d", + "url": "https://files.pythonhosted.org/packages/90/6c/23b7bba775522b819b2b6616aa83fd1f4577fea3e7c6ed0a862df1aeb855/yarl-1.7.2-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "a1d0894f238763717bdcfea74558c94e3bc34aeacd3351d769460c1a586a8b05", + "url": "https://files.pythonhosted.org/packages/94/d3/434dca72103d1280dd3e1281f501fb5e6ad0eb6c18ae92ca8d43fb8c2fa7/yarl-1.7.2-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "1d3d5ad8ea96bd6d643d80c7b8d5977b4e2fb1bab6c9da7322616fd26203d125", + "url": "https://files.pythonhosted.org/packages/a9/3a/19cb4d33a7b3e81d2a3663803c59a7365bf4694077823c3d1ff2f82a2481/yarl-1.7.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "da6df107b9ccfe52d3a48165e48d72db0eca3e3029b5b8cb4fe6ee3cb870ba8b", + "url": "https://files.pythonhosted.org/packages/b8/43/bd158143b6facbd309fd0b10a21b9546f455db6f851be6911e6b25c40c47/yarl-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f44477ae29025d8ea87ec308539f95963ffdc31a82f42ca9deecf2d505242e72", + "url": "https://files.pythonhosted.org/packages/bc/4a/a6f020c4be2654bf8d375731fcacfdcfd1d2f5fd0c48c8dfebb6ec14a84b/yarl-1.7.2-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "1ca56f002eaf7998b5fcf73b2421790da9d2586331805f38acd9997743114e98", + "url": "https://files.pythonhosted.org/packages/d0/5f/0410c8c038e626b8732db53bf7ca2b5deb2b1ac8b4a4659763890a61a43c/yarl-1.7.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "5bb7d54b8f61ba6eee541fba4b83d22b8a046b4ef4d8eb7f15a7e35db2e1e245", + "url": "https://files.pythonhosted.org/packages/d8/71/c3b593ccef94111a41aed0cf068be3a5f0e331eb1ff9ea538d21b523e6f4/yarl-1.7.2-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "dfe4b95b7e00c6635a72e2d00b478e8a28bfb122dc76349a06e20792eb53a523", + "url": "https://files.pythonhosted.org/packages/db/c7/6f0ae227ea247012055daf4856a8cd85d690f0b18480c54da0b919d2beba/yarl-1.7.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "c145ab54702334c42237a6c6c4cc08703b6aa9b94e2f227ceb3d477d20c36c63", + "url": "https://files.pythonhosted.org/packages/e8/ce/920cebfb0fef407eae4d21b37be949d9c4e47671bb9d7271dd8203cd55d8/yarl-1.7.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "9c1f083e7e71b2dd01f7cd7434a5f88c15213194df38bc29b388ccdf1492b739", + "url": "https://files.pythonhosted.org/packages/f2/0b/b897521eb6367f97f452bb6313d99e3653f93e5e62b53c60c865c4bc23b0/yarl-1.7.2-cp310-cp310-musllinux_1_1_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "45399b46d60c253327a460e99856752009fcee5f5d3c80b2f7c0cae1c38d56dd", + "url": "https://files.pythonhosted.org/packages/f6/da/46d1b3d69a9a0835dabf9d59c7eb0f1600599edd421a4c5a15ab09f527e0/yarl-1.7.2.tar.gz" + } + ], + "project_name": "yarl", + "requires_dists": [ + "idna>=2.0", + "multidict>=4.0", + "typing-extensions>=3.7.4; python_version < \"3.8\"" + ], + "requires_python": ">=3.6", + "version": "1.7.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "0662eb3ebe764fa168a5883cd8819ef83b94bd9e39955537188459d2264a7f60", + "url": "https://files.pythonhosted.org/packages/81/f3/d7b4c8c9b6657ff0db27b739894ed0665fa8f3c78a7452bf74d6447f6865/zipstream_new-1.1.8-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "b031fe181b94e51678389d26b174bc76382605a078d7d5d8f5beae083f111c76", + "url": "https://files.pythonhosted.org/packages/e5/f3/1b5228576f215b200c7e922a280a92e4494df33baae6e0280a6f45371f13/zipstream-new-1.1.8.tar.gz" + } + ], + "project_name": "zipstream-new", + "requires_dists": [], + "requires_python": null, + "version": "1.1.8" + } + ], + "platform_tag": [ + "cp310", + "cp310", + "manylinux_2_31_aarch64" + ] + } + ], + "path_mappings": {}, + "pex_version": "2.1.84", + "prefer_older_binary": false, + "requirements": [ + "Jinja2~=3.0.1", + "PyJWT~=2.0", + "PyYAML~=5.4.1", + "SQLAlchemy[postgresql_asyncpg]~=1.4.29", + "aiodataloader-ng~=0.2.1", + "aiodns>=3.0", + "aiodocker~=0.21.0", + "aiofiles~=0.8.0", + "aiohttp_cors~=0.7", + "aiohttp_session[aioredis]~=2.11", + "aiohttp_sse>=2.0", + "aiohttp~=3.8.1", + "aiomonitor~=0.4.5", + "aioredis[hiredis]~=2.0.1", + "aiosqlite~=0.17.0", + "aiotools~=1.5.9", + "aiotusclient~=0.1.4", + "alembic~=1.7.7", + "appdirs~=1.4.4", + "async_timeout~=4.0", + "asyncudp>=0.4", + "attrs>=20.3", + "backend.ai-krunner-alpine~=3.3", + "backend.ai-krunner-static-gnu~=2.0", + "cachetools~=4.1.1", + "callosum~=0.9.10", + "click>=7.1.2", + "colorama>=0.4.4", + "coloredlogs~=15.0", + "cryptography>=2.8", + "etcetra~=0.1.6", + "graphene~=2.1.9", + "humanize>=3.1.0", + "janus>=0.6.1", + "kubernetes-asyncio~=9.1.0", + "kubernetes~=10.0.0", + "lark-parser~=0.11.3", + "more-itertools~=8.12.0", + "msgpack>=1.0.0", + "netifaces~=0.11.0", + "packaging>=21.3", + "passlib[bcrypt]>=1.7.4", + "pexpect~=4.8", + "psutil~=5.8.0", + "psycopg2-binary>=2.8.4", + "pytest-dependency>=0.5.1", + "pytest~=7.1", + "python-dateutil>=2.8", + "python-json-logger>=2.0.1", + "python-snappy~=0.6.0", + "pyzmq~=22.1.0", + "redis[hiredis]~=4.3.1", + "rich~=12.2", + "setproctitle~=1.2.2", + "tabulate~=0.8.9", + "tblib~=1.7", + "tenacity>=8.0", + "toml>=0.10.2; python_version <= \"3.11\"", + "tomlkit~=0.8.0", + "tqdm>=4.61", + "trafaret~=2.1", + "typeguard~=2.10", + "types-Jinja2", + "types-PyYAML", + "types-aiofiles", + "types-cachetools", + "types-click", + "types-python-dateutil", + "types-setuptools", + "types-six", + "types-tabulate", + "types-toml", + "typing_extensions~=4.1.1", + "uvloop>=0.16", + "yarl>=1.7", + "zipstream-new~=1.1.8" + ], + "requires_python": [ + "==3.10.4" + ], + "resolver_version": "pip-2020-resolver", + "style": "universal", + "transitive": true, + "use_pep517": null +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..0525ad6576 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,79 @@ +aiodataloader-ng~=0.2.1 +aiodocker~=0.21.0 +aiofiles~=0.8.0 +aiohttp~=3.8.1 +aiohttp_cors~=0.7 +aiohttp_sse>=2.0 +aiohttp_session[aioredis]~=2.11 +aiodns>=3.0 +aiomonitor~=0.4.5 +aioredis[hiredis]~=2.0.1 +aiosqlite~=0.17.0 +aiotools~=1.5.9 +aiotusclient~=0.1.4 +alembic~=1.7.7 +appdirs~=1.4.4 +async_timeout~=4.0 +asyncudp>=0.4 +attrs>=20.3 +cachetools~=4.1.1 +callosum~=0.9.10 +click>=7.1.2 +coloredlogs~=15.0 +colorama>=0.4.4 +cryptography>=2.8 +etcetra~=0.1.6 +humanize>=3.1.0 +graphene~=2.1.9 +janus>=0.6.1 +Jinja2~=3.0.1 +kubernetes~=10.0.0 +kubernetes-asyncio~=9.1.0 +lark-parser~=0.11.3 +more-itertools~=8.12.0 +msgpack>=1.0.0 +netifaces~=0.11.0 +passlib[bcrypt]>=1.7.4 +pexpect~=4.8 +psutil~=5.8.0 +psycopg2-binary>=2.8.4 +pytest~=7.1 +pytest-dependency>=0.5.1 +python-dateutil>=2.8 +python-snappy~=0.6.0 +python-json-logger>=2.0.1 +pyzmq~=22.1.0 +PyJWT~=2.0 +PyYAML~=5.4.1 +packaging>=21.3 +redis[hiredis]~=4.3.1 +rich~=12.2 +SQLAlchemy[postgresql_asyncpg]~=1.4.29 +setproctitle~=1.2.2 +tabulate~=0.8.9 +tblib~=1.7 +tenacity>=8.0 +toml>=0.10.2 ; python_version<='3.11' +tomlkit~=0.8.0 +tqdm>=4.61 +trafaret~=2.1 +typeguard~=2.10 +typing_extensions~=4.1.1 +uvloop>=0.16 +yarl>=1.7 +zipstream-new~=1.1.8 + +# type stubs +types-aiofiles +types-click +types-cachetools +types-Jinja2 +types-PyYAML +types-python-dateutil +types-setuptools +types-six +types-tabulate +types-toml + +backend.ai-krunner-alpine~=3.3 +backend.ai-krunner-static-gnu~=2.0 diff --git a/requirements/build.txt b/requirements/build.txt deleted file mode 100644 index b7a3c3f3d1..0000000000 --- a/requirements/build.txt +++ /dev/null @@ -1 +0,0 @@ --e .[build] diff --git a/requirements/dev.txt b/requirements/dev.txt deleted file mode 100644 index 12b4028a7d..0000000000 --- a/requirements/dev.txt +++ /dev/null @@ -1 +0,0 @@ --e .[dev,build,test,lint,typecheck,docs] diff --git a/requirements/docs.txt b/requirements/docs.txt deleted file mode 100644 index a6b9c53b9b..0000000000 --- a/requirements/docs.txt +++ /dev/null @@ -1,2 +0,0 @@ -git+https://github.com/lablup/pygments-graphql-lexer --e .[docs] diff --git a/requirements/lint.txt b/requirements/lint.txt deleted file mode 100644 index 1d4464c246..0000000000 --- a/requirements/lint.txt +++ /dev/null @@ -1 +0,0 @@ --e .[lint] diff --git a/requirements/main.txt b/requirements/main.txt deleted file mode 100644 index d6e1198b1a..0000000000 --- a/requirements/main.txt +++ /dev/null @@ -1 +0,0 @@ --e . diff --git a/requirements/test.txt b/requirements/test.txt deleted file mode 100644 index a4af5aa107..0000000000 --- a/requirements/test.txt +++ /dev/null @@ -1 +0,0 @@ --e .[test] diff --git a/requirements/typecheck.txt b/requirements/typecheck.txt deleted file mode 100644 index 2059180232..0000000000 --- a/requirements/typecheck.txt +++ /dev/null @@ -1 +0,0 @@ --e .[typecheck] diff --git a/scripts/agent/build-dropbear.sh b/scripts/agent/build-dropbear.sh new file mode 100755 index 0000000000..08475e99c5 --- /dev/null +++ b/scripts/agent/build-dropbear.sh @@ -0,0 +1,79 @@ +#! /bin/bash +set -e + +arch=$(uname -m) +distros=("glibc" "musl") + +glibc_builder_dockerfile=$(cat <<'EOF' +FROM ubuntu:20.04 +RUN apt-get update +RUN apt-get install -y make gcc +RUN apt-get install -y autoconf automake zlib1g-dev +EOF +) + +musl_builder_dockerfile=$(cat <<'EOF' +FROM alpine:3.8 +RUN apk add --no-cache make gcc musl-dev +RUN apk add --no-cache autoconf automake zlib-dev +EOF +) + +build_script=$(cat <<'EOF' +#! /bin/sh +set -e +cd dropbear +autoreconf +./configure --enable-static --prefix=/opt/kernel + +# Improve SFTP up/download throughputs. +# FIXME: Temporarily falling back to the default to avoid PyCharm compatibility issue +sed -i 's/\(DEFAULT_RECV_WINDOW\) [0-9][0-9]*/\1 2097152/' default_options.h +sed -i 's/\(RECV_MAX_PAYLOAD_LEN\) [0-9][0-9]*/\1 2621440/' default_options.h +sed -i 's/\(TRANS_MAX_PAYLOAD_LEN\) [0-9][0-9]*/\1 2621440/' default_options.h +sed -i 's/DEFAULT_PATH/getenv("PATH")/' svr-chansession.c + +# Disable clearing environment variables for new pty sessions and remote commands +sed -i 's%/\* *#define \+DEBUG_VALGRIND *\*/%#define DEBUG_VALGRIND%' debug.h + +make +cp dropbear ../dropbear.$X_DISTRO.$X_ARCH.bin +cp dropbearkey ../dropbearkey.$X_DISTRO.$X_ARCH.bin +cp dropbearconvert ../dropbearconvert.$X_DISTRO.$X_ARCH.bin +make clean +EOF +) + +SCRIPT_DIR=$(cd `dirname "${BASH_SOURCE[0]}"` && pwd) +temp_dir=$(mktemp -d -t dropbear-build.XXXXX) +echo "Using temp directory: $temp_dir" +echo "$build_script" > "$temp_dir/build.sh" +chmod +x $temp_dir/*.sh +echo "$glibc_builder_dockerfile" > "$SCRIPT_DIR/dropbear-builder.glibc.dockerfile" +echo "$musl_builder_dockerfile" > "$SCRIPT_DIR/dropbear-builder.musl.dockerfile" + +for distro in "${distros[@]}"; do + docker build -t dropbear-builder:$distro \ + -f $SCRIPT_DIR/dropbear-builder.$distro.dockerfile $SCRIPT_DIR +done + +cd "$temp_dir" +git clone -c advice.detachedHead=false --branch "DROPBEAR_2020.81" https://github.com/mkj/dropbear dropbear + +for distro in "${distros[@]}"; do + docker run --rm -it \ + -e X_DISTRO=$distro \ + -e X_ARCH=$arch \ + -u $(id -u):$(id -g) \ + -w /workspace \ + -v $temp_dir:/workspace \ + dropbear-builder:$distro \ + /workspace/build.sh +done + +ls -l . +cp dropbear.*.bin $SCRIPT_DIR/../src/ai/backend/runner +cp dropbearkey.*.bin $SCRIPT_DIR/../src/ai/backend/runner +cp dropbearconvert.*.bin $SCRIPT_DIR/../src/ai/backend/runner + +rm -rf "$temp_dir" diff --git a/scripts/agent/build-krunner-extractor.sh b/scripts/agent/build-krunner-extractor.sh new file mode 100755 index 0000000000..67baef96d6 --- /dev/null +++ b/scripts/agent/build-krunner-extractor.sh @@ -0,0 +1,6 @@ +#! /bin/sh + +# IMPORTANT: this must be executed at the respository root. + +docker build -f src/ai/backend/runner/krunner-extractor.dockerfile -t backendai-krunner-extractor:latest src/ai/backend/runner +docker save backendai-krunner-extractor:latest | xz > src/ai/backend/runner/krunner-extractor.img.tar.xz diff --git a/scripts/agent/build-sftpserver.sh b/scripts/agent/build-sftpserver.sh new file mode 100755 index 0000000000..8d817a040c --- /dev/null +++ b/scripts/agent/build-sftpserver.sh @@ -0,0 +1,136 @@ +#! /bin/bash +set -e + +arch=$(uname -m) +distros=("ubuntu16.04" "ubuntu18.04" "ubuntu20.04" "centos7.6" "alpine3.8") + +static_libs_dockerfile_part=$(cat <<'EOF' +ENV ZLIB_VER=1.2.11 \ + SSL_VER=1.1.1i + +RUN wget https://www.zlib.net/zlib-${ZLIB_VER}.tar.gz -O /root/zlib-${ZLIB_VER}.tar.gz && \ + wget https://www.openssl.org/source/openssl-${SSL_VER}.tar.gz -O /root/openssl-${SSL_VER}.tar.gz + +RUN cd /root && \ + tar xzvf zlib-${ZLIB_VER}.tar.gz && \ + tar xzvf openssl-${SSL_VER}.tar.gz + +RUN echo "BUILD: zlib" && \ + cd /root/zlib-${ZLIB_VER} && \ + ./configure --prefix=/usr/local --static && \ + make && \ + make install + +RUN echo "BUILD: OpenSSL" && \ + cd /root/openssl-${SSL_VER} && \ + ./config --prefix=/usr no-shared --openssldir=/usr/local/openssl && \ + make && \ + make install +EOF +) + +ubuntu1604_builder_dockerfile=$(cat <<'EOF' +FROM ubuntu:16.04 +RUN apt-get update +RUN apt-get install -y make gcc +RUN apt-get install -y autoconf +RUN apt-get install -y wget +# below required for sys/mman.h +RUN apt-get install -y libc6-dev +EOF +) + +ubuntu1804_builder_dockerfile=$(cat <<'EOF' +FROM ubuntu:18.04 +RUN apt-get update +RUN apt-get install -y make gcc +RUN apt-get install -y autoconf +RUN apt-get install -y wget +# below required for sys/mman.h +RUN apt-get install -y libc6-dev +EOF +) + +ubuntu2004_builder_dockerfile=$(cat <<'EOF' +FROM ubuntu:20.04 +RUN apt-get update +RUN apt-get install -y make gcc +RUN apt-get install -y autoconf +RUN apt-get install -y wget +# below required for sys/mman.h +RUN apt-get install -y libc6-dev +EOF +) + +centos_builder_dockerfile=$(cat <<'EOF' +FROM centos:7 +RUN yum install -y make gcc +RUN yum install -y autoconf +RUN yum install -y wget +EOF +) + +alpine_builder_dockerfile=$(cat <<'EOF' +FROM alpine:3.8 +RUN apk add --no-cache make gcc musl-dev +RUN apk add --no-cache autoconf +RUN apk add --no-cache wget +# below required for sys/mman.h +RUN apk add --no-cache linux-headers +EOF +) + +build_script=$(cat <<'EOF' +#! /bin/sh +echo "BUILD: OpenSSH" +cd /workspace/openssh-portable +autoreconf +export LDFLAGS="-L/root/zlib-${ZLIB_VER} -L/root/openssl-${SSL_VER} -pthread" +export LIBS="-ldl" +sed -i "s/-lcrypto/-l:libcrypto.a/" ./configure +sed -i "s/-lz/-l:libz.a/" ./configure +./configure --prefix=/usr +sed -i 's/^# \?define SFTP_MAX_MSG_LENGTH[ \t]*.*/#define SFTP_MAX_MSG_LENGTH 5242880/g' sftp-common.h +make sftp-server scp +cp sftp-server ../sftp-server.$X_DISTRO.$X_ARCH.bin +cp scp ../scp.$X_DISTRO.$X_ARCH.bin +make clean +EOF +) + +SCRIPT_DIR=$(cd `dirname "${BASH_SOURCE[0]}"` && pwd) +temp_dir=$(mktemp -d -t sftpserver-build.XXXXX) +echo "Using temp directory: $temp_dir" +echo "$build_script" > "$temp_dir/build.sh" +chmod +x $temp_dir/*.sh +echo -e "$ubuntu1604_builder_dockerfile\n$static_libs_dockerfile_part" > "$SCRIPT_DIR/sftpserver-builder.ubuntu16.04.dockerfile" +echo -e "$ubuntu1804_builder_dockerfile\n$static_libs_dockerfile_part" > "$SCRIPT_DIR/sftpserver-builder.ubuntu18.04.dockerfile" +echo -e "$ubuntu2004_builder_dockerfile\n$static_libs_dockerfile_part" > "$SCRIPT_DIR/sftpserver-builder.ubuntu20.04.dockerfile" +echo -e "$centos_builder_dockerfile\n$static_libs_dockerfile_part" > "$SCRIPT_DIR/sftpserver-builder.centos7.6.dockerfile" +echo -e "$alpine_builder_dockerfile\n$static_libs_dockerfile_part" > "$SCRIPT_DIR/sftpserver-builder.alpine3.8.dockerfile" + +for distro in "${distros[@]}"; do + docker build -t sftpserver-builder:$distro \ + -f $SCRIPT_DIR/sftpserver-builder.$distro.dockerfile $SCRIPT_DIR +done + +cd "$temp_dir" +git clone -c advice.detachedHead=false --branch "V_8_1_P1" https://github.com/openssh/openssh-portable openssh-portable + +for distro in "${distros[@]}"; do + docker run --rm -it \ + -e X_DISTRO=$distro \ + -e X_ARCH=$arch \ + -u $(id -u):$(id -g) \ + -w /workspace \ + -v $temp_dir:/workspace \ + sftpserver-builder:$distro \ + /workspace/build.sh +done + +ls -l . +cp sftp-server.*.bin $SCRIPT_DIR/../src/ai/backend/runner +cp scp.*.bin $SCRIPT_DIR/../src/ai/backend/runner + +cd $SCRIPT_DIR/.. +rm -rf "$temp_dir" diff --git a/scripts/agent/build-socket-relay.sh b/scripts/agent/build-socket-relay.sh new file mode 100755 index 0000000000..46ed971a39 --- /dev/null +++ b/scripts/agent/build-socket-relay.sh @@ -0,0 +1,5 @@ +#! /bin/bash +IMG="backendai-socket-relay:latest" +docker build -t "$IMG" docker/socket-relay +docker image inspect "$IMG" | jq -r '.[0].ContainerConfig.Labels."ai.backend.version"' > src/ai/backend/agent/docker/backendai-socket-relay.version.txt +docker save "$IMG" | gzip > src/ai/backend/agent/docker/backendai-socket-relay.img.tar.gz diff --git a/scripts/agent/build-suexec.sh b/scripts/agent/build-suexec.sh new file mode 100755 index 0000000000..e67c80a4c9 --- /dev/null +++ b/scripts/agent/build-suexec.sh @@ -0,0 +1,83 @@ +#! /bin/bash +set -e + +arch=$(uname -m) +distros=("ubuntu16.04" "ubuntu18.04" "ubuntu20.04" "centos7.6" "alpine3.8") + +ubuntu1604_builder_dockerfile=$(cat <<'EOF' +FROM ubuntu:16.04 +RUN apt-get update +RUN apt-get install -y make gcc +EOF +) + +ubuntu1804_builder_dockerfile=$(cat <<'EOF' +FROM ubuntu:18.04 +RUN apt-get update +RUN apt-get install -y make gcc +EOF +) + +ubuntu2004_builder_dockerfile=$(cat <<'EOF' +FROM ubuntu:20.04 +RUN apt-get update +RUN apt-get install -y make gcc +EOF +) + +centos_builder_dockerfile=$(cat <<'EOF' +FROM centos:7 +RUN yum install -y make gcc +EOF +) + +alpine_builder_dockerfile=$(cat <<'EOF' +FROM alpine:3.8 +RUN apk add --no-cache make gcc musl-dev +EOF +) + +build_script=$(cat <<'EOF' +#! /bin/sh +set -e +cd su-exec +make +cp su-exec ../su-exec.$X_DISTRO.$X_ARCH.bin +make clean +EOF +) + +SCRIPT_DIR=$(cd `dirname "${BASH_SOURCE[0]}"` && pwd) +temp_dir=$(mktemp -d -t suexec-build.XXXXX) +echo "Using temp directory: $temp_dir" +echo "$build_script" > "$temp_dir/build.sh" +chmod +x $temp_dir/*.sh +echo "$ubuntu1604_builder_dockerfile" > "$SCRIPT_DIR/suexec-builder.ubuntu16.04.dockerfile" +echo "$ubuntu1804_builder_dockerfile" > "$SCRIPT_DIR/suexec-builder.ubuntu18.04.dockerfile" +echo "$ubuntu2004_builder_dockerfile" > "$SCRIPT_DIR/suexec-builder.ubuntu20.04.dockerfile" +echo "$centos_builder_dockerfile" > "$SCRIPT_DIR/suexec-builder.centos7.6.dockerfile" +echo "$alpine_builder_dockerfile" > "$SCRIPT_DIR/suexec-builder.alpine3.8.dockerfile" + +for distro in "${distros[@]}"; do + docker build -t suexec-builder:$distro \ + -f $SCRIPT_DIR/suexec-builder.$distro.dockerfile $SCRIPT_DIR +done + +cd "$temp_dir" +git clone -c advice.detachedHead=false https://github.com/ncopa/su-exec su-exec + +for distro in "${distros[@]}"; do + docker run --rm -it \ + -e X_DISTRO=$distro \ + -e X_ARCH=$arch \ + -u $(id -u):$(id -g) \ + -w /workspace \ + -v $temp_dir:/workspace \ + suexec-builder:$distro \ + /workspace/build.sh +done + +ls -l . +cp su-exec.*.bin $SCRIPT_DIR/../src/ai/backend/runner + +rm -rf "$temp_dir" diff --git a/scripts/agent/build-tmux.sh b/scripts/agent/build-tmux.sh new file mode 100755 index 0000000000..aa1330789b --- /dev/null +++ b/scripts/agent/build-tmux.sh @@ -0,0 +1,92 @@ +#! /bin/bash +set -e + +arch=$(uname -m) +distros=("glibc" "musl") + +glibc_builder_dockerfile=$(cat <<'EOF' +FROM ubuntu:20.04 +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update +RUN apt-get install -y make gcc g++ bison flex +RUN apt-get install -y pkg-config +EOF +) + +musl_builder_dockerfile=$(cat <<'EOF' +FROM alpine:3.8 +RUN apk add --no-cache make gcc g++ musl-dev file bison flex +RUN apk add --no-cache pkgconfig +EOF +) + +build_script=$(cat <<'EOF' +#! /bin/sh +set -e +TARGETDIR=$PWD/build +mkdir -p $TARGETDIR + +cd libevent-2.0.22-stable +./configure --prefix=$TARGETDIR --disable-openssl --enable-shared=no --enable-static=yes --with-pic && make && make install +make clean +cd .. +cd ncurses-6.0 + +CPPFLAGS="-P" ./configure --prefix $TARGETDIR \ + --with-default-terminfo-dir=/usr/share/terminfo \ + --with-terminfo-dirs="/etc/terminfo:/lib/terminfo:/usr/share/terminfo" \ + --enable-pc-files \ + --with-pkg-config-libdir=$TARGETDIR/lib/pkgconfig \ +&& make && make install +make clean +cd .. +cd tmux-3.0a +PKG_CONFIG_PATH=$TARGETDIR/lib/pkgconfig \ + CFLAGS="-I$TARGETDIR/include/event2 -I$TARGETDIR/include/ncurses" \ + LDFLAGS="-L$TARGETDIR/lib" \ + ./configure --enable-static --prefix=$TARGETDIR && make && make install +make clean +cd .. + +cp $TARGETDIR/bin/tmux tmux.$X_DISTRO.$X_ARCH.bin +rm -rf $TARGETDIR + +EOF +) + +SCRIPT_DIR=$(cd `dirname "${BASH_SOURCE[0]}"` && pwd) +temp_dir=$(mktemp -d -t tmux-build.XXXXX) +echo "Using temp directory: $temp_dir" +echo "$build_script" > "$temp_dir/build.sh" +chmod +x $temp_dir/*.sh +echo "$glibc_builder_dockerfile" > "$SCRIPT_DIR/tmux-builder.glibc.dockerfile" +echo "$musl_builder_dockerfile" > "$SCRIPT_DIR/tmux-builder.musl.dockerfile" + +for distro in "${distros[@]}"; do + docker build -t tmux-builder:$distro \ + -f $SCRIPT_DIR/tmux-builder.$distro.dockerfile $SCRIPT_DIR +done + +cd "$temp_dir" + +curl -LO https://github.com/libevent/libevent/releases/download/release-2.0.22-stable/libevent-2.0.22-stable.tar.gz +tar -zxvf libevent-2.0.22-stable.tar.gz +curl -LO https://mirror.yongbok.net/gnu/ncurses/ncurses-6.0.tar.gz +tar zxvf ncurses-6.0.tar.gz +curl -LO https://github.com/tmux/tmux/releases/download/3.0a/tmux-3.0a.tar.gz +tar zxvf tmux-3.0a.tar.gz + +for distro in "${distros[@]}"; do + docker run --rm -it \ + -e X_DISTRO=$distro \ + -e X_ARCH=$arch \ + -w /workspace \ + -v $temp_dir:/workspace \ + tmux-builder:$distro \ + /workspace/build.sh +done + +ls -l . +cp tmux.*.bin $SCRIPT_DIR/../src/ai/backend/runner + +rm -rf "$temp_dir" diff --git a/scripts/agent/ci/deploy.sh b/scripts/agent/ci/deploy.sh new file mode 100755 index 0000000000..5393a78fec --- /dev/null +++ b/scripts/agent/ci/deploy.sh @@ -0,0 +1,6 @@ +#! /bin/bash +set -ev + +pip install --user -U twine setuptools wheel +python setup.py sdist bdist_wheel +twine upload dist/* diff --git a/scripts/agent/ci/install-manager.sh b/scripts/agent/ci/install-manager.sh new file mode 100755 index 0000000000..f2e3ffd95c --- /dev/null +++ b/scripts/agent/ci/install-manager.sh @@ -0,0 +1,17 @@ +#! /bin/bash +set -ev + +cd ${HOME}/build/lablup +git clone https://github.com/lablup/backend.ai-manager.git +cd backend.ai-manager +python -m venv --system-site-packages ~/virtualenv/manager +set +v +source ~/virtualenv/manager/bin/activate +set -v + +pip install -U pip setuptools +sed -i'' -e "s/{BRANCH}/$BRANCH/g" requirements-ci.txt +pip install -U --upgrade-strategy=eager -r requirements-ci.txt +psql -c 'CREATE DATABASE testing;' -U postgres +cp alembic.ini.sample alembic.ini +sed -i'' -e 's!^sqlalchemy.url = .*$!sqlalchemy.url = postgresql://postgres@localhost:5432/testing!' alembic.ini diff --git a/scripts/agent/ci/overwrite-python.sh b/scripts/agent/ci/overwrite-python.sh new file mode 100755 index 0000000000..cdf4aa3b0e --- /dev/null +++ b/scripts/agent/ci/overwrite-python.sh @@ -0,0 +1,15 @@ +#! /bin/bash +set -ev + +PYTHON_MAJOR=$(echo "$PYTHON_VERSION" | awk -F \. '{print $1}') +docker create --name pybin python:$PYTHON_VERSION +sudo docker cp "pybin:/usr/local/bin" /usr/local +sudo docker cp "pybin:/usr/local/lib/python${PYTHON_VERSION}" "/usr/local/lib/python${PYTHON_VERSION}" +sudo docker cp "pybin:/usr/local/lib/libpython${PYTHON_VERSION}m.so" /usr/local/lib +sudo docker cp "pybin:/usr/local/lib/libpython${PYTHON_VERSION}m.so.1.0" /usr/local/lib +sudo docker cp "pybin:/usr/local/lib/libpython${PYTHON_MAJOR}.so" /usr/local/lib +sudo docker cp "pybin:/usr/local/include/python${PYTHON_VERSION}m" /usr/local/include +docker rm pybin +sudo ldconfig +rm -rf ~/virtualenv/python* +python --version diff --git a/scripts/agent/ci/run-manager.sh b/scripts/agent/ci/run-manager.sh new file mode 100755 index 0000000000..af1caf0f74 --- /dev/null +++ b/scripts/agent/ci/run-manager.sh @@ -0,0 +1,21 @@ +#! /bin/bash +set -e +source ~/virtualenv/manager/bin/activate + +set -v +cd ${HOME}/build/lablup/backend.ai-manager +python -c 'import sys; print(sys.prefix)' +python -m ai.backend.manager.cli schema oneshot head +python -m ai.backend.manager.cli fixture populate example_keypair +python -m ai.backend.manager.cli etcd put volumes/_mount /tmp/vfolders +python -m ai.backend.gateway.server \ + --etcd-addr ${BACKEND_ETCD_ADDR} \ + --namespace ${BACKEND_NAMESPACE} \ + --redis-addr ${BACKEND_REDIS_ADDR} \ + --db-addr ${BACKEND_DB_ADDR} \ + --db-name ${BACKEND_DB_NAME} \ + --db-user ${BACKEND_DB_USER} \ + --db-password "${BACKEND_DB_PASSWORD}" \ + --service-ip 127.0.0.1 \ + --service-port 5001 \ + --events-port 5002 & diff --git a/scripts/agent/deploy-static/deploy_static_files.py b/scripts/agent/deploy-static/deploy_static_files.py new file mode 100644 index 0000000000..5e72ecd737 --- /dev/null +++ b/scripts/agent/deploy-static/deploy_static_files.py @@ -0,0 +1,34 @@ +import paramiko +import sys +import os +import subprocess + +if len(sys.argv) == 1: + print('Usage: python deploy_static_files.py ') + exit(1) + +STATIC_FILE = 'https://backend-ai-k8s-agent-static.s3.ap-northeast-2.amazonaws.com/bai-static.tar.gz' + +ssh = paramiko.SSHClient() +ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + +for ip in sys.argv[1:]: + ssh.connect(ip) + + # Get cwd and username + stdin, stdout, stderr = ssh.exec_command('pwd') + pwd = stdout.readlines()[0].strip().replace('\n', '') + print(f'pwd: {pwd}') + stdin, stdout, stderr = ssh.exec_command('whoami') + whoami = stdout.readlines()[0].strip().replace('\n', '') + print(f'whoami: {whoami}') + + # delete old static files + stdin, stdout, stderr = ssh.exec_command(f'sudo rm -rf /opt/backend.ai && rm -rf {pwd}/bai*') + + print(''.join(stdout.readlines())) + + # Extract to /opt/backend.ai + stdin, stdout, stderr = ssh.exec_command(f'wget https://backend-ai-k8s-agent-static.s3.ap-northeast-2.amazonaws.com/bai-static.tar.gz && tar xvf {pwd}/bai-static.tar.gz && sudo mv {pwd}/backend.ai /opt && sudo chown {whoami}:{whoami} /opt/backend.ai') + + print(''.join(stdout.readlines())) \ No newline at end of file diff --git a/scripts/agent/eks.yml b/scripts/agent/eks.yml new file mode 100644 index 0000000000..89512cee9b --- /dev/null +++ b/scripts/agent/eks.yml @@ -0,0 +1,33 @@ +--- +apiVersion: eksctl.io/v1alpha5 +kind: ClusterConfig + +metadata: + name: bai-agent-k8s-test + region: ap-northeast-2 + +vpc: + id: "vpc-0b13efed1cef527a6" # (optional, must match VPC ID used for each subnet below) + cidr: "192.168.0.0/16" # (optional, must match CIDR used by the given VPC) + subnets: + # must provide 'private' and/or 'public' subnets by availibility zone as shown + private: + ap-northeast-2a: + id: "subnet-05d21f68fee1e6605" + cidr: "192.168.2.0/24" # (optional, must match CIDR used by the given subnet) + ap-northeast-2b: + id: "subnet-02d9a82da50e55d16" + cidr: "192.168.3.0/24" + public: + ap-northeast-2a: + id: "subnet-0c7205f74540a9dda" + cidr: "192.168.1.0/24" + ap-northeast-2b: + id: "subnet-01e0fa8018fb2ab56" + cidr: "192.168.4.0/24" +nodeGroups: + - name: ng-1 + instanceType: t3.medium + desiredCapacity: 2 + privateNetworking: true # if only 'Private' subnets are given, this must be enabled + diff --git a/scripts/agent/run-with-halfstack.sh b/scripts/agent/run-with-halfstack.sh new file mode 100755 index 0000000000..19cc12359a --- /dev/null +++ b/scripts/agent/run-with-halfstack.sh @@ -0,0 +1,5 @@ +#! /bin/sh + +export BACKEND_ETCD_ADDR=localhost:8120 + +exec "$@" diff --git a/scripts/agent/update-metadata-iptables.sh b/scripts/agent/update-metadata-iptables.sh new file mode 100755 index 0000000000..d2e1affd0d --- /dev/null +++ b/scripts/agent/update-metadata-iptables.sh @@ -0,0 +1,20 @@ +#! /bin/bash + +IPTABLES_REGEX='REDIRECT\s+tcp\s+\-\-\s+anywhere\s+169\.254\.169\.254\s+tcp dpt:http redir ports 40128' +if [ $(id -u) -ne 0 ]; then + echo "Please run as root." + exit +fi +if [[ $(sudo iptables -t nat -L PREROUTING) =~ $IPTABLES_REGEX ]]; then + echo "iptables rule already set, skipping" +else + sudo iptables -t nat \ + -I PREROUTING \ + -p tcp \ + -d 169.254.169.254 \ + --dport 80 \ + -j REDIRECT \ + --to-ports 40128 \ + -i docker0 + echo "iptables rule updated" +fi diff --git a/scripts/check-docker.py b/scripts/check-docker.py new file mode 100644 index 0000000000..560f5798bd --- /dev/null +++ b/scripts/check-docker.py @@ -0,0 +1,109 @@ +import functools +import json +import re +import subprocess +import sys +from pathlib import Path +from urllib.parse import quote + +import requests +import requests_unixsocket + + +log = functools.partial(print, file=sys.stderr) +run = subprocess.run + + +def parse_version(expr): + result = [] + for part in expr.split('.'): + try: + result.append(int(part)) + except ValueError: + result.append(part) + return tuple(result) + + +def detect_snap_docker(): + if not Path('/run/snapd.socket').is_socket(): + return None + with requests.get("http+unix://%2Frun%2Fsnapd.socket/v2/snaps?names=docker") as r: + if r.status_code != 200: + raise RuntimeError("Failed to query Snapd package information") + response_data = r.json() + for pkg_data in response_data['result']: + if pkg_data['name'] == 'docker': + return pkg_data['version'] + + +def detect_system_docker(): + sock_paths = [ + Path('/run/docker.sock'), # Linux default + Path('/var/run/docker.sock'), # macOS default + ] + for sock_path in sock_paths: + if sock_path.is_socket(): + break + else: + return None + encoded_sock_path = quote(bytes(sock_path), safe='') + with requests.get(f"http+unix://{encoded_sock_path}/version") as r: + if r.status_code != 200: + raise RuntimeError("Failed to query the Docker daemon API") + response_data = r.json() + return response_data['Version'] + + +def fail_with_snap_docker_refresh_request(): + log("Please install Docker 20.10.15 or later from the Snap package index.") + log("Instructions: `sudo snap refresh docker --edge`") + sys.exit(1) + + +def fail_with_system_docker_install_request(): + log("Please install Docker for your system.") + log("Instructions: https://docs.docker.com/install/") + sys.exit(1) + + +def fail_with_compose_install_request(): + log("Please install docker-compose v2 or later.") + log("Instructions: https://docs.docker.com/compose/install/") + sys.exit(1) + + +def main(): + requests_unixsocket.monkeypatch() + + docker_version = detect_snap_docker() + if docker_version is not None: + log(f"Detected Docker installation: Snap package ({docker_version})") + if parse_version(docker_version) < (20, 10, 15): + fail_with_snap_docker_refresh_request() + else: + docker_version = detect_system_docker() + if docker_version is not None: + log(f"Detected Docker installation: System package ({docker_version})") + else: + fail_with_system_docker_install_request() + + try: + proc = run(['docker', 'compose', 'version'], capture_output=True, check=True) + except subprocess.CalledProcessError as e: + fail_with_compose_install_request() + else: + m = re.search(r'\d+\.\d+\.\d+', proc.stdout.decode()) + if m is None: + log("Failed to retrieve the docker-compose version!") + sys.exit(1) + else: + compose_version = m.group(0) + log(f"Detected docker-compose installation ({compose_version})") + if parse_version(compose_version) < (2, 0, 0): + fail_with_compose_install_request() + + # Now we can proceed with the given docker & docker-compose installation. + + +if __name__ == '__main__': + main() diff --git a/scripts/delete-dev.sh b/scripts/delete-dev.sh index d3059ef637..edb650a846 100755 --- a/scripts/delete-dev.sh +++ b/scripts/delete-dev.sh @@ -31,13 +31,6 @@ usage() { echo "${LWHITE}OPTIONS${NC}" echo " ${LWHITE}-h, --help${NC} Show this help and exit" echo "" - echo " ${LWHITE}-e, --env ENVID${NC} Set the target environment ID (required)" - echo "" - echo " ${LWHITE}--install-path PATH${NC} Set the target directory when installed in a" - echo " non-default locatin (default: ./backend.ai-dev)" - echo "" - echo " ${LWHITE}--skip-venvs${NC} Skip removal of virtualenvs (default: false)" - echo "" echo " ${LWHITE}--skip-containers${NC} Skip removal of docker resources (default: false)" echo "" echo " ${LWHITE}--skip-source${NC} Skip removal of the install path (default: false)" @@ -87,22 +80,15 @@ else exit 1 fi -ENV_ID="" INSTALL_PATH="./backend.ai-dev" REMOVE_VENVS=1 REMOVE_CONTAINERS=1 -REMOVE_SOURCE=1 while [ $# -gt 0 ]; do case $1 in -h | --help) usage; exit 1 ;; - -e | --env) ENV_ID=$2; shift ;; - --env=*) ENV_ID="${1#*=}" ;; - --install-path) INSTALL_PATH=$2; shift ;; - --install-path=*) INSTALL_PATH="${1#*=}" ;; --skip-venvs) REMOVE_VENVS=0 ;; --skip-containers) REMOVE_CONTAINERS=0 ;; - --skip-source) REMOVE_SOURCE=0 ;; *) echo "Unknown option: $1" echo "Run '$0 --help' for usage." @@ -110,21 +96,10 @@ while [ $# -gt 0 ]; do esac shift 1 done -if [ -z "$ENV_ID" ]; then - echo "You must specify the environment ID (-e/--env option)" - exit 1 -fi -INSTALL_PATH=$(readlinkf "$INSTALL_PATH") if [ $REMOVE_VENVS -eq 1 ]; then - echo "Removing Python virtual environments..." - pyenv uninstall -f "venv-${ENV_ID}-agent" - pyenv uninstall -f "venv-${ENV_ID}-client" - pyenv uninstall -f "venv-${ENV_ID}-common" - pyenv uninstall -f "venv-${ENV_ID}-manager" - pyenv uninstall -f "venv-${ENV_ID}-webserver" - pyenv uninstall -f "venv-${ENV_ID}-storage-proxy" - pyenv uninstall -f "venv-${ENV_ID}-tester" + echo "Removing the unified and temporary venvs..." + rm -rf dist/export pyenv uninstall -f "tmp-grpcio-build" else echo "Skipped removal of Python virtual environments." @@ -132,28 +107,10 @@ fi if [ $REMOVE_CONTAINERS -eq 1 ]; then echo "Removing Docker containers..." - cd "${INSTALL_PATH}/backend.ai" - $docker_sudo $DOCKER_COMPOSE -p "${ENV_ID}" -f "docker-compose.halfstack.${ENV_ID}.yml" down - rm "docker-compose.halfstack.${ENV_ID}.yml" + $docker_sudo $DOCKER_COMPOSE -f "docker-compose.halfstack.current.yml" down + rm "docker-compose.halfstack.current.yml" else echo "Skipped removal of Docker containers." fi -if [ $REMOVE_SOURCE -eq 1 ]; then - echo "Removing cloned source files..." - $sudo rm -rf "${INSTALL_PATH}/manager" - $sudo rm -rf "${INSTALL_PATH}/agent" - $sudo rm -rf "${INSTALL_PATH}/common" - $sudo rm -rf "${INSTALL_PATH}/client-py" - $sudo rm -rf "${INSTALL_PATH}/webserver" - $sudo rm -rf "${INSTALL_PATH}/storage-proxy" - $sudo rm -rf "${INSTALL_PATH}/backend.ai" - $sudo rm -rf "${INSTALL_PATH}/vfolder" - $sudo rm -rf "${INSTALL_PATH}/accel-cuda" - $sudo rm -rf "${INSTALL_PATH}/tester" - $sudo rm -rf "${INSTALL_PATH}/wheelhouse" - echo "Please remove ${INSTALL_PATH} by yourself." -else - echo "Skipped removal of cloned source files." -fi echo "Done." diff --git a/scripts/install-dev.sh b/scripts/install-dev.sh index 2c4f767bd9..ee26e5a06e 100755 --- a/scripts/install-dev.sh +++ b/scripts/install-dev.sh @@ -39,61 +39,65 @@ trim() { usage() { echo "${GREEN}Backend.AI Development Setup${NC}: ${CYAN}Auto-installer Tool${NC}" echo "" - echo "${LWHITE}USAGE${NC}" - echo " $0 ${LWHITE}[OPTIONS]${NC}" + echo "Installs the single-node development setup of Backend.AI from this" + echo "semi-mono repository for the server-side components." echo "" - echo "${LWHITE}OPTIONS${NC}" - echo " ${LWHITE}-h, --help${NC} Show this help message and exit" + echo "Changes in 22.06 or later:" echo "" - echo " ${LWHITE}-e, --env ENVID${NC}" - echo " Manually override the environment ID to use" - echo " (default: random-generated)" + echo "* Deprecated '-e/--env', '--install-path', '--python-version' options" + echo " as they are now deprecated because the working-copy directory" + echo " becomes the target installation path and identifies the" + echo " installation". + echo "* '--server-branch' and '--client-branch' is now merged into a single" + echo " '--branch' option." echo "" - echo " ${LWHITE}--python-version VERSION${NC}" - echo " Set the Python version to install via pyenv" - echo " (default: 3.9.10)" - echo "" - echo " ${LWHITE}--install-path PATH${NC} Set the target directory" - echo " (default: ./backend.ai-dev)" + echo "${LWHITE}USAGE${NC}" + echo " $0 ${LWHITE}[OPTIONS]${NC}" echo "" - echo " ${LWHITE}--server-branch NAME${NC}" - echo " The branch of git clones for server components" - echo " (default: main)" + echo "${LWHITE}OPTIONS${NC}" + echo " ${LWHITE}-h, --help${NC}" + echo " Show this help message and exit" echo "" - echo " ${LWHITE}--client-branch NAME${NC}" - echo " The branch of git clones for client components" - echo " (default: main)" + echo " ${LWHITE}--branch NAME${NC}" + echo " The branch of git clones for server components" + echo " (default: main)" echo "" - echo " ${LWHITE}--enable-cuda${NC} Install CUDA accelerator plugin and pull a" - echo " TenosrFlow CUDA kernel for testing/demo." - echo " (default: false)" + echo " ${LWHITE}--enable-cuda${NC}" + echo " Install CUDA accelerator plugin and pull a" + echo " TenosrFlow CUDA kernel for testing/demo." + echo " (default: false)" echo "" - echo " ${LWHITE}--cuda-branch NAME${NC} The branch of git clone for the CUDA accelerator " - echo " plugin; only valid if ${LWHITE}--enable-cuda${NC} is specified." - echo " If set as ${LWHITE}\"mock\"${NC}, it will install the mockup version " - echo " plugin so that you may develop and test CUDA integration " - echo " features without real GPUs." - echo " (default: main)" + echo " ${LWHITE}--cuda-branch NAME${NC}" + echo " The branch of git clone for the CUDA accelerator " + echo " plugin; only valid if ${LWHITE}--enable-cuda${NC} is specified." + echo " If set as ${LWHITE}\"mock\"${NC}, it will install the mockup version " + echo " plugin so that you may develop and test CUDA integration " + echo " features without real GPUs." + echo " (default: main)" echo "" - echo " ${LWHITE}--postgres-port PORT${NC} The port to bind the PostgreSQL container service." - echo " (default: 8100)" + echo " ${LWHITE}--postgres-port PORT${NC}" + echo " The port to bind the PostgreSQL container service." + echo " (default: 8100)" echo "" - echo " ${LWHITE}--redis-port PORT${NC} The port to bind the Redis container service." - echo " (default: 8110)" + echo " ${LWHITE}--redis-port PORT${NC}" + echo " The port to bind the Redis container service." + echo " (default: 8110)" echo "" - echo " ${LWHITE}--etcd-port PORT${NC} The port to bind the etcd container service." - echo " (default: 8120)" + echo " ${LWHITE}--etcd-port PORT${NC}" + echo " The port to bind the etcd container service." + echo " (default: 8120)" echo "" - echo " ${LWHITE}--manager-port PORT${NC} The port to expose the manager API service." - echo " (default: 8081)" + echo " ${LWHITE}--manager-port PORT${NC}" + echo " The port to expose the manager API service." + echo " (default: 8081)" echo "" echo " ${LWHITE}--agent-rpc-port PORT${NC}" - echo " The port for the manager-to-agent RPC calls." - echo " (default: 6001)" + echo " The port for the manager-to-agent RPC calls." + echo " (default: 6001)" echo "" echo " ${LWHITE}--agent-watcher-port PORT${NC}" - echo " The port for the agent's watcher service." - echo " (default: 6009)" + echo " The port for the agent's watcher service." + echo " (default: 6009)" } show_error() { @@ -175,12 +179,12 @@ else exit 1 fi -ROOT_PATH=$(pwd) -ENV_ID="" -PYTHON_VERSION="3.10.4" -SERVER_BRANCH="main" -CLIENT_BRANCH="main" -INSTALL_PATH="./backend.ai-dev" +ROOT_PATH="$(pwd)" +PLUGIN_PATH="${ROOT_PATH}/plugins" +HALFSTACK_VOLUME_PATH="${ROOT_PATH}/volumes" +PANTS_VERSION=$(cat pants.toml | $bpython -c 'import sys,re;m=re.search("pants_version = \"([^\"]+)\"", sys.stdin.read());print(m.group(1) if m else sys.exit(1))') +PYTHON_VERSION=$(cat pants.toml | $bpython -c 'import sys,re;m=re.search("CPython==([^\"]+)", sys.stdin.read());print(m.group(1) if m else sys.exit(1))') +MONO_BRANCH="main" DOWNLOAD_BIG_IMAGES=0 ENABLE_CUDA=0 CUDA_BRANCH="main" @@ -191,7 +195,7 @@ CUDA_BRANCH="main" # WEBSERVER_PORT="8080" # AGENT_RPC_PORT="6001" # AGENT_WATCHER_PORT="6009" -# VFOLDER_REL_PATH="vfolder/local" +# VFOLDER_REL_PATH="vfroot/local" # LOCAL_STORAGE_PROXY="local" # LOCAL_STORAGE_VOLUME="volume1" @@ -202,7 +206,7 @@ MANAGER_PORT="8091" WEBSERVER_PORT="8090" AGENT_RPC_PORT="6011" AGENT_WATCHER_PORT="6019" -VFOLDER_REL_PATH="vfolder/local" +VFOLDER_REL_PATH="vfroot/local" LOCAL_STORAGE_PROXY="local" # MUST be one of the real storage volumes LOCAL_STORAGE_VOLUME="volume1" @@ -210,16 +214,10 @@ LOCAL_STORAGE_VOLUME="volume1" while [ $# -gt 0 ]; do case $1 in -h | --help) usage; exit 1 ;; - -e | --env) ENV_ID=$2; shift ;; - --env=*) ENV_ID="${1#*=}" ;; --python-version) PYTHON_VERSION=$2; shift ;; --python-version=*) PYTHON_VERSION="${1#*=}" ;; - --install-path) INSTALL_PATH=$2; shift ;; - --install-path=*) INSTALL_PATH="${1#*=}" ;; - --server-branch) SERVER_BRANCH=$2; shift ;; - --server-branch=*) SERVER_BRANCH="${1#*=}" ;; - --client-branch) CLIENT_BRANCH=$2; shift ;; - --client-branch=*) CLIENT_BRANCH="${1#*=}" ;; + --branch) MONO_BRANCH=$2; shift ;; + --branch=*) MONO_BRANCH="${1#*=}" ;; --enable-cuda) ENABLE_CUDA=1 ;; --download-big-images) DOWNLOAD_BIG_IMAGES=1 ;; --cuda-branch) CUDA_BRANCH=$2; shift ;; @@ -245,7 +243,6 @@ while [ $# -gt 0 ]; do esac shift done -INSTALL_PATH=$(readlinkf "$INSTALL_PATH") install_brew() { case $DISTRO in @@ -332,49 +329,6 @@ install_system_pkg() { esac } -install_docker() { - show_info "Install docker" - case $DISTRO in - Debian) - $sudo apt-get install -y lxcfs - $sudo curl -fsSL https://get.docker.io | bash - $sudo usermod -aG docker $(whoami) - ;; - RedHat) - $sudo curl -fsSL https://get.docker.io | bash - $sudo usermod -aG docker $(whoami) - ;; - Darwin) - show_info "Please install the latest version of docker and try again." - show_info "It should have been installed with Docker Desktop for Mac or Docker Toolbox." - show_info " - Instructions: https://docs.docker.com/install/" - show_info " - Download: https://download.docker.com/mac/stable/Docker.dmg" - exit 1 - ;; - esac -} - -install_docker_compose() { - show_info "Install docker-compose" - case $DISTRO in - Debian) - $sudo curl -L "https://github.com/docker/compose/releases/download/v2.2.3/docker-compose-$(uname -s | tr '[:upper:]' '[:lower:]' | sed -e 's/_.*//')-$(uname -m)" -o /usr/local/bin/docker-compose - $sudo chmod +x /usr/local/bin/docker-compose - ;; - RedHat) - $sudo curl -L "https://github.com/docker/compose/releases/download/v2.2.3/docker-compose-$(uname -s | tr '[:upper:]' '[:lower:]' | sed -e 's/_.*//')-$(uname -m)" -o /usr/local/bin/docker-compose - $sudo chmod +x /usr/local/bin/docker-compose - ;; - Darwin) - show_info "Please install the latest version of docker-compose and try again." - show_info "It should have been installed with Docker Desktop for Mac or Docker Toolbox." - show_info " - Instructions: https://docs.docker.com/compose/install/" - show_info " - Download: https://download.docker.com/mac/stable/Docker.dmg" - exit 1 - ;; - esac -} - set_brew_python_build_flags() { local _prefix_openssl="$(brew --prefix openssl)" local _prefix_sqlite3="$(brew --prefix sqlite3)" @@ -436,33 +390,64 @@ check_python() { pyenv shell --unset } +bootstrap_pants() { + set -e + mkdir -p .tmp + if [ -f '.pants.env' -a -f './pants-local' ]; then + echo "It seems that you have an already locally bootstrapped Pants." + echo "The installer will keep using it." + echo "If you want to reset it, delete ./.pants.env and ./pants-local files." + ./pants-local version + PANTS="./pants-local" + return + fi + set +e + PANTS="./pants" + ./pants version + # Note that pants 2.11 requires Python 3.9 (not Python 3.10!) to work properly. + if [ $? -eq 1 ]; then + show_info "Downloading and building Pants for the current setup" + _PYENV_PYVER=$(pyenv versions --bare | grep '^3\.9\.' | grep -v '/envs/' | sort -t. -k1,1r -k 2,2nr -k 3,3nr | head -n 1) + if [ -z "$_PYENV_PYVER" ]; then + echo "No Python 3.9 available via pyenv!" + echo "Please install Python 3.9 using pyenv," + echo "or add 'PY=' in ./.pants.env to " + echo "manually set the Pants-compatible interpreter path." + exit 1 + else + echo "Chosen Python $_PYENV_PYVER (from pyenv) as the local Pants interpreter" + fi + echo "PY=\$(pyenv prefix $_PYENV_PYVER)/bin/python" >> "$ROOT_PATH/.pants.env" + set -e + # FIXME: The branch name uses the "MAJOR.MINOR.x" format. Until Pants is officially released with Linux arm64 support, + # we need to fallback to the main branch for custom patches. + ## local PANTS_CLONE_VERSION="$(echo $PANTS_VERSION | cut -d. -f1).$(echo $PANTS_VERSION | cut -d. -f2).x" + local PANTS_CLONE_VERSION="main" + git clone --branch=$PANTS_CLONE_VERSION --depth=1 https://github.com/pantsbuild/pants tools/pants-src + cd tools/pants-src + if [ "$(uname -p)" = "arm" -a "$DISTRO" != "Darwin" ]; then + git apply ../pants-linux-aarch64.patch + fi + cd ../.. + ln -s tools/pants-local + ./pants-local version + PANTS="./pants-local" + fi + set +e +} + # BEGIN! echo " " echo "${LGREEN}Backend.AI one-line installer for developers${NC}" -# NOTE: docker-compose enforces lower-cased project names -if [ -z "${ENV_ID}" ]; then - ENV_ID=$(LC_ALL=C tr -dc 'a-z0-9' < /dev/urandom | head -c 8) -fi -show_note "Your environment ID is ${YELLOW}${ENV_ID}${NC}." - # Check prerequisites show_info "Checking prerequisites and script dependencies..." install_script_deps -if ! type "docker" >/dev/null 2>&1; then - show_warning "docker is not available; trying to install it automatically..." - install_docker -fi -docker compose version >/dev/null 2>&1 -if [ $? -eq 0 ]; then - DOCKER_COMPOSE="docker compose" -else - if ! type "docker-compose" >/dev/null 2>&1; then - show_warning "docker-compose is not available; trying to install it automatically..." - install_docker_compose - fi - DOCKER_COMPOSE="docker-compose" +$bpython -m pip --disable-pip-version-check install -q requests requests_unixsocket +$bpython scripts/check-docker.py +if [ $? -ne 0 ]; then + exit 1 fi if [ "$DISTRO" = "Darwin" ]; then echo "validating Docker Desktop mount permissions..." @@ -474,9 +459,9 @@ if [ "$DISTRO" = "Darwin" ]; then show_error "You must allow mount of '$HOME/.pyenv' in the File Sharing preference of the Docker Desktop app." exit 1 fi - docker run --rm -v "$INSTALL_PATH:/root/vol" alpine:3.8 ls /root/vol > /dev/null 2>&1 + docker run --rm -v "$ROOT_PATH:/root/vol" alpine:3.8 ls /root/vol > /dev/null 2>&1 if [ $? -ne 0 ]; then - show_error "You must allow mount of '$INSTALL_PATH' in the File Sharing preference of the Docker Desktop app." + show_error "You must allow mount of '$ROOT_PATH' in the File Sharing preference of the Docker Desktop app." exit 1 fi echo "${REWRITELN}validating Docker Desktop mount permissions: ok" @@ -522,20 +507,13 @@ install_python show_info "Checking Python features..." check_python +show_info "Bootstrapping the Pants build system..." +bootstrap_pants + set -e -show_info "Creating virtualenv on pyenv..." -pyenv virtualenv "${PYTHON_VERSION}" "venv-${ENV_ID}-manager" -pyenv virtualenv "${PYTHON_VERSION}" "venv-${ENV_ID}-agent" -pyenv virtualenv "${PYTHON_VERSION}" "venv-${ENV_ID}-common" -pyenv virtualenv "${PYTHON_VERSION}" "venv-${ENV_ID}-client" -pyenv virtualenv "${PYTHON_VERSION}" "venv-${ENV_ID}-storage-proxy" -pyenv virtualenv "${PYTHON_VERSION}" "venv-${ENV_ID}-webserver" -pyenv virtualenv "${PYTHON_VERSION}" "venv-${ENV_ID}-tester" # Make directories -show_info "Creating the install directory..." -mkdir -p "${INSTALL_PATH}" -cd "${INSTALL_PATH}" +show_info "Using the current working-copy directory as the installation path..." mkdir -p ./wheelhouse if [ "$DISTRO" = "Darwin" -a "$(uname -p)" = "arm" ]; then @@ -568,37 +546,15 @@ fi # Install postgresql, etcd packages via docker show_info "Launching the docker compose \"halfstack\"..." -git clone --branch "${SERVER_BRANCH}" https://github.com/lablup/backend.ai - -cd backend.ai -cp "docker-compose.halfstack-${SERVER_BRANCH//.}.yml" "docker-compose.halfstack.${ENV_ID}.yml" -sed_inplace "s/8100:5432/${POSTGRES_PORT}:5432/" "docker-compose.halfstack.${ENV_ID}.yml" -sed_inplace "s/8110:6379/${REDIS_PORT}:6379/" "docker-compose.halfstack.${ENV_ID}.yml" -sed_inplace "s/8120:2379/${ETCD_PORT}:2379/" "docker-compose.halfstack.${ENV_ID}.yml" -mkdir -p tmp/backend.ai-halfstack/postgres-data -mkdir -p tmp/backend.ai-halfstack/etcd-data -$docker_sudo $DOCKER_COMPOSE -f "docker-compose.halfstack.${ENV_ID}.yml" -p "${ENV_ID}" up -d -$docker_sudo docker ps | grep "${ENV_ID}" # You should see three containers here. - -# Clone source codes -show_info "Cloning backend.ai source codes..." -cd "${INSTALL_PATH}" -git clone --branch "${SERVER_BRANCH}" https://github.com/lablup/backend.ai-manager manager -git clone --branch "${SERVER_BRANCH}" https://github.com/lablup/backend.ai-agent agent -git clone --branch "${SERVER_BRANCH}" https://github.com/lablup/backend.ai-common common -git clone --branch "${SERVER_BRANCH}" https://github.com/lablup/backend.ai-storage-proxy storage-proxy -git clone --branch "${SERVER_BRANCH}" --recurse-submodules https://github.com/lablup/backend.ai-webserver webserver -git clone --branch "${CLIENT_BRANCH}" https://github.com/lablup/backend.ai-client-py client-py -git clone --branch "${CLIENT_BRANCH}" https://github.com/lablup/backend.ai-test.git tester - -if [ $ENABLE_CUDA -eq 1 ]; then - if [ "$CUDA_BRANCH" == "mock" ]; then - git clone https://github.com/lablup/backend.ai-accelerator-cuda-mock accel-cuda - cp accel-cuda/configs/sample-mig.toml agent/cuda-mock.toml - else - git clone --branch "${CUDA_BRANCH}" https://github.com/lablup/backend.ai-accelerator-cuda accel-cuda - fi -fi +mkdir -p "$HALFSTACK_VOLUME_PATH" +cp "docker-compose.halfstack-${MONO_BRANCH//.}.yml" "docker-compose.halfstack.current.yml" +sed_inplace "s/8100:5432/${POSTGRES_PORT}:5432/" "docker-compose.halfstack.current.yml" +sed_inplace "s/8110:6379/${REDIS_PORT}:6379/" "docker-compose.halfstack.current.yml" +sed_inplace "s/8120:2379/${ETCD_PORT}:2379/" "docker-compose.halfstack.current.yml" +mkdir -p "${HALFSTACK_VOLUME_PATH}/postgres-data" +mkdir -p "${HALFSTACK_VOLUME_PATH}/etcd-data" +$docker_sudo docker compose -f "docker-compose.halfstack.current.yml" up -d +$docker_sudo docker compose -f "docker-compose.halfstack.current.yml" ps # You should see three containers here. check_snappy() { pip download python-snappy @@ -610,73 +566,40 @@ check_snappy() { rm -f $pkgfile } -show_info "Install packages on virtual environments..." -cd "${INSTALL_PATH}/manager" -pyenv local "venv-${ENV_ID}-manager" -pip install -U -q pip setuptools wheel +show_info "Creating the unified virtualenv for IDEs..." check_snappy -pip install -U --find-links=../wheelhouse -e ../common -r requirements/dev.txt +$PANTS export '::' -cd "${INSTALL_PATH}/agent" -pyenv local "venv-${ENV_ID}-agent" -pip install -U -q pip setuptools wheel -pip install -U --find-links=../wheelhouse -e ../common -r requirements/dev.txt -if [[ "$OSTYPE" == "linux-gnu" ]]; then - $sudo setcap cap_sys_ptrace,cap_sys_admin,cap_dac_override+eip $(readlinkf $(pyenv which python)) -fi if [ $ENABLE_CUDA -eq 1 ]; then - cd "${INSTALL_PATH}/accel-cuda" - pyenv local "venv-${ENV_ID}-agent" # share the agent's venv - pip install -U -e . + if [ "$CUDA_BRANCH" == "mock" ]; then + PLUGIN_BRANCH=$CUDA_BRANCH scripts/install-plugin.sh "lablup/backend.ai-accelerator-cuda-mock" + cp "${PLUGIN_PATH}/backend.ai-accelerator-cuda-mock/configs/sample-mig.toml" cuda-mock.toml + else + PLUGIN_BRANCH=$CUDA_BRANCH scripts/install-plugin.sh "lablup/backend.ai-accelerator-cuda" + fi fi -cd "${INSTALL_PATH}/common" -pyenv local "venv-${ENV_ID}-common" -pip install -U -q pip setuptools wheel -pip install -U --find-links=../wheelhouse -r requirements/dev.txt - -cd "${INSTALL_PATH}/storage-proxy" -pyenv local "venv-${ENV_ID}-storage-proxy" -pip install -U -q pip setuptools wheel -pip install -U --find-links=../wheelhouse -e ../common -r requirements/dev.txt - -cd "${INSTALL_PATH}/webserver" -pyenv local "venv-${ENV_ID}-webserver" -pip install -U -q pip setuptools wheel -pip install -U --find-links=../wheelhouse -e ../client-py -r requirements/dev.txt - -cd "${INSTALL_PATH}/tester" -pyenv local "venv-${ENV_ID}-tester" -pip install -U -q pip setuptools wheel -pip install -U --find-links=../wheelhouse -r requirements/dev.txt - # Copy default configurations show_info "Copy default configuration files to manager / agent root..." -cd "${INSTALL_PATH}/manager" -pyenv local "venv-${ENV_ID}-manager" -cp config/halfstack.toml ./manager.toml +cp configs/manager/halfstack.toml ./manager.toml sed_inplace "s/num-proc = .*/num-proc = 1/" ./manager.toml sed_inplace "s/port = 8120/port = ${ETCD_PORT}/" ./manager.toml sed_inplace "s/port = 8100/port = ${POSTGRES_PORT}/" ./manager.toml sed_inplace "s/port = 8081/port = ${MANAGER_PORT}/" ./manager.toml -cp config/halfstack.alembic.ini ./alembic.ini +cp configs/manager/halfstack.alembic.ini ./alembic.ini sed_inplace "s/localhost:8100/localhost:${POSTGRES_PORT}/" ./alembic.ini -python -m ai.backend.manager.cli etcd put config/redis/addr "127.0.0.1:${REDIS_PORT}" -cp config/sample.etcd.volumes.json ./dev.etcd.volumes.json +./backend.ai mgr etcd put config/redis/addr "127.0.0.1:${REDIS_PORT}" +cp configs/manager/sample.etcd.volumes.json ./dev.etcd.volumes.json MANAGER_AUTH_KEY=$(python -c 'import secrets; print(secrets.token_hex(32), end="")') sed_inplace "s/\"secret\": \"some-secret-shared-with-storage-proxy\"/\"secret\": \"${MANAGER_AUTH_KEY}\"/" ./dev.etcd.volumes.json sed_inplace "s/\"default_host\": .*$/\"default_host\": \"${LOCAL_STORAGE_PROXY}:${LOCAL_STORAGE_VOLUME}\",/" ./dev.etcd.volumes.json -cd "${INSTALL_PATH}/agent" -pyenv local "venv-${ENV_ID}-agent" -cp config/halfstack.toml ./agent.toml +cp configs/agent/halfstack.toml ./agent.toml sed_inplace "s/port = 8120/port = ${ETCD_PORT}/" ./agent.toml sed_inplace "s/port = 6001/port = ${AGENT_RPC_PORT}/" ./agent.toml sed_inplace "s/port = 6009/port = ${AGENT_WATCHER_PORT}/" ./agent.toml -cd "${INSTALL_PATH}/storage-proxy" -pyenv local "venv-${ENV_ID}-storage-proxy" -cp config/sample.toml ./storage-proxy.toml +cp configs/storage-proxy/sample.toml ./storage-proxy.toml STORAGE_PROXY_RANDOM_KEY=$(python -c 'import secrets; print(secrets.token_hex(32), end="")') sed_inplace "s/secret = \"some-secret-private-for-storage-proxy\"/secret = \"${STORAGE_PROXY_RANDOM_KEY}\"/" ./storage-proxy.toml sed_inplace "s/secret = \"some-secret-shared-with-manager\"/secret = \"${MANAGER_AUTH_KEY}\"/" ./storage-proxy.toml @@ -688,75 +611,61 @@ sed_inplace "s/^purity/# purity/" ./storage-proxy.toml sed_inplace "s/^netapp_/# netapp_/" ./storage-proxy.toml # add LOCAL_STORAGE_VOLUME vfs volume -echo "\n[volume.${LOCAL_STORAGE_VOLUME}]\nbackend = \"vfs\"\npath = \"${INSTALL_PATH}/${VFOLDER_REL_PATH}\"" >> ./storage-proxy.toml +echo "\n[volume.${LOCAL_STORAGE_VOLUME}]\nbackend = \"vfs\"\npath = \"${ROOT_PATH}/${VFOLDER_REL_PATH}\"" >> ./storage-proxy.toml -cd "${INSTALL_PATH}/webserver" -pyenv local "venv-${ENV_ID}-webserver" -cp webserver.sample.conf ./webserver.conf +cp configs/webserver/sample.conf ./webserver.conf sed_inplace "s/^port = 8080$/port = ${WEBSERVER_PORT}/" ./webserver.conf sed_inplace "s/https:\/\/api.backend.ai/http:\/\/127.0.0.1:${MANAGER_PORT}/" ./webserver.conf sed_inplace "s/ssl-verify = true/ssl-verify = false/" ./webserver.conf sed_inplace "s/redis.port = 6379/redis.port = ${REDIS_PORT}/" ./webserver.conf -cd "${INSTALL_PATH}/tester" -pyenv local "venv-${ENV_ID}-tester" -cp sample-env-tester.sh ./env-tester-admin.sh -cp sample-env-tester.sh ./env-tester-user.sh +# TODO +## cp configs/testers/sample-env-tester.sh ./env-tester-admin.sh +## cp configs/testers/sample-env-tester.sh ./env-tester-user.sh # DB schema show_info "Setting up databases..." -cd "${INSTALL_PATH}/manager" -python -m ai.backend.manager.cli schema oneshot -python -m ai.backend.manager.cli fixture populate fixtures/example-keypairs.json -python -m ai.backend.manager.cli fixture populate fixtures/example-resource-presets.json +./backend.ai mgr schema oneshot +./backend.ai mgr fixture populate fixtures/manager/example-keypairs.json +./backend.ai mgr fixture populate fixtures/manager/example-resource-presets.json # Docker registry setup show_info "Configuring the Lablup's official image registry..." -cd "${INSTALL_PATH}/manager" -python -m ai.backend.manager.cli etcd put config/docker/registry/cr.backend.ai "https://cr.backend.ai" -python -m ai.backend.manager.cli etcd put config/docker/registry/cr.backend.ai/type "harbor2" +./backend.ai mgr etcd put config/docker/registry/cr.backend.ai "https://cr.backend.ai" +./backend.ai mgr etcd put config/docker/registry/cr.backend.ai/type "harbor2" if [ "$(uname -p)" = "arm" ]; then - python -m ai.backend.manager.cli etcd put config/docker/registry/cr.backend.ai/project "stable,community,multiarch" + ./backend.ai mgr etcd put config/docker/registry/cr.backend.ai/project "stable,community,multiarch" else - python -m ai.backend.manager.cli etcd put config/docker/registry/cr.backend.ai/project "stable,community" + ./backend.ai mgr etcd put config/docker/registry/cr.backend.ai/project "stable,community" fi # Scan the container image registry show_info "Scanning the image registry..." -python -m ai.backend.manager.cli etcd rescan-images cr.backend.ai +./backend.ai mgr etcd rescan-images cr.backend.ai if [ "$(uname -p)" = "arm" ]; then - python -m ai.backend.manager.cli etcd alias python "cr.backend.ai/multiarch/python:3.9-ubuntu20.04" aarch64 + ./backend.ai mgr etcd alias python "cr.backend.ai/multiarch/python:3.9-ubuntu20.04" aarch64 else - python -m ai.backend.manager.cli etcd alias python "cr.backend.ai/stable/python:3.9-ubuntu20.04" x86_64 + ./backend.ai mgr etcd alias python "cr.backend.ai/stable/python:3.9-ubuntu20.04" x86_64 fi # Virtual folder setup show_info "Setting up virtual folder..." -mkdir -p "${INSTALL_PATH}/${VFOLDER_REL_PATH}" -cd "${INSTALL_PATH}/manager" -python -m ai.backend.manager.cli etcd put-json volumes "./dev.etcd.volumes.json" -cd "${INSTALL_PATH}/agent" +mkdir -p "${ROOT_PATH}/${VFOLDER_REL_PATH}" +./backend.ai mgr etcd put-json volumes "./dev.etcd.volumes.json" mkdir -p scratches -POSTGRES_CONTAINER_ID=$($docker_sudo docker ps | grep "${ENV_ID}[-_]backendai-half-db[-_]1" | awk '{print $1}') +POSTGRES_CONTAINER_ID=$($docker_sudo docker compose -f "docker-compose.halfstack.current.yml" ps | grep "[-_]backendai-half-db[-_]1" | awk '{print $1}') $docker_sudo docker exec -it $POSTGRES_CONTAINER_ID psql postgres://postgres:develove@localhost:5432/backend database -c "update domains set allowed_vfolder_hosts = '{${LOCAL_STORAGE_PROXY}:${LOCAL_STORAGE_VOLUME}}';" $docker_sudo docker exec -it $POSTGRES_CONTAINER_ID psql postgres://postgres:develove@localhost:5432/backend database -c "update groups set allowed_vfolder_hosts = '{${LOCAL_STORAGE_PROXY}:${LOCAL_STORAGE_VOLUME}}';" $docker_sudo docker exec -it $POSTGRES_CONTAINER_ID psql postgres://postgres:develove@localhost:5432/backend database -c "update keypair_resource_policies set allowed_vfolder_hosts = '{${LOCAL_STORAGE_PROXY}:${LOCAL_STORAGE_VOLUME}}';" $docker_sudo docker exec -it $POSTGRES_CONTAINER_ID psql postgres://postgres:develove@localhost:5432/backend database -c "update vfolders set host = '${LOCAL_STORAGE_PROXY}:${LOCAL_STORAGE_VOLUME}' where host='${LOCAL_STORAGE_VOLUME}';" -show_info "Installing Python client SDK/CLI source..." -# Install python client package -cd "${INSTALL_PATH}/client-py" -pyenv local "venv-${ENV_ID}-client" -pip install -U -q pip setuptools wheel -pip install -U -r requirements/dev.txt - # Client backend endpoint configuration shell script CLIENT_ADMIN_CONF_FOR_API="env-local-admin-api.sh" CLIENT_ADMIN_CONF_FOR_SESSION="env-local-admin-session.sh" echo "# Directly access to the manager using API keypair (admin)" >> "${CLIENT_ADMIN_CONF_FOR_API}" echo "export BACKEND_ENDPOINT=http://127.0.0.1:${MANAGER_PORT}/" >> "${CLIENT_ADMIN_CONF_FOR_API}" -echo "export BACKEND_ACCESS_KEY=$(cat ../manager/fixtures/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="admin@lablup.com") | .access_key')" >> "${CLIENT_ADMIN_CONF_FOR_API}" -echo "export BACKEND_SECRET_KEY=$(cat ../manager/fixtures/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="admin@lablup.com") | .secret_key')" >> "${CLIENT_ADMIN_CONF_FOR_API}" +echo "export BACKEND_ACCESS_KEY=$(cat fixtures/manager/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="admin@lablup.com") | .access_key')" >> "${CLIENT_ADMIN_CONF_FOR_API}" +echo "export BACKEND_SECRET_KEY=$(cat fixtures/manager/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="admin@lablup.com") | .secret_key')" >> "${CLIENT_ADMIN_CONF_FOR_API}" echo "export BACKEND_ENDPOINT_TYPE=api" >> "${CLIENT_ADMIN_CONF_FOR_API}" chmod +x "${CLIENT_ADMIN_CONF_FOR_API}" echo "# Indirectly access to the manager via the web server a using cookie-based login session (admin)" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" @@ -765,15 +674,15 @@ echo "unset BACKEND_ACCESS_KEY" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" echo "unset BACKEND_SECRET_KEY" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" echo "export BACKEND_ENDPOINT_TYPE=session" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" echo "echo 'Run backend.ai login to make an active session.'" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" -echo "echo 'Username: $(cat ../manager/fixtures/example-keypairs.json | jq -r '.users[] | select(.username=="admin") | .email')'" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" -echo "echo 'Password: $(cat ../manager/fixtures/example-keypairs.json | jq -r '.users[] | select(.username=="admin") | .password')'" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" +echo "echo 'Username: $(cat fixtures/manager/example-keypairs.json | jq -r '.users[] | select(.username=="admin") | .email')'" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" +echo "echo 'Password: $(cat fixtures/manager/example-keypairs.json | jq -r '.users[] | select(.username=="admin") | .password')'" >> "${CLIENT_ADMIN_CONF_FOR_SESSION}" chmod +x "${CLIENT_ADMIN_CONF_FOR_SESSION}" CLIENT_DOMAINADMIN_CONF_FOR_API="env-local-domainadmin-api.sh" CLIENT_DOMAINADMIN_CONF_FOR_SESSION="env-local-domainadmin-session.sh" echo "# Directly access to the manager using API keypair (admin)" >> "${CLIENT_DOMAINADMIN_CONF_FOR_API}" echo "export BACKEND_ENDPOINT=http://127.0.0.1:${MANAGER_PORT}/" >> "${CLIENT_DOMAINADMIN_CONF_FOR_API}" -echo "export BACKEND_ACCESS_KEY=$(cat ../manager/fixtures/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="domain-admin@lablup.com") | .access_key')" >> "${CLIENT_DOMAINADMIN_CONF_FOR_API}" -echo "export BACKEND_SECRET_KEY=$(cat ../manager/fixtures/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="domain-admin@lablup.com") | .secret_key')" >> "${CLIENT_DOMAINADMIN_CONF_FOR_API}" +echo "export BACKEND_ACCESS_KEY=$(cat fixtures/manager/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="domain-admin@lablup.com") | .access_key')" >> "${CLIENT_DOMAINADMIN_CONF_FOR_API}" +echo "export BACKEND_SECRET_KEY=$(cat fixtures/manager/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="domain-admin@lablup.com") | .secret_key')" >> "${CLIENT_DOMAINADMIN_CONF_FOR_API}" echo "export BACKEND_ENDPOINT_TYPE=api" >> "${CLIENT_DOMAINADMIN_CONF_FOR_API}" chmod +x "${CLIENT_DOMAINADMIN_CONF_FOR_API}" echo "# Indirectly access to the manager via the web server a using cookie-based login session (admin)" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" @@ -782,15 +691,15 @@ echo "unset BACKEND_ACCESS_KEY" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" echo "unset BACKEND_SECRET_KEY" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" echo "export BACKEND_ENDPOINT_TYPE=session" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" echo "echo 'Run backend.ai login to make an active session.'" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" -echo "echo 'Username: $(cat ../manager/fixtures/example-keypairs.json | jq -r '.users[] | select(.username=="domain-admin") | .email')'" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" -echo "echo 'Password: $(cat ../manager/fixtures/example-keypairs.json | jq -r '.users[] | select(.username=="domain-admin") | .password')'" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" +echo "echo 'Username: $(cat fixtures/manager/example-keypairs.json | jq -r '.users[] | select(.username=="domain-admin") | .email')'" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" +echo "echo 'Password: $(cat fixtures/manager/example-keypairs.json | jq -r '.users[] | select(.username=="domain-admin") | .password')'" >> "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" chmod +x "${CLIENT_DOMAINADMIN_CONF_FOR_SESSION}" CLIENT_USER_CONF_FOR_API="env-local-user-api.sh" CLIENT_USER_CONF_FOR_SESSION="env-local-user-session.sh" echo "# Directly access to the manager using API keypair (user)" >> "${CLIENT_USER_CONF_FOR_API}" echo "export BACKEND_ENDPOINT=http://127.0.0.1:${MANAGER_PORT}/" >> "${CLIENT_USER_CONF_FOR_API}" -echo "export BACKEND_ACCESS_KEY=$(cat ../manager/fixtures/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="user@lablup.com") | .access_key')" >> "${CLIENT_USER_CONF_FOR_API}" -echo "export BACKEND_SECRET_KEY=$(cat ../manager/fixtures/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="user@lablup.com") | .secret_key')" >> "${CLIENT_USER_CONF_FOR_API}" +echo "export BACKEND_ACCESS_KEY=$(cat fixtures/manager/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="user@lablup.com") | .access_key')" >> "${CLIENT_USER_CONF_FOR_API}" +echo "export BACKEND_SECRET_KEY=$(cat fixtures/manager/example-keypairs.json | jq -r '.keypairs[] | select(.user_id=="user@lablup.com") | .secret_key')" >> "${CLIENT_USER_CONF_FOR_API}" echo "export BACKEND_ENDPOINT_TYPE=api" >> "${CLIENT_USER_CONF_FOR_API}" chmod +x "${CLIENT_USER_CONF_FOR_API}" echo "# Indirectly access to the manager via the web server a using cookie-based login session (user)" >> "${CLIENT_USER_CONF_FOR_SESSION}" @@ -799,18 +708,15 @@ echo "unset BACKEND_ACCESS_KEY" >> "${CLIENT_USER_CONF_FOR_SESSION}" echo "unset BACKEND_SECRET_KEY" >> "${CLIENT_USER_CONF_FOR_SESSION}" echo "export BACKEND_ENDPOINT_TYPE=session" >> "${CLIENT_USER_CONF_FOR_SESSION}" echo "echo 'Run backend.ai login to make an active session.'" >> "${CLIENT_USER_CONF_FOR_SESSION}" -echo "echo 'Username: $(cat ../manager/fixtures/example-keypairs.json | jq -r '.users[] | select(.username=="user") | .email')'" >> "${CLIENT_USER_CONF_FOR_SESSION}" -echo "echo 'Password: $(cat ../manager/fixtures/example-keypairs.json | jq -r '.users[] | select(.username=="user") | .password')'" >> "${CLIENT_USER_CONF_FOR_SESSION}" +echo "echo 'Username: $(cat fixtures/manager/example-keypairs.json | jq -r '.users[] | select(.username=="user") | .email')'" >> "${CLIENT_USER_CONF_FOR_SESSION}" +echo "echo 'Password: $(cat fixtures/manager/example-keypairs.json | jq -r '.users[] | select(.username=="user") | .password')'" >> "${CLIENT_USER_CONF_FOR_SESSION}" chmod +x "${CLIENT_USER_CONF_FOR_SESSION}" -# Update tester env script -cd "${INSTALL_PATH}/tester" -VENV_PATH="$(pyenv root)/versions/venv-${ENV_ID}-client" -sed_inplace "s@export BACKENDAI_TEST_CLIENT_VENV=/home/user/.pyenv/versions/venv-dev-client@export BACKENDAI_TEST_CLIENT_VENV=${VENV_PATH}@" ./env-tester-admin.sh -sed_inplace "s@export BACKENDAI_TEST_CLIENT_ENV=/home/user/bai-dev/client-py/my-backend-session.sh@export BACKENDAI_TEST_CLIENT_ENV=${INSTALL_PATH}/client-py/${CLIENT_ADMIN_CONF_FOR_API}@" ./env-tester-admin.sh -sed_inplace "s@export BACKENDAI_TEST_CLIENT_VENV=/home/user/.pyenv/versions/venv-dev-client@export BACKENDAI_TEST_CLIENT_VENV=${VENV_PATH}@" ./env-tester-user.sh -sed_inplace "s@export BACKENDAI_TEST_CLIENT_ENV=/home/user/bai-dev/client-py/my-backend-session.sh@export BACKENDAI_TEST_CLIENT_ENV=${INSTALL_PATH}/client-py/${CLIENT_USER_CONF_FOR_API}@" ./env-tester-user.sh -cd "${INSTALL_PATH}/client-py" +# TODO: Update tester env script +## sed_inplace "s@export BACKENDAI_TEST_CLIENT_VENV=/home/user/.pyenv/versions/venv-dev-client@export BACKENDAI_TEST_CLIENT_VENV=${VENV_PATH}@" ./env-tester-admin.sh +## sed_inplace "s@export BACKENDAI_TEST_CLIENT_ENV=/home/user/bai-dev/client-py/my-backend-session.sh@export BACKENDAI_TEST_CLIENT_ENV=${INSTALL_PATH}/client-py/${CLIENT_ADMIN_CONF_FOR_API}@" ./env-tester-admin.sh +## sed_inplace "s@export BACKENDAI_TEST_CLIENT_VENV=/home/user/.pyenv/versions/venv-dev-client@export BACKENDAI_TEST_CLIENT_VENV=${VENV_PATH}@" ./env-tester-user.sh +## sed_inplace "s@export BACKENDAI_TEST_CLIENT_ENV=/home/user/bai-dev/client-py/my-backend-session.sh@export BACKENDAI_TEST_CLIENT_ENV=${INSTALL_PATH}/client-py/${CLIENT_USER_CONF_FOR_API}@" ./env-tester-user.sh show_info "Pre-pulling frequently used kernel images..." echo "NOTE: Other images will be downloaded from the docker registry when requested.\n" @@ -824,16 +730,8 @@ else fi fi -DELETE_OPTS='' -if [ ! "$INSTALL_PATH" = $(readlinkf "./backend.ai-dev") ]; then - DELETE_OPTS+=" --install-path=${INSTALL_PATH}" -fi -DELETE_OPTS=$(trim "$DELETE_OPTS") - -cd "${INSTALL_PATH}" show_info "Installation finished." show_note "Check out the default API keypairs and account credentials for local development and testing:" -echo "> ${WHITE}cd client-py${NC}" echo "> ${WHITE}cat env-local-admin-api.sh${NC}" echo "> ${WHITE}cat env-local-admin-session.sh${NC}" echo "> ${WHITE}cat env-local-domainadmin-api.sh${NC}" @@ -841,37 +739,31 @@ echo "> ${WHITE}cat env-local-domainadmin-session.sh${NC}" echo "> ${WHITE}cat env-local-user-api.sh${NC}" echo "> ${WHITE}cat env-local-user-session.sh${NC}" show_note "To apply the client config, source one of the configs like:" -echo "> ${WHITE}cd client-py${NC}" echo "> ${WHITE}source env-local-user-session.sh${NC}" echo " " show_important_note "You should change your default admin API keypairs for production environment!" show_note "How to run Backend.AI manager:" -echo "> ${WHITE}cd ${INSTALL_PATH}/manager${NC}" -echo "> ${WHITE}python -m ai.backend.manager.server --debug${NC}" +echo "> ${WHITE}./backend.ai mgr start-server --debug${NC}" show_note "How to run Backend.AI agent:" -echo "> ${WHITE}cd ${INSTALL_PATH}/agent${NC}" -echo "> ${WHITE}python -m ai.backend.agent.server --debug${NC}" +echo "> ${WHITE}./backend.ai ag start-server --debug${NC}" show_note "How to run Backend.AI storage-proxy:" -echo "> ${WHITE}cd ${INSTALL_PATH}/storage-proxy${NC}" -echo "> ${WHITE}python -m ai.backend.storage.server${NC}" +echo "> ${WHITE}./py -m ai.backend.storage.server${NC}" show_note "How to run Backend.AI web server (for ID/Password login):" -echo "> ${WHITE}cd ${INSTALL_PATH}/webserver${NC}" -echo "> ${WHITE}python -m ai.backend.web.server${NC}" +echo "> ${WHITE}./py -m ai.backend.web.server${NC}" show_note "How to run your first code:" -echo "> ${WHITE}cd ${INSTALL_PATH}/client-py${NC}" -echo "> ${WHITE}backend.ai --help${NC}" +echo "> ${WHITE}./backend.ai --help${NC}" echo "> ${WHITE}source env-local-admin-api.sh${NC}" -echo "> ${WHITE}backend.ai run python -c \"print('Hello World\\!')\"${NC}" +echo "> ${WHITE}./backend.ai run python -c \"print('Hello World\\!')\"${NC}" echo " " echo "${GREEN}Development environment is now ready.${NC}" -show_note "Reminder: Your environment ID is ${YELLOW}${ENV_ID}${NC}." -echo " * When using docker-compose, do:" -echo " > ${WHITE}cd ${INSTALL_PATH}/backend.ai${NC}" +show_note "How to run docker-compose:" if [ ! -z "$docker_sudo" ]; then - echo " > ${WHITE}${docker_sudo} docker-compose -p ${ENV_ID} -f docker-compose.halfstack.${ENV_ID}.yml up -d ...${NC}" + echo " > ${WHITE}${docker_sudo} docker compose -f docker-compose.halfstack.current.yml up -d ...${NC}" else - echo " > ${WHITE}docker-compose -p ${ENV_ID} -f docker-compose.halfstack.${ENV_ID}.yml up -d ...${NC}" + echo " > ${WHITE}docker compose -f docker-compose.halfstack.current.yml up -d ...${NC}" fi -echo " * To delete this development environment, run:" -echo " > ${WHITE}$(dirname $0)/delete-dev.sh --env ${ENV_ID} ${DELETE_OPTS}${NC}" +show_note "How to reset this setup:" +echo " > ${WHITE}$(dirname $0)/delete-dev.sh${NC}" echo " " + +# vim: tw=0 diff --git a/scripts/install-plugin.sh b/scripts/install-plugin.sh new file mode 100755 index 0000000000..6ea62f263a --- /dev/null +++ b/scripts/install-plugin.sh @@ -0,0 +1,8 @@ +#! /bin/bash + +mkdir -p ./plugins +PLUGIN_BRANCH=${PLUGIN_BRANCH:-main} +PLUGIN_OWNER=$(echo $1 | cut -d / -f 1) +PLUGIN_REPO=$(echo $1 | cut -d / -f 2) +git clone "https://github.com/$1" "./plugins/$PLUGIN_REPO" +./py -m pip install -e "./plugins/$PLUGIN_REPO" diff --git a/scripts/reinstall-plugins.sh b/scripts/reinstall-plugins.sh new file mode 100755 index 0000000000..90f7558d36 --- /dev/null +++ b/scripts/reinstall-plugins.sh @@ -0,0 +1,8 @@ +#! /bin/bash + +PY=${PY:-$(python --version|awk '{ print $2 }')} +source "dist/export/python/virtualenvs/python-default/$PY/bin/activate" + +for repo_path in $(ls -d ./plugins/*/); do + pip install -e "$repo_path" +done diff --git a/scripts/specialize.py b/scripts/specialize.py new file mode 100644 index 0000000000..fab6f58de5 --- /dev/null +++ b/scripts/specialize.py @@ -0,0 +1,36 @@ +import argparse +import shutil +from pathlib import Path + + +def _copy_if_exists(src: Path, dst: Path) -> None: + if src.exists(): + shutil.copyfile(src, dst) + + +def populate_setup_cfg(pkg_name: str) -> None: + pkg_root = Path('packages') / pkg_name + root = Path('.') + shutil.copyfile(pkg_root / 'setup.cfg', root / 'setup.cfg') + + +def populate_manifest(pkg_name: str) -> None: + pkg_root = Path('packages') / pkg_name + root = Path('.') + _copy_if_exists(pkg_root / 'MANIFEST.in', root / 'MANIFEST.in') + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("pkg_name") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + populate_setup_cfg(args.pkg_name) + populate_manifest(args.pkg_name) + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index d5f7c5eecb..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,86 +0,0 @@ -[metadata] -name = backend.ai -version = attr: ai.backend.meta.__version__ -description = Lablup Backend.AI Meta-package -long_description = file: README.md -long_description_content_type = text/markdown -url = https://backend.ai -author = Lablup Inc. -author_email = devops@lablup.com -license = LGPLv3 -classifiers = - Development Status :: 5 - Production/Stable - License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+) - Intended Audience :: Developers - Intended Audience :: Science/Research - Intended Audience :: Information Technology - Intended Audience :: System Administrators - Intended Audience :: Healthcare Industry - Intended Audience :: Financial and Insurance Industry - Programming Language :: Python - Programming Language :: Python :: 3 - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.9 - Operating System :: POSIX - Operating System :: MacOS :: MacOS X - Environment :: Console - Environment :: No Input/Output (Daemon) - Topic :: Scientific/Engineering - Topic :: Scientific/Engineering :: Artificial Intelligence - Topic :: Software Development - Topic :: System :: Clustering - Topic :: System :: Distributed Computing - Framework :: AsyncIO -project_urls = - Documentation = https://docs.backend.ai - Source = https://github.com/lablup/backend.ai - Tracker = https://github.com/lablup/backend.ai/issues - -[options] -zip_safe = False -package_dir = - = src -packages = find_namespace: -python_requires = >=3.9 -setup_requires = - setuptools>=60.0.0 -install_requires = - backend.ai-client~=22.3.0,<22.4.0 - -[options.packages.find] -where = src -include = ai.backend.meta - -[options.extras_require] -manager = - backend.ai-manager>=21.3.0,<22.4.0 -agent = - backend.ai-agent>=22.3.0,<22.4.0 -build = - wheel>=0.37.0 - twine>=3.5.0 - setuptools>=60.0.0 -dev = -test = - pytest>=6.2.5 -lint = - flake8>=4.0.1 -typecheck = - mypy>=0.930 -docs = - sphinx~=3.3 - sphinx-rtd-theme>=1.0.0 - pygments~=2.6 - -[bdist_wheel] -universal = false - -[flake8] -# ref: http://pep8.readthedocs.io/en/latest/intro.html#error-codes -ignore = E731,E221,E241,E126,E127,E129,E401,N801,N802 -max-line-length = 105 -builtins = _ -exclude = .git,.cache,.idea,.tox,.eggs,*.egg,__pycache__,venv,build,docs,docker - -[tool:pytest] -testpaths = tests diff --git a/setup.py b/setup.py deleted file mode 100644 index 606849326a..0000000000 --- a/setup.py +++ /dev/null @@ -1,3 +0,0 @@ -from setuptools import setup - -setup() diff --git a/src/ai/backend/agent/BUILD b/src/ai/backend/agent/BUILD new file mode 100644 index 0000000000..42ddaaa1e3 --- /dev/null +++ b/src/ai/backend/agent/BUILD @@ -0,0 +1,71 @@ +python_sources( + name="service", + sources=["**/*.py"], + dependencies=[ + "src/ai/backend/cli:lib", + "src/ai/backend/common:lib", + "src/ai/backend/runner:lib", + "src/ai/backend/helpers:lib", + "src/ai/backend/kernel:lib", + ":resources", + "//:reqs#backend.ai-krunner-static-gnu", + ], +) + +pex_binary( + name="server", + dependencies=[ + ":service", + ], + entry_point="server.py", +) + +pex_binary( + name="watcher", + entry_point="watcher.py", +) + +python_requirement( + name="kernel-support", + requirements=[], +) + +python_distribution( + name="dist", + dependencies=[ + ":service", + ":kernel-support", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-agent", + description="Backend.AI Agent", + license="LGPLv3", + ), + entry_points={ + "backendai_cli_v10": { + "ag": "ai.backend.agent.cli:main", + "ag.start-server": "ai.backend.agent.server:main", + }, + }, + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + "docker/*.txt", + "docker/*.tar.gz", + "docker/*.bin", + "kubernetes/*.txt", + "kubernetes/*.tar.gz", + # "kubernetes/*.bin", # no matching files yet + ], +) diff --git a/src/ai/backend/agent/README.k8s.md b/src/ai/backend/agent/README.k8s.md new file mode 100644 index 0000000000..99d9c3fb0f --- /dev/null +++ b/src/ai/backend/agent/README.k8s.md @@ -0,0 +1,208 @@ +# Backend.AI Agent with K8s + +The Backend.AI Agent is a small daemon that does: + +* Reports the status and available resource slots of a worker to the manager +* Routes code execution requests to the designated kernel container +* Manages the lifecycle of kernel containers (create/monitor/destroy them) + +## Package Structure + +* `ai.backend` + - `agent`: The agent package + - `server`: The agent daemon which communicates with the manager and the Docker daemon + + +## Installation + +First, you need **a working manager installation**. +For the detailed instructions on installing the manager, please refer +[the manager's README](https://github.com/lablup/backend.ai-manager/blob/master/README.md) +and come back here again. + +### For development + +#### Prequisites + +* `libsnappy-dev` or `snappy-devel` system package depending on your distro +* Python 3.6 or higher with [pyenv](https://github.com/pyenv/pyenv) +and [pyenv-virtualenv](https://github.com/pyenv/pyenv-virtualenv) (optional but recommneded) +* Docker 18.03 or later with docker-compose (18.09 or later is recommended) +* Properly configured Kubeconfig file - should be located at `$HOME/.kube/config` or `$KUBECONFIG` path +* [Git LFS](https://git-lfs.github.com/) installed and configured +* [Backend.AI K8s controller](https://github.com/lablup/backend.ai-k8s-controller) installed and running without any error + +#### One-shot installation + +```sh +curl https://raw.githubusercontent.com/lablup/backend.ai-agent/feature/kube-integration/install-halfstack.sh | sudo bash - +``` +This script automatically installs required components to run halfstack on a single-node cluster. This script is only for debian based distributions. + +#### Common steps + +Next, prepare the source clone of the agent and install from it as follows. + +```console +$ git clone https://github.com/lablup/backend.ai-agent agent +$ cd agent +$ git checkout feature/kube-integration +$ git lfs pull +$ pyenv virtualenv venv-agent +$ pyenv local venv-agent +$ pip install -U pip setuptools +$ pip install -U -r requirements-dev.txt +``` + +From now on, let's assume all shell commands are executed inside the virtualenv. + +### Halfstack (single-node development & testing) + +With the halfstack, you can run the agent simply. +Note that you need a working manager running with the halfstack already! + +#### Recommended directory structure + +* `backend.ai-dev` + - `manager` (git clone from [the manager repo](https://github.com/lablup/backend.ai-manager)) + - `agent` (git clone from here, with branch `feature/kube-integration`) + - `common` (git clone from [the common repo](https://github.com/lablup/backend.ai-common)) + +Install `backend.ai-common` as an editable package in the agent (and the manager) virtualenvs +to keep the codebase up-to-date. + +```console +$ cd agent +$ pip install -U -e ../common +``` + +#### Steps + +```console +$ mkdir -p "./scratches" +$ cp config/halfstack.toml ./agent.toml +``` + +Then, run it (for debugging, append a `--debug` flag): + +```console +$ python -m ai.backen`d.agent.server +``` + +To run the agent-watcher: + +```console +$ python -m ai.backend.agent.watcher +``` + +The watcher shares the same configuration TOML file with the agent. +Note that the watcher is only meaningful if the agent is installed as a systemd service +named `backendai-agent.service`. + +To run tests: + +```console +$ python -m flake8 src tests +$ python -m pytest -m 'not integration' tests +``` + + + + +## Deployment + +### Configuration + +Put a TOML-formatted agent configuration (see the sample in `config/sample.toml`) +in one of the following locations: + + * `agent.toml` (current working directory) + * `~/.config/backend.ai/agent.toml` (user-config directory) + * `/etc/backend.ai/agent.toml` (system-config directory) + +Only the first found one is used by the daemon. + +The agent reads most other configurations from the etcd v3 server where the cluster +administrator or the Backend.AI manager stores all the necessary settings. + +The etcd address and namespace must match with the manager to make the agent +paired and activated. +By specifying distinguished namespaces, you may share a single etcd cluster with multiple +separate Backend.AI clusters. + +By default the agent uses `/var/cache/scratches` directory for making temporary +home directories used by kernel containers (the `/home/work` volume mounted in +containers). Note that the directory must exist in prior and the agent-running +user must have ownership of it. You can change the location by +`scratch-root` option in `agent.toml`. + +### Setting up NFS-based vFolder for Backend.AI + +When you are trying to deploy Backend.AI Agent to K8s cluster from Cloud Provider(EKS, GKE, AKS, ...), you can provide NFS connection info directly to Backend.AI agent. All vFolder mounts will be set up and managed by K8s' PersistentVolume and PersistentVolumeClaim. To set up: + +1. Add NFS connection information to `agent.toml`. Check `config/sample.toml` for details. +2. Mount NFS volume to Backend.AI manager's mount target provided when setting up Backend.AI manager (if you have deployed halfstack with one-shot installation script, just mount to `$HOME/vfroot`). + +### Setting up CUDA support (experimental) +If [nVIDIA Device Plugin for K8s](https://github.com/NVIDIA/k8s-device-plugin) has been installed and running without any error, Backend.AI Agent automatically inspects available GPU and report it to manager. + +### Running from a command line + +The minimal command to execute: + +```sh +python -m ai.backend.agent.server +``` + +For more arguments and options, run the command with `--help` option. + +### Example config for agent server/instances + +`/etc/supervisor/conf.d/agent.conf`: + +```dosini +[program:backend.ai-agent] +user = user +stopsignal = TERM +stopasgroup = true +command = /home/user/run-agent.sh +``` + +`/home/user/run-agent.sh`: + +```sh +#!/bin/sh +source /home/user/venv-agent/bin/activate +exec python -m ai.backend.agent.server +``` + +### Networking + +The manager and agent should run in the same local network or different +networks reachable via VPNs, whereas the manager's API service must be exposed to +the public network or another private network that users have access to. + +The manager must be able to access TCP ports 6001, 6009, and 30000 to 31000 of the agents in default +configurations. You can of course change those port numbers and ranges in the configuration. + +| Manager-to-Agent TCP Ports | Usage | +|:--------------------------:|-------| +| 6001 | ZeroMQ-based RPC calls from managers to agents | +| 6009 | HTTP watcher API | +| 30000-31000 | Port pool for in-container services | + +The operation of agent itself does not require both incoming/outgoing access to +the public Internet, but if the user's computation programs need the Internet, the docker containers +should be able to access the public Internet (maybe via some corporate firewalls). + +| Agent-to-X TCP Ports | Usage | +|:------------------------:|-------| +| manager:5002 | ZeroMQ-based event push from agents to the manager | +| etcd:2379 | etcd API access | +| redis:6379 | Redis API access | +| docker-registry:{80,443} | HTTP watcher API | +| (Other hosts) | Depending on user program requirements | + +The agent and K8s cluster should run in the same local network or VPC, and should be +reachable to each other with appropriate Security Group if cluster is provisioned within Cloud Provicers. +Agent uses NodePort service to communicate between agent and pod, so nodeport should be configured properly at K8s. diff --git a/src/ai/backend/agent/README.md b/src/ai/backend/agent/README.md new file mode 100644 index 0000000000..22134a4921 --- /dev/null +++ b/src/ai/backend/agent/README.md @@ -0,0 +1,291 @@ +# Backend.AI Agent + +The Backend.AI Agent is a small daemon that does: + +* Reports the status and available resource slots of a worker to the manager +* Routes code execution requests to the designated kernel container +* Manages the lifecycle of kernel containers (create/monitor/destroy them) + +## Package Structure + +* `ai.backend` + - `agent`: The agent package + - `docker`: A docker-based backend implementation for the kernel lifecycle interface. + - `server`: The agent daemon which communicates with the manager and the Docker daemon + - `watcher`: A side-by-side daemon which provides a separate HTTP endpoint for accessing the status + information of the agent daemon and manipulation of the agent's systemd service + - `helpers`: A utility package that is available as `ai.backend.helpers` *inside* Python-based containers + - `kernel`: Language-specific runtimes (mostly ipykernel client adaptor) which run *inside* containers + - `runner`: Auxiliary components (usually self-contained binaries) mounted *inside* contaienrs + + +## Installation + +Please visit [the installation guides](https://github.com/lablup/backend.ai/wiki). + + +### Kernel/system configuration + +#### Recommended kernel parameters in the bootloader (e.g., Grub): + +``` +cgroup_enable=memory swapaccount=1 +``` + +#### Recommended resource limits: + +**`/etc/security/limits.conf`** +``` +root hard nofile 512000 +root soft nofile 512000 +root hard nproc 65536 +root soft nproc 65536 +user hard nofile 512000 +user soft nofile 512000 +user hard nproc 65536 +user soft nproc 65536 +``` + +**sysctl** +``` +fs.file-max=2048000 +fs.inotify.max_user_watches=524288 +net.core.somaxconn=1024 +net.ipv4.tcp_max_syn_backlog=1024 +net.ipv4.tcp_slow_start_after_idle=0 +net.ipv4.tcp_fin_timeout=10 +net.ipv4.tcp_window_scaling=1 +net.ipv4.tcp_tw_reuse=1 +net.ipv4.tcp_early_retrans=1 +net.ipv4.ip_local_port_range=40000 65000 +net.core.rmem_max=16777216 +net.core.wmem_max=16777216 +net.ipv4.tcp_rmem=4096 12582912 16777216 +net.ipv4.tcp_wmem=4096 12582912 16777216 +net.netfilter.nf_conntrack_max=10485760 +net.netfilter.nf_conntrack_tcp_timeout_established=432000 +net.netfilter.nf_conntrack_tcp_timeout_close_wait=10 +net.netfilter.nf_conntrack_tcp_timeout_fin_wait=10 +net.netfilter.nf_conntrack_tcp_timeout_time_wait=10 +``` + +The `ip_local_port_range` should not overlap with the container port range pool +(default: 30000 to 31000). + +To apply netfilter settings during the boot time, you may need to add `nf_conntrack` to `/etc/modules` +so that `sysctl` could set the `net.netfilter.nf_conntrack_*` values. + + +### For development + +#### Prerequisites + +* `libsnappy-dev` or `snappy-devel` system package depending on your distro +* Python 3.6 or higher with [pyenv](https://github.com/pyenv/pyenv) +and [pyenv-virtualenv](https://github.com/pyenv/pyenv-virtualenv) (optional but recommneded) +* Docker 18.03 or later with docker-compose (18.09 or later is recommended) + +First, you need **a working manager installation**. +For the detailed instructions on installing the manager, please refer +[the manager's README](https://github.com/lablup/backend.ai-manager/blob/master/README.md) +and come back here again. + +#### Preparing working copy + +Install and activate [`git-lfs`](https://git-lfs.github.com/) to work with pre-built binaries in +`src/ai/backend/runner`. + +```console +$ git lfs install +``` + +Next, prepare the source clone of the agent and install from it as follows. +`pyenv` is just a recommendation; you may use other virtualenv management tools. + +```console +$ git clone https://github.com/lablup/backend.ai-agent agent +$ cd agent +$ pyenv virtualenv venv-agent +$ pyenv local venv-agent +$ pip install -U pip setuptools +$ pip install -U -r requirements/dev.txt +``` + +### Linting + +We use `flake8` and `mypy` to statically check our code styles and type consistency. +Enable those linters in your favorite IDE or editor. + +### Halfstack (single-node development & testing) + +With the halfstack, you can run the agent simply. +Note that you need a working manager running with the halfstack already! + +#### Recommended directory structure + +* `backend.ai-dev` + - `manager` (git clone from [the manager repo](https://github.com/lablup/backend.ai-manager)) + - `agent` (git clone from here) + - `common` (git clone from [the common repo](https://github.com/lablup/backend.ai-common)) + +Install `backend.ai-common` as an editable package in the agent (and the manager) virtualenvs +to keep the codebase up-to-date. + +```console +$ cd agent +$ pip install -U -e ../common +``` + +#### Steps + +```console +$ mkdir -p "./scratches" +$ cp config/halfstack.toml ./agent.toml +``` + +If you're running agent under linux, make sure you've set appropriate iptables rule +before starting agent. This can be done by executing script `scripts/update-metadata-iptables.sh` +before each agent start. + +Then, run it (for debugging, append a `--debug` flag): + +```console +$ python -m ai.backend.agent.server +``` + +To run the agent-watcher: + +```console +$ python -m ai.backend.agent.watcher +``` + +The watcher shares the same configuration TOML file with the agent. +Note that the watcher is only meaningful if the agent is installed as a systemd service +named `backendai-agent.service`. + +To run tests: + +```console +$ python -m flake8 src tests +$ python -m pytest -m 'not integration' tests +``` + + +## Deployment + +### Configuration + +Put a TOML-formatted agent configuration (see the sample in `config/sample.toml`) +in one of the following locations: + + * `agent.toml` (current working directory) + * `~/.config/backend.ai/agent.toml` (user-config directory) + * `/etc/backend.ai/agent.toml` (system-config directory) + +Only the first found one is used by the daemon. + +The agent reads most other configurations from the etcd v3 server where the cluster +administrator or the Backend.AI manager stores all the necessary settings. + +The etcd address and namespace must match with the manager to make the agent +paired and activated. +By specifying distinguished namespaces, you may share a single etcd cluster with multiple +separate Backend.AI clusters. + +By default the agent uses `/var/cache/scratches` directory for making temporary +home directories used by kernel containers (the `/home/work` volume mounted in +containers). Note that the directory must exist in prior and the agent-running +user must have ownership of it. You can change the location by +`scratch-root` option in `agent.toml`. + +### Running from a command line + +The minimal command to execute: + +```sh +python -m ai.backend.agent.server +python -m ai.backend.agent.watcher +``` + +For more arguments and options, run the command with `--help` option. + +### Example config for systemd + +`/etc/systemd/system/backendai-agent.service`: + +```dosini +[Unit] +Description=Backend.AI Agent +Requires=docker.service +After=network.target remote-fs.target docker.service + +[Service] +Type=simple +User=root +Group=root +Environment=HOME=/home/user +ExecStart=/home/user/backend.ai/agent/run-agent.sh +WorkingDirectory=/home/user/backend.ai/agent +KillMode=process +KillSignal=SIGTERM +PrivateTmp=false +Restart=on-failure +RestartSec=5 + +[Install] +WantedBy=multi-user.target +``` + +`/home/user/backend.ai/agent/run-agent.sh`: + +```sh +#! /bin/sh +if [ -z "$PYENV_ROOT" ]; then + export PYENV_ROOT="$HOME/.pyenv" + export PATH="$PYENV_ROOT/bin:$PATH" +fi +eval "$(pyenv init -)" +eval "$(pyenv virtualenv-init -)" + +cd /home/user/backend.ai/agent +if [ "$#" -eq 0 ]; then + sh /home/user/backend.ai/agent/scripts/update-metadata-iptables.sh + exec python -m ai.backend.agent.server +else + exec "$@" +fi +``` + +### Networking + +The manager and agent should run in the same local network or different +networks reachable via VPNs, whereas the manager's API service must be exposed to +the public network or another private network that users have access to. + +The manager must be able to access TCP ports 6001, 6009, and 30000 to 31000 of the agents in default +configurations. You can of course change those port numbers and ranges in the configuration. + +| Manager-to-Agent TCP Ports | Usage | +|:--------------------------:|-------| +| 6001 | ZeroMQ-based RPC calls from managers to agents | +| 6009 | HTTP watcher API | +| 30000-31000 | Port pool for in-container services | + +The operation of agent itself does not require both incoming/outgoing access to +the public Internet, but if the user's computation programs need the Internet, the docker containers +should be able to access the public Internet (maybe via some corporate firewalls). + +| Agent-to-X TCP Ports | Usage | +|:------------------------:|-------| +| manager:5002 | ZeroMQ-based event push from agents to the manager | +| etcd:2379 | etcd API access | +| redis:6379 | Redis API access | +| docker-registry:{80,443} | HTTP watcher API | +| (Other hosts) | Depending on user program requirements | + + +LICENSES +-------- + +[GNU Lesser General Public License](https://github.com/lablup/backend.ai-agent/blob/master/LICENSE) +[Dependencies](https://github.com/lablup/backend.ai-manager/blob/agent/DEPENDENCIES.md) diff --git a/src/ai/backend/agent/VERSION b/src/ai/backend/agent/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/agent/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/agent/__init__.py b/src/ai/backend/agent/__init__.py new file mode 100644 index 0000000000..17b3552989 --- /dev/null +++ b/src/ai/backend/agent/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +__version__ = (Path(__file__).parent / 'VERSION').read_text().strip() diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py new file mode 100644 index 0000000000..40eb717ac7 --- /dev/null +++ b/src/ai/backend/agent/agent.py @@ -0,0 +1,1807 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +import asyncio +from collections import defaultdict +from decimal import Decimal +from io import BytesIO, SEEK_END +import json +import logging +from pathlib import Path +import pickle +import pkg_resources +import re +import signal +import sys +import traceback +from types import TracebackType +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Collection, + Dict, + FrozenSet, + Generic, + Optional, + List, + Literal, + Mapping, + MutableMapping, + MutableSequence, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + TYPE_CHECKING, + cast, +) +import weakref + +import aioredis +import aiotools +from async_timeout import timeout +import attr +from cachetools import cached, LRUCache +import snappy +from tenacity import ( + AsyncRetrying, + stop_after_attempt, + stop_after_delay, + retry_if_exception_type, + wait_fixed, +) +import time +import zmq, zmq.asyncio + +from ai.backend.common import msgpack, redis +from ai.backend.common.docker import ( + ImageRef, + MIN_KERNELSPEC, + MAX_KERNELSPEC, +) +from ai.backend.common.logging import BraceStyleAdapter, pretty +from ai.backend.common.types import ( + AutoPullBehavior, + ContainerId, + KernelId, + SessionId, + DeviceName, + SlotName, + HardwareMetadata, + ImageRegistry, + ClusterInfo, + KernelCreationConfig, + KernelCreationResult, + MountTypes, + MountPermission, + Sentinel, + ServicePortProtocols, + VFolderMount, + aobject, +) +from ai.backend.common.events import ( + EventProducer, + AbstractEvent, + AgentErrorEvent, + AgentHeartbeatEvent, + AgentStartedEvent, + AgentTerminatedEvent, + DoSyncKernelLogsEvent, + DoSyncKernelStatsEvent, + ExecutionCancelledEvent, + ExecutionFinishedEvent, + ExecutionStartedEvent, + ExecutionTimeoutEvent, + KernelCreatingEvent, + KernelPreparingEvent, + KernelPullingEvent, + KernelStartedEvent, + KernelTerminatedEvent, + SessionFailureEvent, + SessionSuccessEvent, +) +from ai.backend.common.utils import cancel_tasks, current_loop +from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext +from ai.backend.common.service_ports import parse_service_ports +from . import __version__ as VERSION +from .exception import AgentError, ResourceError +from .kernel import ( + AbstractKernel, + KernelFeatures, + match_distro_data, +) +from . import resources as resources_mod +from .resources import ( + AbstractComputeDevice, + AbstractComputePlugin, + AbstractAllocMap, + KernelResourceSpec, + Mount, +) +from .stats import ( + StatContext, StatModes, +) +from .types import ( + Container, + ContainerStatus, + ContainerLifecycleEvent, + LifecycleEvent, +) +from .utils import ( + generate_local_instance_id, + get_arch_name, +) + +if TYPE_CHECKING: + from ai.backend.common.etcd import AsyncEtcd + +log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.agent')) + +_sentinel = Sentinel.TOKEN + +ACTIVE_STATUS_SET = frozenset([ + ContainerStatus.RUNNING, + ContainerStatus.RESTARTING, + ContainerStatus.PAUSED, +]) + +DEAD_STATUS_SET = frozenset([ + ContainerStatus.EXITED, + ContainerStatus.DEAD, + ContainerStatus.REMOVING, +]) + + +KernelObjectType = TypeVar('KernelObjectType', bound=AbstractKernel) + + +class AbstractKernelCreationContext(aobject, Generic[KernelObjectType]): + kspec_version: int + kernel_id: KernelId + kernel_config: KernelCreationConfig + local_config: Mapping[str, Any] + kernel_features: FrozenSet[str] + image_ref: ImageRef + internal_data: Mapping[str, Any] + restarting: bool + cancellation_handlers: Sequence[Callable[[], Awaitable[None]]] = [] + _rx_distro = re.compile(r"\.([a-z-]+\d+\.\d+)\.") + + def __init__( + self, + kernel_id: KernelId, + kernel_config: KernelCreationConfig, + local_config: Mapping[str, Any], + computers: MutableMapping[str, ComputerContext], + restarting: bool = False, + ) -> None: + self.image_labels = kernel_config['image']['labels'] + self.kspec_version = int(self.image_labels.get('ai.backend.kernelspec', '1')) + self.kernel_features = frozenset(self.image_labels.get('ai.backend.features', '').split()) + self.kernel_id = kernel_id + self.kernel_config = kernel_config + self.image_ref = ImageRef( + kernel_config['image']['canonical'], + known_registries=[kernel_config['image']['registry']['name']], + architecture=kernel_config['image'].get('architecture', get_arch_name()), + ) + self.internal_data = kernel_config['internal_data'] or {} + self.computers = computers + self.restarting = restarting + self.local_config = local_config + + @abstractmethod + async def get_extra_envs(self) -> Mapping[str, str]: + return {} + + @abstractmethod + async def prepare_resource_spec( + self, + ) -> Tuple[KernelResourceSpec, Optional[Mapping[str, Any]]]: + raise NotImplementedError + + @abstractmethod + async def prepare_scratch(self) -> None: + pass + + @abstractmethod + async def get_intrinsic_mounts(self) -> Sequence[Mount]: + return [] + + @abstractmethod + async def apply_network(self, cluster_info: ClusterInfo) -> None: + """ + Apply the given cluster network information to the deployment. + """ + raise NotImplementedError + + @abstractmethod + async def install_ssh_keypair(self, cluster_info: ClusterInfo) -> None: + """ + Install the ssh keypair inside the kernel from cluster_info. + """ + raise NotImplementedError + + @abstractmethod + async def process_mounts(self, mounts: Sequence[Mount]): + raise NotImplementedError + + @abstractmethod + async def apply_accelerator_allocation(self, computer, device_alloc) -> None: + raise NotImplementedError + + @abstractmethod + def resolve_krunner_filepath(self, filename) -> Path: + """ + Return matching krunner path object for given filename. + """ + raise NotImplementedError + + @abstractmethod + def get_runner_mount( + self, + type: MountTypes, + src: Union[str, Path], + target: Union[str, Path], + perm: Literal['ro', 'rw'] = 'ro', + opts: Mapping[str, Any] = None, + ): + """ + Return mount object to mount target krunner file/folder/volume. + """ + raise NotImplementedError + + @abstractmethod + async def spawn( + self, + resource_spec: KernelResourceSpec, + environ: Mapping[str, str], + service_ports, + ) -> KernelObjectType: + raise NotImplementedError + + @abstractmethod + async def start_container( + self, + kernel_obj: AbstractKernel, + cmdargs: List[str], + resource_opts, + preopen_ports, + ) -> Mapping[str, Any]: + raise NotImplementedError + + @cached( + cache=LRUCache(maxsize=32), # type: ignore + key=lambda self: ( + self.image_ref, + self.kernel_config['image']['labels'].get('ai.backend.base-distro', 'ubuntu16.04'), + ), + ) + def get_krunner_info(self) -> Tuple[str, str, str, str, str]: + image_labels = self.kernel_config['image']['labels'] + distro = image_labels.get('ai.backend.base-distro', 'ubuntu16.04') + matched_distro, krunner_volume = match_distro_data( + self.local_config['container']['krunner-volumes'], distro) + matched_libc_style = 'glibc' + if distro.startswith('alpine'): + matched_libc_style = 'musl' + krunner_pyver = '3.6' # fallback + if m := re.search(r'^([a-z-]+)(\d+(\.\d+)*)?$', matched_distro): + matched_distro_pkgname = m.group(1).replace('-', '_') + try: + krunner_pyver = Path(pkg_resources.resource_filename( + f'ai.backend.krunner.{matched_distro_pkgname}', + f'krunner-python.{matched_distro}.txt', + )).read_text().strip() + except FileNotFoundError: + pass + log.debug('selected krunner: {}', matched_distro) + log.debug('selected libc style: {}', matched_libc_style) + log.debug('krunner volume: {}', krunner_volume) + log.debug('krunner python: {}', krunner_pyver) + arch = get_arch_name() + return arch, matched_distro, matched_libc_style, krunner_volume, krunner_pyver + + async def mount_vfolders( + self, + vfolders: Sequence[VFolderMount], + resource_spec: KernelResourceSpec, + ) -> None: + for vfolder in vfolders: + if self.internal_data.get('prevent_vfolder_mounts', False): + # Only allow mount of ".logs" directory to prevent expose + # internal-only information, such as Docker credentials to user's ".docker" vfolder + # in image importer kernels. + if vfolder.name != '.logs': + continue + mount = Mount( + MountTypes.BIND, + Path(vfolder.host_path), + Path(vfolder.kernel_path), + vfolder.mount_perm, + ) + resource_spec.mounts.append(mount) + + async def mount_krunner( + self, + resource_spec: KernelResourceSpec, + environ: MutableMapping[str, str], + ) -> None: + + def _mount( + type, src, dst, + ): + resource_spec.mounts.append( + self.get_runner_mount( + type, src, dst, + MountPermission('ro'), + ), + ) + + # Inject Backend.AI kernel runner dependencies. + image_labels = self.kernel_config['image']['labels'] + distro = image_labels.get('ai.backend.base-distro', 'ubuntu16.04') + + arch, matched_distro, matched_libc_style, krunner_volume, krunner_pyver = \ + self.get_krunner_info() + artifact_path = Path(pkg_resources.resource_filename( + 'ai.backend.agent', '../runner')) + + def find_artifacts(pattern: str) -> Mapping[str, str]: + artifacts = {} + for p in artifact_path.glob(pattern): + m = self._rx_distro.search(p.name) + if m is not None: + artifacts[m.group(1)] = p.name + return artifacts + + suexec_candidates = find_artifacts(f"su-exec.*.{arch}.bin") + _, suexec_candidate = match_distro_data(suexec_candidates, distro) + suexec_path = self.resolve_krunner_filepath('runner/' + suexec_candidate) + + hook_candidates = find_artifacts(f"libbaihook.*.{arch}.so") + _, hook_candidate = match_distro_data(hook_candidates, distro) + hook_path = self.resolve_krunner_filepath('runner/' + hook_candidate) + + sftp_server_candidates = find_artifacts(f"sftp-server.*.{arch}.bin") + _, sftp_server_candidate = match_distro_data(sftp_server_candidates, distro) + sftp_server_path = self.resolve_krunner_filepath('runner/' + sftp_server_candidate) + + scp_candidates = find_artifacts(f"scp.*.{arch}.bin") + _, scp_candidate = match_distro_data(scp_candidates, distro) + scp_path = self.resolve_krunner_filepath('runner/' + scp_candidate) + + jail_path: Optional[Path] + if self.local_config['container']['sandbox-type'] == 'jail': + jail_candidates = find_artifacts(f"jail.*.{arch}.bin") + _, jail_candidate = match_distro_data(jail_candidates, distro) + jail_path = self.resolve_krunner_filepath('runner/' + jail_candidate) + else: + jail_path = None + + kernel_pkg_path = self.resolve_krunner_filepath('kernel') + helpers_pkg_path = self.resolve_krunner_filepath('helpers') + dropbear_path = self.resolve_krunner_filepath(f'runner/dropbear.{matched_libc_style}.{arch}.bin') + dropbearconv_path = \ + self.resolve_krunner_filepath(f'runner/dropbearconvert.{matched_libc_style}.{arch}.bin') + dropbearkey_path = \ + self.resolve_krunner_filepath(f'runner/dropbearkey.{matched_libc_style}.{arch}.bin') + tmux_path = self.resolve_krunner_filepath(f'runner/tmux.{matched_libc_style}.{arch}.bin') + dotfile_extractor_path = self.resolve_krunner_filepath('runner/extract_dotfiles.py') + persistent_files_warning_doc_path = \ + self.resolve_krunner_filepath('runner/DO_NOT_STORE_PERSISTENT_FILES_HERE.md') + entrypoint_sh_path = self.resolve_krunner_filepath('runner/entrypoint.sh') + + if matched_libc_style == 'musl': + terminfo_path = self.resolve_krunner_filepath('runner/terminfo.alpine3.8') + _mount(MountTypes.BIND, terminfo_path, '/home/work/.terminfo') + + _mount(MountTypes.BIND, dotfile_extractor_path, '/opt/kernel/extract_dotfiles.py') + _mount(MountTypes.BIND, entrypoint_sh_path, '/opt/kernel/entrypoint.sh') + _mount(MountTypes.BIND, suexec_path, '/opt/kernel/su-exec') + if jail_path is not None: + _mount(MountTypes.BIND, jail_path, '/opt/kernel/jail') + _mount(MountTypes.BIND, hook_path, '/opt/kernel/libbaihook.so') + _mount(MountTypes.BIND, dropbear_path, '/opt/kernel/dropbear') + _mount(MountTypes.BIND, dropbearconv_path, '/opt/kernel/dropbearconvert') + _mount(MountTypes.BIND, dropbearkey_path, '/opt/kernel/dropbearkey') + _mount(MountTypes.BIND, tmux_path, '/opt/kernel/tmux') + _mount(MountTypes.BIND, sftp_server_path, '/usr/libexec/sftp-server') + _mount(MountTypes.BIND, scp_path, '/usr/bin/scp') + _mount(MountTypes.BIND, persistent_files_warning_doc_path, + '/home/work/DO_NOT_STORE_PERSISTENT_FILES_HERE.md') + + _mount(MountTypes.VOLUME, krunner_volume, '/opt/backend.ai') + pylib_path = f'/opt/backend.ai/lib/python{krunner_pyver}/site-packages/' + _mount(MountTypes.BIND, kernel_pkg_path, + pylib_path + 'ai/backend/kernel') + _mount(MountTypes.BIND, helpers_pkg_path, + pylib_path + 'ai/backend/helpers') + environ['LD_PRELOAD'] = '/opt/kernel/libbaihook.so' + + # Inject ComputeDevice-specific env-varibles and hooks + already_injected_hooks: Set[Path] = set() + for dev_type, device_alloc in resource_spec.allocations.items(): + computer_set = self.computers[dev_type] + await self.apply_accelerator_allocation( + computer_set.instance, device_alloc, + ) + alloc_sum = Decimal(0) + for dev_id, per_dev_alloc in device_alloc.items(): + alloc_sum += sum(per_dev_alloc.values()) + if alloc_sum > 0: + hook_paths = await computer_set.instance.get_hooks(distro, arch) + if hook_paths: + log.debug('accelerator {} provides hooks: {}', + type(computer_set.instance).__name__, + ', '.join(map(str, hook_paths))) + for hook_path in map(lambda p: Path(p).absolute(), hook_paths): + if hook_path in already_injected_hooks: + continue + container_hook_path = f"/opt/kernel/{hook_path.name}" + _mount(MountTypes.BIND, hook_path, container_hook_path) + environ['LD_PRELOAD'] += ':' + container_hook_path + already_injected_hooks.add(hook_path) + + +KernelCreationContextType = TypeVar('KernelCreationContextType', bound=AbstractKernelCreationContext) + + +@attr.s(auto_attribs=True, slots=True) +class RestartTracker: + request_lock: asyncio.Lock + destroy_event: asyncio.Event + done_event: asyncio.Event + + +@attr.s(auto_attribs=True, slots=True) +class ComputerContext: + instance: AbstractComputePlugin + devices: Collection[AbstractComputeDevice] + alloc_map: AbstractAllocMap + + +class AbstractAgent(aobject, Generic[KernelObjectType, KernelCreationContextType], metaclass=ABCMeta): + + loop: asyncio.AbstractEventLoop + local_config: Mapping[str, Any] + etcd: AsyncEtcd + local_instance_id: str + kernel_registry: MutableMapping[KernelId, AbstractKernel] + computers: MutableMapping[str, ComputerContext] + images: Mapping[str, str] + port_pool: Set[int] + + redis: aioredis.Redis + zmq_ctx: zmq.asyncio.Context + + restarting_kernels: MutableMapping[KernelId, RestartTracker] + terminating_kernels: Set[KernelId] + timer_tasks: MutableSequence[asyncio.Task] + container_lifecycle_queue: asyncio.Queue[ContainerLifecycleEvent | Sentinel] + + stat_ctx: StatContext + stat_sync_sockpath: Path + stat_sync_task: asyncio.Task + + stats_monitor: StatsPluginContext # unused currently + error_monitor: ErrorPluginContext # unused in favor of produce_error_event() + + _pending_creation_tasks: Dict[KernelId, Set[asyncio.Task]] + _ongoing_exec_batch_tasks: weakref.WeakSet[asyncio.Task] + _ongoing_destruction_tasks: weakref.WeakValueDictionary[KernelId, asyncio.Task] + + def __init__( + self, + etcd: AsyncEtcd, + local_config: Mapping[str, Any], + *, + stats_monitor: StatsPluginContext, + error_monitor: ErrorPluginContext, + skip_initial_scan: bool = False, + ) -> None: + self._skip_initial_scan = skip_initial_scan + self.loop = current_loop() + self.etcd = etcd + self.local_config = local_config + self.local_instance_id = generate_local_instance_id(__file__) + self.kernel_registry = {} + self.computers = {} + self.images = {} # repoTag -> digest + self.restarting_kernels = {} + self.terminating_kernels = set() + self.stat_ctx = StatContext( + self, mode=StatModes(local_config['container']['stats-type']), + ) + self.timer_tasks = [] + self.port_pool = set(range( + local_config['container']['port-range'][0], + local_config['container']['port-range'][1] + 1, + )) + self.stats_monitor = stats_monitor + self.error_monitor = error_monitor + self._pending_creation_tasks = defaultdict(set) + self._ongoing_exec_batch_tasks = weakref.WeakSet() + self._ongoing_destruction_tasks = weakref.WeakValueDictionary() + + async def __ainit__(self) -> None: + """ + An implementation of AbstractAgent would define its own ``__ainit__()`` method. + It must call this super method in an appropriate order, only once. + """ + self.resource_lock = asyncio.Lock() + self.registry_lock = asyncio.Lock() + self.container_lifecycle_queue = asyncio.Queue() + + self.event_producer = await EventProducer.new( + self.local_config['redis'], + db=4, + log_events=self.local_config['debug']['log-events'], + ) + self.redis_stream_pool = redis.get_redis_object(self.local_config['redis'], db=4) + self.redis_stat_pool = redis.get_redis_object(self.local_config['redis'], db=0) + + self.zmq_ctx = zmq.asyncio.Context() + + resources_mod.log_alloc_map = self.local_config['debug']['log-alloc-map'] + computers, self.slots = await self.detect_resources() + for name, computer in computers.items(): + devices = await computer.list_devices() + alloc_map = await computer.create_alloc_map() + self.computers[name] = ComputerContext(computer, devices, alloc_map) + + if not self._skip_initial_scan: + self.images = await self.scan_images() + self.timer_tasks.append(aiotools.create_timer(self._scan_images_wrapper, 20.0)) + await self.scan_running_kernels() + + # Prepare stat collector tasks. + self.timer_tasks.append(aiotools.create_timer(self.collect_node_stat, 5.0)) + self.timer_tasks.append(aiotools.create_timer(self.collect_container_stat, 5.0)) + + # Prepare heartbeats. + self.timer_tasks.append(aiotools.create_timer(self.heartbeat, 3.0)) + + # Prepare auto-cleaning of idle kernels. + self.timer_tasks.append(aiotools.create_timer(self.sync_container_lifecycles, 10.0)) + + loop = current_loop() + self.last_registry_written_time = time.monotonic() + self.container_lifecycle_handler = loop.create_task(self.process_lifecycle_events()) + + # Notify the gateway. + await self.produce_event(AgentStartedEvent(reason="self-started")) + + async def shutdown(self, stop_signal: signal.Signals) -> None: + """ + An implementation of AbstractAgent would define its own ``shutdown()`` method. + It must call this super method in an appropriate order, only once. + """ + await cancel_tasks(self._ongoing_exec_batch_tasks) + + async with self.registry_lock: + # Close all pending kernel runners. + for kernel_obj in self.kernel_registry.values(): + if kernel_obj.runner is not None: + await kernel_obj.runner.close() + await kernel_obj.close() + if stop_signal == signal.SIGTERM: + await self.clean_all_kernels(blocking=True) + + # Stop timers. + cancel_results = await cancel_tasks(self.timer_tasks) + for result in cancel_results: + if isinstance(result, Exception): + log.error('timer cancellation error: {}', result) + + # Stop lifecycle event handler. + await self.container_lifecycle_queue.put(_sentinel) + await self.container_lifecycle_handler + + # Notify the gateway. + await self.produce_event(AgentTerminatedEvent(reason="shutdown")) + + # Shut down the event dispatcher and Redis connection pools. + await self.event_producer.close() + await self.redis_stream_pool.close() + await self.redis_stat_pool.close() + + self.zmq_ctx.term() + + async def produce_event(self, event: AbstractEvent) -> None: + """ + Send an event to the manager(s). + """ + if self.local_config['debug']['log-heartbeats']: + _log = log.debug if isinstance(event, AgentHeartbeatEvent) else log.info + else: + _log = (lambda *args: None) if isinstance(event, AgentHeartbeatEvent) else log.info + if self.local_config['debug']['log-events']: + _log('produce_event({0})', event) + if isinstance(event, KernelTerminatedEvent): + pending_creation_tasks = self._pending_creation_tasks.get(event.kernel_id, None) + if pending_creation_tasks is not None: + for t in set(pending_creation_tasks): + if not t.done() and not t.cancelled(): + t.cancel() + try: + await t + except asyncio.CancelledError: + continue + await self.event_producer.produce_event(event, source=self.local_config['agent']['id']) + + async def produce_error_event( + self, + exc_info: Tuple[Type[BaseException], BaseException, TracebackType] = None, + ) -> None: + exc_type, exc, tb = sys.exc_info() if exc_info is None else exc_info + pretty_message = ''.join(traceback.format_exception_only(exc_type, exc)).strip() + pretty_tb = ''.join(traceback.format_tb(tb)).strip() + await self.produce_event(AgentErrorEvent(pretty_message, pretty_tb)) + + async def heartbeat(self, interval: float): + """ + Send my status information and available kernel images to the manager(s). + """ + res_slots = {} + try: + for cctx in self.computers.values(): + for slot_key, slot_type in cctx.instance.slot_types: + res_slots[slot_key] = ( + slot_type, + str(self.slots.get(slot_key, 0)), + ) + agent_info = { + 'ip': str(self.local_config['agent']['rpc-listen-addr'].host), + 'region': self.local_config['agent']['region'], + 'scaling_group': self.local_config['agent']['scaling-group'], + 'addr': f"tcp://{self.local_config['agent']['rpc-listen-addr']}", + 'resource_slots': res_slots, + 'version': VERSION, + 'compute_plugins': { + key: { + 'version': computer.instance.get_version(), + **(await computer.instance.extra_info()), + } + for key, computer in self.computers.items() + }, + 'images': snappy.compress(msgpack.packb([ + (repo_tag, digest) for repo_tag, digest in self.images.items() + ])), + 'architecture': get_arch_name(), + } + await self.produce_event(AgentHeartbeatEvent(agent_info)) + except asyncio.TimeoutError: + log.warning('event dispatch timeout: instance_heartbeat') + except Exception: + log.exception('instance_heartbeat failure') + await self.produce_error_event() + + async def collect_logs( + self, + kernel_id: KernelId, + container_id: str, + async_log_iterator: AsyncIterator[bytes], + ) -> None: + chunk_size = self.local_config['agent']['container-logs']['chunk-size'] + log_key = f'containerlog.{container_id}' + log_length = 0 + chunk_buffer = BytesIO() + chunk_length = 0 + try: + async with aiotools.aclosing(async_log_iterator): + async for fragment in async_log_iterator: + fragment_length = len(fragment) + chunk_buffer.write(fragment) + chunk_length += fragment_length + log_length += fragment_length + while chunk_length >= chunk_size: + cb = chunk_buffer.getbuffer() + stored_chunk = bytes(cb[:chunk_size]) + await redis.execute( + self.redis_stream_pool, + lambda r: r.rpush( + log_key, stored_chunk), + ) + remaining = cb[chunk_size:] + chunk_length = len(remaining) + next_chunk_buffer = BytesIO(remaining) + next_chunk_buffer.seek(0, SEEK_END) + del remaining, cb + chunk_buffer.close() + chunk_buffer = next_chunk_buffer + assert chunk_length < chunk_size + if chunk_length > 0: + await redis.execute( + self.redis_stream_pool, + lambda r: r.rpush( + log_key, chunk_buffer.getvalue()), + ) + finally: + chunk_buffer.close() + # Keep the log for at most one hour in Redis. + # This is just a safety measure to prevent memory leak in Redis + # for cases when the event delivery has failed or processing + # the log data has failed. + await redis.execute( + self.redis_stream_pool, + lambda r: r.expire(log_key, 3600), + ) + await self.produce_event(DoSyncKernelLogsEvent(kernel_id, container_id)) + + async def collect_node_stat(self, interval: float): + if self.local_config['debug']['log-stats']: + log.debug('collecting node statistics') + try: + await self.stat_ctx.collect_node_stat() + except asyncio.CancelledError: + pass + except Exception: + log.exception('unhandled exception while syncing node stats') + await self.produce_error_event() + + async def collect_container_stat(self, interval: float): + if self.local_config['debug']['log-stats']: + log.debug('collecting container statistics') + try: + updated_kernel_ids = [] + container_ids = [] + async with self.registry_lock: + for kernel_id, kernel_obj in [*self.kernel_registry.items()]: + if not kernel_obj.stats_enabled: + continue + updated_kernel_ids.append(kernel_id) + container_ids.append(kernel_obj['container_id']) + await self.stat_ctx.collect_container_stat(container_ids) + # Let the manager store the statistics in the persistent database. + if updated_kernel_ids: + await self.produce_event(DoSyncKernelStatsEvent(updated_kernel_ids)) + except asyncio.CancelledError: + pass + except Exception: + log.exception('unhandled exception while syncing container stats') + await self.produce_error_event() + + async def _handle_start_event(self, ev: ContainerLifecycleEvent) -> None: + async with self.registry_lock: + kernel_obj = self.kernel_registry.get(ev.kernel_id) + if kernel_obj is not None: + kernel_obj.stats_enabled = True + + async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: + try: + current_task = asyncio.current_task() + assert current_task is not None + if ev.kernel_id not in self._ongoing_destruction_tasks: + self._ongoing_destruction_tasks[ev.kernel_id] = current_task + self.terminating_kernels.add(ev.kernel_id) + async with self.registry_lock: + kernel_obj = self.kernel_registry.get(ev.kernel_id) + if kernel_obj is None: + log.warning('destroy_kernel(k:{0}) kernel missing (already dead?)', + ev.kernel_id) + if ev.container_id is None: + await self.rescan_resource_usage() + if not ev.suppress_events: + await self.produce_event( + KernelTerminatedEvent(ev.kernel_id, "already-terminated"), + ) + if ev.done_future is not None: + ev.done_future.set_result(None) + return + else: + await self.container_lifecycle_queue.put( + ContainerLifecycleEvent( + ev.kernel_id, + ev.container_id, + LifecycleEvent.CLEAN, + ev.reason, + suppress_events=ev.suppress_events, + done_future=ev.done_future, + ), + ) + else: + kernel_obj.stats_enabled = False + kernel_obj.termination_reason = ev.reason + if kernel_obj.runner is not None: + await kernel_obj.runner.close() + kernel_obj.clean_event = ev.done_future + try: + await self.destroy_kernel(ev.kernel_id, ev.container_id) + except Exception as e: + if ev.done_future is not None: + ev.done_future.set_exception(e) + raise + except asyncio.CancelledError: + pass + except Exception: + log.exception('unhandled exception while processing DESTROY event') + await self.produce_error_event() + + async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None: + destruction_task = self._ongoing_destruction_tasks.get(ev.kernel_id, None) + if destruction_task is not None and not destruction_task.done(): + # let the destruction task finish first + await destruction_task + del destruction_task + async with self.registry_lock: + try: + kernel_obj = self.kernel_registry.get(ev.kernel_id) + if kernel_obj is not None and kernel_obj.runner is not None: + await kernel_obj.runner.close() + await self.clean_kernel( + ev.kernel_id, + ev.container_id, + ev.kernel_id in self.restarting_kernels, + ) + except Exception as e: + if ev.done_future is not None: + ev.done_future.set_exception(e) + await self.produce_error_event() + finally: + if ev.kernel_id in self.restarting_kernels: + # Don't forget as we are restarting it. + kernel_obj = self.kernel_registry.get(ev.kernel_id, None) + else: + # Forget as we are done with this kernel. + kernel_obj = self.kernel_registry.pop(ev.kernel_id, None) + try: + if kernel_obj is not None: + # Restore used ports to the port pool. + port_range = self.local_config['container']['port-range'] + # Exclude out-of-range ports, because when the agent restarts + # with a different port range, existing containers' host ports + # may not belong to the new port range. + restored_ports = [*filter( + lambda p: port_range[0] <= p <= port_range[1], + kernel_obj['host_ports'], + )] + self.port_pool.update(restored_ports) + await kernel_obj.close() + finally: + self.terminating_kernels.discard(ev.kernel_id) + if restart_tracker := self.restarting_kernels.get(ev.kernel_id, None): + restart_tracker.destroy_event.set() + else: + await self.rescan_resource_usage() + if not ev.suppress_events: + await self.produce_event( + KernelTerminatedEvent(ev.kernel_id, ev.reason), + ) + # Notify cleanup waiters after all state updates. + if kernel_obj is not None and kernel_obj.clean_event is not None: + kernel_obj.clean_event.set_result(None) + if ev.done_future is not None and not ev.done_future.done(): + ev.done_future.set_result(None) + + async def process_lifecycle_events(self) -> None: + + async def lifecycle_task_exception_handler( + exc_type: Type[Exception], exc_obj: Exception, tb: TracebackType, + ) -> None: + log.exception("unexpected error in lifecycle task", exc_info=exc_obj) + + async with aiotools.PersistentTaskGroup( + exception_handler=lifecycle_task_exception_handler, + ) as tg: + ipc_base_path = self.local_config['agent']['ipc-base-path'] + while True: + ev = await self.container_lifecycle_queue.get() + now = time.monotonic() + if now > self.last_registry_written_time + 60 or isinstance(ev, Sentinel): + self.last_registry_written_time = now + with open(ipc_base_path / f'last_registry.{self.local_instance_id}.dat', 'wb') as f: + pickle.dump(self.kernel_registry, f) + log.debug(f'saved last_registry.{self.local_instance_id}.dat') + if isinstance(ev, Sentinel): + return + # attr currently does not support customizing getstate/setstate dunder methods + # until the next release. + if self.local_config['debug']['log-events']: + log.info(f'lifecycle event: {ev!r}') + try: + if ev.event == LifecycleEvent.START: + tg.create_task(self._handle_start_event(ev)) + elif ev.event == LifecycleEvent.DESTROY: + tg.create_task(self._handle_destroy_event(ev)) + elif ev.event == LifecycleEvent.CLEAN: + tg.create_task(self._handle_clean_event(ev)) + else: + log.warning('unsupported lifecycle event: {!r}', ev) + except Exception: + log.exception( + 'unexpected error in process_lifecycle_events(): {!r}, continuing...', ev, + ) + finally: + self.container_lifecycle_queue.task_done() + + async def inject_container_lifecycle_event( + self, + kernel_id: KernelId, + event: LifecycleEvent, + reason: str, + *, + container_id: ContainerId = None, + exit_code: int = None, + done_future: asyncio.Future = None, + suppress_events: bool = False, + ) -> None: + try: + kernel_obj = self.kernel_registry[kernel_id] + if kernel_obj.termination_reason: + reason = kernel_obj.termination_reason + if container_id is not None: + if event == LifecycleEvent.START: + # Update the container ID (for restarted kernels). + # This will be overwritten by create_kernel() soon, but + # updating here improves consistency of kernel_id to container_id + # mapping earlier. + kernel_obj['container_id'] = container_id + elif container_id != kernel_obj['container_id']: + # This should not happen! + log.warning( + "container id mismatch for kernel_obj (k:{}, c:{}) with event (e:{}, c:{})", + kernel_id, kernel_obj['container_id'], + event.name, container_id, + ) + container_id = kernel_obj['container_id'] + except KeyError: + if event == LifecycleEvent.START: + # When creating a new kernel, the kernel_registry is not populated yet + # during creation of actual containers. + # The Docker daemon may publish the container creation event before + # returning the API and our async handlers may deliver the event earlier. + # In such cases, it is safe to ignore the missing kernel_regisry item. + pass + else: + log.warning( + "injecting lifecycle event (e:{}) for unknown kernel (k:{})", + event.name, kernel_id, + ) + await self.container_lifecycle_queue.put( + ContainerLifecycleEvent( + kernel_id, + container_id, + event, + reason, + done_future, + exit_code, + suppress_events, + ), + ) + + @abstractmethod + async def enumerate_containers( + self, + status_filter: FrozenSet[ContainerStatus] = ACTIVE_STATUS_SET, + ) -> Sequence[Tuple[KernelId, Container]]: + """ + Enumerate the containers with the given status filter. + """ + + async def rescan_resource_usage(self) -> None: + async with self.resource_lock: + for computer_set in self.computers.values(): + computer_set.alloc_map.clear() + for kernel_id, container in (await self.enumerate_containers()): + for computer_set in self.computers.values(): + try: + await computer_set.instance.restore_from_container( + container, + computer_set.alloc_map, + ) + except Exception: + log.warning( + "rescan_resoucre_usage(k:{}): " + "failed to read kernel resource info; " + "maybe already terminated", + kernel_id, + ) + + async def sync_container_lifecycles(self, interval: float) -> None: + """ + Periodically synchronize the alive/known container sets, + for cases when we miss the container lifecycle events from the underlying implementation APIs + due to the agent restarts or crashes. + """ + known_kernels: Dict[KernelId, ContainerId] = {} + alive_kernels: Dict[KernelId, ContainerId] = {} + terminated_kernels = {} + + async with self.resource_lock: + try: + # Check if: there are dead containers + for kernel_id, container in (await self.enumerate_containers(DEAD_STATUS_SET)): + if kernel_id in self.restarting_kernels or kernel_id in self.terminating_kernels: + continue + log.info('detected dead container during lifeycle sync (k:{}, c:{})', + kernel_id, container.id) + terminated_kernels[kernel_id] = ContainerLifecycleEvent( + kernel_id, + known_kernels[kernel_id], + LifecycleEvent.CLEAN, + 'self-terminated', + ) + for kernel_id, container in (await self.enumerate_containers(ACTIVE_STATUS_SET)): + alive_kernels[kernel_id] = container.id + for kernel_id, kernel_obj in self.kernel_registry.items(): + known_kernels[kernel_id] = kernel_obj['container_id'] + # Check if: kernel_registry has the container but it's gone. + for kernel_id in (known_kernels.keys() - alive_kernels.keys()): + if kernel_id in self.restarting_kernels or kernel_id in self.terminating_kernels: + continue + terminated_kernels[kernel_id] = ContainerLifecycleEvent( + kernel_id, + known_kernels[kernel_id], + LifecycleEvent.CLEAN, + 'self-terminated', + ) + # Check if: there are containers not spawned by me. + for kernel_id in (alive_kernels.keys() - known_kernels.keys()): + if kernel_id in self.restarting_kernels: + continue + terminated_kernels[kernel_id] = ContainerLifecycleEvent( + kernel_id, + alive_kernels[kernel_id], + LifecycleEvent.DESTROY, + 'terminated-unknown-container', + ) + finally: + # Enqueue the events. + for kernel_id, ev in terminated_kernels.items(): + await self.container_lifecycle_queue.put(ev) + + async def clean_all_kernels(self, blocking: bool = False) -> None: + kernel_ids = [*self.kernel_registry.keys()] + clean_events = {} + loop = asyncio.get_running_loop() + if blocking: + for kernel_id in kernel_ids: + clean_events[kernel_id] = loop.create_future() + for kernel_id in kernel_ids: + await self.inject_container_lifecycle_event( + kernel_id, + LifecycleEvent.DESTROY, + 'agent-termination', + done_future=clean_events[kernel_id] if blocking else None, + ) + if blocking: + waiters = [clean_events[kernel_id] for kernel_id in kernel_ids] + await asyncio.gather(*waiters) + + @abstractmethod + async def detect_resources(self) -> Tuple[ + Mapping[DeviceName, AbstractComputePlugin], + Mapping[SlotName, Decimal], + ]: + """ + Scan and define the amount of available resource slots in this node. + """ + + async def gather_hwinfo(self) -> Mapping[str, HardwareMetadata]: + """ + Collect the hardware metadata from the compute plugins. + """ + hwinfo: Dict[str, HardwareMetadata] = {} + tasks = [] + + async def _get( + key: str, plugin: AbstractComputePlugin, + ) -> Tuple[str, Union[Exception, HardwareMetadata]]: + try: + result = await plugin.get_node_hwinfo() + return key, result + except Exception as e: + return key, e + + for key, plugin in self.computers.items(): + tasks.append(_get(key, plugin.instance)) + results = await asyncio.gather(*tasks, return_exceptions=True) + for key, result in results: + if isinstance(result, NotImplementedError): + continue + elif isinstance(result, Exception): + hwinfo[key] = { + 'status': "unavailable", + 'status_info': str(result), + 'metadata': {}, + } + else: + hwinfo[key] = result + return hwinfo + + @abstractmethod + async def scan_images(self) -> Mapping[str, str]: + """ + Scan the available kernel images/templates and update ``self.images``. + This is called periodically to keep the image list up-to-date and allow + manual image addition and deletions by admins. + """ + + async def _scan_images_wrapper(self, interval: float) -> None: + self.images = await self.scan_images() + + @abstractmethod + async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> None: + ''' + Pull the given image from the given registry. + ''' + + @abstractmethod + async def check_image(self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior) -> bool: + ''' + Check the availability of the image and return a boolean flag that indicates whether + the agent should try pulling the image from a registry. + ''' + return False + + async def scan_running_kernels(self) -> None: + """ + Scan currently running kernels and recreate the kernel objects in + ``self.kernel_registry`` if any missing. + """ + try: + ipc_base_path = self.local_config['agent']['ipc-base-path'] + with open(ipc_base_path / f'last_registry.{self.local_instance_id}.dat', 'rb') as f: + self.kernel_registry = pickle.load(f) + for kernel_obj in self.kernel_registry.values(): + kernel_obj.agent_config = self.local_config + if kernel_obj.runner is not None: + await kernel_obj.runner.__ainit__() + except FileNotFoundError: + pass + async with self.resource_lock: + for kernel_id, container in (await self.enumerate_containers( + ACTIVE_STATUS_SET | DEAD_STATUS_SET, + )): + if container.status in ACTIVE_STATUS_SET: + kernelspec = int(container.labels.get('ai.backend.kernelspec', '1')) + if not (MIN_KERNELSPEC <= kernelspec <= MAX_KERNELSPEC): + continue + # Consume the port pool. + for p in container.ports: + if p.host_port is not None: + self.port_pool.discard(p.host_port) + # Restore compute resources. + for computer_set in self.computers.values(): + await computer_set.instance.restore_from_container( + container, + computer_set.alloc_map, + ) + await self.inject_container_lifecycle_event( + kernel_id, + LifecycleEvent.START, + 'resuming-agent-operation', + container_id=container.id, + ) + elif container.status in DEAD_STATUS_SET: + log.info('detected dead container while agent is down (k:{}, c:{})', + kernel_id, container.id) + await self.inject_container_lifecycle_event( + kernel_id, + LifecycleEvent.CLEAN, + 'self-terminated', + container_id=container.id, + ) + + log.info('starting with resource allocations') + for computer_name, computer_ctx in self.computers.items(): + log.info('{}: {!r}', + computer_name, + dict(computer_ctx.alloc_map.allocations)) + + @abstractmethod + async def init_kernel_context( + self, + kernel_id: KernelId, + kernel_config: KernelCreationConfig, + *, + restarting: bool = False, + ) -> AbstractKernelCreationContext: + raise NotImplementedError + + async def execute_batch( + self, + kernel_id: KernelId, + startup_command: str, + ) -> None: + kernel_obj = self.kernel_registry.get(kernel_id, None) + if kernel_obj is None: + log.warning('execute_batch(k:{}): no such kernel', kernel_id) + return + log.debug('execute_batch(k:{}): executing {!r}', kernel_id, (startup_command or '')[:60]) + mode: Literal['batch', 'continue'] = 'batch' + opts = { + 'exec': startup_command, + } + try: + while True: + try: + result = await self.execute( + kernel_id, + 'batch-job', # a reserved run ID + mode, + '', + opts=opts, + flush_timeout=1.0, + api_version=3) + except KeyError: + await self.produce_event( + KernelTerminatedEvent(kernel_id, "self-terminated"), + ) + break + + if result['status'] == 'finished': + if result['exitCode'] == 0: + await self.produce_event( + SessionSuccessEvent(SessionId(kernel_id), "task-done", 0), + ) + else: + await self.produce_event( + SessionFailureEvent(SessionId(kernel_id), "task-failed", result['exitCode']), + ) + break + if result['status'] == 'exec-timeout': + await self.produce_event( + SessionFailureEvent(SessionId(kernel_id), "task-timeout", -2), + ) + break + opts = { + 'exec': '', + } + mode = 'continue' + except asyncio.CancelledError: + await self.produce_event( + SessionFailureEvent(SessionId(kernel_id), "task-cancelled", -2), + ) + + async def create_kernel( + self, + creation_id: str, + session_id: SessionId, + kernel_id: KernelId, + kernel_config: KernelCreationConfig, + cluster_info: ClusterInfo, + *, + restarting: bool = False, + ) -> KernelCreationResult: + """ + Create a new kernel. + """ + + if not restarting: + await self.produce_event( + KernelPreparingEvent(kernel_id, creation_id), + ) + + # Initialize the creation context + if self.local_config['debug']['log-kernel-config']: + log.debug('Kernel creation config: {0}', pretty(kernel_config)) + ctx = await self.init_kernel_context( + kernel_id, kernel_config, + restarting=restarting, + ) + environ: MutableMapping[str, str] = {**kernel_config['environ']} + + # Inject Backend.AI-intrinsic env-variables for gosu + if KernelFeatures.UID_MATCH in ctx.kernel_features: + uid = self.local_config['container']['kernel-uid'] + gid = self.local_config['container']['kernel-gid'] + environ['LOCAL_USER_ID'] = str(uid) + environ['LOCAL_GROUP_ID'] = str(gid) + environ.update( + await ctx.get_extra_envs(), + ) + image_labels = kernel_config['image']['labels'] + + agent_architecture = get_arch_name() + if agent_architecture != ctx.image_ref.architecture: + # disable running different architecture's image + raise AgentError( + f'cannot run {ctx.image_ref.architecture} image on {agent_architecture} machine', + ) + + # Check if we need to pull the container image + do_pull = await self.check_image( + ctx.image_ref, + kernel_config['image']['digest'], + AutoPullBehavior(kernel_config.get('auto_pull', 'digest')), + ) + if do_pull: + await self.produce_event( + KernelPullingEvent(kernel_id, creation_id, ctx.image_ref.canonical), + ) + await self.pull_image(ctx.image_ref, kernel_config['image']['registry']) + + if not restarting: + await self.produce_event( + KernelCreatingEvent(kernel_id, creation_id), + ) + + # Get the resource spec from existing kernel scratches + # or create a new resource spec from ctx.kernel_config + resource_spec, resource_opts = await ctx.prepare_resource_spec() + # When creating a new kernel, + # we need to allocate agent resources, prepare the networks, + # adn specify the container mounts. + + # Mount backend-specific intrinsic mounts (e.g., scratch directories) + if not restarting: + resource_spec.mounts.extend( + await ctx.get_intrinsic_mounts(), + ) + + # Realize ComputeDevice (including accelerators) allocations. + slots = resource_spec.slots + dev_names: Set[DeviceName] = set() + for slot_name in slots.keys(): + dev_name = slot_name.split('.', maxsplit=1)[0] + dev_names.add(DeviceName(dev_name)) + + if not restarting: + async with self.resource_lock: + for dev_name in dev_names: + computer_set = self.computers[dev_name] + device_specific_slots = { + SlotName(slot_name): Decimal(alloc) + for slot_name, alloc in slots.items() + if slot_name.startswith(dev_name) + } + try: + # TODO: support allocate_evenly() + resource_spec.allocations[dev_name] = \ + computer_set.alloc_map.allocate( + device_specific_slots, + context_tag=dev_name) + except ResourceError as e: + log.info( + "resource allocation failed ({}): {} of {}\n" + "(alloc map: {})", + type(e).__name__, device_specific_slots, dev_name, + dict(computer_set.alloc_map.allocations), + ) + raise + + # Prepare scratch spaces and dotfiles inside it. + if not restarting: + await ctx.prepare_scratch() + + # Prepare networking. + await ctx.apply_network(cluster_info) + await ctx.install_ssh_keypair(cluster_info) + + # Mount vfolders and krunner stuffs. + if not restarting: + vfolder_mounts = [VFolderMount.from_json(item) for item in kernel_config['mounts']] + await ctx.mount_vfolders(vfolder_mounts, resource_spec) + await ctx.mount_krunner(resource_spec, environ) + + # Inject Backend.AI-intrinsic env-variables for libbaihook and gosu + label_envs_corecount = image_labels.get('ai.backend.envs.corecount', '') + envs_corecount = label_envs_corecount.split(',') if label_envs_corecount else [] + cpu_core_count = len(resource_spec.allocations[DeviceName('cpu')][SlotName('cpu')]) + environ.update({k: str(cpu_core_count) for k in envs_corecount if k not in environ}) + + # Realize mounts. + await ctx.process_mounts(resource_spec.mounts) + + # Get attached devices information (including model_name). + attached_devices = {} + for dev_name, device_alloc in resource_spec.allocations.items(): + computer_set = self.computers[dev_name] + devices = await computer_set.instance.get_attached_devices(device_alloc) + attached_devices[dev_name] = devices + + exposed_ports = [2000, 2001] + service_ports = [] + port_map = {} + preopen_ports = ctx.kernel_config.get('preopen_ports') + if preopen_ports is None: + preopen_ports = [] + + if ctx.kernel_config['cluster_role'] in ('main', 'master'): + for sport in parse_service_ports(image_labels.get('ai.backend.service-ports', '')): + port_map[sport['name']] = sport + port_map['sshd'] = { + 'name': 'sshd', + 'protocol': ServicePortProtocols('tcp'), + 'container_ports': (2200,), + 'host_ports': (None,), + } + port_map['ttyd'] = { + 'name': 'ttyd', + 'protocol': ServicePortProtocols('http'), + 'container_ports': (7681,), + 'host_ports': (None,), + } + for port_no in preopen_ports: + sport = { + 'name': str(port_no), + 'protocol': ServicePortProtocols('preopen'), + 'container_ports': (port_no,), + 'host_ports': (None,), + } + service_ports.append(sport) + for cport in sport['container_ports']: + exposed_ports.append(cport) + for sport in port_map.values(): + service_ports.append(sport) + for cport in sport['container_ports']: + exposed_ports.append(cport) + log.debug('exposed ports: {!r}', exposed_ports) + + runtime_type = image_labels.get('ai.backend.runtime-type', 'python') + runtime_path = image_labels.get('ai.backend.runtime-path', None) + cmdargs: List[str] = [] + if self.local_config['container']['sandbox-type'] == 'jail': + cmdargs += [ + "/opt/kernel/jail", + "-policy", "/etc/backend.ai/jail/policy.yml", + ] + if self.local_config['container']['jail-args']: + cmdargs += map(lambda s: s.strip(), self.local_config['container']['jail-args']) + cmdargs += [ + "/opt/backend.ai/bin/python", + "-m", "ai.backend.kernel", runtime_type, + ] + if runtime_path is not None: + cmdargs.append(runtime_path) + + # Store information required for restarts. + # NOTE: kconfig may be updated after restarts. + resource_spec.freeze() + await self.restart_kernel__store_config( + kernel_id, 'kconfig.dat', + pickle.dumps(ctx.kernel_config), + ) + if not restarting: + await self.restart_kernel__store_config( + kernel_id, 'cluster.json', + json.dumps(cluster_info).encode('utf8'), + ) + + if self.local_config['debug']['log-kernel-config']: + log.info('kernel starting with resource spec: \n{0}', + pretty(attr.asdict(resource_spec))) + kernel_obj: KernelObjectType = await ctx.spawn( + resource_spec, + environ, + service_ports, + ) + async with self.registry_lock: + self.kernel_registry[ctx.kernel_id] = kernel_obj + container_data = await ctx.start_container( + kernel_obj, + cmdargs, + resource_opts, + preopen_ports, + ) + async with self.registry_lock: + self.kernel_registry[ctx.kernel_id].data.update(container_data) + await kernel_obj.init() + + current_task = asyncio.current_task() + assert current_task is not None + self._pending_creation_tasks[kernel_id].add(current_task) + try: + async for attempt in AsyncRetrying( + wait=wait_fixed(0.3), + stop=(stop_after_attempt(10) | stop_after_delay(60)), + retry=retry_if_exception_type(zmq.error.ZMQError), + ): + with attempt: + # Wait until bootstrap script is executed. + # - Main kernel runner is executed after bootstrap script, and + # check_status is accessible only after kernel runner is loaded. + await kernel_obj.check_status() + # Update the service-ports metadata from the image labels + # with the extended template metadata from the agent and krunner. + live_services = await kernel_obj.get_service_apps() + if live_services['status'] != 'failed': + for live_service in live_services['data']: + for service_port in service_ports: + if live_service['name'] == service_port['name']: + service_port.update(live_service) + break + if self.local_config['debug']['log-kernel-config']: + log.debug('service ports:\n{!r}', pretty(service_ports)) + except asyncio.CancelledError: + log.warning("cancelled waiting of container startup (k:{})", kernel_id) + raise + except Exception: + log.exception("unexpected error while waiting container startup (k:{})", kernel_id) + raise RuntimeError( + "cancelled waiting of container startup due to initialization failure", + ) + finally: + self._pending_creation_tasks[kernel_id].remove(current_task) + if not self._pending_creation_tasks[kernel_id]: + del self._pending_creation_tasks[kernel_id] + + # Finally we are done. + await self.produce_event( + KernelStartedEvent(kernel_id, creation_id), + ) + + if kernel_config['session_type'] == 'batch' and kernel_config['cluster_role'] == 'main': + self._ongoing_exec_batch_tasks.add( + asyncio.create_task( + self.execute_batch(kernel_id, kernel_config['startup_command'] or ""), + ), + ) + + # The startup command for the batch-type sessions will be executed by the manager + # upon firing of the "session_started" event. + + return { + 'id': KernelId(kernel_id), + 'kernel_host': str(kernel_obj['kernel_host']), + 'repl_in_port': kernel_obj['repl_in_port'], + 'repl_out_port': kernel_obj['repl_out_port'], + 'stdin_port': kernel_obj['stdin_port'], # legacy + 'stdout_port': kernel_obj['stdout_port'], # legacy + 'service_ports': service_ports, + 'container_id': kernel_obj['container_id'], + 'resource_spec': resource_spec.to_json_serializable_dict(), + 'attached_devices': attached_devices, + } + + @abstractmethod + async def destroy_kernel( + self, + kernel_id: KernelId, + container_id: Optional[ContainerId], + ) -> None: + """ + Initiate destruction of the kernel. + + Things to do: + * Send SIGTERM to the kernel's main process. + * Send SIGKILL if it's not terminated within a few seconds. + """ + + @abstractmethod + async def clean_kernel( + self, + kernel_id: KernelId, + container_id: Optional[ContainerId], + restarting: bool, + ) -> None: + """ + Clean up kernel-related book-keepers when the underlying + implementation detects an event that the kernel has terminated. + + Things to do: + * Call :meth:`self.collect_logs()` to store the container's console outputs. + * Delete the underlying kernel resource (e.g., container) + * Release host-specific resources used for the kernel (e.g., scratch spaces) + + This method is intended to be called asynchronously by the implementation-specific + event monitoring routine. + + The ``container_id`` may be ``None`` if the container has already gone away. + In such cases, skip container-specific cleanups. + """ + + @abstractmethod + async def create_overlay_network(self, network_name: str) -> None: + """ + Create an overlay network for a multi-node multicontainer session, where containers in different + agents can connect to each other using cluster hostnames without explicit port mapping. + + This is called by the manager before kernel creation. + It may raise :exc:`NotImplementedError` and then the manager + will cancel creation of the session. + """ + + @abstractmethod + async def destroy_overlay_network(self, network_name: str) -> None: + """ + Destroy an overlay network. + + This is called by the manager after kernel destruction. + """ + + @abstractmethod + async def create_local_network(self, network_name: str) -> None: + """ + Create a local bridge network for a single-node multicontainer session, where containers in the + same agent can connect to each other using cluster hostnames without explicit port mapping. + Depending on the backend, this may be an alias to :meth:`create_overlay_network()`. + + This is called by the manager before kernel creation. + It may raise :exc:`NotImplementedError` and then the manager + will cancel creation of the session. + """ + + @abstractmethod + async def destroy_local_network(self, network_name: str) -> None: + """ + Destroy a local bridge network. + Depending on the backend, this may be an alias to :meth:`destroy_overlay_network()`. + + This is called by the manager after kernel destruction. + """ + + @abstractmethod + async def restart_kernel__load_config( + self, + kernel_id: KernelId, + name: str, + ) -> bytes: + """ + Restore the cluster config from a previous launch of the kernel. + """ + pass + + @abstractmethod + async def restart_kernel__store_config( + self, + kernel_id: KernelId, + name: str, + data: bytes, + ) -> None: + """ + Store the cluster config to a kernel-related storage (e.g., scratch space), + so that restarts of this kernel can reuse the configuration. + """ + pass + + async def restart_kernel( + self, + creation_id: str, + session_id: SessionId, + kernel_id: KernelId, + updating_kernel_config: KernelCreationConfig, + ): + tracker = self.restarting_kernels.get(kernel_id) + if tracker is None: + tracker = RestartTracker( + request_lock=asyncio.Lock(), + destroy_event=asyncio.Event(), + done_event=asyncio.Event()) + self.restarting_kernels[kernel_id] = tracker + + existing_kernel_config = pickle.loads( + await self.restart_kernel__load_config(kernel_id, 'kconfig.dat'), + ) + existing_cluster_info = json.loads( + await self.restart_kernel__load_config(kernel_id, 'cluster.json'), + ) + kernel_config = cast( + KernelCreationConfig, + {**existing_kernel_config, **updating_kernel_config}, + ) + async with tracker.request_lock: + tracker.done_event.clear() + await self.inject_container_lifecycle_event( + kernel_id, + LifecycleEvent.DESTROY, + 'restarting', + ) + try: + with timeout(60): + await tracker.destroy_event.wait() + except asyncio.TimeoutError: + log.warning('timeout detected while restarting kernel {0}!', + kernel_id) + self.restarting_kernels.pop(kernel_id, None) + await self.inject_container_lifecycle_event( + kernel_id, + LifecycleEvent.CLEAN, + 'restart-timeout', + ) + raise + else: + try: + await self.create_kernel( + creation_id, + session_id, + kernel_id, + kernel_config, + existing_cluster_info, + restarting=True) + self.restarting_kernels.pop(kernel_id, None) + except Exception: + # TODO: retry / cancel others? + log.exception('restart_kernel(s:{}, k:{}): re-creation failure', + session_id, kernel_id) + tracker.done_event.set() + kernel_obj = self.kernel_registry[kernel_id] + return { + 'container_id': kernel_obj['container_id'], + 'repl_in_port': kernel_obj['repl_in_port'], + 'repl_out_port': kernel_obj['repl_out_port'], + 'stdin_port': kernel_obj['stdin_port'], + 'stdout_port': kernel_obj['stdout_port'], + 'service_ports': kernel_obj.service_ports, + } + + async def execute( + self, + kernel_id: KernelId, + run_id: Optional[str], + mode: Literal['query', 'batch', 'input', 'continue'], + text: str, + *, + opts: Mapping[str, Any], + api_version: int, + flush_timeout: float, + ): + # Wait for the kernel restarting if it's ongoing... + restart_tracker = self.restarting_kernels.get(kernel_id) + if restart_tracker is not None: + await restart_tracker.done_event.wait() + + await self.produce_event( + ExecutionStartedEvent(SessionId(kernel_id)), + ) + try: + kernel_obj = self.kernel_registry[kernel_id] + result = await kernel_obj.execute( + run_id, mode, text, + opts=opts, + flush_timeout=flush_timeout, + api_version=api_version) + except asyncio.CancelledError: + await self.produce_event( + ExecutionCancelledEvent(SessionId(kernel_id)), + ) + raise + except KeyError: + # This situation is handled in the lifecycle management subsystem. + raise RuntimeError(f'The container for kernel {kernel_id} is not found! ' + '(might be terminated--try it again)') from None + + if result['status'] in ('finished', 'exec-timeout'): + log.debug('_execute({0}) {1}', kernel_id, result['status']) + if result['status'] == 'finished': + await self.produce_event( + ExecutionFinishedEvent(SessionId(kernel_id)), + ) + elif result['status'] == 'exec-timeout': + await self.produce_event( + ExecutionTimeoutEvent(SessionId(kernel_id)), + ) + await self.inject_container_lifecycle_event( + kernel_id, + LifecycleEvent.DESTROY, + 'exec-timeout', + ) + return { + **result, + 'files': [], # kept for API backward-compatibility + } + + async def get_completions(self, kernel_id: KernelId, text: str, opts: dict): + return await self.kernel_registry[kernel_id].get_completions(text, opts) + + async def get_logs(self, kernel_id: KernelId): + return await self.kernel_registry[kernel_id].get_logs() + + async def interrupt_kernel(self, kernel_id: KernelId): + return await self.kernel_registry[kernel_id].interrupt_kernel() + + async def start_service(self, kernel_id: KernelId, service: str, opts: dict): + return await self.kernel_registry[kernel_id].start_service(service, opts) + + async def shutdown_service(self, kernel_id: KernelId, service: str): + try: + kernel_obj = self.kernel_registry[kernel_id] + if kernel_obj is not None: + await kernel_obj.shutdown_service(service) + except Exception: + log.exception('unhandled exception while shutting down service app ${}', service) + + async def accept_file(self, kernel_id: KernelId, filename: str, filedata): + return await self.kernel_registry[kernel_id].accept_file(filename, filedata) + + async def download_file(self, kernel_id: KernelId, filepath: str): + return await self.kernel_registry[kernel_id].download_file(filepath) + + async def list_files(self, kernel_id: KernelId, path: str): + return await self.kernel_registry[kernel_id].list_files(path) diff --git a/src/ai/backend/agent/cli.py b/src/ai/backend/agent/cli.py new file mode 100644 index 0000000000..10e3a60de6 --- /dev/null +++ b/src/ai/backend/agent/cli.py @@ -0,0 +1,7 @@ +import click + + +@click.group() +def main(): + '''The root entrypoint for unified CLI of agent''' + pass diff --git a/src/ai/backend/agent/config.py b/src/ai/backend/agent/config.py new file mode 100644 index 0000000000..33b1806658 --- /dev/null +++ b/src/ai/backend/agent/config.py @@ -0,0 +1,102 @@ +import os + +import trafaret as t + +from ai.backend.common import config +from ai.backend.common import validators as tx + +from .stats import StatModes +from .types import AgentBackend + + +coredump_defaults = { + 'enabled': False, + 'path': './coredumps', + 'backup-count': 10, + 'size-limit': '64M', +} + +agent_local_config_iv = t.Dict({ + t.Key('agent'): t.Dict({ + tx.AliasedKey(['backend', 'mode']): tx.Enum(AgentBackend), + t.Key('rpc-listen-addr', default=('', 6001)): + tx.HostPortPair(allow_blank_host=True), + t.Key('agent-sock-port', default=6007): t.Int[1024:65535], + t.Key('id', default=None): t.Null | t.String, + t.Key('ipc-base-path', default="/tmp/backend.ai/ipc"): + tx.Path(type='dir', auto_create=True), + t.Key('region', default=None): t.Null | t.String, + t.Key('instance-type', default=None): t.Null | t.String, + t.Key('scaling-group', default='default'): t.String, + t.Key('pid-file', default=os.devnull): + tx.Path(type='file', allow_nonexisting=True, allow_devnull=True), + t.Key('event-loop', default='asyncio'): t.Enum('asyncio', 'uvloop'), + t.Key('skip-manager-detection', default=False): t.ToBool, + t.Key('aiomonitor-port', default=50102): t.Int[1:65535], + }).allow_extra('*'), + t.Key('container'): t.Dict({ + t.Key('kernel-uid', default=-1): tx.UserID, + t.Key('kernel-gid', default=-1): tx.GroupID, + t.Key('bind-host', default=''): t.String(allow_blank=True), + t.Key('advertised-host', default=None): t.Null | t.String(), + t.Key('port-range', default=(30000, 31000)): tx.PortRange, + t.Key('stats-type', default='docker'): + t.Null | t.Enum(*[e.value for e in StatModes]), + t.Key('sandbox-type', default='docker'): t.Enum('docker', 'jail'), + t.Key('jail-args', default=[]): t.List(t.String), + t.Key('scratch-type'): t.Enum('hostdir', 'memory', 'k8s-nfs'), + t.Key('scratch-root', default='./scratches'): + tx.Path(type='dir', auto_create=True), + t.Key('scratch-size', default='0'): tx.BinarySize, + t.Key('scratch-nfs-address', default=None): t.Null | t.String, + t.Key('scratch-nfs-options', default=None): t.Null | t.String, + }).allow_extra('*'), + t.Key('logging'): t.Any, # checked in ai.backend.common.logging + t.Key('resource'): t.Dict({ + t.Key('reserved-cpu', default=1): t.Int, + t.Key('reserved-mem', default="1G"): tx.BinarySize, + t.Key('reserved-disk', default="8G"): tx.BinarySize, + }).allow_extra('*'), + t.Key('debug'): t.Dict({ + t.Key('enabled', default=False): t.Bool, + t.Key('skip-container-deletion', default=False): t.Bool, + t.Key('log-stats', default=False): t.Bool, + t.Key('log-kernel-config', default=False): t.Bool, + t.Key('log-alloc-map', default=False): t.Bool, + t.Key('log-events', default=False): t.Bool, + t.Key('log-heartbeats', default=False): t.Bool, + t.Key('log-docker-events', default=False): t.Bool, + t.Key('coredump', default=coredump_defaults): t.Dict({ + t.Key('enabled', default=coredump_defaults['enabled']): t.Bool, + t.Key('path', default=coredump_defaults['path']): + tx.Path(type='dir', auto_create=True), + t.Key('backup-count', default=coredump_defaults['backup-count']): + t.Int[1:], + t.Key('size-limit', default=coredump_defaults['size-limit']): + tx.BinarySize, + }).allow_extra('*'), + }).allow_extra('*'), +}).merge(config.etcd_config_iv).allow_extra('*') + +docker_extra_config_iv = t.Dict({ + t.Key('container'): t.Dict({ + t.Key('swarm-enabled', default=False): t.Bool, + }).allow_extra('*'), +}).allow_extra('*') + +default_container_logs_config = { + 'max-length': '10M', # the maximum tail size + 'chunk-size': '64K', # used when storing logs to Redis as a side-channel to the event bus +} + +agent_etcd_config_iv = t.Dict({ + t.Key('container-logs', default=default_container_logs_config): t.Dict({ + t.Key('max-length', default=default_container_logs_config['max-length']): tx.BinarySize(), + t.Key('chunk-size', default=default_container_logs_config['chunk-size']): tx.BinarySize(), + }).allow_extra('*'), +}).allow_extra('*') + +container_etcd_config_iv = t.Dict({ + t.Key('kernel-uid', optional=True): t.ToInt, + t.Key('kernel-gid', optional=True): t.ToInt, +}).allow_extra('*') diff --git a/src/ai/backend/agent/docker/__init__.py b/src/ai/backend/agent/docker/__init__.py new file mode 100644 index 0000000000..2232d57a28 --- /dev/null +++ b/src/ai/backend/agent/docker/__init__.py @@ -0,0 +1,8 @@ +from typing import Type + +from ..agent import AbstractAgent +from .agent import DockerAgent + + +def get_agent_cls() -> Type[AbstractAgent]: + return DockerAgent diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py new file mode 100644 index 0000000000..f0258eb428 --- /dev/null +++ b/src/ai/backend/agent/docker/agent.py @@ -0,0 +1,1367 @@ +from __future__ import annotations + +import asyncio +import base64 +from decimal import Decimal +from functools import partial +from io import StringIO +import json +import logging +import os +from pathlib import Path +from aiohttp import web +import pkg_resources +import secrets +import shutil +import signal +import struct +from subprocess import CalledProcessError +import sys +from typing import ( + Any, + FrozenSet, + Dict, + List, + Literal, + Mapping, + MutableMapping, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) + +from aiodocker.docker import Docker, DockerContainer +from aiodocker.exceptions import DockerError +import aiotools +from async_timeout import timeout +import zmq + +from ai.backend.common.docker import ( + ImageRef, + MIN_KERNELSPEC, + MAX_KERNELSPEC, +) +from ai.backend.common.exception import ImageNotAvailable +from ai.backend.common.logging import BraceStyleAdapter, pretty +from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext +from ai.backend.common.types import ( + AutoPullBehavior, + BinarySize, + ClusterInfo, + ImageRegistry, + KernelCreationConfig, + KernelId, + ContainerId, + DeviceName, + SlotName, + MountPermission, + MountTypes, + ResourceSlot, + Sentinel, + current_resource_slots, +) +from ai.backend.common.utils import AsyncFileWriter, current_loop +from .kernel import DockerKernel, prepare_kernel_metadata_uri_handling +from .metadata.server import create_server as create_metadata_server +from .resources import detect_resources +from .utils import PersistentServiceContainer +from ..exception import UnsupportedResource, InitializationError +from ..fs import create_scratch_filesystem, destroy_scratch_filesystem +from ..kernel import AbstractKernel, KernelFeatures +from ..resources import ( + Mount, + KernelResourceSpec, +) +from ..agent import ( + AbstractAgent, + AbstractKernelCreationContext, + ACTIVE_STATUS_SET, + ComputerContext, +) +from ..proxy import proxy_connection, DomainSocketProxy +from ..resources import ( + AbstractComputePlugin, + known_slot_types, +) +from ..server import ( + get_extra_volumes, +) +from ..types import ( + Container, + Port, + ContainerStatus, + LifecycleEvent, +) +from ..utils import ( + closing_async, + update_nested_dict, + get_kernel_id_from_container, + host_pid_to_container_pid, + container_pid_to_host_pid, +) + +if TYPE_CHECKING: + from ai.backend.common.etcd import AsyncEtcd + +log = BraceStyleAdapter(logging.getLogger(__name__)) +eof_sentinel = Sentinel.TOKEN + + +def container_from_docker_container(src: DockerContainer) -> Container: + ports = [] + for private_port, host_ports in src['NetworkSettings']['Ports'].items(): + private_port = int(private_port.split('/')[0]) + if host_ports is None: + host_ip = '127.0.0.1' + host_port = 0 + else: + host_ip = host_ports[0]['HostIp'] + host_port = int(host_ports[0]['HostPort']) + ports.append(Port(host_ip, private_port, host_port)) + return Container( + id=src._id, + status=src['State']['Status'], + image=src['Config']['Image'], + labels=src['Config']['Labels'], + ports=ports, + backend_obj=src, + ) + + +def _DockerError_reduce(self): + return ( + type(self), + (self.status, {'message': self.message}, *self.args), + ) + + +def _DockerContainerError_reduce(self): + return ( + type(self), + (self.status, {'message': self.message}, self.container_id, *self.args), + ) + + +class DockerKernelCreationContext(AbstractKernelCreationContext[DockerKernel]): + + scratch_dir: Path + tmp_dir: Path + config_dir: Path + work_dir: Path + container_configs: List[Mapping[str, Any]] + domain_socket_proxies: List[DomainSocketProxy] + computer_docker_args: Dict[str, Any] + port_pool: Set[int] + agent_sockpath: Path + resource_lock: asyncio.Lock + + def __init__( + self, + kernel_id: KernelId, + kernel_config: KernelCreationConfig, + local_config: Mapping[str, Any], + computers: MutableMapping[str, ComputerContext], + port_pool: Set[int], + agent_sockpath: Path, + resource_lock: asyncio.Lock, + restarting: bool = False, + ) -> None: + super().__init__(kernel_id, kernel_config, local_config, computers, restarting=restarting) + scratch_dir = (self.local_config['container']['scratch-root'] / str(kernel_id)).resolve() + tmp_dir = (self.local_config['container']['scratch-root'] / f'{kernel_id}_tmp').resolve() + + self.scratch_dir = scratch_dir + self.tmp_dir = tmp_dir + self.config_dir = scratch_dir / 'config' + self.work_dir = scratch_dir / 'work' + + self.port_pool = port_pool + self.agent_sockpath = agent_sockpath + self.resource_lock = resource_lock + + self.container_configs = [] + self.domain_socket_proxies = [] + self.computer_docker_args = {} + + def _kernel_resource_spec_read(self, filename): + with open(filename, 'r') as f: + resource_spec = KernelResourceSpec.read_from_file(f) + return resource_spec + + async def get_extra_envs(self) -> Mapping[str, str]: + return {} + + async def prepare_resource_spec(self) -> Tuple[KernelResourceSpec, Optional[Mapping[str, Any]]]: + loop = current_loop() + if self.restarting: + resource_spec = await loop.run_in_executor( + None, + self._kernel_resource_spec_read, + self.config_dir / 'resource.txt') + resource_opts = None + else: + slots = ResourceSlot.from_json(self.kernel_config['resource_slots']) + # Ensure that we have intrinsic slots. + assert SlotName('cpu') in slots + assert SlotName('mem') in slots + # accept unknown slot type with zero values + # but reject if they have non-zero values. + for st, sv in slots.items(): + if st not in known_slot_types and sv != Decimal(0): + raise UnsupportedResource(st) + # sanitize the slots + current_resource_slots.set(known_slot_types) + slots = slots.normalize_slots(ignore_unknown=True) + resource_spec = KernelResourceSpec( + container_id='', + allocations={}, + slots={**slots}, # copy + mounts=[], + scratch_disk_size=0, # TODO: implement (#70) + ) + resource_opts = self.kernel_config.get('resource_opts', {}) + return resource_spec, resource_opts + + async def prepare_scratch(self) -> None: + loop = current_loop() + + # Create the scratch, config, and work directories. + if ( + sys.platform.startswith('linux') + and self.local_config['container']['scratch-type'] == 'memory' + ): + await loop.run_in_executor(None, partial(self.tmp_dir.mkdir, exist_ok=True)) + await create_scratch_filesystem(self.scratch_dir, 64) + await create_scratch_filesystem(self.tmp_dir, 64) + else: + await loop.run_in_executor(None, partial(self.scratch_dir.mkdir, exist_ok=True)) + + def _create_scratch_dirs(): + self.config_dir.mkdir(parents=True, exist_ok=True) + self.config_dir.chmod(0o755) + self.work_dir.mkdir(parents=True, exist_ok=True) + self.work_dir.chmod(0o755) + + await loop.run_in_executor(None, _create_scratch_dirs) + + if not self.restarting: + # Since these files are bind-mounted inside a bind-mounted directory, + # we need to touch them first to avoid their "ghost" files are created + # as root in the host-side filesystem, which prevents deletion of scratch + # directories when the agent is running as non-root. + def _clone_dotfiles(): + jupyter_custom_css_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', 'jupyter-custom.css')) + logo_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', 'logo.svg')) + font_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', 'roboto.ttf')) + font_italic_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', 'roboto-italic.ttf')) + bashrc_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', '.bashrc')) + bash_profile_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', '.bash_profile')) + vimrc_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', '.vimrc')) + tmux_conf_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', '.tmux.conf')) + jupyter_custom_dir = (self.work_dir / '.jupyter' / 'custom') + jupyter_custom_dir.mkdir(parents=True, exist_ok=True) + shutil.copy(jupyter_custom_css_path.resolve(), jupyter_custom_dir / 'custom.css') + shutil.copy(logo_path.resolve(), jupyter_custom_dir / 'logo.svg') + shutil.copy(font_path.resolve(), jupyter_custom_dir / 'roboto.ttf') + shutil.copy(font_italic_path.resolve(), jupyter_custom_dir / 'roboto-italic.ttf') + shutil.copy(bashrc_path.resolve(), self.work_dir / '.bashrc') + shutil.copy(bash_profile_path.resolve(), self.work_dir / '.bash_profile') + shutil.copy(vimrc_path.resolve(), self.work_dir / '.vimrc') + shutil.copy(tmux_conf_path.resolve(), self.work_dir / '.tmux.conf') + if KernelFeatures.UID_MATCH in self.kernel_features: + uid = self.local_config['container']['kernel-uid'] + gid = self.local_config['container']['kernel-gid'] + if os.geteuid() == 0: # only possible when I am root. + os.chown(self.work_dir, uid, gid) + os.chown(self.work_dir / '.jupyter', uid, gid) + os.chown(self.work_dir / '.jupyter' / 'custom', uid, gid) + os.chown(self.work_dir / '.bashrc', uid, gid) + os.chown(self.work_dir / '.bash_profile', uid, gid) + os.chown(self.work_dir / '.vimrc', uid, gid) + os.chown(self.work_dir / '.tmux.conf', uid, gid) + + await loop.run_in_executor(None, _clone_dotfiles) + + async def get_intrinsic_mounts(self) -> Sequence[Mount]: + loop = current_loop() + + # scratch/config/tmp mounts + mounts: List[Mount] = [ + Mount(MountTypes.BIND, self.config_dir, Path("/home/config"), + MountPermission.READ_ONLY), + Mount(MountTypes.BIND, self.work_dir, Path("/home/work"), + MountPermission.READ_WRITE), + ] + if (sys.platform.startswith("linux") and + self.local_config["container"]["scratch-type"] == "memory"): + mounts.append(Mount( + MountTypes.BIND, + self.tmp_dir, + Path("/tmp"), + MountPermission.READ_WRITE, + )) + + # lxcfs mounts + lxcfs_root = Path("/var/lib/lxcfs") + if lxcfs_root.is_dir(): + mounts.extend( + Mount( + MountTypes.BIND, + lxcfs_proc_path, + "/" / lxcfs_proc_path.relative_to(lxcfs_root), + MountPermission.READ_WRITE, + ) + for lxcfs_proc_path in (lxcfs_root / "proc").iterdir() + ) + mounts.extend( + Mount( + MountTypes.BIND, + lxcfs_root / path, + "/" / Path(path), + MountPermission.READ_WRITE, + ) + for path in [ + "sys/devices/system/cpu", + "sys/devices/system/cpu/online", + ] + if Path(lxcfs_root / path).exists() + ) + + # extra mounts + async with closing_async(Docker()) as docker: + extra_mount_list = await get_extra_volumes(docker, self.image_ref.short) + mounts.extend(Mount(MountTypes.VOLUME, v.name, v.container_path, v.mode) + for v in extra_mount_list) + + # debug mounts + if self.local_config['debug']['coredump']['enabled']: + mounts.append(Mount( + MountTypes.BIND, + self.local_config['debug']['coredump']['path'], + self.local_config['debug']['coredump']['core_path'], + MountPermission.READ_WRITE, + )) + + # agent-socket mount + mounts.append(Mount( + MountTypes.BIND, + self.agent_sockpath, + Path('/opt/kernel/agent.sock'), + MountPermission.READ_WRITE, + )) + ipc_base_path = self.local_config['agent']['ipc-base-path'] + + # domain-socket proxy mount + # (used for special service containers such image importer) + for host_sock_path in self.internal_data.get('domain_socket_proxies', []): + await loop.run_in_executor( + None, + partial((ipc_base_path / 'proxy').mkdir, parents=True, exist_ok=True)) + host_proxy_path = ipc_base_path / 'proxy' / f'{secrets.token_hex(12)}.sock' + proxy_server = await asyncio.start_unix_server( + aiotools.apartial(proxy_connection, host_sock_path), + str(host_proxy_path)) + await loop.run_in_executor(None, host_proxy_path.chmod, 0o666) + self.domain_socket_proxies.append(DomainSocketProxy( + Path(host_sock_path), + host_proxy_path, + proxy_server, + )) + mounts.append(Mount( + MountTypes.BIND, + host_proxy_path, + host_sock_path, + MountPermission.READ_WRITE, + )) + + return mounts + + def resolve_krunner_filepath(self, filename) -> Path: + return Path(pkg_resources.resource_filename( + 'ai.backend.runner', '../' + filename, + )).resolve() + + def get_runner_mount( + self, + type: MountTypes, + src: Union[str, Path], + target: Union[str, Path], + perm: Literal['ro', 'rw'] = 'ro', + opts: Mapping[str, Any] = None, + ) -> Mount: + return Mount( + type, Path(src), Path(target), + MountPermission(perm), + opts=opts, + ) + + async def apply_network(self, cluster_info: ClusterInfo) -> None: + if cluster_info['network_name'] is not None: + self.container_configs.append({ + 'HostConfig': { + 'NetworkMode': cluster_info['network_name'], + }, + 'NetworkingConfig': { + 'EndpointsConfig': { + cluster_info['network_name']: { + 'Aliases': [self.kernel_config['cluster_hostname']], + }, + }, + }, + }) + # RDMA mounts + ib_root = Path("/dev/infiniband") + if ib_root.is_dir() and (ib_root / "uverbs0").exists(): + self.container_configs.append({ + 'HostConfig': { + 'Devices': [ + { + 'PathOnHost': "/dev/infiniband", + 'PathInContainer': "/dev/infiniband", + 'CgroupPermissions': "rwm", + }, + ], + }, + }) + + async def install_ssh_keypair(self, cluster_info: ClusterInfo) -> None: + sshkey = cluster_info['ssh_keypair'] + if sshkey is None: + return + + def _write_keypair(): + try: + priv_key_path = (self.config_dir / 'ssh' / 'id_cluster') + pub_key_path = (self.config_dir / 'ssh' / 'id_cluster.pub') + priv_key_path.parent.mkdir(parents=True, exist_ok=True) + priv_key_path.write_text(sshkey['private_key']) + pub_key_path.write_text(sshkey['public_key']) + if KernelFeatures.UID_MATCH in self.kernel_features: + uid = self.local_config['container']['kernel-uid'] + gid = self.local_config['container']['kernel-gid'] + if os.geteuid() == 0: # only possible when I am root. + os.chown(str(priv_key_path), uid, gid) + os.chown(str(pub_key_path), uid, gid) + priv_key_path.chmod(0o600) + except Exception: + log.exception('error while writing cluster keypair') + + current_loop().run_in_executor(None, _write_keypair) # ??? + + async def process_mounts(self, mounts: Sequence[Mount]): + def fix_unsupported_perm(folder_perm: MountPermission) -> MountPermission: + if folder_perm == MountPermission.RW_DELETE: + # TODO: enforce readable/writable but not deletable + # (Currently docker's READ_WRITE includes DELETE) + return MountPermission.READ_WRITE + return folder_perm + + container_config = { + 'HostConfig': { + 'Mounts': [ + { + 'Target': str(mount.target), + 'Source': str(mount.source), + 'Type': mount.type.value, + 'ReadOnly': fix_unsupported_perm(mount.permission) == MountPermission.READ_ONLY, + f'{mount.type.value.capitalize()}Options': + mount.opts if mount.opts else {}, + } + for mount in mounts + ], + }, + } + self.container_configs.append(container_config) + + async def apply_accelerator_allocation(self, computer, device_alloc) -> None: + async with closing_async(Docker()) as docker: + update_nested_dict( + self.computer_docker_args, + await computer.generate_docker_args(docker, device_alloc), + ) + + async def spawn( + self, + resource_spec: KernelResourceSpec, + environ: Mapping[str, str], + service_ports, + ) -> DockerKernel: + loop = current_loop() + + if self.restarting: + pass + else: + # Create bootstrap.sh into workdir if needed + if bootstrap := self.kernel_config.get('bootstrap_script'): + + def _write_user_bootstrap_script(): + (self.work_dir / 'bootstrap.sh').write_text(bootstrap) + if KernelFeatures.UID_MATCH in self.kernel_features: + uid = self.local_config['container']['kernel-uid'] + gid = self.local_config['container']['kernel-gid'] + if os.geteuid() == 0: + os.chown(self.work_dir / 'bootstrap.sh', uid, gid) + + await loop.run_in_executor(None, _write_user_bootstrap_script) + + with StringIO() as buf: + for k, v in environ.items(): + buf.write(f'{k}={v}\n') + accel_envs = self.computer_docker_args.get('Env', []) + for env in accel_envs: + buf.write(f'{env}\n') + await loop.run_in_executor( + None, + (self.config_dir / 'environ.txt').write_bytes, + buf.getvalue().encode('utf8'), + ) + + with StringIO() as buf: + resource_spec.write_to_file(buf) + for dev_type, device_alloc in resource_spec.allocations.items(): + computer_self = self.computers[dev_type] + kvpairs = \ + await computer_self.instance.generate_resource_data(device_alloc) + for k, v in kvpairs.items(): + buf.write(f'{k}={v}\n') + await loop.run_in_executor( + None, + (self.config_dir / 'resource.txt').write_bytes, + buf.getvalue().encode('utf8'), + ) + + docker_creds = self.internal_data.get('docker_credentials') + if docker_creds: + await loop.run_in_executor( + None, + (self.config_dir / 'docker-creds.json').write_text, + json.dumps(docker_creds)) + + # TODO: refactor out dotfiles/sshkey initialization to the base agent? + + shutil.copyfile(self.config_dir / 'environ.txt', self.config_dir / 'environ_base.txt') + shutil.copyfile(self.config_dir / 'resource.txt', self.config_dir / 'resource_base.txt') + # Create SSH keypair only if ssh_keypair internal_data exists and + # /home/work/.ssh folder is not mounted. + if self.internal_data.get('ssh_keypair'): + for mount in resource_spec.mounts: + container_path = str(mount).split(':')[1] + if container_path == '/home/work/.ssh': + break + else: + pubkey = self.internal_data['ssh_keypair']['public_key'].encode('ascii') + privkey = self.internal_data['ssh_keypair']['private_key'].encode('ascii') + ssh_dir = self.work_dir / '.ssh' + + def _populate_ssh_config(): + ssh_dir.mkdir(parents=True, exist_ok=True) + ssh_dir.chmod(0o700) + (ssh_dir / 'authorized_keys').write_bytes(pubkey) + (ssh_dir / 'authorized_keys').chmod(0o600) + (self.work_dir / 'id_container').write_bytes(privkey) + (self.work_dir / 'id_container').chmod(0o600) + if KernelFeatures.UID_MATCH in self.kernel_features: + uid = self.local_config['container']['kernel-uid'] + gid = self.local_config['container']['kernel-gid'] + if os.geteuid() == 0: # only possible when I am root. + os.chown(ssh_dir, uid, gid) + os.chown(ssh_dir / 'authorized_keys', uid, gid) + os.chown(self.work_dir / 'id_container', uid, gid) + + await loop.run_in_executor(None, _populate_ssh_config) + + # higher priority dotfiles are stored last to support overwriting + for dotfile in self.internal_data.get('dotfiles', []): + if dotfile['path'].startswith('/'): + if dotfile['path'].startswith('/home/'): + path_arr = dotfile['path'].split('/') + file_path: Path = self.scratch_dir / '/'.join(path_arr[2:]) + else: + file_path = Path(dotfile['path']) + else: + file_path = self.work_dir / dotfile['path'] + file_path.parent.mkdir(parents=True, exist_ok=True) + await loop.run_in_executor( + None, + file_path.write_text, + dotfile['data']) + + tmp = Path(file_path) + while tmp != self.work_dir: + tmp.chmod(int(dotfile['perm'], 8)) + # only possible when I am root. + if KernelFeatures.UID_MATCH in self.kernel_features and os.geteuid() == 0: + uid = self.local_config['container']['kernel-uid'] + gid = self.local_config['container']['kernel-gid'] + os.chown(tmp, uid, gid) + tmp = tmp.parent + + kernel_obj = DockerKernel( + self.kernel_id, + self.image_ref, + self.kspec_version, + agent_config=self.local_config, + service_ports=service_ports, + resource_spec=resource_spec, + environ=environ, + data={}, + ) + return kernel_obj + + async def start_container( + self, + kernel_obj: AbstractKernel, + cmdargs: List[str], + resource_opts, + preopen_ports, + ) -> Mapping[str, Any]: + loop = current_loop() + resource_spec = kernel_obj.resource_spec + service_ports = kernel_obj.service_ports + environ = kernel_obj.environ + image_labels = self.kernel_config['image']['labels'] + + # PHASE 4: Run! + container_bind_host = self.local_config['container']['bind-host'] + advertised_kernel_host = self.local_config['container'].get('advertised-host') + exposed_ports = [2000, 2001] + for sport in service_ports: + exposed_ports.extend(sport['container_ports']) + if len(exposed_ports) > len(self.port_pool): + raise RuntimeError('Container ports are not sufficiently available.') + host_ports = [] + for eport in exposed_ports: + hport = self.port_pool.pop() + host_ports.append(hport) + + container_log_size = self.local_config['agent']['container-logs']['max-length'] + container_log_file_count = 5 + container_log_file_size = BinarySize(container_log_size // container_log_file_count) + container_config: MutableMapping[str, Any] = { + 'Image': self.image_ref.canonical, + 'Tty': True, + 'OpenStdin': True, + 'Privileged': False, + 'StopSignal': 'SIGINT', + 'ExposedPorts': { + f'{port}/tcp': {} for port in exposed_ports + }, + 'EntryPoint': ["/opt/kernel/entrypoint.sh"], + 'Cmd': cmdargs, + 'Env': [f'{k}={v}' for k, v in environ.items()], + 'WorkingDir': "/home/work", + 'Hostname': self.kernel_config['cluster_hostname'], + 'Labels': { + 'ai.backend.kernel-id': str(self.kernel_id), + 'ai.backend.internal.block-service-ports': + '1' if self.internal_data.get('block_service_ports', False) else '0', + }, + 'HostConfig': { + 'Init': True, + 'PortBindings': { + f'{eport}/tcp': [{'HostPort': str(hport), + 'HostIp': str(container_bind_host)}] + for eport, hport in zip(exposed_ports, host_ports) + }, + 'PublishAllPorts': False, # we manage port mapping manually! + 'CapAdd': [ + 'IPC_LOCK', # for hugepages and RDMA + ], + 'Ulimits': [ + {"Name": "nofile", "Soft": 1048576, "Hard": 1048576}, + {"Name": "memlock", "Soft": -1, "Hard": -1}, + ], + 'LogConfig': { + 'Type': 'local', # for efficient docker-specific storage + 'Config': { + # these fields must be str + # (ref: https://docs.docker.com/config/containers/logging/local/) + 'max-size': f"{container_log_file_size:s}", + 'max-file': str(container_log_file_count), + 'compress': 'false', + }, + }, + }, + } + # merge all container configs generated during prior preparation steps + for c in self.container_configs: + update_nested_dict(container_config, c) + if self.local_config['container']['sandbox-type'] == 'jail': + update_nested_dict(container_config, { + 'HostConfig': { + 'SecurityOpt': ['seccomp=unconfined', 'apparmor=unconfined'], + }, + }) + + if resource_opts and resource_opts.get('shmem'): + shmem = int(resource_opts.get('shmem', '0')) + self.computer_docker_args['HostConfig']['ShmSize'] = shmem + self.computer_docker_args['HostConfig']['MemorySwap'] -= shmem + self.computer_docker_args['HostConfig']['Memory'] -= shmem + + encoded_preopen_ports = ','.join(f'{port_no}:preopen:{port_no}' for port_no in preopen_ports) + container_config['Labels']['ai.backend.service-ports'] = \ + image_labels['ai.backend.service-ports'] + ',' + encoded_preopen_ports + update_nested_dict(container_config, self.computer_docker_args) + kernel_name = f"kernel.{self.image_ref.name.split('/')[-1]}.{self.kernel_id}" + if self.local_config['debug']['log-kernel-config']: + log.debug('full container config: {!r}', pretty(container_config)) + + # optional local override of docker config + extra_container_opts_name = 'agent-docker-container-opts.json' + for extra_container_opts_file in [ + Path('/etc/backend.ai') / extra_container_opts_name, + Path.home() / '.config' / 'backend.ai' / extra_container_opts_name, + Path.cwd() / extra_container_opts_name, + ]: + if extra_container_opts_file.is_file(): + try: + extra_container_opts = json.loads(extra_container_opts_file.read_bytes()) + update_nested_dict(container_config, extra_container_opts) + except IOError: + pass + + # We are all set! Create and start the container. + async with closing_async(Docker()) as docker: + try: + container = await docker.containers.create( + config=container_config, name=kernel_name) + cid = container._id + + resource_spec.container_id = cid + # Write resource.txt again to update the contaienr id. + with open(self.config_dir / 'resource.txt', 'w') as f: + await loop.run_in_executor(None, resource_spec.write_to_file, f) + async with AsyncFileWriter( + target_filename=self.config_dir / 'resource.txt', + access_mode='a', + ) as writer: + for dev_name, device_alloc in resource_spec.allocations.items(): + computer_ctx = self.computers[dev_name] + kvpairs = \ + await computer_ctx.instance.generate_resource_data(device_alloc) + for k, v in kvpairs.items(): + await writer.write(f'{k}={v}\n') + + await container.start() + except asyncio.CancelledError: + raise + except Exception: + # Oops, we have to restore the allocated resources! + if (sys.platform.startswith('linux') and + self.local_config['container']['scratch-type'] == 'memory'): + await destroy_scratch_filesystem(self.scratch_dir) + await destroy_scratch_filesystem(self.tmp_dir) + await loop.run_in_executor(None, shutil.rmtree, self.tmp_dir) + await loop.run_in_executor(None, shutil.rmtree, self.scratch_dir) + self.port_pool.update(host_ports) + async with self.resource_lock: + for dev_name, device_alloc in resource_spec.allocations.items(): + self.computers[dev_name].alloc_map.free(device_alloc) + raise + + ctnr_host_port_map: MutableMapping[int, int] = {} + stdin_port = 0 + stdout_port = 0 + for idx, port in enumerate(exposed_ports): + host_port = int((await container.port(port))[0]['HostPort']) + assert host_port == host_ports[idx] + if port == 2000: # intrinsic + repl_in_port = host_port + elif port == 2001: # intrinsic + repl_out_port = host_port + elif port == 2002: # legacy + stdin_port = host_port + elif port == 2003: # legacy + stdout_port = host_port + else: + ctnr_host_port_map[port] = host_port + for sport in service_ports: + sport['host_ports'] = tuple( + ctnr_host_port_map[cport] for cport in sport['container_ports'] + ) + + return { + 'container_id': container._id, + 'kernel_host': advertised_kernel_host or container_bind_host, + 'repl_in_port': repl_in_port, + 'repl_out_port': repl_out_port, + 'stdin_port': stdin_port, # legacy + 'stdout_port': stdout_port, # legacy + 'host_ports': host_ports, + 'domain_socket_proxies': self.domain_socket_proxies, + 'block_service_ports': self.internal_data.get('block_service_ports', False), + } + + +class DockerAgent(AbstractAgent[DockerKernel, DockerKernelCreationContext]): + + monitor_docker_task: asyncio.Task + agent_sockpath: Path + agent_sock_task: asyncio.Task + scan_images_timer: asyncio.Task + metadata_server_runner: web.AppRunner + docker_ptask_group: aiotools.PersistentTaskGroup + + def __init__( + self, + etcd: AsyncEtcd, + local_config: Mapping[str, Any], + *, + stats_monitor: StatsPluginContext, + error_monitor: ErrorPluginContext, + skip_initial_scan: bool = False, + ) -> None: + super().__init__( + etcd, + local_config, + stats_monitor=stats_monitor, + error_monitor=error_monitor, + skip_initial_scan=skip_initial_scan, + ) + + async def __ainit__(self) -> None: + async with closing_async(Docker()) as docker: + if not self._skip_initial_scan: + docker_version = await docker.version() + log.info('running with Docker {0} with API {1}', + docker_version['Version'], docker_version['ApiVersion']) + await super().__ainit__() + await self.check_swarm_status() + if self.heartbeat_extra_info['swarm_enabled']: + log.info('The Docker Swarm cluster is configured and enabled') + ipc_base_path = self.local_config['agent']['ipc-base-path'] + (ipc_base_path / 'container').mkdir(parents=True, exist_ok=True) + self.agent_sockpath = ipc_base_path / 'container' / f'agent.{self.local_instance_id}.sock' + socket_relay_name = f"backendai-socket-relay.{self.local_instance_id}" + socket_relay_container = PersistentServiceContainer( + 'backendai-socket-relay:latest', + { + 'Cmd': [ + f"UNIX-LISTEN:/ipc/{self.agent_sockpath.name},unlink-early,fork,mode=777", + f"TCP-CONNECT:127.0.0.1:{self.local_config['agent']['agent-sock-port']}", + ], + 'HostConfig': { + 'Mounts': [ + { + 'Type': 'bind', + 'Source': str(ipc_base_path / 'container'), + 'Target': '/ipc', + }, + ], + 'NetworkMode': 'host', + }, + }, + name=socket_relay_name, + ) + await socket_relay_container.ensure_running_latest() + self.agent_sock_task = asyncio.create_task(self.handle_agent_socket()) + self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events()) + self.monitor_swarm_task = asyncio.create_task(self.check_swarm_status(as_task=True)) + self.docker_ptask_group = aiotools.PersistentTaskGroup() + + await prepare_kernel_metadata_uri_handling(self.local_config) + metadata_server_runner = web.AppRunner( + await create_metadata_server(self.local_config, self.kernel_registry), + ) + await metadata_server_runner.setup() + site = web.TCPSite(metadata_server_runner, '0.0.0.0', 40128) + await site.start() + self.metadata_server_runner = metadata_server_runner + # For legacy accelerator plugins + self.docker = Docker() + + async def shutdown(self, stop_signal: signal.Signals): + # Stop handling agent sock. + if self.agent_sock_task is not None: + self.agent_sock_task.cancel() + await self.agent_sock_task + if self.docker_ptask_group is not None: + await self.docker_ptask_group.shutdown() + + try: + await super().shutdown(stop_signal) + finally: + # Stop docker event monitoring. + if self.monitor_docker_task is not None: + self.monitor_docker_task.cancel() + await self.monitor_docker_task + + if self.monitor_swarm_task is not None: + self.monitor_swarm_task.cancel() + await self.monitor_swarm_task + + await self.metadata_server_runner.cleanup() + if self.docker: + await self.docker.close() + + async def detect_resources(self) -> Tuple[ + Mapping[DeviceName, AbstractComputePlugin], + Mapping[SlotName, Decimal], + ]: + return await detect_resources(self.etcd, self.local_config) + + async def enumerate_containers( + self, + status_filter: FrozenSet[ContainerStatus] = ACTIVE_STATUS_SET, + ) -> Sequence[Tuple[KernelId, Container]]: + result = [] + fetch_tasks = [] + async with closing_async(Docker()) as docker: + for container in (await docker.containers.list()): + + async def _fetch_container_info(container): + kernel_id = "(unknown)" + try: + kernel_id = await get_kernel_id_from_container(container) + if kernel_id is None: + return + if container['State']['Status'] in status_filter: + await container.show() + result.append( + ( + kernel_id, + container_from_docker_container(container), + ), + ) + except asyncio.CancelledError: + pass + except Exception: + log.exception( + "error while fetching container information (cid:{}, k:{})", + container._id, kernel_id, + ) + + fetch_tasks.append(_fetch_container_info(container)) + + await asyncio.gather(*fetch_tasks, return_exceptions=True) + return result + + async def check_swarm_status(self, as_task=False): + try: + while True: + if as_task: + await asyncio.sleep(30) + try: + swarm_enabled = self.local_config['container'].get('swarm-enabled', False) + if not swarm_enabled: + continue + async with closing_async(Docker()) as docker: + docker_info = await docker.system.info() + if docker_info['Swarm']['LocalNodeState'] == 'inactive': + raise InitializationError( + "The swarm mode is enabled but the node state of " + "the local Docker daemon is inactive.", + ) + except InitializationError as e: + log.exception(str(e)) + swarm_enabled = False + finally: + self.heartbeat_extra_info = { + 'swarm_enabled': swarm_enabled, + } + if not as_task: + return + except asyncio.CancelledError: + pass + + async def scan_images(self) -> Mapping[str, str]: + async with closing_async(Docker()) as docker: + all_images = await docker.images.list() + updated_images = {} + for image in all_images: + if image['RepoTags'] is None: + continue + for repo_tag in image['RepoTags']: + if repo_tag.endswith(''): + continue + img_detail = await docker.images.inspect(repo_tag) + labels = img_detail['Config']['Labels'] + if labels is None or 'ai.backend.kernelspec' not in labels: + continue + kernelspec = int(labels['ai.backend.kernelspec']) + if MIN_KERNELSPEC <= kernelspec <= MAX_KERNELSPEC: + updated_images[repo_tag] = img_detail['Id'] + for added_image in (updated_images.keys() - self.images.keys()): + log.debug('found kernel image: {0}', added_image) + for removed_image in (self.images.keys() - updated_images.keys()): + log.debug('removed kernel image: {0}', removed_image) + return updated_images + + async def handle_agent_socket(self): + """ + A simple request-reply socket handler for in-container processes. + For ease of implementation in low-level languages such as C, + it uses a simple C-friendly ZeroMQ-based multipart messaging protocol. + + The agent listens on a local TCP port and there is a socat relay + that proxies this port via a UNIX domain socket mounted inside + actual containers. The reason for this is to avoid inode changes + upon agent restarts by keeping the relay container running persistently, + so that the mounted UNIX socket files don't get to refere a dangling pointer + when the agent is restarted. + + Request message: + The first part is the requested action as string, + The second part and later are arguments. + + Reply message: + The first part is a 32-bit integer (int in C) + (0: success) + (-1: generic unhandled error) + (-2: invalid action) + The second part and later are arguments. + + All strings are UTF-8 encoded. + """ + terminating = False + while True: + agent_sock = self.zmq_ctx.socket(zmq.REP) + try: + agent_sock.bind(f"tcp://127.0.0.1:{self.local_config['agent']['agent-sock-port']}") + while True: + msg = await agent_sock.recv_multipart() + if not msg: + break + try: + if msg[0] == b'host-pid-to-container-pid': + container_id = msg[1].decode() + host_pid = struct.unpack('i', msg[2])[0] + container_pid = await host_pid_to_container_pid( + container_id, host_pid) + reply = [ + struct.pack('i', 0), + struct.pack('i', container_pid), + ] + elif msg[0] == b'container-pid-to-host-pid': + container_id = msg[1].decode() + container_pid = struct.unpack('i', msg[2])[0] + host_pid = await container_pid_to_host_pid( + container_id, container_pid) + reply = [ + struct.pack('i', 0), + struct.pack('i', host_pid), + ] + else: + reply = [struct.pack('i', -2), b'Invalid action'] + except asyncio.CancelledError: + terminating = True + raise + except Exception as e: + log.exception("handle_agent_socket(): internal error") + reply = [struct.pack('i', -1), f'Error: {e}'.encode('utf-8')] + await agent_sock.send_multipart(reply) + except asyncio.CancelledError: + terminating = True + return + except zmq.ZMQError: + log.exception("handle_agent_socket(): zmq error") + finally: + agent_sock.close() + if not terminating: + log.info("handle_agent_socket(): rebinding the socket") + + async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> None: + auth_config = None + reg_user = registry_conf.get('username') + reg_passwd = registry_conf.get('password') + if reg_user and reg_passwd: + encoded_creds = base64.b64encode( + f'{reg_user}:{reg_passwd}'.encode('utf-8')) \ + .decode('ascii') + auth_config = { + 'auth': encoded_creds, + } + log.info('pulling image {} from registry', image_ref.canonical) + async with closing_async(Docker()) as docker: + await docker.images.pull( + image_ref.canonical, + auth=auth_config) + + async def check_image(self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior) -> bool: + try: + async with closing_async(Docker()) as docker: + image_info = await docker.images.inspect(image_ref.canonical) + if auto_pull == AutoPullBehavior.DIGEST: + if image_info['Id'] != image_id: + return True + log.info('found the local up-to-date image for {}', image_ref.canonical) + except DockerError as e: + if e.status == 404: + if auto_pull == AutoPullBehavior.DIGEST: + return True + elif auto_pull == AutoPullBehavior.TAG: + return True + elif auto_pull == AutoPullBehavior.NONE: + raise ImageNotAvailable(image_ref) + else: + raise + return False + + async def init_kernel_context( + self, + kernel_id: KernelId, + kernel_config: KernelCreationConfig, + *, + restarting: bool = False, + ) -> DockerKernelCreationContext: + return DockerKernelCreationContext( + kernel_id, + kernel_config, + self.local_config, + self.computers, + self.port_pool, + self.agent_sockpath, + self.resource_lock, + restarting=restarting, + ) + + async def restart_kernel__load_config( + self, + kernel_id: KernelId, + name: str, + ) -> bytes: + loop = current_loop() + scratch_dir = (self.local_config['container']['scratch-root'] / str(kernel_id)).resolve() + config_dir = scratch_dir / 'config' + return await loop.run_in_executor( + None, + (config_dir / name).read_bytes, + ) + + async def restart_kernel__store_config( + self, + kernel_id: KernelId, + name: str, + data: bytes, + ) -> None: + loop = current_loop() + scratch_dir = (self.local_config['container']['scratch-root'] / str(kernel_id)).resolve() + config_dir = scratch_dir / 'config' + return await loop.run_in_executor( + None, + (config_dir / name).write_bytes, + data, + ) + + async def destroy_kernel( + self, + kernel_id: KernelId, + container_id: Optional[ContainerId], + ) -> None: + if container_id is None: + return + try: + async with closing_async(Docker()) as docker: + container = docker.containers.container(container_id) + # The default timeout of the docker stop API is 10 seconds + # to kill if container does not self-terminate. + await container.stop() + except DockerError as e: + if e.status == 409 and 'is not running' in e.message: + # already dead + log.warning('destroy_kernel(k:{0}) already dead', kernel_id) + await self.rescan_resource_usage() + elif e.status == 404: + # missing + log.warning('destroy_kernel(k:{0}) kernel missing, ' + 'forgetting this kernel', kernel_id) + await self.rescan_resource_usage() + else: + log.exception('destroy_kernel(k:{0}) kill error', kernel_id) + await self.error_monitor.capture_exception() + + async def clean_kernel( + self, + kernel_id: KernelId, + container_id: Optional[ContainerId], + restarting: bool, + ) -> None: + loop = current_loop() + async with closing_async(Docker()) as docker: + if container_id is not None: + container = docker.containers.container(container_id) + + async def log_iter(): + it = container.log( + stdout=True, stderr=True, follow=True, + ) + async with aiotools.aclosing(it): + async for line in it: + yield line.encode('utf-8') + + try: + with timeout(60): + await self.collect_logs(kernel_id, container_id, log_iter()) + except asyncio.TimeoutError: + log.warning('timeout for collecting container logs (k:{}, cid:{})', + kernel_id, container_id) + except Exception as e: + log.warning('error while collecting container logs (k:{}, cid:{})', + kernel_id, container_id, exc_info=e) + + kernel_obj = self.kernel_registry.get(kernel_id) + if kernel_obj is not None: + for domain_socket_proxy in kernel_obj.get('domain_socket_proxies', []): + if domain_socket_proxy.proxy_server.is_serving(): + domain_socket_proxy.proxy_server.close() + await domain_socket_proxy.proxy_server.wait_closed() + try: + domain_socket_proxy.host_proxy_path.unlink() + except IOError: + pass + + if not self.local_config['debug']['skip-container-deletion'] and container_id is not None: + container = docker.containers.container(container_id) + try: + with timeout(90): + await container.delete(force=True, v=True) + except DockerError as e: + if e.status == 409 and 'already in progress' in e.message: + return + elif e.status == 404: + return + else: + log.exception( + 'unexpected docker error while deleting container (k:{}, c:{})', + kernel_id, container_id) + except asyncio.TimeoutError: + log.warning('container deletion timeout (k:{}, c:{})', + kernel_id, container_id) + + if not restarting: + scratch_root = self.local_config['container']['scratch-root'] + scratch_dir = scratch_root / str(kernel_id) + tmp_dir = scratch_root / f'{kernel_id}_tmp' + try: + if (sys.platform.startswith('linux') and + self.local_config['container']['scratch-type'] == 'memory'): + await destroy_scratch_filesystem(scratch_dir) + await destroy_scratch_filesystem(tmp_dir) + await loop.run_in_executor(None, shutil.rmtree, tmp_dir) + await loop.run_in_executor(None, shutil.rmtree, scratch_dir) + except CalledProcessError: + pass + except FileNotFoundError: + pass + + async def create_overlay_network(self, network_name: str) -> None: + if not self.heartbeat_extra_info['swarm_enabled']: + raise RuntimeError("This agent has not joined to a swarm cluster.") + async with closing_async(Docker()) as docker: + await docker.networks.create({ + 'Name': network_name, + 'Driver': 'overlay', + 'Attachable': True, + 'Labels': { + 'ai.backend.cluster-network': '1', + }, + }) + + async def destroy_overlay_network(self, network_name: str) -> None: + docker = Docker() + try: + network = await docker.networks.get(network_name) + await network.delete() + finally: + await docker.close() + + async def create_local_network(self, network_name: str) -> None: + async with closing_async(Docker()) as docker: + await docker.networks.create({ + 'Name': network_name, + 'Driver': 'bridge', + 'Labels': { + 'ai.backend.cluster-network': '1', + }, + }) + + async def destroy_local_network(self, network_name: str) -> None: + async with closing_async(Docker()) as docker: + network = await docker.networks.get(network_name) + await network.delete() + + async def monitor_docker_events(self): + + async def handle_action_start(kernel_id: KernelId, evdata: Mapping[str, Any]) -> None: + await self.inject_container_lifecycle_event( + kernel_id, + LifecycleEvent.START, + 'new-container-started', + container_id=ContainerId(evdata['Actor']['ID']), + ) + + async def handle_action_die(kernel_id: KernelId, evdata: Mapping[str, Any]) -> None: + # When containers die, we immediately clean up them. + reason = None + kernel_obj = self.kernel_registry.get(kernel_id) + if kernel_obj is not None: + reason = kernel_obj.termination_reason + try: + exit_code = evdata['Actor']['Attributes']['exitCode'] + except KeyError: + exit_code = 255 + await self.inject_container_lifecycle_event( + kernel_id, + LifecycleEvent.CLEAN, + reason or 'self-terminated', + container_id=ContainerId(evdata['Actor']['ID']), + exit_code=exit_code, + ) + + while True: + async with closing_async(Docker()) as docker: + subscriber = docker.events.subscribe(create_task=True) + try: + while True: + try: + # ref: https://docs.docker.com/engine/api/v1.40/#operation/SystemEvents + evdata = await subscriber.get() + if evdata is None: + # Break out to the outermost loop when the connection is closed + log.info( + "monitor_docker_events(): " + "restarting aiodocker event subscriber", + ) + break + if evdata['Type'] != 'container': + # Our interest is the container-related events + continue + container_name = evdata['Actor']['Attributes']['name'] + kernel_id = await get_kernel_id_from_container(container_name) + if kernel_id is None: + continue + if ( + self.local_config['debug']['log-docker-events'] + and evdata['Action'] in ('start', 'die') + ): + log.debug('docker-event: action={}, actor={}', + evdata['Action'], evdata['Actor']) + if evdata['Action'] == 'start': + await asyncio.shield(self.docker_ptask_group.create_task( + handle_action_start(kernel_id, evdata), + )) + elif evdata['Action'] == 'die': + await asyncio.shield(self.docker_ptask_group.create_task( + handle_action_die(kernel_id, evdata), + )) + except asyncio.CancelledError: + # We are shutting down... + return + except Exception: + log.exception("monitor_docker_events(): unexpected error") + finally: + await asyncio.shield(self.docker_ptask_group.create_task( + docker.events.stop(), + )) diff --git a/src/ai/backend/agent/docker/backendai-socket-relay.img.aarch64.tar.gz b/src/ai/backend/agent/docker/backendai-socket-relay.img.aarch64.tar.gz new file mode 100644 index 0000000000..e335efcbf7 --- /dev/null +++ b/src/ai/backend/agent/docker/backendai-socket-relay.img.aarch64.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae9cceb0f1b329c55ca0b46eedef73b2ecf15e074e5b3c5c3e802473da5b845e +size 3161130 diff --git a/src/ai/backend/agent/docker/backendai-socket-relay.img.x86_64.tar.gz b/src/ai/backend/agent/docker/backendai-socket-relay.img.x86_64.tar.gz new file mode 100644 index 0000000000..f86102fe7b --- /dev/null +++ b/src/ai/backend/agent/docker/backendai-socket-relay.img.x86_64.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:632296d4e118f2c09abad66bf51ed5e9383cc1fb35a6eb18f5d04f73c3ad437c +size 3254795 diff --git a/src/ai/backend/agent/docker/backendai-socket-relay.version.txt b/src/ai/backend/agent/docker/backendai-socket-relay.version.txt new file mode 100644 index 0000000000..d00491fd7e --- /dev/null +++ b/src/ai/backend/agent/docker/backendai-socket-relay.version.txt @@ -0,0 +1 @@ +1 diff --git a/src/ai/backend/agent/docker/files.py b/src/ai/backend/agent/docker/files.py new file mode 100644 index 0000000000..78659aa534 --- /dev/null +++ b/src/ai/backend/agent/docker/files.py @@ -0,0 +1,64 @@ +import logging +import os +from pathlib import Path +from typing import Dict + +from ai.backend.common.logging import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +# the names of following AWS variables follow boto3 convention. +s3_access_key = os.environ.get('AWS_ACCESS_KEY_ID', 'dummy-access-key') +s3_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', 'dummy-secret-key') +s3_region = os.environ.get('AWS_REGION', 'ap-northeast-1') +s3_bucket = os.environ.get('AWS_S3_BUCKET', 'codeonweb') +s3_bucket_path = os.environ.get('AWS_S3_BUCKET_PATH', 'bucket') + +if s3_access_key == 'dummy-access-key': + log.info('Automatic ~/.output file S3 uploads is disabled.') + + +def relpath(path, base): + return Path(path).resolve().relative_to(Path(base).resolve()) + + +def scandir(root: Path, allowed_max_size: int): + ''' + Scans a directory recursively and returns a dictionary of all files and + their last modified time. + ''' + file_stats: Dict[Path, float] = dict() + if not isinstance(root, Path): + root = Path(root) + if not root.exists(): + return file_stats + for entry in os.scandir(root): + # Skip hidden files. + if entry.name.startswith('.'): + continue + if entry.is_file(): + try: + stat = entry.stat() + except PermissionError: + continue + # Skip too large files! + if stat.st_size > allowed_max_size: + continue + file_stats[Path(entry.path)] = stat.st_mtime + elif entry.is_dir(): + try: + file_stats.update(scandir(Path(entry.path), allowed_max_size)) + except PermissionError: + pass + return file_stats + + +def diff_file_stats(fs1, fs2): + k2 = set(fs2.keys()) + k1 = set(fs1.keys()) + new_files = k2 - k1 + modified_files = set() + for k in (k2 - new_files): + if fs1[k] < fs2[k]: + modified_files.add(k) + return new_files | modified_files diff --git a/src/ai/backend/agent/docker/intrinsic.py b/src/ai/backend/agent/docker/intrinsic.py new file mode 100644 index 0000000000..7039e2ea53 --- /dev/null +++ b/src/ai/backend/agent/docker/intrinsic.py @@ -0,0 +1,606 @@ +import asyncio +from decimal import Decimal +import logging +import os +from pathlib import Path +import platform +from typing import ( + cast, + Any, + Collection, + Dict, + List, + Mapping, + Optional, + Sequence, +) + +import aiohttp +from aiodocker.docker import Docker, DockerContainer +from aiodocker.exceptions import DockerError +import async_timeout +import psutil + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.utils import current_loop, nmget +from ai.backend.common.types import ( + DeviceName, DeviceId, + DeviceModelInfo, + SlotName, SlotTypes, + MetricKey, +) +from .agent import Container +from .resources import ( + get_resource_spec_from_container, +) +from .. import __version__ +from ..resources import ( + AbstractAllocMap, DeviceSlotInfo, + DiscretePropertyAllocMap, + AbstractComputeDevice, + AbstractComputePlugin, +) +from ..stats import ( + StatContext, NodeMeasurement, ContainerMeasurement, + StatModes, MetricTypes, Measurement, +) +from ..utils import closing_async, read_sysfs +from ..vendor.linux import libnuma + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +async def fetch_api_stats(container: DockerContainer) -> Optional[Dict[str, Any]]: + short_cid = container._id[:7] + try: + ret = await container.stats(stream=False) # TODO: cache + except RuntimeError as e: + msg = str(e.args[0]).lower() + if 'event loop is closed' in msg or 'session is closed' in msg: + return None + raise + except (DockerError, aiohttp.ClientError) as e: + log.error( + 'cannot read stats (cid:{}): client error: {!r}.', + short_cid, e, + ) + return None + else: + # aiodocker 0.16 or later returns a list of dict, even when not streaming. + if isinstance(ret, list): + if not ret: + # The API may return an empty result upon container termination. + return None + ret = ret[0] + # The API may return an invalid or empty result upon container termination. + if ret is None or not isinstance(ret, dict): + log.warning( + 'cannot read stats (cid:{}): got an empty result: {}', + short_cid, ret, + ) + return None + if ( + ret['read'].startswith('0001-01-01') or + ret['preread'].startswith('0001-01-01') + ): + return None + return ret + + +# Pseudo-plugins for intrinsic devices (CPU and the main memory) + +class CPUDevice(AbstractComputeDevice): + pass + + +class CPUPlugin(AbstractComputePlugin): + """ + Represents the CPU. + """ + + config_watch_enabled = False + + key = DeviceName('cpu') + slot_types = [ + (SlotName('cpu'), SlotTypes.COUNT), + ] + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None: + pass + + async def list_devices(self) -> Collection[CPUDevice]: + cores = await libnuma.get_available_cores() + overcommit_factor = int(os.environ.get('BACKEND_CPU_OVERCOMMIT_FACTOR', '1')) + assert 1 <= overcommit_factor <= 10 + return [ + CPUDevice( + device_id=DeviceId(str(core_idx)), + hw_location='root', + numa_node=libnuma.node_of_cpu(core_idx), + memory_size=0, + processing_units=1 * overcommit_factor, + ) + for core_idx in sorted(cores) + ] + + async def available_slots(self) -> Mapping[SlotName, Decimal]: + devices = await self.list_devices() + return { + SlotName('cpu'): Decimal(sum(dev.processing_units for dev in devices)), + } + + def get_version(self) -> str: + return __version__ + + async def extra_info(self) -> Mapping[str, str]: + return { + 'agent_version': __version__, + 'machine': platform.machine(), + 'os_type': platform.system(), + } + + async def gather_node_measures(self, ctx: StatContext) -> Sequence[NodeMeasurement]: + _cstat = psutil.cpu_times(True) + q = Decimal('0.000') + total_cpu_used = cast(Decimal, + sum((Decimal(c.user + c.system) * 1000).quantize(q) for c in _cstat)) + now, raw_interval = ctx.update_timestamp('cpu-node') + interval = Decimal(raw_interval * 1000).quantize(q) + + return [ + NodeMeasurement( + MetricKey('cpu_util'), + MetricTypes.UTILIZATION, + unit_hint='msec', + current_hook=lambda metric: metric.stats.diff, + per_node=Measurement(total_cpu_used, interval), + per_device={ + DeviceId(str(idx)): + Measurement( + (Decimal(c.user + c.system) * 1000).quantize(q), + interval, + ) + for idx, c in enumerate(_cstat) + }, + ), + ] + + async def gather_container_measures( + self, + ctx: StatContext, + container_ids: Sequence[str], + ) -> Sequence[ContainerMeasurement]: + + async def sysfs_impl(container_id): + cpu_prefix = f'/sys/fs/cgroup/cpuacct/docker/{container_id}/' + try: + cpu_used = read_sysfs(cpu_prefix + 'cpuacct.usage', int) / 1e6 + except IOError as e: + log.warning('cannot read stats: sysfs unreadable for container {0}\n{1!r}', + container_id[:7], e) + return None + return cpu_used + + async def api_impl(container_id): + async with closing_async(Docker()) as docker: + container = DockerContainer(docker, id=container_id) + try: + async with async_timeout.timeout(2.0): + ret = await fetch_api_stats(container) + except asyncio.TimeoutError: + return None + if ret is None: + return None + cpu_used = nmget(ret, 'cpu_stats.cpu_usage.total_usage', 0) / 1e6 + return cpu_used + + if ctx.mode == StatModes.CGROUP: + impl = sysfs_impl + elif ctx.mode == StatModes.DOCKER: + impl = api_impl + else: + raise RuntimeError("should not reach here") + + q = Decimal('0.000') + per_container_cpu_used = {} + per_container_cpu_util = {} + tasks = [] + for cid in container_ids: + tasks.append(asyncio.ensure_future(impl(cid))) + results = await asyncio.gather(*tasks) + for cid, cpu_used in zip(container_ids, results): + if cpu_used is None: + continue + per_container_cpu_used[cid] = Measurement(Decimal(cpu_used).quantize(q)) + per_container_cpu_util[cid] = Measurement( + Decimal(cpu_used).quantize(q), + capacity=Decimal(1000), + ) + return [ + ContainerMeasurement( + MetricKey('cpu_util'), + MetricTypes.UTILIZATION, + unit_hint='percent', + current_hook=lambda metric: metric.stats.rate, + stats_filter=frozenset({'avg', 'max'}), + per_container=per_container_cpu_util, + ), + ContainerMeasurement( + MetricKey('cpu_used'), + MetricTypes.USAGE, + unit_hint='msec', + per_container=per_container_cpu_used, + ), + ] + + async def create_alloc_map(self) -> AbstractAllocMap: + devices = await self.list_devices() + return DiscretePropertyAllocMap( + device_slots={ + dev.device_id: + DeviceSlotInfo(SlotTypes.COUNT, SlotName('cpu'), Decimal(dev.processing_units)) + for dev in devices + }, + ) + + async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: + # TODO: move the sysconf hook in libbaihook.so here + return [] + + async def generate_docker_args( + self, + docker: Docker, + device_alloc, + ) -> Mapping[str, Any]: + cores = [*map(int, device_alloc['cpu'].keys())] + sorted_core_ids = [*map(str, sorted(cores))] + return { + 'HostConfig': { + 'CpuPeriod': 100_000, # docker default + 'CpuQuota': int(100_000 * len(cores)), + 'Cpus': ','.join(sorted_core_ids), + 'CpusetCpus': ','.join(sorted_core_ids), + # 'CpusetMems': f'{resource_spec.numa_node}', + }, + } + + async def restore_from_container( + self, + container: Container, + alloc_map: AbstractAllocMap, + ) -> None: + assert isinstance(alloc_map, DiscretePropertyAllocMap) + # Docker does not return the original cpuset.... :( + # We need to read our own records. + resource_spec = await get_resource_spec_from_container(container.backend_obj) + if resource_spec is None: + return + alloc_map.apply_allocation({ + SlotName('cpu'): + resource_spec.allocations[DeviceName('cpu')][SlotName('cpu')], + }) + + async def get_attached_devices( + self, + device_alloc: Mapping[SlotName, + Mapping[DeviceId, Decimal]], + ) -> Sequence[DeviceModelInfo]: + device_ids = [*device_alloc[SlotName('cpu')].keys()] + available_devices = await self.list_devices() + attached_devices: List[DeviceModelInfo] = [] + for device in available_devices: + if device.device_id in device_ids: + attached_devices.append({ + 'device_id': device.device_id, + 'model_name': '', + 'data': {'cores': len(device_ids)}, + }) + return attached_devices + + +class MemoryDevice(AbstractComputeDevice): + pass + + +class MemoryPlugin(AbstractComputePlugin): + """ + Represents the main memory. + + When collecting statistics, it also measures network and I/O usage + in addition to the memory usage. + """ + + config_watch_enabled = False + + key = DeviceName('mem') + slot_types = [ + (SlotName('mem'), SlotTypes.BYTES), + ] + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None: + pass + + async def list_devices(self) -> Collection[MemoryDevice]: + # TODO: support NUMA? + memory_size = psutil.virtual_memory().total + overcommit_factor = int(os.environ.get('BACKEND_MEM_OVERCOMMIT_FACTOR', '1')) + return [MemoryDevice( + device_id=DeviceId('root'), + hw_location='root', + numa_node=0, + memory_size=overcommit_factor * memory_size, + processing_units=0, + )] + + async def available_slots(self) -> Mapping[SlotName, Decimal]: + devices = await self.list_devices() + return { + SlotName('mem'): Decimal(sum(dev.memory_size for dev in devices)), + } + + def get_version(self) -> str: + return __version__ + + async def extra_info(self) -> Mapping[str, str]: + return {} + + async def gather_node_measures(self, ctx: StatContext) -> Sequence[NodeMeasurement]: + _mstat = psutil.virtual_memory() + total_mem_used_bytes = Decimal(_mstat.total - _mstat.available) + total_mem_capacity_bytes = Decimal(_mstat.total) + _nstat = psutil.net_io_counters() + net_rx_bytes = _nstat.bytes_recv + net_tx_bytes = _nstat.bytes_sent + + def get_disk_stat(): + pruned_disk_types = frozenset(['squashfs', 'vfat', 'tmpfs']) + total_disk_usage = Decimal(0) + total_disk_capacity = Decimal(0) + per_disk_stat = {} + for disk_info in psutil.disk_partitions(): + if disk_info.fstype not in pruned_disk_types: + dstat = os.statvfs(disk_info.mountpoint) + disk_usage = Decimal(dstat.f_frsize * (dstat.f_blocks - dstat.f_bavail)) + disk_capacity = Decimal(dstat.f_frsize * dstat.f_blocks) + per_disk_stat[disk_info.device] = Measurement(disk_usage, disk_capacity) + total_disk_usage += disk_usage + total_disk_capacity += disk_capacity + return total_disk_usage, total_disk_capacity, per_disk_stat + + loop = current_loop() + total_disk_usage, total_disk_capacity, per_disk_stat = \ + await loop.run_in_executor(None, get_disk_stat) + return [ + NodeMeasurement( + MetricKey('mem'), + MetricTypes.USAGE, + unit_hint='bytes', + stats_filter=frozenset({'max'}), + per_node=Measurement(total_mem_used_bytes, total_mem_capacity_bytes), + per_device={DeviceId('root'): + Measurement(total_mem_used_bytes, + total_mem_capacity_bytes)}, + ), + NodeMeasurement( + MetricKey('disk'), + MetricTypes.USAGE, + unit_hint='bytes', + per_node=Measurement(total_disk_usage, total_disk_capacity), + per_device=per_disk_stat, + ), + NodeMeasurement( + MetricKey('net_rx'), + MetricTypes.RATE, + unit_hint='bps', + current_hook=lambda metric: metric.stats.rate, + per_node=Measurement(Decimal(net_rx_bytes)), + per_device={DeviceId('node'): Measurement(Decimal(net_rx_bytes))}, + ), + NodeMeasurement( + MetricKey('net_tx'), + MetricTypes.RATE, + unit_hint='bps', + current_hook=lambda metric: metric.stats.rate, + per_node=Measurement(Decimal(net_tx_bytes)), + per_device={DeviceId('node'): Measurement(Decimal(net_tx_bytes))}, + ), + ] + + async def gather_container_measures(self, ctx: StatContext, container_ids: Sequence[str]) \ + -> Sequence[ContainerMeasurement]: + + def get_scratch_size(container_id: str) -> int: + # Temporarily disabled as this function incurs too much delay with + # a large number of files in scratch dirs, causing indefinite accumulation of + # stat collector tasks and slowing down everything. + return 0 + # for kernel_id, info in ctx.agent.kernel_registry.items(): + # if info['container_id'] == container_id: + # break + # else: + # return 0 + # work_dir = ctx.agent.local_config['container']['scratch-root'] / str(kernel_id) / 'work' + # total_size = 0 + # for path in work_dir.rglob('*'): + # if path.is_symlink(): + # total_size += path.lstat().st_size + # elif path.is_file(): + # total_size += path.stat().st_size + # return total_size + + async def sysfs_impl(container_id): + mem_prefix = f'/sys/fs/cgroup/memory/docker/{container_id}/' + io_prefix = f'/sys/fs/cgroup/blkio/docker/{container_id}/' + try: + mem_cur_bytes = read_sysfs(mem_prefix + 'memory.usage_in_bytes', int) + io_stats = Path(io_prefix + 'blkio.throttle.io_service_bytes').read_text() + # example data: + # 8:0 Read 13918208 + # 8:0 Write 0 + # 8:0 Sync 0 + # 8:0 Async 13918208 + # 8:0 Total 13918208 + # Total 13918208 + io_read_bytes = 0 + io_write_bytes = 0 + for line in io_stats.splitlines(): + if line.startswith('Total '): + continue + dev, op, nbytes = line.strip().split() + if op == 'Read': + io_read_bytes += int(nbytes) + elif op == 'Write': + io_write_bytes += int(nbytes) + except IOError as e: + log.warning('cannot read stats: sysfs unreadable for container {0}\n{1!r}', + container_id[:7], e) + return None + loop = current_loop() + scratch_sz = await loop.run_in_executor( + None, get_scratch_size, container_id) + return mem_cur_bytes, io_read_bytes, io_write_bytes, scratch_sz + + async def api_impl(container_id): + async with closing_async(Docker()) as docker: + container = DockerContainer(docker, id=container_id) + try: + async with async_timeout.timeout(2.0): + ret = await fetch_api_stats(container) + except asyncio.TimeoutError: + return None + if ret is None: + return None + mem_cur_bytes = nmget(ret, 'memory_stats.usage', 0) + io_read_bytes = 0 + io_write_bytes = 0 + for item in nmget(ret, 'blkio_stats.io_service_bytes_recursive', []): + if item['op'] == 'Read': + io_read_bytes += item['value'] + elif item['op'] == 'Write': + io_write_bytes += item['value'] + loop = current_loop() + scratch_sz = await loop.run_in_executor( + None, get_scratch_size, container_id) + return mem_cur_bytes, io_read_bytes, io_write_bytes, scratch_sz + + if ctx.mode == StatModes.CGROUP: + impl = sysfs_impl + elif ctx.mode == StatModes.DOCKER: + impl = api_impl + else: + raise RuntimeError("should not reach here") + + per_container_mem_used_bytes = {} + per_container_io_read_bytes = {} + per_container_io_write_bytes = {} + per_container_io_scratch_size = {} + tasks = [] + for cid in container_ids: + tasks.append(asyncio.ensure_future(impl(cid))) + results = await asyncio.gather(*tasks) + for cid, result in zip(container_ids, results): + if result is None: + continue + per_container_mem_used_bytes[cid] = Measurement( + Decimal(result[0])) + per_container_io_read_bytes[cid] = Measurement( + Decimal(result[1])) + per_container_io_write_bytes[cid] = Measurement( + Decimal(result[2])) + per_container_io_scratch_size[cid] = Measurement( + Decimal(result[3])) + return [ + ContainerMeasurement( + MetricKey('mem'), + MetricTypes.USAGE, + unit_hint='bytes', + stats_filter=frozenset({'max'}), + per_container=per_container_mem_used_bytes, + ), + ContainerMeasurement( + MetricKey('io_read'), + MetricTypes.USAGE, + unit_hint='bytes', + stats_filter=frozenset({'rate'}), + per_container=per_container_io_read_bytes, + ), + ContainerMeasurement( + MetricKey('io_write'), + MetricTypes.USAGE, + unit_hint='bytes', + stats_filter=frozenset({'rate'}), + per_container=per_container_io_write_bytes, + ), + ContainerMeasurement( + MetricKey('io_scratch_size'), + MetricTypes.USAGE, + unit_hint='bytes', + stats_filter=frozenset({'max'}), + per_container=per_container_io_scratch_size, + ), + ] + + async def create_alloc_map(self) -> AbstractAllocMap: + devices = await self.list_devices() + return DiscretePropertyAllocMap( + device_slots={ + dev.device_id: + DeviceSlotInfo(SlotTypes.BYTES, SlotName('mem'), Decimal(dev.memory_size)) + for dev in devices + }, + ) + + async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: + return [] + + async def generate_docker_args( + self, + docker: Docker, + device_alloc, + ) -> Mapping[str, Any]: + memory = sum(device_alloc['mem'].values()) + return { + 'HostConfig': { + 'MemorySwap': int(memory), # prevent using swap! + 'Memory': int(memory), + }, + } + + async def restore_from_container( + self, + container: Container, + alloc_map: AbstractAllocMap, + ) -> None: + assert isinstance(alloc_map, DiscretePropertyAllocMap) + memory_limit = container.backend_obj['HostConfig']['Memory'] + alloc_map.apply_allocation({ + SlotName('mem'): {DeviceId('root'): memory_limit}, + }) + + async def get_attached_devices( + self, + device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> Sequence[DeviceModelInfo]: + device_ids = [*device_alloc[SlotName('mem')].keys()] + available_devices = await self.list_devices() + attached_devices: List[DeviceModelInfo] = [] + for device in available_devices: + if device.device_id in device_ids: + attached_devices.append({ + 'device_id': device.device_id, + 'model_name': '', + 'data': {}, + }) + return attached_devices diff --git a/src/ai/backend/agent/docker/kernel.py b/src/ai/backend/agent/docker/kernel.py new file mode 100644 index 0000000000..1de118c60a --- /dev/null +++ b/src/ai/backend/agent/docker/kernel.py @@ -0,0 +1,413 @@ +import asyncio +import logging +import lzma +import os +from pathlib import Path, PurePosixPath +import pkg_resources +import re +import shutil +import subprocess +import textwrap +from typing import ( + Any, Optional, + Mapping, Dict, + FrozenSet, + Sequence, Tuple, +) + +from aiodocker.docker import Docker, DockerVolume +from aiodocker.exceptions import DockerError +from aiotools import TaskGroup + +from ai.backend.common.docker import ImageRef +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import KernelId +from ai.backend.common.utils import current_loop + +from ai.backend.agent.docker.utils import PersistentServiceContainer +from ..resources import KernelResourceSpec +from ..kernel import AbstractKernel, AbstractCodeRunner +from ..utils import closing_async, get_arch_name + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class DockerKernel(AbstractKernel): + + def __init__( + self, kernel_id: KernelId, image: ImageRef, version: int, *, + agent_config: Mapping[str, Any], + resource_spec: KernelResourceSpec, + service_ports: Any, # TODO: type-annotation + environ: Mapping[str, Any], + data: Dict[str, Any], + ) -> None: + super().__init__( + kernel_id, image, version, + agent_config=agent_config, + resource_spec=resource_spec, + service_ports=service_ports, + data=data, + environ=environ, + ) + + async def close(self) -> None: + pass + + def __getstate__(self): + props = super().__getstate__() + return props + + def __setstate__(self, props): + super().__setstate__(props) + + async def create_code_runner(self, *, + client_features: FrozenSet[str], + api_version: int) -> AbstractCodeRunner: + return await DockerCodeRunner.new( + self.kernel_id, + kernel_host=self.data['kernel_host'], + repl_in_port=self.data['repl_in_port'], + repl_out_port=self.data['repl_out_port'], + exec_timeout=0, + client_features=client_features) + + async def get_completions(self, text: str, opts: Mapping[str, Any]): + result = await self.runner.feed_and_get_completion(text, opts) + return {'status': 'finished', 'completions': result} + + async def check_status(self): + result = await self.runner.feed_and_get_status() + return result + + async def get_logs(self): + container_id = self.data['container_id'] + async with closing_async(Docker()) as docker: + container = await docker.containers.get(container_id) + logs = await container.log(stdout=True, stderr=True) + return {'logs': ''.join(logs)} + + async def interrupt_kernel(self): + await self.runner.feed_interrupt() + return {'status': 'finished'} + + async def start_service(self, service: str, opts: Mapping[str, Any]): + if self.data.get('block_service_ports', False): + return { + 'status': 'failed', + 'error': 'operation blocked', + } + for sport in self.service_ports: + if sport['name'] == service: + break + else: + return {'status': 'failed', 'error': 'invalid service name'} + result = await self.runner.feed_start_service({ + 'name': service, + 'port': sport['container_ports'][0], # primary port + 'ports': sport['container_ports'], + 'protocol': sport['protocol'], + 'options': opts, + }) + return result + + async def shutdown_service(self, service: str): + await self.runner.feed_shutdown_service(service) + + async def get_service_apps(self): + result = await self.runner.feed_service_apps() + return result + + async def accept_file(self, filename: str, filedata: bytes): + loop = current_loop() + work_dir = self.agent_config['container']['scratch-root'] / str(self.kernel_id) / 'work' + try: + # create intermediate directories in the path + dest_path = (work_dir / filename).resolve(strict=False) + parent_path = dest_path.parent + except ValueError: # parent_path does not start with work_dir! + raise AssertionError('malformed upload filename and path.') + + def _write_to_disk(): + parent_path.mkdir(parents=True, exist_ok=True) + dest_path.write_bytes(filedata) + + try: + await loop.run_in_executor(None, _write_to_disk) + except FileNotFoundError: + log.error('{0}: writing uploaded file failed: {1} -> {2}', + self.kernel_id, filename, dest_path) + + async def download_file(self, filepath: str): + container_id = self.data['container_id'] + async with closing_async(Docker()) as docker: + container = docker.containers.container(container_id) + home_path = PurePosixPath('/home/work') + try: + abspath = (home_path / filepath) + abspath.relative_to(home_path) + except ValueError: + raise PermissionError('You cannot download files outside /home/work') + try: + with await container.get_archive(str(abspath)) as tarobj: + tarobj.fileobj.seek(0, 2) + fsize = tarobj.fileobj.tell() + if fsize > 1048576: + raise ValueError('too large file') + tarbytes = tarobj.fileobj.getvalue() + except DockerError: + log.warning('Could not found the file: {0}', abspath) + raise FileNotFoundError(f'Could not found the file: {abspath}') + return tarbytes + + async def list_files(self, container_path: str): + container_id = self.data['container_id'] + + # Confine the lookable paths in the home directory + home_path = Path('/home/work') + try: + resolved_path = (home_path / container_path).resolve() + resolved_path.relative_to(home_path) + except ValueError: + raise PermissionError('You cannot list files outside /home/work') + + # Gather individual file information in the target path. + code = textwrap.dedent(''' + import json + import os + import stat + import sys + + files = [] + for f in os.scandir(sys.argv[1]): + fstat = f.stat() + ctime = fstat.st_ctime # TODO: way to get concrete create time? + mtime = fstat.st_mtime + atime = fstat.st_atime + files.append({ + 'mode': stat.filemode(fstat.st_mode), + 'size': fstat.st_size, + 'ctime': ctime, + 'mtime': mtime, + 'atime': atime, + 'filename': f.name, + }) + print(json.dumps(files)) + ''') + proc = await asyncio.create_subprocess_exec( + *[ + 'docker', 'exec', container_id, + '/opt/backend.ai/bin/python', '-c', code, + str(container_path), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + raw_out, raw_err = await proc.communicate() + out = raw_out.decode('utf-8') + err = raw_err.decode('utf-8') + return {'files': out, 'errors': err, 'abspath': str(container_path)} + + +class DockerCodeRunner(AbstractCodeRunner): + + kernel_host: str + repl_in_port: int + repl_out_port: int + + def __init__(self, kernel_id, *, + kernel_host, repl_in_port, repl_out_port, + exec_timeout=0, client_features=None) -> None: + super().__init__( + kernel_id, + exec_timeout=exec_timeout, + client_features=client_features) + self.kernel_host = kernel_host + self.repl_in_port = repl_in_port + self.repl_out_port = repl_out_port + + async def get_repl_in_addr(self) -> str: + return f'tcp://{self.kernel_host}:{self.repl_in_port}' + + async def get_repl_out_addr(self) -> str: + return f'tcp://{self.kernel_host}:{self.repl_out_port}' + + +async def prepare_krunner_env_impl(distro: str) -> Tuple[str, Optional[str]]: + if distro.startswith('static-'): + distro_name = distro.replace('-', '_') # pkg/mod name use underscores + else: + if (m := re.search(r'^([a-z]+)\d+\.\d+$', distro)) is None: + raise ValueError('Unrecognized "distro[version]" format string.') + distro_name = m.group(1) + docker = Docker() + arch = get_arch_name() + current_version = int(Path( + pkg_resources.resource_filename( + f'ai.backend.krunner.{distro_name}', + f'./krunner-version.{distro}.txt')) + .read_text().strip()) + volume_name = f'backendai-krunner.v{current_version}.{arch}.{distro}' + extractor_image = 'backendai-krunner-extractor:latest' + + try: + for item in (await docker.images.list()): + if item['RepoTags'] is None: + continue + if item['RepoTags'][0] == extractor_image: + break + else: + log.info('preparing the Docker image for krunner extractor...') + extractor_archive = pkg_resources.resource_filename( + 'ai.backend.runner', f'krunner-extractor.img.{arch}.tar.xz') + with lzma.open(extractor_archive, 'rb') as reader: + proc = await asyncio.create_subprocess_exec( + *['docker', 'load'], stdin=reader) + if (await proc.wait() != 0): + raise RuntimeError('loading krunner extractor image has failed!') + + log.info('checking krunner-env for {}...', distro) + do_create = False + try: + vol = DockerVolume(docker, volume_name) + await vol.show() + # Instead of checking the version from txt files inside the volume, + # we check the version via the volume name and its existence. + # This is because: + # - to avoid overwriting of volumes in use. + # - the name comparison is quicker than reading actual files. + except DockerError as e: + if e.status == 404: + do_create = True + if do_create: + archive_path = Path(pkg_resources.resource_filename( + f'ai.backend.krunner.{distro_name}', + f'krunner-env.{distro}.{arch}.tar.xz')).resolve() + if not archive_path.exists(): + log.warning("krunner environment for {} ({}) is not supported!", distro, arch) + else: + log.info('populating {} volume version {}', + volume_name, current_version) + await docker.volumes.create({ + 'Name': volume_name, + 'Driver': 'local', + }) + extractor_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', + 'krunner-extractor.sh')).resolve() + proc = await asyncio.create_subprocess_exec(*[ + 'docker', 'run', '--rm', '-i', + '-v', f'{archive_path}:/root/archive.tar.xz', + '-v', f'{extractor_path}:/root/krunner-extractor.sh', + '-v', f'{volume_name}:/root/volume', + '-e', f'KRUNNER_VERSION={current_version}', + extractor_image, + '/root/krunner-extractor.sh', + ]) + if (await proc.wait() != 0): + raise RuntimeError('extracting krunner environment has failed!') + except Exception: + log.exception('unexpected error') + return distro, None + finally: + await docker.close() + return distro, volume_name + + +async def prepare_krunner_env(local_config: Mapping[str, Any]) -> Mapping[str, Sequence[str]]: + """ + Check if the volume "backendai-krunner.{distro}.{arch}" exists and is up-to-date. + If not, automatically create it and update its content from the packaged pre-built krunner + tar archives. + """ + + all_distros = [] + entry_prefix = 'backendai_krunner_v10' + for entrypoint in pkg_resources.iter_entry_points(entry_prefix): + log.debug('loading krunner pkg: {}', entrypoint.module_name) + plugin = entrypoint.load() + await plugin.init({}) # currently does nothing + provided_versions = Path(pkg_resources.resource_filename( + f'ai.backend.krunner.{entrypoint.name}', + 'versions.txt', + )).read_text().splitlines() + all_distros.extend(provided_versions) + + tasks = [] + async with TaskGroup() as tg: + for distro in all_distros: + tasks.append(tg.create_task(prepare_krunner_env_impl(distro))) + distro_volumes = [t.result() for t in tasks if not t.cancelled()] + result = {} + for distro_name_and_version, volume_name in distro_volumes: + if volume_name is None: + continue + result[distro_name_and_version] = volume_name + return result + + +LinuxKit_IPTABLES_RULE = \ + re.compile(r'DNAT\s+tcp\s+\-\-\s+anywhere\s+169\.254\.169\.254\s+tcp dpt:http to:127\.0\.0\.1:50128') +LinuxKit_CMD_EXEC_PREFIX = [ + 'docker', 'run', '--rm', '-i', + '--privileged', '--pid=host', + 'linuxkit-nsenter:latest', +] + + +async def prepare_kernel_metadata_uri_handling(local_config: Mapping[str, Any]) -> None: + async with closing_async(Docker()) as docker: + kernel_version = (await docker.version())['KernelVersion'] + if 'linuxkit' in kernel_version: + local_config['agent']['docker-mode'] = 'linuxkit' + # Docker Desktop mode + arch = get_arch_name() + proxy_worker_binary = pkg_resources.resource_filename( + 'ai.backend.agent.docker', + f'linuxkit-metadata-proxy-worker.{arch}.bin') + shutil.copyfile(proxy_worker_binary, '/tmp/backend.ai/linuxkit-metadata-proxy') + os.chmod('/tmp/backend.ai/linuxkit-metadata-proxy', 0o755) + # Prepare proxy worker container + proxy_worker_container = PersistentServiceContainer( + 'linuxkit-nsenter:latest', + { + 'Cmd': [ + '/bin/sh', '-c', + 'ctr -n services.linuxkit t kill --exec-id metaproxy docker;' + 'ctr -n services.linuxkit t exec --exec-id metaproxy docker ' + '/host_mnt/tmp/backend.ai/linuxkit-metadata-proxy -remote-port 40128', + ], + 'HostConfig': { + 'PidMode': 'host', + 'Privileged': True, + }, + }, + name='linuxkit-nsenter', + ) + await proxy_worker_container.ensure_running_latest() + + # Check if iptables rule is propagated on LinuxKit VM properly + log.info('Checking metadata URL iptables rule ...') + proc = await asyncio.create_subprocess_exec(*( + LinuxKit_CMD_EXEC_PREFIX + + ['/sbin/iptables', '-n', '-t', 'nat', '-L', 'PREROUTING'] + ), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + await proc.wait() + assert proc.stdout is not None + raw_rules = await proc.stdout.read() + rules = raw_rules.decode() + if LinuxKit_IPTABLES_RULE.search(rules) is None: + proc = await asyncio.create_subprocess_exec(*( + LinuxKit_CMD_EXEC_PREFIX + + [ + '/sbin/iptables', '-t', 'nat', '-I', 'PREROUTING', + '-d', '169.254.169.254', '-p', 'tcp', '--dport', '80', + '-j', 'DNAT', '--to-destination', '127.0.0.1:50128', + ] + )) + await proc.wait() + log.info('Inserted the iptables rules.') + else: + log.info('The iptables rule already exists.') + else: + # Linux Mode + local_config['agent']['docker-mode'] = 'native' diff --git a/src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.aarch64.bin b/src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.aarch64.bin new file mode 100755 index 0000000000..75f59e76ec --- /dev/null +++ b/src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d3d885ca298969b79981b8023c07fe7ba148058a58ec6a07e2fbc6317899668 +size 6558515 diff --git a/src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.x86_64.bin b/src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.x86_64.bin new file mode 100755 index 0000000000..522ce4442a --- /dev/null +++ b/src/ai/backend/agent/docker/linuxkit-metadata-proxy-worker.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fab02b10a487492942a205702bfcb939f451e9e754f30c85e02309a3c3fb757 +size 6731237 diff --git a/src/ai/backend/agent/docker/linuxkit-nsenter.img.aarch64.tar.gz b/src/ai/backend/agent/docker/linuxkit-nsenter.img.aarch64.tar.gz new file mode 100644 index 0000000000..d5f882ee11 --- /dev/null +++ b/src/ai/backend/agent/docker/linuxkit-nsenter.img.aarch64.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d2412a7570592e1ee37c0ccb6b71316b95b77aa5fde6046173874472baf2ad9 +size 39823 diff --git a/src/ai/backend/agent/docker/linuxkit-nsenter.img.x86_64.tar.gz b/src/ai/backend/agent/docker/linuxkit-nsenter.img.x86_64.tar.gz new file mode 100644 index 0000000000..689b4cf8e7 --- /dev/null +++ b/src/ai/backend/agent/docker/linuxkit-nsenter.img.x86_64.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53cb39555fba5c675e17918085214e9546ff4cc183b636d2d999171e8faeaec5 +size 32891 diff --git a/src/ai/backend/agent/docker/linuxkit-nsenter.version.txt b/src/ai/backend/agent/docker/linuxkit-nsenter.version.txt new file mode 100644 index 0000000000..56a6051ca2 --- /dev/null +++ b/src/ai/backend/agent/docker/linuxkit-nsenter.version.txt @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/src/ai/backend/agent/docker/metadata/server.py b/src/ai/backend/agent/docker/metadata/server.py new file mode 100644 index 0000000000..d164ed51f9 --- /dev/null +++ b/src/ai/backend/agent/docker/metadata/server.py @@ -0,0 +1,64 @@ +import logging +from typing import Any, Mapping +from uuid import UUID +from ai.backend.common.types import KernelId + +from aiodocker.docker import Docker +from aiohttp import web + +from ai.backend.agent.utils import closing_async +from ai.backend.common.logging import BraceStyleAdapter + +from ai.backend.agent.docker.kernel import DockerKernel + +from ai.backend.agent.kernel import AbstractKernel +from aiohttp.typedefs import Handler + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@web.middleware +async def container_resolver(request: web.Request, handler: Handler): + if request.headers.get('X-Forwarded-For') is not None and request.app['docker-mode'] == 'linuxkit': + container_ip = request.headers['X-Forwarded-For'] + elif remote_ip := request.remote: + container_ip = remote_ip + else: + return web.Response(status=403) + async with closing_async(Docker()) as docker: + containers = await docker.containers.list( + filters='{"label":["ai.backend.kernel-id"],"network":["bridge"],"status":["running"]}', + ) + target_container = list(filter( + lambda x: x['NetworkSettings']['Networks'].get('bridge', {}).get('IPAddress') == container_ip, + containers, + )) + + if len(target_container) == 0: + return web.Response(status=403) + request['container-ip'] = container_ip + request['container'] = target_container[0] + return await handler(request) + + +async def get_metadata(request: web.Request) -> web.Response: + kernel: DockerKernel = \ + request.app['kernel-registry'].get(UUID(request['container']['Labels']['ai.backend.kernel-id'])) + if kernel is None: + return web.Response(status=404) + response = dict(kernel.environ) + return web.json_response(response) + + +# TODO: Split out metadata server as seperate backend.ai plugin +async def create_server( + local_config: Mapping[str, Any], + kernel_registry: Mapping[KernelId, AbstractKernel], +) -> web.Application: + app = web.Application( + middlewares=[container_resolver], + ) + app['docker-mode'] = local_config['agent']['docker-mode'] + app['kernel-registry'] = kernel_registry + app.router.add_route('GET', '/meta-data', get_metadata) + return app diff --git a/src/ai/backend/agent/docker/resources.py b/src/ai/backend/agent/docker/resources.py new file mode 100644 index 0000000000..8281235024 --- /dev/null +++ b/src/ai/backend/agent/docker/resources.py @@ -0,0 +1,98 @@ +from decimal import Decimal +import logging +from pathlib import Path +from typing import ( + Any, Optional, + Mapping, MutableMapping, + Tuple, +) + +import aiofiles + +from ai.backend.common.etcd import AsyncEtcd +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + DeviceName, SlotName, +) +from ..exception import InitializationError +from ..resources import ( + AbstractComputePlugin, ComputePluginContext, KernelResourceSpec, known_slot_types, +) + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +async def detect_resources( + etcd: AsyncEtcd, + local_config: Mapping[str, Any], +) -> Tuple[Mapping[DeviceName, AbstractComputePlugin], + Mapping[SlotName, Decimal]]: + """ + Detect available computing resource of the system. + It also loads the accelerator plugins. + + limit_cpus, limit_gpus are deprecated. + """ + reserved_slots = { + 'cpu': local_config['resource']['reserved-cpu'], + 'mem': local_config['resource']['reserved-mem'], + 'disk': local_config['resource']['reserved-disk'], + } + slots: MutableMapping[SlotName, Decimal] = {} + + compute_device_types: MutableMapping[DeviceName, AbstractComputePlugin] = {} + + # Initialize intrinsic plugins by ourselves. + from .intrinsic import CPUPlugin, MemoryPlugin + compute_plugin_ctx = ComputePluginContext( + etcd, local_config, + ) + await compute_plugin_ctx.init() + if 'cpu' not in compute_plugin_ctx.plugins: + cpu_config = await etcd.get_prefix('config/plugins/cpu') + cpu_plugin = CPUPlugin(cpu_config, local_config) + compute_plugin_ctx.attach_intrinsic_device(cpu_plugin) + if 'mem' not in compute_plugin_ctx.plugins: + memory_config = await etcd.get_prefix('config/plugins/memory') + memory_plugin = MemoryPlugin(memory_config, local_config) + compute_plugin_ctx.attach_intrinsic_device(memory_plugin) + for plugin_name, plugin_instance in compute_plugin_ctx.plugins.items(): + if not all( + (invalid_name := sname, sname.startswith(f'{plugin_instance.key}.'))[1] + for sname, _ in plugin_instance.slot_types + if sname not in {'cpu', 'mem'} + ): + raise InitializationError( + "Slot types defined by an accelerator plugin must be prefixed " + "by the plugin's key.", + invalid_name, # noqa: F821 + plugin_instance.key, + ) + if plugin_instance.key in compute_device_types: + raise InitializationError( + f"A plugin defining the same key '{plugin_instance.key}' already exists. " + "You may need to uninstall it first.") + compute_device_types[plugin_instance.key] = plugin_instance + + for key, computer in compute_device_types.items(): + known_slot_types.update(computer.slot_types) # type: ignore # (only updated here!) + resource_slots = await computer.available_slots() + for sname, sval in resource_slots.items(): + slots[sname] = Decimal(max(0, sval - reserved_slots.get(sname, 0))) + if slots[sname] <= 0 and sname in (SlotName('cpu'), SlotName('mem')): + raise InitializationError( + f"The resource slot '{sname}' is not sufficient (zero or below zero). " + "Try to adjust the reserved resources or use a larger machine.") + + log.info('Resource slots: {!r}', slots) + log.info('Slot types: {!r}', known_slot_types) + return compute_device_types, slots + + +async def get_resource_spec_from_container(container_info) -> Optional[KernelResourceSpec]: + for mount in container_info['HostConfig']['Mounts']: + if mount['Target'] == '/home/config': + async with aiofiles.open(Path(mount['Source']) / 'resource.txt', 'r') as f: # type: ignore + return await KernelResourceSpec.aread_from_file(f) + else: + return None diff --git a/src/ai/backend/agent/docker/utils.py b/src/ai/backend/agent/docker/utils.py new file mode 100644 index 0000000000..0fc8ff0bb9 --- /dev/null +++ b/src/ai/backend/agent/docker/utils.py @@ -0,0 +1,162 @@ +import asyncio +import gzip +import logging +import pkg_resources +import subprocess +from pathlib import Path +from typing import Any, BinaryIO, Mapping, Tuple, cast + +from aiodocker.docker import Docker +from aiodocker.exceptions import DockerError + +from ai.backend.common.logging import BraceStyleAdapter + +from ..exception import InitializationError +from ..utils import closing_async, get_arch_name, update_nested_dict + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class PersistentServiceContainer: + + def __init__( + self, + image_ref: str, + container_config: Mapping[str, Any], + *, + name: str = None, + ) -> None: + self.image_ref = image_ref + arch = get_arch_name() + default_container_name = image_ref.split(':')[0].rsplit('/', maxsplit=1)[-1] + if name is None: + self.container_name = default_container_name + else: + self.container_name = name + self.container_config = container_config + self.img_version = int(Path(pkg_resources.resource_filename( + 'ai.backend.agent.docker', + f'{default_container_name}.version.txt', + )).read_text()) + self.img_path = Path(pkg_resources.resource_filename( + 'ai.backend.agent.docker', + f'{default_container_name}.img.{arch}.tar.gz', + )) + + async def get_container_version_and_status(self) -> Tuple[int, bool]: + async with closing_async(Docker()) as docker: + try: + c = docker.containers.container(self.container_name) + await c.show() + except DockerError as e: + if e.status == 404: + return 0, False + else: + raise + if c['Config'].get('Labels', {}).get('ai.backend.system', '0') != '1': + raise RuntimeError( + f"An existing container named \"{c['Name'].lstrip('/')}\" is not a system container " + f"spawned by Backend.AI. Please check and remove it.") + return ( + int(c['Config'].get('Labels', {}).get('ai.backend.version', '0')), + c['State']['Status'].lower() == 'running', + ) + + async def get_image_version(self) -> int: + async with closing_async(Docker()) as docker: + try: + img = await docker.images.inspect(self.image_ref) + except DockerError as e: + if e.status == 404: + return 0 + else: + raise + return int((img['Config'].get('Labels') or {}).get('ai.backend.version', '0')) + + async def ensure_running_latest(self) -> None: + image_version = await self.get_image_version() + if image_version == 0: + log.info("PersistentServiceContainer({}): installing...", self.image_ref) + await self.install_latest() + elif image_version < self.img_version: + log.info("PersistentServiceContainer({}): upgrading (v{} -> v{})", + self.image_ref, image_version, self.img_version) + await self.install_latest() + container_version, is_running = await self.get_container_version_and_status() + if container_version == 0 or image_version != container_version or not is_running: + log.info("PersistentServiceContainer({}): recreating...", self.image_ref) + await self.recreate() + if not is_running: + log.info("PersistentServiceContainer({}): starting...", self.image_ref) + await self.start() + + async def install_latest(self) -> None: + with gzip.open(self.img_path, 'rb') as reader: + proc = await asyncio.create_subprocess_exec( + *['docker', 'load'], + stdin=cast(BinaryIO, reader), + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + ) + if (await proc.wait() != 0): + stderr = b'(unavailable)' + if proc.stderr is not None: + stderr = await proc.stderr.read() + raise RuntimeError( + 'loading the image has failed!', + self.image_ref, proc.returncode, stderr, + ) + + async def recreate(self) -> None: + async with closing_async(Docker()) as docker: + try: + c = docker.containers.container(self.container_name) + await c.stop() + await c.delete(force=True) + except DockerError as e: + if e.status == 409 and 'is not running' in e.message: + pass + elif e.status == 404: + pass + else: + raise + container_config = { + 'Image': self.image_ref, + 'Tty': True, + 'Privileged': False, + 'AttachStdin': False, + 'AttachStdout': False, + 'AttachStderr': False, + 'HostConfig': { + 'Init': True, + 'RestartPolicy': { + 'Name': 'unless-stopped', # make it persistent + 'MaximumRetryCount': 0, + }, + }, + } + update_nested_dict(container_config, self.container_config) + try: + await docker.containers.create(config=container_config, name=self.container_name) + except DockerError as e: + err_msg = e.args[1].get("message", "") + if ( + e.args[0] == 400 and + 'bind source path does not exist' in err_msg and + '/tmp/backend.ai/ipc' in err_msg + ): + raise InitializationError( + f"Could not create persistent service container '{self.container_name}' " + f"because it cannot access /tmp/backend.ai/ipc directory. " + f"This may occur when Docker is installed with Snap or the agent is configured " + f"to use a private tmp directory. " + f"To resolve, explicitly configure the 'ipc-base-path' option in agent.toml to " + f"indicate a directory under $HOME or a non-virtualized directory.", + ) + else: + raise + + async def start(self) -> None: + async with closing_async(Docker()) as docker: + c = docker.containers.container(self.container_name) + await c.start() diff --git a/src/ai/backend/agent/exception.py b/src/ai/backend/agent/exception.py new file mode 100644 index 0000000000..2b10d5c482 --- /dev/null +++ b/src/ai/backend/agent/exception.py @@ -0,0 +1,53 @@ +class InitializationError(Exception): + """ + Errors during agent initialization and compute plugin setup + """ + pass + + +class ResourceError(ValueError): + pass + + +class UnsupportedResource(ResourceError): + pass + + +class InvalidResourceCombination(ResourceError): + pass + + +class InvalidResourceArgument(ResourceError): + pass + + +class NotMultipleOfQuantum(InvalidResourceArgument): + pass + + +class InsufficientResource(ResourceError): + pass + + +class UnsupportedBaseDistroError(RuntimeError): + pass + + +class K8sError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message + + +class AgentError(RuntimeError): + ''' + A dummy exception class to distinguish agent-side errors passed via + aiozmq.rpc calls. + + It carrise two args tuple: the exception type and exception arguments from + the agent. + ''' + + def __init__(self, *args, exc_repr: str = None): + super().__init__(*args) + self.exc_repr = exc_repr diff --git a/src/ai/backend/agent/fs.py b/src/ai/backend/agent/fs.py new file mode 100644 index 0000000000..f30ef477ed --- /dev/null +++ b/src/ai/backend/agent/fs.py @@ -0,0 +1,42 @@ +from subprocess import CalledProcessError +import asyncio + + +async def create_scratch_filesystem(scratch_dir, size): + ''' + Create scratch folder size quota by using tmpfs filesystem. + + :param scratch_dir: The path of scratch directory. + + :param size: The quota size of scratch directory. + Size parameter is must be MiB(mebibyte). + ''' + + proc = await asyncio.create_subprocess_exec(*[ + 'mount', + '-t', 'tmpfs', + '-o', f'size={size}M', + 'tmpfs', f'{scratch_dir}', + ]) + exit_code = await proc.wait() + + if exit_code < 0: + raise CalledProcessError(proc.returncode, proc.args, + output=proc.stdout, stderr=proc.stderr) + + +async def destroy_scratch_filesystem(scratch_dir): + ''' + Destroy scratch folder size quota by using tmpfs filesystem. + + :param scratch_dir: The path of scratch directory. + ''' + proc = await asyncio.create_subprocess_exec(*[ + 'umount', + f'{scratch_dir}', + ]) + exit_code = await proc.wait() + + if exit_code < 0: + raise CalledProcessError(proc.returncode, proc.args, + output=proc.stdout, stderr=proc.stderr) diff --git a/src/ai/backend/agent/kernel.py b/src/ai/backend/agent/kernel.py new file mode 100644 index 0000000000..9c126e8707 --- /dev/null +++ b/src/ai/backend/agent/kernel.py @@ -0,0 +1,920 @@ +from __future__ import annotations + +from abc import abstractmethod, ABCMeta +import asyncio +import codecs +from collections import OrderedDict, UserDict +from dataclasses import dataclass +import io +import json +import logging +import math +import re +import secrets +import time +from typing import ( + Any, + Dict, + FrozenSet, + List, + Literal, + Mapping, + Optional, + Set, + Sequence, + Tuple, + TypedDict, + Union, +) + +from async_timeout import timeout +import zmq, zmq.asyncio + +from ai.backend.common import msgpack +from ai.backend.common.asyncio import current_loop +from ai.backend.common.docker import ImageRef +from ai.backend.common.enum_extension import StringSetFlag +from ai.backend.common.types import aobject, KernelId +from ai.backend.common.logging import BraceStyleAdapter +from .exception import UnsupportedBaseDistroError +from .resources import KernelResourceSpec + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +# msg types visible to the API client. +# (excluding control signals such as 'finished' and 'waiting-input' +# since they are passed as separate status field.) +ConsoleItemType = Literal[ + 'stdout', 'stderr', 'media', 'html', 'log', 'completion', +] +outgoing_msg_types: FrozenSet[ConsoleItemType] = frozenset([ + 'stdout', 'stderr', 'media', 'html', 'log', 'completion', +]) +ResultType = Union[ConsoleItemType, Literal[ + 'continued', + 'clean-finished', + 'build-finished', + 'finished', + 'exec-timeout', + 'waiting-input', +]] + + +class KernelFeatures(StringSetFlag): + UID_MATCH = 'uid-match' + USER_INPUT = 'user-input' + BATCH_MODE = 'batch' + QUERY_MODE = 'query' + TTY_MODE = 'tty' + + +class ClientFeatures(StringSetFlag): + INPUT = 'input' + CONTINUATION = 'continuation' + + +# TODO: use Python 3.7 contextvars for per-client feature selection +default_client_features = frozenset({ + ClientFeatures.INPUT.value, + ClientFeatures.CONTINUATION.value, +}) +default_api_version = 4 + + +class RunEvent(Exception): + + data: Any + + def __init__(self, data=None): + super().__init__() + self.data = data + + +class InputRequestPending(RunEvent): + pass + + +class CleanFinished(RunEvent): + pass + + +class BuildFinished(RunEvent): + pass + + +class RunFinished(RunEvent): + pass + + +class ExecTimeout(RunEvent): + pass + + +@dataclass +class ResultRecord: + msg_type: ResultType + data: Optional[str] = None + + +class NextResult(TypedDict, total=False): + runId: Optional[str] + status: ResultType + exitCode: Optional[int] + options: Optional[Mapping[str, Any]] + # v1 + stdout: Optional[str] + stderr: Optional[str] + media: Optional[Sequence[Any]] + html: Optional[Sequence[Any]] + # v2 + console: Optional[Sequence[Any]] + + +class AbstractKernel(UserDict, aobject, metaclass=ABCMeta): + + version: int + agent_config: Mapping[str, Any] + kernel_id: KernelId + image: ImageRef + resource_spec: KernelResourceSpec + service_ports: Any + data: Dict[Any, Any] + last_used: float + termination_reason: Optional[str] + clean_event: Optional[asyncio.Future] + stats_enabled: bool + # FIXME: apply TypedDict to data in Python 3.8 + environ: Mapping[str, Any] + + _tasks: Set[asyncio.Task] + + runner: 'AbstractCodeRunner' + + def __init__( + self, kernel_id: KernelId, image: ImageRef, version: int, *, + agent_config: Mapping[str, Any], + resource_spec: KernelResourceSpec, + service_ports: Any, # TODO: type-annotation + data: Dict[Any, Any], + environ: Mapping[str, Any], + ) -> None: + self.agent_config = agent_config + self.kernel_id = kernel_id + self.image = image + self.version = version + self.resource_spec = resource_spec + self.service_ports = service_ports + self.data = data + self.last_used = time.monotonic() + self.termination_reason = None + self.clean_event = None + self.stats_enabled = False + self._tasks = set() + self.environ = environ + + async def init(self) -> None: + log.debug('kernel.init(k:{0}, api-ver:{1}, client-features:{2}): ' + 'starting new runner', + self.kernel_id, default_api_version, default_client_features) + self.runner = await self.create_code_runner( + client_features=default_client_features, + api_version=default_api_version) + + def __getstate__(self) -> Mapping[str, Any]: + props = self.__dict__.copy() + del props['agent_config'] + del props['clean_event'] + del props['_tasks'] + return props + + def __setstate__(self, props) -> None: + self.__dict__.update(props) + # agent_config is set by the pickle.loads() caller. + self.clean_event = None + self._tasks = set() + + @abstractmethod + async def close(self) -> None: + """ + Release internal resources used for interacting with the kernel. + Note that this does NOT terminate the container. + """ + pass + + # We don't have "allocate_slots()" method here because: + # - resource_spec is initialized by allocating slots at computer's alloc_map + # when creating new kernels. + # - restoration from running containers is done by computer's classmethod + # "restore_from_container" + + def release_slots(self, computer_ctxs) -> None: + """ + Release the resource slots occupied by the kernel + to the allocation maps. + """ + for accel_key, accel_alloc in self.resource_spec.allocations.items(): + computer_ctxs[accel_key].alloc_map.free(accel_alloc) + + @abstractmethod + async def create_code_runner( + self, *, + client_features: FrozenSet[str], + api_version: int, + ) -> 'AbstractCodeRunner': + raise NotImplementedError + + @abstractmethod + async def check_status(self): + raise NotImplementedError + + @abstractmethod + async def get_completions(self, text, opts): + raise NotImplementedError + + @abstractmethod + async def get_logs(self): + raise NotImplementedError + + @abstractmethod + async def interrupt_kernel(self): + raise NotImplementedError + + @abstractmethod + async def start_service(self, service, opts): + raise NotImplementedError + + @abstractmethod + async def shutdown_service(self, service): + raise NotImplementedError + + @abstractmethod + async def get_service_apps(self): + raise NotImplementedError + + @abstractmethod + async def accept_file(self, filename, filedata): + raise NotImplementedError + + @abstractmethod + async def download_file(self, filepath): + raise NotImplementedError + + @abstractmethod + async def list_files(self, path: str): + raise NotImplementedError + + async def execute( + self, + run_id: Optional[str], + mode: Literal['batch', 'query', 'input', 'continue'], + text: str, + *, + opts: Mapping[str, Any], + api_version: int, + flush_timeout: float, + ) -> NextResult: + myself = asyncio.current_task() + assert myself is not None + self._tasks.add(myself) + try: + await self.runner.attach_output_queue(run_id) + try: + if mode == 'batch': + await self.runner.feed_batch(opts) + elif mode == 'query': + await self.runner.feed_code(text) + elif mode == 'input': + await self.runner.feed_input(text) + elif mode == 'continue': + pass + except zmq.ZMQError: + # cancel the operation by myself + # since the peer is gone. + raise asyncio.CancelledError + return await self.runner.get_next_result( + api_ver=api_version, + flush_timeout=flush_timeout, + ) + except asyncio.CancelledError: + await self.runner.close() + raise + finally: + self._tasks.remove(myself) + + +_zctx = None + + +class AbstractCodeRunner(aobject, metaclass=ABCMeta): + + kernel_id: KernelId + started_at: float + finished_at: Optional[float] + exec_timeout: float + max_record_size: int + client_features: FrozenSet[str] + + input_sock: zmq.asyncio.Socket + output_sock: zmq.asyncio.Socket + + completion_queue: asyncio.Queue[bytes] + service_queue: asyncio.Queue[bytes] + service_apps_info_queue: asyncio.Queue[bytes] + status_queue: asyncio.Queue[bytes] + output_queue: Optional[asyncio.Queue[ResultRecord]] + current_run_id: Optional[str] + pending_queues: OrderedDict[str, Tuple[asyncio.Event, asyncio.Queue[ResultRecord]]] + + read_task: Optional[asyncio.Task] + status_task: Optional[asyncio.Task] + watchdog_task: Optional[asyncio.Task] + + _closed: bool + + def __init__( + self, + kernel_id: KernelId, + *, + exec_timeout: float = 0, + client_features: FrozenSet[str] = None, + ) -> None: + global _zctx + self.kernel_id = kernel_id + self.started_at = time.monotonic() + self.finished_at = None + if not math.isfinite(exec_timeout) or exec_timeout < 0: + raise ValueError('execution timeout must be a zero or finite positive number.') + self.kernel_id = kernel_id + self.exec_timeout = exec_timeout + self.max_record_size = 10 * (2 ** 20) # 10 MBytes + self.client_features = client_features or frozenset() + if _zctx is None: + _zctx = zmq.asyncio.Context() + self.zctx = _zctx # share the global context + self.input_sock = self.zctx.socket(zmq.PUSH) + self.output_sock = self.zctx.socket(zmq.PULL) + self.completion_queue = asyncio.Queue(maxsize=128) + self.service_queue = asyncio.Queue(maxsize=128) + self.service_apps_info_queue = asyncio.Queue(maxsize=128) + self.status_queue = asyncio.Queue(maxsize=128) + self.output_queue = None + self.pending_queues = OrderedDict() + self.current_run_id = None + self.read_task = None + self.status_task = None + self.watchdog_task = None + self._closed = False + + async def __ainit__(self) -> None: + loop = current_loop() + self.input_sock.connect(await self.get_repl_in_addr()) + self.input_sock.setsockopt(zmq.LINGER, 50) + self.output_sock.connect(await self.get_repl_out_addr()) + self.output_sock.setsockopt(zmq.LINGER, 50) + self.status_task = loop.create_task(self.ping_status()) + self.read_task = loop.create_task(self.read_output()) + if self.exec_timeout > 0: + self.watchdog_task = loop.create_task(self.watchdog()) + else: + self.watchdog_task = None + + def __getstate__(self): + props = self.__dict__.copy() + del props['zctx'] + del props['input_sock'] + del props['output_sock'] + del props['completion_queue'] + del props['service_queue'] + del props['service_apps_info_queue'] + del props['status_queue'] + del props['output_queue'] + del props['pending_queues'] + del props['read_task'] + del props['status_task'] + del props['watchdog_task'] + del props['_closed'] + return props + + def __setstate__(self, props): + global _zctx + self.__dict__.update(props) + if _zctx is None: + _zctx = zmq.asyncio.Context() + self.zctx = _zctx # share the global context + self.input_sock = self.zctx.socket(zmq.PUSH) + self.output_sock = self.zctx.socket(zmq.PULL) + self.completion_queue = asyncio.Queue(maxsize=128) + self.service_queue = asyncio.Queue(maxsize=128) + self.service_apps_info_queue = asyncio.Queue(maxsize=128) + self.status_queue = asyncio.Queue(maxsize=128) + self.output_queue = None + self.pending_queues = OrderedDict() + self.read_task = None + self.status_task = None + self.watchdog_task = None + self._closed = False + # __ainit__() is called by the caller. + + @abstractmethod + async def get_repl_in_addr(self) -> str: + raise NotImplementedError + + @abstractmethod + async def get_repl_out_addr(self) -> str: + raise NotImplementedError + + async def close(self) -> None: + if self._closed: + return + self._closed = True + try: + if self.watchdog_task and not self.watchdog_task.done(): + self.watchdog_task.cancel() + await self.watchdog_task + if self.status_task and not self.status_task.done(): + self.status_task.cancel() + await self.status_task + if self.read_task and not self.read_task.done(): + self.read_task.cancel() + await self.read_task + if self.input_sock: + self.input_sock.close() + if self.output_sock: + self.output_sock.close() + # WARNING: + # destroying zmq contexts here with possibility of re-entrance + # may cause deadlocks. + except Exception: + log.exception("AbstractCodeRunner.close(): unexpected error") + + async def ping_status(self): + """ + This is to keep the REPL in/out port mapping in the Linux + kernel's NAT table alive. + """ + try: + while True: + ret = await self.feed_and_get_status() + if ret is None: + break + await asyncio.sleep(10) + except asyncio.CancelledError: + pass + except Exception: + log.exception("AbstractCodeRunner.ping_status(): unexpected error") + + async def feed_batch(self, opts): + if self.input_sock.closed: + raise asyncio.CancelledError + clean_cmd = opts.get('clean', '') + if clean_cmd is None: + clean_cmd = '' + await self.input_sock.send_multipart([ + b'clean', + clean_cmd.encode('utf8'), + ]) + build_cmd = opts.get('build', '') + if build_cmd is None: + build_cmd = '' + await self.input_sock.send_multipart([ + b'build', + build_cmd.encode('utf8'), + ]) + exec_cmd = opts.get('exec', '') + if exec_cmd is None: + exec_cmd = '' + await self.input_sock.send_multipart([ + b'exec', + exec_cmd.encode('utf8'), + ]) + + async def feed_code(self, text: str): + if self.input_sock.closed: + raise asyncio.CancelledError + await self.input_sock.send_multipart([b'code', text.encode('utf8')]) + + async def feed_input(self, text: str): + if self.input_sock.closed: + raise asyncio.CancelledError + await self.input_sock.send_multipart([b'input', text.encode('utf8')]) + + async def feed_interrupt(self): + if self.input_sock.closed: + raise asyncio.CancelledError + await self.input_sock.send_multipart([b'interrupt', b'']) + + async def feed_and_get_status(self): + if self.input_sock.closed: + raise asyncio.CancelledError + await self.input_sock.send_multipart([b'status', b'']) + try: + result = await self.status_queue.get() + self.status_queue.task_done() + return msgpack.unpackb(result) + except asyncio.CancelledError: + return None + + async def feed_and_get_completion(self, code_text, opts): + if self.input_sock.closed: + raise asyncio.CancelledError + payload = { + 'code': code_text, + } + payload.update(opts) + await self.input_sock.send_multipart([ + b'complete', + json.dumps(payload).encode('utf8'), + ]) + try: + result = await self.completion_queue.get() + self.completion_queue.task_done() + return json.loads(result) + except asyncio.CancelledError: + return [] + + async def feed_start_service(self, service_info): + if self.input_sock.closed: + raise asyncio.CancelledError + await self.input_sock.send_multipart([ + b'start-service', + json.dumps(service_info).encode('utf8'), + ]) + try: + with timeout(10): + result = await self.service_queue.get() + self.service_queue.task_done() + return json.loads(result) + except asyncio.CancelledError: + return {'status': 'failed', 'error': 'cancelled'} + except asyncio.TimeoutError: + return {'status': 'failed', 'error': 'timeout'} + + async def feed_shutdown_service(self, service_name: str): + if self.input_sock.closed: + raise asyncio.CancelledError + await self.input_sock.send_multipart([ + b'shutdown-service', + json.dumps(service_name).encode('utf8'), + ]) + + async def feed_service_apps(self): + await self.input_sock.send_multipart([ + b'get-apps', + ''.encode('utf8'), + ]) + try: + with timeout(10): + result = await self.service_apps_info_queue.get() + self.service_apps_info_queue.task_done() + return json.loads(result) + except asyncio.CancelledError: + return {'status': 'failed', 'error': 'cancelled'} + except asyncio.TimeoutError: + return {'status': 'failed', 'error': 'timeout'} + + async def watchdog(self) -> None: + try: + await asyncio.sleep(self.exec_timeout) + if self.output_queue is not None: + # TODO: what to do if None? + await self.output_queue.put(ResultRecord('exec-timeout', None)) + except asyncio.CancelledError: + pass + + @staticmethod + def aggregate_console(result: NextResult, records: Sequence[ResultRecord], api_ver: int) -> None: + + if api_ver == 1: + + stdout_items = [] + stderr_items = [] + media_items = [] + html_items = [] + + for rec in records: + if rec.msg_type == 'stdout': + stdout_items.append(rec.data or '') + elif rec.msg_type == 'stderr': + stderr_items.append(rec.data or '') + elif rec.msg_type == 'media' and rec.data is not None: + o = json.loads(rec.data) + media_items.append((o['type'], o['data'])) + elif rec.msg_type == 'html': + html_items.append(rec.data) + + result['stdout'] = ''.join(stdout_items) + result['stderr'] = ''.join(stderr_items) + result['media'] = media_items + result['html'] = html_items + + elif api_ver >= 2: + + console_items: List[Tuple[ConsoleItemType, Union[str, Tuple[str, str]]]] = [] + last_stdout = io.StringIO() + last_stderr = io.StringIO() + + for rec in records: + + if last_stdout.tell() and rec.msg_type != 'stdout': + console_items.append(('stdout', last_stdout.getvalue())) + last_stdout.seek(0) + last_stdout.truncate(0) + if last_stderr.tell() and rec.msg_type != 'stderr': + console_items.append(('stderr', last_stderr.getvalue())) + last_stderr.seek(0) + last_stderr.truncate(0) + + if rec.msg_type == 'stdout': + last_stdout.write(rec.data or '') + elif rec.msg_type == 'stderr': + last_stderr.write(rec.data or '') + elif rec.msg_type == 'media' and rec.data is not None: + o = json.loads(rec.data) + console_items.append(('media', (o['type'], o['data']))) + elif rec.msg_type in outgoing_msg_types: + # FIXME: currently mypy cannot handle dynamic specialization of literals. + console_items.append((rec.msg_type, rec.data)) # type: ignore + + if last_stdout.tell(): + console_items.append(('stdout', last_stdout.getvalue())) + if last_stderr.tell(): + console_items.append(('stderr', last_stderr.getvalue())) + + result['console'] = console_items + last_stdout.close() + last_stderr.close() + + else: + raise AssertionError('Unrecognized API version') + + async def get_next_result(self, api_ver=2, flush_timeout=2.0) -> NextResult: + # Context: per API request + has_continuation = ClientFeatures.CONTINUATION in self.client_features + try: + records = [] + result: NextResult + assert self.output_queue is not None + with timeout(flush_timeout if has_continuation else None): + while True: + rec = await self.output_queue.get() + if rec.msg_type in outgoing_msg_types: + records.append(rec) + self.output_queue.task_done() + if rec.msg_type == 'finished': + data = json.loads(rec.data) if rec.data else {} + raise RunFinished(data) + elif rec.msg_type == 'clean-finished': + data = json.loads(rec.data) if rec.data else {} + raise CleanFinished(data) + elif rec.msg_type == 'build-finished': + data = json.loads(rec.data) if rec.data else {} + raise BuildFinished(data) + elif rec.msg_type == 'waiting-input': + opts = json.loads(rec.data) if rec.data else {} + raise InputRequestPending(opts) + elif rec.msg_type == 'exec-timeout': + raise ExecTimeout + except asyncio.CancelledError: + self.resume_output_queue() + raise + except asyncio.TimeoutError: + result = { + 'runId': self.current_run_id, + 'status': 'continued', + 'exitCode': None, + 'options': None, + } + type(self).aggregate_console(result, records, api_ver) + self.resume_output_queue() + return result + except CleanFinished as e: + result = { + 'runId': self.current_run_id, + 'status': 'clean-finished', + 'exitCode': e.data.get('exitCode'), + 'options': None, + } + type(self).aggregate_console(result, records, api_ver) + self.resume_output_queue() + return result + except BuildFinished as e: + result = { + 'runId': self.current_run_id, + 'status': 'build-finished', + 'exitCode': e.data.get('exitCode'), + 'options': None, + } + type(self).aggregate_console(result, records, api_ver) + self.resume_output_queue() + return result + except RunFinished as e: + result = { + 'runId': self.current_run_id, + 'status': 'finished', + 'exitCode': e.data.get('exitCode'), + 'options': None, + } + type(self).aggregate_console(result, records, api_ver) + self.next_output_queue() + return result + except ExecTimeout: + result = { + 'runId': self.current_run_id, + 'status': 'exec-timeout', + 'exitCode': None, + 'options': None, + } + log.warning('Execution timeout detected on kernel ' + f'{self.kernel_id}') + type(self).aggregate_console(result, records, api_ver) + self.next_output_queue() + return result + except InputRequestPending as e: + result = { + 'runId': self.current_run_id, + 'status': 'waiting-input', + 'exitCode': None, + 'options': e.data, + } + type(self).aggregate_console(result, records, api_ver) + self.resume_output_queue() + return result + except Exception: + log.exception('unexpected error') + raise + + async def attach_output_queue(self, run_id: Optional[str]) -> None: + # Context: per API request + if run_id is None: + run_id = secrets.token_hex(16) + assert run_id is not None + if run_id not in self.pending_queues: + q: asyncio.Queue[ResultRecord] = asyncio.Queue(maxsize=4096) + activated = asyncio.Event() + self.pending_queues[run_id] = (activated, q) + else: + activated, q = self.pending_queues[run_id] + if self.output_queue is None: + self.output_queue = q + else: + if self.current_run_id == run_id: + # No need to wait if we are continuing. + pass + else: + # If there is an outstanding ongoning execution, + # wait until it has "finished". + await activated.wait() + activated.clear() + self.current_run_id = run_id + assert self.output_queue is q + + def resume_output_queue(self) -> None: + """ + Use this to conclude get_next_result() when the execution should be + continued from the client. + + At that time, we need to reuse the current run ID and its output queue. + We don't change self.output_queue here so that we can continue to read + outputs while the client sends the continuation request. + """ + if self.current_run_id is None: + return + self.pending_queues.move_to_end(self.current_run_id, last=False) + + def next_output_queue(self) -> None: + """ + Use this to conclude get_next_result() when we have finished a "run". + """ + assert self.current_run_id is not None + self.pending_queues.pop(self.current_run_id, None) + self.current_run_id = None + if len(self.pending_queues) > 0: + # Make the next waiting API request handler to proceed. + _, (activated, q) = self.pending_queues.popitem(last=False) + self.output_queue = q + activated.set() + else: + # If there is no pending request, just ignore all outputs + # from the kernel. + self.output_queue = None + + async def read_output(self) -> None: + # We should use incremental decoder because some kernels may + # send us incomplete UTF-8 byte sequences (e.g., Julia). + decoders = ( + codecs.getincrementaldecoder('utf8')(errors='replace'), + codecs.getincrementaldecoder('utf8')(errors='replace'), + ) + while True: + try: + msg_type, msg_data = await self.output_sock.recv_multipart() + try: + if msg_type == b'status': + await self.status_queue.put(msg_data) + elif msg_type == b'completion': + await self.completion_queue.put(msg_data) + elif msg_type == b'service-result': + await self.service_queue.put(msg_data) + elif msg_type == b'apps-result': + await self.service_apps_info_queue.put(msg_data) + elif msg_type == b'stdout': + if self.output_queue is None: + continue + if len(msg_data) > self.max_record_size: + msg_data = msg_data[:self.max_record_size] + await self.output_queue.put( + ResultRecord( + 'stdout', + decoders[0].decode(msg_data), + )) + elif msg_type == b'stderr': + if self.output_queue is None: + continue + if len(msg_data) > self.max_record_size: + msg_data = msg_data[:self.max_record_size] + await self.output_queue.put( + ResultRecord( + 'stderr', + decoders[1].decode(msg_data), + )) + else: + # Normal outputs should go to the current + # output queue. + if self.output_queue is None: + continue + await self.output_queue.put( + ResultRecord( + msg_type.decode('ascii'), + msg_data.decode('utf8'), + )) + except asyncio.QueueFull: + pass + if msg_type == b'build-finished': + # finalize incremental decoder + decoders[0].decode(b'', True) + decoders[1].decode(b'', True) + elif msg_type == b'finished': + # finalize incremental decoder + decoders[0].decode(b'', True) + decoders[1].decode(b'', True) + self.finished_at = time.monotonic() + except (asyncio.CancelledError, GeneratorExit): + break + except Exception: + log.exception('unexpected error') + break + + +def match_distro_data(data: Mapping[str, Any], distro: str) -> Tuple[str, Any]: + """ + Find the latest or exactly matching entry from krunner_volumes mapping using the given distro + string expression. + + It assumes that the keys of krunner_volumes mapping is a string concatenated with a distro + prefix (e.g., "centos", "ubuntu") and a distro version composed of multiple integer components + joined by single dots (e.g., "1.2.3", "18.04"). + """ + rx_ver_suffix = re.compile(r'(\d+(\.\d+)*)$') + m = rx_ver_suffix.search(distro) + if m is None: + # Assume latest + distro_prefix = distro + distro_ver = None + else: + distro_prefix = distro[:-len(m.group(1))] + distro_ver = m.group(1) + + # Check if there are static-build krunners first. + if distro_prefix == 'alpine': + libc_flavor = 'musl' + else: + libc_flavor = 'gnu' + distro_key = f'static-{libc_flavor}' + if volume := data.get(distro_key): + return distro_key, volume + + # Search through the per-distro versions + match_list = [ + (distro_key, value) + for distro_key, value in data.items() + if distro_key.startswith(distro_prefix) + ] + + def _extract_version(item: Tuple[str, Any]) -> Tuple[int, ...]: + m = rx_ver_suffix.search(item[0]) + if m is not None: + return tuple(map(int, m.group(1).split('.'))) + return (0,) + + match_list = sorted( + match_list, + key=_extract_version, + reverse=True) + if match_list: + if distro_ver is None: + return match_list[0] + for distro_key, value in match_list: + if distro_key == distro: + return (distro_key, value) + return match_list[0] # fallback to the latest of its kind + raise UnsupportedBaseDistroError(distro) diff --git a/src/ai/backend/agent/kubernetes/__init__.py b/src/ai/backend/agent/kubernetes/__init__.py new file mode 100644 index 0000000000..0b3cb2cb91 --- /dev/null +++ b/src/ai/backend/agent/kubernetes/__init__.py @@ -0,0 +1,8 @@ +from typing import Type + +from ..agent import AbstractAgent +from .agent import KubernetesAgent + + +def get_agent_cls() -> Type[AbstractAgent]: + return KubernetesAgent diff --git a/src/ai/backend/agent/kubernetes/agent.py b/src/ai/backend/agent/kubernetes/agent.py new file mode 100644 index 0000000000..29c73c4a8c --- /dev/null +++ b/src/ai/backend/agent/kubernetes/agent.py @@ -0,0 +1,1000 @@ +import functools +from io import StringIO +from pathlib import Path +import uuid + +import aiotools +from ai.backend.common.asyncio import current_loop +from ai.backend.common.docker import ImageRef +import asyncio +from decimal import Decimal +import hashlib +import logging +import pkg_resources +import random +import shutil +import signal +from typing import ( + Any, + FrozenSet, + List, + Literal, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + Union, +) +from ai.backend.common.logging_utils import BraceStyleAdapter + +import cattr +from kubernetes_asyncio import ( + client as kube_client, + config as kube_config, +) + +from .resources import detect_resources +from .kernel import KubernetesKernel +from .kube_object import ( + ConfigMap, HostPathPersistentVolume, KubernetesAbstractVolume, + KubernetesConfigMapVolume, KubernetesHostPathVolume, + KubernetesPVCVolume, KubernetesVolumeMount, NFSPersistentVolume, + PersistentVolumeClaim, Service, +) + +from ..agent import ACTIVE_STATUS_SET, AbstractAgent, AbstractKernelCreationContext, ComputerContext +from ..exception import K8sError, UnsupportedResource +from ..kernel import AbstractKernel, KernelFeatures +from ..resources import ( + AbstractComputePlugin, + KernelResourceSpec, + Mount, + known_slot_types, +) +from ..types import Container, ContainerStatus, Port + +from ai.backend.common.etcd import AsyncEtcd +from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext +from ai.backend.common.types import ( + AutoPullBehavior, + ClusterInfo, + ContainerId, + DeviceName, + ImageRegistry, + KernelCreationConfig, + KernelId, + MountPermission, + MountTypes, + ResourceSlot, + SlotName, + VFolderMount, + current_resource_slots, +) + +log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.kubernetes.agent')) + + +class KubernetesKernelCreationContext(AbstractKernelCreationContext[KubernetesKernel]): + config_map_name: str + deployment_name: str + scratch_dir: Path + work_dir: Path + config_dir: Path + internal_mounts: List[Mount] = [] + static_pvc_name: str + workers: Mapping[str, Mapping[str, str]] + config_maps: List[ConfigMap] + + volume_mounts: List[KubernetesVolumeMount] + volumes: List[KubernetesAbstractVolume] + + def __init__( + self, + kernel_id: KernelId, + kernel_config: KernelCreationConfig, + local_config: Mapping[str, Any], + computers: MutableMapping[str, ComputerContext], + workers: Mapping[str, Mapping[str, str]], + static_pvc_name: str, + restarting: bool = False, + ) -> None: + super().__init__(kernel_id, kernel_config, local_config, computers, restarting=restarting) + scratch_dir = (self.local_config['container']['scratch-root'] / str(kernel_id)).resolve() + + self.scratch_dir = scratch_dir + self.work_dir = scratch_dir / 'work' + self.config_dir = scratch_dir / 'config' + self.static_pvc_name = static_pvc_name + self.workers = workers + + self.volume_mounts = [] + self.volumes = [ + KubernetesPVCVolume( + name=f'kernel-{self.kernel_id}-scratches', + persistentVolumeClaim={ + 'claimName': self.static_pvc_name, + }, + ), + ] + + self.config_maps = [] + + async def get_extra_envs(self) -> Mapping[str, str]: + return {} + + async def prepare_resource_spec(self) -> Tuple[KernelResourceSpec, Optional[Mapping[str, Any]]]: + loop = current_loop() + if self.restarting: + await kube_config.load_kube_config() + + def _kernel_resource_spec_read(): + with open((self.config_dir / 'resource.txt').resolve(), 'r') as f: + resource_spec = KernelResourceSpec.read_from_file(f) + return resource_spec + + resource_spec = await loop.run_in_executor(None, _kernel_resource_spec_read) + resource_opts = None + else: + slots = ResourceSlot.from_json(self.kernel_config['resource_slots']) + # Ensure that we have intrinsic slots. + assert SlotName('cpu') in slots + assert SlotName('mem') in slots + # accept unknown slot type with zero values + # but reject if they have non-zero values. + for st, sv in slots.items(): + if st not in known_slot_types and sv != Decimal(0): + raise UnsupportedResource(st) + # sanitize the slots + current_resource_slots.set(known_slot_types) + slots = slots.normalize_slots(ignore_unknown=True) + resource_spec = KernelResourceSpec( + container_id='', + allocations={}, + slots={**slots}, # copy + mounts=[], + scratch_disk_size=0, # TODO: implement (#70) + ) + resource_opts = self.kernel_config.get('resource_opts', {}) + return resource_spec, resource_opts + + async def prepare_scratch(self) -> None: + loop = current_loop() + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + + # Unlike Docker, static files will be mounted directly to blank folder + # Check if NFS PVC for static files exists and bound + nfs_pvc = await core_api.list_persistent_volume_claim_for_all_namespaces( + label_selector='backend.ai/backend-ai-scratch-volume', + ) + if len(nfs_pvc.items) == 0: + raise K8sError('No PVC for backend.ai static files') + pvc = nfs_pvc.items[0] + if pvc.status.phase != 'Bound': + raise K8sError('PVC not Bound') + self.static_pvc_name = pvc.metadata.name + + def _create_scratch_dirs(): + self.work_dir.resolve().mkdir(parents=True, exist_ok=True) + self.config_dir.resolve().mkdir(parents=True, exist_ok=True) + + # Mount scratch directory as PV + # Config files can be mounted via ConfigMap + await loop.run_in_executor(None, _create_scratch_dirs) + + if not self.restarting: + # Since these files are bind-mounted inside a bind-mounted directory, + # we need to touch them first to avoid their "ghost" files are created + # as root in the host-side filesystem, which prevents deletion of scratch + # directories when the agent is running as non-root. + def _clone_dotfiles(): + jupyter_custom_css_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', 'jupyter-custom.css')) + logo_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', 'logo.svg')) + font_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', 'roboto.ttf')) + font_italic_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', 'roboto-italic.ttf')) + bashrc_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', '.bashrc')) + bash_profile_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', '.bash_profile')) + vimrc_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', '.vimrc')) + tmux_conf_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', '.tmux.conf')) + jupyter_custom_dir = (self.work_dir / '.jupyter' / 'custom') + jupyter_custom_dir.mkdir(parents=True, exist_ok=True) + shutil.copy(jupyter_custom_css_path.resolve(), jupyter_custom_dir / 'custom.css') + shutil.copy(logo_path.resolve(), jupyter_custom_dir / 'logo.svg') + shutil.copy(font_path.resolve(), jupyter_custom_dir / 'roboto.ttf') + shutil.copy(font_italic_path.resolve(), jupyter_custom_dir / 'roboto-italic.ttf') + shutil.copy(bashrc_path.resolve(), self.work_dir / '.bashrc') + shutil.copy(bash_profile_path.resolve(), self.work_dir / '.bash_profile') + shutil.copy(vimrc_path.resolve(), self.work_dir / '.vimrc') + shutil.copy(tmux_conf_path.resolve(), self.work_dir / '.tmux.conf') + + await loop.run_in_executor(None, _clone_dotfiles) + + async def get_intrinsic_mounts(self) -> Sequence[Mount]: + mounts: List[Mount] = [ + # Mount scratch directory + Mount( + MountTypes.K8S_GENERIC, + Path(str(self.kernel_id)), + Path('/home/work'), + MountPermission.READ_WRITE, + opts={ + 'name': f'kernel-{self.kernel_id}-scratches', + }, + ), + ] + + # TODO: Find way to mount extra volumes + + return mounts + + async def apply_network(self, cluster_info: ClusterInfo) -> None: + pass + + async def install_ssh_keypair(self, cluster_info: ClusterInfo) -> None: + sshkey = cluster_info['ssh_keypair'] + if sshkey is None: + return + + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + + # Get hash of public key + enc = hashlib.md5() + enc.update(sshkey['public_key'].encode('ascii')) + hash = enc.digest().decode('utf-8') + + try: + await core_api.read_namespaced_config_map(f'ssh-keypair-{hash}', 'backend-ai') + except: + # Keypair not stored on ConfigMap, create one + cm = ConfigMap('', f'kernel-{self.kernel_id}-ssh-keypair-{hash}') + cm.put('public', sshkey['public_key']) + cm.put('private', sshkey['private_key']) + + self.config_maps.append(cm) + + await self.process_volumes([ + KubernetesConfigMapVolume( + name=f'kernel-{self.kernel_id}-ssh-keypair', + configMap={ + 'name': 'ssh-keypair-hash', + }, + ), + ]) + await self.process_mounts([ + Mount( + MountTypes.K8S_GENERIC, + Path('public'), + Path('/home/config/ssh/id_cluster.pub'), + permission=MountPermission.READ_ONLY, + opts={ + 'name': f'kernel-{self.kernel_id}-ssh-keypair', + }, + ), + Mount( + MountTypes.K8S_GENERIC, + Path('private'), + Path('/home/config/ssh/id_cluster.pub'), + permission=MountPermission.READ_ONLY, + opts={ + 'name': f'kernel-{self.kernel_id}-ssh-keypair', + }, + ), + ]) + + async def process_mounts(self, mounts: Sequence[Mount]): + for i, mount in zip(range(len(mounts)), mounts): + if mount.type == MountTypes.K8S_GENERIC: + name = (mount.opts or {})['name'] + self.volume_mounts.append( + KubernetesVolumeMount( + name=name, + mountPath=mount.target.as_posix(), + subPath=mount.source.as_posix() if mount.source is not None else None, + readOnly=mount.permission == MountPermission.READ_ONLY, + ), + ) + elif mount.type == MountTypes.K8S_HOSTPATH: + name = (mount.opts or {})['name'] + self.volume_mounts.append( + KubernetesVolumeMount( + name=name, + mountPath=mount.target.as_posix(), + subPath=None, + readOnly=mount.permission == MountPermission.READ_ONLY, + ), + ) + else: + log.warn( + 'Mount {}:{} -> Mount type {} it not supported on K8s Agent. Skipping mount', + mount.source, + mount.target, + mount.type, + ) + + def resolve_krunner_filepath(self, filename: str) -> Path: + return Path(filename) + + def get_runner_mount( + self, + type: MountTypes, + src: Union[str, Path], + target: Union[str, Path], + perm: Literal['ro', 'rw'] = 'ro', + opts: Mapping[str, Any] = None, + ) -> Mount: + return Mount( + MountTypes.K8S_GENERIC, + Path(src), + Path(target), + MountPermission(perm), + opts={ + **(opts or {}), + 'name': f'kernel-{self.kernel_id}-scratches', + }, + ) + + async def process_volumes( + self, + volumes: List[KubernetesAbstractVolume], + ) -> None: + self.volumes += volumes + + async def mount_vfolders( + self, + vfolders: Sequence[VFolderMount], + resource_spec: KernelResourceSpec, + ) -> None: + # We can't mount vFolder backed by storage proxy + for idx, vfolder in enumerate(vfolders): + if self.internal_data.get('prevent_vfolder_mounts', False): + # Only allow mount of ".logs" directory to prevent expose + # internal-only information, such as Docker credentials to user's ".docker" vfolder + # in image importer kernels. + if vfolder.name != '.logs': + continue + mount = Mount( + MountTypes.K8S_HOSTPATH, + Path(vfolder.host_path), + Path(vfolder.kernel_path), + vfolder.mount_perm, + opts={ + 'name': f'kernel-{self.kernel_id}-hostPath-{idx}', + }, + ) + await self.process_volumes([ + KubernetesHostPathVolume( + name=f'kernel-{self.kernel_id}-hostPath-{idx}', + hostPath={ + 'path': vfolder.host_path.as_posix(), + 'type': 'Directory', + }, + ), + ]) + resource_spec.mounts.append(mount) + + async def apply_accelerator_allocation(self, computer, device_alloc) -> None: + # update_nested_dict( + # self.computer_docker_args, + # await computer.generate_docker_args(self.docker, device_alloc)) + # TODO: add support for accelerator allocation + pass + + async def generate_deployment_object( + self, image: str, environ: Mapping[str, Any], + ports: List[int], command: List[str], + labels: Mapping[str, Any] = {}, + ) -> dict: + return { + 'apiVersion': 'apps/v1', + 'kind': 'Deployment', + 'metadata': { + 'name': f'kernel-{self.kernel_id}', + 'labels': labels, + }, + 'spec': { + 'replicas': 0, + 'selector': {'matchLabels': {'run': f'kernel-{self.kernel_id}'}}, + 'template': { + 'metadata': {'labels': {'run': f'kernel-{self.kernel_id}'}}, + 'spec': { + 'containers': [{ + 'name': f'kernel-{self.kernel_id}-session', + 'image': image, + 'imagePullPolicy': 'IfNotPresent', + 'command': ['sh', '/opt/kernel/entrypoint.sh'], + 'args': command, + 'env': [{'name': k, 'value': v} for k, v in environ.items()], + 'volumeMounts': [cattr.unstructure(v) for v in self.volume_mounts], + 'ports': [{'containerPort': x} for x in ports], + }], + 'volumes': [cattr.unstructure(v) for v in self.volumes], + }, + }, + }, + } + + async def spawn( + self, + resource_spec: KernelResourceSpec, + environ: Mapping[str, str], + service_ports, + ) -> KubernetesKernel: + loop = current_loop() + + if self.restarting: + pass + else: + if bootstrap := self.kernel_config.get('bootstrap_script'): + + def _write_user_bootstrap_script(): + (self.work_dir / 'bootstrap.sh').write_text(bootstrap) + if KernelFeatures.UID_MATCH in self.kernel_features: # UID Match won't work on K8s + # uid = self.local_config['container']['kernel-uid'] + # gid = self.local_config['container']['kernel-gid'] + # if os.geteuid() == 0: + # os.chown(self.work_dir / 'bootstrap.sh', uid, gid) + pass + + await loop.run_in_executor(None, _write_user_bootstrap_script) + + def _write_config(file_name: str, content: str): + (self.config_dir / file_name).write_text(content) + + with StringIO() as buf: + for k, v in environ.items(): + buf.write(f'{k}={v}\n') + # accel_envs = self.computer_docker_args.get('Env', []) + # for env in accel_envs: + # buf.write(f'{env}\n') + await loop.run_in_executor( + None, functools.partial(_write_config, 'environ.txt', buf.getvalue()), + ) + + with StringIO() as buf: + resource_spec.write_to_file(buf) + for dev_type, device_alloc in resource_spec.allocations.items(): + computer_self = self.computers[dev_type] + kvpairs = \ + await computer_self.instance.generate_resource_data(device_alloc) + for k, v in kvpairs.items(): + buf.write(f'{k}={v}\n') + await loop.run_in_executor( + None, functools.partial(_write_config, 'resource.txt', buf.getvalue()), + ) + + docker_creds = self.internal_data.get('docker_credentials') + if docker_creds: + await loop.run_in_executor( + None, functools.partial(_write_config, 'docker-creds.json', docker_creds), + ) + + if keypair := self.internal_data.get('ssh_keypair'): + for mount in resource_spec.mounts: + container_path = str(mount).split(':')[1] + if container_path == '/home/work/.ssh': + break + else: + pubkey = keypair['public_key'].encode('ascii') + privkey = keypair['private_key'].encode('ascii') + ssh_config_map = ConfigMap(self.kernel_id, f'kernel-{self.kernel_id}-ssh-config') + ssh_config_map.put('authorized_keys', pubkey) + ssh_config_map.put('id_container', privkey) + await self.process_volumes([ + KubernetesConfigMapVolume( + name='ssh-config', + configMap={ + 'name': f'kernel-{self.kernel_id}-ssh-config', + }, + ), + ]) + await self.process_mounts([ + Mount( + MountTypes.K8S_GENERIC, + Path('authorized_keys'), + Path('/home/work/.ssh/authorized_keys'), + opts={ + 'name': 'ssh-config', + }, + ), + Mount( + MountTypes.K8S_GENERIC, + Path('id_container'), + Path('/home/work/.ssh/id_container'), + opts={ + 'name': 'ssh-config', + }, + ), + ]) + + # higher priority dotfiles are stored last to support overwriting + for dotfile in self.internal_data.get('dotfiles', []): + if dotfile['path'].startswith('/'): + if dotfile['path'].startswith('/home/'): + path_arr = dotfile['path'].split('/') + file_path: Path = self.scratch_dir / '/'.join(path_arr[2:]) + else: + file_path = Path(dotfile['path']) + else: + file_path = self.work_dir / dotfile['path'] + file_path.parent.mkdir(parents=True, exist_ok=True) + await loop.run_in_executor( + None, + file_path.write_text, + dotfile['data']) + + # TODO: Mark shmem feature as unsupported when advertising agent + + kernel_obj = KubernetesKernel( + self.kernel_id, + self.image_ref, + self.kspec_version, + agent_config=self.local_config, + service_ports=service_ports, + resource_spec=resource_spec, + environ=environ, + data={}, + ) + return kernel_obj + + async def start_container( + self, + kernel_obj: AbstractKernel, + cmdargs: List[str], + resource_opts, + preopen_ports, + ) -> Mapping[str, Any]: + image_labels = self.kernel_config['image']['labels'] + service_ports = kernel_obj.service_ports + environ = kernel_obj.environ + + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + apps_api = kube_client.AppsV1Api() + exposed_ports = [2000, 2001] + for sport in service_ports: + exposed_ports.extend(sport['container_ports']) + + encoded_preopen_ports = ','.join(f'{port_no}:preopen:{port_no}' for port_no in preopen_ports) + + service_port_label = image_labels['ai.backend.service-ports'] + if len(encoded_preopen_ports) > 0: + service_port_label += f',{encoded_preopen_ports}' + + deployment = await self.generate_deployment_object( + self.image_ref.canonical, + environ, + exposed_ports, + cmdargs, + labels={ + 'ai.backend.service-ports': service_port_label.replace(':', '-').replace(',', '.'), + }, + ) + + if self.local_config['debug']['log-kernel-config']: + log.debug('Initial container config: {0}', deployment) + + expose_service = Service( + str(self.kernel_id), + f'kernel-{self.kernel_id}-expose', + [ + (port, f'kernel-{self.kernel_id}-svc-{index}') + for index, port in zip(range(len(exposed_ports)), exposed_ports) + ], + ) + + async def rollup( + functions: List[Tuple[Optional[functools.partial], + Optional[functools.partial]]], + ): + rollback_functions: List[Optional[functools.partial]] = [] + + for (rollup_function, future_rollback_function) in functions: + try: + if rollup_function: + await rollup_function() + rollback_functions.append(future_rollback_function) + except Exception as e: + for rollback_function in rollback_functions[::-1]: + if rollback_function: + await rollback_function() + log.exception('Error while rollup: {}', e) + raise + + arguments: List[Tuple[Optional[functools.partial], Optional[functools.partial]]] = [] + node_ports = [] + + try: + expose_service_api_response = \ + await core_api.create_namespaced_service('backend-ai', body=expose_service.to_dict()) + except Exception as e: + log.exception('Error while rollup: {}', e) + raise + + node_ports = expose_service_api_response.spec.ports + arguments.append(( + None, + functools.partial(core_api.delete_namespaced_service, expose_service.name, 'backend-ai'), + )) + for cm in self.config_maps: + arguments.append(( + functools.partial( + core_api.create_namespaced_config_map, 'backend-ai', + body=cm.to_dict(), + ), + functools.partial(core_api.delete_namespaced_config_map, cm.name, 'backend-ai'), + )) + + arguments.append(( + functools.partial( + apps_api.create_namespaced_deployment, 'backend-ai', + body=deployment, pretty='pretty-example', + ), + None, + )) + + await rollup(arguments) + + assigned_ports: MutableMapping[str, int] = {} + for port in node_ports: + assigned_ports[port.port] = port.node_port + + ctnr_host_port_map: MutableMapping[int, int] = {} + stdin_port = 0 + stdout_port = 0 + for idx, port in enumerate(exposed_ports): + host_port = assigned_ports[port] + + if port == 2000: # intrinsic + repl_in_port = host_port + elif port == 2001: # intrinsic + repl_out_port = host_port + elif port == 2002: # legacy + stdin_port = host_port + elif port == 2003: # legacy + stdout_port = host_port + else: + ctnr_host_port_map[port] = host_port + for sport in service_ports: + sport['host_ports'] = tuple( + ctnr_host_port_map[cport] for cport in sport['container_ports'] + ) + + target_node_ip = random.choice([x['InternalIP'] for x in self.workers.values()]) + + return { + 'container_id': '', + 'kernel_host': target_node_ip, + 'repl_in_port': repl_in_port, + 'repl_out_port': repl_out_port, + 'stdin_port': stdin_port, # legacy + 'stdout_port': stdout_port, # legacy + 'assigned_ports': assigned_ports, + # 'domain_socket_proxies': self.domain_socket_proxies, + 'block_service_ports': self.internal_data.get('block_service_ports', False), + } + + +class KubernetesAgent( + AbstractAgent[KubernetesKernel, KubernetesKernelCreationContext], +): + workers: MutableMapping[str, MutableMapping[str, str]] = {} + k8s_ptask_group: aiotools.PersistentTaskGroup + + def __init__( + self, + etcd: AsyncEtcd, + local_config: Mapping[str, Any], + *, + stats_monitor: StatsPluginContext, + error_monitor: ErrorPluginContext, + skip_initial_scan: bool = False, + ) -> None: + super().__init__( + etcd, + local_config, + stats_monitor=stats_monitor, + error_monitor=error_monitor, + skip_initial_scan=skip_initial_scan, + ) + + async def __ainit__(self) -> None: + await super().__ainit__() + ipc_base_path = self.local_config['agent']['ipc-base-path'] + self.agent_sockpath = ipc_base_path / 'container' / f'agent.{self.local_instance_id}.sock' + + await self.check_krunner_pv_status() + await self.fetch_workers() + self.k8s_ptask_group = aiotools.PersistentTaskGroup() + # Socket Relay initialization + # Agent socket handler initialization + # K8s event monitor task initialization + + async def check_krunner_pv_status(self): + capacity = format(self.local_config['container']['scratch-size'], 'g')[:-1] + + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + + namespaces = await core_api.list_namespace() + if len(list(filter(lambda ns: ns.metadata.name == 'backend-ai', namespaces.items))) == 0: + await core_api.create_namespace({ + 'apiVersion': 'v1', + 'kind': 'Namespace', + 'metadata': { + 'name': 'backend-ai', + }, + }) + + pv = await core_api.list_persistent_volume(label_selector='backend.ai/backend-ai-scratch-volume') + + if len(pv.items) == 0: + # PV does not exists; create one + if self.local_config['container']['scratch-type'] == 'k8s-nfs': + new_pv = NFSPersistentVolume( + self.local_config['container']['scratch-nfs-address'], + 'backend-ai-static-pv', capacity, + ) + new_pv.label( + 'backend.ai/backend-ai-scratch-volume', + self.local_config['container']['scratch-nfs-address'], + ) + new_pv.options = [ + x.strip() + for x in self.local_config['container']['scratch-nfs-options'].split(',') + ] + elif self.local_config['container']['scratch-type'] == 'hostdir': + new_pv = HostPathPersistentVolume( + self.local_config['container']['scratch-root'].as_posix(), + 'backend-ai-static-pv', capacity, + ) + new_pv.label('backend.ai/backend-ai-scratch-volume', 'hostPath') + else: + raise NotImplementedError( + f'Scratch type {self.local_config["container"]["scratch-type"]} is not supported', + ) + + try: + await core_api.create_persistent_volume(body=new_pv.to_dict()) + except: + raise + + pvc = await core_api.list_persistent_volume_claim_for_all_namespaces( + label_selector='backend.ai/backend-ai-scratch-volume', + ) + + if len(pvc.items) == 0: + # PV does not exists; create one + new_pvc = PersistentVolumeClaim( + 'backend-ai-static-pvc', 'backend-ai-static-pv', capacity, + ) + if self.local_config['container']['scratch-type'] == 'k8s-nfs': + new_pvc.label( + 'backend.ai/backend-ai-scratch-volume', + self.local_config['container']['scratch-nfs-address'], + ) + else: + new_pvc.label('backend.ai/backend-ai-scratch-volume', 'hostPath') + try: + await core_api.create_namespaced_persistent_volume_claim( + 'backend-ai', body=new_pvc.to_dict(), + ) + except Exception as e: + log.exception('Error: {}', e) + raise + + async def fetch_workers(self): + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + nodes = await core_api.list_node() + for node in nodes.items: + # if 'node-role.kubernetes.io/master' in node.metadata.labels.keys(): + # continue + self.workers[node.metadata.name] = node.status.capacity + for addr in node.status.addresses: + if addr.type == 'InternalIP': + self.workers[node.metadata.name]['InternalIP'] = addr.address + if addr.type == 'ExternalIP': + self.workers[node.metadata.name]['ExternalIP'] = addr.address + + async def shutdown(self, stop_signal: signal.Signals): + # Stop agent socket handler task + + try: + if self.k8s_ptask_group is not None: + await self.k8s_ptask_group.shutdown() + await super().shutdown(stop_signal) + finally: + # Stop k8s event monitoring. + pass + + async def detect_resources(self) -> Tuple[ + Mapping[DeviceName, AbstractComputePlugin], + Mapping[SlotName, Decimal], + ]: + return await detect_resources(self.etcd, self.local_config) + + async def enumerate_containers( + self, + status_filter: FrozenSet[ContainerStatus] = ACTIVE_STATUS_SET, + ) -> Sequence[Tuple[KernelId, Container]]: + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + + result = [] + fetch_tasks = [] + for deployment in (await core_api.list_namespaced_pod('backend-ai')).items: + # Additional check to filter out real worker pods only? + + async def _fetch_container_info(pod: Any): + kernel_id: Union[str, None] = "(unknown)" + try: + kernel_id = await get_kernel_id_from_deployment(pod) + if kernel_id is None: + return + # Is it okay to assume that only one container resides per pod? + if pod['status']['containerStatuses'][0]['stats'].keys()[0] in status_filter: + result.append( + ( + KernelId(uuid.UUID(kernel_id)), + await container_from_pod(pod), + ), + ) + except asyncio.CancelledError: + pass + except Exception: + log.exception( + "error while fetching container information (cid:{}, k:{})", + pod['metadata']['uid'], kernel_id, + ) + + fetch_tasks.append(_fetch_container_info(deployment)) + + await asyncio.gather(*fetch_tasks, return_exceptions=True) + return result + + async def scan_images(self) -> Mapping[str, str]: + # Retrieving image label from registry api is not possible + return {} + + async def handle_agent_socket(self): + # TODO: Add support for remote agent socket mechanism + pass + + async def pull_image(self, image_ref: ImageRef, registry_conf: ImageRegistry) -> None: + # TODO: Add support for appropriate image pulling mechanism on K8s + pass + + async def check_image(self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior) -> bool: + # TODO: Add support for appropriate image checking mechanism on K8s + # Just mark all images as 'pulled' since we can't manually initiate image pull on each kube node + return True + + async def init_kernel_context( + self, + kernel_id: KernelId, + kernel_config: KernelCreationConfig, + *, + restarting: bool = False, + ) -> KubernetesKernelCreationContext: + return KubernetesKernelCreationContext( + kernel_id, + kernel_config, + self.local_config, + self.computers, + self.workers, + 'backend-ai-static-pvc', + restarting=restarting, + ) + + async def destroy_kernel(self, kernel_id: KernelId, container_id: Optional[ContainerId]) -> None: + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + apps_api = kube_client.AppsV1Api() + + async def force_cleanup(reason='self-terminated'): + await self.send_event('kernel_terminated', + kernel_id, 'self-terminated', + None) + try: + kernel = self.kernel_registry[kernel_id] + except: + log.warning('_destroy_kernel({0}) kernel missing (already dead?)', + kernel_id) + await asyncio.shield(self.k8s_ptask_group.create_task(force_cleanup())) + return None + deployment_name = kernel['deployment_name'] + try: + await core_api.delete_namespaced_service(f'{deployment_name}-service', 'backend-ai') + await core_api.delete_namespaced_service(f'{deployment_name}-nodeport', 'backend-ai') + await apps_api.delete_namespaced_deployment(f'{deployment_name}', 'backend-ai') + except: + log.warning('_destroy({0}) kernel missing (already dead?)', kernel_id) + + async def clean_kernel( + self, + kernel_id: KernelId, + container_id: Optional[ContainerId], + restarting: bool, + ) -> None: + loop = current_loop() + if not restarting: + scratch_dir = self.local_config['container']['scratch-root'] / str(kernel_id) + await loop.run_in_executor(None, shutil.rmtree, str(scratch_dir)) + + async def create_overlay_network(self, network_name: str) -> None: + return await super().create_overlay_network(network_name) + + async def destroy_overlay_network(self, network_name: str) -> None: + return await super().destroy_overlay_network(network_name) + + async def create_local_network(self, network_name: str) -> None: + return await super().create_local_network(network_name) + + async def destroy_local_network(self, network_name: str) -> None: + return await super().destroy_local_network(network_name) + + async def restart_kernel__load_config( + self, + kernel_id: KernelId, + name: str, + ) -> bytes: + loop = current_loop() + scratch_dir = (self.local_config['container']['scratch-root'] / str(kernel_id)).resolve() + config_dir = scratch_dir / 'config' + return await loop.run_in_executor( + None, + (config_dir / name).read_bytes, + ) + + async def restart_kernel__store_config( + self, + kernel_id: KernelId, + name: str, + data: bytes, + ) -> None: + loop = current_loop() + scratch_dir = (self.local_config['container']['scratch-root'] / str(kernel_id)).resolve() + config_dir = scratch_dir / 'config' + return await loop.run_in_executor( + None, + (config_dir / name).write_bytes, + data, + ) + + +async def get_kernel_id_from_deployment(pod: Any) -> Optional[str]: + # TODO: create function which extracts kernel id from pod object + return pod.get('metadata', {}).get('name') + + +async def container_from_pod(pod: Any) -> Container: + status: ContainerStatus = ContainerStatus.RUNNING + phase = pod['status']['phase'] + if phase == 'Pending' or phase == 'Running': + status = ContainerStatus.RUNNING + elif phase == 'Succeeded': + status = ContainerStatus.EXITED + elif phase == 'Failed' or phase == 'Unknown': + status = ContainerStatus.DEAD + + # TODO: Create Container object from K8s Pod definition + return Container( + id=ContainerId(''), + status=status, + image=pod['spec']['containers'][0]['image'], + labels=pod['metadata']['labels'], + ports=[Port( + port['host_ip'], port['container_port'], port['host_port'], + ) for port in pod['spec']['containers'][0]['ports']], + backend_obj=pod, + ) diff --git a/src/ai/backend/agent/kubernetes/backendai-socket-relay.img.tar.gz b/src/ai/backend/agent/kubernetes/backendai-socket-relay.img.tar.gz new file mode 100644 index 0000000000..f86102fe7b --- /dev/null +++ b/src/ai/backend/agent/kubernetes/backendai-socket-relay.img.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:632296d4e118f2c09abad66bf51ed5e9383cc1fb35a6eb18f5d04f73c3ad437c +size 3254795 diff --git a/src/ai/backend/agent/kubernetes/backendai-socket-relay.version.txt b/src/ai/backend/agent/kubernetes/backendai-socket-relay.version.txt new file mode 100644 index 0000000000..d00491fd7e --- /dev/null +++ b/src/ai/backend/agent/kubernetes/backendai-socket-relay.version.txt @@ -0,0 +1 @@ +1 diff --git a/src/ai/backend/agent/kubernetes/files.py b/src/ai/backend/agent/kubernetes/files.py new file mode 100644 index 0000000000..78659aa534 --- /dev/null +++ b/src/ai/backend/agent/kubernetes/files.py @@ -0,0 +1,64 @@ +import logging +import os +from pathlib import Path +from typing import Dict + +from ai.backend.common.logging import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +# the names of following AWS variables follow boto3 convention. +s3_access_key = os.environ.get('AWS_ACCESS_KEY_ID', 'dummy-access-key') +s3_secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', 'dummy-secret-key') +s3_region = os.environ.get('AWS_REGION', 'ap-northeast-1') +s3_bucket = os.environ.get('AWS_S3_BUCKET', 'codeonweb') +s3_bucket_path = os.environ.get('AWS_S3_BUCKET_PATH', 'bucket') + +if s3_access_key == 'dummy-access-key': + log.info('Automatic ~/.output file S3 uploads is disabled.') + + +def relpath(path, base): + return Path(path).resolve().relative_to(Path(base).resolve()) + + +def scandir(root: Path, allowed_max_size: int): + ''' + Scans a directory recursively and returns a dictionary of all files and + their last modified time. + ''' + file_stats: Dict[Path, float] = dict() + if not isinstance(root, Path): + root = Path(root) + if not root.exists(): + return file_stats + for entry in os.scandir(root): + # Skip hidden files. + if entry.name.startswith('.'): + continue + if entry.is_file(): + try: + stat = entry.stat() + except PermissionError: + continue + # Skip too large files! + if stat.st_size > allowed_max_size: + continue + file_stats[Path(entry.path)] = stat.st_mtime + elif entry.is_dir(): + try: + file_stats.update(scandir(Path(entry.path), allowed_max_size)) + except PermissionError: + pass + return file_stats + + +def diff_file_stats(fs1, fs2): + k2 = set(fs2.keys()) + k1 = set(fs1.keys()) + new_files = k2 - k1 + modified_files = set() + for k in (k2 - new_files): + if fs1[k] < fs2[k]: + modified_files.add(k) + return new_files | modified_files diff --git a/src/ai/backend/agent/kubernetes/intrinsic.py b/src/ai/backend/agent/kubernetes/intrinsic.py new file mode 100644 index 0000000000..1629727c47 --- /dev/null +++ b/src/ai/backend/agent/kubernetes/intrinsic.py @@ -0,0 +1,336 @@ +from decimal import Decimal +import logging +import os +from pathlib import Path +import platform +from typing import ( + Any, + Collection, + Dict, + List, + Mapping, + Optional, + Sequence, +) + +import aiohttp +from aiodocker.docker import Docker, DockerContainer +from aiodocker.exceptions import DockerError +from kubernetes_asyncio import client as K8sClient, config as K8sConfig + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + DeviceName, DeviceId, + DeviceModelInfo, + SlotName, SlotTypes, +) +from .agent import Container +from .resources import ( + get_resource_spec_from_container, +) +from .. import __version__ +from ..resources import ( + AbstractAllocMap, DeviceSlotInfo, + DiscretePropertyAllocMap, + AbstractComputeDevice, + AbstractComputePlugin, +) +from ..stats import ( + StatContext, NodeMeasurement, ContainerMeasurement, +) + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +async def fetch_api_stats(container: DockerContainer) -> Optional[Dict[str, Any]]: + short_cid = container._id[:7] + try: + ret = await container.stats(stream=False) # TODO: cache + except RuntimeError as e: + msg = str(e.args[0]).lower() + if 'event loop is closed' in msg or 'session is closed' in msg: + return None + raise + except (DockerError, aiohttp.ClientError) as e: + log.error( + 'cannot read stats (cid:{}): client error: {!r}.', + short_cid, e, + ) + return None + else: + # aiodocker 0.16 or later returns a list of dict, even when not streaming. + if isinstance(ret, list): + if not ret: + # The API may return an empty result upon container termination. + return None + ret = ret[0] + # The API may return an invalid or empty result upon container termination. + if ret is None or not isinstance(ret, dict): + log.warning( + 'cannot read stats (cid:{}): got an empty result: {}', + short_cid, ret, + ) + return None + if ( + ret['read'].startswith('0001-01-01') or + ret['preread'].startswith('0001-01-01') + ): + return None + return ret + + +# Pseudo-plugins for intrinsic devices (CPU and the main memory) + +class CPUDevice(AbstractComputeDevice): + pass + + +class CPUPlugin(AbstractComputePlugin): + """ + Represents the CPU. + """ + + config_watch_enabled = False + + key = DeviceName('cpu') + slot_types = [ + (SlotName('cpu'), SlotTypes.COUNT), + ] + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None: + pass + + async def list_devices(self) -> Collection[CPUDevice]: + await K8sConfig.load_kube_config() + core_api = K8sClient.CoreV1Api() + + nodes = (await core_api.list_node()).to_dict()['items'] + overcommit_factor = int(os.environ.get('BACKEND_CPU_OVERCOMMIT_FACTOR', '1')) + assert 1 <= overcommit_factor <= 10 + + return [ + CPUDevice( + device_id=DeviceId(node['metadata']['uid']), + hw_location='root', + numa_node=None, + memory_size=0, + processing_units=int(node['status']['capacity']['cpu']) * overcommit_factor, + ) + for i, node in zip(range(len(nodes)), nodes) + # if 'node-role.kubernetes.io/master' not in node['metadata']['labels'].keys() + ] + + async def available_slots(self) -> Mapping[SlotName, Decimal]: + devices = await self.list_devices() + log.debug('available_slots: {}', devices) + return { + SlotName('cpu'): Decimal(sum(dev.processing_units for dev in devices)), + } + + def get_version(self) -> str: + return __version__ + + async def extra_info(self) -> Mapping[str, str]: + return { + 'agent_version': __version__, + 'machine': platform.machine(), + 'os_type': platform.system(), + } + + async def gather_node_measures(self, ctx: StatContext) -> Sequence[NodeMeasurement]: + # TODO: Create our own k8s metric collector + + return [] + + async def gather_container_measures( + self, + ctx: StatContext, + container_ids: Sequence[str], + ) -> Sequence[ContainerMeasurement]: + # TODO: Implement Kubernetes-specific container metric collection + + return [ + ] + + async def create_alloc_map(self) -> AbstractAllocMap: + devices = await self.list_devices() + return DiscretePropertyAllocMap( + device_slots={ + dev.device_id: + DeviceSlotInfo(SlotTypes.COUNT, SlotName('cpu'), Decimal(dev.processing_units)) + for dev in devices + }, + ) + + async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: + # TODO: move the sysconf hook in libbaihook.so here + return [] + + async def generate_docker_args( + self, + docker: Docker, + device_alloc, + ) -> Mapping[str, Any]: + # This function might be needed later to apply fine-grained tuning for + # K8s resource allocation + return {} + + async def restore_from_container( + self, + container: Container, + alloc_map: AbstractAllocMap, + ) -> None: + assert isinstance(alloc_map, DiscretePropertyAllocMap) + # Docker does not return the original cpuset.... :( + # We need to read our own records. + resource_spec = await get_resource_spec_from_container(container.backend_obj) + if resource_spec is None: + return + alloc_map.apply_allocation({ + SlotName('cpu'): + resource_spec.allocations[DeviceName('cpu')][SlotName('cpu')], + }) + + async def get_attached_devices( + self, + device_alloc: Mapping[SlotName, + Mapping[DeviceId, Decimal]], + ) -> Sequence[DeviceModelInfo]: + device_ids = [*device_alloc[SlotName('cpu')].keys()] + available_devices = await self.list_devices() + attached_devices: List[DeviceModelInfo] = [] + for device in available_devices: + if device.device_id in device_ids: + attached_devices.append({ + 'device_id': device.device_id, + 'model_name': '', + 'data': {'cores': len(device_ids)}, + }) + return attached_devices + + +class MemoryDevice(AbstractComputeDevice): + pass + + +class MemoryPlugin(AbstractComputePlugin): + """ + Represents the main memory. + + When collecting statistics, it also measures network and I/O usage + in addition to the memory usage. + """ + + config_watch_enabled = False + + key = DeviceName('mem') + slot_types = [ + (SlotName('mem'), SlotTypes.BYTES), + ] + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None: + pass + + async def list_devices(self) -> Collection[MemoryDevice]: + await K8sConfig.load_kube_config() + core_api = K8sClient.CoreV1Api() + + nodes = (await core_api.list_node()).to_dict()['items'] + overcommit_factor = int(os.environ.get('BACKEND_MEM_OVERCOMMIT_FACTOR', '1')) + assert 1 <= overcommit_factor <= 10 + mem = 0 + for node in nodes: + # if 'node-role.kubernetes.io/master' in node['metadata']['labels'].keys(): + # continue + mem += int(node['status']['capacity']['memory'][:-2]) * 1024 + return [ + MemoryDevice( + device_id=DeviceId('root'), + hw_location='root', + numa_node=0, + memory_size=mem * overcommit_factor, + processing_units=0, + ), + ] + + async def available_slots(self) -> Mapping[SlotName, Decimal]: + devices = await self.list_devices() + return { + SlotName('mem'): Decimal(sum(dev.memory_size for dev in devices)), + } + + def get_version(self) -> str: + return __version__ + + async def extra_info(self) -> Mapping[str, str]: + return {} + + async def gather_node_measures(self, ctx: StatContext) -> Sequence[NodeMeasurement]: + # TODO: Create our own k8s metric collector + return [] + + async def gather_container_measures(self, ctx: StatContext, container_ids: Sequence[str]) \ + -> Sequence[ContainerMeasurement]: + # TODO: Implement Kubernetes-specific container metric collection + return [] + + async def create_alloc_map(self) -> AbstractAllocMap: + devices = await self.list_devices() + return DiscretePropertyAllocMap( + device_slots={ + dev.device_id: + DeviceSlotInfo(SlotTypes.BYTES, SlotName('mem'), Decimal(dev.memory_size)) + for dev in devices + }, + ) + + async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: + return [] + + async def generate_docker_args( + self, + docker: Docker, + device_alloc, + ) -> Mapping[str, Any]: + # This function might be needed later to apply fine-grained tuning for + # K8s resource allocation + return {} + + async def restore_from_container( + self, + container: Container, + alloc_map: AbstractAllocMap, + ) -> None: + assert isinstance(alloc_map, DiscretePropertyAllocMap) + memory_limit = container.backend_obj['HostConfig']['Memory'] + alloc_map.apply_allocation({ + SlotName('mem'): {DeviceId('root'): memory_limit}, + }) + + async def get_attached_devices( + self, + device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> Sequence[DeviceModelInfo]: + device_ids = [*device_alloc[SlotName('mem')].keys()] + available_devices = await self.list_devices() + attached_devices: List[DeviceModelInfo] = [] + for device in available_devices: + if device.device_id in device_ids: + attached_devices.append({ + 'device_id': device.device_id, + 'model_name': '', + 'data': {}, + }) + return attached_devices diff --git a/src/ai/backend/agent/kubernetes/kernel.py b/src/ai/backend/agent/kubernetes/kernel.py new file mode 100644 index 0000000000..ca8c545a73 --- /dev/null +++ b/src/ai/backend/agent/kubernetes/kernel.py @@ -0,0 +1,448 @@ +import asyncio +import logging +import lzma +from pathlib import Path +import shutil +from ai.backend.agent.utils import get_arch_name +import pkg_resources +import re +import textwrap +from typing import ( + Any, Optional, + Mapping, Dict, + FrozenSet, + Sequence, Tuple, +) + +from kubernetes_asyncio import client as kube_client, config as kube_config, watch +from aiodocker.docker import Docker +from aiotools import TaskGroup + +from ai.backend.common.docker import ImageRef +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import KernelId +from ai.backend.common.utils import current_loop +import zmq +from ..resources import KernelResourceSpec +from ..kernel import AbstractKernel, AbstractCodeRunner + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class KubernetesKernel(AbstractKernel): + + deployment_name: str + + def __init__( + self, kernel_id: KernelId, image: ImageRef, version: int, *, + agent_config: Mapping[str, Any], + resource_spec: KernelResourceSpec, + service_ports: Any, # TODO: type-annotation + data: Dict[str, Any], + environ: Mapping[str, Any], + ) -> None: + super().__init__( + kernel_id, image, version, + agent_config=agent_config, + resource_spec=resource_spec, + service_ports=service_ports, + data=data, + environ=environ, + ) + + self.deployment_name = f'kernel-{kernel_id}' + + async def close(self) -> None: + await self.scale(0) + + async def create_code_runner(self, *, + client_features: FrozenSet[str], + api_version: int) -> AbstractCodeRunner: + + scale = await self.scale(1) + if scale.to_dict()['spec']['replicas'] == 0: + log.error('Scaling failed! Response body: {0}', scale) + raise ValueError('Scaling failed!') + + if scale.to_dict()['status']['replicas'] == 0: + while not await self.is_scaled(): + await asyncio.sleep(0.5) + + # TODO: Find way to detect if kernel runner has started inside container + + runner = await KubernetesCodeRunner.new( + self.kernel_id, + kernel_host=self.data['kernel_host'], + repl_in_port=self.data['repl_in_port'], + repl_out_port=self.data['repl_out_port'], + exec_timeout=0, + client_features=client_features) + + retries = 0 + while True: + try: + await runner.feed_and_get_status() + break + except zmq.error.ZMQError as e: + if retries < 4: + retries += 1 + log.debug('Socket not responding, retrying #{}', retries) + await asyncio.sleep(retries ** 2) + else: + raise e + + return runner + + async def scale(self, num: int): + await kube_config.load_kube_config() + apps_api = kube_client.AppsV1Api() + try: + return await apps_api.replace_namespaced_deployment_scale( + self.deployment_name, 'backend-ai', + body={ + 'apiVersion': 'autoscaling/v1', + 'kind': 'Scale', + 'metadata': { + 'name': self.deployment_name, + 'namespace': 'backend-ai', + }, + 'spec': {'replicas': num}, + 'status': {'replicas': num, 'selector': f'run={self.deployment_name}'}, + }, + ) + except Exception as e: + log.exception('scale failed: {}', e) + + async def is_scaled(self): + await kube_config.load_kube_config() + apps_api = kube_client.AppsV1Api() + core_api = kube_client.CoreV1Api() + scale = await apps_api.read_namespaced_deployment(self.deployment_name, 'backend-ai') + + if scale.to_dict()['status']['replicas'] == 0: + return False + for condition in scale.to_dict()['status']['conditions']: + if not condition['status']: + return False + + pods = await core_api.list_namespaced_pod( + 'backend-ai', + label_selector=f'run=kernel-{self.kernel_id}', + ) + pods = pods.to_dict()['items'] or [] + if len(pods) == 0: + return False + for pod in pods: + containers = pod['status']['container_statuses'] or [] + if len(containers) == 0: + return False + for container in containers: + started = container.get('started') + if not container['ready'] or started is not None and not started: + return False + return True + + async def get_completions(self, text: str, opts: Mapping[str, Any]): + result = await self.runner.feed_and_get_completion(text, opts) + return {'status': 'finished', 'completions': result} + + async def check_status(self): + result = await self.runner.feed_and_get_status() + return result + + async def get_logs(self): + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + + result = await core_api.read_namespaced_pod_log(self.kernel_id, 'backend-ai') + return {'logs': result.data.decode('utf-8')} + + async def interrupt_kernel(self): + await self.runner.feed_interrupt() + return {'status': 'finished'} + + async def start_service(self, service: str, opts: Mapping[str, Any]): + if self.data.get('block_service_ports', False): + return { + 'status': 'failed', + 'error': 'operation blocked', + } + for sport in self.service_ports: + if sport['name'] == service: + break + else: + return {'status': 'failed', 'error': 'invalid service name'} + result = await self.runner.feed_start_service({ + 'name': service, + 'port': sport['container_ports'][0], # primary port + 'ports': sport['container_ports'], + 'protocol': sport['protocol'], + 'options': opts, + }) + return result + + async def shutdown_service(self, service: str): + await self.runner.feed_shutdown_service(service) + + async def get_service_apps(self): + result = await self.runner.feed_service_apps() + return result + + async def accept_file(self, filename: str, filedata: bytes): + loop = current_loop() + work_dir = self.agent_config['container']['scratch-root'] / str(self.kernel_id) / 'work' + try: + # create intermediate directories in the path + dest_path = (work_dir / filename).resolve(strict=False) + parent_path = dest_path.parent + except ValueError: # parent_path does not start with work_dir! + raise AssertionError('malformed upload filename and path.') + + def _write_to_disk(): + parent_path.mkdir(parents=True, exist_ok=True) + dest_path.write_bytes(filedata) + + try: + await loop.run_in_executor(None, _write_to_disk) + except FileNotFoundError: + log.error('{0}: writing uploaded file failed: {1} -> {2}', + self.kernel_id, filename, dest_path) + + async def download_file(self, filepath: str): + # TODO: Implement file operations with pure Kubernetes API + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + + home_path = Path('/home/work') + try: + abspath = (home_path / filepath).resolve() + abspath.relative_to(home_path) + except ValueError: + raise PermissionError('You cannot download files outside /home/work') + + async with watch.Watch().stream( + core_api.connect_get_namespaced_pod_exec, + self.kernel_id, 'backend-ai', + command=['tar', 'cf', '-', abspath.resolve()], stderr=True, stdin=True, stdout=True, + tty=False, _preload_content=False, + ) as stream: + async for event in stream: + log.debug('stream: {}', event) + + return None + + async def list_files(self, container_path: str): + # TODO: Implement file operations with pure Kubernetes API + await kube_config.load_kube_config() + core_api = kube_client.CoreV1Api() + + # Confine the lookable paths in the home directory + home_path = Path('/home/work') + try: + resolved_path = (home_path / container_path).resolve() + resolved_path.relative_to(home_path) + except ValueError: + raise PermissionError('You cannot list files outside /home/work') + + # Gather individual file information in the target path. + code = textwrap.dedent(''' + import json + import os + import stat + import sys + + files = [] + for f in os.scandir(sys.argv[1]): + fstat = f.stat() + ctime = fstat.st_ctime # TODO: way to get concrete create time? + mtime = fstat.st_mtime + atime = fstat.st_atime + files.append({ + 'mode': stat.filemode(fstat.st_mode), + 'size': fstat.st_size, + 'ctime': ctime, + 'mtime': mtime, + 'atime': atime, + 'filename': f.name, + }) + print(json.dumps(files)) + ''') + + command = ['/opt/backend.ai/bin/python', '-c', code, str(container_path)] + async with watch.Watch().stream( + core_api.connect_get_namespaced_pod_exec, + self.kernel_id, 'backend-ai', + command=command, stderr=True, stdin=True, stdout=True, + tty=False, _preload_content=False, + ) as stream: + async for event in stream: + log.debug('stream: {}', event) + + return {'files': '', 'errors': '', 'abspath': str(container_path)} + + +class KubernetesCodeRunner(AbstractCodeRunner): + + kernel_host: str + repl_in_port: int + repl_out_port: int + + def __init__(self, kernel_id, *, + kernel_host, repl_in_port, repl_out_port, + exec_timeout=0, client_features=None) -> None: + super().__init__( + kernel_id, + exec_timeout=exec_timeout, + client_features=client_features) + self.kernel_host = kernel_host + self.repl_in_port = repl_in_port + self.repl_out_port = repl_out_port + + async def get_repl_in_addr(self) -> str: + return f'tcp://{self.kernel_host}:{self.repl_in_port}' + + async def get_repl_out_addr(self) -> str: + return f'tcp://{self.kernel_host}:{self.repl_out_port}' + + +async def prepare_krunner_env_impl(distro: str, root_path: str) -> Tuple[str, Optional[str]]: + if distro.startswith('static-'): + distro_name = distro.replace('-', '_') # pkg/mod name use underscores + else: + if (m := re.search(r'^([a-z]+)\d+\.\d+$', distro)) is None: + raise ValueError('Unrecognized "distro[version]" format string.') + distro_name = m.group(1) + docker = Docker() + arch = get_arch_name() + current_version = int(Path( + pkg_resources.resource_filename( + f'ai.backend.krunner.{distro_name}', + f'./krunner-version.{distro}.txt')) + .read_text().strip()) + krunner_folder_name = f'backendai-krunner.v{current_version}.{distro}' + target_path = Path(root_path) / krunner_folder_name + extractor_image = 'backendai-krunner-extractor:latest' + + try: + for item in (await docker.images.list()): + if item['RepoTags'] is None: + continue + if item['RepoTags'][0] == extractor_image: + break + else: + log.info('preparing the Docker image for krunner extractor...') + extractor_archive = pkg_resources.resource_filename( + 'ai.backend.runner', f'krunner-extractor.img.{arch}.tar.xz') + with lzma.open(extractor_archive, 'rb') as reader: + proc = await asyncio.create_subprocess_exec( + *['docker', 'load'], stdin=reader) + if (await proc.wait() != 0): + raise RuntimeError('loading krunner extractor image has failed!') + + log.info('checking krunner-env for {}.{}...', distro, arch) + + if not target_path.exists(): + log.info('populating {} volume version {}', + krunner_folder_name, current_version) + target_path.mkdir(exist_ok=False) + archive_path = Path(pkg_resources.resource_filename( + f'ai.backend.krunner.{distro_name}', + f'krunner-env.{distro}.{arch}.tar.xz')).resolve() + extractor_path = Path(pkg_resources.resource_filename( + 'ai.backend.runner', + 'krunner-extractor.sh')).resolve() + + log.debug('Executing {}', ' '.join([ + 'docker', 'run', '--rm', '-i', + '-v', f'{archive_path}:/root/archive.tar.xz', + '-v', f'{extractor_path}:/root/krunner-extractor.sh', + '-v', f'{target_path.absolute().as_posix()}:/root/volume', + '-e', f'KRUNNER_VERSION={current_version}', + extractor_image, + '/root/krunner-extractor.sh', + ])) + + proc = await asyncio.create_subprocess_exec(*[ + 'docker', 'run', '--rm', '-i', + '-v', f'{archive_path}:/root/archive.tar.xz', + '-v', f'{extractor_path}:/root/krunner-extractor.sh', + '-v', f'{target_path.absolute().as_posix()}:/root/volume', + '-e', f'KRUNNER_VERSION={current_version}', + extractor_image, + '/root/krunner-extractor.sh', + ]) + if (await proc.wait() != 0): + raise RuntimeError('extracting krunner environment has failed!') + except Exception: + log.exception('unexpected error') + return distro, None + finally: + await docker.close() + return distro, krunner_folder_name + + +async def copy_runner_files(scratch_path: Path) -> None: + artifact_path = Path(pkg_resources.resource_filename('ai.backend.agent', '../runner')) + kernel_path = Path(pkg_resources.resource_filename('ai.backend.agent', '../kernel')) + helpers_path = Path(pkg_resources.resource_filename('ai.backend.agent', '../helpers')) + + destination_path = scratch_path + + if (destination_path / 'runner').exists(): + shutil.rmtree(destination_path / 'runner', ignore_errors=True) + (destination_path / 'runner').mkdir(parents=True) + + target_files = [ + 'entrypoint.sh', + '*.bin', + '*.so', + 'DO_NOT_STORE_PERSISTENT_FILES_HERE.md', + 'extract_dotfiles.py', + ] + + for target_glob in target_files: + for matched_path in artifact_path.glob(target_glob): + shutil.copy(matched_path.resolve(), destination_path / 'runner') + + if (destination_path / 'kernel').exists(): + shutil.rmtree(destination_path / 'kernel', ignore_errors=True) + shutil.copytree(kernel_path.resolve(), destination_path / 'kernel') + + if (destination_path / 'helpers').exists(): + shutil.rmtree(destination_path / 'helpers', ignore_errors=True) + shutil.copytree(helpers_path.resolve(), destination_path / 'helpers') + + +async def prepare_krunner_env(local_config: Mapping[str, Any]) -> Mapping[str, Sequence[str]]: + """ + Check if the volume "backendai-krunner.{distro}.{arch}" exists and is up-to-date. + If not, automatically create it and update its content from the packaged pre-built krunner + tar archives. + """ + + all_distros = [] + entry_prefix = 'backendai_krunner_v10' + for entrypoint in pkg_resources.iter_entry_points(entry_prefix): + log.debug('loading krunner pkg: {}', entrypoint.module_name) + plugin = entrypoint.load() + await plugin.init({}) # currently does nothing + provided_versions = Path(pkg_resources.resource_filename( + f'ai.backend.krunner.{entrypoint.name}', + 'versions.txt', + )).read_text().splitlines() + all_distros.extend(provided_versions) + + scratch_mount = local_config['container']['scratch-root'] + await copy_runner_files(Path(scratch_mount)) + + tasks = [] + async with TaskGroup() as tg: + for distro in all_distros: + tasks.append(tg.create_task(prepare_krunner_env_impl(distro, scratch_mount))) + distro_volumes = [t.result() for t in tasks if not t.cancelled()] + result = {} + for distro_name_and_version, volume_name in distro_volumes: + if volume_name is None: + continue + result[distro_name_and_version] = volume_name + return result diff --git a/src/ai/backend/agent/kubernetes/kube_object.py b/src/ai/backend/agent/kubernetes/kube_object.py new file mode 100644 index 0000000000..10dd62ffe3 --- /dev/null +++ b/src/ai/backend/agent/kubernetes/kube_object.py @@ -0,0 +1,201 @@ +import attr +from typing import Any, Dict, Mapping, Optional + +'''This file contains API templates for Python K8s Client. +Since I don't prefer using templates provided from vanila k8s client, +all API definitions for Backend.AI Agent will use needs to be defined here. +All API definitions defined here (especially for use with outside this file) +should implement to_dict() method, which returns complete definition in dictionary. +To pass API body from objects defined here, simply put return value of to_dict() method as a body: +e.g) await k8sCoreApi.create_persistent_volume(body=pv.to_dict())''' + + +class AbstractAPIObject: + pass + + +@attr.s(auto_attribs=True, slots=True) +class KubernetesVolumeMount: + name: str + mountPath: str + subPath: Optional[str] + readOnly: Optional[bool] + + +class KubernetesAbstractVolume: + name: str + + +@attr.s(auto_attribs=True, slots=True) +class KubernetesEmptyDirVolume(KubernetesAbstractVolume): + name: str + emptyDir: Mapping[str, Any] = {} + + +@attr.s(auto_attribs=True, slots=True) +class KubernetesPVCVolume(KubernetesAbstractVolume): + name: str + persistentVolumeClaim: Mapping[str, str] + + +@attr.s(auto_attribs=True, slots=True) +class KubernetesConfigMapVolume(KubernetesAbstractVolume): + name: str + configMap: Mapping[str, str] + + +@attr.s(auto_attribs=True, slots=True) +class KubernetesHostPathVolume(KubernetesAbstractVolume): + name: str + hostPath: Mapping[str, str] + + +class ConfigMap(AbstractAPIObject): + items: Dict[str, str] = {} + + def __init__(self, kernel_id, name: str): + self.name = name + self.labels = {'backend.ai/kernel-id': kernel_id} + + def put(self, key: str, value: str): + self.items[key] = value + + def to_dict(self) -> dict: + return { + 'apiVersion': 'v1', + 'kind': 'ConfigMap', + 'metadata': { + 'name': self.name, + 'labels': self.labels, + }, + 'data': self.items, + } + + +class Service(AbstractAPIObject): + def __init__(self, kernel_id: str, name: str, container_port: list, service_type='NodePort'): + self.name = name + self.deployment_name = f'kernel-{kernel_id}' + self.container_port = container_port + self.service_type = service_type + self.labels = {'run': self.name, 'backend.ai/kernel-id': kernel_id} + + def to_dict(self) -> dict: + base: Dict[str, Any] = { + 'apiVersion': 'v1', + 'kind': 'Service', + 'metadata': { + 'name': self.name, + 'labels': self.labels, + }, + 'spec': { + 'ports': [{'targetPort': x[0], 'port': x[0], 'name': x[1]} for x in self.container_port], + 'selector': {'run': self.deployment_name}, + 'type': '', + }, + } + if self.service_type == 'NodePort': + base['spec']['type'] = 'NodePort' + elif self.service_type == 'LoadBalancer': + base['spec']['type'] = 'LoadBalancer' + return base + + +class NFSPersistentVolume(AbstractAPIObject): + + def __init__(self, server, name, capacity, path='/'): + self.server = server + self.path = path + self.name = name + self.capacity = capacity + self.labels = {} + self.options = [] + + def label(self, k, v): + self.labels[k] = v + + def to_dict(self) -> dict: + return { + 'apiVersion': 'v1', + 'kind': 'PersistentVolume', + 'metadata': { + 'name': self.name, + 'labels': self.labels, + }, + 'spec': { + 'capacity': { + 'storage': self.capacity + 'Gi', + }, + 'accessModes': ['ReadWriteMany'], + 'nfs': { + 'server': self.server, + 'path': self.path, + }, + 'mountOptions': self.options, + }, + } + + +class HostPathPersistentVolume(AbstractAPIObject): + + def __init__(self, path, name, capacity): + self.path = path + self.name = name + self.capacity = capacity + self.labels = {} + self.options = [] + + def label(self, k, v): + self.labels[k] = v + + def to_dict(self) -> dict: + return { + 'apiVersion': 'v1', + 'kind': 'PersistentVolume', + 'metadata': { + 'name': self.name, + 'labels': self.labels, + }, + 'spec': { + 'capacity': { + 'storage': self.capacity + 'Gi', + }, + 'accessModes': ['ReadWriteMany'], + 'hostPath': { + 'path': self.path, + }, + 'mountOptions': self.options, + }, + } + + +class PersistentVolumeClaim(AbstractAPIObject): + + def __init__(self, name, pv_name, capacity): + self.name = name + self.pv_name = pv_name + self.capacity = capacity + self.labels = {} + + def label(self, k, v): + self.labels[k] = v + + def to_dict(self) -> dict: + base = { + 'apiVersion': 'v1', + 'kind': 'PersistentVolumeClaim', + 'metadata': { + 'name': self.name, + 'labels': self.labels, + }, + 'spec': { + 'resources': { + 'requests': { + 'storage': self.capacity + 'Gi', + }, + }, + 'accessModes': ['ReadWriteMany'], + 'storageClassName': '', + }, + } + return base diff --git a/src/ai/backend/agent/kubernetes/resources.py b/src/ai/backend/agent/kubernetes/resources.py new file mode 100644 index 0000000000..8281235024 --- /dev/null +++ b/src/ai/backend/agent/kubernetes/resources.py @@ -0,0 +1,98 @@ +from decimal import Decimal +import logging +from pathlib import Path +from typing import ( + Any, Optional, + Mapping, MutableMapping, + Tuple, +) + +import aiofiles + +from ai.backend.common.etcd import AsyncEtcd +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + DeviceName, SlotName, +) +from ..exception import InitializationError +from ..resources import ( + AbstractComputePlugin, ComputePluginContext, KernelResourceSpec, known_slot_types, +) + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +async def detect_resources( + etcd: AsyncEtcd, + local_config: Mapping[str, Any], +) -> Tuple[Mapping[DeviceName, AbstractComputePlugin], + Mapping[SlotName, Decimal]]: + """ + Detect available computing resource of the system. + It also loads the accelerator plugins. + + limit_cpus, limit_gpus are deprecated. + """ + reserved_slots = { + 'cpu': local_config['resource']['reserved-cpu'], + 'mem': local_config['resource']['reserved-mem'], + 'disk': local_config['resource']['reserved-disk'], + } + slots: MutableMapping[SlotName, Decimal] = {} + + compute_device_types: MutableMapping[DeviceName, AbstractComputePlugin] = {} + + # Initialize intrinsic plugins by ourselves. + from .intrinsic import CPUPlugin, MemoryPlugin + compute_plugin_ctx = ComputePluginContext( + etcd, local_config, + ) + await compute_plugin_ctx.init() + if 'cpu' not in compute_plugin_ctx.plugins: + cpu_config = await etcd.get_prefix('config/plugins/cpu') + cpu_plugin = CPUPlugin(cpu_config, local_config) + compute_plugin_ctx.attach_intrinsic_device(cpu_plugin) + if 'mem' not in compute_plugin_ctx.plugins: + memory_config = await etcd.get_prefix('config/plugins/memory') + memory_plugin = MemoryPlugin(memory_config, local_config) + compute_plugin_ctx.attach_intrinsic_device(memory_plugin) + for plugin_name, plugin_instance in compute_plugin_ctx.plugins.items(): + if not all( + (invalid_name := sname, sname.startswith(f'{plugin_instance.key}.'))[1] + for sname, _ in plugin_instance.slot_types + if sname not in {'cpu', 'mem'} + ): + raise InitializationError( + "Slot types defined by an accelerator plugin must be prefixed " + "by the plugin's key.", + invalid_name, # noqa: F821 + plugin_instance.key, + ) + if plugin_instance.key in compute_device_types: + raise InitializationError( + f"A plugin defining the same key '{plugin_instance.key}' already exists. " + "You may need to uninstall it first.") + compute_device_types[plugin_instance.key] = plugin_instance + + for key, computer in compute_device_types.items(): + known_slot_types.update(computer.slot_types) # type: ignore # (only updated here!) + resource_slots = await computer.available_slots() + for sname, sval in resource_slots.items(): + slots[sname] = Decimal(max(0, sval - reserved_slots.get(sname, 0))) + if slots[sname] <= 0 and sname in (SlotName('cpu'), SlotName('mem')): + raise InitializationError( + f"The resource slot '{sname}' is not sufficient (zero or below zero). " + "Try to adjust the reserved resources or use a larger machine.") + + log.info('Resource slots: {!r}', slots) + log.info('Slot types: {!r}', known_slot_types) + return compute_device_types, slots + + +async def get_resource_spec_from_container(container_info) -> Optional[KernelResourceSpec]: + for mount in container_info['HostConfig']['Mounts']: + if mount['Target'] == '/home/config': + async with aiofiles.open(Path(mount['Source']) / 'resource.txt', 'r') as f: # type: ignore + return await KernelResourceSpec.aread_from_file(f) + else: + return None diff --git a/src/ai/backend/agent/kubernetes/utils.py b/src/ai/backend/agent/kubernetes/utils.py new file mode 100644 index 0000000000..e2dd657bb2 --- /dev/null +++ b/src/ai/backend/agent/kubernetes/utils.py @@ -0,0 +1,140 @@ +import asyncio +import gzip +import logging +from pathlib import Path +import pkg_resources +import subprocess +from typing import Any, BinaryIO, Mapping, Tuple, cast + +from aiodocker.docker import Docker +from aiodocker.exceptions import DockerError + +from ai.backend.common.logging import BraceStyleAdapter + +from ..utils import update_nested_dict + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class PersistentServiceContainer: + + def __init__( + self, + docker: Docker, + image_ref: str, + container_config: Mapping[str, Any], + *, + name: str = None, + ) -> None: + self.docker = docker + self.image_ref = image_ref + default_container_name = image_ref.split(':')[0].rsplit('/', maxsplit=1)[-1] + if name is None: + self.container_name = default_container_name + else: + self.container_name = name + self.container_config = container_config + self.img_version = int(Path(pkg_resources.resource_filename( + 'ai.backend.agent.docker', + f'{default_container_name}.version.txt', + )).read_text()) + self.img_path = Path(pkg_resources.resource_filename( + 'ai.backend.agent.docker', + f'{default_container_name}.img.tar.gz', + )) + + async def get_container_version_and_status(self) -> Tuple[int, bool]: + try: + c = self.docker.containers.container(self.container_name) + await c.show() + except DockerError as e: + if e.status == 404: + return 0, False + else: + raise + if c['Config'].get('Labels', {}).get('ai.backend.system', '0') != '1': + raise RuntimeError( + f"An existing container named \"{c['Name'].lstrip('/')}\" is not a system container " + f"spawned by Backend.AI. Please check and remove it.") + return ( + int(c['Config'].get('Labels', {}).get('ai.backend.version', '0')), + c['State']['Status'].lower() == 'running', + ) + + async def get_image_version(self) -> int: + try: + img = await self.docker.images.inspect(self.image_ref) + except DockerError as e: + if e.status == 404: + return 0 + else: + raise + return int(img['Config'].get('Labels', {}).get('ai.backend.version', '0')) + + async def ensure_running_latest(self) -> None: + image_version = await self.get_image_version() + if image_version == 0: + log.info("PersistentServiceContainer({}): installing...", self.image_ref) + await self.install_latest() + elif image_version < self.img_version: + log.info("PersistentServiceContainer({}): upgrading (v{} -> v{})", + self.image_ref, image_version, self.img_version) + await self.install_latest() + container_version, is_running = await self.get_container_version_and_status() + if container_version == 0 or image_version != container_version or not is_running: + log.info("PersistentServiceContainer({}): recreating...", self.image_ref) + await self.recreate() + if not is_running: + log.info("PersistentServiceContainer({}): starting...", self.image_ref) + await self.start() + + async def install_latest(self) -> None: + with gzip.open(self.img_path, 'rb') as reader: + proc = await asyncio.create_subprocess_exec( + *['docker', 'load'], + stdin=cast(BinaryIO, reader), + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + ) + if (await proc.wait() != 0): + stderr = b'(unavailable)' + if proc.stderr is not None: + stderr = await proc.stderr.read() + raise RuntimeError( + 'loading the image has failed!', + self.image_ref, proc.returncode, stderr, + ) + + async def recreate(self) -> None: + try: + c = self.docker.containers.container(self.container_name) + await c.stop() + await c.delete(force=True) + except DockerError as e: + if e.status == 409 and 'is not running' in e.message: + pass + elif e.status == 404: + pass + else: + raise + container_config = { + 'Image': self.image_ref, + 'Tty': True, + 'Privileged': False, + 'AttachStdin': False, + 'AttachStdout': False, + 'AttachStderr': False, + 'HostConfig': { + 'Init': True, + 'RestartPolicy': { + 'Name': 'unless-stopped', # make it persistent + 'MaximumRetryCount': 0, + }, + }, + } + update_nested_dict(container_config, self.container_config) + await self.docker.containers.create(config=container_config, name=self.container_name) + + async def start(self) -> None: + c = self.docker.containers.container(self.container_name) + await c.start() diff --git a/src/ai/backend/agent/linuxkit-metadata-proxy/main.go b/src/ai/backend/agent/linuxkit-metadata-proxy/main.go new file mode 100644 index 0000000000..bfa5156474 --- /dev/null +++ b/src/ai/backend/agent/linuxkit-metadata-proxy/main.go @@ -0,0 +1,84 @@ +package main + +import ( + "flag" + "io" + "log" + "net" + "net/http" + "net/url" + "strconv" +) + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + +func handleHTTP(w http.ResponseWriter, req *http.Request, remotePort int) { + req.URL = &url.URL{ + Scheme: "http", + Opaque: req.URL.Opaque, + User: req.URL.User, + Host: "host.docker.internal:" + strconv.Itoa(remotePort), + Path: req.URL.Path, + RawPath: req.URL.RawPath, + ForceQuery: req.URL.ForceQuery, + RawQuery: req.URL.RawQuery, + Fragment: req.URL.Fragment, + RawFragment: req.URL.RawFragment, + } + req.Host = "host.docker.internal:" + strconv.Itoa(remotePort) + log.Printf("%s %s\n", req.Method, req.URL) + delHopHeaders(req.Header) + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + req.Header.Set("X-Forwarded-For", clientIP) + } + resp, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + defer resp.Body.Close() + copyHeader(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func delHopHeaders(header http.Header) { + for _, h := range hopHeaders { + header.Del(h) + } +} + +func main() { + var localPort int + var remotePort int + flag.IntVar(&localPort, "port", 50128, "Target port for proxy to listen") + flag.IntVar(&remotePort, "remote-port", 8000, "Remote metadata server listening port") + flag.Parse() + server := &http.Server{ + Addr: ":" + strconv.Itoa(localPort), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleHTTP(w, r, remotePort) + }), + } + log.Printf("Listening on 0.0.0.0:%d -> host.docker.internal:%d\n", localPort, remotePort) + log.Fatal(server.ListenAndServe()) +} diff --git a/src/ai/backend/agent/proxy.py b/src/ai/backend/agent/proxy.py new file mode 100644 index 0000000000..25469faba3 --- /dev/null +++ b/src/ai/backend/agent/proxy.py @@ -0,0 +1,73 @@ +import asyncio +from asyncio import Future +from pathlib import Path +from typing import Set, Tuple + +from ai.backend.common.utils import current_loop +import attr + + +@attr.s(auto_attribs=True, slots=True) +class DomainSocketProxy: + host_sock_path: Path + host_proxy_path: Path + proxy_server: asyncio.AbstractServer + + +async def proxy_connection(upper_sock_path: Path, + down_reader: asyncio.StreamReader, + down_writer: asyncio.StreamWriter) -> None: + + up_reader, up_writer = await asyncio.open_unix_connection(str(upper_sock_path)) + + async def _downstream(): + try: + while True: + data = await up_reader.read(4096) + if not data: + break + down_writer.write(data) + await down_writer.drain() + except asyncio.CancelledError: + pass + finally: + down_writer.close() + await down_writer.wait_closed() + await asyncio.sleep(0) + + async def _upstream(): + try: + while True: + data = await down_reader.read(4096) + if not data: + break + up_writer.write(data) + await up_writer.drain() + except asyncio.CancelledError: + pass + finally: + up_writer.close() + await up_writer.wait_closed() + await asyncio.sleep(0) + + loop = current_loop() + downstream_task = loop.create_task(_downstream()) + upstream_task = loop.create_task(_upstream()) + tasks = [upstream_task, downstream_task] + # Since we cannot determine which side (the server or client) disconnects first, + # await until any task that completes first. + # For example, when proxying the docker domain socket, the proxy connections for one-shot + # docker commands are usually disconnected by the client first, but the connections for + # long-running streaming commands are disconnected by the server first when the server-side + # processing finishes. + try: + task_results: Tuple[Set[Future], Set[Future]] = \ + await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + done, pending = task_results + except asyncio.CancelledError: + pass + finally: + # And then cancel all remaining tasks. + for t in pending: + t.cancel() + await t diff --git a/src/ai/backend/agent/py.typed b/src/ai/backend/agent/py.typed new file mode 100644 index 0000000000..5abed26af8 --- /dev/null +++ b/src/ai/backend/agent/py.typed @@ -0,0 +1 @@ +marker diff --git a/src/ai/backend/agent/resources.py b/src/ai/backend/agent/resources.py new file mode 100644 index 0000000000..1accbb2ff0 --- /dev/null +++ b/src/ai/backend/agent/resources.py @@ -0,0 +1,1001 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from collections import defaultdict +from decimal import Decimal, ROUND_DOWN +import enum +import fnmatch +import logging +import json +import operator +from pathlib import Path +from typing import ( + Any, + Collection, + Container, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + FrozenSet, + Sequence, + Set, + TextIO, + Tuple, + Type, + TypeVar, + cast, + TYPE_CHECKING, +) + +import attr +import aiodocker + +from ai.backend.common.types import ( + ResourceSlot, SlotName, SlotTypes, + DeviceId, DeviceName, DeviceModelInfo, + MountPermission, MountTypes, + BinarySize, + HardwareMetadata, +) +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.plugin import AbstractPlugin, BasePluginContext +from .exception import ( + InsufficientResource, + InvalidResourceArgument, + InvalidResourceCombination, + NotMultipleOfQuantum, +) +from .stats import StatContext, NodeMeasurement, ContainerMeasurement +from .types import Container as SessionContainer + +if TYPE_CHECKING: + from aiofiles.threadpool.text import AsyncTextIOWrapper + from io import TextIOWrapper + +log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.resources')) + +log_alloc_map: bool = False +known_slot_types: Mapping[SlotName, SlotTypes] = {} + + +def round_down(from_dec: Decimal, with_dec: Decimal): + remainder = from_dec.remainder_near(with_dec) + if remainder < 0: + remainder += with_dec + return from_dec - remainder + + +class AllocationStrategy(enum.Enum): + FILL = 0 + EVENLY = 1 + + +@attr.s(auto_attribs=True, slots=True) +class KernelResourceSpec: + """ + This struct-like object stores the kernel resource allocation information + with serialization and deserialization. + + It allows seamless reconstruction of allocations even when the agent restarts + while kernel containers are running. + """ + + container_id: str + """The container ID to refer inside containers.""" + + slots: Mapping[SlotName, str] + """Stores the original user-requested resource slots.""" + + allocations: MutableMapping[DeviceName, Mapping[SlotName, Mapping[DeviceId, Decimal]]] + """ + Represents the resource allocations for each slot (device) type and devices. + """ + + scratch_disk_size: int + """The size of scratch disk. (not implemented yet)""" + + mounts: List['Mount'] = attr.Factory(list) + """The mounted vfolder list.""" + + def freeze(self) -> None: + """Replace the attribute setter to make it immutable.""" + # TODO: implement + pass + + # def _frozen_setattr(self, name, value): + # raise RuntimeError("tried to modify a frozen KernelResourceSpec object") + + # self.mounts = tuple(self.mounts) # type: ignore + # # TODO: wrap slots and allocations with frozendict? + # setattr(self, '__setattr__', _frozen_setattr) # <-- __setattr__ is read-only... :( + + def write_to_string(self) -> str: + mounts_str = ','.join(map(str, self.mounts)) + slots_str = json.dumps({ + k: str(v) for k, v in self.slots.items() + }) + + resource_str = f'CID={self.container_id}\n' + resource_str += f'SCRATCH_SIZE={BinarySize(self.scratch_disk_size):m}\n' + resource_str += f'MOUNTS={mounts_str}\n' + resource_str += f'SLOTS={slots_str}\n' + + for device_name, slots in self.allocations.items(): + for slot_name, per_device_alloc in slots.items(): + if not (slot_name.startswith(f'{device_name}.') or slot_name == device_name): + raise ValueError(f'device_name ({device_name}) must be a prefix of ' + f'slot_name ({slot_name})') + pieces = [] + for dev_id, alloc in per_device_alloc.items(): + if known_slot_types.get(slot_name, 'count') == 'bytes': + pieces.append(f'{dev_id}:{BinarySize(alloc):s}') + else: + pieces.append(f'{dev_id}:{alloc}') + alloc_str = ','.join(pieces) + resource_str += f'{slot_name.upper()}_SHARES={alloc_str}\n' + + return resource_str + + def write_to_file(self, file: TextIO) -> None: + file.write(self.write_to_string()) + + @classmethod + def read_from_string(cls, text: str) -> 'KernelResourceSpec': + kvpairs = {} + for line in text.split('\n'): + if '=' not in line: + continue + key, val = line.strip().split('=', maxsplit=1) + kvpairs[key] = val + allocations = cast( + MutableMapping[ + DeviceName, + MutableMapping[SlotName, Mapping[DeviceId, Decimal]], + ], + defaultdict(lambda: defaultdict(Decimal)), + ) + for key, val in kvpairs.items(): + if key.endswith('_SHARES'): + slot_name = SlotName(key[:-7].lower()) + device_name = DeviceName(slot_name.split('.')[0]) + per_device_alloc: MutableMapping[DeviceId, Decimal] = {} + for entry in val.split(','): + raw_dev_id, _, raw_alloc = entry.partition(':') + if not raw_dev_id or not raw_alloc: + continue + dev_id = DeviceId(raw_dev_id) + try: + if known_slot_types.get(slot_name, 'count') == 'bytes': + alloc = Decimal(BinarySize.from_str(raw_alloc)) + else: + alloc = Decimal(raw_alloc) + except KeyError as e: + log.warning('A previously launched container has ' + 'unknown slot type: {}. Ignoring it.', + e.args[0]) + continue + per_device_alloc[dev_id] = alloc + allocations[device_name][slot_name] = per_device_alloc + mounts = [Mount.from_str(m) for m in kvpairs['MOUNTS'].split(',') if m] + return cls( + container_id=kvpairs.get('CID', 'unknown'), + scratch_disk_size=BinarySize.finite_from_str(kvpairs['SCRATCH_SIZE']), + allocations=dict(allocations), + slots=ResourceSlot(json.loads(kvpairs['SLOTS'])), + mounts=mounts, + ) + + @classmethod + def read_from_file(cls, file: TextIOWrapper) -> 'KernelResourceSpec': + text = '\n'.join(file.readlines()) + return cls.read_from_string(text) + + @classmethod + async def aread_from_file(cls, file: AsyncTextIOWrapper) -> 'KernelResourceSpec': + text = '\n'.join(await file.readlines()) # type: ignore + return cls.read_from_string(text) + + def to_json_serializable_dict(self) -> Mapping[str, Any]: + o = attr.asdict(self) + for slot_name, alloc in o['slots'].items(): + if known_slot_types.get(slot_name, 'count') == 'bytes': + o['slots'] = f'{BinarySize(alloc):s}' + else: + o['slots'] = str(alloc) + for dev_name, dev_alloc in o['allocations'].items(): + for slot_name, per_device_alloc in dev_alloc.items(): + for dev_id, alloc in per_device_alloc.items(): + if known_slot_types.get(slot_name, 'count') == 'bytes': + alloc = f'{BinarySize(alloc):s}' + else: + alloc = str(alloc) + o['allocations'][dev_name][slot_name][dev_id] = alloc + o['mounts'] = list(map(str, self.mounts)) + return o + + def to_json(self) -> str: + return json.dumps(self.to_json_serializable_dict()) + + +@attr.s(auto_attribs=True) +class AbstractComputeDevice(): + device_id: DeviceId + hw_location: str # either PCI bus ID or arbitrary string + numa_node: Optional[int] # NUMA node ID (None if not applicable) + memory_size: int # bytes of available per-accelerator memory + processing_units: int # number of processing units (e.g., cores, SMP) + + +class AbstractComputePlugin(AbstractPlugin, metaclass=ABCMeta): + + key: DeviceName = DeviceName('accelerator') + slot_types: Sequence[Tuple[SlotName, SlotTypes]] + exclusive_slot_types: Set[str] + + @abstractmethod + async def list_devices(self) -> Collection[AbstractComputeDevice]: + """ + Return the list of accelerator devices, as read as physically + on the host. + """ + raise NotImplementedError + + @abstractmethod + async def available_slots(self) -> Mapping[SlotName, Decimal]: + """ + Return available slot amounts for each slot key. + """ + raise NotImplementedError + + @abstractmethod + def get_version(self) -> str: + """ + Return the version string of the plugin. + """ + raise NotImplementedError + + @abstractmethod + async def extra_info(self) -> Mapping[str, str]: + """ + Return extra information related to this plugin, + such as the underlying driver version and feature flags. + """ + return {} + + @abstractmethod + async def gather_node_measures(self, ctx: StatContext) -> Sequence[NodeMeasurement]: + """ + Return the system-level and device-level statistic metrics. + + It may return any number of metrics using different statistics key names in the + returning map. + Note that the key must not conflict with other accelerator plugins and must not + contain dots. + """ + raise NotImplementedError + + @abstractmethod + async def gather_container_measures( + self, + ctx: StatContext, + container_ids: Sequence[str], + ) -> Sequence[ContainerMeasurement]: + """ + Return the container-level statistic metrics. + """ + raise NotImplementedError + + @abstractmethod + async def create_alloc_map(self) -> 'AbstractAllocMap': + """ + Create and return an allocation map for this plugin. + """ + raise NotImplementedError + + @abstractmethod + async def get_hooks(self, distro: str, arch: str) -> Sequence[Path]: + """ + Return the library hook paths used by the plugin (optional). + + :param str distro: The target Linux distribution such as "ubuntu16.04" or + "alpine3.8" + :param str arch: The target CPU architecture such as "amd64" + """ + return [] + + @abstractmethod + async def generate_docker_args( + self, + docker: aiodocker.docker.Docker, + device_alloc, + ) -> Mapping[str, Any]: + """ + When starting a new container, generate device-specific options for the + docker container create API as a dictionary, referring the given allocation + map. The agent will merge it with its own options. + """ + return {} + + async def generate_resource_data(self, device_alloc) -> Mapping[str, str]: + """ + Generate extra resource.txt key-value pair sets to be used by the plugin's + own hook libraries in containers. + """ + return {} + + @abstractmethod + async def restore_from_container( + self, + container: SessionContainer, + alloc_map: AbstractAllocMap, + ) -> None: + """ + When the agent restarts, retore the allocation map from the container + metadata dictionary fetched from aiodocker. + """ + pass + + @abstractmethod + async def get_attached_devices( + self, + device_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> Sequence[DeviceModelInfo]: + """ + Make up container-attached device information with allocated device id. + """ + return [] + + async def get_node_hwinfo(self) -> HardwareMetadata: + raise NotImplementedError + + +class ComputePluginContext(BasePluginContext[AbstractComputePlugin]): + plugin_group = 'backendai_accelerator_v20' + + @classmethod + def discover_plugins( + cls, + plugin_group: str, + blocklist: Container[str] = None, + ) -> Iterator[Tuple[str, Type[AbstractComputePlugin]]]: + scanned_plugins = [*super().discover_plugins(plugin_group, blocklist)] + + def accel_lt_intrinsic(item): + # push back "intrinsic" plugins (if exists) + if item[0] in ('cpu', 'mem'): + return 0 + return -1 + + scanned_plugins.sort(key=accel_lt_intrinsic) + yield from scanned_plugins + + def attach_intrinsic_device(self, plugin: AbstractComputePlugin) -> None: + self.plugins[plugin.key] = plugin + + +@attr.s(auto_attribs=True, slots=True) +class Mount: + type: MountTypes + source: Optional[Path] + target: Path + permission: MountPermission = MountPermission.READ_ONLY + opts: Optional[Mapping[str, Any]] = None + + def __str__(self): + return f'{self.source}:{self.target}:{self.permission.value}' + + @classmethod + def from_str(cls, s): + source, target, perm = s.split(':') + source = Path(source) + type = MountTypes.BIND + if not source.is_absolute(): + if len(source.parts) == 1: + source = str(source) + type = MountTypes.VOLUME + else: + raise ValueError('Mount source must be an absolute path ' + 'if it is not a volume name.', + source) + target = Path(target) + if not target.is_absolute(): + raise ValueError('Mount target must be an absolute path.', target) + perm = MountPermission(perm) + return cls(type, source, target, perm, None) + + +@attr.s(auto_attribs=True) +class DeviceSlotInfo: + slot_type: SlotTypes + slot_name: SlotName + amount: Decimal + + +class AbstractAllocMap(metaclass=ABCMeta): + + device_slots: Mapping[DeviceId, DeviceSlotInfo] + device_mask: FrozenSet[DeviceId] + exclusive_slot_types: Iterable[SlotName] + allocations: MutableMapping[SlotName, MutableMapping[DeviceId, Decimal]] + + def __init__( + self, *, + device_slots: Mapping[DeviceId, DeviceSlotInfo] = None, + device_mask: Iterable[DeviceId] = None, + exclusive_slot_types: Iterable[SlotName] = None, + ) -> None: + self.exclusive_slot_types = exclusive_slot_types or {} + self.device_slots = device_slots or {} + self.slot_types = {info.slot_name: info.slot_type for info in self.device_slots.values()} + self.device_mask = frozenset(device_mask) if device_mask is not None else frozenset() + self.allocations = defaultdict(lambda: defaultdict(Decimal)) + for dev_id, dev_slot_info in self.device_slots.items(): + self.allocations[dev_slot_info.slot_name][dev_id] = Decimal(0) + + def clear(self) -> None: + self.allocations.clear() + for dev_id, dev_slot_info in self.device_slots.items(): + self.allocations[dev_slot_info.slot_name][dev_id] = Decimal(0) + + def check_exclusive(self, a: SlotName, b: SlotName) -> bool: + if not self.exclusive_slot_types: + return False + if a == b: + return False + a_in_exclusive_set = (a in self.exclusive_slot_types) + b_in_exclusive_set = (b in self.exclusive_slot_types) + if a_in_exclusive_set and b_in_exclusive_set: + # fast-path for exact match + return True + for t in self.exclusive_slot_types: + if '*' in t: + a_in_exclusive_set = a_in_exclusive_set or fnmatch.fnmatchcase(a, t) + b_in_exclusive_set = b_in_exclusive_set or fnmatch.fnmatchcase(b, t) + return a_in_exclusive_set and b_in_exclusive_set + + def format_current_allocations(self) -> str: + bufs = [] + for slot_name, per_device_alloc in self.allocations.items(): + bufs.append(f"slot[{slot_name}]:") + for device_id, alloc in per_device_alloc.items(): + bufs.append(f" {device_id}: {alloc}") + return "\n".join(bufs) + + @abstractmethod + def allocate( + self, + slots: Mapping[SlotName, Decimal], + *, + context_tag: str = None, + ) -> Mapping[SlotName, Mapping[DeviceId, Decimal]]: + """ + Allocate the given amount of resources. + + For a slot type, there may be multiple different devices which can allocate resources + in the given slot type. An implementation of alloc map finds suitable match from the + remaining capacities of those devices. + + Returns a mapping from each requested slot to the allocations per device. + """ + pass + + @abstractmethod + def apply_allocation( + self, + existing_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> None: + """ + Apply the given allocation restored from disk or other persistent storages. + """ + pass + + @abstractmethod + def free( + self, + existing_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> None: + """ + Free the allocated resources using the token returned when the allocation + occurred. + """ + pass + + +def bitmask2set(mask: int) -> FrozenSet[int]: + bpos = 0 + bset = [] + while mask > 0: + if (mask & 1) == 1: + bset.append(bpos) + mask = (mask >> 1) + bpos += 1 + return frozenset(bset) + + +T = TypeVar("T") + + +def distribute(num_items: int, groups: Sequence[T]) -> Mapping[T, int]: + base, extra = divmod(num_items, len(groups)) + return dict(zip( + groups, + ((base + (1 if i < extra else 0)) for i in range(len(groups))), + )) + + +class DiscretePropertyAllocMap(AbstractAllocMap): + """ + An allocation map using discrete property. + The user must pass a "property function" which returns a desired resource + property from the device object. + + e.g., 1.0 means 1 device, 2.0 means 2 devices, etc. + (no fractions allowed) + """ + + def __init__( + self, + *args, + allocation_strategy: AllocationStrategy = AllocationStrategy.EVENLY, + **kwargs, + ) -> None: + self.allocation_strategy = allocation_strategy + self._allocate_impl = { + AllocationStrategy.FILL: self._allocate_by_filling, + AllocationStrategy.EVENLY: self._allocate_evenly, + } + super().__init__(*args, **kwargs) + + def allocate( + self, + slots: Mapping[SlotName, Decimal], + *, + context_tag: str = None, + ) -> Mapping[SlotName, Mapping[DeviceId, Decimal]]: + # prune zero alloc slots + requested_slots = {k: v for k, v in slots.items() if v > 0} + + # check exclusive + for slot_name_a in requested_slots.keys(): + for slot_name_b in requested_slots.keys(): + if self.check_exclusive(slot_name_a, slot_name_b): + raise InvalidResourceCombination( + f"Slots {slot_name_a} and {slot_name_b} cannot be allocated at the same time.") + + # check unique + for slot_name, alloc in requested_slots.items(): + slot_type = self.slot_types.get(slot_name, SlotTypes.COUNT) + if slot_type in (SlotTypes.COUNT, SlotTypes.BYTES): + pass + elif slot_type == SlotTypes.UNIQUE: + if alloc != Decimal(1): + raise InvalidResourceArgument( + f"You may allocate only 1 for the unique-type slot {slot_name}", + ) + + return self._allocate_impl[self.allocation_strategy]( + requested_slots, + context_tag=context_tag, + ) + + def _allocate_by_filling( + self, + requested_slots: Mapping[SlotName, Decimal], + *, + context_tag: str = None, + ) -> Mapping[SlotName, Mapping[DeviceId, Decimal]]: + allocation = {} + for slot_name, alloc in requested_slots.items(): + slot_allocation: MutableMapping[DeviceId, Decimal] = {} + + sorted_dev_allocs = sorted( + self.allocations[slot_name].items(), # k: slot_name, v: per-device alloc + key=lambda pair: self.device_slots[pair[0]].amount - pair[1], + reverse=True) + + if log_alloc_map: + log.debug('DiscretePropertyAllocMap: allocating {} {}', slot_name, alloc) + log.debug('DiscretePropertyAllocMap: current-alloc: {!r}', sorted_dev_allocs) + + total_allocatable = int(0) + remaining_alloc = Decimal(alloc).normalize() + + # fill up starting from the most free devices + for dev_id, current_alloc in sorted_dev_allocs: + current_alloc = self.allocations[slot_name][dev_id] + assert slot_name == self.device_slots[dev_id].slot_name + total_allocatable += int(self.device_slots[dev_id].amount - current_alloc) + if total_allocatable < alloc: + raise InsufficientResource( + 'DiscretePropertyAllocMap: insufficient allocatable amount!', + context_tag, slot_name, str(alloc), str(total_allocatable)) + for dev_id, current_alloc in sorted_dev_allocs: + current_alloc = self.allocations[slot_name][dev_id] + allocatable = (self.device_slots[dev_id].amount - current_alloc) + if allocatable > 0: + allocated = Decimal(min(remaining_alloc, allocatable)) + slot_allocation[dev_id] = allocated + self.allocations[slot_name][dev_id] += allocated + remaining_alloc -= allocated + if remaining_alloc == 0: + break + allocation[slot_name] = slot_allocation + return allocation + + def _allocate_evenly( + self, + requested_slots: Mapping[SlotName, Decimal], + *, + context_tag: str = None, + ) -> Mapping[SlotName, Mapping[DeviceId, Decimal]]: + allocation = {} + + for slot_name, requested_alloc in requested_slots.items(): + new_alloc: MutableMapping[DeviceId, Decimal] = defaultdict(Decimal) + remaining_alloc = int(Decimal(requested_alloc)) + + while remaining_alloc > 0: + # calculate remaining slots per device + total_allocatable = int(sum( + self.device_slots[dev_id].amount - current_alloc - new_alloc[dev_id] + for dev_id, current_alloc in self.allocations[slot_name].items() + )) + # if the sum of remaining slot is less than the remaining alloc, fail. + if total_allocatable < remaining_alloc: + raise InsufficientResource( + "DiscretePropertyAllocMap: insufficient allocatable amount!", + context_tag, slot_name, str(requested_alloc), str(total_allocatable), + ) + + # calculate the amount to spread out + nonzero_devs = [ + dev_id + for dev_id, current_alloc in self.allocations[slot_name].items() + if self.device_slots[dev_id].amount - current_alloc - new_alloc[dev_id] > 0 + ] + initial_diffs = distribute(remaining_alloc, nonzero_devs) + diffs = { + dev_id: min( + int(self.device_slots[dev_id].amount - current_alloc - new_alloc[dev_id]), + initial_diffs.get(dev_id, 0), + ) + for dev_id, current_alloc in self.allocations[slot_name].items() + } + + # distribute the remainig alloc to the remaining slots. + sorted_dev_allocs = sorted( + self.allocations[slot_name].items(), # k: slot_name, v: per-device alloc + key=lambda pair: self.device_slots[pair[0]].amount - pair[1], + reverse=True) + for dev_id, current_alloc in sorted_dev_allocs: + diff = diffs[dev_id] + new_alloc[dev_id] += diff + remaining_alloc -= diff + if remaining_alloc == 0: + break + + for dev_id, allocated in new_alloc.items(): + self.allocations[slot_name][dev_id] += allocated + allocation[slot_name] = {k: v for k, v in new_alloc.items() if v > 0} + + return allocation + + def apply_allocation( + self, + existing_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> None: + for slot_name, per_device_alloc in existing_alloc.items(): + for device_id, alloc in per_device_alloc.items(): + self.allocations[slot_name][device_id] += alloc + + def free( + self, + existing_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> None: + for slot_name, per_device_alloc in existing_alloc.items(): + for device_id, alloc in per_device_alloc.items(): + self.allocations[slot_name][device_id] -= alloc + + +class FractionAllocMap(AbstractAllocMap): + + def __init__( + self, + *args, + allocation_strategy: AllocationStrategy = AllocationStrategy.EVENLY, + quantum_size: Decimal = Decimal("0.01"), + enforce_physical_continuity: bool = True, + **kwargs, + ) -> None: + self.allocation_strategy = allocation_strategy + self.quantum_size = quantum_size + self.enforce_physical_continuity = enforce_physical_continuity + self._allocate_impl = { + AllocationStrategy.FILL: self._allocate_by_filling, + AllocationStrategy.EVENLY: self._allocate_evenly, + } + super().__init__(*args, **kwargs) + self.digits = Decimal(10) ** -2 # decimal points that is supported by agent + self.powers = Decimal(100) # reciprocal of self.digits + + def allocate( + self, + slots: Mapping[SlotName, Decimal], + *, + context_tag: str = None, + min_memory: Decimal = Decimal("0.01"), + ) -> Mapping[SlotName, Mapping[DeviceId, Decimal]]: + # prune zero alloc slots + requested_slots = {k: v for k, v in slots.items() if v > 0} + + # check exclusive + for slot_name_a in requested_slots.keys(): + for slot_name_b in requested_slots.keys(): + if self.check_exclusive(slot_name_a, slot_name_b): + raise InvalidResourceCombination( + f"Slots {slot_name_a} and {slot_name_b} cannot be allocated at the same time.", + ) + + calculated_alloc_map = self._allocate_impl[self.allocation_strategy]( + requested_slots, + context_tag=context_tag, + min_memory=min_memory, + ) + actual_alloc_map: MutableMapping[SlotName, MutableMapping[DeviceId, Decimal]] = {} + for slot_name, alloc in calculated_alloc_map.items(): + actual_alloc: MutableMapping[DeviceId, Decimal] = {} + for dev_id, val in alloc.items(): + self.allocations[slot_name][dev_id] = round_down( + self.allocations[slot_name][dev_id], self.quantum_size) + actual_alloc[dev_id] = round_down(val, self.quantum_size) + if sum(actual_alloc.values()) == 0 and requested_slots[slot_name] > 0: + raise NotMultipleOfQuantum( + f'Requested resource amount for {slot_name} is {requested_slots[slot_name]} ' + 'but actual calculated amount is zero. This can happen if user requests ' + 'resource amount smaller than target device\'s quantum size.', + ) + actual_alloc_map[slot_name] = actual_alloc + + return actual_alloc_map + + def _allocate_by_filling( + self, + requested_slots: Mapping[SlotName, Decimal], + *, + context_tag: str = None, + min_memory: Decimal = Decimal(0.01), + ) -> Mapping[SlotName, Mapping[DeviceId, Decimal]]: + allocation = {} + for slot_name, alloc in requested_slots.items(): + slot_allocation: MutableMapping[DeviceId, Decimal] = {} + + # fill up starting from the most free devices + sorted_dev_allocs = sorted( + self.allocations[slot_name].items(), + key=lambda pair: self.device_slots[pair[0]].amount - pair[1], + reverse=True) + + if log_alloc_map: + log.debug('FractionAllocMap: allocating {} {}', slot_name, alloc) + log.debug('FractionAllocMap: current-alloc: {!r}', sorted_dev_allocs) + + slot_type = self.slot_types.get(slot_name, SlotTypes.COUNT) + if slot_type in (SlotTypes.COUNT, SlotTypes.BYTES): + pass + elif slot_type == SlotTypes.UNIQUE: + if alloc != Decimal(1): + raise InvalidResourceArgument( + f"You may allocate only 1 for the unique-type slot {slot_name}", + ) + total_allocatable = Decimal(0) + remaining_alloc = Decimal(alloc).normalize() + + for dev_id, current_alloc in sorted_dev_allocs: + current_alloc = self.allocations[slot_name][dev_id] + assert slot_name == self.device_slots[dev_id].slot_name + total_allocatable += (self.device_slots[dev_id].amount - + current_alloc) + if total_allocatable < alloc: + raise InsufficientResource( + 'FractionAllocMap: insufficient allocatable amount!', + context_tag, slot_name, str(alloc), str(total_allocatable)) + slot_allocation = {} + for dev_id, current_alloc in sorted_dev_allocs: + current_alloc = self.allocations[slot_name][dev_id] + allocatable = (self.device_slots[dev_id].amount - + current_alloc) + if allocatable > 0: + allocated = min(remaining_alloc, allocatable) + slot_allocation[dev_id] = allocated + self.allocations[slot_name][dev_id] += allocated + remaining_alloc -= allocated + if remaining_alloc <= 0: + break + + allocation[slot_name] = slot_allocation + return allocation + + def _allocate_evenly( + self, + requested_slots: Mapping[SlotName, Decimal], + *, + context_tag: str = None, + min_memory: Decimal = Decimal(0.01), + ) -> Mapping[SlotName, Mapping[DeviceId, Decimal]]: + + # higher value means more even with 0 being the highest value + def measure_evenness(alloc_map: Mapping[DeviceId, Decimal]) \ + -> Decimal: + alloc_arr = sorted([alloc_map[dev_id] for dev_id in alloc_map]) + evenness_score = Decimal(0).quantize(self.digits) + for idx in range(len(alloc_arr) - 1): + evenness_score += abs(alloc_arr[idx + 1] - alloc_arr[idx]) + return -evenness_score + + # higher value means more fragmented + # i.e. the number of unusable resources is higher + def measure_fragmentation(allocation: Mapping[DeviceId, Decimal], + min_memory: Decimal): + fragmentation_arr = [self.device_slots[dev_id].amount - allocation[dev_id] + for dev_id in allocation] + return sum(self.digits < v.quantize(self.digits) < min_memory.quantize(self.digits) + for v in fragmentation_arr) + + # evenly distributes remaining_alloc across dev_allocs + def distribute_evenly(dev_allocs: List[Tuple[DeviceId, Decimal]], + remaining_alloc: Decimal, + allocation: MutableMapping[DeviceId, Decimal]): + n_devices = len(dev_allocs) + for dev_id, _ in dev_allocs: + dev_allocation = remaining_alloc / n_devices + dev_allocation = dev_allocation.quantize(self.digits, rounding=ROUND_DOWN) + allocation[dev_id] = dev_allocation + + # need to take care of decimals + remainder = round(remaining_alloc * self.powers - + dev_allocation * n_devices * self.powers) + for idx in range(remainder): + dev_id, _ = dev_allocs[idx] + allocation[dev_id] += self.digits + + # allocates remaining_alloc across multiple devices i.e. dev_allocs + # all devices in dev_allocs are being used + def allocate_across_devices(dev_allocs: List[Tuple[DeviceId, Decimal]], + remaining_alloc: Decimal, slot_name: str) \ + -> MutableMapping[DeviceId, Decimal]: + slot_allocation: MutableMapping[DeviceId, Decimal] = {} + n_devices = len(dev_allocs) + idx = n_devices - 1 # check from the device with smallest allocatable resource + while n_devices > 0: + dev_id, current_alloc = dev_allocs[idx] + allocatable = self.device_slots[dev_id].amount - current_alloc + # if the remaining_alloc can be allocated to evenly among remaining devices + if allocatable >= remaining_alloc / n_devices: + break + slot_allocation[dev_id] = allocatable.quantize(self.digits) + remaining_alloc -= allocatable + idx -= 1 + n_devices -= 1 + + if n_devices > 0: + distribute_evenly(dev_allocs[:n_devices], remaining_alloc, slot_allocation) + + return slot_allocation + + min_memory = min_memory.quantize(self.digits) + allocation = {} + for slot_name, alloc in requested_slots.items(): + slot_allocation: MutableMapping[DeviceId, Decimal] = {} + remaining_alloc = Decimal(alloc).normalize() + sorted_dev_allocs = sorted( + self.allocations[slot_name].items(), + key=lambda pair: self.device_slots[pair[0]].amount - pair[1], + reverse=True) + + # do not consider devices whose remaining resource under min_memory + sorted_dev_allocs = list(filter( + lambda pair: self.device_slots[pair[0]].amount - pair[1] >= min_memory, + sorted_dev_allocs)) + + if log_alloc_map: + log.debug('FractionAllocMap: allocating {} {}', slot_name, alloc) + log.debug('FractionAllocMap: current-alloc: {!r}', sorted_dev_allocs) + + # check if there is enough resource for allocation + total_allocatable = Decimal(0) + for dev_id, current_alloc in sorted_dev_allocs: + current_alloc = self.allocations[slot_name][dev_id] + total_allocatable += (self.device_slots[dev_id].amount - current_alloc) + if total_allocatable.quantize(self.digits) < \ + remaining_alloc.quantize(self.digits): + raise InsufficientResource( + 'FractionAllocMap: insufficient allocatable amount!', + context_tag, slot_name, str(alloc), str(total_allocatable)) + + # allocate resources + if (remaining_alloc <= + self.device_slots[sorted_dev_allocs[0][0]].amount - sorted_dev_allocs[0][1]): + # if remaining_alloc fits in one device + slot_allocation = {} + for dev_id, current_alloc in sorted_dev_allocs[::-1]: + allocatable = (self.device_slots[dev_id].amount - current_alloc) + if remaining_alloc <= allocatable: + slot_allocation[dev_id] = remaining_alloc.quantize(self.digits) + break + else: + # need to distribute across devices + # calculate the minimum number of required devices + n_devices, allocated = 0, Decimal(0) + for dev_id, current_alloc in sorted_dev_allocs: + n_devices += 1 + allocated += self.device_slots[dev_id].amount - current_alloc + if allocated.quantize(self.digits) >= \ + remaining_alloc.quantize(self.digits): + break + # need to check from using minimum number of devices to using all devices + # evenness must be non-decreasing with the increase of window size + best_alloc_candidate_arr = [] + for n_dev in range(n_devices, len(sorted_dev_allocs) + 1): + allocatable = sum(map(lambda x: self.device_slots[x[0]].amount - x[1], + sorted_dev_allocs[:n_dev]), start=Decimal(0)) + # choose the best allocation from all possible allocation candidates + alloc_candidate = allocate_across_devices(sorted_dev_allocs[:n_dev], + remaining_alloc, slot_name) + max_evenness = measure_evenness(alloc_candidate) + # three criteria to decide allocation are + # eveness, number of resources used, and amount of fragmentatino + alloc_candidate_arr = [(alloc_candidate, max_evenness, -len(alloc_candidate), + -measure_fragmentation(alloc_candidate, min_memory))] + for idx in range(1, len(sorted_dev_allocs) - n_dev + 1): + # update amount of allocatable space + allocatable -= \ + self.device_slots[sorted_dev_allocs[idx - 1][0]].amount - \ + sorted_dev_allocs[idx - 1][1] + allocatable += \ + self.device_slots[sorted_dev_allocs[idx + n_dev - 1][0]].amount - \ + sorted_dev_allocs[idx + n_dev - 1][1] + # break if not enough resource + if allocatable.quantize(self.digits) < \ + remaining_alloc.quantize(self.digits): + break + alloc_candidate = allocate_across_devices( + sorted_dev_allocs[idx:idx + n_dev], remaining_alloc, slot_name) + # evenness gets worse (or same at best) as the allocatable gets smaller + evenness_score = measure_evenness(alloc_candidate) + if evenness_score < max_evenness: + break + alloc_candidate_arr.append((alloc_candidate, evenness_score, + -len(alloc_candidate), -measure_fragmentation(alloc_candidate, min_memory))) + # since evenness is the same, sort by fragmentation (low is good) + best_alloc_candidate_arr.append( + sorted(alloc_candidate_arr, key=lambda x: x[2])[-1]) + # there is no need to look at more devices if the desired evenness is achieved + if max_evenness.quantize(self.digits) == self.digits: + best_alloc_candidate_arr = best_alloc_candidate_arr[-1:] + break + # choose the best allocation with the three criteria + slot_allocation = sorted(best_alloc_candidate_arr, + key=operator.itemgetter(1, 2, 3))[-1][0] + allocation[slot_name] = slot_allocation + for dev_id, value in slot_allocation.items(): + self.allocations[slot_name][dev_id] += value + return allocation + + def apply_allocation( + self, + existing_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> None: + for slot_name, per_device_alloc in existing_alloc.items(): + for device_id, alloc in per_device_alloc.items(): + self.allocations[slot_name][device_id] += alloc + + def free( + self, + existing_alloc: Mapping[SlotName, Mapping[DeviceId, Decimal]], + ) -> None: + for slot_name, per_device_alloc in existing_alloc.items(): + for device_id, alloc in per_device_alloc.items(): + self.allocations[slot_name][device_id] -= alloc diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py new file mode 100644 index 0000000000..ccfb2c0989 --- /dev/null +++ b/src/ai/backend/agent/server.py @@ -0,0 +1,795 @@ +from __future__ import annotations + +import asyncio +import functools +import importlib +from ipaddress import ip_network, _BaseAddress as BaseIPAddress +import logging, logging.config +import os, os.path +from pathlib import Path +from pprint import pformat, pprint +import shutil +import signal +import sys +from typing import ( + Any, + AsyncGenerator, + Callable, + ClassVar, + Coroutine, + Dict, + Literal, + Mapping, + Sequence, + Set, + Tuple, + cast, +) +from uuid import UUID + +import aiomonitor +import aiotools +from aiotools import aclosing +from callosum.rpc import Peer, RPCMessage +from callosum.ordering import ExitOrderedAsyncScheduler +from callosum.lower.zeromq import ZeroMQAddress, ZeroMQRPCTransport +import click +from etcetra.types import WatchEventType +from setproctitle import setproctitle +import tomlkit +from trafaret.dataerror import DataError as TrafaretDataError + +from ai.backend.common import config, utils, identity, msgpack +from ai.backend.common.etcd import AsyncEtcd, ConfigScopes +from ai.backend.common.logging import Logger, BraceStyleAdapter +from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext +from ai.backend.common.types import ( + HardwareMetadata, aobject, + ClusterInfo, + HostPortPair, + KernelId, + KernelCreationConfig, + SessionId, +) +from ai.backend.common.utils import current_loop +from . import __version__ as VERSION +from .agent import AbstractAgent +from .config import ( + agent_local_config_iv, + agent_etcd_config_iv, + docker_extra_config_iv, + container_etcd_config_iv, +) +from .exception import ResourceError +from .types import AgentBackend, VolumeInfo, LifecycleEvent +from .utils import get_subnet_ip + +log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.server')) + +deeplearning_image_keys = { + 'tensorflow', 'caffe', + 'keras', 'torch', + 'mxnet', 'theano', +} + +deeplearning_sample_volume = VolumeInfo( + 'deeplearning-samples', '/home/work/samples', 'ro', +) + +agent_instance: AgentRPCServer + + +async def get_extra_volumes(docker, lang): + avail_volumes = (await docker.volumes.list())['Volumes'] + if not avail_volumes: + return [] + avail_volume_names = set(v['Name'] for v in avail_volumes) + + # deeplearning specialization + # TODO: extract as config + volume_list = [] + for k in deeplearning_image_keys: + if k in lang: + volume_list.append(deeplearning_sample_volume) + break + + # Mount only actually existing volumes + mount_list = [] + for vol in volume_list: + if vol.name in avail_volume_names: + mount_list.append(vol) + else: + log.info('skipped attaching extra volume {0} ' + 'to a kernel based on image {1}', + vol.name, lang) + return mount_list + + +def collect_error(meth: Callable) -> Callable: + @functools.wraps(meth) + async def _inner(self: AgentRPCServer, *args, **kwargs): + try: + return await meth(self, *args, **kwargs) + except Exception: + await self.agent.produce_error_event() + raise + return _inner + + +class RPCFunctionRegistry: + + functions: Set[str] + + def __init__(self) -> None: + self.functions = set() + + def __call__( + self, + meth: Callable[..., Coroutine[None, None, Any]], + ) -> Callable[[AgentRPCServer, RPCMessage], Coroutine[None, None, Any]]: + + @functools.wraps(meth) + async def _inner(self_: AgentRPCServer, request: RPCMessage) -> Any: + try: + if request.body is None: + return await meth(self_) + else: + return await meth( + self_, + *request.body['args'], + **request.body['kwargs'], + ) + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except ResourceError: + # This is an expected scenario. + raise + except Exception: + log.exception('unexpected error') + await self_.error_monitor.capture_exception() + raise + + self.functions.add(meth.__name__) + return _inner + + +class AgentRPCServer(aobject): + rpc_function: ClassVar[RPCFunctionRegistry] = RPCFunctionRegistry() + + loop: asyncio.AbstractEventLoop + agent: AbstractAgent + rpc_server: Peer + rpc_addr: str + agent_addr: str + + _stop_signal: signal.Signals + + def __init__( + self, + etcd: AsyncEtcd, + local_config: Mapping[str, Any], + *, + skip_detect_manager: bool = False, + ) -> None: + self.loop = current_loop() + self.etcd = etcd + self.local_config = local_config + self.skip_detect_manager = skip_detect_manager + self._stop_signal = signal.SIGTERM + + async def __ainit__(self) -> None: + # Start serving requests. + await self.update_status('starting') + + if not self.skip_detect_manager: + await self.detect_manager() + + await self.read_agent_config() + await self.read_agent_config_container() + + self.stats_monitor = StatsPluginContext(self.etcd, self.local_config) + self.error_monitor = ErrorPluginContext(self.etcd, self.local_config) + await self.stats_monitor.init() + await self.error_monitor.init() + + backend = self.local_config['agent']['backend'] + agent_mod = importlib.import_module(f"ai.backend.agent.{backend.value}") + self.agent = await agent_mod.get_agent_cls().new( # type: ignore + self.etcd, + self.local_config, + stats_monitor=self.stats_monitor, + error_monitor=self.error_monitor, + ) + + rpc_addr = self.local_config['agent']['rpc-listen-addr'] + self.rpc_server = Peer( + bind=ZeroMQAddress(f"tcp://{rpc_addr}"), + transport=ZeroMQRPCTransport, + scheduler=ExitOrderedAsyncScheduler(), + serializer=msgpack.packb, + deserializer=msgpack.unpackb, + debug_rpc=self.local_config['debug']['enabled'], + ) + for func_name in self.rpc_function.functions: + self.rpc_server.handle_function(func_name, getattr(self, func_name)) + log.info('started handling RPC requests at {}', rpc_addr) + + await self.etcd.put('ip', rpc_addr.host, scope=ConfigScopes.NODE) + watcher_port = utils.nmget(self.local_config, 'watcher.service-addr.port', None) + if watcher_port is not None: + await self.etcd.put('watcher_port', watcher_port, scope=ConfigScopes.NODE) + + await self.update_status('running') + + async def detect_manager(self): + log.info('detecting the manager...') + manager_instances = await self.etcd.get_prefix('nodes/manager') + if not manager_instances: + log.warning('watching etcd to wait for the manager being available') + async with aclosing(self.etcd.watch_prefix('nodes/manager')) as agen: + async for ev in agen: + if ev.event == WatchEventType.PUT and ev.value == 'up': + break + log.info('detected at least one manager running') + + async def read_agent_config(self): + # Fill up Redis configs from etcd. + self.local_config['redis'] = config.redis_config_iv.check( + await self.etcd.get_prefix('config/redis'), + ) + log.info('configured redis_addr: {0}', self.local_config['redis']['addr']) + + # Fill up vfolder configs from etcd. + self.local_config['vfolder'] = config.vfolder_config_iv.check( + await self.etcd.get_prefix('volumes'), + ) + if self.local_config['vfolder']['mount'] is None: + log.info('assuming use of storage-proxy since vfolder mount path is not configured in etcd') + else: + log.info('configured vfolder mount base: {0}', self.local_config['vfolder']['mount']) + log.info('configured vfolder fs prefix: {0}', self.local_config['vfolder']['fsprefix']) + + # Fill up shared agent configurations from etcd. + agent_etcd_config = agent_etcd_config_iv.check( + await self.etcd.get_prefix('config/agent'), + ) + for k, v in agent_etcd_config.items(): + self.local_config['agent'][k] = v + + async def read_agent_config_container(self): + # Fill up global container configurations from etcd. + try: + container_etcd_config = container_etcd_config_iv.check( + await self.etcd.get_prefix('config/container'), + ) + except TrafaretDataError as etrafa: + log.warning("etcd: container-config error: {}".format(etrafa)) + container_etcd_config = {} + for k, v in container_etcd_config.items(): + self.local_config['container'][k] = v + log.info("etcd: container-config: {}={}".format(k, v)) + + async def __aenter__(self) -> None: + await self.rpc_server.__aenter__() + + def mark_stop_signal(self, stop_signal: signal.Signals) -> None: + self._stop_signal = stop_signal + + async def __aexit__(self, *exc_info) -> None: + # Stop receiving further requests. + await self.rpc_server.__aexit__(*exc_info) + await self.agent.shutdown(self._stop_signal) + await self.stats_monitor.cleanup() + await self.error_monitor.cleanup() + + @collect_error + async def update_status(self, status): + await self.etcd.put('', status, scope=ConfigScopes.NODE) + + @rpc_function + @collect_error + async def update_scaling_group(self, scaling_group): + cfg_src_path = config.find_config_file('agent') + with open(cfg_src_path, 'r') as f: + data = tomlkit.load(f) + data['agent']['scaling-group'] = scaling_group + shutil.copy(cfg_src_path, f"{cfg_src_path}.bak") + with open(cfg_src_path, 'w') as f: + tomlkit.dump(data, f) + self.local_config['agent']['scaling-group'] = scaling_group + log.info('rpc::update_scaling_group()') + + @rpc_function + @collect_error + async def ping(self, msg: str) -> str: + log.debug('rpc::ping()') + return msg + + @rpc_function + @collect_error + async def gather_hwinfo(self) -> Mapping[str, HardwareMetadata]: + log.debug('rpc::gather_hwinfo()') + return await self.agent.gather_hwinfo() + + @rpc_function + @collect_error + async def ping_kernel(self, kernel_id: str): + log.debug('rpc::ping_kernel({0})', kernel_id) + + @rpc_function + @collect_error + async def create_kernels( + self, + creation_id: str, + raw_session_id: str, + raw_kernel_ids: Sequence[str], + raw_configs: Sequence[dict], + raw_cluster_info: dict, + ): + cluster_info = cast(ClusterInfo, raw_cluster_info) + session_id = SessionId(UUID(raw_session_id)) + raw_results = [] + coros = [] + for raw_kernel_id, raw_config in zip(raw_kernel_ids, raw_configs): + log.info('rpc::create_kernel(k:{0}, img:{1})', + raw_kernel_id, raw_config['image']['canonical']) + kernel_id = KernelId(UUID(raw_kernel_id)) + kernel_config = cast(KernelCreationConfig, raw_config) + coros.append(self.agent.create_kernel( + creation_id, + session_id, + kernel_id, + kernel_config, + cluster_info, + )) + results = await asyncio.gather(*coros, return_exceptions=True) + errors = [*filter(lambda item: isinstance(item, Exception), results)] + if errors: + # Raise up the first error. + if len(errors) == 1: + raise errors[0] + raise aiotools.TaskGroupError("agent.create_kernels() failed", errors) + raw_results = [ + { + 'id': str(result['id']), + 'kernel_host': result['kernel_host'], + 'repl_in_port': result['repl_in_port'], + 'repl_out_port': result['repl_out_port'], + 'stdin_port': result['stdin_port'], # legacy + 'stdout_port': result['stdout_port'], # legacy + 'service_ports': result['service_ports'], + 'container_id': result['container_id'], + 'resource_spec': result['resource_spec'], + 'attached_devices': result['attached_devices'], + } + for result in results + ] + return raw_results + + @rpc_function + @collect_error + async def destroy_kernel( + self, + kernel_id: str, + reason: str = None, + suppress_events: bool = False, + ): + loop = asyncio.get_running_loop() + done = loop.create_future() + log.info('rpc::destroy_kernel(k:{0})', kernel_id) + await self.agent.inject_container_lifecycle_event( + KernelId(UUID(kernel_id)), + LifecycleEvent.DESTROY, + reason or 'user-requested', + done_future=done, + suppress_events=suppress_events, + ) + return await done + + @rpc_function + @collect_error + async def interrupt_kernel(self, kernel_id: str): + log.info('rpc::interrupt_kernel(k:{0})', kernel_id) + await self.agent.interrupt_kernel(KernelId(UUID(kernel_id))) + + @rpc_function + @collect_error + async def get_completions(self, kernel_id: str, + text: str, opts: dict): + log.debug('rpc::get_completions(k:{0}, ...)', kernel_id) + await self.agent.get_completions(KernelId(UUID(kernel_id)), text, opts) + + @rpc_function + @collect_error + async def get_logs(self, kernel_id: str): + log.info('rpc::get_logs(k:{0})', kernel_id) + return await self.agent.get_logs(KernelId(UUID(kernel_id))) + + @rpc_function + @collect_error + async def restart_kernel( + self, + creation_id: str, + session_id: str, + kernel_id: str, + updated_config: dict, + ): + log.info('rpc::restart_kernel(s:{0}, k:{1})', session_id, kernel_id) + return await self.agent.restart_kernel( + creation_id, + SessionId(UUID(session_id)), + KernelId(UUID(kernel_id)), + cast(KernelCreationConfig, updated_config), + ) + + @rpc_function + @collect_error + async def execute( + self, + kernel_id, # type: str + api_version, # type: int + run_id, # type: str + mode, # type: Literal['query', 'batch', 'continue', 'input'] + code, # type: str + opts, # type: Dict[str, Any] + flush_timeout, # type: float + ): + # type: (...) -> Dict[str, Any] + if mode != 'continue': + log.info('rpc::execute(k:{0}, run-id:{1}, mode:{2}, code:{3!r})', + kernel_id, run_id, mode, + code[:20] + '...' if len(code) > 20 else code) + result = await self.agent.execute( + KernelId(UUID(kernel_id)), + run_id, + mode, + code, + opts=opts, + api_version=api_version, + flush_timeout=flush_timeout, + ) + return result + + @rpc_function + @collect_error + async def execute_batch( + self, + kernel_id, # type: str + startup_command, # type: str + ) -> None: + # DEPRECATED + asyncio.create_task(self.agent.execute_batch( + KernelId(UUID(kernel_id)), + startup_command, + )) + await asyncio.sleep(0) + + @rpc_function + @collect_error + async def start_service( + self, + kernel_id, # type: str + service, # type: str + opts, # type: Dict[str, Any] + ): + # type: (...) -> Dict[str, Any] + log.info('rpc::start_service(k:{0}, app:{1})', kernel_id, service) + return await self.agent.start_service(KernelId(UUID(kernel_id)), service, opts) + + @rpc_function + @collect_error + async def shutdown_service( + self, + kernel_id, # type: str + service, # type: str + ): + log.info('rpc::shutdown_service(k:{0}, app:{1})', kernel_id, service) + return await self.agent.shutdown_service(KernelId(UUID(kernel_id)), service) + + @rpc_function + @collect_error + async def upload_file(self, kernel_id: str, filename: str, filedata: bytes): + log.info('rpc::upload_file(k:{0}, fn:{1})', kernel_id, filename) + await self.agent.accept_file(KernelId(UUID(kernel_id)), filename, filedata) + + @rpc_function + @collect_error + async def download_file(self, kernel_id: str, filepath: str): + log.info('rpc::download_file(k:{0}, fn:{1})', kernel_id, filepath) + return await self.agent.download_file(KernelId(UUID(kernel_id)), filepath) + + @rpc_function + @collect_error + async def list_files(self, kernel_id: str, path: str): + log.info('rpc::list_files(k:{0}, fn:{1})', kernel_id, path) + return await self.agent.list_files(KernelId(UUID(kernel_id)), path) + + @rpc_function + @collect_error + async def shutdown_agent(self, terminate_kernels: bool): + # TODO: implement + log.info('rpc::shutdown_agent()') + pass + + @rpc_function + @collect_error + async def create_overlay_network(self, network_name: str) -> None: + log.debug('rpc::create_overlay_network(name:{})', network_name) + return await self.agent.create_overlay_network(network_name) + + @rpc_function + @collect_error + async def destroy_overlay_network(self, network_name: str) -> None: + log.debug('rpc::destroy_overlay_network(name:{})', network_name) + return await self.agent.destroy_overlay_network(network_name) + + @rpc_function + @collect_error + async def create_local_network(self, network_name: str) -> None: + log.debug('rpc::create_local_network(name:{})', network_name) + return await self.agent.create_local_network(network_name) + + @rpc_function + @collect_error + async def destroy_local_network(self, network_name: str) -> None: + log.debug('rpc::destroy_local_network(name:{})', network_name) + return await self.agent.destroy_local_network(network_name) + + @rpc_function + @collect_error + async def reset_agent(self): + log.debug('rpc::reset()') + kernel_ids = tuple(self.agent.kernel_registry.keys()) + tasks = [] + for kernel_id in kernel_ids: + try: + task = asyncio.ensure_future( + self.agent.destroy_kernel(kernel_id, 'agent-reset')) + tasks.append(task) + except Exception: + await self.error_monitor.capture_exception() + log.exception('reset: destroying {0}', kernel_id) + await asyncio.gather(*tasks) + + +@aiotools.server +async def server_main_logwrapper( + loop: asyncio.AbstractEventLoop, + pidx: int, + _args: Tuple[Any, ...], +) -> AsyncGenerator[None, signal.Signals]: + setproctitle(f"backend.ai: agent worker-{pidx}") + log_endpoint = _args[1] + logger = Logger(_args[0]['logging'], is_master=False, log_endpoint=log_endpoint) + with logger: + async with server_main(loop, pidx, _args): + yield + + +@aiotools.server +async def server_main( + loop: asyncio.AbstractEventLoop, + pidx: int, + _args: Tuple[Any, ...], +) -> AsyncGenerator[None, signal.Signals]: + local_config = _args[0] + + log.info('Preparing kernel runner environments...') + kernel_mod = importlib.import_module( + f"ai.backend.agent.{local_config['agent']['backend'].value}.kernel", + ) + krunner_volumes = await kernel_mod.prepare_krunner_env(local_config) # type: ignore + # TODO: merge k8s branch: nfs_mount_path = local_config['baistatic']['mounted-at'] + log.info('Kernel runner environments: {}', [*krunner_volumes.keys()]) + local_config['container']['krunner-volumes'] = krunner_volumes + + if not local_config['agent']['id']: + local_config['agent']['id'] = await identity.get_instance_id() + if not local_config['agent']['instance-type']: + local_config['agent']['instance-type'] = await identity.get_instance_type() + + etcd_credentials = None + if local_config['etcd']['user']: + etcd_credentials = { + 'user': local_config['etcd']['user'], + 'password': local_config['etcd']['password'], + } + scope_prefix_map = { + ConfigScopes.GLOBAL: '', + ConfigScopes.SGROUP: f"sgroup/{local_config['agent']['scaling-group']}", + ConfigScopes.NODE: f"nodes/agents/{local_config['agent']['id']}", + } + etcd = AsyncEtcd(local_config['etcd']['addr'], + local_config['etcd']['namespace'], + scope_prefix_map, + credentials=etcd_credentials) + + rpc_addr = local_config['agent']['rpc-listen-addr'] + if not rpc_addr.host: + _subnet_hint = await etcd.get('config/network/subnet/agent') + subnet_hint = None + if _subnet_hint is not None: + subnet_hint = ip_network(_subnet_hint) + log.debug('auto-detecting agent host') + local_config['agent']['rpc-listen-addr'] = HostPortPair( + await identity.get_instance_ip(subnet_hint), + rpc_addr.port, + ) + if 'kernel-host' in local_config['container']: + log.warning("The configuration parameter `container.kernel-host` is deprecated; " + "use `container.bind-host` instead!") + # fallback for legacy configs + local_config['container']['bind-host'] = local_config['container']['kernel-host'] + if not local_config['container']['bind-host']: + log.debug("auto-detecting `container.bind-host` from container subnet config " + "and agent.rpc-listen-addr") + local_config['container']['bind-host'] = await get_subnet_ip( + etcd, 'container', fallback_addr=local_config['agent']['rpc-listen-addr'].host, + ) + log.info('Agent external IP: {}', local_config['agent']['rpc-listen-addr'].host) + log.info('Container external IP: {}', local_config['container']['bind-host']) + if not local_config['agent']['region']: + local_config['agent']['region'] = await identity.get_instance_region() + log.info('Node ID: {0} (machine-type: {1}, host: {2})', + local_config['agent']['id'], + local_config['agent']['instance-type'], + rpc_addr.host) + + # Pre-load compute plugin configurations. + local_config['plugins'] = await etcd.get_prefix_dict('config/plugins/accelerator') + + # Start aiomonitor. + # Port is set by config (default=50002). + monitor = aiomonitor.Monitor( + loop, + port=local_config['agent']['aiomonitor-port'], + console_enabled=False, + ) + monitor.prompt = "monitor (agent) >>> " + monitor.start() + + # Start RPC server. + global agent_instance + agent = await AgentRPCServer.new( + etcd, local_config, + skip_detect_manager=local_config['agent']['skip-manager-detection'], + ) + agent_instance = agent + + # Run! + try: + async with agent: + stop_signal = yield + agent.mark_stop_signal(stop_signal) + finally: + monitor.close() + + +@click.group(invoke_without_command=True) +@click.option('-f', '--config-path', '--config', type=Path, default=None, + help='The config file path. ' + '(default: ./agent.conf and /etc/backend.ai/agent.conf)') +@click.option('--debug', is_flag=True, + help='Enable the debug mode and override the global log level to DEBUG.') +@click.pass_context +def main( + cli_ctx: click.Context, + config_path: Path, + debug: bool, +) -> int: + + # Determine where to read configuration. + raw_cfg, cfg_src_path = config.read_from_file(config_path, 'agent') + + # Override the read config with environment variables (for legacy). + config.override_with_env(raw_cfg, ('etcd', 'namespace'), 'BACKEND_NAMESPACE') + config.override_with_env(raw_cfg, ('etcd', 'addr'), 'BACKEND_ETCD_ADDR') + config.override_with_env(raw_cfg, ('etcd', 'user'), 'BACKEND_ETCD_USER') + config.override_with_env(raw_cfg, ('etcd', 'password'), 'BACKEND_ETCD_PASSWORD') + config.override_with_env(raw_cfg, ('agent', 'rpc-listen-addr', 'host'), + 'BACKEND_AGENT_HOST_OVERRIDE') + config.override_with_env(raw_cfg, ('agent', 'rpc-listen-addr', 'port'), + 'BACKEND_AGENT_PORT') + config.override_with_env(raw_cfg, ('agent', 'pid-file'), 'BACKEND_PID_FILE') + config.override_with_env(raw_cfg, ('container', 'port-range'), + 'BACKEND_CONTAINER_PORT_RANGE') + config.override_with_env(raw_cfg, ('container', 'bind-host'), + 'BACKEND_BIND_HOST_OVERRIDE') + config.override_with_env(raw_cfg, ('container', 'sandbox-type'), 'BACKEND_SANDBOX_TYPE') + config.override_with_env(raw_cfg, ('container', 'scratch-root'), 'BACKEND_SCRATCH_ROOT') + if debug: + config.override_key(raw_cfg, ('debug', 'enabled'), True) + config.override_key(raw_cfg, ('logging', 'level'), 'DEBUG') + config.override_key(raw_cfg, ('logging', 'pkg-ns', 'ai.backend'), 'DEBUG') + + # Validate and fill configurations + # (allow_extra will make configs to be forward-copmatible) + try: + cfg = config.check(raw_cfg, agent_local_config_iv) + if cfg['agent']['backend'] == AgentBackend.KUBERNETES: + if cfg['container']['scratch-type'] == 'k8s-nfs' and \ + (cfg['container']['scratch-nfs-address'] is None + or cfg['container']['scratch-nfs-options'] is None): + raise ValueError('scratch-nfs-address and scratch-nfs-options are required for k8s-nfs') + if cfg['agent']['backend'] == AgentBackend.DOCKER: + config.check(raw_cfg, docker_extra_config_iv) + if 'debug' in cfg and cfg['debug']['enabled']: + print('== Agent configuration ==') + pprint(cfg) + cfg['_src'] = cfg_src_path + except config.ConfigurationError as e: + print('ConfigurationError: Validation of agent configuration has failed:', file=sys.stderr) + print(pformat(e.invalid_data), file=sys.stderr) + raise click.Abort() + + rpc_host = cfg['agent']['rpc-listen-addr'].host + if (isinstance(rpc_host, BaseIPAddress) and + (rpc_host.is_unspecified or rpc_host.is_link_local)): + print('ConfigurationError: ' + 'Cannot use link-local or unspecified IP address as the RPC listening host.', + file=sys.stderr) + raise click.Abort() + + if os.getuid() != 0 and cfg['container']['stats-type'] == 'cgroup': + print('Cannot use cgroup statistics collection mode unless the agent runs as root.', + file=sys.stderr) + raise click.Abort() + + if cli_ctx.invoked_subcommand is None: + + if cfg['debug']['coredump']['enabled']: + if not sys.platform.startswith('linux'): + print('ConfigurationError: ' + 'Storing container coredumps is only supported in Linux.', + file=sys.stderr) + raise click.Abort() + core_pattern = Path('/proc/sys/kernel/core_pattern').read_text().strip() + if core_pattern.startswith('|') or not core_pattern.startswith('/'): + print('ConfigurationError: ' + '/proc/sys/kernel/core_pattern must be an absolute path ' + 'to enable container coredumps.', + file=sys.stderr) + raise click.Abort() + cfg['debug']['coredump']['core_path'] = Path(core_pattern).parent + + cfg['agent']['pid-file'].write_text(str(os.getpid())) + ipc_base_path = cfg['agent']['ipc-base-path'] + log_sockpath = ipc_base_path / f'agent-logger-{os.getpid()}.sock' + log_sockpath.parent.mkdir(parents=True, exist_ok=True) + log_endpoint = f'ipc://{log_sockpath}' + cfg['logging']['endpoint'] = log_endpoint + try: + logger = Logger(cfg['logging'], is_master=True, log_endpoint=log_endpoint) + with logger: + ns = cfg['etcd']['namespace'] + setproctitle(f"backend.ai: agent {ns}") + log.info('Backend.AI Agent {0}', VERSION) + log.info('runtime: {0}', utils.env_info()) + + log_config = logging.getLogger('ai.backend.agent.config') + if debug: + log_config.debug('debug mode enabled.') + + if cfg['agent']['event-loop'] == 'uvloop': + import uvloop + uvloop.install() + log.info('Using uvloop as the event loop backend') + aiotools.start_server( + server_main_logwrapper, + num_workers=1, + args=(cfg, log_endpoint), + wait_timeout=5.0, + ) + log.info('exit.') + finally: + if cfg['agent']['pid-file'].is_file(): + # check is_file() to prevent deleting /dev/null! + cfg['agent']['pid-file'].unlink() + else: + # Click is going to invoke a subcommand. + pass + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/ai/backend/agent/stats.py b/src/ai/backend/agent/stats.py new file mode 100644 index 0000000000..6c32a687ac --- /dev/null +++ b/src/ai/backend/agent/stats.py @@ -0,0 +1,428 @@ +""" +A module to collect various performance metrics of Docker containers. + +Reference: https://www.datadoghq.com/blog/how-to-collect-docker-metrics/ +""" + +import asyncio +from decimal import Decimal +import enum +import logging +import sys +import time +from typing import ( + Callable, + Dict, + FrozenSet, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, +) +import aioredis + +import attr + +from ai.backend.common import redis +from ai.backend.common.identity import is_containerized +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common import msgpack +from ai.backend.common.types import ( + ContainerId, DeviceId, KernelId, + MetricKey, MetricValue, MovingStatValue, +) +from .utils import ( + remove_exponent, +) +if TYPE_CHECKING: + from .agent import AbstractAgent + +__all__ = ( + 'StatContext', + 'StatModes', + 'MetricTypes', + 'NodeMeasurement', + 'ContainerMeasurement', + 'Measurement', +) + +log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.stats')) + + +def check_cgroup_available(): + """ + Check if the host OS provides cgroups. + """ + return (not is_containerized() and sys.platform.startswith('linux')) + + +class StatModes(enum.Enum): + CGROUP = 'cgroup' + DOCKER = 'docker' + + @staticmethod + def get_preferred_mode(): + """ + Returns the most preferred statistics collector type for the host OS. + """ + if check_cgroup_available(): + return StatModes.CGROUP + return StatModes.DOCKER + + +class MetricTypes(enum.Enum): + USAGE = 0 # for instant snapshot (e.g., used memory bytes, used cpu msec) + RATE = 1 # for rate of increase (e.g., I/O bps) + UTILIZATION = 2 # for ratio of resource occupation time per measurement interval (e.g., CPU util) + ACCUMULATED = 3 # for accumulated value (e.g., total number of events) + + +@attr.s(auto_attribs=True, slots=True) +class Measurement: + value: Decimal + capacity: Optional[Decimal] = None + + +@attr.s(auto_attribs=True, slots=True) +class NodeMeasurement: + """ + Collection of per-node and per-agent statistics for a specific metric. + """ + # 2-tuple of Decimals mean raw values for (usage, available) + # Percent values are calculated from them. + key: str + type: MetricTypes + per_node: Measurement + per_device: Mapping[DeviceId, Measurement] = attr.Factory(dict) + unit_hint: Optional[str] = None + stats_filter: FrozenSet[str] = attr.Factory(frozenset) + current_hook: Optional[Callable[['Metric'], Decimal]] = None + + +@attr.s(auto_attribs=True, slots=True) +class ContainerMeasurement: + """ + Collection of per-container statistics for a specific metric. + """ + key: str + type: MetricTypes + per_container: Mapping[str, Measurement] = attr.Factory(dict) + unit_hint: Optional[str] = None + stats_filter: FrozenSet[str] = attr.Factory(frozenset) + current_hook: Optional[Callable[['Metric'], Decimal]] = None + + +class MovingStatistics: + __slots__ = ( + '_sum', '_count', + '_min', '_max', '_last', + ) + _sum: Decimal + _count: int + _min: Decimal + _max: Decimal + _last: List[Tuple[Decimal, float]] + + def __init__(self, initial_value: Decimal = None): + self._last = [] + if initial_value is None: + self._sum = Decimal(0) + self._min = Decimal('inf') + self._max = Decimal('-inf') + self._count = 0 + else: + self._sum = initial_value + self._min = initial_value + self._max = initial_value + self._count = 1 + point = (initial_value, time.perf_counter()) + self._last.append(point) + + def update(self, value: Decimal): + self._sum += value + self._min = min(self._min, value) + self._max = max(self._max, value) + self._count += 1 + point = (value, time.perf_counter()) + self._last.append(point) + # keep only the latest two data points + if len(self._last) > 2: + self._last.pop(0) + + @property + def min(self) -> Decimal: + return self._min + + @property + def max(self) -> Decimal: + return self._max + + @property + def sum(self) -> Decimal: + return self._sum + + @property + def avg(self) -> Decimal: + return self._sum / self._count + + @property + def diff(self) -> Decimal: + if len(self._last) == 2: + return self._last[-1][0] - self._last[-2][0] + return Decimal(0) + + @property + def rate(self) -> Decimal: + if len(self._last) == 2: + return ((self._last[-1][0] - self._last[-2][0]) / + Decimal(self._last[-1][1] - self._last[-2][1])) + return Decimal(0) + + def to_serializable_dict(self) -> MovingStatValue: + q = Decimal('0.000') + return { + 'min': str(remove_exponent(self.min.quantize(q))), + 'max': str(remove_exponent(self.max.quantize(q))), + 'sum': str(remove_exponent(self.sum.quantize(q))), + 'avg': str(remove_exponent(self.avg.quantize(q))), + 'diff': str(remove_exponent(self.diff.quantize(q))), + 'rate': str(remove_exponent(self.rate.quantize(q))), + 'version': 2, + } + + +@attr.s(auto_attribs=True, slots=True) +class Metric: + key: str + type: MetricTypes + stats: MovingStatistics + stats_filter: FrozenSet[str] + current: Decimal + capacity: Optional[Decimal] = None + unit_hint: Optional[str] = None + current_hook: Optional[Callable[['Metric'], Decimal]] = None + + def update(self, value: Measurement): + if value.capacity is not None: + self.capacity = value.capacity + self.stats.update(value.value) + self.current = value.value + if self.current_hook is not None: + self.current = self.current_hook(self) + + def to_serializable_dict(self) -> MetricValue: + q = Decimal('0.000') + q_pct = Decimal('0.00') + return { + 'current': str(remove_exponent(self.current.quantize(q))), + 'capacity': (str(remove_exponent(self.capacity.quantize(q))) + if self.capacity is not None else None), + 'pct': ( + str(remove_exponent( + (Decimal(self.current) / Decimal(self.capacity) * 100).quantize(q_pct))) + if (self.capacity is not None and + self.capacity.is_normal() and + self.capacity > 0) + else None), + 'unit_hint': self.unit_hint, + **{f'stats.{k}': v # type: ignore + for k, v in self.stats.to_serializable_dict().items() + if k in self.stats_filter}, + } + + +class StatContext: + + agent: 'AbstractAgent' + mode: StatModes + node_metrics: Mapping[MetricKey, Metric] + device_metrics: Mapping[MetricKey, MutableMapping[DeviceId, Metric]] + kernel_metrics: MutableMapping[KernelId, MutableMapping[MetricKey, Metric]] + + def __init__(self, agent: 'AbstractAgent', mode: StatModes = None, *, + cache_lifespan: int = 120) -> None: + self.agent = agent + self.mode = mode if mode is not None else StatModes.get_preferred_mode() + self.cache_lifespan = cache_lifespan + + self.node_metrics = {} + self.device_metrics = {} + self.kernel_metrics = {} + + self._lock = asyncio.Lock() + self._timestamps: MutableMapping[str, float] = {} + + def update_timestamp(self, timestamp_key: str) -> Tuple[float, float]: + """ + Update the timestamp for the given key and return a pair of the current timestamp and + the interval from the last update of the same key. + + If the last timestamp for the given key does not exist, the interval becomes "NaN". + + Intended to be used by compute plugins. + """ + now = time.perf_counter() + last = self._timestamps.get(timestamp_key, None) + self._timestamps[timestamp_key] = now + if last is None: + return now, float('NaN') + return now, now - last + + async def collect_node_stat(self): + """ + Collect the per-node, per-device, and per-container statistics. + + Intended to be used by the agent. + """ + async with self._lock: + # Here we use asyncio.gather() instead of aiotools.TaskGroup + # to keep methods of other plugins running when a plugin raises an error + # instead of cancelling them. + _tasks = [] + for computer in self.agent.computers.values(): + _tasks.append(computer.instance.gather_node_measures(self)) + results = await asyncio.gather(*_tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + log.error('collect_node_stat(): gather_node_measures() error', + exc_info=result) + continue + for node_measure in result: + metric_key = node_measure.key + # update node metric + if metric_key not in self.node_metrics: + self.node_metrics[metric_key] = Metric( + metric_key, node_measure.type, + current=node_measure.per_node.value, + capacity=node_measure.per_node.capacity, + unit_hint=node_measure.unit_hint, + stats=MovingStatistics(node_measure.per_node.value), + stats_filter=frozenset(node_measure.stats_filter), + current_hook=node_measure.current_hook, + ) + else: + self.node_metrics[metric_key].update(node_measure.per_node) + # update per-device metric + # NOTE: device IDs are defined by each metric keys. + for dev_id, measure in node_measure.per_device.items(): + dev_id = str(dev_id) + if metric_key not in self.device_metrics: + self.device_metrics[metric_key] = {} + if dev_id not in self.device_metrics[metric_key]: + self.device_metrics[metric_key][dev_id] = Metric( + metric_key, node_measure.type, + current=measure.value, + capacity=measure.capacity, + unit_hint=node_measure.unit_hint, + stats=MovingStatistics(measure.value), + stats_filter=frozenset(node_measure.stats_filter), + current_hook=node_measure.current_hook, + ) + else: + self.device_metrics[metric_key][dev_id].update(measure) + + # push to the Redis server + redis_agent_updates = { + 'node': { + key: obj.to_serializable_dict() + for key, obj in self.node_metrics.items() + }, + 'devices': { + metric_key: {dev_id: obj.to_serializable_dict() + for dev_id, obj in per_device.items()} + for metric_key, per_device in self.device_metrics.items() + }, + } + if self.agent.local_config['debug']['log-stats']: + log.debug('stats: node_updates: {0}: {1}', + self.agent.local_config['agent']['id'], redis_agent_updates['node']) + serialized_agent_updates = msgpack.packb(redis_agent_updates) + + async def _pipe_builder(r: aioredis.Redis): + async with r.pipeline() as pipe: + pipe.set(self.agent.local_config['agent']['id'], serialized_agent_updates) + pipe.expire(self.agent.local_config['agent']['id'], self.cache_lifespan) + await pipe.execute() + + await redis.execute(self.agent.redis_stat_pool, _pipe_builder) + + async def collect_container_stat( + self, + container_ids: Sequence[ContainerId], + ) -> None: + """ + Collect the per-container statistics only, + + Intended to be used by the agent and triggered by container cgroup synchronization processes. + """ + async with self._lock: + kernel_id_map: Dict[ContainerId, KernelId] = {} + for kid, info in self.agent.kernel_registry.items(): + cid = info['container_id'] + kernel_id_map[ContainerId(cid)] = kid + unused_kernel_ids = set(self.kernel_metrics.keys()) - set(kernel_id_map.values()) + for unused_kernel_id in unused_kernel_ids: + log.debug('removing kernel_metric for {}', unused_kernel_id) + self.kernel_metrics.pop(unused_kernel_id, None) + + # Here we use asyncio.gather() instead of aiotools.TaskGroup + # to keep methods of other plugins running when a plugin raises an error + # instead of cancelling them. + _tasks = [] + kernel_id = None + for computer in self.agent.computers.values(): + _tasks.append(asyncio.create_task( + computer.instance.gather_container_measures(self, container_ids), + )) + results = await asyncio.gather(*_tasks, return_exceptions=True) + updated_kernel_ids: Set[KernelId] = set() + for result in results: + if isinstance(result, Exception): + log.error('collect_container_stat(): gather_container_measures() error', + exc_info=result) + continue + for ctnr_measure in result: + metric_key = ctnr_measure.key + # update per-container metric + for cid, measure in ctnr_measure.per_container.items(): + try: + kernel_id = kernel_id_map[cid] + except KeyError: + continue + updated_kernel_ids.add(kernel_id) + if kernel_id not in self.kernel_metrics: + self.kernel_metrics[kernel_id] = {} + if metric_key not in self.kernel_metrics[kernel_id]: + self.kernel_metrics[kernel_id][metric_key] = Metric( + metric_key, ctnr_measure.type, + current=measure.value, + capacity=measure.capacity or measure.value, + unit_hint=ctnr_measure.unit_hint, + stats=MovingStatistics(measure.value), + stats_filter=frozenset(ctnr_measure.stats_filter), + current_hook=ctnr_measure.current_hook, + ) + else: + self.kernel_metrics[kernel_id][metric_key].update(measure) + + async def _pipe_builder(r: aioredis.Redis): + async with r.pipeline() as pipe: + for kernel_id in updated_kernel_ids: + metrics = self.kernel_metrics[kernel_id] + serializable_metrics = { + key: obj.to_serializable_dict() + for key, obj in metrics.items() + } + if self.agent.local_config['debug']['log-stats']: + log.debug('kernel_updates: {0}: {1}', + kernel_id, serializable_metrics) + serialized_metrics = msgpack.packb(serializable_metrics) + + pipe.set(str(kernel_id), serialized_metrics) + await pipe.execute() + + await redis.execute(self.agent.redis_stat_pool, _pipe_builder) diff --git a/src/ai/backend/agent/types.py b/src/ai/backend/agent/types.py new file mode 100644 index 0000000000..acc2c70451 --- /dev/null +++ b/src/ai/backend/agent/types.py @@ -0,0 +1,83 @@ +import asyncio +import enum +from typing import ( + Any, Optional, + Mapping, + Sequence, +) + +import attr + +from ai.backend.common.types import ( + ContainerId, + KernelId, +) + + +class AgentBackend(enum.Enum): + # The list of importable backend names under "ai.backend.agent" pkg namespace. + DOCKER = 'docker' + KUBERNETES = 'kubernetes' + + +@attr.s(auto_attribs=True, slots=True) +class VolumeInfo: + name: str # volume name + container_path: str # in-container path as str + mode: str # 'rw', 'ro', 'rwm' + + +@attr.s(auto_attribs=True, slots=True) +class Port: + host: str + private_port: int + host_port: int + + +class ContainerStatus(str, enum.Enum): + RUNNING = 'running' + RESTARTING = 'restarting' + PAUSED = 'paused' + EXITED = 'exited' + DEAD = 'dead' + REMOVING = 'removing' + + +@attr.s(auto_attribs=True, slots=True) +class Container: + id: ContainerId + status: ContainerStatus + image: str + labels: Mapping[str, str] + ports: Sequence[Port] + backend_obj: Any # used to keep the backend-specific data + + +class LifecycleEvent(int, enum.Enum): + DESTROY = 0 + CLEAN = 1 + START = 2 + + +@attr.s(auto_attribs=True, slots=True) +class ContainerLifecycleEvent: + kernel_id: KernelId + container_id: Optional[ContainerId] + event: LifecycleEvent + reason: str + done_future: Optional[asyncio.Future] = None + exit_code: Optional[int] = None + suppress_events: bool = False + + def __str__(self): + if self.container_id: + cid = self.container_id[:13] + else: + cid = 'unknown' + return ( + f"LifecycleEvent(" + f"{self.event.name}, " + f"k:{self.kernel_id}, " + f"c:{cid}, " + f"reason:{self.reason!r})" + ) diff --git a/src/ai/backend/agent/utils.py b/src/ai/backend/agent/utils.py new file mode 100644 index 0000000000..4550a0bc76 --- /dev/null +++ b/src/ai/backend/agent/utils.py @@ -0,0 +1,336 @@ +import asyncio +from decimal import Decimal +import hashlib +import io +import ipaddress +import json +import logging +from pathlib import Path +import platform +import re +from typing import ( + Any, + AsyncContextManager, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Protocol, + Tuple, + Type, + TypeVar, + Union, + overload, +) +from typing_extensions import Final +from uuid import UUID + +import aiodocker +from aiodocker.docker import DockerContainer +import netifaces +import trafaret as t + +from ai.backend.common import identity +from ai.backend.common.etcd import AsyncEtcd +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + PID, HostPID, ContainerPID, KernelId, +) +from ai.backend.common.utils import current_loop + +log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.utils')) + +IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network] +IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + +InOtherContainerPID: Final = ContainerPID(PID(-2)) +NotContainerPID: Final = ContainerPID(PID(-1)) +NotHostPID: Final = HostPID(PID(-1)) + + +class SupportsAsyncClose(Protocol): + async def close(self) -> None: + ... + + +_SupportsAsyncCloseT = TypeVar('_SupportsAsyncCloseT', bound=SupportsAsyncClose) + + +class closing_async(AsyncContextManager[_SupportsAsyncCloseT]): + """ + contextlib.closing calls close(), and aiotools.aclosing() calls aclose(). + This context manager calls close() as a coroutine. + """ + + def __init__(self, obj: _SupportsAsyncCloseT) -> None: + self.obj = obj + + async def __aenter__(self) -> _SupportsAsyncCloseT: + return self.obj + + async def __aexit__(self, *exc_info) -> None: + await self.obj.close() + + +def generate_local_instance_id(hint: str) -> str: + return hashlib.md5(hint.encode('utf-8')).hexdigest()[:12] + + +def get_arch_name() -> str: + ret = platform.machine().lower() + aliases = { + "arm64": "aarch64", # macOS with LLVM + "amd64": "x86_64", # Windows/Linux + "x64": "x86_64", # Windows + "x32": "x86", # Windows + "i686": "x86", # Windows + } + return aliases.get(ret, ret) + + +def update_nested_dict(dest: MutableMapping, additions: Mapping) -> None: + for k, v in additions.items(): + if k not in dest: + dest[k] = v + else: + if isinstance(dest[k], MutableMapping): + assert isinstance(v, MutableMapping) + update_nested_dict(dest[k], v) + elif isinstance(dest[k], List): + assert isinstance(v, List) + dest[k].extend(v) + else: + dest[k] = v + + +def numeric_list(s: str) -> List[int]: + return [int(p) for p in s.split()] + + +def remove_exponent(num: Decimal) -> Decimal: + return num.quantize(Decimal(1)) if num == num.to_integral() else num.normalize() + + +@overload +def read_sysfs(path: Union[str, Path], type_: Type[bool], default: bool) -> bool: + ... + + +@overload +def read_sysfs(path: Union[str, Path], type_: Type[int], default: int) -> int: + ... + + +@overload +def read_sysfs(path: Union[str, Path], type_: Type[float], default: float) -> float: + ... + + +@overload +def read_sysfs(path: Union[str, Path], type_: Type[str], default: str) -> str: + ... + + +def read_sysfs(path: Union[str, Path], type_: Type[Any], default: Any = None) -> Any: + def_vals: Mapping[Any, Any] = { + bool: False, + int: 0, + float: 0.0, + str: '', + } + if type_ not in def_vals: + raise TypeError('unsupported conversion type from sysfs content') + if default is None: + default = def_vals[type_] + try: + raw_str = Path(path).read_text().strip() + if type_ is bool: + return t.ToBool().check(raw_str) + else: + return type_(raw_str) + except IOError: + return default + + +async def read_tail(path: Path, nbytes: int) -> bytes: + file_size = path.stat().st_size + + def _read_tail() -> bytes: + with open(path, 'rb') as f: + f.seek(max(file_size - nbytes, 0), io.SEEK_SET) + return f.read(nbytes) + + loop = current_loop() + return await loop.run_in_executor(None, _read_tail) + + +async def get_kernel_id_from_container(val: Union[str, DockerContainer]) -> Optional[KernelId]: + if isinstance(val, DockerContainer): + if 'Name' not in val._container: + await val.show() + name = val['Name'] + elif isinstance(val, str): + name = val + name = name.lstrip('/') + if not name.startswith('kernel.'): + return None + try: + return KernelId(UUID(name.rsplit('.', 2)[-1])) + except (IndexError, ValueError): + return None + + +async def get_subnet_ip(etcd: AsyncEtcd, network: str, fallback_addr: str = '0.0.0.0') -> str: + raw_subnet = await etcd.get(f'config/network/subnet/{network}') + if raw_subnet is None: + addr = fallback_addr + else: + subnet = ipaddress.ip_network(raw_subnet) + if subnet.prefixlen == 0: + addr = fallback_addr + else: + local_ipaddrs = [*identity.fetch_local_ipaddrs(subnet)] + log.debug('get_subnet_ip(): subnet {} candidates: {}', + subnet, local_ipaddrs) + if local_ipaddrs: + addr = str(local_ipaddrs[0]) + else: + addr = fallback_addr + return addr + + +async def host_pid_to_container_pid(container_id: str, host_pid: HostPID) -> ContainerPID: + kernel_ver = Path('/proc/version').read_text() + if m := re.match(r'Linux version (\d+)\.(\d+)\..*', kernel_ver): # noqa + kernel_ver_tuple: Tuple[str, str] = m.groups() # type: ignore + if kernel_ver_tuple < ('4', '1'): + # TODO: this should be deprecated when the minimun supported Linux kernel will be 4.1. + # + # In CentOs 7, NSPid is not accesible since it is supported from Linux kernel >=4.1. + # We provide alternative, although messy, way for older Linux kernels. Below describes + # the logic briefly: + # * Obtain information on all the processes inside the target container, + # which contains host PID, by docker top API (containers//top). + # - Get the COMMAND of the target process (by using host_pid). + # - Filter host processes which have the exact same COMMAND. + # * Obtain information on all the processes inside the target container, + # which contains container PID, by executing "ps -aux" command from inside the container. + # - Filter container processes which have the exact same COMMAND. + # * Get the index of the target process from the host process table. + # * Use the index to get the target process from the container process table, and get PID. + # - Since docker top and ps -aux both displays processes in the order of PID, we + # can safely assume that the order of the processes from both tables are the same. + # + # Example host and container process table: + # + # [ + # ['devops', '15454', '12942', '99', '15:36', 'pts/1', '00:00:08', 'python mnist.py'], + # ... (processes with the same COMMAND) + # ] + # + # [ + # ['work', '227', '121', '4.6', '22408680', '1525428', 'pts/1', 'Rl+', '06:36', '0:08', + # 'python', 'mnist.py'], + # ... (processes with the same COMMAND) + # ] + try: + docker = aiodocker.Docker() + # Get process table from host (docker top information). Filter processes which have + # exactly the same COMMAND as with target host process. + result = await docker._query_json(f'containers/{container_id}/top', method='GET') + procs = result['Processes'] + cmd = list(filter(lambda x: str(host_pid) == x[1], procs))[0][7] + host_table = list(filter(lambda x: cmd == x[7], procs)) + + # Get process table from inside container (execute 'ps -aux' command from container). + # Filter processes which have exactly the same COMMAND like above. + result = await docker._query_json( + f'containers/{container_id}/exec', + method='POST', + data={ + 'AttachStdin': False, + 'AttachStdout': True, + 'AttachStderr': True, + 'Cmd': ['ps', '-aux'], + }, + ) + exec_id = result['Id'] + async with docker._query( + f'exec/{exec_id}/start', + method='POST', + headers={'content-type': 'application/json'}, + data=json.dumps({ + 'Stream': False, # get response immediately + 'Detach': False, + 'Tty': False, + }), + ) as resp: + result = await resp.read() + result = result.decode('latin-1').split('\n') + result = list(map(lambda x: x.split(), result)) + head = result[0] + procs = result[1:] + pid_idx, cmd_idx = head.index('PID'), head.index('COMMAND') + container_table = list( + filter(lambda x: cmd == ' '.join(x[cmd_idx:]) if x else False, procs), + ) + + # When there are multiple processes which have the same COMMAND, just get the index of + # the target host process and apply it with the container table. Since ps and docker top + # both displays processes ordered by PID, we can expect those two tables have same + # order of processes. + process_idx = None + for idx, p in enumerate(host_table): + if str(host_pid) == p[1]: + process_idx = idx + break + else: + raise IndexError + container_pid = ContainerPID(container_table[process_idx][pid_idx]) + log.debug('host pid {} is mapped to container pid {}', host_pid, container_pid) + return ContainerPID(PID(int(container_pid))) + except asyncio.CancelledError: + raise + except (IndexError, KeyError, aiodocker.exceptions.DockerError): + return NotContainerPID + finally: + await docker.close() + + try: + for p in Path('/sys/fs/cgroup/pids/docker').iterdir(): + if not p.is_dir(): + continue + tasks_path = p / 'tasks' + cgtasks = [*map(int, tasks_path.read_text().splitlines())] + if host_pid not in cgtasks: + continue + if p.name == container_id: + proc_path = Path(f'/proc/{host_pid}/status') + proc_status = {k: v for k, v + in map(lambda l: l.split(':\t'), + proc_path.read_text().splitlines())} + nspids = [*map(lambda pid: ContainerPID(PID(int(pid))), proc_status['NSpid'].split())] + return nspids[1] + return InOtherContainerPID + return NotContainerPID + except (ValueError, KeyError, IOError): + return NotContainerPID + + +async def container_pid_to_host_pid(container_id: str, container_pid: ContainerPID) -> HostPID: + # TODO: implement + return NotHostPID + + +def fetch_local_ipaddrs(cidr: IPNetwork) -> Iterable[IPAddress]: + ifnames = netifaces.interfaces() + proto = netifaces.AF_INET if cidr.version == 4 else netifaces.AF_INET6 + for ifname in ifnames: + addrs = netifaces.ifaddresses(ifname).get(proto, None) + if addrs is None: + continue + for entry in addrs: + addr = ipaddress.ip_address(entry['addr']) + if addr in cidr: + yield addr diff --git a/src/ai/backend/agent/vendor/__init__.py b/src/ai/backend/agent/vendor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/agent/vendor/linux.py b/src/ai/backend/agent/vendor/linux.py new file mode 100644 index 0000000000..6ab9300cd5 --- /dev/null +++ b/src/ai/backend/agent/vendor/linux.py @@ -0,0 +1,57 @@ +import ctypes, ctypes.util +import os +import sys + +import aiohttp +import aiotools + +_numa_supported = False + +if sys.platform == 'linux': + _libnuma_path = ctypes.util.find_library('numa') + if _libnuma_path: + _libnuma = ctypes.CDLL(_libnuma_path) + _numa_supported = True + + +class libnuma: + + @staticmethod + def node_of_cpu(core): + if _numa_supported: + return int(_libnuma.numa_node_of_cpu(core)) + else: + return 0 + + @staticmethod + def num_nodes(): + if _numa_supported: + return int(_libnuma.numa_num_configured_nodes()) + else: + return 1 + + @staticmethod + @aiotools.lru_cache(maxsize=1) + async def get_available_cores(): + try: + # Try to get the # cores allocated to Docker first. + unix_conn = aiohttp.UnixConnector('/var/run/docker.sock') + async with aiohttp.ClientSession(connector=unix_conn) as sess: + async with sess.get('http://docker/info') as resp: + data = await resp.json() + return {idx for idx in range(data['NCPU'])} + except aiohttp.ClientError: + try: + return os.sched_getaffinity(os.getpid()) + except AttributeError: + return {idx for idx in range(os.cpu_count())} + + @staticmethod + async def get_core_topology(limit_cpus=None): + topo = tuple([] for _ in range(libnuma.num_nodes())) + for c in (await libnuma.get_available_cores()): + if limit_cpus is not None and c not in limit_cpus: + continue + n = libnuma.node_of_cpu(c) + topo[n].append(c) + return topo diff --git a/src/ai/backend/agent/watcher.py b/src/ai/backend/agent/watcher.py new file mode 100644 index 0000000000..3d6874ab7c --- /dev/null +++ b/src/ai/backend/agent/watcher.py @@ -0,0 +1,384 @@ +import asyncio +import logging +import os +from pathlib import Path +from pprint import pprint, pformat +import signal +import ssl +import subprocess +import sys + +import aiofiles +from aiohttp import web +import aiotools +import click +from setproctitle import setproctitle +import trafaret as t + +from ai.backend.common import config, utils, validators as tx +from ai.backend.common.etcd import AsyncEtcd, ConfigScopes +from ai.backend.common.logging import Logger, BraceStyleAdapter +from ai.backend.common.utils import Fstab +from . import __version__ as VERSION + +log = BraceStyleAdapter(logging.getLogger('ai.backend.agent.watcher')) + +shutdown_enabled = False + + +@web.middleware +async def auth_middleware(request, handler): + token = request.headers.get('X-BackendAI-Watcher-Token', None) + if token == request.app['token']: + try: + return (await handler(request)) + except FileNotFoundError as e: + log.info(repr(e)) + message = 'Agent is not loaded with systemctl.' + return web.json_response({'message': message}, status=200) + except Exception as e: + log.exception(repr(e)) + raise + log.info('invalid requested token') + return web.HTTPForbidden() + + +async def handle_status(request: web.Request) -> web.Response: + svc = request.app['config']['watcher']['target-service'] + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'is-active', svc], + stdout=subprocess.PIPE) + if proc.stdout is not None: + status = (await proc.stdout.read()).strip().decode() + else: + status = 'unknown' + await proc.wait() + return web.json_response({ + 'agent-status': status, # maybe also "inactive", "activating" + 'watcher-status': 'active', + }) + + +async def handle_soft_reset(request: web.Request) -> web.Response: + svc = request.app['config']['watcher']['target-service'] + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'reload', svc]) + await proc.wait() + return web.json_response({ + 'result': 'ok', + }) + + +async def handle_hard_reset(request: web.Request) -> web.Response: + svc = request.app['config']['watcher']['target-service'] + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'stop', svc]) + await proc.wait() + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'restart', 'docker.service']) + await proc.wait() + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'start', svc]) + await proc.wait() + return web.json_response({ + 'result': 'ok', + }) + + +async def handle_shutdown(request: web.Request) -> web.Response: + global shutdown_enabled + svc = request.app['config']['watcher']['target-service'] + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'stop', svc]) + await proc.wait() + shutdown_enabled = True + signal.alarm(1) + return web.json_response({ + 'result': 'ok', + }) + + +async def handle_agent_start(request: web.Request) -> web.Response: + svc = request.app['config']['watcher']['target-service'] + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'start', svc]) + await proc.wait() + return web.json_response({ + 'result': 'ok', + }) + + +async def handle_agent_stop(request: web.Request) -> web.Response: + svc = request.app['config']['watcher']['target-service'] + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'stop', svc]) + await proc.wait() + return web.json_response({ + 'result': 'ok', + }) + + +async def handle_agent_restart(request: web.Request) -> web.Response: + svc = request.app['config']['watcher']['target-service'] + proc = await asyncio.create_subprocess_exec( + *['sudo', 'systemctl', 'restart', svc]) + await proc.wait() + return web.json_response({ + 'result': 'ok', + }) + + +async def handle_fstab_detail(request: web.Request) -> web.Response: + log.info('HANDLE_FSTAB_DETAIL') + params = request.query + fstab_path = params.get('fstab_path', '/etc/fstab') + async with aiofiles.open(fstab_path, mode='r') as fp: + content = await fp.read() + return web.Response(text=content) + + +async def handle_list_mounts(request: web.Request) -> web.Response: + log.info('HANDLE_LIST_MOUNT') + config = request.app['config_server'] + mount_prefix = await config.get('volumes/_mount') + if mount_prefix is None: + mount_prefix = '/mnt' + mounts = set() + for p in Path(mount_prefix).iterdir(): + if p.is_dir() and p.is_mount(): + mounts.add(str(p)) + return web.json_response(sorted(mounts)) + + +async def handle_mount(request: web.Request) -> web.Response: + log.info('HANDLE_MOUNT') + params = await request.json() + config = request.app['config_server'] + mount_prefix = await config.get('volumes/_mount') + if mount_prefix is None: + mount_prefix = '/mnt' + mountpoint = Path(mount_prefix) / params['name'] + mountpoint.mkdir(exist_ok=True) + if params.get('options', None): + cmd = ['sudo', 'mount', '-t', params['fs_type'], '-o', params['options'], + params['fs_location'], str(mountpoint)] + else: + cmd = ['sudo', 'mount', '-t', params['fs_type'], + params['fs_location'], str(mountpoint)] + proc = await asyncio.create_subprocess_exec(*cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + raw_out, raw_err = await proc.communicate() + out = raw_out.decode('utf8') + err = raw_err.decode('utf8') + await proc.wait() + if err: + log.error('Mount error: ' + err) + return web.Response(text=err, status=500) + log.info('Mounted ' + params['name'] + ' on ' + mount_prefix) + if params['edit_fstab']: + fstab_path = params['fstab_path'] if params['fstab_path'] else '/etc/fstab' + # FIXME: Remove ignore if https://github.com/python/typeshed/pull/4650 is released + async with aiofiles.open(fstab_path, mode='r+') as fp: # type: ignore + fstab = Fstab(fp) + await fstab.add(params['fs_location'], str(mountpoint), + params['fs_type'], params['options']) + return web.Response(text=out) + + +async def handle_umount(request: web.Request) -> web.Response: + log.info('HANDLE_UMOUNT') + params = await request.json() + config = request.app['config_server'] + mount_prefix = await config.get('volumes/_mount') + if mount_prefix is None: + mount_prefix = '/mnt' + mountpoint = Path(mount_prefix) / params['name'] + assert Path(mount_prefix) != mountpoint + proc = await asyncio.create_subprocess_exec(*[ + 'sudo', 'umount', str(mountpoint), + ], stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + raw_out, raw_err = await proc.communicate() + out = raw_out.decode('utf8') + err = raw_err.decode('utf8') + await proc.wait() + if err: + log.error('Unmount error: ' + err) + return web.Response(text=err, status=500) + log.info('Unmounted ' + params['name'] + ' from ' + mount_prefix) + try: + mountpoint.rmdir() # delete directory if empty + except OSError: + pass + if params['edit_fstab']: + fstab_path = params['fstab_path'] if params['fstab_path'] else '/etc/fstab' + # FIXME: Remove ignore if https://github.com/python/typeshed/pull/4650 is released + async with aiofiles.open(fstab_path, mode='r+') as fp: # type: ignore + fstab = Fstab(fp) + await fstab.remove_by_mountpoint(str(mountpoint)) + return web.Response(text=out) + + +async def init_app(app): + r = app.router.add_route + r('GET', '/', handle_status) + if app['config']['watcher']['soft-reset-available']: + r('POST', '/soft-reset', handle_soft_reset) + r('POST', '/hard-reset', handle_hard_reset) + r('POST', '/shutdown', handle_shutdown) + r('POST', '/agent/start', handle_agent_start) + r('POST', '/agent/stop', handle_agent_stop) + r('POST', '/agent/restart', handle_agent_restart) + r('GET', '/fstab', handle_fstab_detail) + r('GET', '/mounts', handle_list_mounts) + r('POST', '/mounts', handle_mount) + r('DELETE', '/mounts', handle_umount) + + +async def shutdown_app(app): + pass + + +async def prepare_hook(request, response): + response.headers['Server'] = 'BackendAI-AgentWatcher' + + +@aiotools.server +async def watcher_server(loop, pidx, args): + global shutdown_enabled + + app = web.Application() + app['config'] = args[0] + + etcd_credentials = None + if app['config']['etcd']['user']: + etcd_credentials = { + 'user': app['config']['etcd']['user'], + 'password': app['config']['etcd']['password'], + } + scope_prefix_map = { + ConfigScopes.GLOBAL: '', + } + etcd = AsyncEtcd(app['config']['etcd']['addr'], + app['config']['etcd']['namespace'], + scope_prefix_map=scope_prefix_map, + credentials=etcd_credentials) + app['config_server'] = etcd + + token = await etcd.get('config/watcher/token') + if token is None: + token = 'insecure' + log.debug('watcher authentication token: {}', token) + app['token'] = token + + app.middlewares.append(auth_middleware) + app.on_shutdown.append(shutdown_app) + app.on_startup.append(init_app) + app.on_response_prepare.append(prepare_hook) + ssl_ctx = None + if app['config']['watcher']['ssl-enabled']: + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain( + str(app['config']['watcher']['ssl-cert']), + str(app['config']['watcher']['ssl-privkey']), + ) + runner = web.AppRunner(app) + await runner.setup() + watcher_addr = app['config']['watcher']['service-addr'] + site = web.TCPSite( + runner, + str(watcher_addr.host), + watcher_addr.port, + backlog=5, + reuse_port=True, + ssl_context=ssl_ctx, + ) + await site.start() + log.info('started at {}', watcher_addr) + try: + stop_sig = yield + finally: + log.info('shutting down...') + if stop_sig == signal.SIGALRM and shutdown_enabled: + log.warning('shutting down the agent node!') + subprocess.run(['shutdown', '-h', 'now']) + await runner.cleanup() + + +@click.command() +@click.option('-f', '--config-path', '--config', type=Path, default=None, + help='The config file path. (default: ./agent.conf and /etc/backend.ai/agent.conf)') +@click.option('--debug', is_flag=True, + help='Enable the debug mode and override the global log level to DEBUG.') +@click.pass_context +def main(cli_ctx, config_path, debug): + + watcher_config_iv = t.Dict({ + t.Key('watcher'): t.Dict({ + t.Key('service-addr', default=('0.0.0.0', 6009)): tx.HostPortPair, + t.Key('ssl-enabled', default=False): t.Bool, + t.Key('ssl-cert', default=None): t.Null | tx.Path(type='file'), + t.Key('ssl-key', default=None): t.Null | tx.Path(type='file'), + t.Key('target-service', default='backendai-agent.service'): t.String, + t.Key('soft-reset-available', default=False): t.Bool, + }).allow_extra('*'), + t.Key('logging'): t.Any, # checked in ai.backend.common.logging + t.Key('debug'): t.Dict({ + t.Key('enabled', default=False): t.Bool, + }).allow_extra('*'), + }).merge(config.etcd_config_iv).allow_extra('*') + + raw_cfg, cfg_src_path = config.read_from_file(config_path, 'agent') + + config.override_with_env(raw_cfg, ('etcd', 'namespace'), 'BACKEND_NAMESPACE') + config.override_with_env(raw_cfg, ('etcd', 'addr'), 'BACKEND_ETCD_ADDR') + config.override_with_env(raw_cfg, ('etcd', 'user'), 'BACKEND_ETCD_USER') + config.override_with_env(raw_cfg, ('etcd', 'password'), 'BACKEND_ETCD_PASSWORD') + config.override_with_env(raw_cfg, ('watcher', 'service-addr', 'host'), + 'BACKEND_WATCHER_SERVICE_IP') + config.override_with_env(raw_cfg, ('watcher', 'service-addr', 'port'), + 'BACKEND_WATCHER_SERVICE_PORT') + if debug: + config.override_key(raw_cfg, ('debug', 'enabled'), True) + + try: + cfg = config.check(raw_cfg, watcher_config_iv) + if 'debug' in cfg and cfg['debug']['enabled']: + print('== Watcher configuration ==') + pprint(cfg) + cfg['_src'] = cfg_src_path + except config.ConfigurationError as e: + print('Validation of watcher configuration has failed:', file=sys.stderr) + print(pformat(e.invalid_data), file=sys.stderr) + raise click.Abort() + + # Change the filename from the logging config's file section. + log_sockpath = Path(f'/tmp/backend.ai/ipc/watcher-logger-{os.getpid()}.sock') + log_sockpath.parent.mkdir(parents=True, exist_ok=True) + log_endpoint = f'ipc://{log_sockpath}' + cfg['logging']['endpoint'] = log_endpoint + logger = Logger(cfg['logging'], is_master=True, log_endpoint=log_endpoint) + if 'file' in cfg['logging']['drivers']: + fn = Path(cfg['logging']['file']['filename']) + cfg['logging']['file']['filename'] = f"{fn.stem}-watcher{fn.suffix}" + + setproctitle(f"backend.ai: watcher {cfg['etcd']['namespace']}") + with logger: + log.info('Backend.AI Agent Watcher {0}', VERSION) + log.info('runtime: {0}', utils.env_info()) + + log_config = logging.getLogger('ai.backend.agent.config') + log_config.debug('debug mode enabled.') + + aiotools.start_server( + watcher_server, + num_workers=1, + args=(cfg, ), + stop_signals={signal.SIGINT, signal.SIGTERM, signal.SIGALRM}, + ) + log.info('exit.') + return 0 + + +if __name__ == '__main__': + main() diff --git a/src/ai/backend/cli/BUILD b/src/ai/backend/cli/BUILD new file mode 100644 index 0000000000..a472fca71b --- /dev/null +++ b/src/ai/backend/cli/BUILD @@ -0,0 +1,48 @@ +python_sources( + name="lib", + dependencies=[ + "src/ai/backend/plugin:lib", + ":resources", + ], + sources=["**/*.py"], +) + +python_distribution( + name="dist", + dependencies=[ + ":lib", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-cli", + description="Backend.AI Command Line Interface Helper", + license="MIT", + ), + entry_points={ + "console_scripts": { + "backend.ai": "ai.backend.cli.__main__:main", + }, + }, + generate_setup=True, + tags=["wheel"], +) + +pex_binary( + name="cli", + dependencies=[ + ":lib", + ], + entry_point="ai.backend.cli.__main__:main", +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + ], +) diff --git a/src/ai/backend/cli/README.md b/src/ai/backend/cli/README.md new file mode 100644 index 0000000000..efb86e37b0 --- /dev/null +++ b/src/ai/backend/cli/README.md @@ -0,0 +1,23 @@ +# backend.ai-cli + +Unified command-line interface for Backend.AI + + +## How to adopt in CLI-enabled Backend.AI packages + +An example `setup.cfg` in Backend.AI Manager: +``` +[options.entry_points] +backendai_cli_v10 = + mgr = ai.backend.manager.cli.__main__:main + mgr.start-server = ai.backend.gateway.server:main +``` + +Define your package entry points that returns a Click command group using a +prefix, and add additional entry points that returns a Click command using a +prefix followed by a dot and sub-command name for shortcut access, under the +`backendai_cli_v10` entry point group. + +Then add `backend.ai-cli` to the `install_requires` list. + +You can do the same in `setup.py` as well. diff --git a/src/ai/backend/cli/VERSION b/src/ai/backend/cli/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/cli/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/cli/__init__.py b/src/ai/backend/cli/__init__.py new file mode 100644 index 0000000000..ef7eb44d9a --- /dev/null +++ b/src/ai/backend/cli/__init__.py @@ -0,0 +1 @@ +__version__ = '0.6.0' diff --git a/src/ai/backend/cli/__main__.py b/src/ai/backend/cli/__main__.py new file mode 100644 index 0000000000..885de31552 --- /dev/null +++ b/src/ai/backend/cli/__main__.py @@ -0,0 +1,9 @@ +from .loader import load_entry_points + + +main = load_entry_points() +# main object is called by the console script. + +if __name__ == "__main__": + # Execute right away if the module is directly called from CLI. + main() diff --git a/src/ai/backend/cli/extensions.py b/src/ai/backend/cli/extensions.py new file mode 100644 index 0000000000..dcf790df93 --- /dev/null +++ b/src/ai/backend/cli/extensions.py @@ -0,0 +1,144 @@ +import os +import signal +import sys + +import click +from click.exceptions import ClickException, Abort + + +class InterruptAwareCommandMixin(click.BaseCommand): + """ + Replace the main() method to support proper exit-codes + for interruptions on Windows and POSIX platforms. + Using this, interrupting the command will let the shell + know that the execution is also interrupted instead of + continuing the shell/batch script. + """ + + def main(self, *args, **kwargs): + try: + _interrupted = False + kwargs.pop('standalone_mode', None) + kwargs.pop('prog_name', None) + super().main( + *args, + standalone_mode=False, + prog_name='backend.ai', + **kwargs, + ) + except KeyboardInterrupt: + # For interruptions outside the Click's exception handling block. + print("Interrupted!", end="", file=sys.stderr) + sys.stderr.flush() + _interrupted = True + except Abort as e: + # Click wraps unhandled KeyboardInterrupt with a plain + # sys.exit(1) call and prints "Aborted!" message + # (which would look non-sense to users). + # This is *NOT* what we want. + # Instead of relying on Click, mark the _interrupted + # flag to perform our own exit routines. + if isinstance(e.__context__, KeyboardInterrupt): + print("Interrupted!", end="", file=sys.stderr) + sys.stderr.flush() + _interrupted = True + else: + print("Aborted!", end="", file=sys.stderr) + sys.stderr.flush() + sys.exit(1) + except ClickException as e: + e.show() + sys.exit(e.exit_code) + finally: + if _interrupted: + # Override the exit code when it's interrupted, + # referring https://github.com/python/cpython/pull/11862 + if sys.platform.startswith('win'): + # Use STATUS_CONTROL_C_EXIT to notify cmd.exe + # for interrupted exit + sys.exit(-1073741510) + else: + # Use the default signal handler to set the exit + # code properly for interruption. + signal.signal(signal.SIGINT, signal.SIG_DFL) + os.kill(os.getpid(), signal.SIGINT) + + +class AliasGroupMixin(click.Group): + """ + Enable command aliases. + + ref) https://github.com/click-contrib/click-aliases + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._commands = {} + self._aliases = {} + + def command(self, *args, **kwargs): + aliases = kwargs.pop('aliases', []) + decorator = super().command(*args, **kwargs) + if not aliases: + return decorator + + def _decorator(f): + cmd = decorator(f) + if aliases: + self._commands[cmd.name] = aliases + for alias in aliases: + self._aliases[alias] = cmd.name + return cmd + return _decorator + + def group(self, *args, **kwargs): + aliases = kwargs.pop('aliases', []) + # keep the same class type + kwargs['cls'] = type(self) + decorator = super().group(*args, **kwargs) + if not aliases: + return decorator + + def _decorator(f): + cmd = decorator(f) + if aliases: + self._commands[cmd.name] = aliases + for alias in aliases: + self._aliases[alias] = cmd.name + return cmd + return _decorator + + def get_command(self, ctx, cmd_name): + if cmd_name in self._aliases: + cmd_name = self._aliases[cmd_name] + command = super().get_command(ctx, cmd_name) + if command: + return command + + def format_commands(self, ctx, formatter): + commands = [] + for subcommand in self.list_commands(ctx): + cmd = self.get_command(ctx, subcommand) + # What is this, the tool lied about a command. Ignore it + if cmd is None: + continue + if cmd.hidden: + continue + if subcommand in self._commands: + aliases = ','.join(sorted(self._commands[subcommand])) + subcommand = '{0} ({1})'.format(subcommand, aliases) + commands.append((subcommand, cmd)) + + # allow for 3 times the default spacing + if len(commands): + limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands) + rows = [] + for subcommand, cmd in commands: + help = cmd.get_short_help_str(limit) + rows.append((subcommand, help)) + if rows: + with formatter.section('Commands'): + formatter.write_dl(rows) + + +class ExtendedCommandGroup(InterruptAwareCommandMixin, AliasGroupMixin, click.Group): + pass diff --git a/src/ai/backend/cli/interaction.py b/src/ai/backend/cli/interaction.py new file mode 100644 index 0000000000..ef3abac109 --- /dev/null +++ b/src/ai/backend/cli/interaction.py @@ -0,0 +1,150 @@ +import ipaddress +from pathlib import Path +from typing import Optional, Union +from urllib.error import HTTPError +from urllib.request import urlopen + + +Numeric = Union[int, float] + + +def ask_host(prompt: str, default: str = "127.0.0.1", allow_hostname=False) -> str: + while True: + user_reply = input(f"{prompt} (default: {default}): ") + if user_reply == "": + user_reply = default + try: + if allow_hostname: + url = user_reply + if not (user_reply.startswith("http://") or user_reply.startswith("https://")): + url = f"http://{user_reply}" + try: + urlopen(url) + break + except HTTPError: + print("Please input correct URL.") + ipaddress.ip_address(user_reply) + break + except ValueError: + print("Please input correct host.") + return user_reply + + +def convert_str_into_numeric(user_reply: str) -> Numeric: + if user_reply.isdigit(): + return int(user_reply) + return float(user_reply) + + +def ask_number(prompt: str, default: Numeric, min_value: Numeric, max_value: Numeric) -> Numeric: + while True: + user_reply = input(f"{prompt} (default: {default}): ") + try: + if user_reply == "": + return default + if user_reply.isdigit() and min_value <= convert_str_into_numeric(user_reply) <= max_value: + user_reply_numeric = convert_str_into_numeric(user_reply) + break + except ValueError: + print(f"Please input correct number between {min_value}~{max_value}.") + return user_reply_numeric + + +def ask_string(prompt: str, default: str = "", use_default: bool = True) -> str: + while True: + if use_default: + user_reply = input(f"{prompt} (default: \"{default}\"): ") + if user_reply == "": + return default + return user_reply + else: + user_reply = input(f"{prompt} (if you don\'t want, just leave empty): ") + return user_reply + + +def ask_string_in_array(prompt: str, choices: list, default: str) -> Optional[str]: + if default and default not in choices: + print("Default value should be in choices args.") + return None + if "" in choices: + choices.remove("") + + if default: + question = f"{prompt} (choices: {'/'.join(choices)}, " \ + f"if left empty, this will use default value: {default}): " + else: + question = f"{prompt} (choices: {'/'.join(choices)}, if left empty, this will remove this key): " + + while True: + user_reply = input(question) + if user_reply == "": + if default: + user_reply = default + else: + return None + break + elif user_reply.lower() in choices: + break + else: + print(f"Please answer in {'/'.join(choices)}.") + return user_reply + + +def ask_path(prompt: str, is_file=True, is_directory=True) -> Path: + if not (is_file or is_directory): + print("One of args(is_file/is_directory) has True value.") + while True: + user_reply = input(f"{prompt}: ") + path = Path(user_reply) + if is_file and path.is_file(): + break + if is_directory and path.is_dir(): + break + + if is_file and is_directory: + print("Please answer a correct file/directory path.") + elif is_file: + print("Please answer a correct file path.") + elif is_directory: + print("Please answer a correct directory path.") + return path + + +def ask_yn(prompt: str = 'Are you sure?', default: str = 'y') -> bool: + if default == 'y': + choices = 'Y/n' + elif default == 'n': + choices = 'y/N' + else: + raise ValueError("default must be given either 'y' or 'n'.") + while True: + user_reply = input("{0} [{1}] ".format(prompt, choices)).lower() + if user_reply == '': + user_reply = default + if user_reply in ('y', 'yes', 'n', 'no'): + break + else: + print("Please answer in y/yes/n/no.") + if user_reply[:1].lower() == 'y': + return True + return False + + +def ask_tf(prompt: str = 'Are you sure?', default: str = 'true') -> bool: + if default == 't': + choices = 'T/f' + elif default == 'f': + choices = 't/F' + else: + raise ValueError("default must be given either 'true' or 'n'.") + while True: + user_reply = input(f"{prompt} [{choices}] ").lower() + if user_reply == '': + user_reply = default + if user_reply in ('t', 'true', 'f', 'false'): + break + else: + print("Please answer in t/true/f/false.") + if user_reply[:1].lower() == 't': + return True + return False diff --git a/src/ai/backend/cli/loader.py b/src/ai/backend/cli/loader.py new file mode 100644 index 0000000000..0ec8b60529 --- /dev/null +++ b/src/ai/backend/cli/loader.py @@ -0,0 +1,23 @@ +import click # noqa: E402 + +from ai.backend.plugin.entrypoint import scan_entrypoints + +from .main import main # noqa: E402 + + +def load_entry_points() -> click.Group: + entry_prefix = 'backendai_cli_v10' + for entrypoint in scan_entrypoints(entry_prefix): + if entrypoint.name == "_": + cmd_group: click.Group = entrypoint.load() + for name, cmd in cmd_group.commands.items(): + main.add_command(cmd, name=name) + else: + prefix, _, subprefix = entrypoint.name.partition(".") + if not subprefix: + subcmd = entrypoint.load() + main.add_command(subcmd, name=prefix) + else: + subcmd = entrypoint.load() + main.commands[prefix].add_command(subcmd, name=subprefix) # type: ignore + return main diff --git a/src/ai/backend/cli/main.py b/src/ai/backend/cli/main.py new file mode 100644 index 0000000000..cf5cbe1cb3 --- /dev/null +++ b/src/ai/backend/cli/main.py @@ -0,0 +1,14 @@ +import click + +from .extensions import ExtendedCommandGroup + + +@click.group( + cls=ExtendedCommandGroup, + context_settings={ + 'help_option_names': ['-h', '--help'], + }, +) +def main() -> click.Group: + '''Unified Command Line Interface for Backend.ai''' + pass diff --git a/src/ai/backend/cli/py.typed b/src/ai/backend/cli/py.typed new file mode 100644 index 0000000000..b3a425249b --- /dev/null +++ b/src/ai/backend/cli/py.typed @@ -0,0 +1 @@ +placeholder \ No newline at end of file diff --git a/src/ai/backend/client/BUILD b/src/ai/backend/client/BUILD new file mode 100644 index 0000000000..53db75eee8 --- /dev/null +++ b/src/ai/backend/client/BUILD @@ -0,0 +1,48 @@ +python_sources( + name="lib", + dependencies=[ + "src/ai/backend/cli:lib", + ":resources", + ], + sources=["**/*.py"], +) + +pex_binary( + name="cli", + dependencies=[ + ":lib", + ], + entry_point="cli/__main__.py", +) + +python_distribution( + name="dist", + dependencies=[ + ":lib", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-client", + description="Backend.AI Client SDK", + license="MIT", + ), + entry_points={ + "backendai_cli_v10": { + "_": "ai.backend.client.cli.main:main", + }, + }, + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + ], +) diff --git a/src/ai/backend/client/README.rst b/src/ai/backend/client/README.rst new file mode 100644 index 0000000000..2c227fb936 --- /dev/null +++ b/src/ai/backend/client/README.rst @@ -0,0 +1,218 @@ +Backend.AI Client +================= + +.. image:: https://badge.fury.io/py/backend.ai-client.svg + :target: https://badge.fury.io/py/backend.ai-client + :alt: PyPI version + +.. image:: https://img.shields.io/pypi/pyversions/backend.ai-client.svg + :target: https://pypi.org/project/backend.ai-client/ + :alt: Python Versions + +.. image:: https://readthedocs.org/projects/backendai-client-sdk-for-python/badge/?version=latest + :target: https://client-py.docs.backend.ai/en/latest/?badge=latest + :alt: SDK Documentation + +.. image:: https://travis-ci.com/lablup/backend.ai-client-py.svg?branch=master + :target: https://travis-ci.com/lablup/backend.ai-client-py + :alt: Build Status (Linux) + +.. image:: https://ci.appveyor.com/api/projects/status/5h6r1cmbx2965yn1/branch/master?svg=true + :target: https://ci.appveyor.com/project/lablup/backend.ai-client-py/branch/master + :alt: Build Status (Windows) + +.. image:: https://codecov.io/gh/lablup/backend.ai-client-py/branch/master/graph/badge.svg + :target: https://codecov.io/gh/lablup/backend.ai-client-py + :alt: Code Coverage + +The official API client library for `Backend.AI `_ + + +Usage (KeyPair mode) +-------------------- + +You should set the access key and secret key as environment variables to use the API. +Grab your keypair from `cloud.backend.ai `_ or your cluster +admin. + +On Linux/macOS, create a shell script as ``my-backend-ai.sh`` and run it before using +the ``backend.ai`` command: + +.. code-block:: sh + + export BACKEND_ACCESS_KEY=... + export BACKEND_SECRET_KEY=... + export BACKEND_ENDPOINT=https://my-precious-cluster + export BACKEND_ENDPOINT_TYPE=api + +On Windows, create a batch file as ``my-backend-ai.bat`` and run it before using +the ``backend.ai`` command: + +.. code-block:: bat + + chcp 65001 + set PYTHONIOENCODING=UTF-8 + set BACKEND_ACCESS_KEY=... + set BACKEND_SECRET_KEY=... + set BACKEND_ENDPOINT=https://my-precious-cluster + set BACKEND_ENDPOINT_TYPE=api + +Note that you need to switch to the UTF-8 codepage for correct display of +special characters used in the console logs. + + +Usage (Session mode) +-------------------- + +Change ``BACKEND_ENDPOINT_TYPE`` to "session" and set the endpoint to the URL of your console server. + +.. code-block:: sh + + export BACKEND_ENDPOINT=https://my-precious-cluster + export BACKEND_ENDPOINT_TYPE=session + +.. code-block:: console + + $ backend.ai login + User ID: myid@mydomain.com + Password: + ✔ Login succeeded! + + $ backend.ai ... # run any command + + $ backend.ai logout + ✔ Logout done. + +The session expiration timeout is set by the console server. + + +Command-line Interface +---------------------- + +``backend.ai`` command is the entry point of all sub commands. +(Alternatively you can use a verbosely long version: ``python -m ai.backend.client.cli``) + +Highlight: ``run`` command +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``run`` command execute a code snippet or code source files on a Backend.AI compute session +created on-the-fly. + +To run the code specified in the command line directly, +use ``-c`` option to pass the code string (like a shell). + +.. code-block:: console + + $ backend.ai run python:3.6-ubuntu18.04 -c "print('hello world')" + ∙ Client session token: d3694dda6e5a9f1e5c718e07bba291a9 + ✔ Kernel (ID: zuF1OzMIhFknyjUl7Apbvg) is ready. + hello world + +By default, you need to specify language with full version tag like +``python:3.6-ubuntu18.04``. Depending on the Backend.AI admin's language +alias settings, this can be shortened just as ``python``. If you want to +know defined language aliases, contact the admin of Backend.AI server. + +You can even run a C code on-the-fly. (Note that we put a dollar sign before +the single-quoted code argument so that the shell to interpret ``'\n'`` as +actual newlines.) + +.. code-block:: console + + $ backend.ai run gcc:gcc6.4-alpine3.8 -c $'#include \nint main() {printf("hello world\\n");}' + ∙ Client session token: abc06ee5e03fce60c51148c6d2dd6126 + ✔ Kernel (ID: d1YXvee-uAJTx4AKYyeksA) is ready. + hello world + +For larger programs, you may upload multiple files and then build & execute +them. The below is a simple example to run `a sample C program +`_. + +.. code-block:: console + + $ git clone https://gist.github.com/achimnol/df464c6a3fe05b21e9b06d5b80e986c5 c-example + Cloning into 'c-example'... + Unpacking objects: 100% (5/5), done. + $ cd c-example + $ backend.ai run gcc:gcc6.4-alpine3.8 main.c mylib.c mylib.h + ∙ Client session token: 1c352a572bc751a81d1f812186093c47 + ✔ Kernel (ID: kJ6CgWR7Tz3_v2WsDHOwLQ) is ready. + ✔ Uploading done. + ✔ Build finished. + myvalue is 42 + your name? LABLUP + hello, LABLUP! + +Please refer the ``--help`` manual provided by the ``run`` command. + +Highlight: ``start`` and ``app`` command +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``backend.ai start`` is simliar to the ``run`` command in that it creates a new compute session, +but it does not execute anything there. +You can subsequently call ``backend.ai run -t ...`` to execute codes snippets +or use ``backend.ai app`` command to start a local proxy to a container service such as Jupyter which +runs inside the compute session. + +.. code-block:: console + + $ backend.ai start -t mysess -r cpu=1 -r mem=2g lablup/python:3.6-ubuntu18.04 + ∙ Session ID mysess is created and ready. + ∙ This session provides the following app services: ipython, jupyter, jupyterlab + $ backend.ai app mysess jupyter + ∙ A local proxy to the application "jupyter" provided by the session "mysess" is available at: http://127.0.0.1:8080 + + +Highlight: ``ps`` and ``rm`` command +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can see the list of currently running sessions using your API keypair. + +.. code-block:: console + + $ backend.ai ps + Session ID Lang/runtime Tag Created At Terminated At Status CPU Cores CPU Used (ms) Total Memory (MiB) Used Memory (MiB) GPU Cores + ------------ ------------------------ ----- -------------------------------- --------------- -------- ----------- --------------- -------------------- ------------------- ----------- + 88ee10a027 lablup/python:3.6-ubuntu 2018-12-11T03:53:14.802206+00:00 RUNNING 1 16314 1024 39.2 0 + fce7830826 lablup/python:3.6-ubuntu 2018-12-11T03:50:10.150740+00:00 RUNNING 1 15391 1024 39.2 0 + +If you set ``-t`` option in the ``run`` command, it will be used as the session ID—you may use it to assign a human-readable, easy-to-type alias for your sessions. +These session IDs can be reused after the current session using the same ID terminates. + +To terminate a session, you can use ``terminate`` or ``rm`` command. + +.. code-block:: console + + $ backend.ai rm 5baafb2136029228ca9d873e1f2b4f6a + ✔ Done. + +Highlight: ``proxy`` command +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To use API development tools such as GraphiQL for the admin API, run an insecure +local API proxy. This will attach all the necessary authorization headers to your +vanilla HTTP API requests. + +.. code-block:: console + + $ backend.ai proxy + ∙ Starting an insecure API proxy at http://localhost:8084 + +More commands? +~~~~~~~~~~~~~~ + +Please run ``backend.ai --help`` to see more commands. + + +Troubleshooting (FAQ) +--------------------- + +* There are error reports related to ``simplejson`` with Anaconda on Windows. + This package no longer depends on simplejson since v1.0.5, so you may uninstall it + safely since Python 3.5+ offers almost identical ``json`` module in the standard + library. + + If you really need to keep the ``simplejson`` package, uninstall the existing + simplejson package manually and try reinstallation of it by downloading `a + pre-built binary wheel from here + `_. diff --git a/src/ai/backend/client/VERSION b/src/ai/backend/client/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/client/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/client/__init__.py b/src/ai/backend/client/__init__.py new file mode 100644 index 0000000000..e997ce4d4c --- /dev/null +++ b/src/ai/backend/client/__init__.py @@ -0,0 +1,15 @@ +from pathlib import Path + +from . import exceptions +from . import session + +__all__ = ( + *exceptions.__all__, + *session.__all__, +) + +__version__ = (Path(__file__).parent / 'VERSION').read_text().strip() + + +def get_user_agent(): + return 'Backend.AI Client for Python {0}'.format(__version__) diff --git a/src/ai/backend/client/auth.py b/src/ai/backend/client/auth.py new file mode 100644 index 0000000000..70cec136d2 --- /dev/null +++ b/src/ai/backend/client/auth.py @@ -0,0 +1,73 @@ +from datetime import datetime +import enum +import hashlib +import hmac +from typing import ( + Mapping, + Tuple, +) + +import attr +from yarl import URL + +__all__ = ( + 'AuthToken', + 'AuthTokenTypes', + 'generate_signature', +) + + +class AuthTokenTypes(enum.Enum): + KEYPAIR = 'keypair' + JWT = 'jwt' + + +@attr.s +class AuthToken: + type = attr.ib(default=AuthTokenTypes.KEYPAIR) # type: AuthTokenTypes + content = attr.ib(default=None) # type: str + + +def generate_signature( + *, + method: str, + version: str, + endpoint: URL, + date: datetime, + rel_url: str, + content_type: str, + access_key: str, + secret_key: str, + hash_type: str, +) -> Tuple[Mapping[str, str], str]: + ''' + Generates the API request signature from the given parameters. + ''' + hash_type = hash_type + hostname = endpoint._val.netloc # type: ignore + body_hash = hashlib.new(hash_type, b'').hexdigest() + + sign_str = '{}\n{}\n{}\nhost:{}\ncontent-type:{}\nx-backendai-version:{}\n{}'.format( # noqa + method.upper(), + rel_url, + date.isoformat(), + hostname, + content_type.lower(), + version, + body_hash, + ) + sign_bytes = sign_str.encode() + + sign_key = hmac.new(secret_key.encode(), + date.strftime('%Y%m%d').encode(), hash_type).digest() + sign_key = hmac.new(sign_key, hostname.encode(), hash_type).digest() + + signature = hmac.new(sign_key, sign_bytes, hash_type).hexdigest() + headers = { + 'Authorization': 'BackendAI signMethod=HMAC-{}, credential={}:{}'.format( + hash_type.upper(), + access_key, + signature, + ), + } + return headers, signature diff --git a/src/ai/backend/client/cli/__init__.py b/src/ai/backend/client/cli/__init__.py new file mode 100644 index 0000000000..0ca64d5afa --- /dev/null +++ b/src/ai/backend/client/cli/__init__.py @@ -0,0 +1,9 @@ +from . import main # noqa +from . import config # noqa +from . import session # noqa +from . import session_template # noqa +from . import vfolder # noqa +from . import dotfile # noqa +from . import server_log # noqa +from . import admin # noqa +from . import app, logs, proxy, run # noqa diff --git a/src/ai/backend/client/cli/__main__.py b/src/ai/backend/client/cli/__main__.py new file mode 100644 index 0000000000..bcbfde6d69 --- /dev/null +++ b/src/ai/backend/client/cli/__main__.py @@ -0,0 +1,5 @@ +from .main import main + + +if __name__ == "__main__": + main() diff --git a/src/ai/backend/client/cli/admin/__init__.py b/src/ai/backend/client/cli/admin/__init__.py new file mode 100644 index 0000000000..55ae8db1cd --- /dev/null +++ b/src/ai/backend/client/cli/admin/__init__.py @@ -0,0 +1,27 @@ +from ..main import main + + +@main.group() +def admin(): + """ + Administrative command set + """ + + +from . import ( # noqa + agent, + domain, + etcd, + group, + image, + keypair, + manager, + license, + resource, + resource_policy, + scaling_group, + session, + storage, + user, + vfolder, +) diff --git a/src/ai/backend/client/cli/admin/agent.py b/src/ai/backend/client/cli/admin/agent.py new file mode 100644 index 0000000000..af71f22a6a --- /dev/null +++ b/src/ai/backend/client/cli/admin/agent.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import sys + +import click + +from ai.backend.client.session import Session +from ai.backend.client.output.fields import agent_fields +from ..types import CLIContext +from . import admin + + +@admin.group() +def agent(): + """ + Agent administration commands. + """ + + +@agent.command() +@click.pass_obj +@click.argument('agent_id') +def info(ctx: CLIContext, agent_id: str) -> None: + """ + Show the information about the given agent. + """ + fields = [ + agent_fields['id'], + agent_fields['status'], + agent_fields['region'], + agent_fields['architecture'], + agent_fields['first_contact'], + agent_fields['cpu_cur_pct'], + agent_fields['available_slots'], + agent_fields['occupied_slots'], + agent_fields['hardware_metadata'], + agent_fields['live_stat'], + ] + with Session() as session: + try: + item = session.Agent.detail(agent_id=agent_id, fields=fields) + ctx.output.print_item(item, fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@agent.command() +@click.pass_obj +@click.option('-s', '--status', type=str, default='ALIVE', + help='Filter agents by the given status.') +@click.option('--scaling-group', '--sgroup', type=str, default=None, + help='Filter agents by the scaling group.') +@click.option('--filter', 'filter_', default=None, + help='Set the query filter expression.') +@click.option('--order', default=None, + help='Set the query ordering expression.') +@click.option('--offset', default=0, + help='The index of the current page start for pagination.') +@click.option('--limit', default=None, + help='The page size for pagination.') +def list( + ctx: CLIContext, + status: str, + scaling_group: str | None, + filter_: str | None, + order: str | None, + offset: int, + limit: int | None, +) -> None: + """ + List agents. + (super-admin privilege required) + """ + fields = [ + agent_fields['id'], + agent_fields['status'], + agent_fields['architecture'], + agent_fields['scaling_group'], + agent_fields['region'], + agent_fields['first_contact'], + agent_fields['cpu_cur_pct'], + agent_fields['mem_cur_bytes'], + agent_fields['available_slots'], + agent_fields['occupied_slots'], + ] + try: + with Session() as session: + fetch_func = lambda pg_offset, pg_size: session.Agent.paginated_list( + status, + scaling_group, + fields=fields, + page_offset=pg_offset, + page_size=pg_size, + filter=filter_, + order=order, + ) + ctx.output.print_paginated_list( + fetch_func, + initial_page_offset=offset, + page_size=limit, + ) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@admin.group() +def watcher(): + """ + Agent watcher commands. + + Available only for Linux-based agents. + """ + + +@watcher.command() +@click.pass_obj +@click.argument('agent', type=str) +def status(ctx: CLIContext, agent: str) -> None: + """ + Get agent and watcher status. + (superadmin privilege required) + + \b + AGENT: Agent id. + """ + with Session() as session: + try: + status = session.AgentWatcher.get_status(agent) + print(status) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@watcher.command() +@click.pass_obj +@click.argument('agent', type=str) +def agent_start(ctx: CLIContext, agent: str) -> None: + """ + Start agent service. + (superadmin privilege required) + + \b + AGENT: Agent id. + """ + with Session() as session: + try: + status = session.AgentWatcher.agent_start(agent) + print(status) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@watcher.command() +@click.pass_obj +@click.argument('agent', type=str) +def agent_stop(ctx: CLIContext, agent: str) -> None: + """ + Stop agent service. + (superadmin privilege required) + + \b + AGENT: Agent id. + """ + with Session() as session: + try: + status = session.AgentWatcher.agent_stop(agent) + print(status) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@watcher.command() +@click.pass_obj +@click.argument('agent', type=str) +def agent_restart(ctx: CLIContext, agent: str) -> None: + """ + Restart agent service. + (superadmin privilege required) + + \b + AGENT: Agent id. + """ + with Session() as session: + try: + status = session.AgentWatcher.agent_restart(agent) + print(status) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/cli/admin/domain.py b/src/ai/backend/client/cli/admin/domain.py new file mode 100644 index 0000000000..2da7a25f30 --- /dev/null +++ b/src/ai/backend/client/cli/admin/domain.py @@ -0,0 +1,228 @@ +import sys + +import click + +from ai.backend.cli.interaction import ask_yn +from ai.backend.client.session import Session +from ai.backend.client.func.domain import ( + _default_list_fields, + _default_detail_fields, +) +# from ai.backend.client.output.fields import domain_fields +from . import admin +from ..pretty import print_info + +from ..types import CLIContext + + +@admin.group() +def domain(): + """ + Domain administration commands. + """ + + +@domain.command() +@click.pass_obj +@click.argument('name', type=str) +def info(ctx: CLIContext, name: str) -> None: + """ + Show the information about the given domain. + If name is not give, user's own domain information will be retrieved. + """ + with Session() as session: + try: + item = session.Domain.detail(name=name) + ctx.output.print_item(item, _default_detail_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@domain.command() +@click.pass_obj +def list(ctx: CLIContext) -> None: + """ + List and manage domains. + (admin privilege required) + """ + with Session() as session: + try: + items = session.Domain.list() + ctx.output.print_list(items, _default_list_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@domain.command() +@click.pass_obj +@click.argument('name', type=str, metavar='NAME') +@click.option('-d', '--description', type=str, default='', + help='Description of new domain') +@click.option('-i', '--inactive', is_flag=True, + help='New domain will be inactive.') +@click.option('--total-resource-slots', type=str, default='{}', + help='Set total resource slots.') +@click.option('--allowed-vfolder-hosts', type=str, multiple=True, + help='Allowed virtual folder hosts.') +@click.option('--allowed-docker-registries', type=str, multiple=True, + help='Allowed docker registries.') +def add(ctx: CLIContext, name, description, inactive, total_resource_slots, + allowed_vfolder_hosts, allowed_docker_registries): + """ + Add a new domain. + + NAME: Name of new domain. + """ + with Session() as session: + try: + data = session.Domain.create( + name, + description=description, + is_active=not inactive, + total_resource_slots=total_resource_slots, + allowed_vfolder_hosts=allowed_vfolder_hosts, + allowed_docker_registries=allowed_docker_registries, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='domain', + action_name='add', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='domain', + action_name='add', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + item_name='domain', + ) + + +@domain.command() +@click.pass_obj +@click.argument('name', type=str, metavar='NAME') +@click.option('--new-name', type=str, help='New name of the domain') +@click.option('--description', type=str, help='Description of the domain') +@click.option('--is-active', type=bool, help='Set domain inactive.') +@click.option('--total-resource-slots', type=str, + help='Update total resource slots.') +@click.option('--allowed-vfolder-hosts', type=str, multiple=True, + help='Allowed virtual folder hosts.') +@click.option('--allowed-docker-registries', type=str, multiple=True, + help='Allowed docker registries.') +def update(ctx: CLIContext, name, new_name, description, is_active, total_resource_slots, + allowed_vfolder_hosts, allowed_docker_registries): + """ + Update an existing domain. + + NAME: Name of new domain. + """ + with Session() as session: + try: + data = session.Domain.update( + name, + new_name=new_name, + description=description, + is_active=is_active, + total_resource_slots=total_resource_slots, + allowed_vfolder_hosts=allowed_vfolder_hosts, + allowed_docker_registries=allowed_docker_registries, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='domain', + action_name='update', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='domain', + action_name='update', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'name': name, + }, + ) + + +@domain.command() +@click.pass_obj +@click.argument('name', type=str, metavar='NAME') +def delete(ctx: CLIContext, name): + """ + Inactive an existing domain. + + NAME: Name of a domain to inactive. + """ + with Session() as session: + try: + data = session.Domain.delete(name) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='domain', + action_name='deletion', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='domain', + action_name='deletion', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'name': name, + }, + ) + + +@domain.command() +@click.pass_obj +@click.argument('name', type=str, metavar='NAME') +def purge(ctx: CLIContext, name): + """ + Delete an existing domain. + + NAME: Name of a domain to delete. + """ + with Session() as session: + try: + if not ask_yn(): + print_info('Cancelled') + sys.exit(1) + data = session.Domain.purge(name) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='domain', + action_name='purge', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='domain', + action_name='purge', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'name': name, + }, + ) diff --git a/src/ai/backend/client/cli/admin/etcd.py b/src/ai/backend/client/cli/admin/etcd.py new file mode 100644 index 0000000000..41c363500f --- /dev/null +++ b/src/ai/backend/client/cli/admin/etcd.py @@ -0,0 +1,85 @@ +import json +import sys + +import click + +from . import admin +from ..pretty import print_pretty, print_error, print_fail +from ...session import Session + + +@admin.group() +def etcd() -> None: + """ + etcd query and manipulation commands. + (admin privilege required) + """ + + +@etcd.command() +@click.argument('key', type=str, metavar='KEY') +@click.option('-p', '--prefix', is_flag=True, default=False, + help='Get all keys prefixed with the given key.') +def get(key, prefix): + """ + Get a ETCD value(s). + + KEY: Name of ETCD key. + """ + with Session() as session: + try: + data = session.EtcdConfig.get(key, prefix) + except Exception as e: + print_error(e) + sys.exit(1) + data = json.dumps(data, indent=2) if data else 'null' + print_pretty(data) + + +@etcd.command() +@click.argument('key', type=str, metavar='KEY') +@click.argument('value', type=str, metavar='VALUE') +def set(key, value): + """ + Set new key and value on ETCD. + + KEY: Name of ETCD key. + VALUE: Value to set. + """ + with Session() as session: + try: + value = json.loads(value) + print_pretty('Value converted to a dictionary.') + except json.JSONDecodeError: + pass + try: + data = session.EtcdConfig.set(key, value) + except Exception as e: + print_error(e) + sys.exit(1) + if data.get('result', False) != 'ok': + print_fail('Unable to set key/value.') + else: + print_pretty('Successfully set key/value.') + + +@etcd.command() +@click.argument('key', type=str, metavar='KEY') +@click.option('-p', '--prefix', is_flag=True, default=False, + help='Delete all keys prefixed with the given key.') +def delete(key, prefix): + """ + Delete key(s) from ETCD. + + KEY: Name of ETCD key. + """ + with Session() as session: + try: + data = session.EtcdConfig.delete(key, prefix) + except Exception as e: + print_error(e) + sys.exit(1) + if data.get('result', False) != 'ok': + print_fail('Unable to delete key/value.') + else: + print_pretty('Successfully deleted key/value.') diff --git a/src/ai/backend/client/cli/admin/group.py b/src/ai/backend/client/cli/admin/group.py new file mode 100644 index 0000000000..b028995b07 --- /dev/null +++ b/src/ai/backend/client/cli/admin/group.py @@ -0,0 +1,322 @@ +import sys +import uuid + +import click + +from ai.backend.cli.interaction import ask_yn +from ai.backend.client.session import Session +from ai.backend.client.func.group import ( + _default_list_fields, + _default_detail_fields, +) +# from ai.backend.client.output.fields import group_fields +from . import admin +from ..pretty import print_info + +from ..types import CLIContext + + +@admin.group() +def group() -> None: + """ + User group (project) administration commands + """ + + +@group.command() +@click.pass_obj +@click.argument('id_or_name', type=str) +def info(ctx: CLIContext, id_or_name: str) -> None: + """ + Show the information about the group(s) having the given name. + Two or more groups in different domains may have the same name, + so this may print out information of multiple groups if queried + by a superadmin. + + When queried with a human-readable name by a super-admin, + it may return multiple results with the same name from + different domains. + + \b + id_or_name: Group ID (UUID) or name. + """ + with Session() as session: + try: + gid = uuid.UUID(id_or_name) + except ValueError: + # interpret as name + try: + item = session.Group.from_name(id_or_name) + ctx.output.print_item(item, _default_detail_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + else: + # interpret as UUID + try: + item = session.Group.detail(gid=str(gid)) + ctx.output.print_item(item, _default_detail_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@group.command() +@click.pass_obj +@click.option('-d', '--domain-name', type=str, default=None, + help='Domain name to list groups belongs to it.') +def list(ctx: CLIContext, domain_name) -> None: + """ + List groups in the given domain. + (admin privilege required) + """ + with Session() as session: + try: + items = session.Group.list(domain_name=domain_name) + ctx.output.print_list(items, _default_list_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@group.command() +@click.pass_obj +@click.argument('domain_name', type=str, metavar='DOMAIN_NAME') +@click.argument('name', type=str, metavar='NAME') +@click.option('-d', '--description', type=str, default='', + help='Description of new group.') +@click.option('-i', '--inactive', is_flag=True, + help='New group will be inactive.') +@click.option('--total-resource-slots', type=str, default='{}', + help='Set total resource slots.') +@click.option('--allowed-vfolder-hosts', type=str, multiple=True, + help='Allowed virtual folder hosts.') +def add(ctx: CLIContext, domain_name, name, description, inactive, total_resource_slots, + allowed_vfolder_hosts): + """ + Add new group. A group must belong to a domain, so DOMAIN_NAME should be provided. + + \b + DOMAIN_NAME: Name of the domain where new group belongs to. + NAME: Name of new group. + """ + with Session() as session: + try: + data = session.Group.create( + domain_name, name, + description=description, + is_active=not inactive, + total_resource_slots=total_resource_slots, + allowed_vfolder_hosts=allowed_vfolder_hosts, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='group', + action_name='add', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='group', + action_name='add', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + item_name='group', + ) + + +@group.command() +@click.pass_obj +@click.argument('gid', type=str, metavar='GROUP_ID') +@click.option('-n', '--name', type=str, help='New name of the group') +@click.option('-d', '--description', type=str, help='Description of the group') +@click.option('--is-active', type=bool, help='Set group inactive.') +@click.option('--total-resource-slots', type=str, help='Update total resource slots.') +@click.option('--allowed-vfolder-hosts', type=str, multiple=True, + help='Allowed virtual folder hosts.') +def update(ctx: CLIContext, gid, name, description, is_active, total_resource_slots, + allowed_vfolder_hosts): + """ + Update an existing group. Domain name is not necessary since group ID is unique. + + GROUP_ID: Group ID to update. + """ + with Session() as session: + try: + data = session.Group.update( + gid, + name=name, + description=description, + is_active=is_active, + total_resource_slots=total_resource_slots, + allowed_vfolder_hosts=allowed_vfolder_hosts, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='group', + action_name='update', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='group', + action_name='update', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'gid': gid, + }, + ) + + +@group.command() +@click.pass_obj +@click.argument('gid', type=str, metavar='GROUP_ID') +def delete(ctx: CLIContext, gid): + """ + Inactivates the existing group. Does not actually delete it for safety. + + GROUP_ID: Group ID to inactivate. + """ + with Session() as session: + try: + data = session.Group.delete(gid) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='group', + action_name='deletion', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='group', + action_name='deletion', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'gid': gid, + }, + ) + + +@group.command() +@click.pass_obj +@click.argument('gid', type=str, metavar='GROUP_ID') +def purge(ctx: CLIContext, gid): + """ + Delete the existing group. This action cannot be undone. + + GROUP_ID: Group ID to inactivate. + """ + with Session() as session: + try: + if not ask_yn(): + print_info('Cancelled') + sys.exit(1) + data = session.Group.purge(gid) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='group', + action_name='purge', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='group', + action_name='purge', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'gid': gid, + }, + ) + + +@group.command() +@click.pass_obj +@click.argument('gid', type=str, metavar='GROUP_ID') +@click.argument('user_uuids', type=str, metavar='USER_UUIDS', nargs=-1) +def add_users(ctx: CLIContext, gid, user_uuids): + """ + Add users to a group. + + \b + GROUP_ID: Group ID where users will be belong to. + USER_UUIDS: List of users' uuids to be added to the group. + """ + with Session() as session: + try: + data = session.Group.add_users(gid, user_uuids) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='group', + action_name='add_users', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='group', + action_name='add_users', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'gid': gid, + }, + ) + + +@group.command() +@click.pass_obj +@click.argument('gid', type=str, metavar='GROUP_ID') +@click.argument('user_uuids', type=str, metavar='USER_UUIDS', nargs=-1) +def remove_users(ctx: CLIContext, gid, user_uuids): + """ + Remove users from a group. + + \b + GROUP_ID: Group ID where users currently belong to. + USER_UUIDS: List of users' uuids to be removed from the group. + """ + with Session() as session: + try: + data = session.Group.remove_users(gid, user_uuids) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='group', + action_name='users_remove', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='group', + action_name='users_remove', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'gid': gid, + }, + ) diff --git a/src/ai/backend/client/cli/admin/image.py b/src/ai/backend/client/cli/admin/image.py new file mode 100644 index 0000000000..219f00093b --- /dev/null +++ b/src/ai/backend/client/cli/admin/image.py @@ -0,0 +1,120 @@ +import json +import sys + +import click +from tqdm import tqdm + +from ai.backend.client.session import Session +from ai.backend.client.func.image import ( + _default_list_fields_admin, +) +# from ai.backend.client.output.fields import image_fields +from . import admin +from ...compat import asyncio_run +from ...session import AsyncSession +from ..pretty import print_done, print_warn, print_fail, print_error + +from ..types import CLIContext + + +@admin.group() +def image() -> None: + """ + Image administration commands. + """ + + +@image.command() +@click.pass_obj +@click.option('--operation', is_flag=True, help='Get operational images only') +def list(ctx: CLIContext, operation: bool) -> None: + """ + Show the list of registered images in this cluster. + """ + with Session() as session: + try: + items = session.Image.list(operation=operation) + ctx.output.print_list(items, _default_list_fields_admin) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@image.command() +@click.option('-r', '--registry', type=str, default=None, + help='The name (usually hostname or "lablup") ' + 'of the Docker registry configured.') +def rescan(registry: str) -> None: + """ + Update the kernel image metadata from all configured docker registries. + """ + + async def rescan_images_impl(registry: str) -> None: + async with AsyncSession() as session: + try: + result = await session.Image.rescan_images(registry) + except Exception as e: + print_error(e) + sys.exit(1) + if not result['ok']: + print_fail(f"Failed to begin registry scanning: {result['msg']}") + sys.exit(1) + print_done("Started updating the image metadata from the configured registries.") + bgtask_id = result['task_id'] + bgtask = session.BackgroundTask(bgtask_id) + try: + completion_msg_func = lambda: print_done("Finished registry scanning.") + with tqdm(unit='image') as pbar: + async with bgtask.listen_events() as response: + async for ev in response: + data = json.loads(ev.data) + if ev.event == 'bgtask_updated': + pbar.total = data['total_progress'] + pbar.write(data['message']) + pbar.update(data['current_progress'] - pbar.n) + elif ev.event == 'bgtask_failed': + error_msg = data['message'] + completion_msg_func = \ + lambda: print_fail(f"Error occurred: {error_msg}") + elif ev.event == 'bgtask_cancelled': + completion_msg_func = \ + lambda: print_warn("Registry scanning has been " + "cancelled in the middle.") + finally: + completion_msg_func() + + asyncio_run(rescan_images_impl(registry)) + + +@image.command() +@click.argument('alias', type=str) +@click.argument('target', type=str) +@click.option('--arch', type=str, default=None, help='Set an explicit architecture.') +def alias(alias, target, arch): + """Add an image alias.""" + with Session() as session: + try: + result = session.Image.alias_image(alias, target, arch) + except Exception as e: + print_error(e) + sys.exit(1) + if result['ok']: + print_done(f"An alias has created: {alias} -> {target}") + else: + print_fail("Aliasing has failed: {0}".format(result['msg'])) + + +@image.command() +@click.argument('alias', type=str) +def dealias(alias): + """Remove an image alias.""" + with Session() as session: + try: + result = session.Image.dealias_image(alias) + except Exception as e: + print_error(e) + sys.exit(1) + if result['ok']: + print_done(f"The alias has been removed: {alias}") + else: + print_fail("Dealiasing has failed: {0}".format(result['msg'])) diff --git a/src/ai/backend/client/cli/admin/keypair.py b/src/ai/backend/client/cli/admin/keypair.py new file mode 100644 index 0000000000..14026136d2 --- /dev/null +++ b/src/ai/backend/client/cli/admin/keypair.py @@ -0,0 +1,293 @@ +import sys + +import click + +from ai.backend.client.session import Session +from ai.backend.client.output.fields import keypair_fields +from . import admin +from ..types import CLIContext + + +@admin.group() +def keypair() -> None: + """ + KeyPair administration commands. + """ + + +@keypair.command() +@click.pass_obj +def info(ctx: CLIContext) -> None: + """ + Show the server-side information of the currently configured access key. + """ + fields = [ + keypair_fields['user_id'], + keypair_fields['full_name'], + keypair_fields['access_key'], + keypair_fields['secret_key'], + keypair_fields['is_active'], + keypair_fields['is_admin'], + keypair_fields['created_at'], + keypair_fields['last_used'], + keypair_fields['resource_policy'], + keypair_fields['rate_limit'], + keypair_fields['concurrency_used'], + ] + with Session() as session: + try: + kp = session.KeyPair(session.config.access_key) + item = kp.info(fields=fields) + ctx.output.print_item(item, fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@keypair.command() +@click.pass_obj +@click.option('-u', '--user-id', type=str, default=None, + help='Show keypairs of this given user. [default: show all]') +@click.option('--is-active', type=bool, default=None, + help='Filter keypairs by activation.') +@click.option('--filter', 'filter_', default=None, + help='Set the query filter expression.') +@click.option('--order', default=None, + help='Set the query ordering expression.') +@click.option('--offset', default=0, + help='The index of the current page start for pagination.') +@click.option('--limit', default=None, + help='The page size for pagination.') +def list(ctx: CLIContext, user_id, is_active, filter_, order, offset, limit) -> None: + """ + List keypairs. + To show all keypairs or other user's, your access key must have the admin + privilege. + (admin privilege required) + """ + fields = [ + keypair_fields['user_id'], + keypair_fields['full_name'], + keypair_fields['access_key'], + keypair_fields['secret_key'], + keypair_fields['is_active'], + keypair_fields['is_admin'], + keypair_fields['created_at'], + keypair_fields['last_used'], + keypair_fields['resource_policy'], + keypair_fields['rate_limit'], + keypair_fields['concurrency_used'], + ] + try: + with Session() as session: + fetch_func = lambda pg_offset, pg_size: session.KeyPair.paginated_list( + is_active, + user_id=user_id, + fields=fields, + page_offset=pg_offset, + page_size=pg_size, + filter=filter_, + order=order, + ) + ctx.output.print_paginated_list( + fetch_func, + initial_page_offset=offset, + page_size=limit, + ) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@keypair.command() +@click.pass_obj +@click.argument('user-id', type=str, default=None, metavar='USERID') +@click.argument('resource-policy', type=str, default=None, metavar='RESOURCE_POLICY') +@click.option('-a', '--admin', is_flag=True, + help='Give the admin privilege to the new keypair.') +@click.option('-i', '--inactive', is_flag=True, + help='Create the new keypair in inactive state.') +@click.option('-r', '--rate-limit', type=int, default=5000, + help='Set the API query rate limit.') +def add(ctx: CLIContext, user_id, resource_policy, admin, inactive, rate_limit): + """ + Add a new keypair. + + USER_ID: User ID of a new key pair. + RESOURCE_POLICY: resource policy for new key pair. + """ + with Session() as session: + try: + data = session.KeyPair.create( + user_id, + is_active=not inactive, + is_admin=admin, + resource_policy=resource_policy, + rate_limit=rate_limit) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='keypair', + action_name='add', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='keypair', + action_name='add', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + item_name='keypair', + extra_info={ + 'access_key': data['keypair']['access_key'], + 'secret_key': data['keypair']['secret_key'], + }, + ) + + +@keypair.command() +@click.pass_obj +@click.argument('access_key', type=str, default=None, metavar='ACCESSKEY') +@click.option('--resource-policy', type=str, help='Resource policy for the keypair.') +@click.option('--is-admin', type=bool, help='Set admin privilege.') +@click.option('--is-active', type=bool, help='Set key pair active or not.') +@click.option('-r', '--rate-limit', type=int, help='Set the API query rate limit.') +def update(ctx: CLIContext, access_key, resource_policy, is_admin, is_active, rate_limit): + """ + Update an existing keypair. + + ACCESS_KEY: Access key of an existing key pair. + """ + with Session() as session: + try: + data = session.KeyPair.update( + access_key, + is_active=is_active, + is_admin=is_admin, + resource_policy=resource_policy, + rate_limit=rate_limit) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='keypair', + action_name='update', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='keypair', + action_name='update', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'access_key': access_key, + }, + ) + + +@keypair.command() +@click.pass_obj +@click.argument('access-key', type=str, metavar='ACCESSKEY') +def delete(ctx: CLIContext, access_key): + """ + Delete an existing keypair. + + ACCESSKEY: ACCESSKEY for a keypair to delete. + """ + with Session() as session: + try: + data = session.KeyPair.delete(access_key) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='keypair', + action_name='deletion', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='keypair', + action_name='deletion', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'access_key': access_key, + }, + ) + + +@keypair.command() +@click.pass_obj +@click.argument('access-key', type=str, metavar='ACCESSKEY') +def activate(ctx: CLIContext, access_key): + """ + Activate an inactivated keypair. + + ACCESS_KEY: Access key of an existing key pair. + """ + with Session() as session: + try: + data = session.KeyPair.activate(access_key) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='keypair', + action_name='activation', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='keypair', + action_name='activation', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'access_key': access_key, + }, + ) + + +@keypair.command() +@click.pass_obj +@click.argument('access-key', type=str, metavar='ACCESSKEY') +def deactivate(ctx: CLIContext, access_key): + """ + Deactivate an active keypair. + + ACCESS_KEY: Access key of an existing key pair. + """ + with Session() as session: + try: + data = session.KeyPair.deactivate(access_key) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='keypair', + action_name='deactivation', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='keypair', + action_name='deactivation', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'access_key': access_key, + }, + ) diff --git a/src/ai/backend/client/cli/admin/license.py b/src/ai/backend/client/cli/admin/license.py new file mode 100644 index 0000000000..31f0498f8e --- /dev/null +++ b/src/ai/backend/client/cli/admin/license.py @@ -0,0 +1,38 @@ +import asyncio +import sys + +from tabulate import tabulate + +from ...session import AsyncSession +from ...request import Request +from ..pretty import print_done, print_error, print_warn +from . import admin + + +@admin.group() +def license() -> None: + """ + License administration commands. + """ + + +@license.command() +def show(): + """ + Show the license information (enterprise editions only). + """ + async def _show_license(): + async with AsyncSession(): + rqst = Request('GET', '/license') + async with rqst.fetch() as resp: + data = await resp.json() + if data['status'] == 'valid': + print_done('Your Backend.AI lincese is valid.') + print(tabulate([(k, v) for k, v in data['certificate'].items()])) + else: + print_warn('Your Backend.AI lincese is valid.') + try: + asyncio.run(_show_license()) + except Exception as e: + print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/cli/admin/manager.py b/src/ai/backend/client/cli/admin/manager.py new file mode 100644 index 0000000000..19b7141ab7 --- /dev/null +++ b/src/ai/backend/client/cli/admin/manager.py @@ -0,0 +1,207 @@ +import json +from pathlib import Path +import sys +import time + +import appdirs +import click +from tabulate import tabulate + +from ai.backend.cli.interaction import ask_yn + +from . import admin +from ..pretty import print_done, print_error, print_fail, print_info, print_wait +from ..session import Session + + +@admin.group() +def manager(): + """Set of manager control operations.""" + + +@manager.command() +def status(): + """Show the manager's current status.""" + try: + with Session() as session: + resp = session.Manager.status() + print(tabulate([('Status', 'Active Sessions'), + (resp['status'], resp['active_sessions'])], + headers='firstrow')) + except Exception as e: + print_error(e) + sys.exit(1) + + +@manager.command() +@click.option('--wait', is_flag=True, + help='Hold up freezing the manager until ' + 'there are no running sessions in the manager.') +@click.option('--force-kill', is_flag=True, + help='Kill all running sessions immediately and freeze the manager.') +def freeze(wait, force_kill): + """Freeze manager.""" + if wait and force_kill: + print('You cannot use both --wait and --force-kill options ' + 'at the same time.', file=sys.stderr) + return + try: + with Session() as session: + if wait: + while True: + resp = session.Manager.status() + active_sessions_num = resp['active_sessions'] + if active_sessions_num == 0: + break + print_wait('Waiting for all sessions terminated... ({0} left)' + .format(active_sessions_num)) + time.sleep(3) + print_done('All sessions are terminated.') + + if force_kill: + print_wait('Killing all sessions...') + + session.Manager.freeze(force_kill=force_kill) + + if force_kill: + print_done('All sessions are killed.') + + print('Manager is successfully frozen.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@manager.command() +def unfreeze(): + """Unfreeze manager.""" + try: + with Session() as session: + session.Manager.unfreeze() + print('Manager is successfully unfrozen.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@admin.group() +def announcement(): + """Global announcement related commands""" + + +@announcement.command() +def get(): + """Get current announcement.""" + try: + with Session() as session: + result = session.Manager.get_announcement() + if result.get('enabled', False): + msg = result.get('message') + print(msg) + else: + print('No announcements.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@announcement.command() +@click.option('-m', '--message', default=None, type=click.STRING) +def update(message): + """ + Post new announcement. + + MESSAGE: Announcement message. + """ + try: + with Session() as session: + if message is None: + message = click.edit( + "", + ) + if message is None: + print_info('Cancelled') + sys.exit(1) + session.Manager.update_announcement(enabled=True, message=message) + print_done('Posted new announcement.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@announcement.command() +def delete(): + """Delete current announcement.""" + if not ask_yn(): + print_info('Cancelled.') + sys.exit(1) + try: + with Session() as session: + session.Manager.update_announcement(enabled=False) + print_done('Deleted announcement.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@announcement.command() +def dismiss(): + """Do not show the same announcement again.""" + if not ask_yn(): + print_info('Cancelled.') + sys.exit(1) + try: + local_state_path = Path(appdirs.user_state_dir('backend.ai', 'Lablup')) + with open(local_state_path / 'announcement.json', 'rb') as f: + state = json.load(f) + state['dismissed'] = True + with open(local_state_path / 'announcement.json', 'w') as f: + json.dump(state, f) + print_done('Dismissed the last shown announcement.') + except (IOError, json.JSONDecodeError): + print_fail('No announcements seen yet.') + sys.exit(1) + except Exception as e: + print_error(e) + sys.exit(1) + + +@manager.group() +def scheduler(): + """ + The scheduler operation command group. + """ + pass + + +@scheduler.command() +@click.argument('agent_ids', nargs=-1) +def include_agents(agent_ids): + """ + Include agents in scheduling, meaning that the given agents + will be considered to be ready for creating new session containers. + """ + try: + with Session() as session: + session.Manager.scheduler_op('include-agents', agent_ids) + print_done('The given agents now accepts new sessions.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@scheduler.command() +@click.argument('agent_ids', nargs=-1) +def exclude_agents(agent_ids): + """ + Exclude agents from scheduling, meaning that the given agents + will no longer start new sessions unless they are "included" again, + regardless of their restarts and rejoining events. + """ + try: + with Session() as session: + session.Manager.scheduler_op('exclude-agents', agent_ids) + print_done('The given agents will no longer start new sessions.') + except Exception as e: + print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/cli/admin/resource.py b/src/ai/backend/client/cli/admin/resource.py new file mode 100644 index 0000000000..480141e0e5 --- /dev/null +++ b/src/ai/backend/client/cli/admin/resource.py @@ -0,0 +1,144 @@ +import click + +from . import admin +from ...session import Session +from ..pretty import print_error + + +@admin.group() +def resource() -> None: + """ + Resource administration commands. + """ + + +@resource.command() +def query_slots(): + """ + Get available resource slots. + """ + with Session() as session: + try: + ret = session.Resource.get_resource_slots() + for key, value in ret.items(): + print(key, '(' + value + ')') + except Exception as e: + print_error(e) + + +@resource.command() +def vfolder_types(): + """ + Get available vfolder types. + """ + with Session() as session: + try: + ret = session.Resource.get_vfolder_types() + for t in ret: + print(t) + except Exception as e: + print_error(e) + + +@resource.command() +def docker_registries(): + """ + Get registered docker registries. + """ + with Session() as session: + try: + ret = session.Resource.get_docker_registries() + for t in ret: + print(t) + except Exception as e: + print_error(e) + + +@resource.command() +def recalculate_usage(): + """ + Re-calculate resource occupation by sessions. + + Sometime, reported allocated resources is deviated from the actual value. + By executing this command, the discrepancy will be corrected with real value. + """ + with Session() as session: + try: + session.Resource.recalculate_usage() + print('Resource allocation is re-calculated.') + except Exception as e: + print_error(e) + + +@resource.command() +@click.argument('month', metavar='MONTH') +@click.argument('groups', metavar='GROUP_IDS', nargs=-1) +def usage_per_month(month, groups): + """ + Get session usage stats of target groups for specific month. + + \b + MONTH: Target month to get usage (yyyymm). + GROUP_IDS: IDs of target groups to get usage (UUID). + """ + with Session() as session: + ret = session.Resource.usage_per_month(month, list(groups)) + for item in ret: + print('Group:', item['g_name'] + ' (' + item['g_id'] + ')') + print(' Domain:', item['domain_name']) + print(' Total Allocated:', item['g_smp'], 'core(s)', '/', + item['g_mem_allocated'], 'mem (bytes)') + print(' Total CPU / Memory Used:', item['g_cpu_used'], '(s)', '/', + item['g_mem_used'], '(bytes)') + print(' Total I/O Read / Write:', item['g_io_read'], '/', + item['g_io_write'], '(bytes)') + print(' GPU Devices:', item['g_device_type']) + print(' Containers (' + str(len(item['c_infos'])) + '):') + for cinfo in item['c_infos']: + print(' Identity:', cinfo['name'], '/', cinfo['access_key']) + print(' Image:', cinfo['image_name']) + print(' Duration:', cinfo['used_days'], 'day(s)', + '(' + cinfo['created_at'] + ' ~ ' + cinfo['terminated_at'] + ')') + print(' Allocated:', cinfo['smp'], 'core(s)', '/', + cinfo['mem_allocated'], 'mem (bytes)') + print(' CPU / Memory Used:', cinfo['io_read'], '/', cinfo['io_write'], '(bytes)') + print(' I/O Read / Write:', cinfo['io_read'], '/', cinfo['io_write'], '(bytes)') + print(' NFS mounted:', cinfo['nfs']) + print(' GPU Device:', cinfo['device_type']) + print(' ----------------------------------------') + print() + + +@resource.command() +@click.argument('group') +@click.argument('start_date') +@click.argument('end_date') +def usage_per_period(group, start_date, end_date): + with Session() as session: + item = session.Resource.usage_per_period(group, start_date, end_date) + if 'g_id' in item: + print('Group:', item['g_name'] + ' (' + item['g_id'] + ')') + print(' Domain:', item['domain_name']) + print(' Total Allocated:', item['g_smp'], 'core(s)', '/', + item['g_mem_allocated'], 'mem (bytes)') + print(' Total CPU / Memory Used:', item['g_cpu_used'], '(s)', '/', + item['g_mem_used'], '(bytes)') + print(' Total I/O Read / Write:', item['g_io_read'], '/', + item['g_io_write'], '(bytes)') + print(' GPU Devices:', item['g_device_type']) + print(' Containers (' + str(len(item['c_infos'])) + '):') + for cinfo in item['c_infos']: + print(' Identity:', cinfo['name'], '/', cinfo['access_key']) + print(' Image:', cinfo['image_name']) + print(' Duration:', cinfo['used_days'], 'day(s)', + '(' + cinfo['created_at'] + ' ~ ' + cinfo['terminated_at'] + ')') + print(' Allocated:', cinfo['smp'], 'core(s)', '/', + cinfo['mem_allocated'], 'mem (bytes)') + print(' CPU / Memory Used:', cinfo['io_read'], '/', cinfo['io_write'], '(bytes)') + print(' I/O Read / Write:', cinfo['io_read'], '/', cinfo['io_write'], '(bytes)') + print(' NFS mounted:', cinfo['nfs']) + print(' GPU Device:', cinfo['device_type']) + print(' ----------------------------------------') + print() + else: + print('No usage information during the period.') diff --git a/src/ai/backend/client/cli/admin/resource_policy.py b/src/ai/backend/client/cli/admin/resource_policy.py new file mode 100644 index 0000000000..ff09775ffc --- /dev/null +++ b/src/ai/backend/client/cli/admin/resource_policy.py @@ -0,0 +1,220 @@ +import sys + +import click + +from ai.backend.cli.interaction import ask_yn +from ai.backend.client.session import Session +from ai.backend.client.func.keypair_resource_policy import ( + _default_list_fields, + _default_detail_fields, +) +# from ai.backend.client.output.fields import keypair_resource_policy_fields +from . import admin +from ..pretty import print_info + +from ..types import CLIContext + + +@admin.group() +def keypair_resource_policy() -> None: + """ + KeyPair resource policy administration commands. + """ + + +@keypair_resource_policy.command() +@click.pass_obj +@click.argument('name', type=str) +def info(ctx: CLIContext, name: str) -> None: + """ + Show details about a keypair resource policy. When `name` option is omitted, the + resource policy for the current access_key will be returned. + """ + with Session() as session: + try: + rp = session.KeypairResourcePolicy(session.config.access_key) + item = rp.info(name) + ctx.output.print_item(item, _default_detail_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@keypair_resource_policy.command() +@click.pass_obj +def list(ctx): + """ + List and manage keypair resource policies. + (admin privilege required) + """ + with Session() as session: + try: + items = session.KeypairResourcePolicy.list() + ctx.output.print_list(items, _default_list_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@keypair_resource_policy.command() +@click.pass_obj +@click.argument('name', type=str, default=None, metavar='NAME') +@click.option('--default-for-unspecified', type=str, default='UNLIMITED', + help='Default behavior for unspecified resources: ' + 'LIMITED, UNLIMITED') +@click.option('--total-resource-slots', type=str, default='{}', + help='Set total resource slots.') +@click.option('--max-concurrent-sessions', type=int, default=30, + help='Number of maximum concurrent sessions.') +@click.option('--max-containers-per-session', type=int, default=1, + help='Number of maximum containers per session.') +@click.option('--max-vfolder-count', type=int, default=10, + help='Number of maximum virtual folders allowed.') +@click.option('--max-vfolder-size', type=int, default=0, + help='Maximum virtual folder size (future plan).') +@click.option('--idle-timeout', type=int, default=1800, + help='The maximum period of time allowed for kernels to wait ' + 'further requests.') +# @click.option('--allowed-vfolder-hosts', type=click.Tuple(str), default=['local'], +# help='Locations to create virtual folders.') +@click.option('--allowed-vfolder-hosts', default=['local'], + help='Locations to create virtual folders.') +def add(ctx: CLIContext, name, default_for_unspecified, total_resource_slots, max_concurrent_sessions, + max_containers_per_session, max_vfolder_count, max_vfolder_size, + idle_timeout, allowed_vfolder_hosts): + """ + Add a new keypair resource policy. + + NAME: NAME of a new keypair resource policy. + """ + with Session() as session: + try: + data = session.KeypairResourcePolicy.create( + name, + default_for_unspecified=default_for_unspecified, + total_resource_slots=total_resource_slots, + max_concurrent_sessions=max_concurrent_sessions, + max_containers_per_session=max_containers_per_session, + max_vfolder_count=max_vfolder_count, + max_vfolder_size=max_vfolder_size, + idle_timeout=idle_timeout, + allowed_vfolder_hosts=allowed_vfolder_hosts, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='resource_policy', + action_name='add', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='resource_policy', + action_name='add', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + item_name='resource_policy', + ) + + +@keypair_resource_policy.command() +@click.pass_obj +@click.argument('name', type=str, default=None, metavar='NAME') +@click.option('--default-for-unspecified', type=str, + help='Default behavior for unspecified resources: ' + 'LIMITED, UNLIMITED') +@click.option('--total-resource-slots', type=str, + help='Set total resource slots.') +@click.option('--max-concurrent-sessions', type=int, + help='Number of maximum concurrent sessions.') +@click.option('--max-containers-per-session', type=int, + help='Number of maximum containers per session.') +@click.option('--max-vfolder-count', type=int, + help='Number of maximum virtual folders allowed.') +@click.option('--max-vfolder-size', type=int, + help='Maximum virtual folder size (future plan).') +@click.option('--idle-timeout', type=int, + help='The maximum period of time allowed for kernels to wait ' + 'further requests.') +@click.option('--allowed-vfolder-hosts', help='Locations to create virtual folders.') +def update(ctx: CLIContext, name, default_for_unspecified, total_resource_slots, + max_concurrent_sessions, max_containers_per_session, max_vfolder_count, + max_vfolder_size, idle_timeout, allowed_vfolder_hosts): + """ + Update an existing keypair resource policy. + + NAME: NAME of a keypair resource policy to update. + """ + with Session() as session: + try: + data = session.KeypairResourcePolicy.update( + name, + default_for_unspecified=default_for_unspecified, + total_resource_slots=total_resource_slots, + max_concurrent_sessions=max_concurrent_sessions, + max_containers_per_session=max_containers_per_session, + max_vfolder_count=max_vfolder_count, + max_vfolder_size=max_vfolder_size, + idle_timeout=idle_timeout, + allowed_vfolder_hosts=allowed_vfolder_hosts, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='resource_policy', + action_name='update', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='resource_policy', + action_name='update', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'name': name, + }, + ) + + +@keypair_resource_policy.command() +@click.pass_obj +@click.argument('name', type=str, default=None, metavar='NAME') +def delete(ctx: CLIContext, name): + """ + Delete a keypair resource policy. + + NAME: NAME of a keypair resource policy to delete. + """ + with Session() as session: + if not ask_yn(): + print_info('Cancelled.') + sys.exit(1) + try: + data = session.KeypairResourcePolicy.delete(name) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='resource_policy', + action_name='deletion', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='resource_policy', + action_name='deletion', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'name': name, + }, + ) diff --git a/src/ai/backend/client/cli/admin/scaling_group.py b/src/ai/backend/client/cli/admin/scaling_group.py new file mode 100644 index 0000000000..adcbeb3e17 --- /dev/null +++ b/src/ai/backend/client/cli/admin/scaling_group.py @@ -0,0 +1,285 @@ +import sys + +import click + +from ai.backend.client.session import Session +from ai.backend.client.func.scaling_group import ( + _default_list_fields, + _default_detail_fields, +) +from ai.backend.client.output.fields import scaling_group_fields +from . import admin +from ..params import JSONParamType +from ..types import CLIContext + + +@admin.group() +def scaling_group() -> None: + """ + Scaling group (resource group) administration commands. + """ + + +@scaling_group.command() +@click.pass_obj +@click.argument('group', type=str, metavar='GROUP_NAME') +def get_available(ctx: CLIContext, group: str) -> None: + with Session() as session: + try: + items = session.ScalingGroup.list_available(group) + ctx.output.print_list(items, [scaling_group_fields['name']]) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@scaling_group.command() +@click.pass_obj +@click.argument('name', type=str) +def info(ctx: CLIContext, name: str) -> None: + """ + Show the information about the given scaling group. + (superadmin privilege required) + """ + with Session() as session: + try: + item = session.ScalingGroup.detail(name=name) + ctx.output.print_item(item, _default_detail_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@scaling_group.command() +@click.pass_obj +def list(ctx: CLIContext) -> None: + """ + List and manage scaling groups. + (superadmin privilege required) + """ + with Session() as session: + try: + items = session.ScalingGroup.list() + ctx.output.print_list(items, _default_list_fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@scaling_group.command() +@click.pass_obj +@click.argument('name', type=str, metavar='NAME') +@click.option('-d', '--description', type=str, default='', + help='Description of new scaling group') +@click.option('-i', '--inactive', is_flag=True, + help='New scaling group will be inactive.') +@click.option('--driver', type=str, default='static', + help='Set driver.') +@click.option('--driver-opts', type=JSONParamType(), default='{}', + help='Set driver options as a JSON string.') +@click.option('--scheduler', type=str, default='fifo', + help='Set scheduler.') +@click.option('--scheduler-opts', type=JSONParamType(), default='{}', + help='Set scheduler options as a JSON string.') +def add(ctx: CLIContext, name, description, inactive, + driver, driver_opts, scheduler, scheduler_opts): + """ + Add a new scaling group. + + NAME: Name of new scaling group. + """ + with Session() as session: + try: + data = session.ScalingGroup.create( + name, + description=description, + is_active=not inactive, + driver=driver, + driver_opts=driver_opts, + scheduler=scheduler, + scheduler_opts=scheduler_opts, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='scaling_group', + action_name='add', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='scaling_group', + action_name='add', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + item_name='scaling_group', + ) + + +@scaling_group.command() +@click.pass_obj +@click.argument('name', type=str, metavar='NAME') +@click.option('-d', '--description', type=str, default='', + help='Description of new scaling group') +@click.option('-i', '--inactive', is_flag=True, + help='New scaling group will be inactive.') +@click.option('--driver', type=str, default='static', + help='Set driver.') +@click.option('--driver-opts', type=JSONParamType(), default=None, + help='Set driver options as a JSON string.') +@click.option('--scheduler', type=str, default='fifo', + help='Set scheduler.') +@click.option('--scheduler-opts', type=JSONParamType(), default=None, + help='Set scheduler options as a JSON string.') +def update(ctx: CLIContext, name, description, inactive, + driver, driver_opts, scheduler, scheduler_opts): + """ + Update existing scaling group. + + NAME: Name of new scaling group. + """ + with Session() as session: + try: + data = session.ScalingGroup.update( + name, + description=description, + is_active=not inactive, + driver=driver, + driver_opts=driver_opts, + scheduler=scheduler, + scheduler_opts=scheduler_opts, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='scaling_group', + action_name='update', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='scaling_group', + action_name='update', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'name': name, + }, + ) + + +@scaling_group.command() +@click.pass_obj +@click.argument('name', type=str, metavar='NAME') +def delete(ctx: CLIContext, name): + """ + Delete an existing scaling group. + + NAME: Name of a scaling group to delete. + """ + with Session() as session: + try: + data = session.ScalingGroup.delete(name) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='scaling_group', + action_name='deletion', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='scaling_group', + action_name='deletion', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'name': name, + }, + ) + + +@scaling_group.command() +@click.pass_obj +@click.argument('scaling_group', type=str, metavar='SCALING_GROUP') +@click.argument('domain', type=str, metavar='DOMAIN') +def associate_scaling_group(ctx: CLIContext, scaling_group, domain): + """ + Associate a domain with a scaling_group. + + \b + SCALING_GROUP: The name of a scaling group. + DOMAIN: The name of a domain. + """ + with Session() as session: + try: + data = session.ScalingGroup.associate_domain(scaling_group, domain) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='scaling_group', + action_name='scaling_group_association', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='scaling_group', + action_name='scaling_group_association', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'detail_msg': 'Scaling group {} is assocatiated with domain {}.' + .format(scaling_group, domain), + }, + ) + + +@scaling_group.command() +@click.pass_obj +@click.argument('scaling_group', type=str, metavar='SCALING_GROUP') +@click.argument('domain', type=str, metavar='DOMAIN') +def dissociate_scaling_group(ctx: CLIContext, scaling_group, domain): + """ + Dissociate a domain from a scaling_group. + + \b + SCALING_GROUP: The name of a scaling group. + DOMAIN: The name of a domain. + """ + with Session() as session: + try: + data = session.ScalingGroup.dissociate_domain(scaling_group, domain) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='scaling_group', + action_name='scaling_group_dissociation', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='scaling_group', + action_name='scaling_group_dissociation', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + item_name='scaling_group', + extra_info={ + 'detail_msg': 'Scaling group {} is dissociated from domain {}.' + .format(scaling_group, domain), + }, + ) diff --git a/src/ai/backend/client/cli/admin/session.py b/src/ai/backend/client/cli/admin/session.py new file mode 100644 index 0000000000..ef437a1002 --- /dev/null +++ b/src/ai/backend/client/cli/admin/session.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import sys +from typing import ( + Any, + Dict, + List, +) +import uuid + +import click + +from ai.backend.client.session import Session +from ai.backend.client.output.fields import session_fields, session_fields_v5 +from ai.backend.client.output.types import FieldSpec +from . import admin +from ..main import main +from ..pretty import print_fail +from ..session import session as user_session +from ..types import CLIContext + + +SessionItem = Dict[str, Any] + + +@admin.group() +def session() -> None: + """ + Session administration commands. + """ + + +def _list_cmd(name: str = "list", docs: str = None): + + @click.pass_obj + @click.option('-s', '--status', default=None, + type=click.Choice([ + 'PENDING', 'SCHEDULED', + 'PREPARING', 'BUILDING', 'RUNNING', 'RESTARTING', + 'RESIZING', 'SUSPENDED', 'TERMINATING', + 'TERMINATED', 'ERROR', 'CANCELLED', + 'ALL', # special case + ]), + help='Filter by the given status') + @click.option('--access-key', type=str, default=None, + help='Get sessions for a specific access key ' + '(only works if you are a super-admin)') + @click.option('--name-only', is_flag=True, help='Display session names only.') + @click.option('--dead', is_flag=True, + help='Filter only dead sessions. Ignores --status option.') + @click.option('--running', is_flag=True, + help='Filter only scheduled and running sessions. Ignores --status option.') + @click.option('--detail', is_flag=True, help='Show more details using more columns.') + @click.option('-f', '--format', default=None, help='Display only specified fields.') + @click.option('--plain', is_flag=True, + help='Display the session list without decorative line drawings and the header.') + @click.option('--filter', 'filter_', default=None, help='Set the query filter expression.') + @click.option('--order', default=None, help='Set the query ordering expression.') + @click.option('--offset', default=0, type=int, + help='The index of the current page start for pagination.') + @click.option('--limit', default=None, type=int, + help='The page size for pagination.') + def list( + ctx: CLIContext, + status: str | None, + access_key: str | None, + name_only: str | None, + dead: bool, + running: bool, + detail: bool, + format: str | None, + plain: bool, + filter_: str | None, + order: str | None, + offset: int, + limit: int | None, + ) -> None: + """ + List and manage compute sessions. + """ + fields: List[FieldSpec] = [] + with Session() as session: + is_admin = session.KeyPair(session.config.access_key).info()['is_admin'] + try: + fields.append(session_fields['name']) + if is_admin: + fields.append(session_fields['access_key']) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + if name_only: + pass + elif format is not None: + options = format.split(',') + for opt in options: + if opt not in session_fields: + ctx.output.print_fail(f"There is no such format option: {opt}") + sys.exit(1) + fields = [ + session_fields[opt] for opt in options + ] + else: + if session.api_version[0] >= 6: + fields.append(session_fields['session_id']) + fields.extend([ + session_fields['group_name'], + session_fields['kernel_id'], + session_fields['image'], + session_fields['type'], + session_fields['status'], + session_fields['status_info'], + session_fields['status_changed'], + session_fields['result'], + ]) + if detail: + fields.extend([ + session_fields['tag'], + session_fields['created_at'], + session_fields['occupied_slots'], + ]) + + no_match_name = None + if status is None: + status = ",".join([ + "PENDING", + "SCHEDULED", + "PREPARING", + "PULLING", + "RUNNING", + "RESTARTING", + "TERMINATING", + "RESIZING", + "SUSPENDED", + "ERROR", + ]) + no_match_name = 'active' + if running: + status = ",".join([ + "PREPARING", + "PULLING", + "RUNNING", + ]) + no_match_name = 'running' + if dead: + status = ",".join([ + "CANCELLED", + "TERMINATED", + ]) + no_match_name = 'dead' + if status == 'ALL': + status = ",".join([ + "PENDING", + "SCHEDULED", + "PREPARING", + "PULLING", + "RUNNING", + "RESTARTING", + "TERMINATING", + "RESIZING", + "SUSPENDED", + "ERROR", + "CANCELLED", + "TERMINATED", + ]) + no_match_name = 'in any status' + if no_match_name is None: + no_match_name = status.lower() + + try: + with Session() as session: + fetch_func = lambda pg_offset, pg_size: session.ComputeSession.paginated_list( + status, access_key, + fields=fields, + page_offset=pg_offset, + page_size=pg_size, + filter=filter_, + order=order, + ) + ctx.output.print_paginated_list( + fetch_func, + initial_page_offset=offset, + page_size=limit, + ) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + list.__name__ = name + if docs is not None: + list.__doc__ = docs + return list + + +# Make it available as: +# - backend.ai ps +# - backend.ai admin session list +main.command()(_list_cmd(name="ps", docs="Alias of \"session list\"")) +user_session.command()(_list_cmd(docs="Alias of \"admin session list\"")) +session.command()(_list_cmd()) + + +def _info_cmd(docs: str = None): + + @click.pass_obj + @click.argument('session_id', metavar='SESSID') + def info(ctx: CLIContext, session_id: str) -> None: + """ + Show detailed information for a running compute session. + """ + with Session() as session_: + fields = [ + session_fields['name'], + ] + if session_.api_version[0] >= 6: + fields.append(session_fields['session_id']) + fields.append(session_fields['kernel_id']) + fields.extend([ + session_fields['image'], + session_fields['tag'], + session_fields['created_at'], + session_fields['terminated_at'], + session_fields['status'], + session_fields['status_info'], + session_fields['status_data'], + session_fields['occupied_slots'], + ]) + if session_.api_version[0] >= 6: + fields.append(session_fields['containers']) + else: + fields.append(session_fields_v5['containers']) + fields.append(session_fields['dependencies']) + q = 'query($id: UUID!) {' \ + ' compute_session(id: $id) {' \ + ' $fields' \ + ' }' \ + '}' + try: + uuid.UUID(session_id) + except ValueError: + print_fail("In API v5 or later, the session ID must be given in the UUID format.") + sys.exit(1) + v = {'id': session_id} + q = q.replace('$fields', ' '.join(f.field_ref for f in fields)) + try: + resp = session_.Admin.query(q, v) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + if resp['compute_session'] is None: + if session_.api_version[0] < 5: + ctx.output.print_fail('There is no such running compute session.') + else: + ctx.output.print_fail('There is no such compute session.') + sys.exit(1) + ctx.output.print_item(resp['compute_session'], fields) + + if docs is not None: + info.__doc__ = docs + return info + + +main.command()(_info_cmd(docs="Alias of \"session info\"")) +user_session.command()(_info_cmd(docs="Alias of \"admin session info\"")) +session.command()(_info_cmd()) diff --git a/src/ai/backend/client/cli/admin/storage.py b/src/ai/backend/client/cli/admin/storage.py new file mode 100644 index 0000000000..ce9f79913b --- /dev/null +++ b/src/ai/backend/client/cli/admin/storage.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import sys + +import click + +from ai.backend.client.session import Session +from ai.backend.client.output.fields import storage_fields +from . import admin +from ..types import CLIContext + + +@admin.group() +def storage() -> None: + """ + Storage proxy administration commands. + """ + + +@storage.command() +@click.pass_obj +@click.argument('vfolder_host') +def info(ctx: CLIContext, vfolder_host: str) -> None: + """ + Show the information about the given storage volume. + (super-admin privilege required) + """ + fields = [ + storage_fields['id'], + storage_fields['backend'], + storage_fields['capabilities'], + storage_fields['path'], + storage_fields['fsprefix'], + storage_fields['hardware_metadata'], + storage_fields['performance_metric'], + ] + with Session() as session: + try: + item = session.Storage.detail( + vfolder_host=vfolder_host, + fields=fields, + ) + ctx.output.print_item(item, fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@storage.command() +@click.pass_obj +@click.option('--filter', 'filter_', default=None, + help='Set the query filter expression.') +@click.option('--order', default=None, + help='Set the query ordering expression.') +@click.option('--offset', default=0, + help='The index of the current page start for pagination.') +@click.option('--limit', default=None, + help='The page size for pagination.') +def list(ctx: CLIContext, filter_, order, offset, limit) -> None: + """ + List storage volumes. + (super-admin privilege required) + """ + fields = [ + storage_fields['id'], + storage_fields['backend'], + storage_fields['capabilities'], + ] + try: + with Session() as session: + fetch_func = lambda pg_offset, pg_size: session.Storage.paginated_list( + fields=fields, + page_offset=pg_offset, + page_size=pg_size, + filter=filter_, + order=order, + ) + ctx.output.print_paginated_list( + fetch_func, + initial_page_offset=offset, + page_size=limit, + ) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/cli/admin/user.py b/src/ai/backend/client/cli/admin/user.py new file mode 100644 index 0000000000..2abaf5b9ab --- /dev/null +++ b/src/ai/backend/client/cli/admin/user.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import sys + +import click + +from ai.backend.cli.interaction import ask_yn +from ai.backend.client.session import Session +from ai.backend.client.output.fields import user_fields +from ..pretty import print_info +from ..types import CLIContext +from . import admin + + +@admin.group() +def user() -> None: + """ + User administration commands. + """ + + +@user.command() +@click.pass_obj +@click.option('-e', '--email', type=str, default=None, + help='Email of a user to display.') +def info(ctx: CLIContext, email: str) -> None: + """ + Show the information about the given user by email. If email is not give, + requester's information will be displayed. + """ + fields = [ + user_fields['uuid'], + user_fields['username'], + user_fields['role'], + user_fields['email'], + user_fields['full_name'], + user_fields['need_password_change'], + user_fields['status'], + user_fields['status_info'], + user_fields['created_at'], + user_fields['domain_name'], + user_fields['groups'], + ] + with Session() as session: + try: + item = session.User.detail(email=email, fields=fields) + ctx.output.print_item(item, fields=fields) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@user.command() +@click.pass_obj +@click.option('-s', '--status', type=str, default=None, + help='Filter users in a specific state (active, inactive, deleted, before-verification).') +@click.option('-g', '--group', type=str, default=None, + help='Filter by group ID.') +@click.option('--filter', 'filter_', default=None, + help='Set the query filter expression.') +@click.option('--order', default=None, + help='Set the query ordering expression.') +@click.option('--offset', default=0, + help='The index of the current page start for pagination.') +@click.option('--limit', default=None, + help='The page size for pagination.') +def list(ctx: CLIContext, status, group, filter_, order, offset, limit) -> None: + """ + List users. + (admin privilege required) + """ + fields = [ + user_fields['uuid'], + user_fields['username'], + user_fields['role'], + user_fields['email'], + user_fields['full_name'], + user_fields['need_password_change'], + user_fields['status'], + user_fields['status_info'], + user_fields['created_at'], + user_fields['domain_name'], + user_fields['groups'], + ] + try: + with Session() as session: + fetch_func = lambda pg_offset, pg_size: session.User.paginated_list( + status, group, + fields=fields, + page_offset=pg_offset, + page_size=pg_size, + filter=filter_, + order=order, + ) + ctx.output.print_paginated_list( + fetch_func, + initial_page_offset=offset, + page_size=limit, + ) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + +@user.command() +@click.pass_obj +@click.argument('domain_name', type=str, metavar='DOMAIN_NAME') +@click.argument('email', type=str, metavar='EMAIL') +@click.argument('password', type=str, metavar='PASSWORD') +@click.option('-u', '--username', type=str, default='', help='Username.') +@click.option('-n', '--full-name', type=str, default='', help='Full name.') +@click.option('-r', '--role', type=str, default='user', + help='Role of the user. One of (admin, user, monitor).') +@click.option('-s', '--status', type=str, default='active', + help='Account status. One of (active, inactive, deleted, before-verification).') +@click.option('--need-password-change', is_flag=True, + help='Flag indicate that user needs to change password. ' + 'Useful when admin manually create password.') +@click.option('--description', type=str, default='', help='Description of the user.') +def add(ctx: CLIContext, domain_name, email, password, username, full_name, role, status, + need_password_change, description): + """ + Add new user. A user must belong to a domain, so DOMAIN_NAME should be provided. + + \b + DOMAIN_NAME: Name of the domain where new user belongs to. + EMAIL: Email of new user. + PASSWORD: Password of new user. + """ + with Session() as session: + try: + data = session.User.create( + domain_name, email, password, + username=username, full_name=full_name, role=role, + status=status, + need_password_change=need_password_change, + description=description, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='user', + action_name='add', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='user', + action_name='add', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + item_name='user', + ) + + +@user.command() +@click.pass_obj +@click.argument('email', type=str, metavar='EMAIL') +@click.option('-p', '--password', type=str, help='Password.') +@click.option('-u', '--username', type=str, help='Username.') +@click.option('-n', '--full-name', type=str, help='Full name.') +@click.option('-d', '--domain-name', type=str, help='Domain name.') +@click.option('-r', '--role', type=str, default='user', + help='Role of the user. One of (admin, user, monitor).') +@click.option('-s', '--status', type=str, + help='Account status. One of (active, inactive, deleted, before-verification).') +@click.option('--need-password-change', is_flag=True, + help='Flag indicate that user needs to change password. ' + 'Useful when admin manually create password.') +@click.option('--description', type=str, default='', help='Description of the user.') +def update(ctx: CLIContext, email, password, username, full_name, domain_name, role, status, + need_password_change, description): + """ + Update an existing user. + + EMAIL: Email of user to update. + """ + with Session() as session: + try: + data = session.User.update( + email, + password=password, username=username, full_name=full_name, + domain_name=domain_name, + role=role, status=status, need_password_change=need_password_change, + description=description, + ) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='user', + action_name='update', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='user', + action_name='update', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'email': email, + }, + ) + + +@user.command() +@click.pass_obj +@click.argument('email', type=str, metavar='EMAIL') +def delete(ctx: CLIContext, email): + """ + Inactivate an existing user. + + EMAIL: Email of user to inactivate. + """ + with Session() as session: + try: + data = session.User.delete(email) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='user', + action_name='deletion', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='user', + action_name='deletion', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'email': email, + }, + ) + + +@user.command() +@click.pass_obj +@click.argument('email', type=str, metavar='EMAIL') +@click.option('--purge-shared-vfolders', is_flag=True, default=False, + help='Delete user\'s all virtual folders. ' + 'If False, shared folders will not be deleted ' + 'and migrated the ownership to the requested admin.') +def purge(ctx: CLIContext, email, purge_shared_vfolders): + """ + Delete an existing user. This action cannot be undone. + + NAME: Name of a domain to delete. + """ + with Session() as session: + try: + if not ask_yn(): + print_info('Cancelled') + sys.exit(1) + data = session.User.purge(email, purge_shared_vfolders) + except Exception as e: + ctx.output.print_mutation_error( + e, + item_name='user', + action_name='purge', + ) + sys.exit(1) + if not data['ok']: + ctx.output.print_mutation_error( + msg=data['msg'], + item_name='user', + action_name='purge', + ) + sys.exit(1) + ctx.output.print_mutation_result( + data, + extra_info={ + 'email': email, + }, + ) diff --git a/src/ai/backend/client/cli/admin/vfolder.py b/src/ai/backend/client/cli/admin/vfolder.py new file mode 100644 index 0000000000..e8ea7c04f0 --- /dev/null +++ b/src/ai/backend/client/cli/admin/vfolder.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import sys + +import click +import humanize +from tabulate import tabulate + +from ai.backend.client.session import Session +from ai.backend.client.func.vfolder import _default_list_fields +from ..pretty import print_error +from ..types import CLIContext +from ..vfolder import vfolder as user_vfolder +from . import admin + + +@admin.group() +def vfolder() -> None: + """ + VFolder administration commands. + """ + + +def _list_cmd(docs: str = None): + + @click.pass_obj + @click.option('-g', '--group', type=str, default=None, + help='Filter by group ID.') + @click.option('--filter', 'filter_', default=None, + help='Set the query filter expression.') + @click.option('--order', default=None, + help='Set the query ordering expression.') + @click.option('--offset', default=0, + help='The index of the current page start for pagination.') + @click.option('--limit', default=None, + help='The page size for pagination.') + def list(ctx: CLIContext, group, filter_, order, offset, limit) -> None: + """ + List virtual folders. + """ + try: + with Session() as session: + fetch_func = lambda pg_offset, pg_size: session.VFolder.paginated_list( + group, + fields=_default_list_fields, + page_offset=pg_offset, + page_size=pg_size, + filter=filter_, + order=order, + ) + ctx.output.print_paginated_list( + fetch_func, + initial_page_offset=offset, + page_size=limit, + ) + except Exception as e: + ctx.output.print_error(e) + sys.exit(1) + + if docs is not None: + list.__doc__ = docs + return list + + +user_vfolder.command()(_list_cmd()) +vfolder.command()(_list_cmd()) + + +@vfolder.command() +def list_hosts(): + """ + List all mounted hosts from virtual folder root. + (superadmin privilege required) + """ + with Session() as session: + try: + resp = session.VFolder.list_all_hosts() + print("Default vfolder host: {}".format(resp['default'])) + print("Mounted hosts: {}".format(', '.join(resp['allowed']))) + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('vfolder_host') +def perf_metric(vfolder_host): + """ + Show the performance statistics of a vfolder host. + (superadmin privilege required) + + A vfolder host consists of a string of the storage proxy name and the volume name + separated by a colon. (e.g., "local:volume1") + """ + with Session() as session: + try: + resp = session.VFolder.get_performance_metric(vfolder_host) + print(tabulate( + [(k, humanize.naturalsize(v, binary=True) if 'bytes' in k else f"{v:.2f}") + for k, v in resp['metric'].items()], + headers=('Key', 'Value'), + )) + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.option('-a', '--agent-id', type=str, default=None, + help='Target agent to fetch fstab contents.') +def get_fstab_contents(agent_id): + """ + Get contents of fstab file from a node. + (superadmin privilege required) + + If agent-id is not specified, manager's fstab contents will be returned. + """ + with Session() as session: + try: + resp = session.VFolder.get_fstab_contents(agent_id) + except Exception as e: + print_error(e) + sys.exit(1) + print(resp) + + +@vfolder.command() +def list_mounts(): + """ + List all mounted hosts in virtual folder root. + (superadmin privilege required) + """ + with Session() as session: + try: + resp = session.VFolder.list_mounts() + except Exception as e: + print_error(e) + sys.exit(1) + print('manager') + for k, v in resp['manager'].items(): + print(' ', k, ':', v) + print('\nagents') + for aid, data in resp['agents'].items(): + print(' ', aid) + for k, v in data.items(): + print(' ', k, ':', v) + + +@vfolder.command() +@click.argument('fs-location', type=str) +@click.argument('name', type=str) +@click.option('-o', '--options', type=str, default=None, help='Mount options.') +@click.option('--edit-fstab', is_flag=True, + help='Edit fstab file to mount permanently.') +def mount_host(fs_location, name, options, edit_fstab): + """ + Mount a host in virtual folder root. + (superadmin privilege required) + + \b + FS-LOCATION: Location of file system to be mounted. + NAME: Name of mounted host. + """ + with Session() as session: + try: + resp = session.VFolder.mount_host(name, fs_location, options, edit_fstab) + except Exception as e: + print_error(e) + sys.exit(1) + print('manager') + for k, v in resp['manager'].items(): + print(' ', k, ':', v) + print('agents') + for aid, data in resp['agents'].items(): + print(' ', aid) + for k, v in data.items(): + print(' ', k, ':', v) + + +@vfolder.command() +@click.argument('name', type=str) +@click.option('--edit-fstab', is_flag=True, + help='Edit fstab file to mount permanently.') +def umount_host(name, edit_fstab): + """ + Unmount a host from virtual folder root. + (superadmin privilege required) + + \b + NAME: Name of mounted host. + """ + with Session() as session: + try: + resp = session.VFolder.umount_host(name, edit_fstab) + except Exception as e: + print_error(e) + sys.exit(1) + print('manager') + for k, v in resp['manager'].items(): + print(' ', k, ':', v) + print('agents') + for aid, data in resp['agents'].items(): + print(' ', aid) + for k, v in data.items(): + print(' ', k, ':', v) diff --git a/src/ai/backend/client/cli/announcement.py b/src/ai/backend/client/cli/announcement.py new file mode 100644 index 0000000000..40192f1d2a --- /dev/null +++ b/src/ai/backend/client/cli/announcement.py @@ -0,0 +1,45 @@ +import hashlib +import json +from pathlib import Path + +import appdirs +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.style import Style + +_printed_announcement = False + + +def announce(msg: str, only_once: bool = True) -> None: + global _printed_announcement + if only_once and _printed_announcement: + return + local_state_path = Path(appdirs.user_state_dir('backend.ai', 'Lablup')) + local_state_path.mkdir(parents=True, exist_ok=True) + try: + with open(local_state_path / 'announcement.json', 'rb') as f_current: + last_state = json.load(f_current) + except IOError: + last_state = {'hash': '', 'dismissed': False} + + hasher = hashlib.sha256() + hasher.update(msg.encode('utf8')) + msg_hash = hasher.hexdigest() + + if not (last_state['hash'] == msg_hash and last_state['dismissed']): + console = Console(stderr=True) + doc = Markdown(msg) + console.print( + Panel( + doc, + title="Server Announcement", + border_style=Style(color='cyan', bold=True), + width=min(console.size.width, 82), + ), + ) + _printed_announcement = True + + last_state['hash'] = msg_hash + with open(local_state_path / 'announcement.json', 'w') as f_new: + json.dump(last_state, f_new) diff --git a/src/ai/backend/client/cli/app.py b/src/ai/backend/client/cli/app.py new file mode 100644 index 0000000000..cb427bbba8 --- /dev/null +++ b/src/ai/backend/client/cli/app.py @@ -0,0 +1,344 @@ +import asyncio +import json +import shlex +import sys +from typing import ( + Union, Optional, + MutableMapping, Dict, + Sequence, List, +) + +import aiohttp +import click + +from .main import main +from .pretty import print_info, print_warn, print_fail, print_error +from ..config import DEFAULT_CHUNK_SIZE +from ..request import Request +from ..session import AsyncSession +from ..compat import asyncio_run, asyncio_run_forever +from ..versioning import get_naming + + +class WSProxy: + __slots__ = ( + 'api_session', 'session_name', + 'app_name', + 'args', 'envs', + 'reader', 'writer', + ) + + def __init__( + self, + api_session: AsyncSession, + session_name: str, + app_name: str, + args: MutableMapping[str, Union[None, str, List[str]]], + envs: MutableMapping[str, str], + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + self.api_session = api_session + self.session_name = session_name + self.app_name = app_name + self.args = args + self.envs = envs + self.reader = reader + self.writer = writer + + async def run(self) -> None: + prefix = get_naming(self.api_session.api_version, 'path') + path = f"/stream/{prefix}/{self.session_name}/tcpproxy" + params = {'app': self.app_name} + + if len(self.args.keys()) > 0: + params['arguments'] = json.dumps(self.args) + if len(self.envs.keys()) > 0: + params['envs'] = json.dumps(self.envs) + + api_rqst = Request( + "GET", path, b'', + params=params, + content_type="application/json") + async with api_rqst.connect_websocket() as ws: + + async def downstream() -> None: + try: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.ERROR: + await self.write_error(msg) + break + elif msg.type == aiohttp.WSMsgType.CLOSE: + if msg.data != aiohttp.WSCloseCode.OK: + await self.write_error(msg) + break + elif msg.type == aiohttp.WSMsgType.BINARY: + self.writer.write(msg.data) + await self.writer.drain() + except ConnectionResetError: + pass # shutting down + except asyncio.CancelledError: + pass + finally: + self.writer.close() + try: + await self.writer.wait_closed() + except (BrokenPipeError, IOError): + # closed + pass + + down_task = asyncio.create_task(downstream()) + try: + while True: + chunk = await self.reader.read(DEFAULT_CHUNK_SIZE) + if not chunk: + break + await ws.send_bytes(chunk) + except ConnectionResetError: + pass # shutting down + except asyncio.CancelledError: + raise + finally: + if not down_task.done(): + down_task.cancel() + await down_task + + async def write_error(self, msg: aiohttp.WSMessage) -> None: + if isinstance(msg.data, bytes): + error_msg = msg.data.decode('utf8') + else: + error_msg = str(msg.data) + rsp = 'HTTP/1.1 503 Service Unavailable\r\n' \ + 'Connection: Closed\r\n\r\n' \ + 'WebSocket reply: {}'.format(error_msg) + self.writer.write(rsp.encode()) + await self.writer.drain() + + +class ProxyRunnerContext: + + __slots__ = ( + 'session_name', 'app_name', + 'protocol', 'host', 'port', + 'args', 'envs', + 'api_session', 'local_server', + 'exit_code', + ) + + session_name: str + app_name: str + protocol: str + host: str + port: int + args: Dict[str, Union[None, str, List[str]]] + envs: Dict[str, str] + api_session: Optional[AsyncSession] + local_server: Optional[asyncio.AbstractServer] + exit_code: int + + def __init__( + self, + host: str, + port: int, + session_name: str, + app_name: str, + *, + protocol: str = 'tcp', + args: Sequence[str] = None, + envs: Sequence[str] = None, + ) -> None: + self.host = host + self.port = port + self.session_name = session_name + self.app_name = app_name + self.protocol = protocol + + self.api_session = None + self.local_server = None + self.exit_code = 0 + + self.args, self.envs = {}, {} + if args is not None and len(args) > 0: + for argline in args: + tokens = [] + for token in shlex.shlex(argline, + punctuation_chars=True): + kv = token.split('=', maxsplit=1) + if len(kv) == 1: + tokens.append(shlex.split(token)[0]) + else: + tokens.append(kv[0]) + tokens.append(shlex.split(kv[1])[0]) + + if len(tokens) == 1: + self.args[tokens[0]] = None + elif len(tokens) == 2: + self.args[tokens[0]] = tokens[1] + else: + self.args[tokens[0]] = tokens[1:] + if envs is not None and len(envs) > 0: + for envline in envs: + split = envline.strip().split('=', maxsplit=2) + if len(split) == 2: + self.envs[split[0]] = split[1] + else: + self.envs[split[0]] = '' + + async def handle_connection( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + assert self.api_session is not None + p = WSProxy( + self.api_session, + self.session_name, + self.app_name, + self.args, + self.envs, + reader, + writer, + ) + try: + await p.run() + except asyncio.CancelledError: + pass + except Exception as e: + print_error(e) + + async def __aenter__(self) -> None: + self.exit_code = 0 + self.api_session = AsyncSession() + await self.api_session.__aenter__() + + user_url_template = "{protocol}://{host}:{port}" + try: + compute_session = self.api_session.ComputeSession(self.session_name) + all_apps = await compute_session.stream_app_info() + for app_info in all_apps: + if app_info['name'] == self.app_name: + if 'url_template' in app_info.keys(): + user_url_template = app_info['url_template'] + break + else: + print_fail(f'The app "{self.app_name}" is not supported by the session.') + self.exit_code = 1 + return + + self.local_server = await asyncio.start_server( + self.handle_connection, self.host, self.port) + user_url = user_url_template.format( + protocol=self.protocol, + host=self.host, + port=self.port, + ) + print_info( + "A local proxy to the application \"{0}\" ".format(self.app_name) + + "provided by the session \"{0}\" ".format(self.session_name) + + "is available at:\n{0}".format(user_url), + ) + if self.host == '0.0.0.0': + print_warn('NOTE: Replace "0.0.0.0" with the actual hostname you use ' + 'to connect with the CLI app proxy.') + except Exception: + await self.api_session.__aexit__(*sys.exc_info()) + raise + + async def __aexit__(self, *exc_info) -> None: + if self.local_server is not None: + print_info("Shutting down....") + self.local_server.close() + await self.local_server.wait_closed() + assert self.api_session is not None + await self.api_session.__aexit__(*exc_info) + assert self.api_session.closed + if self.local_server is not None: + print_info("The local proxy to \"{}\" has terminated." + .format(self.app_name)) + self.local_server = None + + +@main.command() +@click.argument('session_name', type=str, metavar='NAME') +@click.argument('app', type=str) +@click.option('-b', '--bind', type=str, default='127.0.0.1:8080', metavar='[HOST:]PORT', + help='The IP/host address and the port number to bind this proxy.') +@click.option('--arg', type=str, multiple=True, metavar='"--option "', + help='Add additional argument when starting service.') +@click.option('-e', '--env', type=str, multiple=True, metavar='"ENVNAME=envvalue"', + help='Add additional environment variable when starting service.') +def app(session_name, app, bind, arg, env): + """ + Run a local proxy to a service provided by Backend.AI compute sessions. + + The type of proxy depends on the app definition: plain TCP or HTTP. + + \b + SESSID: The compute session ID. + APP: The name of service provided by the given session. + """ + bind_parts = bind.rsplit(':', maxsplit=1) + if len(bind_parts) == 1: + host = '127.0.0.1' + port = int(bind_parts[0]) + elif len(bind_parts) == 2: + host = bind_parts[0] + port = int(bind_parts[1]) + try: + proxy_ctx = ProxyRunnerContext( + host, port, + session_name, app, + protocol='tcp', + args=arg, + envs=env, + ) + asyncio_run_forever(proxy_ctx) + sys.exit(proxy_ctx.exit_code) + except Exception as e: + print_error(e) + sys.exit(1) + + +@main.command() +@click.argument('session_name', type=str, metavar='NAME', nargs=1) +@click.argument('app_name', type=str, metavar='APP', nargs=-1) +@click.option('-l', '--list-names', is_flag=True, + help='Just print all available services.') +def apps(session_name, app_name, list_names): + ''' + List available additional arguments and environment variables when starting service. + + \b + SESSID: The compute session ID. + APP: The name of service provided by the given session. Repeatable. + If none provided, this will print all available services. + ''' + + async def print_arguments(): + apps = [] + async with AsyncSession() as api_session: + compute_session = api_session.ComputeSession(session_name) + apps = await compute_session.stream_app_info() + if len(app_name) > 0: + apps = list(filter(lambda x: x['name'] in app_name)) + if list_names: + print_info('This session provides the following app services: {0}' + .format(', '.join(list(map(lambda x: x['name'], apps))))) + return + for service in apps: + has_arguments = 'allowed_arguments' in service.keys() + has_envs = 'allowed_envs' in service.keys() + + if has_arguments or has_envs: + print_info('Information for service {0}:'.format(service['name'])) + if has_arguments: + print('\tAvailable arguments: {0}'.format(service['allowed_arguments'])) + if has_envs: + print('\tAvailable environment variables: {0}'.format(service['allowed_envs'])) + else: + print_info('Service {0} does not have customizable arguments.'.format(service['name'])) + + try: + asyncio_run(print_arguments()) + except Exception as e: + print_error(e) diff --git a/src/ai/backend/client/cli/config.py b/src/ai/backend/client/cli/config.py new file mode 100644 index 0000000000..69929154a1 --- /dev/null +++ b/src/ai/backend/client/cli/config.py @@ -0,0 +1,171 @@ +import getpass +import json +import sys +import warnings + +import click + +from .main import main +from .pretty import print_done, print_error, print_fail, print_warn +from .. import __version__ +from ..config import get_config, local_state_path +from ..exceptions import BackendClientError +from ..session import Session + + +@main.command() +def config(): + ''' + Shows the current configuration. + ''' + config = get_config() + click.echo('API endpoint: {0} (mode: {1})'.format( + click.style(str(config.endpoint), bold=True), + click.style(str(config.endpoint_type), fg='cyan', bold=True))) + click.echo('Client version: {0} (API: {1})'.format( + click.style(__version__, bold=True), + click.style(config.version, bold=True), + )) + if sys.stdout.isatty(): + click.echo('Server version: ...') + click.echo('Negotiated API version: ...') + else: + with Session() as sess: + try: + versions = sess.System.get_versions() + except BackendClientError: + click.echo('Server version: (failed to fetch)') + else: + click.echo('Server version: {0} (API: {1})'.format( + versions.get('manager', 'pre-19.03'), + versions['version'], + )) + click.echo('Negotiated API version: {0}'.format(sess.api_version)) + nrows = 1 + if config.domain: + click.echo('Domain name: "{0}"'.format(click.style(config.domain, bold=True))) + nrows += 1 + if config.group: + click.echo('Group name: "{0}"'.format(click.style(config.group, bold=True))) + nrows += 1 + if config.is_anonymous: + click.echo('Access key: (this is an anonymous session)') + nrows += 1 + elif config.endpoint_type == 'docker': + pass + elif config.endpoint_type == 'session': + if (local_state_path / 'cookie.dat').exists() and \ + (local_state_path / 'config.json').exists(): + sess_config = json.loads((local_state_path / 'config.json').read_text()) + click.echo('Username: "{0}"'.format(click.style(sess_config.get('username', ''), bold=True))) + nrows += 1 + else: + click.echo('Access key: "{0}"'.format(click.style(config.access_key, bold=True))) + nrows += 1 + masked_skey = config.secret_key[:6] + ('*' * 24) + config.secret_key[-10:] + click.echo('Secret key: "{0}"'.format(click.style(masked_skey, bold=True))) + nrows += 1 + click.echo('Signature hash type: {0}'.format( + click.style(config.hash_type, bold=True))) + nrows += 1 + click.echo('Skip SSL certificate validation? {0}'.format( + click.style(str(config.skip_sslcert_validation), bold=True))) + nrows += 1 + if sys.stdout.isatty(): + sys.stdout.flush() + with warnings.catch_warnings(record=True) as captured_warnings, Session() as sess: + click.echo('\u001b[{0}A\u001b[2K'.format(nrows + 1), nl=False) + try: + versions = sess.System.get_versions() + except BackendClientError: + click.echo('Server version: {0}'.format( + click.style('(failed to fetch)', fg='red', bold=True), + )) + else: + click.echo('Server version: {0} (API: {1})'.format( + click.style(versions.get('manager', 'pre-19.03'), bold=True), + click.style(versions['version'], bold=True), + )) + click.echo('\u001b[2K', nl=False) + click.echo('Negotiated API version: {0}'.format( + click.style('v{0[0]}.{0[1]}'.format(sess.api_version), bold=True), + )) + click.echo('\u001b[{0}B'.format(nrows), nl=False) + sys.stdout.flush() + for w in captured_warnings: + warnings.showwarning(w.message, w.category, w.filename, w.lineno, w.line) + + +@main.command() +def login(): + ''' + Log-in to the console API proxy. + It stores the current session cookie in the OS-default + local application data location. + ''' + user_id = input('User ID: ') + password = getpass.getpass() + + config = get_config() + if config.endpoint_type != 'session': + print_warn('To use login, your endpoint type must be "session".') + raise click.Abort() + + with Session() as session: + try: + result = session.Auth.login(user_id, password) + if not result['authenticated']: + print_fail('Login failed.') + sys.exit(1) + print_done('Login succeeded.') + + local_state_path.mkdir(parents=True, exist_ok=True) + session.aiohttp_session.cookie_jar.update_cookies(result['cookies']) + session.aiohttp_session.cookie_jar.save(local_state_path / 'cookie.dat') + (local_state_path / 'config.json').write_text(json.dumps(result.get('config', {}))) + except Exception as e: + print_error(e) + + +@main.command() +def logout(): + ''' + Log-out from the console API proxy and clears the local cookie data. + ''' + config = get_config() + if config.endpoint_type != 'session': + print_warn('To use logout, your endpoint type must be "session".') + raise click.Abort() + + with Session() as session: + try: + session.Auth.logout() + print_done('Logout done.') + try: + (local_state_path / 'cookie.dat').unlink() + (local_state_path / 'config.json').unlink() + except (IOError, PermissionError): + pass + except Exception as e: + print_error(e) + + +@main.command() +@click.argument('old_password', metavar='OLD_PASSWORD') +@click.argument('new_password', metavar='NEW_PASSWORD') +@click.argument('new_password2', metavar='NEW_PASSWORD2') +def update_password(old_password, new_password, new_password2): + ''' + Update user's password. + ''' + config = get_config() + if config.endpoint_type != 'session': + print_warn('To update password, your endpoint type must be "session".') + raise click.Abort() + + with Session() as session: + try: + session.Auth.update_password(old_password, new_password, new_password2) + print_done('Password updated.') + except Exception as e: + print_error(e) diff --git a/src/ai/backend/client/cli/dotfile.py b/src/ai/backend/client/cli/dotfile.py new file mode 100644 index 0000000000..cc9add7dba --- /dev/null +++ b/src/ai/backend/client/cli/dotfile.py @@ -0,0 +1,190 @@ +import sys + +import click +from tabulate import tabulate + +from .main import main +from .pretty import print_info, print_warn, print_error +from ..session import Session + + +@main.group() +def dotfile(): + '''Provides dotfile operations.''' + + +@dotfile.command() +@click.argument('path', metavar='PATH') +@click.option('--perm', 'permission', + help='Linux permission represented in octal number (e.g. 755) ' + 'Defaults to 755 if not specified') +@click.option('-f', '--file', 'dotfile_path', + help='Path to dotfile to upload. ' + 'If not specified, client will try to read file from STDIN. ') +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session of user dotfiles.') +@click.option('-d', '--domain', 'domain', metavar='DOMAIN', + help='Specify the domain name of domain dotfiles.') +@click.option('-g', '--group', metavar='GROUP', + help='Sepcify the group name or id of group dotfiles. ' + '(If group name is provided, domain name must be specified with option -d)') +def create(path, permission, dotfile_path, owner_access_key, domain, group): + ''' + Store dotfile to Backend.AI Manager. + Dotfiles will be automatically loaded when creating kernels. + + PATH: Where dotfiles will be created when starting kernel + ''' + + if dotfile_path: + with open(dotfile_path, 'r') as fr: + body = fr.read() + else: + body = '' + for line in sys.stdin: + body += (line + '\n') + with Session() as session: + try: + if not permission: + permission = '755' + dotfile_ = session.Dotfile.create(body, path, permission, + owner_access_key=owner_access_key, + domain=domain, group=group) + print_info(f'Dotfile {dotfile_.path} created and ready') + except Exception as e: + print_error(e) + sys.exit(1) + + +@dotfile.command() +@click.argument('path', metavar='PATH') +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session of user dotfiles.') +@click.option('-d', '--domain', 'domain', metavar='DOMAIN', + help='Specify the domain name of domain dotfiles.') +@click.option('-g', '--group', metavar='GROUP', + help='Sepcify the group name or id of group dotfiles. ' + '(If group name is provided, domain name must be specified with option -d)') +def get(path, owner_access_key, domain, group): + ''' + Print dotfile content. + ''' + with Session() as session: + try: + dotfile_ = session.Dotfile(path, owner_access_key=owner_access_key, + domain=domain, group=group) + body = dotfile_.get() + print(body['data']) + except Exception as e: + print_error(e) + sys.exit(1) + + +@dotfile.command() +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session of user dotfiles.') +@click.option('-d', '--domain', 'domain', metavar='DOMAIN', + help='Specify the domain name of domain dotfiles.') +@click.option('-g', '--group', metavar='GROUP', + help='Sepcify the group name or id of group dotfiles. ' + '(If group name is provided, domain name must be specified with option -d)') +def list(owner_access_key, domain, group): + ''' + List availabe user/domain/group dotfiles. + ''' + fields = [ + ('Path', 'path', None), + ('Data', 'data', lambda v: v[:30].splitlines()[0]), + ('Permission', 'permission', None), + ] + with Session() as session: + try: + resp = session.Dotfile.list_dotfiles(owner_access_key=owner_access_key, + domain=domain, group=group) + if not resp: + print('There is no dotfiles created yet.') + return + rows = ( + tuple( + item[key] if transform is None else transform(item[key]) + for _, key, transform in fields + ) + for item in resp + ) + hdrs = (display_name for display_name, _, _ in fields) + print(tabulate(rows, hdrs)) + except Exception as e: + print_error(e) + sys.exit(1) + + +@dotfile.command() +@click.argument('path', metavar='PATH') +@click.option('--perm', 'permission', + help='Linux permission represented in octal number (e.g. 755) ' + 'Defaults to 755 if not specified') +@click.option('-f', '--file', 'dotfile_path', + help='Path to dotfile to upload. ' + 'If not specified, client will try to read file from STDIN. ') +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session of user dotfiles.') +@click.option('-d', '--domain', 'domain', metavar='DOMAIN', + help='Specify the domain name of domain dotfiles.') +@click.option('-g', '--group', metavar='GROUP', + help='Sepcify the group name or id of group dotfiles. ' + '(If group name is provided, domain name must be specified with option -d)') +def update(path, permission, dotfile_path, owner_access_key, domain, group): + ''' + Update dotfile stored in Backend.AI Manager. + ''' + + if dotfile_path: + with open(dotfile_path, 'r') as fr: + body = fr.read() + else: + body = '' + for line in sys.stdin: + body += (line + '\n') + with Session() as session: + try: + if not permission: + permission = '755' + dotfile_ = session.Dotfile(path, owner_access_key=owner_access_key, + domain=domain, group=group) + dotfile_.update(body, permission) + print_info(f'Dotfile {dotfile_.path} updated') + except Exception as e: + print_error(e) + sys.exit(1) + + +@dotfile.command() +@click.argument('path', metavar='PATH') +@click.option('-f', '--force', type=bool, is_flag=True, + help='Delete dotfile without confirmation.') +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session of user dotfiles.') +@click.option('-d', '--domain', 'domain', metavar='DOMAIN', + help='Specify the domain name of domain dotfiles.') +@click.option('-g', '--group', metavar='GROUP', + help='Sepcify the group name or id of group dotfiles. ' + '(If group name is provided, domain name must be specified with option -d)') +def delete(path, force, owner_access_key, domain, group): + ''' + Delete dotfile from Backend.AI Manager. + ''' + with Session() as session: + dotfile_ = session.Dotfile(path, owner_access_key=owner_access_key, + domain=domain, group=group) + if not force: + print_warn('Are you sure? (y/[n])') + result = input() + if result.strip() != 'y': + print_info('Aborting.') + exit() + try: + dotfile_.delete() + print_info(f'Dotfile {dotfile_.path} deleted') + except Exception as e: + print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/cli/logs.py b/src/ai/backend/client/cli/logs.py new file mode 100644 index 0000000000..f7e61fc45a --- /dev/null +++ b/src/ai/backend/client/cli/logs.py @@ -0,0 +1,29 @@ +import sys + +import click + +from .main import main +from .pretty import print_error +from ..compat import asyncio_run +from ..session import AsyncSession + + +@main.command() +@click.argument('task_id', metavar='TASKID') +def task_logs(task_id): + ''' + Shows the output logs of a batch task. + + \b + TASKID: An UUID of a task (or kernel). + ''' + async def _task_logs(): + async with AsyncSession() as session: + async for chunk in session.ComputeSession.get_task_logs(task_id): + print(chunk.decode('utf8', errors='replace'), end='') + + try: + asyncio_run(_task_logs()) + except Exception as e: + print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/cli/main.py b/src/ai/backend/client/cli/main.py new file mode 100644 index 0000000000..8ed740f84e --- /dev/null +++ b/src/ai/backend/client/cli/main.py @@ -0,0 +1,45 @@ +import warnings + +import click + +from ai.backend.cli.extensions import ExtendedCommandGroup +from ai.backend.client import __version__ +from ai.backend.client.output import get_output_handler +from ai.backend.client.config import APIConfig, set_config +from ai.backend.client.cli.types import CLIContext, OutputMode + + +@click.group( + cls=ExtendedCommandGroup, + context_settings={ + 'help_option_names': ['-h', '--help'], + }, +) +@click.option('--skip-sslcert-validation', + help='Skip SSL certificate validation for all API requests.', + is_flag=True) +@click.option('--output', type=click.Choice(['json', 'console']), default='console', + help='Set the output style of the command results.') +@click.version_option(version=__version__) +@click.pass_context +def main(ctx: click.Context, skip_sslcert_validation: bool, output: str) -> None: + """ + Backend.AI command line interface. + """ + from .announcement import announce + config = APIConfig( + skip_sslcert_validation=skip_sslcert_validation, + announcement_handler=announce, + ) + set_config(config) + + output_mode = OutputMode(output) + cli_ctx = CLIContext( + api_config=config, + output_mode=output_mode, + ) + cli_ctx.output = get_output_handler(cli_ctx, output_mode) + ctx.obj = cli_ctx + + from .pretty import show_warning + warnings.showwarning = show_warning diff --git a/src/ai/backend/client/cli/pagination.py b/src/ai/backend/client/cli/pagination.py new file mode 100644 index 0000000000..577d9d4536 --- /dev/null +++ b/src/ai/backend/client/cli/pagination.py @@ -0,0 +1,115 @@ +import shutil +import sys +from typing import ( + Any, + Callable, + Iterator, + List, + Literal, + MutableMapping, + Sequence, +) + +import click +from tabulate import tabulate + +from ai.backend.client.output.types import FieldSpec +from ..pagination import MAX_PAGE_SIZE + + +def get_preferred_page_size() -> int: + return min(MAX_PAGE_SIZE, shutil.get_terminal_size((80, 20)).lines) + + +_Item = MutableMapping[str, Any] + + +def tabulate_items( + items: Iterator[_Item], + fields: Sequence[FieldSpec], + *, + page_size: int = None, + item_formatter: Callable[[_Item], None] = None, + tablefmt: Literal['simple', 'plain', 'github'] = 'simple', +) -> Iterator[str]: + is_first = True + output_count = 0 + buffered_items: List[_Item] = [] + + # check table header/footer sizes + header_height = 0 + if tablefmt in ('simple', 'github'): + header_height = 2 + assert header_height >= 0 + + def _tabulate_buffer() -> Iterator[str]: + table = tabulate( + [ + [ + f.formatter.format_console(v, f) for f, v in zip(fields, item.values()) + ] for item in buffered_items + ], + headers=( + [] if tablefmt == 'plain' + else [field.humanized_name for field in fields] + ), + tablefmt=tablefmt, + ) + table_rows = table.splitlines() + if is_first: + yield from (row + '\n' for row in table_rows) + else: + # strip the header for continued page outputs + yield from (row + '\n' for row in table_rows[header_height:]) + + # If we iterate until the end of items, pausing the terminal output + # would not have effects for avoiding unnecessary queries for subsequent pages. + # Let's buffer the items and split the formatting per page. + if page_size is None: + table_height = shutil.get_terminal_size((80, 20)).lines + else: + table_height = page_size + page_size = max(table_height - header_height - 1, 10) + for item in items: + if item_formatter is not None: + item_formatter(item) + buffered_items.append(item) + output_count += 1 + if output_count == page_size: + yield from _tabulate_buffer() + buffered_items.clear() + is_first = False + output_count = 0 + page_size = max(table_height - 1, 10) + if output_count > 0: + yield from _tabulate_buffer() + + +def echo_via_pager( + text_generator: Iterator[str], + break_callback: Callable[[], None] = None, +) -> None: + """ + A variant of ``click.echo_via_pager()`` which implements our own simplified pagination. + The key difference is that it holds the generator for each page, so that the generator + won't continue querying the next results unless continued, avoiding server overloads. + """ + # TODO: support PageUp & PageDn by buffering the output + terminal_height = shutil.get_terminal_size((80, 20)).lines + line_count = 0 + for text in text_generator: + line_count += text.count('\n') + click.echo(text, nl=False) + if line_count == terminal_height - 1: + if sys.stdin.isatty() and sys.stdout.isatty(): + click.echo(':', nl=False) + # Pause the terminal so that we don't execute next-page queries indefinitely. + # Since click.pause() ignores KeyboardInterrupt, we just use click.getchar() + # to allow user interruption. + k = click.getchar(echo=False) + if k in ('q', 'Q'): + if break_callback is not None: + break_callback() + break + click.echo('\r', nl=False) + line_count = 0 diff --git a/src/ai/backend/client/cli/params.py b/src/ai/backend/client/cli/params.py new file mode 100644 index 0000000000..ca8ee3dc6e --- /dev/null +++ b/src/ai/backend/client/cli/params.py @@ -0,0 +1,168 @@ +import json +import re +from decimal import Decimal +from typing import ( + Any, + Mapping, + Union, + Optional, +) + +import click + + +class ByteSizeParamType(click.ParamType): + name = "byte" + + _rx_digits = re.compile(r'^(\d+(?:\.\d*)?)([kmgtpe]?)$', re.I) + _scales = { + 'k': 2 ** 10, + 'm': 2 ** 20, + 'g': 2 ** 30, + 't': 2 ** 40, + 'p': 2 ** 50, + 'e': 2 ** 60, + } + + def convert(self, value, param, ctx): + if isinstance(value, int): + return value + if not isinstance(value, str): + self.fail( + f"expected string, got {value!r} of type {type(value).__name__}", + param, ctx, + ) + m = self._rx_digits.search(value) + if m is None: + self.fail(f"{value!r} is not a valid byte-size expression", param, ctx) + size = float(m.group(1)) + unit = m.group(2).lower() + return int(size * self._scales.get(unit, 1)) + + +class ByteSizeParamCheckType(ByteSizeParamType): + name = "byte-check" + + def convert(self, value, param, ctx): + if isinstance(value, int): + return value + if not isinstance(value, str): + self.fail( + f"expected string, got {value!r} of type {type(value).__name__}", + param, ctx, + ) + m = self._rx_digits.search(value) + if m is None: + self.fail(f"{value!r} is not a valid byte-size expression", param, ctx) + return value + + +class CommaSeparatedKVListParamType(click.ParamType): + name = "comma-seperated-KVList-check" + + def convert(self, value: Union[str, Mapping[str, str]], param, ctx) -> Mapping[str, str]: + if isinstance(value, dict): + return value + if not isinstance(value, str): + self.fail( + f"expected string, got {value!r} of type {type(value).__name__}", + param, ctx, + ) + override_map = {} + for assignment in value.split(","): + try: + k, _, v = assignment.partition("=") + if k == '' or v == '': + raise ValueError(f"key or value is empty. key = {k}, value = {v}") + except ValueError: + self.fail( + f"{value!r} is not a valid mapping expression", param, ctx, + ) + else: + override_map[k] = v + return override_map + + +class JSONParamType(click.ParamType): + """ + A JSON string parameter type. + The default value must be given as a valid JSON-parsable string, + not the Python objects. + """ + + name = "json-string" + + def __init__(self) -> None: + super().__init__() + self._parsed = False + + def convert( + self, + value: Optional[str], + param: Optional[click.Parameter], + ctx: Optional[click.Context], + ) -> Any: + if self._parsed: + # Click invokes this method TWICE + # for a default value given as string. + return value + self._parsed = True + if value is None: + return None + try: + return json.loads(value) + except json.JSONDecodeError: + self.fail(f"cannot parse {value!r} as JSON", param, ctx) + return value + + +def drange(start: Decimal, stop: Decimal, num: int): + """ + A simplified version of numpy.linspace with default options + """ + delta = stop - start + step = delta / (num - 1) + yield from (start + step * Decimal(tick) for tick in range(0, num)) + + +class RangeExprOptionType(click.ParamType): + """ + Accepts a range expression which generates a range of values for a variable. + + Linear space range: "linspace:1,2,10" (start, stop, num) as in numpy.linspace + Pythonic range: "range:1,10,2" (start, stop[, step]) as in Python's range + Case range: "case:a,b,c" (comma-separated strings) + """ + _rx_range_key = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') + name = 'Range Expression' + + def convert(self, arg, param, ctx): + key, value = arg.split('=', maxsplit=1) + assert self._rx_range_key.match(key), 'The key must be a valid slug string.' + try: + if value.startswith('case:'): + return key, value[5:].split(',') + elif value.startswith('linspace:'): + start, stop, num = value[9:].split(',') + return key, tuple(drange(Decimal(start), Decimal(stop), int(num))) + elif value.startswith('range:'): + range_args = map(int, value[6:].split(',')) + return key, tuple(range(*range_args)) + else: + self.fail('Unrecognized range expression type', param, ctx) + except ValueError as e: + self.fail(str(e), param, ctx) + + +class CommaSeparatedListType(click.ParamType): + + name = 'List Expression' + + def convert(self, arg, param, ctx): + try: + if isinstance(arg, int): + return arg + elif isinstance(arg, str): + return arg.split(',') + except ValueError as e: + self.fail(repr(e), param, ctx) diff --git a/src/ai/backend/client/cli/pretty.py b/src/ai/backend/client/cli/pretty.py new file mode 100644 index 0000000000..6e1fb1e3d4 --- /dev/null +++ b/src/ai/backend/client/cli/pretty.py @@ -0,0 +1,181 @@ +import enum +import functools +import sys +import textwrap +import traceback + +from click import echo, style + +from ..exceptions import BackendAPIError + +__all__ = ( + 'PrintStatus', 'print_pretty', 'print_info', 'print_wait', + 'print_done', 'print_warn', 'print_fail', 'print_error', + 'show_warning', +) + + +class PrintStatus(enum.Enum): + NONE = 0 + WAITING = 1 + DONE = 2 + FAILED = 3 + WARNING = 4 + + +def bold(text: str) -> str: + ''' + Wraps the given text with bold enable/disable ANSI sequences. + ''' + return (style(text, bold=True, reset=False) + + style('', bold=False, reset=False)) + + +def underline(text: str) -> str: + return (style(text, underline=True, reset=False) + + style('', underline=False, reset=False)) + + +def inverse(text: str) -> str: + return (style(text, reverse=True, reset=False) + + style('', reverse=False, reset=False)) + + +def italic(text: str) -> str: + return '\x1b[3m' + text + '\x1b[23m' + + +def format_pretty(msg, status=PrintStatus.NONE, colored=True): + if status == PrintStatus.NONE: + indicator = style('\u2219', fg='bright_cyan', reset=False) + elif status == PrintStatus.WAITING: + indicator = style('\u22EF', fg='bright_yellow', reset=False) + elif status == PrintStatus.DONE: + indicator = style('\u2714', fg='bright_green', reset=False) + elif status == PrintStatus.FAILED: + indicator = style('\u2718', fg='bright_red', reset=False) + elif status == PrintStatus.WARNING: + indicator = style('\u2219', fg='yellow', reset=False) + else: + raise ValueError + return style(indicator + textwrap.indent(msg, ' ')[1:], reset=True) + + +format_info = functools.partial(format_pretty, status=PrintStatus.NONE) +format_wait = functools.partial(format_pretty, status=PrintStatus.WAITING) +format_done = functools.partial(format_pretty, status=PrintStatus.DONE) +format_fail = functools.partial(format_pretty, status=PrintStatus.FAILED) +format_warn = functools.partial(format_pretty, status=PrintStatus.WARNING) + + +def print_pretty(msg, *, status=PrintStatus.NONE, file=None): + if file is None: + file = sys.stderr + if status == PrintStatus.NONE: + indicator = style('\u2219', fg='bright_cyan', reset=False) + elif status == PrintStatus.WAITING: + assert '\n' not in msg, 'Waiting message must be a single line.' + indicator = style('\u22EF', fg='bright_yellow', reset=False) + elif status == PrintStatus.DONE: + indicator = style('\u2714', fg='bright_green', reset=False) + elif status == PrintStatus.FAILED: + indicator = style('\u2718', fg='bright_red', reset=False) + elif status == PrintStatus.WARNING: + indicator = style('\u2219', fg='yellow', reset=False) + else: + raise ValueError + echo('\x1b[2K', nl=False, file=file) + text = textwrap.indent(msg, ' ') + text = style(indicator + text[1:], reset=True) + echo('{0}\r'.format(text), nl=False, file=file) + file.flush() + if status != PrintStatus.WAITING: + echo('', file=file) + + +print_info = functools.partial(print_pretty, status=PrintStatus.NONE) +print_wait = functools.partial(print_pretty, status=PrintStatus.WAITING) +print_done = functools.partial(print_pretty, status=PrintStatus.DONE) +print_fail = functools.partial(print_pretty, status=PrintStatus.FAILED) +print_warn = functools.partial(print_pretty, status=PrintStatus.WARNING) + + +def format_error(exc: Exception): + if isinstance(exc, BackendAPIError): + yield '{0}: {1} {2}\n'.format(exc.__class__.__name__, + exc.status, exc.reason) + yield '{0[title]}'.format(exc.data) + if exc.data['type'].endswith('/too-many-sessions-matched'): + matches = exc.data['data'].get('matches', []) + if matches: + yield "\nCandidates (up to 10 recent entries):\n" + for item in matches: + yield f"- {item['id']} ({item['name']}, {item['status']})\n" + elif exc.data['type'].endswith('/session-already-exists'): + existing_session_id = exc.data['data'].get('existingSessionId', None) + if existing_session_id is not None: + yield f"\n- Existing session ID: {existing_session_id}" + elif exc.data['type'].endswith('/invalid-api-params'): + general_error_msg = exc.data.get('msg', None) + if general_error_msg is not None: + yield f"\n- {general_error_msg}" + per_field_errors = exc.data.get('data', {}) + if isinstance(per_field_errors, dict): + for k, v in per_field_errors.items(): + yield f"\n- \"{k}\": {v}" + else: + yield f"\n- {per_field_errors}" + else: + if exc.data['type'].endswith('/graphql-error'): + yield "\n\u279c Message:\n" + yield from (f"{err_item['message']}\n" + for err_item in exc.data.get('data', [])) + else: + other_details = exc.data.get('msg', None) + if other_details: + yield '\n\u279c Message: ' + yield str(other_details) + other_data = exc.data.get('data', None) + if other_data: + yield '\n\u279c Data: ' + yield repr(other_data) + agent_details = exc.data.get('agent-details', None) + if agent_details is not None: + yield "\n\u279c This is an agent-side error. " + yield "Check the agent status or ask the administrator for help." + agent_exc = agent_details.get('exception', None) + if agent_exc is not None: + yield '\n\u279c ' + str(agent_exc) + desc = agent_details.get('title', None) + if desc is not None: + yield '\n\u279c ' + str(desc) + content = exc.data.get('content', None) + if content: + yield "\n" + content + else: + args = exc.args if exc.args else [''] + yield f"{exc.__class__.__name__}: {args[0]}\n" + yield "\n".join(map(str, args[1:])) + yield ("*** Traceback ***\n" + + "".join(traceback.format_tb(exc.__traceback__)).strip()) + + +def print_error(exc: Exception, *, file=None): + if file is None: + file = sys.stderr + indicator = style('\u2718', fg='bright_red', reset=False) + if file.isatty(): + echo('\x1b[2K', nl=False, file=file) + text = ''.join(format_error(exc)) + text = textwrap.indent(text, ' ') + text = style(indicator + text[1:], reset=True) + echo('{0}\r'.format(text), nl=False, file=file) + echo('', file=file) + file.flush() + + +def show_warning(message, category, filename, lineno, file=None, line=None): + echo('{0}: {1}'.format( + style(str(category.__name__), fg='yellow', bold=True), + style(str(message), fg='yellow'), + ), file=file) diff --git a/src/ai/backend/client/cli/proxy.py b/src/ai/backend/client/cli/proxy.py new file mode 100644 index 0000000000..305dc30b95 --- /dev/null +++ b/src/ai/backend/client/cli/proxy.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import asyncio +import json +import re +from typing import ( + Union, + Tuple, + AsyncIterator, +) + +import aiohttp +from aiohttp import web +import click + +from .main import main +from .pretty import print_info, print_error, print_fail +from ..exceptions import BackendAPIError, BackendClientError +from ..request import Request +from ..session import AsyncSession + + +class WebSocketProxy: + __slots__ = ( + 'up_conn', 'down_conn', + 'upstream_buffer', 'upstream_buffer_task', + ) + + upstream_buffer: asyncio.Queue[Tuple[Union[str, bytes], aiohttp.WSMsgType]] + + def __init__(self, up_conn: aiohttp.ClientWebSocketResponse, + down_conn: web.WebSocketResponse): + self.up_conn = up_conn + self.down_conn = down_conn + self.upstream_buffer = asyncio.Queue() + self.upstream_buffer_task = None + + async def proxy(self): + asyncio.ensure_future(self.downstream()) + await self.upstream() + + async def upstream(self): + try: + async for msg in self.down_conn: + if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY): + await self.send(msg.data, msg.type) + elif msg.type == aiohttp.WSMsgType.ERROR: + print_fail("ws connection closed with exception {}" + .format(self.up_conn.exception())) + break + elif msg.type == aiohttp.WSMsgType.CLOSE: + break + # here, client gracefully disconnected + except asyncio.CancelledError: + # here, client forcibly disconnected + pass + finally: + await self.close_downstream() + + async def downstream(self): + try: + self.upstream_buffer_task = \ + asyncio.ensure_future(self.consume_upstream_buffer()) + print_info("websocket proxy started") + async for msg in self.up_conn: + if msg.type == aiohttp.WSMsgType.TEXT: + await self.down_conn.send_str(msg.data) + elif msg.type == aiohttp.WSMsgType.BINARY: + await self.down_conn.send_bytes(msg.data) + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break + # here, server gracefully disconnected + except asyncio.CancelledError: + pass + except Exception as e: + print_fail('unexpected error: {}'.format(e)) + finally: + await self.close_upstream() + print_info("websocket proxy terminated") + + async def consume_upstream_buffer(self): + try: + while True: + data, tp = await self.upstream_buffer.get() + if not self.up_conn.closed: + if tp == aiohttp.WSMsgType.BINARY: + await self.up_conn.send_bytes(data) + elif tp == aiohttp.WSMsgType.TEXT: + await self.up_conn.send_str(data) + except asyncio.CancelledError: + pass + + async def send(self, msg: str, tp: aiohttp.WSMsgType): + await self.upstream_buffer.put((msg, tp)) + + async def close_downstream(self): + if not self.down_conn.closed: + await self.down_conn.close() + + async def close_upstream(self): + if not self.upstream_buffer_task.done(): + self.upstream_buffer_task.cancel() + await self.upstream_buffer_task + if not self.up_conn.closed: + await self.up_conn.close() + + +def _translate_headers(upstream_request: Request, client_request: Request) -> None: + for k, v in client_request.headers.items(): + upstream_request.headers[k] = v + api_endpoint = upstream_request.config.endpoint + assert api_endpoint.host is not None + if api_endpoint.is_default_port(): + upstream_request.headers['Host'] = api_endpoint.host + else: + upstream_request.headers['Host'] = f"{api_endpoint.host}:{api_endpoint.port}" + + +async def web_handler(request): + path = re.sub(r'^/?v(\d+)/', '/', request.path) + try: + # We treat all requests and responses as streaming universally + # to be a transparent proxy. + api_rqst = Request( + request.method, path, request.content, + params=request.query, + ) + _translate_headers(api_rqst, request) + if 'Content-Type' in request.headers: + api_rqst.content_type = request.content_type # set for signing + # Uploading request body happens at the entering of the block, + # and downloading response body happens in the read loop inside. + async with api_rqst.fetch() as up_resp: + down_resp = web.StreamResponse() + down_resp.set_status(up_resp.status, up_resp.reason) + down_resp.headers.update(up_resp.headers) + down_resp.headers['Access-Control-Allow-Origin'] = '*' + await down_resp.prepare(request) + while True: + chunk = await up_resp.read(8192) + if not chunk: + break + await down_resp.write(chunk) + return down_resp + except BackendAPIError as e: + return web.Response(body=json.dumps(e.data), + status=e.status, reason=e.reason) + except BackendClientError: + return web.Response( + body="The proxy target server is inaccessible.", + status=502, + reason="Bad Gateway") + except asyncio.CancelledError: + return web.Response( + body="The proxy is being shut down.", + status=503, + reason="Service Unavailable") + except Exception as e: + print_error(e) + return web.Response( + body="Something has gone wrong.", + status=500, + reason="Internal Server Error") + + +async def websocket_handler(request): + path = re.sub(r'^/?v(\d+)/', '/', request.path) + try: + api_rqst = Request( + request.method, path, request.content, + params=request.query, + content_type=request.content_type, + ) + _translate_headers(api_rqst, request) + async with api_rqst.connect_websocket() as up_conn: + down_conn = web.WebSocketResponse() + await down_conn.prepare(request) + web_socket_proxy = WebSocketProxy(up_conn, down_conn) + await web_socket_proxy.proxy() + return down_conn + except BackendAPIError as e: + return web.Response(body=json.dumps(e.data), + status=e.status, reason=e.reason) + except BackendClientError: + return web.Response( + body="The proxy target server is inaccessible.", + status=502, + reason="Bad Gateway") + except asyncio.CancelledError: + return web.Response( + body="The proxy is being shut down.", + status=503, + reason="Service Unavailable") + except Exception as e: + print_error(e) + return web.Response( + body="Something has gone wrong.", + status=500, + reason="Internal Server Error") + + +async def proxy_context(app: web.Application) -> AsyncIterator[None]: + app['client_session'] = AsyncSession() + async with app['client_session']: + yield + + +def create_proxy_app(): + app = web.Application() + app.cleanup_ctx.append(proxy_context) + + app.router.add_route("GET", r'/stream/{path:.*$}', websocket_handler) + app.router.add_route("GET", r'/wsproxy/{path:.*$}', websocket_handler) + app.router.add_route('*', r'/{path:.*$}', web_handler) + return app + + +@main.command(context_settings=dict(allow_extra_args=True)) +@click.option('--bind', type=str, default='localhost', + help='The IP/host address to bind this proxy.') +@click.option('-p', '--port', type=int, default=8084, + help='The TCP port to accept non-encrypted non-authorized ' + 'API requests.') +@click.pass_context +def proxy(ctx, bind, port): + """ + Run a non-encrypted non-authorized API proxy server. + Use this only for development and testing! + """ + app = create_proxy_app() + web.run_app(app, host=bind, port=port) diff --git a/src/ai/backend/client/cli/run.py b/src/ai/backend/client/cli/run.py new file mode 100644 index 0000000000..faa3b81a6b --- /dev/null +++ b/src/ai/backend/client/cli/run.py @@ -0,0 +1,726 @@ +import asyncio +import collections +from decimal import Decimal +import getpass +import itertools +import json +import secrets +import string +import sys +import traceback +from typing import ( + Mapping, + Optional, + Sequence, + Tuple, +) + +import aiohttp +import click +from humanize import naturalsize +import tabulate as tabulate_mod +from tabulate import tabulate + +from .main import main +from ..config import local_cache_path +from ..compat import asyncio_run, current_loop +from ..exceptions import BackendError +from ..session import AsyncSession +from .pretty import ( + print_info, print_wait, print_done, print_error, print_fail, print_warn, + format_info, +) +from .params import RangeExprOptionType, CommaSeparatedListType + +tabulate_mod.PRESERVE_WHITESPACE = True +range_expr = RangeExprOptionType() +list_expr = CommaSeparatedListType() + + +async def exec_loop(stdout, stderr, compute_session, mode, code, *, opts=None, + vprint_done=print_done, is_multi=False): + """ + Fully streamed asynchronous version of the execute loop. + """ + async with compute_session.stream_execute(code, mode=mode, opts=opts) as stream: + async for result in stream: + if result.type == aiohttp.WSMsgType.TEXT: + result = json.loads(result.data) + else: + # future extension + continue + for rec in result.get('console', []): + if rec[0] == 'stdout': + print(rec[1], end='', file=stdout) + elif rec[0] == 'stderr': + print(rec[1], end='', file=stderr) + else: + print('----- output record (type: {0}) -----'.format(rec[0]), + file=stdout) + print(rec[1], file=stdout) + print('----- end of record -----', file=stdout) + stdout.flush() + files = result.get('files', []) + if files: + print('--- generated files ---', file=stdout) + for item in files: + print('{0}: {1}'.format(item['name'], item['url']), file=stdout) + print('--- end of generated files ---', file=stdout) + if result['status'] == 'clean-finished': + exitCode = result.get('exitCode') + msg = 'Clean finished. (exit code = {0})'.format(exitCode) + if is_multi: + print(msg, file=stderr) + vprint_done(msg) + elif result['status'] == 'build-finished': + exitCode = result.get('exitCode') + msg = 'Build finished. (exit code = {0})'.format(exitCode) + if is_multi: + print(msg, file=stderr) + vprint_done(msg) + elif result['status'] == 'finished': + exitCode = result.get('exitCode') + msg = 'Execution finished. (exit code = {0})'.format(exitCode) + if is_multi: + print(msg, file=stderr) + vprint_done(msg) + break + elif result['status'] == 'waiting-input': + if result['options'].get('is_password', False): + code = getpass.getpass() + else: + code = input() + await stream.send_str(code) + elif result['status'] == 'continued': + pass + + +def exec_loop_sync(stdout, stderr, compute_session, mode, code, *, opts=None, + vprint_done=print_done): + """ + Old synchronous polling version of the execute loop. + """ + opts = opts if opts else {} + run_id = None # use server-assigned run ID + while True: + result = compute_session.execute(run_id, code, mode=mode, opts=opts) + run_id = result['runId'] + opts.clear() # used only once + for rec in result['console']: + if rec[0] == 'stdout': + print(rec[1], end='', file=stdout) + elif rec[0] == 'stderr': + print(rec[1], end='', file=stderr) + else: + print('----- output record (type: {0}) -----'.format(rec[0]), + file=stdout) + print(rec[1], file=stdout) + print('----- end of record -----', file=stdout) + stdout.flush() + files = result.get('files', []) + if files: + print('--- generated files ---', file=stdout) + for item in files: + print('{0}: {1}'.format(item['name'], item['url']), file=stdout) + print('--- end of generated files ---', file=stdout) + if result['status'] == 'clean-finished': + exitCode = result.get('exitCode') + vprint_done('Clean finished. (exit code = {0}'.format(exitCode), + file=stdout) + mode = 'continue' + code = '' + elif result['status'] == 'build-finished': + exitCode = result.get('exitCode') + vprint_done('Build finished. (exit code = {0})'.format(exitCode), + file=stdout) + mode = 'continue' + code = '' + elif result['status'] == 'finished': + exitCode = result.get('exitCode') + vprint_done('Execution finished. (exit code = {0})'.format(exitCode), + file=stdout) + break + elif result['status'] == 'waiting-input': + mode = 'input' + if result['options'].get('is_password', False): + code = getpass.getpass() + else: + code = input() + elif result['status'] == 'continued': + mode = 'continue' + code = '' + + +async def exec_terminal(compute_session, *, + vprint_wait=print_wait, vprint_done=print_done): + # async with compute_session.stream_pty() as stream: ... + raise NotImplementedError + + +def _noop(*args, **kwargs): + pass + + +def format_stats(stats): + formatted = [] + version = stats.pop('version', 1) + stats.pop('status') + if version == 1: + stats.pop('precpu_used', None) + stats.pop('precpu_system_used', None) + stats.pop('cpu_system_used', None) + for key, val in stats.items(): + if key.endswith('_size') or key.endswith('_bytes'): + val = naturalsize(val, binary=True) + elif key == 'cpu_used': + key += '_msec' + val = '{0:,}'.format(int(val)) + else: + val = '{0:,}'.format(int(val)) + formatted.append((key, val)) + elif version == 2: + max_integer_len = 0 + max_fraction_len = 0 + for key, metric in stats.items(): + unit = metric['unit_hint'] + if unit == 'bytes': + val = metric.get('stats.max', metric['current']) + val = naturalsize(val, binary=True) + val, unit = val.rsplit(' ', maxsplit=1) + val = '{:,}'.format(Decimal(val)) + elif unit == 'msec': + val = '{:,}'.format(Decimal(metric['current'])) + unit = 'msec' + elif unit == 'percent': + val = metric['pct'] + unit = '%' + else: + val = metric['current'] + unit = '' + if val is None: + continue + ip, _, fp = val.partition('.') + max_integer_len = max(len(ip), max_integer_len) + max_fraction_len = max(len(fp), max_fraction_len) + formatted.append([key, val, unit]) + fstr_int_only = '{0:>' + str(max_integer_len) + '}' + fstr_float = '{0:>' + str(max_integer_len) + '}.{1:<' + str(max_fraction_len) + '}' + for item in formatted: + ip, _, fp = item[1].partition('.') + if fp == '': + item[1] = fstr_int_only.format(ip) + ' ' * (max_fraction_len + 1) + else: + item[1] = fstr_float.format(ip, fp) + else: + print_warn('Unsupported statistics result version. Upgrade your client.') + return tabulate(formatted) + + +def prepare_resource_arg(resources): + if resources: + resources = {k: v for k, v in map(lambda s: s.split('=', 1), resources)} + else: + resources = {} # use the defaults configured in the server + return resources + + +def prepare_env_arg(env): + if env is not None: + envs = {k: v for k, v in map(lambda s: s.split('=', 1), env)} + else: + envs = {} + return envs + + +def prepare_mount_arg( + mount_args: Optional[Sequence[str]], +) -> Tuple[Sequence[str], Mapping[str, str]]: + """ + Parse the list of mount arguments into a list of + vfolder name and in-container mount path pairs. + """ + mounts = set() + mount_map = {} + if mount_args is not None: + for value in mount_args: + if '=' in value: + sp = value.split('=', maxsplit=1) + elif ':' in value: # docker-like volume mount mapping + sp = value.split(':', maxsplit=1) + else: + sp = [value] + mounts.add(sp[0]) + if len(sp) == 2: + mount_map[sp[0]] = sp[1] + return list(mounts), mount_map + + +@main.command() +@click.argument('image', type=str) +@click.argument('files', nargs=-1, type=click.Path()) +@click.option('-t', '--name', '--client-token', metavar='NAME', + help='Specify a human-readable session name. ' + 'If not set, a random hex string is used.') +# job scheduling options +@click.option('--type', metavar='SESSTYPE', + type=click.Choice(['batch', 'interactive']), + default='interactive', + help='Either batch or interactive') +@click.option('--starts-at', metavar='STARTS_AT', type=str, default=None, + help='Let session to be started at a specific or relative time.') +@click.option('--enqueue-only', is_flag=True, + help='Enqueue the session and return immediately without waiting for its startup.') +@click.option('--max-wait', metavar='SECONDS', type=int, default=0, + help='The maximum duration to wait until the session starts.') +@click.option('--no-reuse', is_flag=True, + help='Do not reuse existing sessions but return an error.') +@click.option('--callback-url', metavar='CALLBACK_URL', type=str, default=None, + help="Callback URL which will be called upon sesison lifecycle events.") +# query-mode options +@click.option('-c', '--code', metavar='CODE', + help='The code snippet as a single string') +@click.option('--terminal', is_flag=True, + help='Connect to the terminal-type compute_session.') +# batch-mode options +@click.option('--clean', metavar='CMD', + help='Custom shell command for cleaning up the base directory') +@click.option('--build', metavar='CMD', + help='Custom shell command for building the given files') +@click.option('--exec', metavar='CMD', + help='Custom shell command for executing the given files') +@click.option('--basedir', metavar='PATH', type=click.Path(), default=None, + help='Base directory path of uploaded files. ' + 'All uploaded files must reside inside this directory.') +# execution environment +@click.option('-e', '--env', metavar='KEY=VAL', type=str, multiple=True, + help='Environment variable (may appear multiple times)') +# extra options +@click.option('--bootstrap-script', metavar='PATH', type=click.File('r'), default=None, + help='A user-defined script to execute on startup.') +@click.option('--rm', is_flag=True, + help='Terminate the session immediately after running ' + 'the given code or files') +@click.option('-s', '--stats', is_flag=True, + help='Show resource usage statistics after termination ' + '(only works if "--rm" is given)') +@click.option('--tag', type=str, default=None, + help='User-defined tag string to annotate sessions.') +@click.option('-q', '--quiet', is_flag=True, + help='Hide execution details but show only the compute_session outputs.') +# experiment support +@click.option('--env-range', metavar='RANGE_EXPR', multiple=True, + type=range_expr, help='Range expression for environment variable.') +@click.option('--build-range', metavar='RANGE_EXPR', multiple=True, + type=range_expr, help='Range expression for execution arguments.') +@click.option('--exec-range', metavar='RANGE_EXPR', multiple=True, type=range_expr, + help='Range expression for execution arguments.') +@click.option('--max-parallel', metavar='NUM', type=int, default=2, + help='The maximum number of parallel sessions.') +# resource spec +@click.option('-v', '--volume', '-m', '--mount', 'mount', + metavar='NAME[=PATH]', type=str, multiple=True, + help='User-owned virtual folder names to mount. ' + 'If path is not provided, virtual folder will be mounted under /home/work. ' + 'When the target path is relative, it is placed under /home/work ' + 'with auto-created parent directories if any. ' + 'Absolute paths are mounted as-is, but it is prohibited to ' + 'override the predefined Linux system directories.') +@click.option('--scaling-group', '--sgroup', type=str, default=None, + help='The scaling group to execute session. If not specified, ' + 'all available scaling groups are included in the scheduling.') +@click.option('-r', '--resources', '--resource', metavar='KEY=VAL', type=str, multiple=True, + help='Set computation resources ' + '(e.g: -r cpu=2 -r mem=256 -r cuda.device=1)') +@click.option('--cluster-size', metavar='NUMBER', type=int, default=1, + help='The size of cluster in number of containers.') +@click.option('--cluster-mode', metavar='MODE', + type=click.Choice(['single-node', 'multi-node']), default='single-node', + help='The mode of clustering.') +@click.option('--resource-opts', metavar='KEY=VAL', type=str, multiple=True, + help='Resource options for creating compute session. ' + '(e.g: shmem=64m)') +# resource grouping +@click.option('-d', '--domain', metavar='DOMAIN_NAME', default=None, + help='Domain name where the session will be spawned. ' + 'If not specified, config\'s domain name will be used.') +@click.option('-g', '--group', metavar='GROUP_NAME', default=None, + help='Group name where the session is spawned. ' + 'User should be a member of the group to execute the code.') +@click.option('--preopen', default=None, type=list_expr, + help='Pre-open service ports') +@click.option('--assign-agent', default=None, type=list_expr, + help='Show mapping list of tuple which mapped containers with agent. ' + 'When user role is Super Admin. ' + '(e.g., --assign-agent agent_id_1,agent_id_2,...)') +def run(image, files, name, # base args + type, starts_at, enqueue_only, max_wait, no_reuse, # job scheduling options + callback_url, + code, terminal, # query-mode options + clean, build, exec, basedir, # batch-mode options + env, # execution environment + bootstrap_script, rm, stats, tag, quiet, # extra options + env_range, build_range, exec_range, max_parallel, # experiment support + mount, scaling_group, resources, # resource spec + cluster_size, cluster_mode, + resource_opts, + domain, group, preopen, assign_agent, # resource grouping + ): + """ + Run the given code snippet or files in a session. + Depending on the session ID you give (default is random), + it may reuse an existing session or create a new one. + + \b + IMAGE: The name (and version/platform tags appended after a colon) of session + runtime or programming language.') + FILES: The code file(s). Can be added multiple times. + """ + if quiet: + vprint_info = vprint_wait = vprint_done = _noop + else: + vprint_info = print_info + vprint_wait = print_wait + vprint_done = print_done + if files and code: + print('You can run only either source files or command-line ' + 'code snippet.', file=sys.stderr) + sys.exit(1) + if not files and not code: + print('You should provide the command-line code snippet using ' + '"-c" option if run without files.', file=sys.stderr) + sys.exit(1) + + envs = prepare_env_arg(env) + resources = prepare_resource_arg(resources) + resource_opts = prepare_resource_arg(resource_opts) + mount, mount_map = prepare_mount_arg(mount) + + if env_range is None: env_range = [] # noqa + if build_range is None: build_range = [] # noqa + if exec_range is None: exec_range = [] # noqa + + env_ranges = {v: r for v, r in env_range} + build_ranges = {v: r for v, r in build_range} + exec_ranges = {v: r for v, r in exec_range} + + env_var_maps = [dict(zip(env_ranges.keys(), values)) + for values in itertools.product(*env_ranges.values())] + build_var_maps = [dict(zip(build_ranges.keys(), values)) + for values in itertools.product(*build_ranges.values())] + exec_var_maps = [dict(zip(exec_ranges.keys(), values)) + for values in itertools.product(*exec_ranges.values())] + case_set = collections.OrderedDict() + vmaps_product = itertools.product(env_var_maps, build_var_maps, exec_var_maps) + build_template = string.Template(build) + exec_template = string.Template(exec) + env_templates = {k: string.Template(v) for k, v in envs.items()} + + if preopen is None: preopen = [] # noqa + if assign_agent is None: assign_agent = [] # noqa + + preopen_ports = preopen + assigned_agent_list = assign_agent + for env_vmap, build_vmap, exec_vmap in vmaps_product: + interpolated_envs = tuple((k, vt.substitute(env_vmap)) + for k, vt in env_templates.items()) + if build: + interpolated_build = build_template.substitute(build_vmap) + else: + interpolated_build = '*' + if exec: + interpolated_exec = exec_template.substitute(exec_vmap) + else: + interpolated_exec = '*' + case_set[(interpolated_envs, interpolated_build, interpolated_exec)] = 1 + + is_multi = (len(case_set) > 1) + if is_multi: + if max_parallel <= 0: + print('The number maximum parallel sessions must be ' + 'a positive integer.', file=sys.stderr) + sys.exit(1) + if terminal: + print('You cannot run multiple cases with terminal.', file=sys.stderr) + sys.exit(1) + if not quiet: + vprint_info('Running multiple sessions for the following combinations:') + for case in case_set.keys(): + pretty_env = ' '.join('{}={}'.format(item[0], item[1]) + for item in case[0]) + print('env = {!r}, build = {!r}, exec = {!r}' + .format(pretty_env, case[1], case[2])) + + def _run_legacy(session, idx, name, envs, + clean_cmd, build_cmd, exec_cmd): + try: + compute_session = session.ComputeSession.get_or_create( + image, + name=name, + type_=type, + enqueue_only=enqueue_only, + max_wait=max_wait, + no_reuse=no_reuse, + cluster_size=cluster_size, + cluster_mode=cluster_mode, + mounts=mount, + mount_map=mount_map, + envs=envs, + resources=resources, + domain_name=domain, + group_name=group, + scaling_group=scaling_group, + tag=tag, + ) + except Exception as e: + print_error(e) + sys.exit(1) + if compute_session.status == 'PENDING': + print_info('Session ID {0} is enqueued for scheduling.' + .format(name)) + return + elif compute_session.status == 'SCHEDULED': + print_info('Session ID {0} is scheduled and about to be started.' + .format(name)) + return + elif compute_session.status == 'RUNNING': + if compute_session.created: + vprint_done( + '[{0}] Session {1} is ready (domain={2}, group={3}).' + .format(idx, compute_session.name, + compute_session.domain, compute_session.group)) + else: + vprint_done('[{0}] Reusing session {1}...'.format(idx, compute_session.name)) + elif compute_session.status == 'TERMINATED': + print_warn('Session ID {0} is already terminated.\n' + 'This may be an error in the compute_session image.' + .format(name)) + return + elif compute_session.status == 'TIMEOUT': + print_info('Session ID {0} is still on the job queue.' + .format(name)) + return + elif compute_session.status in ('ERROR', 'CANCELLED'): + print_fail('Session ID {0} has an error during scheduling/startup or cancelled.' + .format(name)) + return + + try: + if files: + vprint_wait('[{0}] Uploading source files...'.format(idx)) + ret = compute_session.upload(files, basedir=basedir, + show_progress=True) + if ret.status // 100 != 2: + print_fail('[{0}] Uploading source files failed!'.format(idx)) + print('{0}: {1}\n{2}'.format( + ret.status, ret.reason, ret.text())) + return + vprint_done('[{0}] Uploading done.'.format(idx)) + opts = { + 'clean': clean_cmd, + 'build': build_cmd, + 'exec': exec_cmd, + } + if not terminal: + exec_loop_sync(sys.stdout, sys.stderr, compute_session, 'batch', '', + opts=opts, + vprint_done=vprint_done) + if terminal: + raise NotImplementedError('Terminal access is not supported in ' + 'the legacy synchronous mode.') + if code: + exec_loop_sync(sys.stdout, sys.stderr, compute_session, 'query', code, + vprint_done=vprint_done) + vprint_done('[{0}] Execution finished.'.format(idx)) + except Exception as e: + print_error(e) + sys.exit(1) + finally: + if rm: + vprint_wait('[{0}] Cleaning up the session...'.format(idx)) + ret = compute_session.destroy() + vprint_done('[{0}] Cleaned up the session.'.format(idx)) + if stats: + _stats = ret.get('stats', None) if ret else None + if _stats: + print('[{0}] Statistics:\n{1}' + .format(idx, format_stats(_stats))) + else: + print('[{0}] Statistics is not available.'.format(idx)) + + async def _run(session, idx, name, envs, + clean_cmd, build_cmd, exec_cmd, + is_multi=False): + try: + compute_session = await session.ComputeSession.get_or_create( + image, + name=name, + type_=type, + starts_at=starts_at, + enqueue_only=enqueue_only, + max_wait=max_wait, + no_reuse=no_reuse, + callback_url=callback_url, + cluster_size=cluster_size, + cluster_mode=cluster_mode, + mounts=mount, + mount_map=mount_map, + envs=envs, + resources=resources, + resource_opts=resource_opts, + domain_name=domain, + group_name=group, + scaling_group=scaling_group, + bootstrap_script=bootstrap_script.read() if bootstrap_script is not None else None, + tag=tag, + preopen_ports=preopen_ports, + assign_agent=assigned_agent_list, + ) + except Exception as e: + print_fail('[{0}] {1}'.format(idx, e)) + return + if compute_session.status == 'PENDING': + print_info('Session ID {0} is enqueued for scheduling.' + .format(name)) + return + elif compute_session.status == 'SCHEDULED': + print_info('Session ID {0} is scheduled and about to be started.' + .format(name)) + return + elif compute_session.status == 'RUNNING': + if compute_session.created: + vprint_done( + '[{0}] Session {1} is ready (domain={2}, group={3}).' + .format(idx, compute_session.name, + compute_session.domain, compute_session.group)) + else: + vprint_done('[{0}] Reusing session {1}...'.format(idx, compute_session.name)) + elif compute_session.status == 'TERMINATED': + print_warn('Session ID {0} is already terminated.\n' + 'This may be an error in the compute_session image.' + .format(name)) + return + elif compute_session.status == 'TIMEOUT': + print_info('Session ID {0} is still on the job queue.' + .format(name)) + return + elif compute_session.status in ('ERROR', 'CANCELLED'): + print_fail('Session ID {0} has an error during scheduling/startup or cancelled.' + .format(name)) + return + + if not is_multi: + stdout = sys.stdout + stderr = sys.stderr + else: + log_dir = local_cache_path / 'client-logs' + log_dir.mkdir(parents=True, exist_ok=True) + stdout = open(log_dir / '{0}.stdout.log'.format(name), + 'w', encoding='utf-8') + stderr = open(log_dir / '{0}.stderr.log'.format(name), + 'w', encoding='utf-8') + + try: + def indexed_vprint_done(msg): + vprint_done('[{0}] '.format(idx) + msg) + if files: + if not is_multi: + vprint_wait('[{0}] Uploading source files...'.format(idx)) + ret = await compute_session.upload(files, basedir=basedir, + show_progress=not is_multi) + if ret.status // 100 != 2: + print_fail('[{0}] Uploading source files failed!'.format(idx)) + print('{0}: {1}\n{2}'.format( + ret.status, ret.reason, ret.text()), file=stderr) + raise RuntimeError('Uploading source files has failed!') + if not is_multi: + vprint_done('[{0}] Uploading done.'.format(idx)) + opts = { + 'clean': clean_cmd, + 'build': build_cmd, + 'exec': exec_cmd, + } + if not terminal: + await exec_loop(stdout, stderr, compute_session, 'batch', '', + opts=opts, + vprint_done=indexed_vprint_done, + is_multi=is_multi) + if terminal: + await exec_terminal(compute_session) + return + if code: + await exec_loop(stdout, stderr, compute_session, 'query', code, + vprint_done=indexed_vprint_done, + is_multi=is_multi) + except BackendError as e: + print_fail('[{0}] {1}'.format(idx, e)) + raise RuntimeError(e) + except Exception as e: + print_fail('[{0}] Execution failed!'.format(idx)) + traceback.print_exc() + raise RuntimeError(e) + finally: + try: + if rm: + if not is_multi: + vprint_wait('[{0}] Cleaning up the session...'.format(idx)) + ret = await compute_session.destroy() + vprint_done('[{0}] Cleaned up the session.'.format(idx)) + if stats: + _stats = ret.get('stats', None) if ret else None + if _stats: + stats_str = format_stats(_stats) + print(format_info('[{0}] Statistics:'.format(idx)) + + '\n{0}'.format(stats_str)) + if is_multi: + print('Statistics:\n{0}'.format(stats_str), + file=stderr) + else: + print_warn('[{0}] Statistics: unavailable.'.format(idx)) + if is_multi: + print('Statistics: unavailable.', file=stderr) + except Exception as e: + print_fail('[{0}] Error while printing stats'.format(idx)) + traceback.print_exc() + raise RuntimeError(e) + finally: + if is_multi: + stdout.close() + stderr.close() + + async def _run_cases(): + loop = current_loop() + if name is None: + name_prefix = f'pysdk-{secrets.token_hex(5)}' + else: + name_prefix = name + vprint_info('Session name prefix: {0}'.format(name_prefix)) + if is_multi: + print_info('Check out the stdout/stderr logs stored in ' + '~/.cache/backend.ai/client-logs directory.') + async with AsyncSession() as session: + tasks = [] + # TODO: limit max-parallelism using aiojobs + for idx, case in enumerate(case_set.keys()): + if is_multi: + _name = '{0}-{1}'.format(name_prefix, idx) + else: + _name = name_prefix + envs = dict(case[0]) + clean_cmd = clean if clean else '*' + build_cmd = case[1] + exec_cmd = case[2] + t = loop.create_task( + _run(session, idx, _name, envs, + clean_cmd, build_cmd, exec_cmd, + is_multi=is_multi)) + tasks.append(t) + results = await asyncio.gather(*tasks, return_exceptions=True) + if any(map(lambda r: isinstance(r, Exception), results)): + if is_multi: + print_fail('There were failed cases!') + sys.exit(1) + + try: + asyncio_run(_run_cases()) + except Exception as e: + print_fail('{0}'.format(e)) diff --git a/src/ai/backend/client/cli/server_log.py b/src/ai/backend/client/cli/server_log.py new file mode 100644 index 0000000000..0aec1f0431 --- /dev/null +++ b/src/ai/backend/client/cli/server_log.py @@ -0,0 +1,43 @@ +from datetime import datetime +import sys + +import click + +from .main import main +from .pretty import print_error +from ..session import Session + + +@main.group() +def server_logs(): + """Provides operations related to server logs.""" + + +@server_logs.command() +@click.option('--mark-read', is_flag=True, default=False, + help='Mark read flag for server logs being fetched.') +@click.option('-l', '--page-size', type=int, default=20, + help='Number of logs to fetch (from latest log)') +@click.option('-n', '--page-number', type=int, default=1, + help='Page number to fetch.') +def list(mark_read, page_size, page_number): + """Fetch server (error) logs.""" + with Session() as session: + try: + resp = session.ServerLog.list(mark_read, page_size, page_number) + logs = resp.get('logs') + count = resp.get('count', 0) + if logs is not None: + print('Total log count:', count) + for log in logs: + log_time = datetime.utcfromtimestamp(log['created_at']).strftime('%Y-%m-%d %H:%M:%S') + print('----') + print(log_time, log['severity'].upper(), log['source'], log['user']) + print(log['request_status'], log['request_url']) + print(log['message']) + print(log['traceback']) + else: + print('No logs.') + except Exception as e: + print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/cli/session.py b/src/ai/backend/client/cli/session.py new file mode 100644 index 0000000000..440c96db02 --- /dev/null +++ b/src/ai/backend/client/cli/session.py @@ -0,0 +1,863 @@ +from __future__ import annotations + +from datetime import datetime +import json +from pathlib import Path +import secrets +import subprocess +import sys +from typing import IO, Literal, Sequence +import uuid + +import click +from humanize import naturalsize +from tabulate import tabulate + +from .main import main +from .pretty import print_wait, print_done, print_error, print_fail, print_info, print_warn +from .ssh import container_ssh_ctx +from .run import format_stats, prepare_env_arg, prepare_resource_arg, prepare_mount_arg +from ..compat import asyncio_run +from ..exceptions import BackendAPIError +from ..session import Session, AsyncSession +from ..types import Undefined, undefined +from .params import CommaSeparatedListType + +list_expr = CommaSeparatedListType() + + +@main.group() +def session(): + """Set of compute session operations""" + + +def _create_cmd(docs: str = None): + + @click.argument('image') + @click.option('-t', '--name', '--client-token', metavar='NAME', + help='Specify a human-readable session name. ' + 'If not set, a random hex string is used.') + @click.option('-o', '--owner', '--owner-access-key', metavar='ACCESS_KEY', + help='Set the owner of the target session explicitly.') + # job scheduling options + @click.option('--type', metavar='SESSTYPE', + type=click.Choice(['batch', 'interactive']), + default='interactive', + help='Either batch or interactive') + @click.option('--starts-at', metavar='STARTS_AT', type=str, default=None, + help='Let session to be started at a specific or relative time.') + @click.option('-c', '--startup-command', metavar='COMMAND', + help='Set the command to execute for batch-type sessions.') + @click.option('--enqueue-only', is_flag=True, + help='Enqueue the session and return immediately without waiting for its startup.') + @click.option('--max-wait', metavar='SECONDS', type=int, default=0, + help='The maximum duration to wait until the session starts.') + @click.option('--no-reuse', is_flag=True, + help='Do not reuse existing sessions but return an error.') + @click.option('--depends', metavar='SESSION_ID', type=str, multiple=True, + help="Set the list of session ID or names that the newly created session depends on. " + "The session will get scheduled after all of them successfully finish.") + @click.option('--callback-url', metavar='CALLBACK_URL', type=str, default=None, + help="Callback URL which will be called upon sesison lifecycle events.") + # execution environment + @click.option('-e', '--env', metavar='KEY=VAL', type=str, multiple=True, + help='Environment variable (may appear multiple times)') + # extra options + @click.option('--bootstrap-script', metavar='PATH', type=click.File('r'), default=None, + help='A user-defined script to execute on startup.') + @click.option('--tag', type=str, default=None, + help='User-defined tag string to annotate sessions.') + # resource spec + @click.option('-v', '--volume', '-m', '--mount', 'mount', + metavar='NAME[=PATH]', type=str, multiple=True, + help='User-owned virtual folder names to mount. ' + 'If path is not provided, virtual folder will be mounted under /home/work. ' + 'When the target path is relative, it is placed under /home/work ' + 'with auto-created parent directories if any. ' + 'Absolute paths are mounted as-is, but it is prohibited to ' + 'override the predefined Linux system directories.') + @click.option('--scaling-group', '--sgroup', type=str, default=None, + help='The scaling group to execute session. If not specified, ' + 'all available scaling groups are included in the scheduling.') + @click.option('-r', '--resources', metavar='KEY=VAL', type=str, multiple=True, + help='Set computation resources used by the session ' + '(e.g: -r cpu=2 -r mem=256 -r gpu=1).' + '1 slot of cpu/gpu represents 1 core. ' + 'The unit of mem(ory) is MiB.') + @click.option('--cluster-size', metavar='NUMBER', type=int, default=1, + help='The size of cluster in number of containers.') + @click.option('--cluster-mode', metavar='MODE', + type=click.Choice(['single-node', 'multi-node']), default='single-node', + help='The mode of clustering.') + @click.option('--resource-opts', metavar='KEY=VAL', type=str, multiple=True, + help='Resource options for creating compute session ' + '(e.g: shmem=64m)') + @click.option('--preopen', default=None, type=list_expr, + help='Pre-open service ports') + # resource grouping + @click.option('-d', '--domain', metavar='DOMAIN_NAME', default=None, + help='Domain name where the session will be spawned. ' + 'If not specified, config\'s domain name will be used.') + @click.option('-g', '--group', metavar='GROUP_NAME', default=None, + help='Group name where the session is spawned. ' + 'User should be a member of the group to execute the code.') + @click.option('--assign-agent', default=None, type=list_expr, + help='Show mapping list of tuple which mapped containers with agent. ' + 'When user role is Super Admin. ' + '(e.g., --assign-agent agent_id_1,agent_id_2,...)') + def create( + # base args + image: str, + name: str | None, + owner: str | None, + # job scheduling options + type: Literal['batch', 'interactive'], + starts_at: str | None, + startup_command: str | None, + enqueue_only: bool, + max_wait: bool, + no_reuse: bool, + depends: Sequence[str], + callback_url: str, + # execution environment + env: Sequence[str], + # extra options + bootstrap_script: IO | None, + tag: str | None, + # resource spec + mount: Sequence[str], + scaling_group: str | None, + resources: Sequence[str], + cluster_size: int, + cluster_mode: Literal['single-node', 'multi-node'], + resource_opts: Sequence[str], + preopen: str | None, + assign_agent: str | None, + # resource grouping + domain: str | None, + group: str | None, + ) -> None: + """ + Prepare and start a single compute session without executing codes. + You may use the created session to execute codes using the "run" command + or connect to an application service provided by the session using the "app" + command. + + + \b + IMAGE: The name (and version/platform tags appended after a colon) of session + runtime or programming language. + """ + if name is None: + name = f'pysdk-{secrets.token_hex(5)}' + else: + name = name + + ###### + envs = prepare_env_arg(env) + resources = prepare_resource_arg(resources) + resource_opts = prepare_resource_arg(resource_opts) + mount, mount_map = prepare_mount_arg(mount) + + preopen_ports = preopen + assigned_agent_list = assign_agent + with Session() as session: + try: + compute_session = session.ComputeSession.get_or_create( + image, + name=name, + type_=type, + starts_at=starts_at, + enqueue_only=enqueue_only, + max_wait=max_wait, + no_reuse=no_reuse, + dependencies=depends, + callback_url=callback_url, + cluster_size=cluster_size, + cluster_mode=cluster_mode, + mounts=mount, + mount_map=mount_map, + envs=envs, + startup_command=startup_command, + resources=resources, + resource_opts=resource_opts, + owner_access_key=owner, + domain_name=domain, + group_name=group, + scaling_group=scaling_group, + bootstrap_script=bootstrap_script.read() if bootstrap_script is not None else None, + tag=tag, + preopen_ports=preopen_ports, + assign_agent=assigned_agent_list, + ) + except Exception as e: + print_error(e) + sys.exit(1) + else: + if compute_session.status == 'PENDING': + print_info('Session ID {0} is enqueued for scheduling.' + .format(compute_session.id)) + elif compute_session.status == 'SCHEDULED': + print_info('Session ID {0} is scheduled and about to be started.' + .format(compute_session.id)) + return + elif compute_session.status == 'RUNNING': + if compute_session.created: + print_info('Session ID {0} is created and ready.' + .format(compute_session.id)) + else: + print_info('Session ID {0} is already running and ready.' + .format(compute_session.id)) + if compute_session.service_ports: + print_info('This session provides the following app services: ' + + ', '.join(sport['name'] + for sport in compute_session.service_ports)) + elif compute_session.status == 'TERMINATED': + print_warn('Session ID {0} is already terminated.\n' + 'This may be an error in the compute_session image.' + .format(compute_session.id)) + elif compute_session.status == 'TIMEOUT': + print_info('Session ID {0} is still on the job queue.' + .format(compute_session.id)) + elif compute_session.status in ('ERROR', 'CANCELLED'): + print_fail('Session ID {0} has an error during scheduling/startup or cancelled.' + .format(compute_session.id)) + + if docs is not None: + create.__doc__ = docs + return create + + +main.command(aliases=['start'])(_create_cmd(docs="Alias of \"session create\"")) +session.command()(_create_cmd()) + + +def _create_from_template_cmd(docs: str = None): + + @click.argument('template_id') + @click.option('-t', '--name', '--client-token', metavar='NAME', + default=undefined, + help='Specify a human-readable session name. ' + 'If not set, a random hex string is used.') + @click.option('-o', '--owner', '--owner-access-key', metavar='ACCESS_KEY', + default=undefined, + help='Set the owner of the target session explicitly.') + # job scheduling options + @click.option('--type', 'type_', metavar='SESSTYPE', + type=click.Choice(['batch', 'interactive', undefined]), # type: ignore + default=undefined, + help='Either batch or interactive') + @click.option('--starts_at', metavar='STARTS_AT', type=str, default=None, + help='Let session to be started at a specific or relative time.') + @click.option('-i', '--image', default=undefined, + help='Set compute_session image to run.') + @click.option('-c', '--startup-command', metavar='COMMAND', default=undefined, + help='Set the command to execute for batch-type sessions.') + @click.option('--enqueue-only', is_flag=True, + help='Enqueue the session and return immediately without waiting for its startup.') + @click.option('--max-wait', metavar='SECONDS', type=int, default=undefined, + help='The maximum duration to wait until the session starts.') + @click.option('--no-reuse', is_flag=True, + help='Do not reuse existing sessions but return an error.') + @click.option('--depends', metavar='SESSION_ID', type=str, multiple=True, + help="Set the list of session ID or names that the newly created session depends on. " + "The session will get scheduled after all of them successfully finish.") + @click.option('--callback-url', metavar='CALLBACK_URL', type=str, default=None, + help="Callback URL which will be called upon sesison lifecycle events.") + # execution environment + @click.option('-e', '--env', metavar='KEY=VAL', type=str, multiple=True, + help='Environment variable (may appear multiple times)') + # extra options + @click.option('--tag', type=str, default=undefined, + help='User-defined tag string to annotate sessions.') + # resource spec + @click.option('-m', '--mount', metavar='NAME[=PATH]', type=str, multiple=True, + help='User-owned virtual folder names to mount. ' + 'When the target path is relative, it is placed under /home/work ' + 'with auto-created parent directories if any. ' + 'Absolute paths are mounted as-is, but it is prohibited to ' + 'override the predefined Linux system directories.') + @click.option('--scaling-group', '--sgroup', type=str, default=undefined, + help='The scaling group to execute session. If not specified, ' + 'all available scaling groups are included in the scheduling.') + @click.option('-r', '--resources', metavar='KEY=VAL', type=str, multiple=True, + help='Set computation resources used by the session ' + '(e.g: -r cpu=2 -r mem=256 -r gpu=1).' + '1 slot of cpu/gpu represents 1 core. ' + 'The unit of mem(ory) is MiB.') + @click.option('--cluster-size', metavar='NUMBER', type=int, default=undefined, + help='The size of cluster in number of containers.') + @click.option('--resource-opts', metavar='KEY=VAL', type=str, multiple=True, + help='Resource options for creating compute session ' + '(e.g: shmem=64m)') + # resource grouping + @click.option('-d', '--domain', metavar='DOMAIN_NAME', default=None, + help='Domain name where the session will be spawned. ' + 'If not specified, config\'s domain name will be used.') + @click.option('-g', '--group', metavar='GROUP_NAME', default=None, + help='Group name where the session is spawned. ' + 'User should be a member of the group to execute the code.') + # template overrides + @click.option('--no-mount', is_flag=True, + help='If specified, client.py will tell server not to mount ' + 'any vFolders specified at template,') + @click.option('--no-env', is_flag=True, + help='If specified, client.py will tell server not to add ' + 'any environs specified at template,') + @click.option('--no-resource', is_flag=True, + help='If specified, client.py will tell server not to add ' + 'any resource specified at template,') + def create_from_template( + # base args + template_id: str, + name: str | Undefined, + owner: str | Undefined, + # job scheduling options + type_: Literal['batch', 'interactive'] | Undefined, + starts_at: str | None, + image: str | Undefined, + startup_command: str | Undefined, + enqueue_only: bool, + max_wait: int | Undefined, + no_reuse: bool, + depends: Sequence[str], + callback_url: str, + # execution environment + env: Sequence[str], + # extra options + tag: str | Undefined, + # resource spec + mount: Sequence[str], + scaling_group: str | Undefined, + resources: Sequence[str], + cluster_size: int | Undefined, + resource_opts: Sequence[str], + # resource grouping + domain: str | None, + group: str | None, + # template overrides + no_mount: bool, + no_env: bool, + no_resource: bool, + ) -> None: + """ + Prepare and start a single compute session without executing codes. + You may use the created session to execute codes using the "run" command + or connect to an application service provided by the session using the "app" + command. + + \b + IMAGE: The name (and version/platform tags appended after a colon) of session + runtime or programming language. + """ + if name is undefined: + name = f'pysdk-{secrets.token_hex(5)}' + else: + name = name + + envs = prepare_env_arg(env) if len(env) > 0 or no_env else undefined + resources = prepare_resource_arg(resources) if len(resources) > 0 or no_resource else undefined + resource_opts = ( + prepare_resource_arg(resource_opts) + if len(resource_opts) > 0 or no_resource else undefined + ) + prepared_mount, prepared_mount_map = ( + prepare_mount_arg(mount) + if len(mount) > 0 or no_mount else (undefined, undefined) + ) + with Session() as session: + try: + compute_session = session.ComputeSession.create_from_template( + template_id, + image=image, + name=name, + type_=type_, + starts_at=starts_at, + enqueue_only=enqueue_only, + max_wait=max_wait, + no_reuse=no_reuse, + dependencies=depends, + callback_url=callback_url, + cluster_size=cluster_size, + mounts=prepared_mount, + mount_map=prepared_mount_map, + envs=envs, + startup_command=startup_command, + resources=resources, + resource_opts=resource_opts, + owner_access_key=owner, + domain_name=domain, + group_name=group, + scaling_group=scaling_group, + tag=tag, + ) + except Exception as e: + print_error(e) + sys.exit(1) + else: + if compute_session.status == 'PENDING': + print_info('Session ID {0} is enqueued for scheduling.' + .format(name)) + elif compute_session.status == 'SCHEDULED': + print_info('Session ID {0} is scheduled and about to be started.' + .format(name)) + return + elif compute_session.status == 'RUNNING': + if compute_session.created: + print_info('Session ID {0} is created and ready.' + .format(name)) + else: + print_info('Session ID {0} is already running and ready.' + .format(name)) + if compute_session.service_ports: + print_info('This session provides the following app services: ' + + ', '.join(sport['name'] + for sport in compute_session.service_ports)) + elif compute_session.status == 'TERMINATED': + print_warn('Session ID {0} is already terminated.\n' + 'This may be an error in the compute_session image.' + .format(name)) + elif compute_session.status == 'TIMEOUT': + print_info('Session ID {0} is still on the job queue.' + .format(name)) + elif compute_session.status in ('ERROR', 'CANCELLED'): + print_fail('Session ID {0} has an error during scheduling/startup or cancelled.' + .format(name)) + + if docs is not None: + create_from_template.__doc__ = docs + return create_from_template + + +main.command(aliases=['start-from-template'])( + _create_from_template_cmd(docs="Alias of \"session create-from-template\""), +) +session.command()(_create_from_template_cmd()) + + +def _destroy_cmd(docs: str = None): + + @click.argument('session_names', metavar='SESSID', nargs=-1) + @click.option('-f', '--forced', is_flag=True, + help='Force-terminate the errored sessions (only allowed for admins)') + @click.option('-o', '--owner', '--owner-access-key', metavar='ACCESS_KEY', + help='Specify the owner of the target session explicitly.') + @click.option('-s', '--stats', is_flag=True, + help='Show resource usage statistics after termination') + def destroy(session_names, forced, owner, stats): + """ + Terminate and destroy the given session. + + SESSID: session ID given/generated when creating the session. + """ + if len(session_names) == 0: + print_warn('Specify at least one session ID. Check usage with "-h" option.') + sys.exit(1) + print_wait('Terminating the session(s)...') + with Session() as session: + has_failure = False + for session_name in session_names: + try: + compute_session = session.ComputeSession(session_name, owner) + ret = compute_session.destroy(forced=forced) + except BackendAPIError as e: + print_error(e) + if e.status == 404: + print_info( + 'If you are an admin, use "-o" / "--owner" option ' + 'to terminate other user\'s session.') + has_failure = True + except Exception as e: + print_error(e) + has_failure = True + else: + if not has_failure: + print_done('Done.') + if stats: + stats = ret.get('stats', None) if ret else None + if stats: + print(format_stats(stats)) + else: + print('Statistics is not available.') + if has_failure: + sys.exit(1) + + if docs is not None: + destroy.__doc__ = docs + return destroy + + +main.command(aliases=['rm', 'kill'])(_destroy_cmd(docs="Alias of \"session destroy\"")) +session.command(aliases=['rm', 'kill'])(_destroy_cmd()) + + +def _restart_cmd(docs: str = None): + + @click.argument('session_refs', metavar='SESSION_REFS', nargs=-1) + def restart(session_refs): + """ + Restart the compute session. + + \b + SESSION_REF: session ID or name + """ + if len(session_refs) == 0: + print_warn('Specify at least one session ID. Check usage with "-h" option.') + sys.exit(1) + print_wait('Restarting the session(s)...') + with Session() as session: + has_failure = False + for session_ref in session_refs: + try: + compute_session = session.ComputeSession(session_ref) + compute_session.restart() + except BackendAPIError as e: + print_error(e) + if e.status == 404: + print_info( + 'If you are an admin, use "-o" / "--owner" option ' + 'to terminate other user\'s session.') + has_failure = True + except Exception as e: + print_error(e) + has_failure = True + else: + if not has_failure: + print_done('Done.') + if has_failure: + sys.exit(1) + + if docs is not None: + restart.__doc__ = docs + return restart + + +main.command()(_restart_cmd(docs="Alias of \"session restart\"")) +session.command()(_restart_cmd()) + + +@session.command() +@click.argument('session_id', metavar='SESSID') +@click.argument('files', type=click.Path(exists=True), nargs=-1) +def upload(session_id, files): + """ + Upload the files to a compute session's home directory. + If the target directory is in a storage folder mount, the operation is + effectively same to uploading files to the storage folder. + It is recommended to use storage folder commands for large file transfers + to utilize the storage proxy. + + For cluster sessions, the files are only uploaded to the main container. + + \b + SESSID: Session ID or name. + FILES: One or more paths to upload. + """ + if len(files) < 1: + print_warn("Please specify one or more file paths after session ID or name.") + return + with Session() as session: + try: + print_wait('Uploading files...') + kernel = session.ComputeSession(session_id) + kernel.upload(files, show_progress=True) + print_done('Uploaded.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@session.command() +@click.argument('session_id', metavar='SESSID') +@click.argument('files', nargs=-1) +@click.option('--dest', type=Path, default='.', + help='Destination path to store downloaded file(s)') +def download(session_id, files, dest): + """ + Download files from a compute session's home directory. + If the source path is in a storage folder mount, the operation is + effectively same to downloading files from the storage folder. + It is recommended to use storage folder commands for large file transfers + to utilize the storage proxy. + + For cluster sessions, the files are only downloaded from the main container. + + \b + SESSID: Session ID or name. + FILES: One or more paths inside compute session. + """ + if len(files) < 1: + print_warn("Please specify one or more file paths after session ID or name.") + return + with Session() as session: + try: + print_wait('Downloading file(s) from {}...' + .format(session_id)) + kernel = session.ComputeSession(session_id) + kernel.download(files, dest, show_progress=True) + print_done('Downloaded to {}.'.format(dest.resolve())) + except Exception as e: + print_error(e) + sys.exit(1) + + +@session.command() +@click.argument('session_id', metavar='SESSID') +@click.argument('path', metavar='PATH', nargs=1, default='/home/work') +def ls(session_id, path): + """ + List files in a path of a running compute session. + + For cluster sessions, it lists the files of the main container. + + \b + SESSID: Session ID or name. + PATH: Path inside container. + """ + with Session() as session: + try: + print_wait('Retrieving list of files in "{}"...'.format(path)) + kernel = session.ComputeSession(session_id) + result = kernel.list_files(path) + + if 'errors' in result and result['errors']: + print_fail(result['errors']) + sys.exit(1) + + files = json.loads(result['files']) + table = [] + headers = ['File name', 'Size', 'Modified', 'Mode'] + for file in files: + mdt = datetime.fromtimestamp(file['mtime']) + fsize = naturalsize(file['size'], binary=True) + mtime = mdt.strftime('%b %d %Y %H:%M:%S') + row = [file['filename'], fsize, mtime, file['mode']] + table.append(row) + print_done('Retrived.') + print(tabulate(table, headers=headers)) + except Exception as e: + print_error(e) + sys.exit(1) + + +@session.command() +@click.argument('session_id', metavar='SESSID') +def logs(session_id): + ''' + Shows the full console log of a compute session. + + \b + SESSID: Session ID or its alias given when creating the session. + ''' + with Session() as session: + try: + print_wait('Retrieving live container logs...') + kernel = session.ComputeSession(session_id) + result = kernel.get_logs().get('result') + logs = result.get('logs') if 'logs' in result else '' + print(logs) + print_done('End of logs.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@session.command() +@click.argument('session_id', metavar='SESSID') +@click.argument('new_id', metavar='NEWID') +def rename(session_id, new_id): + ''' + Renames session name of running session. + + \b + SESSID: Session ID or its alias given when creating the session. + NEWID: New Session ID to rename to. + ''' + + with Session() as session: + try: + kernel = session.ComputeSession(session_id) + kernel.rename(new_id) + print_done(f'Session renamed to {new_id}.') + except Exception as e: + print_error(e) + sys.exit(1) + + +def _ssh_cmd(docs: str = None): + + @click.argument("session_ref", type=str, metavar='SESSION_REF') + @click.option('-p', '--port', type=int, metavar='PORT', default=9922, + help="the port number for localhost") + @click.pass_context + def ssh(ctx: click.Context, session_ref: str, port: int) -> None: + """Execute the ssh command against the target compute session. + + \b + SESSION_REF: The user-provided name or the unique ID of a running compute session. + + All remaining options and arguments not listed here are passed to the ssh command as-is. + """ + try: + with container_ssh_ctx(session_ref, port) as key_path: + ssh_proc = subprocess.run( + [ + "ssh", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "NoHostAuthenticationForLocalhost=yes", + "-i", key_path, + "work@localhost", + "-p", str(port), + *ctx.args, + ], + shell=False, + check=False, # be transparent against the main command + ) + sys.exit(ssh_proc.returncode) + except Exception as e: + print_error(e) + + if docs is not None: + ssh.__doc__ = docs + return ssh + + +_ssh_cmd_context_settings = { + "ignore_unknown_options": True, + "allow_extra_args": True, + "allow_interspersed_args": True, +} + +# Make it available as: +# - backend.ai ssh +# - backend.ai session ssh +main.command( + context_settings=_ssh_cmd_context_settings, +)(_ssh_cmd(docs="Alias of \"session ssh\"")) +session.command( + context_settings=_ssh_cmd_context_settings, +)(_ssh_cmd()) + + +def _scp_cmd(docs: str = None): + + @click.argument("session_ref", type=str, metavar='SESSION_REF') + @click.argument("src", type=str, metavar='SRC') + @click.argument("dst", type=str, metavar='DST') + @click.option('-p', '--port', type=str, metavar='PORT', default=9922, + help="the port number for localhost") + @click.option('-r', '--recursive', default=False, is_flag=True, + help="recursive flag option to process directories") + @click.pass_context + def scp( + ctx: click.Context, + session_ref: str, + src: str, + dst: str, + port: int, + recursive: bool, + ) -> None: + """ + Execute the scp command against the target compute session. + + \b + The SRC and DST have the same format with the original scp command, + either a remote path as "work@localhost:path" or a local path. + + SESSION_REF: The user-provided name or the unique ID of a running compute session. + SRC: the source path + DST: the destination path + + All remaining options and arguments not listed here are passed to the ssh command as-is. + + Examples: + + * Uploading a local directory to the session: + + > backend.ai scp mysess -p 9922 -r tmp/ work@localhost:tmp2/ + + * Downloading a directory from the session: + + > backend.ai scp mysess -p 9922 -r work@localhost:tmp2/ tmp/ + """ + recursive_args = [] + if recursive: + recursive_args.append("-r") + try: + with container_ssh_ctx(session_ref, port) as key_path: + scp_proc = subprocess.run( + [ + "scp", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "NoHostAuthenticationForLocalhost=yes", + "-i", key_path, + "-P", str(port), + *recursive_args, + src, dst, + *ctx.args, + ], + shell=False, + check=False, # be transparent against the main command + ) + sys.exit(scp_proc.returncode) + except Exception as e: + print_error(e) + + if docs is not None: + scp.__doc__ = docs + return scp + + +# Make it available as: +# - backend.ai scp +# - backend.ai session scp +main.command( + context_settings=_ssh_cmd_context_settings, +)(_scp_cmd(docs="Alias of \"session scp\"")) +session.command( + context_settings=_ssh_cmd_context_settings, +)(_scp_cmd()) + + +def _events_cmd(docs: str = None): + + @click.argument('session_name_or_id', metavar='SESSION_ID_OR_NAME') + @click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session explicitly.') + @click.option('--scope', type=click.Choice(['*', 'session', 'kernel']), default='*', + help='Filter the events by kernel-specific ones or session-specific ones.') + def events(session_name_or_id, owner_access_key, scope): + """ + Monitor the lifecycle events of a compute session. + + SESSID: session ID or its alias given when creating the session. + """ + + async def _run_events(): + async with AsyncSession() as session: + try: + session_id = uuid.UUID(session_name_or_id) + compute_session = session.ComputeSession.from_session_id(session_id) + except ValueError: + compute_session = session.ComputeSession(session_name_or_id, owner_access_key) + async with compute_session.listen_events(scope=scope) as response: + async for ev in response: + print(click.style(ev.event, fg='cyan', bold=True), json.loads(ev.data)) + + try: + asyncio_run(_run_events()) + except Exception as e: + print_error(e) + + if docs is not None: + events.__doc__ = docs + return events + + +# Make it available as: +# - backend.ai events +# - backend.ai session events +main.command()(_events_cmd(docs="Alias of \"session events\"")) +session.command()(_events_cmd()) diff --git a/src/ai/backend/client/cli/session_template.py b/src/ai/backend/client/cli/session_template.py new file mode 100644 index 0000000000..ef9314840b --- /dev/null +++ b/src/ai/backend/client/cli/session_template.py @@ -0,0 +1,156 @@ +import sys + +import click +from tabulate import tabulate + +from .main import main +from .pretty import print_info, print_warn, print_error +from ..session import Session + + +@main.group(aliases=['sesstpl']) +def session_template(): + """Set of session template operations""" + + +@session_template.command() +@click.option('-f', '--file', 'template_path', + help='Path to task template file. ' + 'If not specified, client will try to read config from STDIN. ') +@click.option('-d', '--domain', metavar='DOMAIN_NAME', default=None, + help='Domain name where the session will be spawned. ' + 'If not specified, config\'s domain name will be used.') +@click.option('-g', '--group', metavar='GROUP_NAME', default=None, + help='Group name where the session is spawned. ' + 'User should be a member of the group to execute the code.') +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Set the owner of the target session explicitly.') +def create(template_path, domain, group, owner_access_key): + ''' + Store task template to Backend.AI Manager and return template ID. + Template can be used when creating new session. + ''' + + if template_path: + with open(template_path, 'r') as fr: + body = fr.read() + else: + body = '' + for line in sys.stdin: + body += (line + '\n') + with Session() as session: + try: + # TODO: Make user select template type when cluster template is implemented + template = session.SessionTemplate.create(body, + domain_name=domain, + group_name=group, + owner_access_key=owner_access_key) + print_info(f'Task template {template.template_id} created and ready') + except Exception as e: + print_error(e) + sys.exit(1) + + +@session_template.command() +@click.argument('template_id', metavar='TEMPLATEID') +@click.option('-f', '--format', 'template_format', default='yaml', + help='Output format for task template. "yaml" and "json" allowed.') +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session explicitly.') +def get(template_id, template_format, owner_access_key): + ''' + Print task template associated with given template ID + ''' + with Session() as session: + try: + template = session.SessionTemplate(template_id, owner_access_key=owner_access_key) + body = template.get(body_format=template_format) + print(body) + except Exception as e: + print_error(e) + sys.exit(1) + + +@session_template.command() +@click.option('-a', '--list-all', is_flag=True, + help='List all virtual folders (superadmin privilege is required).') +def list(list_all): + ''' + List all availabe task templates by user. + ''' + fields = [ + ('Name', 'name'), + ('ID', 'id'), + ('Created At', 'created_at'), + ('Owner', 'is_owner'), + ('Type', 'type'), + ('User', 'user'), + ('Group', 'group'), + ] + with Session() as session: + try: + resp = session.SessionTemplate.list_templates(list_all) + if not resp: + print('There is no task templates created yet.') + return + rows = (tuple(vf[key] for _, key in fields) for vf in resp) + hdrs = (display_name for display_name, _ in fields) + print(tabulate(rows, hdrs)) + except Exception as e: + print_error(e) + sys.exit(1) + + +@session_template.command() +@click.argument('template_id', metavar='TEMPLATEID') +@click.option('-f', '--file', 'template_path', + help='Path to task template file. ' + 'If not specified, client will try to read config from STDIN. ') +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session explicitly.') +def update(template_id, template_path, owner_access_key): + ''' + Update task template stored in Backend.AI Manager. + ''' + + if template_path: + with open(template_path, 'r') as fr: + body = fr.read() + else: + body = '' + for line in sys.stdin: + body += (line + '\n') + with Session() as session: + try: + template = session.SessionTemplate(template_id, owner_access_key=owner_access_key) + template.put(body) + print_info(f'Task template {template.template_id} updated') + except Exception as e: + print_error(e) + sys.exit(1) + + +@session_template.command() +@click.argument('template_id', metavar='TEMPLATEID') +@click.option('-f', '--force', type=bool, is_flag=True, + help='If specified, delete task template without asking.') +@click.option('-o', '--owner', '--owner-access-key', 'owner_access_key', metavar='ACCESS_KEY', + help='Specify the owner of the target session explicitly.') +def delete(template_id, force, owner_access_key): + ''' + Delete task template from Backend.AI Manager. + ''' + with Session() as session: + template = session.SessionTemplate(template_id, owner_access_key=owner_access_key) + if not force: + print_warn('Are you sure? (y/[n])') + result = input() + if result.strip() != 'y': + print_info('Aborting.') + exit() + try: + template.delete() + print_info(f'Task template {template.template_id} deleted') + except Exception as e: + print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/cli/ssh.py b/src/ai/backend/client/cli/ssh.py new file mode 100644 index 0000000000..33eba6a6f8 --- /dev/null +++ b/src/ai/backend/client/cli/ssh.py @@ -0,0 +1,60 @@ +import contextlib +import os +from pathlib import Path +import secrets +import signal +import subprocess +import sys +from typing import Iterator, List + +from .pretty import print_info, print_fail + + +@contextlib.contextmanager +def container_ssh_ctx(session_ref: str, port: int) -> Iterator[Path]: + random_id = secrets.token_hex(16) + key_filename = "id_container" + key_path = Path(f"~/.ssh/id_{random_id}").expanduser() + try: + subprocess.run( + ["backend.ai", "session", "download", session_ref, key_filename], + shell=False, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + except subprocess.CalledProcessError as e: + print_fail(f"Failed to download the SSH key from the session (exit: {e.returncode}):") + print(e.stdout.decode()) + sys.exit(1) + os.rename(key_filename, key_path) + print_info(f"running a temporary sshd proxy at localhost:{port} ...", file=sys.stderr) + # proxy_proc is a background process + proxy_proc = subprocess.Popen( + [ + "backend.ai", "app", session_ref, + "sshd", "-b", f"127.0.0.1:{port}", + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + assert proxy_proc.stdout is not None + try: + lines: List[bytes] = [] + while True: + line = proxy_proc.stdout.readline(1024) + if not line: + proxy_proc.wait() + print_fail(f"Unexpected early termination of the sshd app command " + f"(exit: {proxy_proc.returncode}):") + print((b"\n".join(lines)).decode()) + sys.exit(1) + if f"127.0.0.1:{port}".encode() in line: + break + lines.append(line) + lines.clear() + yield key_path + finally: + proxy_proc.send_signal(signal.SIGINT) + proxy_proc.wait() + os.unlink(key_path) diff --git a/src/ai/backend/client/cli/types.py b/src/ai/backend/client/cli/types.py new file mode 100644 index 0000000000..dbf41ad554 --- /dev/null +++ b/src/ai/backend/client/cli/types.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import enum +from typing import TYPE_CHECKING + +import attr + +if TYPE_CHECKING: + from ..config import APIConfig + from ..output import BaseOutputHandler + + +class OutputMode(enum.Enum): + CONSOLE = 'console' + JSON = 'json' + + +@attr.define(slots=True) +class CLIContext: + api_config: APIConfig = attr.field() + output_mode: OutputMode = attr.field() + output: BaseOutputHandler = attr.field(default=None) diff --git a/src/ai/backend/client/cli/vfolder.py b/src/ai/backend/client/cli/vfolder.py new file mode 100644 index 0000000000..90988f3a47 --- /dev/null +++ b/src/ai/backend/client/cli/vfolder.py @@ -0,0 +1,681 @@ +from datetime import datetime +import json +from pathlib import Path +import sys + +import click +import humanize +from tabulate import tabulate +from tqdm import tqdm + +from ai.backend.cli.interaction import ask_yn +from ai.backend.client.config import DEFAULT_CHUNK_SIZE, APIConfig +from ai.backend.client.session import Session + +from ..compat import asyncio_run +from ..session import AsyncSession +from .main import main +from .pretty import print_done, print_error, print_fail, print_info, print_wait, print_warn +from .params import ByteSizeParamType, ByteSizeParamCheckType, CommaSeparatedKVListParamType + + +@main.group() +def vfolder(): + """Set of vfolder operations""" + + +@vfolder.command() +def list_hosts(): + '''List the hosts of virtual folders that is accessible to the current user.''' + with Session() as session: + try: + resp = session.VFolder.list_hosts() + print("Default vfolder host: {}".format(resp['default'])) + print("Usable hosts: {}".format(', '.join(resp['allowed']))) + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +def list_allowed_types(): + '''List allowed vfolder types.''' + with Session() as session: + try: + resp = session.VFolder.list_allowed_types() + print(resp) + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('host', type=str, default=None) +@click.option('-g', '--group', metavar='GROUP', type=str, default=None, + help='Group ID or NAME. Specify this option if you want to create a group folder.') +@click.option('--unmanaged', 'host_path', type=bool, is_flag=True, + help='Treats HOST as a mount point of unmanaged virtual folder. ' + 'This option can only be used by Admin or Superadmin.') +@click.option('-m', '--usage-mode', metavar='USAGE_MODE', type=str, default='general', + help='Purpose of the folder. Normal folders are usually set to "general". ' + 'Available options: "general", "data" (provides data to users), ' + 'and "model" (provides pre-trained models).') +@click.option('-p', '--permission', metavar='PERMISSION', type=str, default='rw', + help='Folder\'s innate permission. ' + 'Group folders can be shared as read-only by setting this option to "ro".' + 'Invited folders override this setting by its own invitation permission.') +@click.option('-q', '--quota', metavar='QUOTA', type=ByteSizeParamCheckType(), default='0', + help='Quota of the virtual folder. ' + '(Use \'m\' for megabytes, \'g\' for gigabytes, and etc.) ' + 'Default is maximum amount possible.') +@click.option('--cloneable', '--allow-clone', type=bool, is_flag=True, + help='Allows the virtual folder to be cloned by users.') +def create(name, host, group, host_path, usage_mode, permission, quota, cloneable): + '''Create a new virtual folder. + + \b + NAME: Name of a virtual folder. + HOST: Name of a virtual folder host in which the virtual folder will be created. + ''' + with Session() as session: + try: + if host_path: + result = session.VFolder.create( + name=name, + unmanaged_path=host, + group=group, + usage_mode=usage_mode, + permission=permission, + quota=quota, + cloneable=cloneable, + ) + else: + result = session.VFolder.create( + name=name, + host=host, + group=group, + usage_mode=usage_mode, + permission=permission, + quota=quota, + cloneable=cloneable, + ) + print('Virtual folder "{0}" is created.'.format(result['name'])) + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +def delete(name): + '''Delete the given virtual folder. This operation is irreversible! + + NAME: Name of a virtual folder. + ''' + with Session() as session: + try: + session.VFolder(name).delete() + print_done('Deleted.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('old_name', type=str) +@click.argument('new_name', type=str) +def rename(old_name, new_name): + '''Rename the given virtual folder. This operation is irreversible! + You cannot change the vfolders that are shared by other users, + and the new name must be unique among all your accessible vfolders + including the shared ones. + + OLD_NAME: The current name of a virtual folder. + NEW_NAME: The new name of a virtual folder. + ''' + with Session() as session: + try: + session.VFolder(old_name).rename(new_name) + print_done('Renamed.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +def info(name): + '''Show the information of the given virtual folder. + + NAME: Name of a virtual folder. + ''' + with Session() as session: + try: + result = session.VFolder(name).info() + print('Virtual folder "{0}" (ID: {1})' + .format(result['name'], result['id'])) + print('- Owner:', result['is_owner']) + print('- Permission:', result['permission']) + print('- Number of files: {0}'.format(result['numFiles'])) + print('- Ownership Type: {0}'.format(result['type'])) + print('- Permission:', result['permission']) + print('- Usage Mode: {0}'.format(result.get('usage_mode', ''))) + print('- Group ID: {0}'.format(result['group'])) + print('- User ID: {0}'.format(result['user'])) + print('- Clone Allowed: {0}'.format(result['cloneable'])) + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command(context_settings={'show_default': True}) # bug: pallets/click#1565 (fixed in 8.0) +@click.argument('name', type=str) +@click.argument('filenames', type=Path, nargs=-1) +@click.option('-b', '--base-dir', type=Path, default=None, + help='Set the parent directory from where the file is uploaded. ' + '[default: current working directry]') +@click.option('--chunk-size', type=ByteSizeParamType(), + default=humanize.naturalsize(DEFAULT_CHUNK_SIZE, binary=True, gnu=True), + help='Transfer the file with the given chunk size with binary suffixes (e.g., "16m"). ' + 'Set this between 8 to 64 megabytes for high-speed disks (e.g., SSD RAID) ' + 'and networks (e.g., 40 GbE) for the maximum throughput.') +@click.option('--override-storage-proxy', + type=CommaSeparatedKVListParamType(), default=None, + help='Overrides storage proxy address. ' + 'The value must shape like "X1=Y1,X2=Y2...". ' + 'Each Yn address must at least include the IP address ' + 'or the hostname and may include the protocol part and the port number to replace.') +def upload(name, filenames, base_dir, chunk_size, override_storage_proxy): + ''' + TUS Upload a file to the virtual folder from the current working directory. + The files with the same names will be overwirtten. + + \b + NAME: Name of a virtual folder. + FILENAMES: Paths of the files to be uploaded. + ''' + with Session() as session: + try: + session.VFolder(name).upload( + filenames, + basedir=base_dir, + chunk_size=chunk_size, + show_progress=True, + address_map=override_storage_proxy or APIConfig.DEFAULTS['storage_proxy_address_map'], + ) + print_done('Done.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command(context_settings={'show_default': True}) # bug: pallets/click#1565 (fixed in 8.0) +@click.argument('name', type=str) +@click.argument('filenames', type=Path, nargs=-1) +@click.option('-b', '--base-dir', type=Path, default=None, + help='Set the parent directory from where the file is uploaded. ' + '[default: current working directry]') +@click.option('--chunk-size', type=ByteSizeParamType(), + default=humanize.naturalsize(DEFAULT_CHUNK_SIZE, binary=True, gnu=True), + help='Transfer the file with the given chunk size with binary suffixes (e.g., "16m"). ' + 'Set this between 8 to 64 megabytes for high-speed disks (e.g., SSD RAID) ' + 'and networks (e.g., 40 GbE) for the maximum throughput.') +@click.option('--override-storage-proxy', + type=CommaSeparatedKVListParamType(), default=None, + help='Overrides storage proxy address. ' + 'The value must shape like "X1=Y1,X2=Y2...". ' + 'Each Yn address must at least include the IP address ' + 'or the hostname and may include the protocol part and the port number to replace.') +def download(name, filenames, base_dir, chunk_size, override_storage_proxy): + ''' + Download a file from the virtual folder to the current working directory. + The files with the same names will be overwirtten. + + \b + NAME: Name of a virtual folder. + FILENAMES: Paths of the files to be downloaded inside a vfolder. + ''' + with Session() as session: + try: + session.VFolder(name).download( + filenames, + basedir=base_dir, + chunk_size=chunk_size, + show_progress=True, + address_map=override_storage_proxy or APIConfig.DEFAULTS['storage_proxy_address_map'], + ) + print_done('Done.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('filename', type=Path) +def request_download(name, filename): + ''' + Request JWT-formated download token for later use. + + \b + NAME: Name of a virtual folder. + FILENAME: Path of the file to be downloaded. + ''' + with Session() as session: + try: + response = json.loads(session.VFolder(name).request_download(filename)) + print_done(f'Download token: {response["token"]}') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('filenames', nargs=-1) +def cp(filenames): + '''An scp-like shortcut for download/upload commands. + + FILENAMES: Paths of the files to operate on. The last one is the target while all + others are the sources. Either source paths or the target path should + be prefixed with ":" like when using the Linux scp + command to indicate if it is a remote path. + ''' + raise NotImplementedError + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('path', type=str) +def mkdir(name, path): + '''Create an empty directory in the virtual folder. + + \b + NAME: Name of a virtual folder. + PATH: The name or path of directory. Parent directories are created automatically + if they do not exist. + ''' + with Session() as session: + try: + session.VFolder(name).mkdir(path) + print_done('Done.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('target_path', type=str) +@click.argument('new_name', type=str) +def rename_file(name, target_path, new_name): + ''' + Rename a file or a directory in a virtual folder. + + \b + NAME: Name of a virtual folder. + TARGET_PATH: The target path inside a virtual folder (file or directory). + NEW_NAME: New name of the target (should not contain slash). + ''' + with Session() as session: + try: + session.VFolder(name).rename_file(target_path, new_name) + print_done('Renamed.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('src', type=str) +@click.argument('dst', type=str) +def mv(name, src, dst): + ''' + Move a file or a directory within a virtual folder. + If the destination is a file and already exists, it will be overwritten. + If the destination is a directory, the source file or directory + is moved inside it. + + \b + NAME: Name of a virtual folder. + SRC: The relative path of the source file or directory inside a virtual folder + DST: The relative path of the destination file or directory inside a virtual folder. + ''' + with Session() as session: + try: + session.VFolder(name).move_file(src, dst) + print_done('Moved.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command(aliases=['delete-file']) +@click.argument('name', type=str) +@click.argument('filenames', nargs=-1) +@click.option('-r', '--recursive', is_flag=True, + help='Enable recursive deletion of directories.') +def rm(name, filenames, recursive): + ''' + Delete files in a virtual folder. + If one of the given paths is a directory and the recursive option is enabled, + all its content and the directory itself are recursively deleted. + + This operation is irreversible! + + \b + NAME: Name of a virtual folder. + FILENAMES: Paths of the files to delete. + ''' + with Session() as session: + try: + if not ask_yn(): + print_info('Cancelled') + sys.exit(1) + session.VFolder(name).delete_files( + filenames, + recursive=recursive) + print_done('Done.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('path', metavar='PATH', nargs=1, default='.') +def ls(name, path): + """ + List files in a path of a virtual folder. + + \b + NAME: Name of a virtual folder. + PATH: Path inside vfolder. + """ + with Session() as session: + try: + print_wait('Retrieving list of files in "{}"...'.format(path)) + result = session.VFolder(name).list_files(path) + if 'error_msg' in result and result['error_msg']: + print_fail(result['error_msg']) + return + files = json.loads(result['files']) + table = [] + headers = ['file name', 'size', 'modified', 'mode'] + for file in files: + mdt = datetime.fromtimestamp(file['mtime']) + mtime = mdt.strftime('%b %d %Y %H:%M:%S') + row = [file['filename'], file['size'], mtime, file['mode']] + table.append(row) + print_done('Retrived.') + print(tabulate(table, headers=headers)) + except Exception as e: + print_error(e) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('emails', type=str, nargs=-1, required=True) +@click.option('-p', '--perm', metavar='PERMISSION', type=str, default='rw', + help='Permission to give. "ro" (read-only) / "rw" (read-write) / "wd" (write-delete).') +def invite(name, emails, perm): + """Invite other users to access a user-type virtual folder. + + \b + NAME: Name of a virtual folder. + EMAILS: Emails to invite. + """ + with Session() as session: + try: + assert perm in ['rw', 'ro', 'wd'], 'Invalid permission: {}'.format(perm) + result = session.VFolder(name).invite(perm, emails) + invited_ids = result.get('invited_ids', []) + if len(invited_ids) > 0: + print('Invitation sent to:') + for invitee in invited_ids: + print('\t- ' + invitee) + else: + print('No users found. Invitation was not sent.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +def invitations(): + """List and manage received invitations. + """ + with Session() as session: + try: + result = session.VFolder.invitations() + invitations = result.get('invitations', []) + if len(invitations) < 1: + print('No invitations.') + return + print('List of invitations (inviter, vfolder id, permission):') + for cnt, inv in enumerate(invitations): + if inv['perm'] == 'rw': + perm = 'read-write' + elif inv['perm'] == 'ro': + perm = 'read-only' + else: + perm = inv['perm'] + print('[{}] {}, {}, {}'.format(cnt + 1, inv['inviter'], + inv['vfolder_id'], perm)) + + selection = input('Choose invitation number to manage: ') + if selection.isdigit(): + selection = int(selection) - 1 + else: + return + if 0 <= selection < len(invitations): + while True: + action = input('Choose action. (a)ccept, (r)eject, (c)ancel: ') + if action.lower() == 'a': + session.VFolder.accept_invitation(invitations[selection]['id']) + msg = ( + 'You can now access vfolder {} ({})'.format( + invitations[selection]['vfolder_name'], + invitations[selection]['id'], + ) + ) + print(msg) + break + elif action.lower() == 'r': + session.VFolder.delete_invitation(invitations[selection]['id']) + msg = ( + 'vfolder invitation rejected: {} ({})'.format( + invitations[selection]['vfolder_name'], + invitations[selection]['id'], + ) + ) + print(msg) + break + elif action.lower() == 'c': + break + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('emails', type=str, nargs=-1, required=True) +@click.option('-p', '--perm', metavar='PERMISSION', type=str, default='rw', + help='Permission to give. "ro" (read-only) / "rw" (read-write) / "wd" (write-delete).') +def share(name, emails, perm): + """Share a group folder to users with overriding permission. + + \b + NAME: Name of a (group-type) virtual folder. + EMAILS: Emails to share. + """ + with Session() as session: + try: + assert perm in ['rw', 'ro', 'wd'], 'Invalid permission: {}'.format(perm) + result = session.VFolder(name).share(perm, emails) + shared_emails = result.get('shared_emails', []) + if len(shared_emails) > 0: + print('Shared with {} permission to:'.format(perm)) + for _email in shared_emails: + print('\t- ' + _email) + else: + print('No users found. Folder is not shared.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('emails', type=str, nargs=-1, required=True) +def unshare(name, emails): + """Unshare a group folder from users. + + \b + NAME: Name of a (group-type) virtual folder. + EMAILS: Emails to share. + """ + with Session() as session: + try: + result = session.VFolder(name).unshare(emails) + unshared_emails = result.get('unshared_emails', []) + if len(unshared_emails) > 0: + print('Unshared from:') + for _email in unshared_emails: + print('\t- ' + _email) + else: + print('No users found. Folder is not unshared.') + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +def leave(name): + '''Leave the shared virutal folder. + + NAME: Name of a virtual folder + ''' + with Session() as session: + try: + vfolder_info = session.VFolder(name).info() + if vfolder_info['type'] == 'group': + print('You cannot leave a group virtual folder.') + return + if vfolder_info['is_owner']: + print('You cannot leave a virtual folder you own. Consider using delete instead.') + return + session.VFolder(name).leave() + print('Left the shared virtual folder "{}".'.format(name)) + + except Exception as e: + print_error(e) + sys.exit(1) + + +@vfolder.command() +@click.argument('name', type=str) +@click.argument('target_name', type=str) +@click.argument('target_host', type=str) +@click.option('-m', '--usage-mode', metavar='USAGE_MODE', type=str, default='general', + help='Purpose of the cloned virtual folder. ' + 'Default value is \'general\'.') +@click.option('-p', '--permission', metavar='PERMISSION', type=str, default='rw', + help='Cloned virtual folder\'s permission. ' + 'Default value is \'rw\'.') +def clone(name, target_name, target_host, usage_mode, permission): + """Clone a virtual folder. + + \b + NAME: Name of the virtual folder to clone from. + TARGET_NAME: Name of the virtual folder to clone to. + TARGET_HOST: Name of a virtual folder host to which the virtual folder will be cloned. + """ + with Session() as session: + try: + vfolder_info = session.VFolder(name).info() + if not vfolder_info['cloneable']: + print("Clone is not allowed for this virtual folder. " + "Please update the 'cloneable' option.") + return + result = session.VFolder(name).clone( + target_name, + target_host=target_host, + usage_mode=usage_mode, + permission=permission, + ) + bgtask_id = result.get('bgtask_id') + except Exception as e: + print_error(e) + sys.exit(1) + + async def clone_vfolder_tracker(bgtask_id): + print_wait( + "Cloning the vfolder... " + "(This may take a while depending on its size and number of files!)", + ) + async with AsyncSession() as session: + try: + bgtask = session.BackgroundTask(bgtask_id) + completion_msg_func = lambda: print_done("Cloning the vfolder is complete.") + async with bgtask.listen_events() as response: + # TODO: get the unit of progress from response + with tqdm(unit='bytes', disable=True) as pbar: + async for ev in response: + data = json.loads(ev.data) + if ev.event == 'bgtask_updated': + pbar.total = data['total_progress'] + pbar.write(data['message']) + pbar.update(data['current_progress'] - pbar.n) + elif ev.event == 'bgtask_failed': + error_msg = data['message'] + completion_msg_func = \ + lambda: print_fail( + f"Error during the operation: {error_msg}", + ) + elif ev.event == 'bgtask_cancelled': + completion_msg_func = \ + lambda: print_warn( + "The operation has been cancelled in the middle. " + "(This may be due to server shutdown.)", + ) + finally: + completion_msg_func() + + if bgtask_id is None: + print_done("Cloning the vfolder is complete.") + else: + asyncio_run(clone_vfolder_tracker(bgtask_id)) + + +@vfolder.command() +@click.argument('name', type=str) +@click.option('-p', '--permission', type=str, metavar='PERMISSION', + help="Folder's innate permission.") +@click.option('--set-cloneable', type=bool, metavar='BOOLEXPR', + help="A boolean-interpretable string whether a virtual folder can be cloned. " + "If not set, the cloneable property is not changed.") +def update_options(name, permission, set_cloneable): + """Update an existing virtual folder. + + \b + NAME: Name of the virtual folder to update. + """ + with Session() as session: + try: + vfolder_info = session.VFolder(name).info() + if not vfolder_info['is_owner']: + print("You cannot update virtual folder that you do not own.") + return + session.VFolder(name).update_options( + name, + permission=permission, + cloneable=set_cloneable, + ) + print_done("Updated.") + except Exception as e: + print_error(e) + sys.exit(1) diff --git a/src/ai/backend/client/compat.py b/src/ai/backend/client/compat.py new file mode 100644 index 0000000000..dc058c5a9a --- /dev/null +++ b/src/ai/backend/client/compat.py @@ -0,0 +1,98 @@ +""" +A compatibility module for backported codes from Python 3.6+ standard library. +""" + +import asyncio + +__all__ = ( + 'current_loop', + 'all_tasks', + 'asyncio_run', + 'asyncio_run_forever', +) + + +if hasattr(asyncio, 'get_running_loop'): # Python 3.7+ + current_loop = asyncio.get_running_loop +else: + current_loop = asyncio.get_event_loop + + +if hasattr(asyncio, 'all_tasks'): # Python 3.7+ + all_tasks = asyncio.all_tasks +else: + all_tasks = asyncio.Task.all_tasks # type: ignore + + +def _cancel_all_tasks(loop): + to_cancel = all_tasks(loop) + if not to_cancel: + return + for task in to_cancel: + task.cancel() + loop.run_until_complete( + asyncio.gather(*to_cancel, return_exceptions=True)) + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'unhandled exception during asyncio.run() shutdown', + 'exception': task.exception(), + 'task': task, + }) + + +def _asyncio_run(coro, *, debug=False): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.set_debug(debug) + try: + return loop.run_until_complete(coro) + finally: + try: + _cancel_all_tasks(loop) + if hasattr(loop, 'shutdown_asyncgens'): # Python 3.6+ + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + loop.stop() + loop.close() + asyncio.set_event_loop(None) + + +if hasattr(asyncio, 'run'): # Python 3.7+ + asyncio_run = asyncio.run +else: + asyncio_run = _asyncio_run + + +def asyncio_run_forever(server_context, *, debug=False): + """ + A proposed-but-not-implemented asyncio.run_forever() API based on + @vxgmichel's idea. + See discussions on https://github.com/python/asyncio/pull/465 + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.set_debug(debug) + + forever = loop.create_future() + + async def _run_forever(): + async with server_context: + try: + await forever + except asyncio.CancelledError: + pass + + try: + return loop.run_until_complete(_run_forever()) + finally: + try: + _cancel_all_tasks(loop) + if hasattr(loop, 'shutdown_asyncgens'): # Python 3.6+ + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + loop.stop() + loop.close() + asyncio.set_event_loop(None) diff --git a/src/ai/backend/client/config.py b/src/ai/backend/client/config.py new file mode 100644 index 0000000000..d912ba82fa --- /dev/null +++ b/src/ai/backend/client/config.py @@ -0,0 +1,388 @@ +import enum +import os +from pathlib import Path +import random +import re +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) + +import appdirs +from yarl import URL + +__all__ = [ + 'parse_api_version', + 'get_config', + 'set_config', + 'APIConfig', + 'API_VERSION', + 'DEFAULT_CHUNK_SIZE', + 'MAX_INFLIGHT_CHUNKS', +] + + +class Undefined(enum.Enum): + token = object() + + +_config = None +_undefined = Undefined.token + +API_VERSION = (6, '20220315') +MIN_API_VERSION = (5, '20191215') + +DEFAULT_CHUNK_SIZE = 16 * (2**20) # 16 MiB +MAX_INFLIGHT_CHUNKS = 4 + +local_state_path = Path(appdirs.user_state_dir('backend.ai', 'Lablup')) +local_cache_path = Path(appdirs.user_cache_dir('backend.ai', 'Lablup')) + + +def parse_api_version(value: str) -> Tuple[int, str]: + match = re.search(r'^v(?P\d+)\.(?P\d{8})$', value) + if match is not None: + return int(match.group(1)), match.group(2) + raise ValueError('Could not parse the given API version string', value) + + +T = TypeVar('T') + + +def default_clean(v: Union[str, Mapping]) -> T: + return cast(T, v) + + +def get_env( + key: str, + default: Union[str, Mapping, Undefined] = _undefined, + *, + clean: Callable[[Any], T] = default_clean, +) -> T: + """ + Retrieves a configuration value from the environment variables. + The given *key* is uppercased and prefixed by ``"BACKEND_"`` and then + ``"SORNA_"`` if the former does not exist. + + :param key: The key name. + :param default: The default value returned when there is no corresponding + environment variable. + :param clean: A single-argument function that is applied to the result of lookup + (in both successes and the default value for failures). + The default is returning the value as-is. + + :returns: The value processed by the *clean* function. + """ + key = key.upper() + raw = os.environ.get('BACKEND_' + key) + if raw is None: + raw = os.environ.get('SORNA_' + key) + if raw is None: + if default is _undefined: + raise KeyError(key) + result = default + else: + result = raw + return clean(result) + + +def bool_env(v: str) -> bool: + v = v.lower() + if v in ('y', 'yes', 't', 'true', '1'): + return True + if v in ('n', 'no', 'f', 'false', '0'): + return False + raise ValueError('Unrecognized value of boolean environment variable', v) + + +def _clean_urls(v: Union[URL, str]) -> List[URL]: + if isinstance(v, URL): + return [v] + urls = [] + if isinstance(v, str): + for entry in v.split(','): + url = URL(entry) + if not url.is_absolute(): + raise ValueError('URL {} is not absolute.'.format(url)) + urls.append(url) + return urls + + +def _clean_tokens(v: str) -> Tuple[str, ...]: + if not v: + return tuple() + return tuple(v.split(',')) + + +def _clean_address_map(v: Union[str, Mapping]) -> Mapping: + if isinstance(v, dict): + return v + if not isinstance(v, str): + raise ValueError( + f'Storage proxy address map has invalid type "{type(v)}", expected str or dict.', + ) + override_map = {} + for assignment in v.split(","): + try: + k, _, v = assignment.partition("=") + if k == '' or v == '': + raise ValueError + except ValueError: + raise ValueError(f"{v} is not a valid mapping expression") + else: + override_map[k] = v + return override_map + + +class APIConfig: + """ + Represents a set of API client configurations. + The access key and secret key are mandatory -- they must be set in either + environment variables or as the explicit arguments. + + :param endpoint: The URL prefix to make API requests via HTTP/HTTPS. + If this is given as ``str`` and contains multiple URLs separated by comma, + the underlying HTTP request-response facility will perform client-side + load balancing and automatic fail-over using them, assuming that all those + URLs indicates a single, same cluster. + The users of the API and CLI will get network connection errors only when + all of the given endpoints fail -- intermittent failures of a subset of endpoints + will be hidden with a little increased latency. + :param endpoint_type: Either ``"api"`` or ``"session"``. + If the endpoint type is ``"api"`` (the default if unspecified), it uses the access key and + secret key in the configuration to access the manager API server directly. + If the endpoint type is ``"session"``, it assumes the endpoint is a Backend.AI console server + which provides cookie-based authentication with username and password. + In the latter, users need to use ``backend.ai login`` and ``backend.ai logout`` to + manage their sign-in status, or the API equivalent in + :meth:`~ai.backend.client.auth.Auth.login` and + :meth:`~ai.backend.client.auth.Auth.logout` methods. + :param version: The API protocol version. + :param user_agent: A custom user-agent string which is sent to the API + server as a ``User-Agent`` HTTP header. + :param access_key: The API access key. If deliberately set to an empty string, the API + requests will be made without signatures (anonymously). + :param secret_key: The API secret key. + :param hash_type: The hash type to generate per-request authentication + signatures. + :param vfolder_mounts: A list of vfolder names (that must belong to the given + access key) to be automatically mounted upon any + :func:`Kernel.get_or_create() + ` calls. + """ + + DEFAULTS: Mapping[str, Union[str, Mapping]] = { + 'endpoint': 'https://api.backend.ai', + 'endpoint_type': 'api', + 'version': f'v{API_VERSION[0]}.{API_VERSION[1]}', + 'hash_type': 'sha256', + 'domain': 'default', + 'group': 'default', + 'storage_proxy_address_map': {}, + 'connection_timeout': '10.0', + 'read_timeout': '0', + } + """ + The default values for config parameterse settable via environment variables + xcept the access and secret keys. + """ + + _endpoints: List[URL] + _group: str + _hash_type: str + _skip_sslcert_validation: bool + + def __init__( + self, *, + endpoint: Union[URL, str] = None, + endpoint_type: str = None, + domain: str = None, + group: str = None, + storage_proxy_address_map: Mapping[str, str] = None, + version: str = None, + user_agent: str = None, + access_key: str = None, + secret_key: str = None, + hash_type: str = None, + vfolder_mounts: Iterable[str] = None, + skip_sslcert_validation: bool = None, + connection_timeout: float = None, + read_timeout: float = None, + announcement_handler: Callable[[str], None] = None, + ) -> None: + from . import get_user_agent + self._endpoints = ( + _clean_urls(endpoint) if endpoint else + get_env('ENDPOINT', self.DEFAULTS['endpoint'], clean=_clean_urls) + ) + random.shuffle(self._endpoints) + self._endpoint_type = endpoint_type if endpoint_type is not None else \ + get_env('ENDPOINT_TYPE', self.DEFAULTS['endpoint_type'], clean=str) + self._domain = domain if domain is not None else \ + get_env('DOMAIN', self.DEFAULTS['domain'], clean=str) + self._group = group if group is not None else \ + get_env('GROUP', self.DEFAULTS['group'], clean=str) + self._storage_proxy_address_map = storage_proxy_address_map \ + if storage_proxy_address_map is not None else \ + get_env( + 'OVERRIDE_STORAGE_PROXY', + self.DEFAULTS['storage_proxy_address_map'], + # The shape of this env var must be like "X1=Y1,X2=Y2" + clean=_clean_address_map, + ) + self._version = version if version is not None else \ + default_clean(self.DEFAULTS['version']) + self._user_agent = user_agent if user_agent is not None else get_user_agent() + if self._endpoint_type == 'api': + self._access_key = access_key if access_key is not None else \ + get_env('ACCESS_KEY', '') + self._secret_key = secret_key if secret_key is not None else \ + get_env('SECRET_KEY', '') + else: + self._access_key = 'dummy' + self._secret_key = 'dummy' + self._hash_type = hash_type.lower() if hash_type is not None else \ + cast(str, self.DEFAULTS['hash_type']) + arg_vfolders = set(vfolder_mounts) if vfolder_mounts else set() + env_vfolders = set(get_env('VFOLDER_MOUNTS', '', clean=_clean_tokens)) + self._vfolder_mounts = [*(arg_vfolders | env_vfolders)] + # prefer the argument flag and fallback to env if the flag is not set. + if skip_sslcert_validation: + self._skip_sslcert_validation = True + else: + self._skip_sslcert_validation = get_env( + 'SKIP_SSLCERT_VALIDATION', 'no', clean=bool_env, + ) + self._connection_timeout = connection_timeout if connection_timeout is not None else \ + get_env('CONNECTION_TIMEOUT', self.DEFAULTS['connection_timeout'], clean=float) + self._read_timeout = read_timeout if read_timeout is not None else \ + get_env('READ_TIMEOUT', self.DEFAULTS['read_timeout'], clean=float) + self._announcement_handler = announcement_handler + + @property + def is_anonymous(self) -> bool: + return self._access_key == '' + + @property + def endpoint(self) -> URL: + """ + The currently active endpoint URL. + This may change if there are multiple configured endpoints + and the current one is not accessible. + """ + return self._endpoints[0] + + @property + def endpoints(self) -> Sequence[URL]: + """All configured endpoint URLs.""" + return self._endpoints + + def rotate_endpoints(self): + if len(self._endpoints) > 1: + item = self._endpoints.pop(0) + self._endpoints.append(item) + + def load_balance_endpoints(self): + pass + + @property + def endpoint_type(self) -> str: + """ + The configured endpoint type. + """ + return self._endpoint_type + + @property + def domain(self) -> str: + """The configured domain.""" + return self._domain + + @property + def group(self) -> str: + """The configured group.""" + return self._group + + @property + def storage_proxy_address_map(self) -> Mapping[str, str]: + """The storage proxy address map for overriding.""" + return self.storage_proxy_address_map + + @property + def user_agent(self) -> str: + """The configured user agent string.""" + return self._user_agent + + @property + def access_key(self) -> str: + """The configured API access key.""" + return self._access_key + + @property + def secret_key(self) -> str: + """The configured API secret key.""" + return self._secret_key + + @property + def version(self) -> str: + """The configured API protocol version.""" + return self._version + + @property + def hash_type(self) -> str: + """The configured hash algorithm for API authentication signatures.""" + return self._hash_type + + @property + def vfolder_mounts(self) -> Sequence[str]: + """The configured auto-mounted vfolder list.""" + return self._vfolder_mounts + + @property + def skip_sslcert_validation(self) -> bool: + """Whether to skip SSL certificate validation for the API gateway.""" + return self._skip_sslcert_validation + + @property + def connection_timeout(self) -> float: + """The maximum allowed duration for making TCP connections to the server.""" + return self._connection_timeout + + @property + def read_timeout(self) -> float: + """The maximum allowed waiting time for the first byte of the response from the server.""" + return self._read_timeout + + @property + def announcement_handler(self) -> Optional[Callable[[str], None]]: + '''The announcement handler to display server-set announcements.''' + return self._announcement_handler + + +def get_config(): + """ + Returns the configuration for the current process. + If there is no explicitly set :class:`APIConfig` instance, + it will generate a new one from the current environment variables + and defaults. + """ + global _config + if _config is None: + _config = APIConfig() + return _config + + +def set_config(conf: APIConfig): + """ + Sets the configuration used throughout the current process. + """ + global _config + _config = conf diff --git a/src/ai/backend/client/exceptions.py b/src/ai/backend/client/exceptions.py new file mode 100644 index 0000000000..bcfd1e3a13 --- /dev/null +++ b/src/ai/backend/client/exceptions.py @@ -0,0 +1,68 @@ +from typing import Any +import json + +__all__ = ( + 'BackendError', + 'BackendAPIError', + 'BackendClientError', + 'APIVersionWarning', +) + + +class BackendError(Exception): + '''Exception type to catch all ai.backend-related errors.''' + + def __str__(self): + return repr(self) + + +class BackendAPIError(BackendError): + '''Exceptions returned by the API gateway.''' + + def __init__(self, status: int, reason: str, data: Any): + if isinstance(data, (str, bytes)): + try: + data = json.loads(data) + except json.JSONDecodeError: + data = { + 'type': 'https://api.backend.ai/probs/generic-error', + 'title': 'Generic Error (could not parse error string)', + 'content': data, + } + super().__init__(status, reason, data) + + @property + def status(self) -> int: + return self.args[0] + + @property + def reason(self) -> str: + return self.args[1] + + @property + def data(self) -> Any: + return self.args[2] + + +class BackendAPIVersionError(BackendError): + """ + Exception indicating that the given operation/argument is not supported + in the currently negotiated server API version. + """ + + +class BackendClientError(BackendError): + """ + Exceptions from the client library, such as argument validation + errors and connection failures. + """ + + pass + + +class APIVersionWarning(UserWarning): + """ + The warning generated if the server's API version is higher. + """ + + pass diff --git a/src/ai/backend/client/func/__init__.py b/src/ai/backend/client/func/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/client/func/admin.py b/src/ai/backend/client/func/admin.py new file mode 100644 index 0000000000..fe186d92ec --- /dev/null +++ b/src/ai/backend/client/func/admin.py @@ -0,0 +1,75 @@ +from typing import Any, Mapping, Optional + +from .base import api_function, BaseFunction +from ..exceptions import BackendAPIError +from ..request import Request +from ..session import api_session + +__all__ = ( + 'Admin', +) + + +class Admin(BaseFunction): + """ + Provides the function interface for making admin GrapQL queries. + + .. note:: + + Depending on the privilege of your API access key, you may or may not + have access to querying/mutating server-side resources of other + users. + """ + + @api_function + @classmethod + async def query( + cls, + query: str, + variables: Optional[Mapping[str, Any]] = None, + ) -> Any: + """ + Sends the GraphQL query and returns the response. + + :param query: The GraphQL query string. + :param variables: An optional key-value dictionary + to fill the interpolated template variables + in the query. + + :returns: The object parsed from the response JSON string. + """ + return await cls._query(query, variables) + + @classmethod + async def _query( + cls, + query: str, + variables: Optional[Mapping[str, Any]] = None, + ) -> Any: + """ + Internal async implementation of the query() method, + which may be reused by other functional APIs to make GQL requests. + """ + gql_query = { + 'query': query, + 'variables': variables if variables else {}, + } + if api_session.get().api_version >= (6, '20210815'): + rqst = Request('POST', '/admin/gql') + rqst.set_json(gql_query) + async with rqst.fetch() as resp: + response = await resp.json() + errors = response.get("errors", []) + if errors: + raise BackendAPIError(400, reason="Bad request", data={ + 'type': 'https://api.backend.ai/probs/graphql-error', + 'title': 'GraphQL-generated error', + 'data': errors, + }) + else: + return response["data"] + else: + rqst = Request('POST', '/admin/graphql') + rqst.set_json(gql_query) + async with rqst.fetch() as resp: + return await resp.json() diff --git a/src/ai/backend/client/func/agent.py b/src/ai/backend/client/func/agent.py new file mode 100644 index 0000000000..56b358d589 --- /dev/null +++ b/src/ai/backend/client/func/agent.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import textwrap +from typing import ( + Sequence, +) + +from ai.backend.client.output.types import ( + FieldSpec, + PaginatedResult, +) +from ai.backend.client.output.fields import agent_fields +from ai.backend.client.request import Request +from ai.backend.client.session import api_session +from ai.backend.client.pagination import generate_paginated_results +from .base import api_function, BaseFunction + +__all__ = ( + 'Agent', + 'AgentWatcher', +) + +_default_list_fields = ( + agent_fields['id'], + agent_fields['status'], + agent_fields['scaling_group'], + agent_fields['available_slots'], + agent_fields['occupied_slots'], +) + +_default_detail_fields = ( + agent_fields['id'], + agent_fields['status'], + agent_fields['scaling_group'], + agent_fields['addr'], + agent_fields['region'], + agent_fields['first_contact'], + agent_fields['cpu_cur_pct'], + agent_fields['mem_cur_bytes'], + agent_fields['available_slots'], + agent_fields['occupied_slots'], +) + + +class Agent(BaseFunction): + """ + Provides a shortcut of :func:`Admin.query() + ` that fetches various agent + information. + + .. note:: + + All methods in this function class require your API access key to + have the *admin* privilege. + """ + + @api_function + @classmethod + async def paginated_list( + cls, + status: str = 'ALIVE', + scaling_group: str = None, + *, + fields: Sequence[FieldSpec] = _default_list_fields, + page_offset: int = 0, + page_size: int = 20, + filter: str = None, + order: str = None, + ) -> PaginatedResult: + """ + Lists the keypairs. + You need an admin privilege for this operation. + """ + return await generate_paginated_results( + 'agent_list', + { + 'status': (status, 'String'), + 'scaling_group': (scaling_group, 'String'), + 'filter': (filter, 'String'), + 'order': (order, 'String'), + }, + fields, + page_size=page_size, + page_offset=page_offset, + ) + + @api_function + @classmethod + async def detail( + cls, + agent_id: str, + fields: Sequence[FieldSpec] = _default_detail_fields, + ) -> Sequence[dict]: + query = textwrap.dedent("""\ + query($agent_id: String!) { + agent(agent_id: $agent_id) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'agent_id': agent_id} + data = await api_session.get().Admin._query(query, variables) + return data['agent'] + + +class AgentWatcher(BaseFunction): + """ + Provides a shortcut of :func:`Admin.query() + ` that manipulate agent status. + + .. note:: + + All methods in this function class require you to + have the *superadmin* privilege. + """ + + @api_function + @classmethod + async def get_status(cls, agent_id: str) -> dict: + """ + Get agent and watcher status. + """ + rqst = Request('GET', '/resource/watcher') + rqst.set_json({'agent_id': agent_id}) + async with rqst.fetch() as resp: + data = await resp.json() + if 'message' in data: + return data['message'] + else: + return data + + @api_function + @classmethod + async def agent_start(cls, agent_id: str) -> dict: + """ + Start agent. + """ + rqst = Request('POST', '/resource/watcher/agent/start') + rqst.set_json({'agent_id': agent_id}) + async with rqst.fetch() as resp: + data = await resp.json() + if 'message' in data: + return data['message'] + else: + return data + + @api_function + @classmethod + async def agent_stop(cls, agent_id: str) -> dict: + """ + Stop agent. + """ + rqst = Request('POST', '/resource/watcher/agent/stop') + rqst.set_json({'agent_id': agent_id}) + async with rqst.fetch() as resp: + data = await resp.json() + if 'message' in data: + return data['message'] + else: + return data + + @api_function + @classmethod + async def agent_restart(cls, agent_id: str) -> dict: + """ + Restart agent. + """ + rqst = Request('POST', '/resource/watcher/agent/restart') + rqst.set_json({'agent_id': agent_id}) + async with rqst.fetch() as resp: + data = await resp.json() + if 'message' in data: + return data['message'] + else: + return data diff --git a/src/ai/backend/client/func/auth.py b/src/ai/backend/client/func/auth.py new file mode 100644 index 0000000000..30a90e58e0 --- /dev/null +++ b/src/ai/backend/client/func/auth.py @@ -0,0 +1,60 @@ +from .base import api_function, BaseFunction +from ..request import Request + +__all__ = ( + 'Auth', +) + + +class Auth(BaseFunction): + """ + Provides the function interface for login session management and authorization. + """ + + @api_function + @classmethod + async def login(cls, user_id: str, password: str) -> dict: + """ + Log-in into the endpoint with the given user ID and password. + It creates a server-side web session and return + a dictionary with ``"authenticated"`` boolean field and + JSON-encoded raw cookie data. + """ + rqst = Request('POST', '/server/login') + rqst.set_json({ + 'username': user_id, + 'password': password, + }) + async with rqst.fetch(anonymous=True) as resp: + data = await resp.json() + data['cookies'] = resp.raw_response.cookies + data['config'] = { + 'username': user_id, + } + return data + + @api_function + @classmethod + async def logout(cls) -> None: + """ + Log-out from the endpoint. + It clears the server-side web session. + """ + rqst = Request('POST', '/server/logout') + async with rqst.fetch() as resp: + resp.raw_response.raise_for_status() + + @api_function + @classmethod + async def update_password(cls, old_password: str, new_password: str, new_password2: str) -> dict: + """ + Update user's password. This API works only for account owner. + """ + rqst = Request('POST', '/auth/update-password') + rqst.set_json({ + 'old_password': old_password, + 'new_password': new_password, + 'new_password2': new_password2, + }) + async with rqst.fetch() as resp: + return await resp.json() diff --git a/src/ai/backend/client/func/base.py b/src/ai/backend/client/func/base.py new file mode 100644 index 0000000000..78b08e3f16 --- /dev/null +++ b/src/ai/backend/client/func/base.py @@ -0,0 +1,65 @@ +import functools +import inspect + +from ..session import api_session, AsyncSession + +__all__ = ( + 'APIFunctionMeta', + 'BaseFunction', + 'api_function', +) + + +def _wrap_method(cls, orig_name, meth): + + @functools.wraps(meth) + def _method(*args, **kwargs): + # We need to keep the original attributes so that they could be correctly + # bound to the class/instance at runtime. + func = getattr(cls, orig_name) + coro = func(*args, **kwargs) + _api_session = api_session.get() + if _api_session is None: + raise RuntimeError( + "API functions must be called " + "inside the context of a valid API session", + ) + if isinstance(_api_session, AsyncSession): + return coro + else: + if inspect.isasyncgen(coro): + return _api_session.worker_thread.execute_generator(coro) + else: + return _api_session.worker_thread.execute(coro) + + return _method + + +def api_function(meth): + """ + Mark the wrapped method as the API function method. + """ + setattr(meth, '_backend_api', True) + return meth + + +class APIFunctionMeta(type): + """ + Converts all methods marked with :func:`api_function` into + session-aware methods that are either plain Python functions + or coroutines. + """ + _async = True + + def __init__(cls, name, bases, attrs, **kwargs): + super().__init__(name, bases, attrs) + for attr_name, attr_value in attrs.items(): + if hasattr(attr_value, '_backend_api'): + orig_name = '_orig_' + attr_name + setattr(cls, orig_name, attr_value) + wrapped = _wrap_method(cls, orig_name, attr_value) + setattr(cls, attr_name, wrapped) + + +class BaseFunction(metaclass=APIFunctionMeta): + pass diff --git a/src/ai/backend/client/func/bgtask.py b/src/ai/backend/client/func/bgtask.py new file mode 100644 index 0000000000..5c32e30aeb --- /dev/null +++ b/src/ai/backend/client/func/bgtask.py @@ -0,0 +1,35 @@ +from typing import Union +from uuid import UUID + +from .base import BaseFunction +from ..request import ( + Request, + SSEContextManager, +) + + +class BackgroundTask(BaseFunction): + """ + Provides server-sent events streaming functions. + """ + + task_id: UUID + + def __init__(self, task_id: Union[UUID, str]) -> None: + self.task_id = task_id if isinstance(task_id, UUID) else UUID(task_id) + + # only supported in AsyncAPISession + def listen_events(self) -> SSEContextManager: + """ + Opens an event stream of the background task updates. + + :returns: a context manager that returns an :class:`SSEResponse` object. + """ + params = { + 'task_id': str(self.task_id), + } + request = Request( + 'GET', '/events/background-task', + params=params, + ) + return request.connect_events() diff --git a/src/ai/backend/client/func/domain.py b/src/ai/backend/client/func/domain.py new file mode 100644 index 0000000000..7aed998c7a --- /dev/null +++ b/src/ai/backend/client/func/domain.py @@ -0,0 +1,193 @@ +import textwrap +from typing import Iterable, Sequence + +from ai.backend.client.output.fields import domain_fields +from ai.backend.client.output.types import FieldSpec +from .base import api_function, BaseFunction +from ..session import api_session + +__all__ = ( + 'Domain', +) + +_default_list_fields = ( + domain_fields['name'], + domain_fields['description'], + domain_fields['is_active'], + domain_fields['created_at'], + domain_fields['total_resource_slots'], + domain_fields['allowed_vfolder_hosts'], + domain_fields['allowed_docker_registries'], + domain_fields['integration_id'], +) + +_default_detail_fields = ( + domain_fields['name'], + domain_fields['description'], + domain_fields['is_active'], + domain_fields['created_at'], + domain_fields['total_resource_slots'], + domain_fields['allowed_vfolder_hosts'], + domain_fields['allowed_docker_registries'], + domain_fields['integration_id'], +) + + +class Domain(BaseFunction): + """ + Provides a shortcut of :func:`Admin.query() + ` that fetches various domain + information. + + .. note:: + + All methods in this function class require your API access key to + have the *admin* privilege. + """ + + @api_function + @classmethod + async def list( + cls, + fields: Sequence[FieldSpec] = _default_list_fields, + ) -> Sequence[dict]: + """ + Fetches the list of domains. + + :param fields: Additional per-domain query fields to fetch. + """ + query = textwrap.dedent("""\ + query { + domains {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + data = await api_session.get().Admin._query(query) + return data['domains'] + + @api_function + @classmethod + async def detail( + cls, + name: str, + fields: Sequence[FieldSpec] = _default_detail_fields, + ) -> dict: + """ + Fetch information of a domain with name. + + :param name: Name of the domain to fetch. + :param fields: Additional per-domain query fields to fetch. + """ + query = textwrap.dedent("""\ + query($name: String) { + domain(name: $name) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'name': name} + data = await api_session.get().Admin._query(query, variables) + return data['domain'] + + @api_function + @classmethod + async def create(cls, name: str, description: str = '', is_active: bool = True, + total_resource_slots: str = None, + allowed_vfolder_hosts: Iterable[str] = None, + allowed_docker_registries: Iterable[str] = None, + integration_id: str = None, + fields: Iterable[str] = None) -> dict: + """ + Creates a new domain with the given options. + You need an admin privilege for this operation. + """ + if fields is None: + fields = ('name',) + query = textwrap.dedent("""\ + mutation($name: String!, $input: DomainInput!) { + create_domain(name: $name, props: $input) { + ok msg domain {$fields} + } + } + """) + query = query.replace('$fields', ' '.join(fields)) + variables = { + 'name': name, + 'input': { + 'description': description, + 'is_active': is_active, + 'total_resource_slots': total_resource_slots, + 'allowed_vfolder_hosts': allowed_vfolder_hosts, + 'allowed_docker_registries': allowed_docker_registries, + 'integration_id': integration_id, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['create_domain'] + + @api_function + @classmethod + async def update(cls, name: str, new_name: str = None, description: str = None, + is_active: bool = None, total_resource_slots: str = None, + allowed_vfolder_hosts: Iterable[str] = None, + allowed_docker_registries: Iterable[str] = None, + integration_id: str = None, + fields: Iterable[str] = None) -> dict: + """ + Update existing domain. + You need an admin privilege for this operation. + """ + query = textwrap.dedent("""\ + mutation($name: String!, $input: ModifyDomainInput!) { + modify_domain(name: $name, props: $input) { + ok msg + } + } + """) + variables = { + 'name': name, + 'input': { + 'name': new_name, + 'description': description, + 'is_active': is_active, + 'total_resource_slots': total_resource_slots, + 'allowed_vfolder_hosts': allowed_vfolder_hosts, + 'allowed_docker_registries': allowed_docker_registries, + 'integration_id': integration_id, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['modify_domain'] + + @api_function + @classmethod + async def delete(cls, name: str): + """ + Inactivates an existing domain. + """ + query = textwrap.dedent("""\ + mutation($name: String!) { + delete_domain(name: $name) { + ok msg + } + } + """) + variables = {'name': name} + data = await api_session.get().Admin._query(query, variables) + return data['delete_domain'] + + @api_function + @classmethod + async def purge(cls, name: str): + """ + Deletes an existing domain. + """ + query = textwrap.dedent("""\ + mutation($name: String!) { + purge_domain(name: $name) { + ok msg + } + } + """) + variables = {'name': name} + data = await api_session.get().Admin._query(query, variables) + return data['purge_domain'] diff --git a/src/ai/backend/client/func/dotfile.py b/src/ai/backend/client/func/dotfile.py new file mode 100644 index 0000000000..c0abdeb95b --- /dev/null +++ b/src/ai/backend/client/func/dotfile.py @@ -0,0 +1,143 @@ +from typing import List, Mapping, Optional + +from .base import api_function, BaseFunction +from ..request import Request + + +__all__ = ( + 'Dotfile', +) + + +class Dotfile(BaseFunction): + + @api_function + @classmethod + async def create(cls, + data: str, + path: str, + permission: str, + owner_access_key: str = None, + domain: str = None, + group: str = None, + ) -> 'Dotfile': + body = { + 'data': data, + 'path': path, + 'permission': permission, + } + if group: + body['group'] = group + if domain: + body['domain'] = domain + rqst_endpoint = '/group-config/dotfiles' + elif domain: + body['domain'] = domain + rqst_endpoint = '/domain-config/dotfiles' + else: + if owner_access_key: + body['owner_access_key'] = owner_access_key + rqst_endpoint = '/user-config/dotfiles' + + rqst = Request('POST', rqst_endpoint) + rqst.set_json(body) + async with rqst.fetch() as resp: + await resp.json() + return cls(path, owner_access_key=owner_access_key, group=group, domain=domain) + + @api_function + @classmethod + async def list_dotfiles(cls, + owner_access_key: str = None, + domain: str = None, + group: str = None, + ) -> 'List[Mapping[str, str]]': + params = {} + if group: + params['group'] = group + if domain: + params['domain'] = domain + rqst_endpoint = '/group-config/dotfiles' + elif domain: + params['domain'] = domain + rqst_endpoint = '/domain-config/dotfiles' + else: + if owner_access_key: + params['onwer_access_key'] = owner_access_key + rqst_endpoint = '/user-config/dotfiles' + + rqst = Request('GET', rqst_endpoint, params=params) + async with rqst.fetch() as resp: + return await resp.json() + + def __init__(self, path: str, owner_access_key: Optional[str] = None, + group: str = None, domain: str = None): + self.path = path + self.owner_access_key = owner_access_key + self.group = group + self.domain = domain + + @api_function + async def get(self) -> str: + params = {'path': self.path} + if self.group: + params['group'] = self.group + if self.domain: + params['domain'] = self.domain + rqst_endpoint = '/group-config/dotfiles' + elif self.domain: + params['domain'] = self.domain + rqst_endpoint = '/domain-config/dotfiles' + else: + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + rqst_endpoint = '/user-config/dotfiles' + + rqst = Request('GET', rqst_endpoint, params=params) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def update(self, data: str, permission: str): + body = { + 'data': data, + 'path': self.path, + 'permission': permission, + } + if self.group: + body['group'] = self.group + if self.domain: + body['domain'] = self.domain + rqst_endpoint = '/group-config/dotfiles' + elif self.domain: + body['domain'] = self.domain + rqst_endpoint = '/domain-config/dotfiles' + else: + if self.owner_access_key: + body['owner_access_key'] = self.owner_access_key + rqst_endpoint = '/user-config/dotfiles' + + rqst = Request('PATCH', rqst_endpoint) + rqst.set_json(body) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def delete(self): + params = {'path': self.path} + if self.group: + params['group'] = self.group + if self.domain: + params['domain'] = self.domain + rqst_endpoint = '/group-config/dotfiles' + elif self.domain: + params['domain'] = self.domain + rqst_endpoint = '/domain-config/dotfiles' + else: + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + rqst_endpoint = '/user-config/dotfiles' + + rqst = Request('DELETE', rqst_endpoint, params=params) + async with rqst.fetch() as resp: + return await resp.json() diff --git a/src/ai/backend/client/func/etcd.py b/src/ai/backend/client/func/etcd.py new file mode 100644 index 0000000000..3059148fce --- /dev/null +++ b/src/ai/backend/client/func/etcd.py @@ -0,0 +1,71 @@ +from .base import api_function, BaseFunction +from ..request import Request + +__all__ = ( + 'EtcdConfig', +) + + +class EtcdConfig(BaseFunction): + """ + Provides a way to get or set ETCD configurations. + + .. note:: + + All methods in this function class require your API access key to + have the *superadmin* privilege. + """ + + @api_function + @classmethod + async def get(cls, key: str, prefix: bool = False) -> dict: + """ + Get configuration from ETCD with given key. + + :param key: Name of the key to fetch. + :param prefix: get all keys prefixed with the give key. + """ + rqst = Request('POST', '/config/get') + rqst.set_json({ + 'key': key, + 'prefix': prefix, + }) + async with rqst.fetch() as resp: + data = await resp.json() + return data.get('result', None) + + @api_function + @classmethod + async def set(cls, key: str, value: str) -> dict: + """ + Set configuration into ETCD with given key and value. + + :param key: Name of the key to set. + :param value: Value to set. + """ + rqst = Request('POST', '/config/set') + rqst.set_json({ + 'key': key, + 'value': value, + }) + async with rqst.fetch() as resp: + data = await resp.json() + return data + + @api_function + @classmethod + async def delete(cls, key: str, prefix: bool = False) -> dict: + """ + Delete configuration from ETCD with given key. + + :param key: Name of the key to delete. + :param prefix: delete all keys prefixed with the give key. + """ + rqst = Request('POST', '/config/delete') + rqst.set_json({ + 'key': key, + 'prefix': prefix, + }) + async with rqst.fetch() as resp: + data = await resp.json() + return data diff --git a/src/ai/backend/client/func/group.py b/src/ai/backend/client/func/group.py new file mode 100644 index 0000000000..5cf712afeb --- /dev/null +++ b/src/ai/backend/client/func/group.py @@ -0,0 +1,276 @@ +import textwrap +from typing import Iterable, Sequence + +from ai.backend.client.output.fields import group_fields +from ai.backend.client.output.types import FieldSpec +from .base import api_function, BaseFunction +from ..session import api_session + +__all__ = ( + 'Group', +) + +_default_list_fields = ( + group_fields['id'], + group_fields['name'], + group_fields['is_active'], + group_fields['created_at'], + group_fields['integration_id'], +) +_default_detail_fields = ( + group_fields['id'], + group_fields['name'], + group_fields['description'], + group_fields['is_active'], + group_fields['created_at'], + group_fields['domain_name'], + group_fields['total_resource_slots'], + group_fields['allowed_vfolder_hosts'], + group_fields['integration_id'], +) + + +class Group(BaseFunction): + """ + Provides a shortcut of :func:`Group.query() + ` that fetches various group information. + + .. note:: + + All methods in this function class require your API access key to + have the *admin* privilege. + """ + + @api_function + @classmethod + async def from_name( + cls, + name: str, + *, + fields: Iterable[str] = None, + domain_name: str = None, + ) -> Sequence[dict]: + """ + Find the group(s) by its name. + It may return multiple groups when there are groups with the same name + in different domains and it is invoked with a super-admin account + without setting the domain name. + + :param domain_name: Name of domain to get groups from. + :param fields: Per-group query fields to fetch. + """ + if fields is None: + fields = _default_detail_fields + query = textwrap.dedent("""\ + query($name: String!, $domain_name: String) { + groups_by_name(name: $name, domain_name: $domain_name) {$fields} + } + """) + query = query.replace('$fields', ' '.join(fields)) + variables = { + 'name': name, + 'domain_name': domain_name, + } + data = await api_session.get().Admin._query(query, variables) + return data['groups_by_name'] + + @api_function + @classmethod + async def list( + cls, + domain_name: str, + fields: Sequence[FieldSpec] = _default_list_fields, + ) -> Sequence[dict]: + """ + Fetches the list of groups. + + :param domain_name: Name of domain to list groups. + :param fields: Per-group query fields to fetch. + """ + if fields is None: + fields = _default_list_fields + query = textwrap.dedent("""\ + query($domain_name: String) { + groups(domain_name: $domain_name) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'domain_name': domain_name} + data = await api_session.get().Admin._query(query, variables) + return data['groups'] + + @api_function + @classmethod + async def detail( + cls, + gid: str, + fields: Sequence[FieldSpec] = _default_detail_fields, + ) -> dict: + """ + Fetch information of a group with group ID. + + :param gid: ID of the group to fetch. + :param fields: Additional per-group query fields to fetch. + """ + if fields is None: + fields = _default_detail_fields + query = textwrap.dedent("""\ + query($gid: UUID!) { + group(id: $gid) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'gid': gid} + data = await api_session.get().Admin._query(query, variables) + return data['group'] + + @api_function + @classmethod + async def create(cls, domain_name: str, name: str, description: str = '', + is_active: bool = True, total_resource_slots: str = None, + allowed_vfolder_hosts: Iterable[str] = None, + integration_id: str = None, + fields: Iterable[str] = None) -> dict: + """ + Creates a new group with the given options. + You need an admin privilege for this operation. + """ + if fields is None: + fields = ('id', 'domain_name', 'name') + query = textwrap.dedent("""\ + mutation($name: String!, $input: GroupInput!) { + create_group(name: $name, props: $input) { + ok msg group {$fields} + } + } + """) + query = query.replace('$fields', ' '.join(fields)) + variables = { + 'name': name, + 'input': { + 'description': description, + 'is_active': is_active, + 'domain_name': domain_name, + 'total_resource_slots': total_resource_slots, + 'allowed_vfolder_hosts': allowed_vfolder_hosts, + 'integration_id': integration_id, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['create_group'] + + @api_function + @classmethod + async def update(cls, gid: str, name: str = None, description: str = None, + is_active: bool = None, total_resource_slots: str = None, + allowed_vfolder_hosts: Iterable[str] = None, + integration_id: str = None, + fields: Iterable[str] = None) -> dict: + """ + Update existing group. + You need an admin privilege for this operation. + """ + query = textwrap.dedent("""\ + mutation($gid: UUID!, $input: ModifyGroupInput!) { + modify_group(gid: $gid, props: $input) { + ok msg + } + } + """) + variables = { + 'gid': gid, + 'input': { + 'name': name, + 'description': description, + 'is_active': is_active, + 'total_resource_slots': total_resource_slots, + 'allowed_vfolder_hosts': allowed_vfolder_hosts, + 'integration_id': integration_id, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['modify_group'] + + @api_function + @classmethod + async def delete(cls, gid: str): + """ + Inactivates the existing group. Does not actually delete it for safety. + """ + query = textwrap.dedent("""\ + mutation($gid: UUID!) { + delete_group(gid: $gid) { + ok msg + } + } + """) + variables = {'gid': gid} + data = await api_session.get().Admin._query(query, variables) + return data['delete_group'] + + @api_function + @classmethod + async def purge(cls, gid: str): + """ + Delete the existing group. This action cannot be undone. + """ + query = textwrap.dedent("""\ + mutation($gid: UUID!) { + purge_group(gid: $gid) { + ok msg + } + } + """) + variables = {'gid': gid} + data = await api_session.get().Admin._query(query, variables) + return data['purge_group'] + + @api_function + @classmethod + async def add_users(cls, gid: str, user_uuids: Iterable[str], + fields: Iterable[str] = None) -> dict: + """ + Add users to a group. + You need an admin privilege for this operation. + """ + query = textwrap.dedent("""\ + mutation($gid: UUID!, $input: ModifyGroupInput!) { + modify_group(gid: $gid, props: $input) { + ok msg + } + } + """) + variables = { + 'gid': gid, + 'input': { + 'user_update_mode': 'add', + 'user_uuids': user_uuids, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['modify_group'] + + @api_function + @classmethod + async def remove_users(cls, gid: str, user_uuids: Iterable[str], + fields: Iterable[str] = None) -> dict: + """ + Remove users from a group. + You need an admin privilege for this operation. + """ + query = textwrap.dedent("""\ + mutation($gid: UUID!, $input: ModifyGroupInput!) { + modify_group(gid: $gid, props: $input) { + ok msg + } + } + """) + variables = { + 'gid': gid, + 'input': { + 'user_update_mode': 'remove', + 'user_uuids': user_uuids, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['modify_group'] diff --git a/src/ai/backend/client/func/image.py b/src/ai/backend/client/func/image.py new file mode 100644 index 0000000000..75a8f2e7b0 --- /dev/null +++ b/src/ai/backend/client/func/image.py @@ -0,0 +1,118 @@ +from typing import Optional, Sequence + +from ai.backend.client.output.fields import image_fields +from ai.backend.client.output.types import FieldSpec +from .base import api_function, BaseFunction +from ..request import Request +from ..session import api_session + +__all__ = ( + 'Image', +) + +_default_list_fields_admin = ( + image_fields['name'], + image_fields['registry'], + image_fields['architecture'], + image_fields['tag'], + image_fields['digest'], + image_fields['size_bytes'], + image_fields['aliases'], +) + + +class Image(BaseFunction): + """ + Provides a shortcut of :func:`Admin.query() + ` that fetches the information about + available images. + """ + + @api_function + @classmethod + async def list( + cls, + operation: bool = False, + fields: Sequence[FieldSpec] = _default_list_fields_admin, + ) -> Sequence[dict]: + """ + Fetches the list of registered images in this cluster. + """ + q = 'query($is_operation: Boolean) {' \ + ' images(is_operation: $is_operation) {' \ + ' $fields' \ + ' }' \ + '}' + q = q.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = { + 'is_operation': operation, + } + data = await api_session.get().Admin._query(q, variables) + return data['images'] + + @api_function + @classmethod + async def rescan_images(cls, registry: str): + q = 'mutation($registry: String) {' \ + ' rescan_images(registry:$registry) {' \ + ' ok msg task_id' \ + ' }' \ + '}' + variables = { + 'registry': registry, + } + data = await api_session.get().Admin._query(q, variables) + return data['rescan_images'] + + @api_function + @classmethod + async def alias_image( + cls, + alias: str, + target: str, + arch: Optional[str] = None, + ) -> dict: + q = 'mutation($alias: String!, $target: String!) {' \ + ' alias_image(alias: $alias, target: $target) {' \ + ' ok msg' \ + ' }' \ + '}' + variables = { + 'alias': alias, + 'target': target, + } + if arch: + variables = {'architecture': arch, **variables} + data = await api_session.get().Admin._query(q, variables) + return data['alias_image'] + + @api_function + @classmethod + async def dealias_image(cls, alias: str) -> dict: + q = 'mutation($alias: String!) {' \ + ' dealias_image(alias: $alias) {' \ + ' ok msg' \ + ' }' \ + '}' + variables = { + 'alias': alias, + } + data = await api_session.get().Admin._query(q, variables) + return data['dealias_image'] + + @api_function + @classmethod + async def get_image_import_form(cls) -> dict: + rqst = Request('GET', '/image/import') + async with rqst.fetch() as resp: + data = await resp.json() + return data + + @api_function + @classmethod + async def build(cls, **kwargs) -> dict: + rqst = Request('POST', '/image/import') + rqst.set_json(kwargs) + async with rqst.fetch() as resp: + data = await resp.json() + return data diff --git a/src/ai/backend/client/func/keypair.py b/src/ai/backend/client/func/keypair.py new file mode 100644 index 0000000000..2689a73737 --- /dev/null +++ b/src/ai/backend/client/func/keypair.py @@ -0,0 +1,265 @@ +from typing import ( + Any, + Dict, + Sequence, + Union, +) + +from ai.backend.client.pagination import generate_paginated_results +from ai.backend.client.session import api_session +from ai.backend.client.output.fields import keypair_fields +from ai.backend.client.output.types import FieldSpec, PaginatedResult +from .base import api_function, BaseFunction + +__all__ = ( + 'KeyPair', +) + +_default_list_fields = ( + keypair_fields['user_id'], + keypair_fields['access_key'], + keypair_fields['secret_key'], + keypair_fields['is_active'], + keypair_fields['is_admin'], + keypair_fields['created_at'], +) + +_default_detail_fields = ( + keypair_fields['user_id'], + keypair_fields['access_key'], + keypair_fields['secret_key'], + keypair_fields['is_active'], + keypair_fields['is_admin'], +) + +_default_result_fields = ( + keypair_fields['access_key'], + keypair_fields['secret_key'], +) + + +class KeyPair(BaseFunction): + """ + Provides interactions with keypairs. + """ + + def __init__(self, access_key: str): + self.access_key = access_key + + @api_function + @classmethod + async def create( + cls, + user_id: Union[int, str], + is_active: bool = True, + is_admin: bool = False, + resource_policy: str = None, + rate_limit: int = None, + fields: Sequence[FieldSpec] = _default_result_fields, + ) -> dict: + """ + Creates a new keypair with the given options. + You need an admin privilege for this operation. + """ + uid_type = 'Int!' if isinstance(user_id, int) else 'String!' + q = 'mutation($user_id: {0}, $input: KeyPairInput!) {{'.format(uid_type) + \ + ' create_keypair(user_id: $user_id, props: $input) {' \ + ' ok msg keypair { $fields }' \ + ' }' \ + '}' + q = q.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = { + 'user_id': user_id, + 'input': { + 'is_active': is_active, + 'is_admin': is_admin, + 'resource_policy': resource_policy, + 'rate_limit': rate_limit, + }, + } + data = await api_session.get().Admin._query(q, variables) + return data['create_keypair'] + + @api_function + @classmethod + async def update( + cls, + access_key: str, + is_active: bool = None, + is_admin: bool = None, + resource_policy: str = None, + rate_limit: int = None, + ) -> dict: + """ + Creates a new keypair with the given options. + You need an admin privilege for this operation. + """ + q = 'mutation($access_key: String!, $input: ModifyKeyPairInput!) {' + \ + ' modify_keypair(access_key: $access_key, props: $input) {' \ + ' ok msg' \ + ' }' \ + '}' + variables = { + 'access_key': access_key, + 'input': { + 'is_active': is_active, + 'is_admin': is_admin, + 'resource_policy': resource_policy, + 'rate_limit': rate_limit, + }, + } + data = await api_session.get().Admin._query(q, variables) + return data['modify_keypair'] + + @api_function + @classmethod + async def delete(cls, access_key: str): + """ + Deletes an existing keypair with given ACCESSKEY. + """ + q = 'mutation($access_key: String!) {' \ + ' delete_keypair(access_key: $access_key) {' \ + ' ok msg' \ + ' }' \ + '}' + variables = { + 'access_key': access_key, + } + data = await api_session.get().Admin._query(q, variables) + return data['delete_keypair'] + + @api_function + @classmethod + async def list( + cls, + user_id: Union[int, str] = None, + is_active: bool = None, + fields: Sequence[FieldSpec] = _default_list_fields, + ) -> Sequence[dict]: + """ + Lists the keypairs. + You need an admin privilege for this operation. + """ + if user_id is None: + q = 'query($is_active: Boolean) {' \ + ' keypairs(is_active: $is_active) {' \ + ' $fields' \ + ' }' \ + '}' + else: + uid_type = 'Int!' if isinstance(user_id, int) else 'String!' + q = 'query($email: {0}, $is_active: Boolean) {{'.format(uid_type) + \ + ' keypairs(email: $email, is_active: $is_active) {' \ + ' $fields' \ + ' }' \ + '}' + q = q.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables: Dict[str, Any] = { + 'is_active': is_active, + } + if user_id is not None: + variables['email'] = user_id + data = await api_session.get().Admin._query(q, variables) + return data['keypairs'] + + @api_function + @classmethod + async def paginated_list( + cls, + is_active: bool = None, + domain_name: str = None, + *, + user_id: str = None, + fields: Sequence[FieldSpec] = _default_list_fields, + page_offset: int = 0, + page_size: int = 20, + filter: str = None, + order: str = None, + ) -> PaginatedResult[dict]: + """ + Lists the keypairs. + You need an admin privilege for this operation. + """ + variables = { + 'is_active': (is_active, 'Boolean'), + 'domain_name': (domain_name, 'String'), + 'filter': (filter, 'String'), + 'order': (order, 'String'), + } + if user_id is not None: + variables['email'] = (user_id, 'String') + return await generate_paginated_results( + 'keypair_list', + variables, + fields, + page_offset=page_offset, + page_size=page_size, + ) + + @api_function + async def info(self, fields: Sequence[FieldSpec] = _default_detail_fields) -> dict: + """ + Returns the keypair's information such as resource limits. + + :param fields: Additional per-agent query fields to fetch. + + .. versionadded:: 18.12 + """ + q = 'query {' \ + ' keypair {' \ + ' $fields' \ + ' }' \ + '}' + q = q.replace('$fields', ' '.join(f.field_ref for f in fields)) + data = await api_session.get().Admin._query(q) + return data['keypair'] + + @api_function + @classmethod + async def activate(cls, access_key: str) -> dict: + """ + Activates this keypair. + You need an admin privilege for this operation. + """ + q = 'mutation($access_key: String!, $input: ModifyKeyPairInput!) {' + \ + ' modify_keypair(access_key: $access_key, props: $input) {' \ + ' ok msg' \ + ' }' \ + '}' + variables = { + 'access_key': access_key, + 'input': { + 'is_active': True, + 'is_admin': None, + 'resource_policy': None, + 'rate_limit': None, + }, + } + data = await api_session.get().Admin._query(q, variables) + return data['modify_keypair'] + + @api_function + @classmethod + async def deactivate(cls, access_key: str) -> dict: + """ + Deactivates this keypair. + Deactivated keypairs cannot make any API requests + unless activated again by an administrator. + You need an admin privilege for this operation. + """ + q = 'mutation($access_key: String!, $input: ModifyKeyPairInput!) {' + \ + ' modify_keypair(access_key: $access_key, props: $input) {' \ + ' ok msg' \ + ' }' \ + '}' + variables = { + 'access_key': access_key, + 'input': { + 'is_active': False, + 'is_admin': None, + 'resource_policy': None, + 'rate_limit': None, + }, + } + data = await api_session.get().Admin._query(q, variables) + return data['modify_keypair'] diff --git a/src/ai/backend/client/func/keypair_resource_policy.py b/src/ai/backend/client/func/keypair_resource_policy.py new file mode 100644 index 0000000000..214a5c0f0c --- /dev/null +++ b/src/ai/backend/client/func/keypair_resource_policy.py @@ -0,0 +1,184 @@ +from typing import Iterable, Sequence + +from ai.backend.client.output.fields import keypair_resource_policy_fields +from ai.backend.client.output.types import FieldSpec +from .base import api_function, BaseFunction +from ..session import api_session + +__all__ = ( + 'KeypairResourcePolicy' +) + +_default_list_fields = ( + keypair_resource_policy_fields['name'], + keypair_resource_policy_fields['created_at'], + keypair_resource_policy_fields['total_resource_slots'], + keypair_resource_policy_fields['max_concurrent_sessions'], + keypair_resource_policy_fields['max_vfolder_count'], + keypair_resource_policy_fields['max_vfolder_size'], + keypair_resource_policy_fields['idle_timeout'], + keypair_resource_policy_fields['max_containers_per_session'], + keypair_resource_policy_fields['allowed_vfolder_hosts'], +) + +_default_detail_fields = ( + keypair_resource_policy_fields['name'], + keypair_resource_policy_fields['created_at'], + keypair_resource_policy_fields['total_resource_slots'], + keypair_resource_policy_fields['max_concurrent_sessions'], + keypair_resource_policy_fields['max_vfolder_count'], + keypair_resource_policy_fields['max_vfolder_size'], + keypair_resource_policy_fields['idle_timeout'], + keypair_resource_policy_fields['max_containers_per_session'], + keypair_resource_policy_fields['allowed_vfolder_hosts'], +) + + +class KeypairResourcePolicy(BaseFunction): + """ + Provides interactions with keypair resource policy. + """ + + def __init__(self, access_key: str): + self.access_key = access_key + + @api_function + @classmethod + async def create(cls, name: str, + default_for_unspecified: int, + total_resource_slots: int, + max_concurrent_sessions: int, + max_containers_per_session: int, + max_vfolder_count: int, + max_vfolder_size: int, + idle_timeout: int, + allowed_vfolder_hosts: Sequence[str], + fields: Iterable[str] = None) -> dict: + """ + Creates a new keypair resource policy with the given options. + You need an admin privilege for this operation. + """ + if fields is None: + fields = ('name',) + q = 'mutation($name: String!, $input: CreateKeyPairResourcePolicyInput!) {' \ + + \ + ' create_keypair_resource_policy(name: $name, props: $input) {' \ + ' ok msg resource_policy { $fields }' \ + ' }' \ + '}' + q = q.replace('$fields', ' '.join(fields)) + variables = { + 'name': name, + 'input': { + 'default_for_unspecified': default_for_unspecified, + 'total_resource_slots': total_resource_slots, + 'max_concurrent_sessions': max_concurrent_sessions, + 'max_containers_per_session': max_containers_per_session, + 'max_vfolder_count': max_vfolder_count, + 'max_vfolder_size': max_vfolder_size, + 'idle_timeout': idle_timeout, + 'allowed_vfolder_hosts': allowed_vfolder_hosts, + }, + } + data = await api_session.get().Admin._query(q, variables) + return data['create_keypair_resource_policy'] + + @api_function + @classmethod + async def update(cls, name: str, + default_for_unspecified: int, + total_resource_slots: int, + max_concurrent_sessions: int, + max_containers_per_session: int, + max_vfolder_count: int, + max_vfolder_size: int, + idle_timeout: int, + allowed_vfolder_hosts: Sequence[str]) -> dict: + """ + Updates an existing keypair resource policy with the given options. + You need an admin privilege for this operation. + """ + q = 'mutation($name: String!, $input: ModifyKeyPairResourcePolicyInput!) {' \ + + \ + ' modify_keypair_resource_policy(name: $name, props: $input) {' \ + ' ok msg' \ + ' }' \ + '}' + variables = { + 'name': name, + 'input': { + 'default_for_unspecified': default_for_unspecified, + 'total_resource_slots': total_resource_slots, + 'max_concurrent_sessions': max_concurrent_sessions, + 'max_containers_per_session': max_containers_per_session, + 'max_vfolder_count': max_vfolder_count, + 'max_vfolder_size': max_vfolder_size, + 'idle_timeout': idle_timeout, + 'allowed_vfolder_hosts': allowed_vfolder_hosts, + }, + } + data = await api_session.get().Admin._query(q, variables) + return data['modify_keypair_resource_policy'] + + @api_function + @classmethod + async def delete(cls, name: str) -> dict: + """ + Deletes an existing keypair resource policy with given name. + You need an admin privilege for this operation. + """ + q = 'mutation($name: String!) {' \ + + \ + ' delete_keypair_resource_policy(name: $name) {' \ + ' ok msg' \ + ' }' \ + '}' + variables = { + 'name': name, + } + data = await api_session.get().Admin._query(q, variables) + return data['delete_keypair_resource_policy'] + + @api_function + @classmethod + async def list( + cls, + fields: Sequence[FieldSpec] = _default_list_fields, + ) -> Sequence[dict]: + ''' + Lists the keypair resource policies. + You need an admin privilege for this operation. + ''' + q = 'query {' \ + ' keypair_resource_policies {' \ + ' $fields' \ + ' }' \ + '}' + q = q.replace('$fields', ' '.join(f.field_ref for f in fields)) + data = await api_session.get().Admin._query(q) + return data['keypair_resource_policies'] + + @api_function + async def info( + self, + name: str, + fields: Sequence[FieldSpec] = _default_detail_fields, + ) -> dict: + """ + Returns the resource policy's information. + + :param fields: Additional per-agent query fields to fetch. + + .. versionadded:: 19.03 + """ + q = 'query($name: String) {' \ + ' keypair_resource_policy(name: $name) {' \ + ' $fields' \ + ' }' \ + '}' + q = q.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = { + 'name': name, + } + data = await api_session.get().Admin._query(q, variables) + return data['keypair_resource_policy'] diff --git a/src/ai/backend/client/func/manager.py b/src/ai/backend/client/func/manager.py new file mode 100644 index 0000000000..792201cd7f --- /dev/null +++ b/src/ai/backend/client/func/manager.py @@ -0,0 +1,105 @@ +from typing import Any + +from .base import api_function, BaseFunction +from ..request import Request + + +class Manager(BaseFunction): + """ + Provides controlling of the gateway/manager servers. + + .. versionadded:: 18.12 + """ + + @api_function + @classmethod + async def status(cls): + """ + Returns the current status of the configured API server. + """ + rqst = Request('GET', '/manager/status') + rqst.set_json({ + 'status': 'running', + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def freeze(cls, force_kill: bool = False): + """ + Freezes the configured API server. + Any API clients will no longer be able to create new compute sessions nor + create and modify vfolders/keypairs/etc. + This is used to enter the maintenance mode of the server for unobtrusive + manager and/or agent upgrades. + + :param force_kill: If set ``True``, immediately shuts down all running + compute sessions forcibly. If not set, clients who have running compute + session are still able to interact with them though they cannot create + new compute sessions. + """ + rqst = Request('PUT', '/manager/status') + rqst.set_json({ + 'status': 'frozen', + 'force_kill': force_kill, + }) + async with rqst.fetch(): + pass + + @api_function + @classmethod + async def unfreeze(cls): + """ + Unfreezes the configured API server so that it resumes to normal operation. + """ + rqst = Request('PUT', '/manager/status') + rqst.set_json({ + 'status': 'running', + }) + async with rqst.fetch(): + pass + + @api_function + @classmethod + async def get_announcement(cls): + ''' + Get current announcement. + ''' + rqst = Request('GET', '/manager/announcement') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def update_announcement(cls, enabled: bool = True, message: str = None): + ''' + Update (create / delete) announcement. + + :param enabled: If set ``False``, delete announcement. + :param message: Announcement message. Required if ``enabled`` is True. + ''' + rqst = Request('POST', '/manager/announcement') + rqst.set_json({ + 'enabled': enabled, + 'message': message, + }) + async with rqst.fetch(): + pass + + @api_function + @classmethod + async def scheduler_op(cls, op: str, args: Any): + ''' + Perform a scheduler operation. + + :param op: The name of scheduler operation. + :param args: Arguments specific to the given operation. + ''' + rqst = Request('POST', '/manager/scheduler/operation') + rqst.set_json({ + 'op': op, + 'args': args, + }) + async with rqst.fetch(): + pass diff --git a/src/ai/backend/client/func/resource.py b/src/ai/backend/client/func/resource.py new file mode 100644 index 0000000000..1b7c7a8fc2 --- /dev/null +++ b/src/ai/backend/client/func/resource.py @@ -0,0 +1,120 @@ +from typing import Sequence + +from .base import api_function, BaseFunction +from ..request import Request + +__all__ = ( + 'Resource' +) + + +class Resource(BaseFunction): + """ + Provides interactions with resource. + """ + + @api_function + @classmethod + async def list(cls): + """ + Lists all resource presets. + """ + rqst = Request('GET', '/resource/presets') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def check_presets(cls): + """ + Lists all resource presets in the current scaling group with additiona + information. + """ + rqst = Request('POST', '/resource/check-presets') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def get_docker_registries(cls): + """ + Lists all registered docker registries. + """ + rqst = Request('GET', '/config/docker-registries') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def usage_per_month(cls, month: str, group_ids: Sequence[str]): + """ + Get usage statistics for groups specified by `group_ids` at specific `month`. + + :param month: The month you want to get the statistics (yyyymm). + :param group_ids: Groups IDs to be included in the result. + """ + rqst = Request('GET', '/resource/usage/month') + rqst.set_json({ + 'month': month, + 'group_ids': group_ids, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def usage_per_period(cls, group_id: str, start_date: str, end_date: str): + """ + Get usage statistics for a group specified by `group_id` for time betweeen + `start_date` and `end_date`. + + :param start_date: start date in string format (yyyymmdd). + :param end_date: end date in string format (yyyymmdd). + :param group_id: Groups ID to list usage statistics. + """ + rqst = Request('GET', '/resource/usage/period') + rqst.set_json({ + 'group_id': group_id, + 'start_date': start_date, + 'end_date': end_date, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def get_resource_slots(cls): + """ + Get supported resource slots of Backend.AI server. + """ + rqst = Request('GET', '/config/resource-slots') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def get_vfolder_types(cls): + rqst = Request('GET', '/config/vfolder-types') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def recalculate_usage(cls): + rqst = Request('POST', '/resource/recalculate-usage') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def user_monthly_stats(cls): + rqst = Request('GET', '/resource/stats/user/month') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def admin_monthly_stats(cls): + rqst = Request('GET', '/resource/stats/admin/month') + async with rqst.fetch() as resp: + return await resp.json() diff --git a/src/ai/backend/client/func/scaling_group.py b/src/ai/backend/client/func/scaling_group.py new file mode 100644 index 0000000000..1c8ba831d5 --- /dev/null +++ b/src/ai/backend/client/func/scaling_group.py @@ -0,0 +1,312 @@ +import json +import textwrap +from typing import Iterable, Mapping, Sequence + +from ai.backend.client.output.fields import scaling_group_fields +from ai.backend.client.output.types import FieldSpec +from .base import api_function, BaseFunction +from ..request import Request +from ..session import api_session + +__all__ = ( + 'ScalingGroup', +) + +_default_list_fields = ( + scaling_group_fields['name'], + scaling_group_fields['description'], + scaling_group_fields['is_active'], + scaling_group_fields['created_at'], + scaling_group_fields['driver'], + scaling_group_fields['scheduler'], +) + +_default_detail_fields = ( + scaling_group_fields['name'], + scaling_group_fields['description'], + scaling_group_fields['is_active'], + scaling_group_fields['created_at'], + scaling_group_fields['driver'], + scaling_group_fields['driver_opts'], + scaling_group_fields['scheduler'], + scaling_group_fields['scheduler_opts'], +) + + +class ScalingGroup(BaseFunction): + """ + Provides getting scaling-group information required for the current user. + + The scaling-group is an opaque server-side configuration which splits the whole + cluster into several partitions, so that server administrators can apply different auto-scaling + policies and operation standards to each partition of agent sets. + """ + + def __init__(self, name: str): + self.name = name + + @api_function + @classmethod + async def list_available(cls, group: str): + """ + List available scaling groups for the current user, + considering the user, the user's domain, and the designated user group. + """ + rqst = Request( + 'GET', '/scaling-groups', + params={'group': group}, + ) + async with rqst.fetch() as resp: + data = await resp.json() + print(data) + return data['scaling_groups'] + + @api_function + @classmethod + async def list( + cls, + fields: Sequence[FieldSpec] = _default_list_fields, + ) -> Sequence[dict]: + """ + List available scaling groups for the current user, + considering the user, the user's domain, and the designated user group. + """ + query = textwrap.dedent("""\ + query($is_active: Boolean) { + scaling_groups(is_active: $is_active) { + $fields + } + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'is_active': None} + data = await api_session.get().Admin._query(query, variables) + return data['scaling_groups'] + + @api_function + @classmethod + async def detail( + cls, + name: str, + fields: Sequence[FieldSpec] = _default_detail_fields, + ) -> dict: + """ + Fetch information of a scaling group by name. + + :param name: Name of the scaling group. + :param fields: Additional per-scaling-group query fields. + """ + query = textwrap.dedent("""\ + query($name: String) { + scaling_group(name: $name) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'name': name} + data = await api_session.get().Admin._query(query, variables) + return data['scaling_group'] + + @api_function + @classmethod + async def create(cls, name: str, description: str = '', is_active: bool = True, + driver: str = None, driver_opts: Mapping[str, str] = None, + scheduler: str = None, scheduler_opts: Mapping[str, str] = None, + fields: Iterable[str] = None) -> dict: + """ + Creates a new scaling group with the given options. + """ + if fields is None: + fields = ('name',) + query = textwrap.dedent("""\ + mutation($name: String!, $input: CreateScalingGroupInput!) { + create_scaling_group(name: $name, props: $input) { + ok msg scaling_group {$fields} + } + } + """) + query = query.replace('$fields', ' '.join(fields)) + variables = { + 'name': name, + 'input': { + 'description': description, + 'is_active': is_active, + 'driver': driver, + 'driver_opts': json.dumps(driver_opts), + 'scheduler': scheduler, + 'scheduler_opts': json.dumps(scheduler_opts), + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['create_scaling_group'] + + @api_function + @classmethod + async def update(cls, name: str, description: str = '', is_active: bool = True, + driver: str = None, driver_opts: Mapping[str, str] = None, + scheduler: str = None, scheduler_opts: Mapping[str, str] = None, + fields: Iterable[str] = None) -> dict: + """ + Update existing scaling group. + """ + if fields is None: + fields = ('name',) + query = textwrap.dedent("""\ + mutation($name: String!, $input: ModifyScalingGroupInput!) { + modify_scaling_group(name: $name, props: $input) { + ok msg + } + } + """) + query = query.replace('$fields', ' '.join(fields)) + variables = { + 'name': name, + 'input': { + 'description': description, + 'is_active': is_active, + 'driver': driver, + 'driver_opts': None if driver_opts is None else json.dumps(driver_opts), + 'scheduler': scheduler, + 'scheduler_opts': None if scheduler_opts is None else json.dumps(scheduler_opts), + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['modify_scaling_group'] + + @api_function + @classmethod + async def delete(cls, name: str): + """ + Deletes an existing scaling group. + """ + query = textwrap.dedent("""\ + mutation($name: String!) { + delete_scaling_group(name: $name) { + ok msg + } + } + """) + variables = {'name': name} + data = await api_session.get().Admin._query(query, variables) + return data['delete_scaling_group'] + + @api_function + @classmethod + async def associate_domain(cls, scaling_group: str, domain: str): + """ + Associate scaling_group with domain. + + :param scaling_group: The name of a scaling group. + :param domain: The name of a domain. + """ + query = textwrap.dedent("""\ + mutation($scaling_group: String!, $domain: String!) { + associate_scaling_group_with_domain( + scaling_group: $scaling_group, domain: $domain) { + ok msg + } + } + """) + variables = {'scaling_group': scaling_group, 'domain': domain} + data = await api_session.get().Admin._query(query, variables) + return data['associate_scaling_group_with_domain'] + + @api_function + @classmethod + async def dissociate_domain(cls, scaling_group: str, domain: str): + """ + Dissociate scaling_group from domain. + + :param scaling_group: The name of a scaling group. + :param domain: The name of a domain. + """ + query = textwrap.dedent("""\ + mutation($scaling_group: String!, $domain: String!) { + disassociate_scaling_group_with_domain( + scaling_group: $scaling_group, domain: $domain) { + ok msg + } + } + """) + variables = {'scaling_group': scaling_group, 'domain': domain} + data = await api_session.get().Admin._query(query, variables) + return data['disassociate_scaling_group_with_domain'] + + @api_function + @classmethod + async def dissociate_all_domain(cls, domain: str): + """ + Dissociate all scaling_groups from domain. + + :param domain: The name of a domain. + """ + query = textwrap.dedent("""\ + mutation($domain: String!) { + disassociate_all_scaling_groups_with_domain(domain: $domain) { + ok msg + } + } + """) + variables = {'domain': domain} + data = await api_session.get().Admin._query(query, variables) + return data['disassociate_all_scaling_groups_with_domain'] + + @api_function + @classmethod + async def associate_group(cls, scaling_group: str, group_id: str): + """ + Associate scaling_group with group. + + :param scaling_group: The name of a scaling group. + :param group_id: The ID of a group. + """ + query = textwrap.dedent("""\ + mutation($scaling_group: String!, $user_group: UUID!) { + associate_scaling_group_with_user_group( + scaling_group: $scaling_group, user_group: $user_group) { + ok msg + } + } + """) + variables = {'scaling_group': scaling_group, 'user_group': group_id} + data = await api_session.get().Admin._query(query, variables) + return data['associate_scaling_group_with_user_group'] + + @api_function + @classmethod + async def dissociate_group(cls, scaling_group: str, group_id: str): + """ + Dissociate scaling_group from group. + + :param scaling_group: The name of a scaling group. + :param group_id: The ID of a group. + """ + query = textwrap.dedent("""\ + mutation($scaling_group: String!, $user_group: String!) { + disassociate_scaling_group_with_user_group( + scaling_group: $scaling_group, user_group: $user_group) { + ok msg + } + } + """) + variables = {'scaling_group': scaling_group, 'user_group': group_id} + data = await api_session.get().Admin._query(query, variables) + return data['disassociate_scaling_group_with_user_group'] + + @api_function + @classmethod + async def dissociate_all_group(cls, group_id: str): + """ + Dissociate all scaling_groups from group. + + :param group_id: The ID of a group. + """ + query = textwrap.dedent("""\ + mutation($group_id: UUID!) { + disassociate_all_scaling_groups_with_group(user_group: $group_id) { + ok msg + } + } + """) + variables = {'group_id': group_id} + data = await api_session.get().Admin._query(query, variables) + return data['disassociate_all_scaling_groups_with_group'] diff --git a/src/ai/backend/client/func/server_log.py b/src/ai/backend/client/func/server_log.py new file mode 100644 index 0000000000..9a0d4b9e75 --- /dev/null +++ b/src/ai/backend/client/func/server_log.py @@ -0,0 +1,43 @@ +from typing import ( + Union, + Sequence, + Mapping, +) + +from .base import api_function, BaseFunction +from ..request import Request + +__all__ = ( + 'ServerLog', +) + + +class ServerLog(BaseFunction): + ''' + Provides a shortcut of :func:`Admin.query() + ` that fetches various server logs. + ''' + + @api_function + @classmethod + async def list( + cls, + mark_read: bool = False, + page_size: int = 20, + page_no: int = 1, + ) -> Sequence[dict]: + ''' + Fetches server (error) logs. + + :param mark_read: Mark read flog for server logs being fetched. + :param page_size: Number of logs to fetch (from latest log). + :param page_no: Page number to fetch. + ''' + params: Mapping[str, Union[str, int]] = { + 'mark_read': str(mark_read), + 'page_size': page_size, + 'page_no': page_no, + } + rqst = Request('GET', '/logs/error', params=params) + async with rqst.fetch() as resp: + return await resp.json() diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py new file mode 100644 index 0000000000..756cc92246 --- /dev/null +++ b/src/ai/backend/client/func/session.py @@ -0,0 +1,993 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +import secrets +import tarfile +import tempfile +from typing import ( + Any, + AsyncIterator, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Sequence, + Union, + cast, +) +from uuid import UUID + +import aiohttp +from aiohttp import hdrs +from tqdm import tqdm + +from ai.backend.client.output.fields import session_fields +from ai.backend.client.output.types import FieldSpec, PaginatedResult +from .base import api_function, BaseFunction +from ..compat import current_loop +from ..config import DEFAULT_CHUNK_SIZE +from ..exceptions import BackendClientError +from ..pagination import generate_paginated_results +from ..request import ( + Request, AttachedFile, + WebSocketResponse, + SSEContextManager, + WebSocketContextManager, +) +from ..session import api_session +from ..utils import ProgressReportingReader +from ..types import Undefined, undefined +from ..versioning import get_naming, get_id_or_name + +__all__ = ( + 'ComputeSession', +) + +_default_list_fields = ( + session_fields['session_id'], + session_fields['image'], + session_fields['type'], + session_fields['status'], + session_fields['status_info'], + session_fields['status_changed'], + session_fields['result'], +) + + +def drop(d: Mapping[str, Any], value_to_drop: Any) -> Mapping[str, Any]: + modified: Dict[str, Any] = {} + for k, v in d.items(): + if isinstance(v, Mapping) or isinstance(v, dict): + modified[k] = drop(v, value_to_drop) + elif v != value_to_drop: + modified[k] = v + return modified + + +class ComputeSession(BaseFunction): + """ + Provides various interactions with compute sessions in Backend.AI. + + The term 'kernel' is now deprecated and we prefer 'compute sessions'. + However, for historical reasons and to avoid confusion with client sessions, we + keep the backward compatibility with the naming of this API function class. + + For multi-container sessions, all methods take effects to the master container + only, except :func:`~ComputeSession.destroy` and :func:`~ComputeSession.restart` methods. + So it is the user's responsibility to distribute uploaded files to multiple + containers using explicit copies or virtual folders which are commonly mounted to + all containers belonging to the same compute session. + """ + + id: Optional[UUID] + name: Optional[str] + owner_access_key: Optional[str] + created: bool + status: str + service_ports: List[str] + domain: str + group: str + + @api_function + @classmethod + async def paginated_list( + cls, + status: str = None, + access_key: str = None, + *, + fields: Sequence[FieldSpec] = _default_list_fields, + page_offset: int = 0, + page_size: int = 20, + filter: str = None, + order: str = None, + ) -> PaginatedResult[dict]: + """ + Fetches the list of users. Domain admins can only get domain users. + + :param is_active: Fetches active or inactive users only if not None. + :param fields: Additional per-user query fields to fetch. + """ + return await generate_paginated_results( + 'compute_session_list', + { + 'status': (status, 'String'), + 'access_key': (access_key, 'String'), + 'filter': (filter, 'String'), + 'order': (order, 'String'), + }, + fields, + page_offset=page_offset, + page_size=page_size, + ) + + @api_function + @classmethod + async def hello(cls) -> str: + rqst = Request('GET', '/') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def get_task_logs( + cls, task_id: str, *, + chunk_size: int = 8192, + ) -> AsyncIterator[bytes]: + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request('GET', f'/{prefix}/_/logs', params={ + 'taskId': task_id, + }) + async with rqst.fetch() as resp: + while True: + chunk = await resp.raw_response.content.read(chunk_size) + if not chunk: + break + yield chunk + + @api_function + @classmethod + async def get_or_create( + cls, + image: str, *, + name: str = None, + type_: str = 'interactive', + starts_at: str = None, + enqueue_only: bool = False, + max_wait: int = 0, + no_reuse: bool = False, + dependencies: Sequence[str] = None, + callback_url: Optional[str] = None, + mounts: List[str] = None, + mount_map: Mapping[str, str] = None, + envs: Mapping[str, str] = None, + startup_command: str = None, + resources: Mapping[str, int] = None, + resource_opts: Mapping[str, int] = None, + cluster_size: int = 1, + cluster_mode: Literal['single-node', 'multi-node'] = 'single-node', + domain_name: str = None, + group_name: str = None, + bootstrap_script: str = None, + tag: str = None, + scaling_group: str = None, + owner_access_key: str = None, + preopen_ports: List[int] = None, + assign_agent: List[str] = None, + ) -> ComputeSession: + """ + Get-or-creates a compute session. + If *name* is ``None``, it creates a new compute session as long as + the server has enough resources and your API key has remaining quota. + If *name* is a valid string and there is an existing compute session + with the same token and the same *image*, then it returns the :class:`ComputeSession` + instance representing the existing session. + + :param image: The image name and tag for the compute session. + Example: ``python:3.6-ubuntu``. + Check out the full list of available images in your server using (TODO: + new API). + :param name: A client-side (user-defined) identifier to distinguish the session among currently + running sessions. + It may be used to seamlessly reuse the session already created. + + .. versionchanged:: 19.12.0 + + Renamed from ``clientSessionToken``. + :param type_: Either ``"interactive"`` (default) or ``"batch"``. + + .. versionadded:: 19.09.0 + :param enqueue_only: Just enqueue the session creation request and return immediately, + without waiting for its startup. (default: ``false`` to preserve the legacy + behavior) + + .. versionadded:: 19.09.0 + :param max_wait: The time to wait for session startup. If the cluster resource + is being fully utilized, this waiting time can be arbitrarily long due to + job queueing. If the timeout reaches, the returned *status* field becomes + ``"TIMEOUT"``. Still in this case, the session may start in the future. + + .. versionadded:: 19.09.0 + :param no_reuse: Raises an explicit error if a session with the same *image* and + the same *name* already exists instead of returning the information + of it. + + .. versionadded:: 19.09.0 + :param mounts: The list of vfolder names that belongs to the currrent API + access key. + :param mount_map: Mapping which contains custom path to mount vfolder. + Key and value of this map should be vfolder name and custom path. + Defalut mounts or relative paths are under /home/work. + If you want different paths, names should be absolute paths. + The target mount path of vFolders should not overlap with the linux system folders. + vFolders which has a dot(.) prefix in its name are not affected. + :param envs: The environment variables which always bypasses the jail policy. + :param resources: The resource specification. (TODO: details) + :param cluster_size: The number of containers in this compute session. + Must be at least 1. + + .. versionadded:: 19.09.0 + .. versionchanged:: 20.09.0 + :param cluster_mode: Set the clustering mode whether to use distributed + nodes or a single node to spawn multiple containers for the new session. + + .. versionadded:: 20.09.0 + :param tag: An optional string to annotate extra information. + :param owner: An optional access key that owns the created session. (Only + available to administrators) + + :returns: The :class:`ComputeSession` instance. + """ + if name is not None: + assert 4 <= len(name) <= 64, \ + 'Client session token should be 4 to 64 characters long.' + else: + name = f'pysdk-{secrets.token_hex(5)}' + if mounts is None: + mounts = [] + if mount_map is None: + mount_map = {} + if resources is None: + resources = {} + if resource_opts is None: + resource_opts = {} + if domain_name is None: + # Even if config.domain is None, it can be guessed in the manager by user information. + domain_name = api_session.get().config.domain + if group_name is None: + group_name = api_session.get().config.group + + mounts.extend(api_session.get().config.vfolder_mounts) + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request('POST', f'/{prefix}') + params: Dict[str, Any] = { + 'tag': tag, + get_naming(api_session.get().api_version, 'name_arg'): name, + 'config': { + 'mounts': mounts, + 'environ': envs, + 'resources': resources, + 'resource_opts': resource_opts, + 'scalingGroup': scaling_group, + }, + } + if api_session.get().api_version >= (6, '20220315'): + params['dependencies'] = dependencies + params['callback_url'] = callback_url + if api_session.get().api_version >= (6, '20200815'): + params['clusterSize'] = cluster_size + params['clusterMode'] = cluster_mode + else: + params['config']['clusterSize'] = cluster_size + if api_session.get().api_version >= (5, '20191215'): + params['starts_at'] = starts_at + params['bootstrap_script'] = bootstrap_script + if assign_agent is not None: + params['config'].update({ + 'mount_map': mount_map, + 'preopen_ports': preopen_ports, + 'agentList': assign_agent, + }) + else: + params['config'].update({ + 'mount_map': mount_map, + 'preopen_ports': preopen_ports, + }) + if api_session.get().api_version >= (4, '20190615'): + params.update({ + 'owner_access_key': owner_access_key, + 'domain': domain_name, + 'group': group_name, + 'type': type_, + 'enqueueOnly': enqueue_only, + 'maxWaitSeconds': max_wait, + 'reuseIfExists': not no_reuse, + 'startupCommand': startup_command, + }) + if api_session.get().api_version > (4, '20181215'): + params['image'] = image + else: + params['lang'] = image + rqst.set_json(params) + async with rqst.fetch() as resp: + data = await resp.json() + o = cls(name, owner_access_key) # type: ignore + if api_session.get().api_version[0] >= 5: + o.id = UUID(data['sessionId']) + o.created = data.get('created', True) # True is for legacy + o.status = data.get('status', 'RUNNING') + o.service_ports = data.get('servicePorts', []) + o.domain = domain_name + o.group = group_name + return o + + @api_function + @classmethod + async def create_from_template( + cls, + template_id: str, *, + name: Union[str, Undefined] = undefined, + type_: Union[str, Undefined] = undefined, + starts_at: str = None, + enqueue_only: Union[bool, Undefined] = undefined, + max_wait: Union[int, Undefined] = undefined, + dependencies: Sequence[str] = None, # cannot be stored in templates + no_reuse: Union[bool, Undefined] = undefined, + image: Union[str, Undefined] = undefined, + mounts: Union[List[str], Undefined] = undefined, + mount_map: Union[Mapping[str, str], Undefined] = undefined, + envs: Union[Mapping[str, str], Undefined] = undefined, + startup_command: Union[str, Undefined] = undefined, + resources: Union[Mapping[str, int], Undefined] = undefined, + resource_opts: Union[Mapping[str, int], Undefined] = undefined, + cluster_size: Union[int, Undefined] = undefined, + cluster_mode: Union[Literal['single-node', 'multi-node'], Undefined] = undefined, + domain_name: Union[str, Undefined] = undefined, + group_name: Union[str, Undefined] = undefined, + bootstrap_script: Union[str, Undefined] = undefined, + tag: Union[str, Undefined] = undefined, + scaling_group: Union[str, Undefined] = undefined, + owner_access_key: Union[str, Undefined] = undefined, + ) -> ComputeSession: + """ + Get-or-creates a compute session from template. + All other parameters provided will be overwritten to template, including + vfolder mounts (not appended!). + If *name* is ``None``, it creates a new compute session as long as + the server has enough resources and your API key has remaining quota. + If *name* is a valid string and there is an existing compute session + with the same token and the same *image*, then it returns the :class:`ComputeSession` + instance representing the existing session. + + :param template_id: Task template to apply to compute session. + :param image: The image name and tag for the compute session. + Example: ``python:3.6-ubuntu``. + Check out the full list of available images in your server using (TODO: + new API). + :param name: A client-side (user-defined) identifier to distinguish the session among currently + running sessions. + It may be used to seamlessly reuse the session already created. + + .. versionchanged:: 19.12.0 + + Renamed from ``clientSessionToken``. + :param type_: Either ``"interactive"`` (default) or ``"batch"``. + + .. versionadded:: 19.09.0 + :param enqueue_only: Just enqueue the session creation request and return immediately, + without waiting for its startup. (default: ``false`` to preserve the legacy + behavior) + + .. versionadded:: 19.09.0 + :param max_wait: The time to wait for session startup. If the cluster resource + is being fully utilized, this waiting time can be arbitrarily long due to + job queueing. If the timeout reaches, the returned *status* field becomes + ``"TIMEOUT"``. Still in this case, the session may start in the future. + + .. versionadded:: 19.09.0 + :param no_reuse: Raises an explicit error if a session with the same *image* and + the same *name* already exists instead of returning the information + of it. + + .. versionadded:: 19.09.0 + :param mounts: The list of vfolder names that belongs to the currrent API + access key. + :param mount_map: Mapping which contains custom path to mount vfolder. + Key and value of this map should be vfolder name and custom path. + Defalut mounts or relative paths are under /home/work. + If you want different paths, names should be absolute paths. + The target mount path of vFolders should not overlap with the linux system folders. + vFolders which has a dot(.) prefix in its name are not affected. + :param envs: The environment variables which always bypasses the jail policy. + :param resources: The resource specification. (TODO: details) + :param cluster_size: The number of containers in this compute session. + Must be at least 1. + + .. versionadded:: 19.09.0 + .. versionchanged:: 20.09.0 + :param cluster_mode: Set the clustering mode whether to use distributed + nodes or a single node to spawn multiple containers for the new session. + + .. versionadded:: 20.09.0 + :param tag: An optional string to annotate extra information. + :param owner: An optional access key that owns the created session. (Only + available to administrators) + + :returns: The :class:`ComputeSession` instance. + """ + if name is not undefined: + assert 4 <= len(name) <= 64, \ + 'Client session token should be 4 to 64 characters long.' + else: + name = f'pysdk-{secrets.token_urlsafe(8)}' + + if domain_name is undefined: + # Even if config.domain is None, it can be guessed in the manager by user information. + domain_name = api_session.get().config.domain + if group_name is undefined: + group_name = api_session.get().config.group + if mounts is undefined: + mounts = [] + if api_session.get().config.vfolder_mounts: + mounts.extend(api_session.get().config.vfolder_mounts) + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request('POST', f'/{prefix}/_/create-from-template') + params: Dict[str, Any] + params = { + 'template_id': template_id, + 'tag': tag, + 'image': image, + 'domain': domain_name, + 'group': group_name, + get_naming(api_session.get().api_version, 'name_arg'): name, + 'bootstrap_script': bootstrap_script, + 'enqueueOnly': enqueue_only, + 'maxWaitSeconds': max_wait, + 'reuseIfExists': not no_reuse, + 'startupCommand': startup_command, + 'owner_access_key': owner_access_key, + 'type': type_, + 'starts_at': starts_at, + 'config': { + 'mounts': mounts, + 'mount_map': mount_map, + 'environ': envs, + 'resources': resources, + 'resource_opts': resource_opts, + 'scalingGroup': scaling_group, + }, + } + if api_session.get().api_version >= (6, '20200815'): + params['clusterSize'] = cluster_size + params['clusterMode'] = cluster_mode + else: + params['config']['clusterSize'] = cluster_size + params = cast(Dict[str, Any], drop(params, undefined)) + rqst.set_json(params) + async with rqst.fetch() as resp: + data = await resp.json() + o = cls(name, owner_access_key if owner_access_key is not undefined else None) + if api_session.get().api_version[0] >= 5: + o.id = UUID(data['sessionId']) + o.created = data.get('created', True) # True is for legacy + o.status = data.get('status', 'RUNNING') + o.service_ports = data.get('servicePorts', []) + o.domain = domain_name + o.group = group_name + return o + + def __init__(self, name: str, owner_access_key: str = None) -> None: + self.id = None + self.name = name + self.owner_access_key = owner_access_key + + @classmethod + def from_session_id(cls, session_id: UUID) -> ComputeSession: + o = cls(None, None) # type: ignore + o.id = session_id + return o + + def get_session_identity_params(self) -> Mapping[str, str]: + if self.id: + identity_params = { + 'sessionId': str(self.id), + } + else: + assert self.name is not None + identity_params = { + 'sessionName': self.name, + } + if self.owner_access_key: + identity_params['owner_access_key'] = self.owner_access_key + return identity_params + + @api_function + async def destroy(self, *, forced: bool = False): + """ + Destroys the compute session. + Since the server literally kills the container(s), all ongoing executions are + forcibly interrupted. + """ + params = {} + if self.owner_access_key is not None: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + if forced: + params['forced'] = 'true' + rqst = Request( + 'DELETE', f'/{prefix}/{self.name}', + params=params, + ) + async with rqst.fetch() as resp: + if resp.status == 200: + return await resp.json() + + @api_function + async def restart(self): + """ + Restarts the compute session. + The server force-destroys the current running container(s), but keeps their + temporary scratch directories intact. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request( + 'PATCH', f'/{prefix}/{self.name}', + params=params, + ) + async with rqst.fetch(): + pass + + @api_function + async def rename(self, new_id): + """ + Renames Session ID of running compute session. + """ + params = {'name': new_id} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request( + 'POST', f'/{prefix}/{self.name}/rename', + params=params, + ) + async with rqst.fetch(): + pass + + @api_function + async def interrupt(self): + """ + Tries to interrupt the current ongoing code execution. + This may fail without any explicit errors depending on the code being + executed. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request( + 'POST', f'/{prefix}/{self.name}/interrupt', + params=params, + ) + async with rqst.fetch(): + pass + + @api_function + async def complete(self, code: str, opts: dict = None) -> Iterable[str]: + """ + Gets the auto-completion candidates from the given code string, + as if a user has pressed the tab key just after the code in + IDEs. + + Depending on the language of the compute session, this feature + may not be supported. Unsupported sessions returns an empty list. + + :param code: An (incomplete) code text. + :param opts: Additional information about the current cursor position, + such as row, col, line and the remainder text. + + :returns: An ordered list of strings. + """ + opts = {} if opts is None else opts + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request( + 'POST', f'/{prefix}/{self.name}/complete', + params=params, + ) + rqst.set_json({ + 'code': code, + 'options': { + 'row': int(opts.get('row', 0)), + 'col': int(opts.get('col', 0)), + 'line': opts.get('line', ''), + 'post': opts.get('post', ''), + }, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def get_info(self): + """ + Retrieves a brief information about the compute session. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request( + 'GET', f'/{prefix}/{self.name}', + params=params, + ) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def get_logs(self): + """ + Retrieves the console log of the compute session container. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request( + 'GET', f'/{prefix}/{self.name}/logs', + params=params, + ) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def execute(self, run_id: str = None, + code: str = None, + mode: str = 'query', + opts: dict = None): + """ + Executes a code snippet directly in the compute session or sends a set of + build/clean/execute commands to the compute session. + + For more details about using this API, please refer :doc:`the official API + documentation `. + + :param run_id: A unique identifier for a particular run loop. In the + first call, it may be ``None`` so that the server auto-assigns one. + Subsequent calls must use the returned ``runId`` value to request + continuation or to send user inputs. + :param code: A code snippet as string. In the continuation requests, it + must be an empty string. When sending user inputs, this is where the + user input string is stored. + :param mode: A constant string which is one of ``"query"``, ``"batch"``, + ``"continue"``, and ``"user-input"``. + :param opts: A dict for specifying additional options. Mainly used in the + batch mode to specify build/clean/execution commands. + See :ref:`the API object reference ` + for details. + + :returns: :ref:`An execution result object ` + """ + opts = opts if opts is not None else {} + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + if mode in {'query', 'continue', 'input'}: + assert code is not None, \ + 'The code argument must be a valid string even when empty.' + rqst = Request( + 'POST', f'/{prefix}/{self.name}', + params=params, + ) + rqst.set_json({ + 'mode': mode, + 'code': code, + 'runId': run_id, + }) + elif mode == 'batch': + rqst = Request( + 'POST', f'/{prefix}/{self.name}', + params=params, + ) + rqst.set_json({ + 'mode': mode, + 'code': code, + 'runId': run_id, + 'options': { + 'clean': opts.get('clean', None), + 'build': opts.get('build', None), + 'buildLog': bool(opts.get('buildLog', False)), + 'exec': opts.get('exec', None), + }, + }) + elif mode == 'complete': + rqst = Request( + 'POST', f'/{prefix}/{self.name}', + params=params, + ) + rqst.set_json({ + 'code': code, + 'options': { + 'row': int(opts.get('row', 0)), + 'col': int(opts.get('col', 0)), + 'line': opts.get('line', ''), + 'post': opts.get('post', ''), + }, + }) + else: + raise BackendClientError('Invalid execution mode: {0}'.format(mode)) + async with rqst.fetch() as resp: + return (await resp.json())['result'] + + @api_function + async def upload(self, files: Sequence[Union[str, Path]], + basedir: Union[str, Path] = None, + show_progress: bool = False): + """ + Uploads the given list of files to the compute session. + You may refer them in the batch-mode execution or from the code + executed in the server afterwards. + + :param files: The list of file paths in the client-side. + If the paths include directories, the location of them in the compute + session is calculated from the relative path to *basedir* and all + intermediate parent directories are automatically created if not exists. + + For example, if a file path is ``/home/user/test/data.txt`` (or + ``test/data.txt``) where *basedir* is ``/home/user`` (or the current + working directory is ``/home/user``), the uploaded file is located at + ``/home/work/test/data.txt`` in the compute session container. + :param basedir: The directory prefix where the files reside. + The default value is the current working directory. + :param show_progress: Displays a progress bar during uploads. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + base_path = ( + Path.cwd() if basedir is None + else Path(basedir).resolve() + ) + files = [Path(file).resolve() for file in files] + total_size = 0 + for file_path in files: + total_size += Path(file_path).stat().st_size + tqdm_obj = tqdm(desc='Uploading files', + unit='bytes', unit_scale=True, + total=total_size, + disable=not show_progress) + with tqdm_obj: + attachments = [] + for file_path in files: + try: + attachments.append(AttachedFile( + str(Path(file_path).relative_to(base_path)), + ProgressReportingReader(str(file_path), + tqdm_instance=tqdm_obj), + 'application/octet-stream', + )) + except ValueError: + msg = 'File "{0}" is outside of the base directory "{1}".' \ + .format(file_path, base_path) + raise ValueError(msg) from None + + rqst = Request( + 'POST', f'/{prefix}/{self.name}/upload', + params=params, + ) + rqst.attach_files(attachments) + async with rqst.fetch() as resp: + return resp + + @api_function + async def download(self, files: Sequence[Union[str, Path]], + dest: Union[str, Path] = '.', + show_progress: bool = False): + """ + Downloads the given list of files from the compute session. + + :param files: The list of file paths in the compute session. + If they are relative paths, the path is calculated from + ``/home/work`` in the compute session container. + :param dest: The destination directory in the client-side. + :param show_progress: Displays a progress bar during downloads. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request( + 'GET', f'/{prefix}/{self.name}/download', + params=params, + ) + rqst.set_json({ + 'files': [*map(str, files)], + }) + file_names = [] + async with rqst.fetch() as resp: + loop = current_loop() + tqdm_obj = tqdm(desc='Downloading files', + unit='bytes', unit_scale=True, + total=resp.content.total_bytes, + disable=not show_progress) + reader = aiohttp.MultipartReader.from_response(resp.raw_response) + with tqdm_obj as pbar: + while True: + part = cast(aiohttp.BodyPartReader, await reader.next()) + if part is None: + break + assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() == 'identity' + assert part.headers.get(hdrs.CONTENT_TRANSFER_ENCODING, 'binary').lower() in ( + 'binary', '8bit', '7bit', + ) + fp = tempfile.NamedTemporaryFile(suffix='.tar', + delete=False) + while True: + chunk = await part.read_chunk(DEFAULT_CHUNK_SIZE) + if not chunk: + break + await loop.run_in_executor(None, lambda: fp.write(chunk)) + pbar.update(len(chunk)) + fp.close() + with tarfile.open(fp.name) as tarf: + tarf.extractall(path=dest) + file_names.extend(tarf.getnames()) + os.unlink(fp.name) + return {'file_names': file_names} + + @api_function + async def list_files(self, path: Union[str, Path] = '.'): + """ + Gets the list of files in the given path inside the compute session + container. + + :param path: The directory path in the compute session. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request( + 'GET', f'/{prefix}/{self.name}/files', + params=params, + ) + rqst.set_json({ + 'path': path, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def stream_app_info(self): + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + id_or_name = get_id_or_name(api_session.get().api_version, self) + api_rqst = Request( + 'GET', f'/stream/{prefix}/{id_or_name}/apps', + params=params, + ) + async with api_rqst.fetch() as resp: + return await resp.json() + + # only supported in AsyncAPISession + def listen_events(self, scope: Literal['*', 'session', 'kernel'] = '*') -> SSEContextManager: + """ + Opens the stream of the kernel lifecycle events. + Only the master kernel of each session is monitored. + + :returns: a :class:`StreamEvents` object. + """ + if api_session.get().api_version[0] >= 6: + request = Request( + 'GET', '/events/session', + params={ + **self.get_session_identity_params(), + 'scope': scope, + }, + ) + else: + assert self.name is not None + params = { + get_naming(api_session.get().api_version, 'event_name_arg'): self.name, + } + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + path = get_naming(api_session.get().api_version, 'session_events_path') + request = Request( + 'GET', path, + params=params, + ) + return request.connect_events() + + stream_events = listen_events # legacy alias + + # only supported in AsyncAPISession + def stream_pty(self) -> WebSocketContextManager: + """ + Opens a pseudo-terminal of the kernel (if supported) streamed via + websockets. + + :returns: a :class:`StreamPty` object. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + id_or_name = get_id_or_name(api_session.get().api_version, self) + request = Request( + 'GET', f'/stream/{prefix}/{id_or_name}/pty', + params=params, + ) + return request.connect_websocket(response_cls=StreamPty) + + # only supported in AsyncAPISession + def stream_execute(self, code: str = '', *, + mode: str = 'query', + opts: dict = None) -> WebSocketContextManager: + """ + Executes a code snippet in the streaming mode. + Since the returned websocket represents a run loop, there is no need to + specify *run_id* explicitly. + """ + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + prefix = get_naming(api_session.get().api_version, 'path') + id_or_name = get_id_or_name(api_session.get().api_version, self) + opts = {} if opts is None else opts + if mode == 'query': + opts = {} + elif mode == 'batch': + opts = { + 'clean': opts.get('clean', None), + 'build': opts.get('build', None), + 'buildLog': bool(opts.get('buildLog', False)), + 'exec': opts.get('exec', None), + } + else: + msg = 'Invalid stream-execution mode: {0}'.format(mode) + raise BackendClientError(msg) + request = Request( + 'GET', f'/stream/{prefix}/{id_or_name}/execute', + params=params, + ) + + async def send_code(ws): + await ws.send_json({ + 'code': code, + 'mode': mode, + 'options': opts, + }) + + return request.connect_websocket(on_enter=send_code) + + +class StreamPty(WebSocketResponse): + """ + A derivative class of :class:`~ai.backend.client.request.WebSocketResponse` which + provides additional functions to control the terminal. + """ + + __slots__ = ('ws', ) + + async def resize(self, rows, cols): + await self.ws.send_str(json.dumps({ + 'type': 'resize', + 'rows': rows, + 'cols': cols, + })) + + async def restart(self): + await self.ws.send_str(json.dumps({ + 'type': 'restart', + })) diff --git a/src/ai/backend/client/func/session_template.py b/src/ai/backend/client/func/session_template.py new file mode 100644 index 0000000000..d68e4594b6 --- /dev/null +++ b/src/ai/backend/client/func/session_template.py @@ -0,0 +1,82 @@ +from typing import Any, List, Mapping + +from .base import api_function, BaseFunction +from ..request import Request +from ..session import api_session + +__all__ = ( + 'SessionTemplate', +) + + +class SessionTemplate(BaseFunction): + + @api_function + @classmethod + async def create(cls, + template: str, + domain_name: str = None, + group_name: str = None, + owner_access_key: str = None, + ) -> 'SessionTemplate': + rqst = Request('POST', '/template/session') + if domain_name is None: + # Even if config.domain is None, it can be guessed in the manager by user information. + domain_name = api_session.get().config.domain + if group_name is None: + group_name = api_session.get().config.group + body = { + 'payload': template, + 'group_name': group_name, + 'domain_name': domain_name, + 'owner_access_key': owner_access_key, + } + rqst.set_json(body) + async with rqst.fetch() as resp: + response = await resp.json() + return cls(response['id'], owner_access_key=owner_access_key) + + @api_function + @classmethod + async def list_templates(cls, list_all: bool = False) -> List[Mapping[str, str]]: + rqst = Request('GET', '/template/session') + rqst.set_json({'all': list_all}) + async with rqst.fetch() as resp: + return await resp.json() + + def __init__(self, template_id: str, owner_access_key: str = None): + self.template_id = template_id + self.owner_access_key = owner_access_key + + @api_function + async def get(self, body_format: str = 'yaml') -> str: + params = {'format': body_format} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + rqst = Request('GET', f'/template/session/{self.template_id}', + params=params) + async with rqst.fetch() as resp: + data = await resp.text() + return data + + @api_function + async def put(self, template: str) -> Any: + body = { + 'payload': template, + } + if self.owner_access_key: + body['owner_access_key'] = self.owner_access_key + rqst = Request('PUT', f'/template/session/{self.template_id}') + rqst.set_json(body) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def delete(self) -> Any: + params = {} + if self.owner_access_key: + params['owner_access_key'] = self.owner_access_key + rqst = Request('DELETE', f'/template/session/{self.template_id}', + params=params) + async with rqst.fetch() as resp: + return await resp.json() diff --git a/src/ai/backend/client/func/storage.py b/src/ai/backend/client/func/storage.py new file mode 100644 index 0000000000..0318eb8520 --- /dev/null +++ b/src/ai/backend/client/func/storage.py @@ -0,0 +1,86 @@ +import textwrap +from typing import ( + Sequence, +) + +from ai.backend.client.session import api_session +from ai.backend.client.output.fields import storage_fields +from ai.backend.client.output.types import FieldSpec, PaginatedResult +from ai.backend.client.pagination import generate_paginated_results +from .base import api_function, BaseFunction + +__all__ = ( + 'Storage', +) + +_default_list_fields = ( + storage_fields['id'], + storage_fields['backend'], + storage_fields['capabilities'], +) + +_default_detail_fields = ( + storage_fields['id'], + storage_fields['backend'], + storage_fields['path'], + storage_fields['fsprefix'], + storage_fields['capabilities'], + storage_fields['hardware_metadata'], +) + + +class Storage(BaseFunction): + """ + Provides a shortcut of :func:`Admin.query() + ` that fetches various straoge volume + information keyed by vfolder hosts. + + .. note:: + + All methods in this function class require your API access key to + have the *super-admin* privilege. + """ + + @api_function + @classmethod + async def paginated_list( + cls, + status: str = 'ALIVE', + *, + fields: Sequence[FieldSpec] = _default_list_fields, + page_offset: int = 0, + page_size: int = 20, + filter: str = None, + order: str = None, + ) -> PaginatedResult[dict]: + """ + Lists the keypairs. + You need an admin privilege for this operation. + """ + return await generate_paginated_results( + 'storage_volume_list', + { + 'filter': (filter, 'String'), + 'order': (order, 'String'), + }, + fields, + page_offset=page_offset, + page_size=page_size, + ) + + @api_function + @classmethod + async def detail( + cls, + vfolder_host: str, + fields: Sequence[FieldSpec] = _default_detail_fields, + ) -> dict: + query = textwrap.dedent("""\ + query($vfolder_host: String!) { + storage_volume(id: $vfolder_host) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'vfolder_host': vfolder_host} + data = await api_session.get().Admin._query(query, variables) + return data['storage_volume'] diff --git a/src/ai/backend/client/func/system.py b/src/ai/backend/client/func/system.py new file mode 100644 index 0000000000..a64e0b583e --- /dev/null +++ b/src/ai/backend/client/func/system.py @@ -0,0 +1,37 @@ +from typing import Mapping + +from .base import api_function, BaseFunction +from ..request import Request + +__all__ = ( + 'System', +) + + +class System(BaseFunction): + """ + Provides the function interface for the API endpoint's system information. + """ + + @api_function + @classmethod + async def get_versions(cls) -> Mapping[str, str]: + rqst = Request('GET', '/') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def get_manager_version(cls) -> str: + rqst = Request('GET', '/') + async with rqst.fetch() as resp: + ret = await resp.json() + return ret['manager'] + + @api_function + @classmethod + async def get_api_version(cls) -> str: + rqst = Request('GET', '/') + async with rqst.fetch() as resp: + ret = await resp.json() + return ret['version'] diff --git a/src/ai/backend/client/func/user.py b/src/ai/backend/client/func/user.py new file mode 100644 index 0000000000..ab3f086728 --- /dev/null +++ b/src/ai/backend/client/func/user.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +import enum +import textwrap +from typing import ( + Iterable, + Sequence, + Union, +) +import uuid + +from ai.backend.client.auth import AuthToken, AuthTokenTypes +from ai.backend.client.request import Request +from ai.backend.client.session import api_session +from ai.backend.client.output.fields import user_fields +from ai.backend.client.output.types import FieldSpec, PaginatedResult +from ai.backend.client.pagination import generate_paginated_results +from .base import api_function, BaseFunction + +__all__ = ( + 'User', + 'UserStatus', + 'UserRole', +) + + +_default_list_fields = ( + user_fields['uuid'], + user_fields['role'], + user_fields['username'], + user_fields['email'], + user_fields['is_active'], + user_fields['created_at'], + user_fields['domain_name'], + user_fields['groups'], +) + +_default_detail_fields = ( + user_fields['uuid'], + user_fields['username'], + user_fields['email'], + user_fields['need_password_change'], + user_fields['status'], + user_fields['status_info'], + user_fields['created_at'], + user_fields['domain_name'], + user_fields['role'], + user_fields['groups'], +) + + +class UserRole(str, enum.Enum): + """ + The role (privilege level) of users. + """ + SUPERADMIN = 'superadmin' + ADMIN = 'admin' + USER = 'user' + MONITOR = 'monitor' + + +class UserStatus(enum.Enum): + """ + The detailed status of users to represent the signup process and account lifecycles. + """ + ACTIVE = 'active' + INACTIVE = 'inactive' + DELETED = 'deleted' + BEFORE_VERIFICATION = 'before-verification' + + +class User(BaseFunction): + """ + Provides interactions with users. + """ + + @api_function + @classmethod + async def authorize(cls, username: str, password: str, *, + token_type: AuthTokenTypes = AuthTokenTypes.KEYPAIR) -> AuthToken: + """ + Authorize the given credentials and get the API authentication token. + This function can be invoked anonymously; i.e., it does not require + access/secret keys in the session config as its purpose is to "get" them. + + Its functionality will be expanded in the future to support multiple types + of authentication methods. + """ + rqst = Request('POST', '/auth/authorize') + rqst.set_json({ + 'type': token_type.value, + 'domain': api_session.get().config.domain, + 'username': username, + 'password': password, + }) + async with rqst.fetch() as resp: + data = await resp.json() + return AuthToken( + type=token_type, + content=data['data'], + ) + + @api_function + @classmethod + async def list( + cls, + status: str = None, + group: str = None, + fields: Sequence[FieldSpec] = _default_list_fields, + ) -> Sequence[dict]: + """ + Fetches the list of users. Domain admins can only get domain users. + + :param status: Fetches users in a specific status + (active, inactive, deleted, before-verification). + :param group: Fetch users in a specific group. + :param fields: Additional per-user query fields to fetch. + """ + query = textwrap.dedent("""\ + query($status: String, $group: UUID) { + users(status: $status, group_id: $group) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = { + 'status': status, + 'group': group, + } + data = await api_session.get().Admin._query(query, variables) + return data['users'] + + @api_function + @classmethod + async def paginated_list( + cls, + status: str = None, + group: str = None, + *, + fields: Sequence[FieldSpec] = _default_list_fields, + page_offset: int = 0, + page_size: int = 20, + filter: str = None, + order: str = None, + ) -> PaginatedResult[dict]: + """ + Fetches the list of users. Domain admins can only get domain users. + + :param status: Fetches users in a specific status + (active, inactive, deleted, before-verification). + :param group: Fetch users in a specific group. + :param fields: Additional per-user query fields to fetch. + """ + return await generate_paginated_results( + 'user_list', + { + 'status': (status, 'String'), + 'group_id': (group, 'UUID'), + 'filter': (filter, 'String'), + 'order': (order, 'String'), + }, + fields, + page_offset=page_offset, + page_size=page_size, + ) + + @api_function + @classmethod + async def detail( + cls, + email: str = None, + fields: Sequence[FieldSpec] = _default_detail_fields, + ) -> Sequence[dict]: + """ + Fetch information of a user. If email is not specified, + requester's information will be returned. + + :param email: Email of the user to fetch. + :param fields: Additional per-user query fields to fetch. + """ + if email is None: + query = textwrap.dedent("""\ + query { + user {$fields} + } + """) + else: + query = textwrap.dedent("""\ + query($email: String) { + user(email: $email) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'email': email} + data = await api_session.get().Admin._query(query, variables if email is not None else None) + return data['user'] + + @api_function + @classmethod + async def detail_by_uuid( + cls, + user_uuid: Union[str, uuid.UUID] = None, + fields: Sequence[FieldSpec] = _default_detail_fields, + ) -> Sequence[dict]: + """ + Fetch information of a user by user's uuid. If user_uuid is not specified, + requester's information will be returned. + + :param user_uuid: UUID of the user to fetch. + :param fields: Additional per-user query fields to fetch. + """ + if user_uuid is None: + query = textwrap.dedent("""\ + query { + user {$fields} + } + """) + else: + query = textwrap.dedent("""\ + query($user_id: ID) { + user_from_uuid(user_id: $user_id) {$fields} + } + """) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + variables = {'user_id': str(user_uuid)} + data = await api_session.get().Admin._query(query, variables if user_uuid is not None else None) + return data['user_from_uuid'] + + @api_function + @classmethod + async def create( + cls, + domain_name: str, + email: str, + password: str, + username: str = None, + full_name: str = None, + role: UserRole | str = UserRole.USER, + status: UserStatus | str = UserStatus.ACTIVE, + need_password_change: bool = False, + description: str = '', + group_ids: Iterable[str] = None, + fields: Iterable[str] = None, + ) -> dict: + """ + Creates a new user with the given options. + You need an admin privilege for this operation. + """ + if fields is None: + fields = ('domain_name', 'email', 'username', 'uuid') + query = textwrap.dedent("""\ + mutation($email: String!, $input: UserInput!) { + create_user(email: $email, props: $input) { + ok msg user {$fields} + } + } + """) + query = query.replace('$fields', ' '.join(fields)) + variables = { + 'email': email, + 'input': { + 'password': password, + 'username': username, + 'full_name': full_name, + 'role': role.value if isinstance(role, UserRole) else role, + 'status': status.value if isinstance(status, UserStatus) else status, + 'need_password_change': need_password_change, + 'description': description, + 'domain_name': domain_name, + 'group_ids': group_ids, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['create_user'] + + @api_function + @classmethod + async def update( + cls, + email: str, + password: str = None, username: str = None, + full_name: str = None, + domain_name: str = None, + role: UserRole | str = UserRole.USER, + status: UserStatus | str = UserStatus.ACTIVE, + need_password_change: bool = None, + description: str = None, + group_ids: Iterable[str] = None, + fields: Iterable[str] = None, + ) -> dict: + """ + Update existing user. + You need an admin privilege for this operation. + """ + query = textwrap.dedent("""\ + mutation($email: String!, $input: ModifyUserInput!) { + modify_user(email: $email, props: $input) { + ok msg + } + } + """) + variables = { + 'email': email, + 'input': { + 'password': password, + 'username': username, + 'full_name': full_name, + 'domain_name': domain_name, + 'role': role.value if isinstance(role, UserRole) else role, + 'status': status.value if isinstance(status, UserStatus) else status, + 'need_password_change': need_password_change, + 'description': description, + 'group_ids': group_ids, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['modify_user'] + + @api_function + @classmethod + async def delete(cls, email: str): + """ + Inactivates an existing user. + """ + query = textwrap.dedent("""\ + mutation($email: String!) { + delete_user(email: $email) { + ok msg + } + } + """) + variables = {'email': email} + data = await api_session.get().Admin._query(query, variables) + return data['delete_user'] + + @api_function + @classmethod + async def purge(cls, email: str, purge_shared_vfolders=False): + """ + Deletes an existing user. + + User's virtual folders are also deleted, except the ones shared with other users. + Shared virtual folder's ownership will be transferred to the requested admin. + To delete shared folders as well, set ``purge_shared_vfolders`` to ``True``. + """ + query = textwrap.dedent("""\ + mutation($email: String!, $input: PurgeUserInput!) { + purge_user(email: $email, props: $input) { + ok msg + } + } + """) + variables = { + 'email': email, + 'input': { + 'purge_shared_vfolders': purge_shared_vfolders, + }, + } + data = await api_session.get().Admin._query(query, variables) + return data['purge_user'] diff --git a/src/ai/backend/client/func/vfolder.py b/src/ai/backend/client/func/vfolder.py new file mode 100644 index 0000000000..0b17046e10 --- /dev/null +++ b/src/ai/backend/client/func/vfolder.py @@ -0,0 +1,473 @@ +import asyncio +from pathlib import Path +from typing import ( + Mapping, + Optional, + Sequence, + Union, +) + +import aiohttp +import janus +from tqdm import tqdm + +from yarl import URL +from aiotusclient import client + +from ai.backend.client.output.fields import vfolder_fields +from ai.backend.client.output.types import FieldSpec, PaginatedResult +from .base import api_function, BaseFunction +from ..compat import current_loop +from ..config import DEFAULT_CHUNK_SIZE, MAX_INFLIGHT_CHUNKS +from ..exceptions import BackendClientError +from ..pagination import generate_paginated_results +from ..request import Request + +__all__ = ( + 'VFolder', +) + +_default_list_fields = ( + vfolder_fields['host'], + vfolder_fields['name'], + vfolder_fields['created_at'], + vfolder_fields['creator'], + vfolder_fields['group_id'], + vfolder_fields['permission'], + vfolder_fields['ownership_type'], +) + + +class VFolder(BaseFunction): + + def __init__(self, name: str): + self.name = name + + @api_function + @classmethod + async def create( + cls, + name: str, + host: str = None, + unmanaged_path: str = None, + group: str = None, + usage_mode: str = 'general', + permission: str = 'rw', + quota: str = '0', + cloneable: bool = False, + ): + rqst = Request('POST', '/folders') + rqst.set_json({ + 'name': name, + 'host': host, + 'unmanaged_path': unmanaged_path, + 'group': group, + 'usage_mode': usage_mode, + 'permission': permission, + 'quota': quota, + 'cloneable': cloneable, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def delete_by_id(cls, oid): + rqst = Request('DELETE', '/folders') + rqst.set_json({'id': oid}) + async with rqst.fetch(): + return {} + + @api_function + @classmethod + async def list(cls, list_all=False): + rqst = Request('GET', '/folders') + rqst.set_json({'all': list_all}) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def paginated_list( + cls, + group: str = None, + *, + fields: Sequence[FieldSpec] = _default_list_fields, + page_offset: int = 0, + page_size: int = 20, + filter: str = None, + order: str = None, + ) -> PaginatedResult[dict]: + """ + Fetches the list of vfolders. Domain admins can only get domain vfolders. + + :param group: Fetch vfolders in a specific group. + :param fields: Additional per-vfolder query fields to fetch. + """ + return await generate_paginated_results( + 'vfolder_list', + { + 'group_id': (group, 'UUID'), + 'filter': (filter, 'String'), + 'order': (order, 'String'), + }, + fields, + page_offset=page_offset, + page_size=page_size, + ) + + @api_function + @classmethod + async def list_hosts(cls): + rqst = Request('GET', '/folders/_/hosts') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def list_all_hosts(cls): + rqst = Request('GET', '/folders/_/all_hosts') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def list_allowed_types(cls): + rqst = Request('GET', '/folders/_/allowed_types') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def info(self): + rqst = Request('GET', '/folders/{0}'.format(self.name)) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def delete(self): + rqst = Request('DELETE', '/folders/{0}'.format(self.name)) + async with rqst.fetch(): + return {} + + @api_function + async def rename(self, new_name): + rqst = Request('POST', '/folders/{0}/rename'.format(self.name)) + rqst.set_json({ + 'new_name': new_name, + }) + async with rqst.fetch() as resp: + self.name = new_name + return await resp.text() + + @api_function + async def download( + self, + relative_paths: Sequence[Union[str, Path]], + *, + basedir: Union[str, Path] = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + show_progress: bool = False, + address_map: Optional[Mapping[str, str]] = None, + ) -> None: + base_path = (Path.cwd() if basedir is None else Path(basedir).resolve()) + for relpath in relative_paths: + file_path = base_path / relpath + rqst = Request('POST', + '/folders/{}/request-download'.format(self.name)) + rqst.set_json({ + 'path': str(relpath), + }) + async with rqst.fetch() as resp: + download_info = await resp.json() + overriden_url = download_info['url'] + if address_map: + if download_info['url'] in address_map: + overriden_url = address_map[download_info['url']] + else: + raise BackendClientError( + 'Overriding storage proxy addresses are given, ' + 'but no url matches with any of them.\n', + ) + + download_url = URL(overriden_url).with_query({ + 'token': download_info['token'], + }) + + def _write_file(file_path: Path, q: janus._SyncQueueProxy[bytes]): + with open(file_path, 'wb') as f: + while True: + chunk = q.get() + if not chunk: + return + f.write(chunk) + q.task_done() + + if show_progress: + print(f"Downloading to {file_path} ...") + async with aiohttp.ClientSession() as client: + # TODO: ranged requests to continue interrupted downloads with automatic retries + async with client.get(download_url, ssl=False) as raw_resp: + size = int(raw_resp.headers['Content-Length']) + if file_path.exists(): + raise RuntimeError('The target file already exists', file_path.name) + q: janus.Queue[bytes] = janus.Queue(MAX_INFLIGHT_CHUNKS) + try: + with tqdm( + total=size, + unit='bytes', + unit_scale=True, + unit_divisor=1024, + disable=not show_progress, + ) as pbar: + loop = current_loop() + writer_fut = loop.run_in_executor(None, _write_file, file_path, q.sync_q) + await asyncio.sleep(0) + while True: + chunk = await raw_resp.content.read(chunk_size) + pbar.update(len(chunk)) + if not chunk: + break + await q.async_q.put(chunk) + finally: + await q.async_q.put(b'') + await writer_fut + q.close() + await q.wait_closed() + + @api_function + async def upload( + self, + files: Sequence[Union[str, Path]], + *, + basedir: Union[str, Path] = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + address_map: Optional[Mapping[str, str]] = None, + show_progress: bool = False, + ) -> None: + base_path = (Path.cwd() if basedir is None else Path(basedir).resolve()) + if basedir: + files = [basedir / Path(file) for file in files] + else: + files = [Path(file).resolve() for file in files] + for file_path in files: + file_size = Path(file_path).stat().st_size + rqst = Request('POST', + '/folders/{}/request-upload'.format(self.name)) + rqst.set_json({ + 'path': "{}".format(str(Path(file_path).relative_to(base_path))), + 'size': int(file_size), + }) + async with rqst.fetch() as resp: + upload_info = await resp.json() + overriden_url = upload_info['url'] + if address_map: + if upload_info['url'] in address_map: + overriden_url = address_map[upload_info['url']] + else: + raise BackendClientError( + 'Overriding storage proxy addresses are given, ' + 'but no url matches with any of them.\n', + ) + upload_url = URL(overriden_url).with_query({ + 'token': upload_info['token'], + }) + tus_client = client.TusClient() + if basedir: + input_file = open(base_path / file_path, "rb") + else: + input_file = open(str(Path(file_path).relative_to(base_path)), "rb") + print(f"Uploading {base_path / file_path} via {upload_info['url']} ...") + # TODO: refactor out the progress bar + uploader = tus_client.async_uploader( + file_stream=input_file, + url=upload_url, + upload_checksum=False, + chunk_size=chunk_size, + ) + return await uploader.upload() + + @api_function + async def mkdir(self, path: Union[str, Path]): + rqst = Request('POST', + '/folders/{}/mkdir'.format(self.name)) + rqst.set_json({ + 'path': path, + }) + async with rqst.fetch() as resp: + return await resp.text() + + @api_function + async def rename_file(self, target_path: str, new_name: str): + rqst = Request('POST', + '/folders/{}/rename-file'.format(self.name)) + rqst.set_json({ + 'target_path': target_path, + 'new_name': new_name, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def move_file(self, src_path: str, dst_path: str): + rqst = Request('POST', + '/folders/{}/move-file'.format(self.name)) + rqst.set_json({ + 'src': src_path, + 'dst': dst_path, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def delete_files(self, + files: Sequence[Union[str, Path]], + recursive: bool = False): + rqst = Request('DELETE', + '/folders/{}/delete-files'.format(self.name)) + rqst.set_json({ + 'files': files, + 'recursive': recursive, + }) + async with rqst.fetch() as resp: + return await resp.text() + + @api_function + async def list_files(self, path: Union[str, Path] = '.'): + rqst = Request('GET', '/folders/{}/files'.format(self.name)) + rqst.set_json({ + 'path': path, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def invite(self, perm: str, emails: Sequence[str]): + rqst = Request('POST', '/folders/{}/invite'.format(self.name)) + rqst.set_json({ + 'perm': perm, 'user_ids': emails, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def invitations(cls): + rqst = Request('GET', '/folders/invitations/list') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def accept_invitation(cls, inv_id: str): + rqst = Request('POST', '/folders/invitations/accept') + rqst.set_json({'inv_id': inv_id}) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def delete_invitation(cls, inv_id: str): + rqst = Request('DELETE', '/folders/invitations/delete') + rqst.set_json({'inv_id': inv_id}) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def get_fstab_contents(cls, agent_id=None): + rqst = Request('GET', '/folders/_/fstab') + rqst.set_json({ + 'agent_id': agent_id, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def get_performance_metric(cls, folder_host: str): + rqst = Request('GET', '/folders/_/perf-metric') + rqst.set_json({ + 'folder_host': folder_host, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def list_mounts(cls): + rqst = Request('GET', '/folders/_/mounts') + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def mount_host(cls, name: str, fs_location: str, options=None, + edit_fstab: bool = False): + rqst = Request('POST', '/folders/_/mounts') + rqst.set_json({ + 'name': name, + 'fs_location': fs_location, + 'options': options, + 'edit_fstab': edit_fstab, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + @classmethod + async def umount_host(cls, name: str, edit_fstab: bool = False): + rqst = Request('DELETE', '/folders/_/mounts') + rqst.set_json({ + 'name': name, + 'edit_fstab': edit_fstab, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def share(self, perm: str, emails: Sequence[str]): + rqst = Request('POST', '/folders/{}/share'.format(self.name)) + rqst.set_json({ + 'permission': perm, 'emails': emails, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def unshare(self, emails: Sequence[str]): + rqst = Request('DELETE', '/folders/{}/unshare'.format(self.name)) + rqst.set_json({ + 'emails': emails, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def leave(self): + rqst = Request('POST', '/folders/{}/leave'.format(self.name)) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def clone(self, target_name: str, target_host: str = None, + usage_mode: str = 'general', permission: str = 'rw'): + rqst = Request('POST', '/folders/{}/clone'.format(self.name)) + rqst.set_json({ + 'target_name': target_name, + 'target_host': target_host, + 'usage_mode': usage_mode, + 'permission': permission, + }) + async with rqst.fetch() as resp: + return await resp.json() + + @api_function + async def update_options(self, name: str, permission: str = None, + cloneable: bool = None): + rqst = Request('POST', '/folders/{}/update-options'.format(self.name)) + rqst.set_json({ + 'cloneable': cloneable, + 'permission': permission, + }) + async with rqst.fetch() as resp: + return await resp.text() diff --git a/src/ai/backend/client/helper.py b/src/ai/backend/client/helper.py new file mode 100644 index 0000000000..a421dd299e --- /dev/null +++ b/src/ai/backend/client/helper.py @@ -0,0 +1,5 @@ +from .session import Session as SyncSession + + +def is_admin(session: SyncSession) -> bool: + return session.KeyPair(session.config.access_key).info()['is_admin'] diff --git a/src/ai/backend/client/load_balancing.py b/src/ai/backend/client/load_balancing.py new file mode 100644 index 0000000000..89d33d76ad --- /dev/null +++ b/src/ai/backend/client/load_balancing.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import List, Mapping, Tuple, Type + +import attr +from yarl import URL + + +@attr.s(auto_attribs=True, frozen=True) +class LoadBalancerConfig(): + name: str + args: Tuple[str, ...] + + +class LoadBalancer(metaclass=ABCMeta): + + @staticmethod + def load(config: LoadBalancerConfig) -> LoadBalancer: + cls = _cls_map[config.name] + return cls(*config.args) + + @staticmethod + def clean_config(config: str) -> LoadBalancerConfig: + name, _, raw_args = config.partition(':') + args = raw_args.split(',') + return LoadBalancerConfig(name, tuple(args)) + + @abstractmethod + def rotate(self, endpoints: List[URL]) -> None: + raise NotImplementedError + + +class SimpleRRLoadBalancer(LoadBalancer): + """ + Rotates the endpoints upon every request. + """ + + def rotate(self, endpoints: List[URL]) -> None: + if len(endpoints) == 1: + return + item = endpoints.pop(0) + endpoints.append(item) + + +class PeriodicRRLoadBalancer(LoadBalancer): + """ + Rotates the endpoints upon the specified interval. + """ + + def rotate(self, endpoints: List[URL]) -> None: + pass + + +class LowestLatencyLoadBalancer(LoadBalancer): + """ + Change the endpoints with the lowest average latency for last N requests. + """ + + def rotate(self, endpoints: List[URL]) -> None: + pass + + # TODO: we need to collect and allow access to the latency statistics. + + +_cls_map: Mapping[str, Type[LoadBalancer]] = { + 'simple_rr': SimpleRRLoadBalancer, + 'periodic_rr': PeriodicRRLoadBalancer, + 'lowest_latency': LowestLatencyLoadBalancer, +} diff --git a/src/ai/backend/client/output/__init__.py b/src/ai/backend/client/output/__init__.py new file mode 100644 index 0000000000..c066829477 --- /dev/null +++ b/src/ai/backend/client/output/__init__.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ai.backend.client.cli.types import CLIContext, OutputMode + from .types import BaseOutputHandler + + +def get_output_handler(cli_ctx: CLIContext, output_mode: OutputMode) -> BaseOutputHandler: + from ai.backend.client.cli.types import OutputMode + if output_mode == OutputMode.JSON: + from .json import JsonOutputHandler + return JsonOutputHandler(cli_ctx) + elif output_mode == OutputMode.CONSOLE: + from .console import ConsoleOutputHandler + return ConsoleOutputHandler(cli_ctx) + raise RuntimeError("Invalid output handler", output_mode) diff --git a/src/ai/backend/client/output/console.py b/src/ai/backend/client/output/console.py new file mode 100644 index 0000000000..e1ce62a210 --- /dev/null +++ b/src/ai/backend/client/output/console.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import sys +from typing import ( + Any, + Callable, + Mapping, + Optional, + Sequence, +) + +from tabulate import tabulate + +from ai.backend.client.cli.pretty import print_error, print_fail +from ai.backend.client.cli.pagination import ( + echo_via_pager, + get_preferred_page_size, + tabulate_items, +) +from .types import FieldSpec, PaginatedResult, BaseOutputHandler + + +class NoItems(Exception): + pass + + +class ConsoleOutputHandler(BaseOutputHandler): + + def print_item( + self, + item: Mapping[str, Any] | None, + fields: Sequence[FieldSpec], + ) -> None: + if item is None: + print_fail("No matching entry found.") + return + field_map = {f.field_name: f for f in fields} + print(tabulate( + [ + ( + field_map[k].humanized_name, + field_map[k].formatter.format_console(v, field_map[k]), + ) + for k, v in item.items() + ], + headers=('Field', 'Value'), + )) + + def print_items( + self, + items: Sequence[Mapping[str, Any]], + fields: Sequence[FieldSpec], + ) -> None: + field_map = {f.field_name: f for f in fields} + for idx, item in enumerate(items): + if idx > 0: + print("-" * 20) + print(tabulate( + [ + ( + field_map[k].humanized_name, + field_map[k].formatter.format_console(v, field_map[k]), + ) + for k, v in item.items() + ], + headers=('Field', 'Value'), + )) + + def print_list( + self, + items: Sequence[Mapping[str, Any]], + fields: Sequence[FieldSpec], + *, + is_scalar: bool = False, + ) -> None: + if is_scalar: + assert len(fields) == 1 + if sys.stdout.isatty(): + + def infinite_fetch(): + current_offset = 0 + page_size = get_preferred_page_size() + while True: + if len(items) == 0: + raise NoItems + if is_scalar: + yield from map( + lambda v: {fields[0].field_name: v}, + items[current_offset:current_offset + page_size], + ) + else: + yield from items[current_offset:current_offset + page_size] + current_offset += page_size + if current_offset >= len(items): + break + + try: + echo_via_pager( + tabulate_items( + infinite_fetch(), + fields, + ), + ) + except NoItems: + print("No matching items.") + else: + if is_scalar: + for line in tabulate_items( + map(lambda v: {fields[0].field_name: v}, items), # type: ignore + fields, + ): + print(line, end="") + else: + for line in tabulate_items( + items, # type: ignore + fields, + ): + print(line, end="") + + def print_paginated_list( + self, + fetch_func: Callable[[int, int], PaginatedResult], + initial_page_offset: int, + page_size: int = None, + ) -> None: + if sys.stdout.isatty() and page_size is None: + page_size = get_preferred_page_size() + fields: Sequence[FieldSpec] = [] + + def infinite_fetch(): + nonlocal fields + current_offset = initial_page_offset + while True: + result = fetch_func(current_offset, page_size) + if result.total_count == 0: + raise NoItems + current_offset += len(result.items) + if not fields: + fields.extend(result.fields) + yield from result.items + if current_offset >= result.total_count: + break + + try: + echo_via_pager( + tabulate_items( + infinite_fetch(), + fields, + ), + ) + except NoItems: + print("No matching items.") + else: + page_size = page_size or 20 + result = fetch_func(initial_page_offset, page_size) + for line in tabulate_items( + result.items, # type: ignore + result.fields, + ): + print(line, end="") + + def print_mutation_result( + self, + item: Mapping[str, Any], + item_name: Optional[str] = None, + action_name: Optional[str] = None, + extra_info: Mapping = {}, + ) -> None: + t = [ + ['ok', item['ok']], + ['msg', item['msg']], + *[(k, v) for k, v in extra_info.items()], + ] + if action_name is not None: + t += [['Action', action_name]] + if item_name is not None: + t += [(k, v) for k, v in item[item_name].items()] + print(tabulate( + t, headers=('Field', 'Value'), + )) + + def print_mutation_error( + self, + error: Optional[Exception] = None, + msg: str = 'Failed', + item_name: Optional[str] = None, + action_name: Optional[str] = None, + extra_info: Mapping = {}, + ) -> None: + t = [ + ['Message', msg], + ] + if item_name is not None: + t += [['Item', item_name]] + if action_name is not None: + t += [['Action', action_name]] + print(tabulate( + t, headers=('Field', 'Value'), + )) + if error is not None: + print_error(error) + + def print_error( + self, + error: Exception, + ) -> None: + print_error(error) + + def print_fail( + self, + message: str, + ) -> None: + print_fail(message) diff --git a/src/ai/backend/client/output/fields.py b/src/ai/backend/client/output/fields.py new file mode 100644 index 0000000000..b7d91db49a --- /dev/null +++ b/src/ai/backend/client/output/fields.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +from .formatters import ( + AgentStatFormatter, + GroupListFormatter, + ContainerListFormatter, + DependencyListFormatter, + SubFieldOutputFormatter, + KernelStatFormatter, + nested_dict_formatter, + mibytes_output_formatter, + resource_slot_formatter, + sizebytes_output_formatter, +) +from .types import ( + FieldSet, + FieldSpec, +) + + +container_fields = FieldSet([ + FieldSpec('id', "Kernel ID", alt_name='kernel_id'), + FieldSpec('cluster_role'), + FieldSpec('cluster_idx'), + FieldSpec('cluster_hostname'), + FieldSpec('session_id', "Session ID"), + FieldSpec('image'), + FieldSpec('registry'), + FieldSpec('status'), + FieldSpec('status_info'), + FieldSpec('status_data', formatter=nested_dict_formatter), + FieldSpec('status_changed'), + FieldSpec('agent'), + FieldSpec('container_id'), + FieldSpec('resource_opts', formatter=nested_dict_formatter), + FieldSpec('occupied_slots', formatter=resource_slot_formatter), + FieldSpec('live_stat', formatter=KernelStatFormatter()), + FieldSpec('last_stat', formatter=KernelStatFormatter()), +]) + + +agent_fields = FieldSet([ + FieldSpec('id'), + FieldSpec('status'), + FieldSpec('status_changed'), + FieldSpec('region'), + FieldSpec('architecture'), + FieldSpec('scaling_group'), + FieldSpec('schedulable'), + FieldSpec('available_slots', formatter=resource_slot_formatter), + FieldSpec('occupied_slots', formatter=resource_slot_formatter), + FieldSpec('addr'), + FieldSpec('first_contact'), + FieldSpec('lost_at'), + FieldSpec('live_stat', formatter=AgentStatFormatter()), + FieldSpec('version'), + FieldSpec('compute_plugins'), + FieldSpec('hardware_metadata', formatter=nested_dict_formatter), + FieldSpec('compute_containers', subfields=container_fields, + formatter=ContainerListFormatter()), + # legacy fields + FieldSpec('cpu_cur_pct', 'CPU Usage (%)'), + FieldSpec('mem_cur_bytes', 'Used Memory (MiB)', formatter=mibytes_output_formatter), +]) + +domain_fields = FieldSet([ + FieldSpec('name'), + FieldSpec('description'), + FieldSpec('is_active'), + FieldSpec('created_at'), + FieldSpec('total_resource_slots', formatter=resource_slot_formatter), + FieldSpec('allowed_vfolder_hosts'), + FieldSpec('allowed_docker_registries'), + FieldSpec('integration_id'), +]) + +group_fields = FieldSet([ + FieldSpec('id'), + FieldSpec('name'), + FieldSpec('description'), + FieldSpec('is_active'), + FieldSpec('created_at'), + FieldSpec('domain_name'), + FieldSpec('total_resource_slots', formatter=resource_slot_formatter), + FieldSpec('allowed_vfolder_hosts'), + FieldSpec('integration_id'), +]) + + +image_fields = FieldSet([ + FieldSpec('name'), + FieldSpec('registry'), + FieldSpec('architecture'), + FieldSpec('tag'), + FieldSpec('digest'), + FieldSpec('size_bytes', formatter=sizebytes_output_formatter), + FieldSpec('aliases'), +]) + + +keypair_fields = FieldSet([ + FieldSpec('user_id', "Email"), + FieldSpec('user_info { full_name }', "Full Name", alt_name='full_name', + formatter=SubFieldOutputFormatter("full_name")), + FieldSpec('access_key'), + FieldSpec('secret_key'), + FieldSpec('is_active'), + FieldSpec('is_admin'), + FieldSpec('created_at'), + FieldSpec('modified_at'), + FieldSpec('last_used'), + FieldSpec('resource_policy'), + FieldSpec('rate_limit'), + FieldSpec('concurrency_used'), + FieldSpec('ssh_public_key'), + FieldSpec('ssh_private_key'), + FieldSpec('dotfiles'), + FieldSpec('bootstrap_script'), +]) + + +keypair_resource_policy_fields = FieldSet([ + FieldSpec('name'), + FieldSpec('created_at'), + FieldSpec('total_resource_slots'), + FieldSpec('max_concurrent_sessions'), # formerly concurrency_limit + FieldSpec('max_vfolder_count'), + FieldSpec('max_vfolder_size', formatter=sizebytes_output_formatter), + FieldSpec('idle_timeout'), + FieldSpec('max_containers_per_session'), + FieldSpec('allowed_vfolder_hosts'), +]) + + +scaling_group_fields = FieldSet([ + FieldSpec('name'), + FieldSpec('description'), + FieldSpec('is_active'), + FieldSpec('created_at'), + FieldSpec('driver'), + FieldSpec('driver_opts', formatter=nested_dict_formatter), + FieldSpec('scheduler'), + FieldSpec('scheduler_opts', formatter=nested_dict_formatter), +]) + + +session_fields = FieldSet([ + FieldSpec('id', "Kernel ID", alt_name='kernel_id'), + FieldSpec('tag'), + FieldSpec('name'), + FieldSpec('type'), + FieldSpec('session_id', "Session ID"), + FieldSpec('image'), + FieldSpec('registry'), + FieldSpec('cluster_template'), + FieldSpec('cluster_mode'), + FieldSpec('cluster_size'), + FieldSpec('domain_name'), + FieldSpec('group_name', "Project/Group"), + FieldSpec('group_id'), + FieldSpec('user_email'), + FieldSpec('user_id'), + FieldSpec('access_key', "Owner Access Key"), + FieldSpec('created_user_email'), + FieldSpec('created_user_id'), + FieldSpec('status'), + FieldSpec('status_info'), + FieldSpec('status_data', formatter=nested_dict_formatter), + FieldSpec('status_changed', "Last Updated"), + FieldSpec('created_at'), + FieldSpec('terminated_at'), + FieldSpec('starts_at'), + FieldSpec('startup_command'), + FieldSpec('result'), + FieldSpec('resoucre_opts', formatter=nested_dict_formatter), + FieldSpec('scaling_group'), + FieldSpec('service_ports', formatter=nested_dict_formatter), + FieldSpec('mounts'), + FieldSpec('occupied_slots', formatter=resource_slot_formatter), + FieldSpec( + 'containers', + subfields=container_fields, + formatter=ContainerListFormatter(), + ), + FieldSpec( + 'dependencies { name id }', + formatter=DependencyListFormatter(), + ), +]) + +session_fields_v5 = FieldSet([ + FieldSpec( + 'containers', + subfields=FieldSet([ + FieldSpec('id', "Kernel ID", alt_name='kernel_id'), + FieldSpec('session_id', "Session ID"), + FieldSpec('role'), + FieldSpec('agent'), + FieldSpec('image'), + FieldSpec('status'), + FieldSpec('status_info'), + FieldSpec('status_data', formatter=nested_dict_formatter), + FieldSpec('status_changed'), + FieldSpec('occupied_slots', formatter=resource_slot_formatter), + FieldSpec('live_stat', formatter=KernelStatFormatter()), + FieldSpec('last_stat', formatter=KernelStatFormatter()), + ]), + formatter=ContainerListFormatter(), + ), +]) + + +storage_fields = FieldSet([ + FieldSpec('id'), + FieldSpec('backend'), + FieldSpec('fsprefix'), + FieldSpec('path'), + FieldSpec('capabilities'), + FieldSpec('hardware_metadata', formatter=nested_dict_formatter), + FieldSpec('performance_metric', formatter=nested_dict_formatter), + FieldSpec('usage', formatter=nested_dict_formatter), +]) + + +user_fields = FieldSet([ + FieldSpec('uuid'), + FieldSpec('username'), + FieldSpec('email'), + # password is not queriable! + FieldSpec('need_password_change'), + FieldSpec('full_name'), + FieldSpec('description'), + FieldSpec('is_active'), + FieldSpec('status'), + FieldSpec('status_info'), + FieldSpec('created_at'), + FieldSpec('modified_at'), + FieldSpec('domain_name'), + FieldSpec('role'), + FieldSpec('groups { id name }', formatter=GroupListFormatter()), +]) + + +vfolder_fields = FieldSet([ + FieldSpec('id'), + FieldSpec('host'), + FieldSpec('name'), + FieldSpec('user', alt_name='user_id'), + FieldSpec('group', alt_name='group_id'), + FieldSpec('creator'), + FieldSpec('unmanaged_path'), + FieldSpec('usage_mode'), + FieldSpec('permission'), + FieldSpec('ownership_type'), + FieldSpec('max_files'), + FieldSpec('max_size'), + FieldSpec('created_at'), + FieldSpec('last_used'), + FieldSpec('num_files'), + FieldSpec('cur_size'), + FieldSpec('cloneable'), +]) diff --git a/src/ai/backend/client/output/formatters.py b/src/ai/backend/client/output/formatters.py new file mode 100644 index 0000000000..c4ad02ffd2 --- /dev/null +++ b/src/ai/backend/client/output/formatters.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import decimal +import json +import textwrap +from typing import ( + Any, + Mapping, + Optional, +) + +import humanize + +from .types import AbstractOutputFormatter, FieldSpec + + +def format_stats(raw_stats: Optional[str], indent='') -> str: + if raw_stats is None: + return "(unavailable)" + stats = json.loads(raw_stats) + text = "\n".join(f"- {k + ': ':18s}{v}" for k, v in stats.items()) + return "\n" + textwrap.indent(text, indent) + + +def format_multiline(value: Any, indent_length: int) -> str: + buf = [] + for idx, line in enumerate(str(value).strip().splitlines()): + if idx == 0: + buf.append(line) + else: + buf.append((" " * indent_length) + line) + return "\n".join(buf) + + +def format_nested_dicts(value: Mapping[str, Mapping[str, Any]]) -> str: + """ + Format a mapping from string keys to sub-mappings. + """ + rows = [] + if not value: + rows.append("(empty)") + else: + for outer_key, outer_value in value.items(): + if isinstance(outer_value, dict): + if outer_value: + rows.append(f"+ {outer_key}") + inner_rows = format_nested_dicts(outer_value) + rows.append(textwrap.indent(inner_rows, prefix=" ")) + else: + rows.append(f"+ {outer_key}: (empty)") + else: + if outer_value is None: + rows.append(f"- {outer_key}: (null)") + else: + rows.append(f"- {outer_key}: {outer_value}") + return "\n".join(rows) + + +def format_value(value: Any) -> str: + if value is None: + return "(null)" + if isinstance(value, (dict, list, set)) and not value: + return "(empty)" + return str(value) + + +class OutputFormatter(AbstractOutputFormatter): + """ + The base implementation of output formats. + """ + + def format_console(self, value: Any, field: FieldSpec) -> str: + if value is None: + return "(null)" + if isinstance(value, (dict, list, set)) and not value: + return "(empty)" + elif isinstance(value, dict): + return "{" \ + + ", ".join(f"{k}: {self.format_console(v, field)}" for k, v in value.items()) \ + + "}" + elif isinstance(value, (list, tuple, set)): + return "[" \ + + ", ".join(self.format_console(v, field) for v in value) \ + + "]" + return str(value) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + if value is None: + return None + if isinstance(value, decimal.Decimal): + return str(value) + elif isinstance(value, dict): + return {k: self.format_json(v, field) for k, v in value.items()} + elif isinstance(value, (list, tuple)): + return [self.format_json(v, field) for v in value] + return value + + +class NestedDictOutputFormatter(OutputFormatter): + + def format_console(self, value: Any, field: FieldSpec) -> str: + if value is None: + return "(null)" + value = json.loads(value) + return format_nested_dicts(value) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + if value is None: + return None + return json.loads(value) + + +class MiBytesOutputFormatter(OutputFormatter): + + def format_console(self, value: Any, field: FieldSpec) -> str: + value = round(value / 2 ** 20, 1) + return super().format_console(value, field) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + value = round(value / 2 ** 20, 1) + return super().format_json(value, field) + + +class SizeBytesOutputFormatter(OutputFormatter): + + def format_console(self, value: Any, field: FieldSpec) -> str: + value = humanize.naturalsize(value, binary=True) + return super().format_console(value, field) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + value = humanize.naturalsize(value, binary=True) + return super().format_json(value, field) + + +class SubFieldOutputFormatter(OutputFormatter): + + def __init__(self, subfield_name: str) -> None: + self._subfield_name = subfield_name + + def format_console(self, value: Any, field: FieldSpec) -> str: + return super().format_console(value[self._subfield_name], field) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + return super().format_json(value[self._subfield_name], field) + + +class ResourceSlotFormatter(OutputFormatter): + + def format_console(self, value: Any, field: FieldSpec) -> str: + value = json.loads(value) + if mem := value.get('mem'): + value['mem'] = humanize.naturalsize(mem, binary=True, gnu=True) + return ", ".join( + f"{k}:{v}" for k, v in value.items() + ) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + return json.loads(value) + + +default_output_formatter = OutputFormatter() +nested_dict_formatter = NestedDictOutputFormatter() +mibytes_output_formatter = MiBytesOutputFormatter() +resource_slot_formatter = ResourceSlotFormatter() +sizebytes_output_formatter = SizeBytesOutputFormatter() + + +class AgentStatFormatter(OutputFormatter): + + def format_console(self, value: Any, field: FieldSpec) -> str: + raw_stats = json.loads(value) + + value_formatters = { + 'bytes': lambda m: "{} / {}".format( + humanize.naturalsize(int(m['current']), binary=True), + humanize.naturalsize(int(m['capacity']), binary=True), + ), + 'Celsius': lambda m: "{:,} C".format( + float(m['current']), + ), + 'bps': lambda m: "{}/s".format( + humanize.naturalsize(float(m['current'])), + ), + 'pct': lambda m: "{} %".format( + m['pct'], + ), + } + + def format_value(metric): + formatter = value_formatters.get( + metric['unit_hint'], + lambda m: "{} / {} {}".format( + m['current'], + m['capacity'], + m['unit_hint'], + ), + ) + return formatter(metric) + + bufs = [] + node_metric_bufs = [] + for stat_key, metric in raw_stats['node'].items(): + if stat_key == 'cpu_util': + num_cores = len(raw_stats['devices']['cpu_util']) + if metric['pct'] is None: + node_metric_bufs.append(f"{stat_key}: (calculating...) % ({num_cores} cores)") + else: + node_metric_bufs.append(f"{stat_key}: {metric['pct']} % ({num_cores} cores)") + else: + node_metric_bufs.append(f"{stat_key}: {format_value(metric)}") + bufs.append(", ".join(node_metric_bufs)) + dev_metric_bufs = [] + for stat_key, per_dev_metric in raw_stats['devices'].items(): + dev_metric_bufs.append(f"+ {stat_key}") + if stat_key == 'cpu_util' and len(per_dev_metric) > 8: + dev_metric_bufs.append( + " - (per-core stats hidden for large CPUs with more than 8 cores)", + ) + else: + for dev_id, metric in per_dev_metric.items(): + dev_metric_bufs.append( + f" - {dev_id}: {format_value(metric)}", + ) + bufs.append("\n".join(dev_metric_bufs)) + return '\n'.join(bufs) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + # TODO: improve + return self.format_console(value, field) + + +class GroupListFormatter(OutputFormatter): + + def format_console(self, value: Any, field: FieldSpec) -> str: + return ", ".join(g['name'] for g in value) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + return value + + +class KernelStatFormatter(OutputFormatter): + + def format_console(self, value: Any, field: FieldSpec) -> str: + return format_stats(value) + + def format_json(self, value: Any, field: FieldSpec) -> Any: + return value + + +class NestedObjectFormatter(OutputFormatter): + + def format_json(self, value: Any, field: FieldSpec) -> Any: + assert isinstance(value, list) + return [ + { + f.alt_name: f.formatter.format_json(item[f.field_name], f) + for f in field.subfields.values() + } + for item in value + ] + + +def _fit_multiline_in_cell(text: str, indent: str) -> str: + if '\n' in text: + return '\n' + textwrap.indent(text, indent) + else: + return text + + +class ContainerListFormatter(NestedObjectFormatter): + + def format_console(self, value: Any, field: FieldSpec, indent='') -> str: + assert isinstance(value, list) + if len(value) == 0: + text = "(no sub-containers belonging to the session)" + else: + text = "" + for item in value: + text += f"+ {item['id']}\n" + text += "\n".join( + f" - {f.humanized_name}: " + f"{_fit_multiline_in_cell(f.formatter.format_console(item[f.field_name], f), ' ')}" # noqa + for f in field.subfields.values() + if f.field_name != "id" + ) + return textwrap.indent(text, indent) + + +class DependencyListFormatter(NestedObjectFormatter): + + def format_console(self, value: Any, field: FieldSpec, indent='') -> str: + assert isinstance(value, list) + if len(value) == 0: + text = "(no dependency tasks)" + else: + text = "" + for item in value: + text += f"+ {item['name']} ({item['id']})\n" + text += "\n".join( + f" - {f.humanized_name}: " + f"{_fit_multiline_in_cell(f.formatter.format_console(item[f.field_name], f), ' ')}" # noqa + for f in field.subfields.values() + if f.field_name not in ("id", "name") + ) + return textwrap.indent(text, indent) diff --git a/src/ai/backend/client/output/json.py b/src/ai/backend/client/output/json.py new file mode 100644 index 0000000000..f54d93b6c9 --- /dev/null +++ b/src/ai/backend/client/output/json.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import json +from typing import ( + Any, + Callable, + Mapping, + Optional, + Sequence, + TypeVar, +) + +from .types import BaseOutputHandler, PaginatedResult, FieldSpec + +_json_opts: Mapping[str, Any] = {"indent": 2} + +T = TypeVar('T') + + +class JsonOutputHandler(BaseOutputHandler): + + def print_item( + self, + item: Mapping[str, Any] | None, + fields: Sequence[FieldSpec], + ) -> None: + if item is None: + print(json.dumps({ + "count": 0, + "total_count": 0, + "items": [], + })) + return + field_map = {f.field_name: f for f in fields} + print(json.dumps( + { + "count": 1, + "total_count": 1, + "items": [ + { + field_map[k].alt_name: field_map[k].formatter.format_json(v, field_map[k]) + for k, v in item.items() + }, + ], + }, + **_json_opts, + )) + + def print_items( + self, + items: Sequence[Mapping[str, Any]], + fields: Sequence[FieldSpec], + ) -> None: + field_map = {f.field_name: f for f in fields} + print(json.dumps( + { + "count": len(items), + "total_count": len(items), + "items": [ + { + field_map[k].alt_name: field_map[k].formatter.format_json(v, field_map[k]) + for k, v in item.items() + } for item in items + ], + }, + **_json_opts, + )) + + def print_list( + self, + items: Sequence[Mapping[str, Any]], + fields: Sequence[FieldSpec], + *, + is_scalar: bool = False, + ) -> None: + if is_scalar: + assert len(fields) == 1 + item_list = [ + { + fields[0].alt_name: fields[0].formatter.format_json(item, fields[0]), + } + for item in items + ] + else: + field_map = {f.field_name: f for f in fields} + item_list = [ + { + field_map[k].alt_name: field_map[k].formatter.format_json(v, field_map[k]) + for k, v in item.items() + } + for item in items + ] + print(json.dumps( + { + "count": len(items), + "total_count": len(items), + "items": item_list, + }, + **_json_opts, + )) + + def print_paginated_list( + self, + fetch_func: Callable[[int, int], PaginatedResult], + initial_page_offset: int, + page_size: int = None, + ) -> None: + page_size = page_size or 20 + result = fetch_func(initial_page_offset, page_size) + field_map = {f.field_name: f for f in result.fields} + print(json.dumps( + { + "count": len(result.items), + "total_count": result.total_count, + "items": [ + { + field_map[k].alt_name: field_map[k].formatter.format_json(v, field_map[k]) + for k, v in item.items() + } + for item in result.items + ], + }, + **_json_opts, + )) + + def print_mutation_result( + self, + item: Mapping[str, Any], + item_name: Optional[str] = None, + action_name: Optional[str] = None, + extra_info: Mapping = {}, + ) -> None: + data = { + 'ok': item.get('ok', False), + 'msg': item.get('msg', 'Failed'), + **extra_info, + } + if item_name is not None and item_name in item: + data = { + **data, + item_name: { + k: v for k, v in item[item_name].items() + }, + } + print(json.dumps( + data, + **_json_opts, + )) + + def print_mutation_error( + self, + error: Optional[Exception] = None, + msg: str = 'Failed', + item_name: Optional[str] = None, + action_name: Optional[str] = None, + extra_info: Mapping = {}, + ) -> None: + data = { + 'ok': False, + 'msg': msg, + 'item_name': item_name, + 'action_name': action_name, + **extra_info, + } + if error is not None: + data['error'] = str(error) + print(json.dumps( + data, + **_json_opts, + )) + + def print_error( + self, + error: Exception, + ) -> None: + print(json.dumps( + { + "error": str(error), + }, + **_json_opts, + )) + + def print_fail( + self, + message: str, + ) -> None: + print(json.dumps( + { + "error": message, + }, + **_json_opts, + )) diff --git a/src/ai/backend/client/output/types.py b/src/ai/backend/client/output/types.py new file mode 100644 index 0000000000..7f0299b5d0 --- /dev/null +++ b/src/ai/backend/client/output/types.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from collections import UserDict +from typing import ( + Any, + Callable, + Generic, + Mapping, + Optional, + Sequence, + TypeVar, + TYPE_CHECKING, +) + +import attr + +if TYPE_CHECKING: + from ai.backend.client.cli.types import CLIContext + + +_predefined_humanized_field_names = { + "id": "ID", + "uuid": "UUID", + "group_id": "Group ID", + "user_id": "User ID", + "resource_policy": "Res.Policy", + "concurrency_used": "Concur.Used", + "fsprefix": "FS Prefix", + "hardware_metadata": "HW Metadata", + "performance_metric": "Perf.Metric", +} + + +def _make_camel_case(name: str) -> str: + return " ".join( + map(lambda s: s[0].upper() + s[1:], name.split("_")), + ) + + +class AbstractOutputFormatter(metaclass=ABCMeta): + """ + The base implementation of output formats. + """ + + @abstractmethod + def format_console(self, value: Any, field: FieldSpec) -> str: + raise NotImplementedError + + @abstractmethod + def format_json(self, value: Any, field: FieldSpec) -> Any: + raise NotImplementedError + + +@attr.define(slots=True, frozen=True) +class FieldSpec: + """ + The specification on how to represent a GraphQL object field + in the functional API and CLI output handlers. + + Attributes: + field_ref: The string to be interpolated inside GraphQL queries. + It may contain sub-fields if the queried field supports. + humanized_name: The string to be shown as the field name by the console formatter. + If not set, it's auto-generated from field_name by camel-casing it and checking + a predefined humanization mapping. + field_name: The exact field name slug. If not set, it's taken from field_ref. + alt_name: The field name slug to refer the field inside a FieldSet object hosting + this FieldSpec instance. + formatter: The formatter instance which provide per-output-type format methods. + (console and json) + subfields: A FieldSet instance to represent sub-fields in the GraphQL schema. + If set, field_ref is Automatically updated to have the braced subfield list + for actual GraphQL queries. + """ + + field_ref: str = attr.field() + humanized_name: str = attr.field() + field_name: str = attr.field() + alt_name: str = attr.field() + formatter: AbstractOutputFormatter = attr.field() + subfields: FieldSet = attr.field(factory=lambda: FieldSet([])) + + def __attrs_post_init__(self) -> None: + if self.subfields: + subfields = " ".join(f.field_ref for f in self.subfields.values()) + object.__setattr__(self, 'field_ref', f"{self.field_name} {{ {subfields} }}") + + @humanized_name.default + def _autogen_humanized_name(self) -> str: + # to handle cases like "groups { id name }", "user_info { full_name }" + field_name = self.field_ref.partition(" ")[0] + if h := _predefined_humanized_field_names.get(field_name): + return h + if field_name.startswith("is_"): + return _make_camel_case(field_name[3:]) + "?" + return _make_camel_case(field_name) + + @field_name.default + def _default_field_name(self) -> str: + return self.field_ref.partition(" ")[0] + + @alt_name.default + def _default_alt_name(self) -> str: + return self.field_ref.partition(" ")[0] + + @formatter.default + def _default_formatter(self) -> AbstractOutputFormatter: + from .formatters import default_output_formatter # avoid circular import + return default_output_formatter + + +class FieldSet(UserDict, Mapping[str, FieldSpec]): + + def __init__(self, fields: Sequence[FieldSpec]) -> None: + super().__init__({ + f.alt_name: f for f in fields + }) + + +T = TypeVar('T') + + +@attr.define(slots=True) +class PaginatedResult(Generic[T]): + total_count: int + items: Sequence[T] + fields: Sequence[FieldSpec] + + +class BaseOutputHandler(metaclass=ABCMeta): + + def __init__(self, cli_context: CLIContext) -> None: + self.ctx = cli_context + + @abstractmethod + def print_item( + self, + item: Mapping[str, Any] | None, + fields: Sequence[FieldSpec], + ) -> None: + raise NotImplementedError + + @abstractmethod + def print_items( + self, + items: Sequence[Mapping[str, Any]], + fields: Sequence[FieldSpec], + ) -> None: + raise NotImplementedError + + @abstractmethod + def print_list( + self, + items: Sequence[Mapping[str, Any]], + fields: Sequence[FieldSpec], + *, + is_scalar: bool = False, + ) -> None: + raise NotImplementedError + + @abstractmethod + def print_paginated_list( + self, + fetch_func: Callable[[int, int], PaginatedResult[T]], + initial_page_offset: int, + page_size: int = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def print_mutation_result( + self, + item: Mapping[str, Any], + item_name: Optional[str] = None, + action_name: Optional[str] = None, + extra_info: Mapping = {}, + ) -> None: + raise NotImplementedError + + @abstractmethod + def print_mutation_error( + self, + error: Optional[Exception] = None, + msg: str = 'Failed', + item_name: Optional[str] = None, + action_name: Optional[str] = None, + extra_info: Mapping = {}, + ) -> None: + raise NotImplementedError + + @abstractmethod + def print_error( + self, + error: Exception, + ) -> None: + raise NotImplementedError + + @abstractmethod + def print_fail( + self, + message: str, + ) -> None: + raise NotImplementedError diff --git a/src/ai/backend/client/pagination.py b/src/ai/backend/client/pagination.py new file mode 100644 index 0000000000..6ca9351079 --- /dev/null +++ b/src/ai/backend/client/pagination.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import textwrap +from typing import ( + Any, + Dict, + Final, + Sequence, + Tuple, + TypeVar, +) + +from .output.types import FieldSpec, PaginatedResult +from .exceptions import BackendAPIVersionError +from .session import api_session + +MAX_PAGE_SIZE: Final = 100 + +T = TypeVar('T') + + +async def execute_paginated_query( + root_field: str, + variables: Dict[str, Tuple[Any, str]], + fields: Sequence[FieldSpec], + *, + limit: int, + offset: int, +) -> PaginatedResult: + if limit > MAX_PAGE_SIZE: + raise ValueError(f"The page size cannot exceed {MAX_PAGE_SIZE}") + query = ''' + query($limit:Int!, $offset:Int!, $var_decls) { + $root_field( + limit:$limit, offset:$offset, $var_args) { + items { $fields } + total_count + } + }''' + query = query.replace('$root_field', root_field) + query = query.replace('$fields', ' '.join(f.field_ref for f in fields)) + query = query.replace( + '$var_decls', + ', '.join(f'${key}: {value[1]}' + for key, value in variables.items()), + ) + query = query.replace( + '$var_args', + ', '.join(f'{key}:${key}' + for key in variables.keys()), + ) + query = textwrap.dedent(query).strip() + var_values = {key: value[0] for key, value in variables.items()} + var_values['limit'] = limit + var_values['offset'] = offset + data = await api_session.get().Admin._query(query, var_values) + return PaginatedResult( + total_count=data[root_field]['total_count'], + items=data[root_field]['items'], + fields=fields, + ) + + +async def generate_paginated_results( + root_field: str, + variables: Dict[str, Tuple[Any, str]], + fields: Sequence[FieldSpec], + *, + page_offset: int, + page_size: int, +) -> PaginatedResult: + if page_size > MAX_PAGE_SIZE: + raise ValueError(f"The page size cannot exceed {MAX_PAGE_SIZE}") + if api_session.get().api_version < (6, '20210815'): + if variables['filter'][0] is not None or variables['order'][0] is not None: + raise BackendAPIVersionError( + "filter and order arguments for paginated lists require v6.20210815 or later.", + ) + # should remove to work with older managers + variables.pop('filter') + variables.pop('order') + offset = page_offset + while True: + limit = page_size + result = await execute_paginated_query( + root_field, variables, fields, + limit=limit, offset=offset, + ) + offset += page_size + return result diff --git a/src/ai/backend/client/py.typed b/src/ai/backend/client/py.typed new file mode 100644 index 0000000000..48cdce8528 --- /dev/null +++ b/src/ai/backend/client/py.typed @@ -0,0 +1 @@ +placeholder diff --git a/src/ai/backend/client/request.py b/src/ai/backend/client/request.py new file mode 100644 index 0000000000..2aaaf7ea1d --- /dev/null +++ b/src/ai/backend/client/request.py @@ -0,0 +1,885 @@ +from __future__ import annotations + +import asyncio +from collections import OrderedDict, namedtuple +from datetime import datetime +from decimal import Decimal +import functools +import io +import logging +import json as modjson +from pathlib import Path +import sys +from typing import ( + Any, Callable, Optional, Union, + Awaitable, AsyncIterator, Type, TypeVar, + Mapping, Sequence, List, + cast, +) + +import aiohttp +from aiohttp.client import _RequestContextManager, _WSRequestContextManager +import aiohttp.web +import appdirs +import attr +from dateutil.tz import tzutc +from multidict import CIMultiDict +from yarl import URL + +from .auth import generate_signature +from .exceptions import BackendClientError, BackendAPIError +from .session import BaseSession, Session as SyncSession, AsyncSession, api_session + +log = logging.getLogger('ai.backend.client.request') + +__all__ = [ + 'Request', + 'BaseResponse', + 'Response', + 'WebSocketResponse', + 'SSEResponse', + 'FetchContextManager', + 'WebSocketContextManager', + 'SSEContextManager', + 'AttachedFile', +] + + +RequestContent = Union[ + bytes, bytearray, str, + aiohttp.StreamReader, + io.IOBase, + None, +] +""" +The type alias for the set of allowed types for request content. +""" + + +AttachedFile = namedtuple('AttachedFile', 'filename stream content_type') +""" +A struct that represents an attached file to the API request. + +:param str filename: The name of file to store. It may include paths + and the server will create parent directories + if required. + +:param Any stream: A file-like object that allows stream-reading bytes. + +:param str content_type: The content type for the stream. For arbitrary + binary data, use "application/octet-stream". +""" + + +_T = TypeVar('_T') + + +async def _coro_return(val: _T) -> _T: + return val + + +class ExtendedJSONEncoder(modjson.JSONEncoder): + + def default(self, obj: Any) -> Any: + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, Decimal): + return str(obj) + return super().default(obj) + + +class Request: + """ + The API request object. + """ + + __slots__ = ( + 'config', 'session', + 'method', 'path', + 'date', 'headers', + 'params', 'content_type', + 'api_version', + '_content', '_attached_files', + 'reporthook', + ) + + _content: RequestContent + _attached_files: Optional[Sequence[AttachedFile]] + + date: Optional[datetime] + api_version: str + + _allowed_methods = frozenset([ + 'GET', 'HEAD', 'POST', + 'PUT', 'PATCH', 'DELETE', + 'OPTIONS']) + + def __init__( + self, + method: str = 'GET', + path: str = None, + content: RequestContent = None, *, + content_type: str = None, + params: Mapping[str, Union[str, int]] = None, + reporthook: Callable = None, + override_api_version: str = None, + ) -> None: + """ + Initialize an API request. + + :param BaseSession session: The session where this request is executed on. + + :param str path: The query path. When performing requests, the version number + prefix will be automatically perpended if required. + + :param RequestContent content: The API query body which will be encoded as + JSON. + + :param str content_type: Explicitly set the content type. See also + :func:`Request.set_content`. + """ + self.session = api_session.get() + self.config = self.session.config + self.method = method + if path is not None and path.startswith('/'): + path = path[1:] + self.path = path + self.params = params + self.date = None + if override_api_version: + self.api_version = override_api_version + else: + self.api_version = f"v{self.session.api_version[0]}.{self.session.api_version[1]}" + self.headers = CIMultiDict([ + ('User-Agent', self.config.user_agent), + ('X-BackendAI-Domain', self.config.domain), + ('X-BackendAI-Version', self.api_version), + ]) + self._content = b'' + self._attached_files = None + self.set_content(content, content_type=content_type) + self.reporthook = reporthook + + @property + def content(self) -> RequestContent: + """ + Retrieves the content in the original form. + Private codes should NOT use this as it incurs duplicate + encoding/decoding. + """ + return self._content + + def set_content( + self, + value: RequestContent, + *, + content_type: str = None, + ) -> None: + """ + Sets the content of the request. + """ + assert self._attached_files is None, \ + 'cannot set content because you already attached files.' + guessed_content_type = 'application/octet-stream' + if value is None: + guessed_content_type = 'text/plain' + self._content = b'' + elif isinstance(value, str): + guessed_content_type = 'text/plain' + self._content = value.encode('utf-8') + else: + guessed_content_type = 'application/octet-stream' + self._content = value + self.content_type = (content_type if content_type is not None + else guessed_content_type) + + def set_json(self, value: Any) -> None: + """ + A shortcut for set_content() with JSON objects. + """ + self.set_content(modjson.dumps(value, cls=ExtendedJSONEncoder), + content_type='application/json') + + def attach_files(self, files: Sequence[AttachedFile]) -> None: + """ + Attach a list of files represented as AttachedFile. + """ + assert not self._content, 'content must be empty to attach files.' + self.content_type = 'multipart/form-data' + self._attached_files = files + + def _sign( + self, + rel_url: URL, + access_key: str = None, + secret_key: str = None, + hash_type: str = None, + ) -> None: + """ + Calculates the signature of the given request and adds the + Authorization HTTP header. + It should be called at the very end of request preparation and before + sending the request to the server. + """ + if access_key is None: + access_key = self.config.access_key + if secret_key is None: + secret_key = self.config.secret_key + if hash_type is None: + hash_type = self.config.hash_type + assert self.date is not None + if self.config.endpoint_type == 'api': + hdrs, _ = generate_signature( + method=self.method, + version=self.api_version, + endpoint=self.config.endpoint, + date=self.date, + rel_url=str(rel_url), + content_type=self.content_type, + access_key=access_key, + secret_key=secret_key, + hash_type=hash_type, + ) + self.headers.update(hdrs) + elif self.config.endpoint_type == 'session': + local_state_path = Path(appdirs.user_state_dir('backend.ai', 'Lablup')) + try: + cookie_jar = cast(aiohttp.CookieJar, self.session.aiohttp_session.cookie_jar) + cookie_jar.load(local_state_path / 'cookie.dat') + except (IOError, PermissionError): + pass + else: + raise ValueError('unsupported endpoint type') + + def _pack_content(self) -> Union[RequestContent, aiohttp.FormData]: + if self._attached_files is not None: + data = aiohttp.FormData() + for f in self._attached_files: + data.add_field('src', + f.stream, + filename=f.filename, + content_type=f.content_type) + assert data.is_multipart, 'Failed to pack files as multipart.' + # Let aiohttp fill up the content-type header including + # multipart boundaries. + self.headers.pop('Content-Type', None) + return data + else: + return self._content + + def _build_url(self) -> URL: + base_url = self.config.endpoint.path.rstrip('/') + query_path = self.path.lstrip('/') if self.path is not None and len(self.path) > 0 else '' + if self.config.endpoint_type == 'session': + if not query_path.startswith('server'): + query_path = 'func/{0}'.format(query_path) + path = '{0}/{1}'.format(base_url, query_path) + url = self.config.endpoint.with_path(path) + if self.params: + url = url.with_query(self.params) + return url + + # TODO: attach rate-limit information + + def fetch(self, **kwargs) -> FetchContextManager: + """ + Sends the request to the server and reads the response. + + You may use this method either with plain synchronous Session or + AsyncSession. Both the followings patterns are valid: + + .. code-block:: python3 + + from ai.backend.client.request import Request + from ai.backend.client.session import Session + + with Session() as sess: + rqst = Request('GET', ...) + with rqst.fetch() as resp: + print(resp.text()) + + .. code-block:: python3 + + from ai.backend.client.request import Request + from ai.backend.client.session import AsyncSession + + async with AsyncSession() as sess: + rqst = Request('GET', ...) + async with rqst.fetch() as resp: + print(await resp.text()) + """ + assert self.method in self._allowed_methods, \ + 'Disallowed HTTP method: {}'.format(self.method) + self.date = datetime.now(tzutc()) + assert self.date is not None + self.headers['Date'] = self.date.isoformat() + if self.content_type is not None and 'Content-Type' not in self.headers: + self.headers['Content-Type'] = self.content_type + force_anonymous = kwargs.pop('anonymous', False) + + def _rqst_ctx_builder(): + timeout_config = aiohttp.ClientTimeout( + total=None, connect=None, + sock_connect=self.config.connection_timeout, + sock_read=self.config.read_timeout, + ) + full_url = self._build_url() + if not self.config.is_anonymous and not force_anonymous: + self._sign(full_url.relative()) + return self.session.aiohttp_session.request( + self.method, + str(full_url), + data=self._pack_content(), + timeout=timeout_config, + headers=self.headers) + + return FetchContextManager(self.session, _rqst_ctx_builder, **kwargs) + + def connect_websocket(self, **kwargs) -> WebSocketContextManager: + """ + Creates a WebSocket connection. + + .. warning:: + + This method only works with + :class:`~ai.backend.client.session.AsyncSession`. + """ + assert isinstance(self.session, AsyncSession), \ + 'Cannot use websockets with sessions in the synchronous mode' + assert self.method == 'GET', 'Invalid websocket method' + self.date = datetime.now(tzutc()) + assert self.date is not None + self.headers['Date'] = self.date.isoformat() + # websocket is always a "binary" stream. + self.content_type = 'application/octet-stream' + + def _ws_ctx_builder(): + full_url = self._build_url() + if not self.config.is_anonymous: + self._sign(full_url.relative()) + return self.session.aiohttp_session.ws_connect( + str(full_url), + autoping=True, heartbeat=30.0, + headers=self.headers) + + return WebSocketContextManager(self.session, _ws_ctx_builder, **kwargs) + + def connect_events(self, **kwargs) -> SSEContextManager: + """ + Creates a Server-Sent Events connection. + + .. warning:: + + This method only works with + :class:`~ai.backend.client.session.AsyncSession`. + """ + assert isinstance(self.session, AsyncSession), \ + 'Cannot use event streams with sessions in the synchronous mode' + assert self.method == 'GET', 'Invalid event stream method' + self.date = datetime.now(tzutc()) + assert self.date is not None + self.headers['Date'] = self.date.isoformat() + self.content_type = 'application/octet-stream' + + def _rqst_ctx_builder(): + timeout_config = aiohttp.ClientTimeout( + total=None, connect=None, + sock_connect=self.config.connection_timeout, + sock_read=self.config.read_timeout, + ) + full_url = self._build_url() + if not self.config.is_anonymous: + self._sign(full_url.relative()) + return self.session.aiohttp_session.request( + self.method, + str(full_url), + timeout=timeout_config, + headers=self.headers) + + return SSEContextManager(self.session, _rqst_ctx_builder, **kwargs) + + +class AsyncResponseMixin: + + _session: BaseSession + _raw_response: aiohttp.ClientResponse + + async def text(self) -> str: + return await self._raw_response.text() + + async def json(self, *, loads=modjson.loads) -> Any: + loads = functools.partial(loads, object_pairs_hook=OrderedDict) + return await self._raw_response.json(loads=loads) + + async def read(self, n: int = -1) -> bytes: + return await self._raw_response.content.read(n) + + async def readall(self) -> bytes: + return await self._raw_response.content.read(-1) + + +class SyncResponseMixin: + + _session: BaseSession + _raw_response: aiohttp.ClientResponse + + def text(self) -> str: + sync_session = cast(SyncSession, self._session) + return sync_session.worker_thread.execute( + self._raw_response.text(), + ) + + def json(self, *, loads=modjson.loads) -> Any: + loads = functools.partial(loads, object_pairs_hook=OrderedDict) + sync_session = cast(SyncSession, self._session) + return sync_session.worker_thread.execute( + self._raw_response.json(loads=loads), + ) + + def read(self, n: int = -1) -> bytes: + sync_session = cast(SyncSession, self._session) + return sync_session.worker_thread.execute( + self._raw_response.content.read(n), + ) + + def readall(self) -> bytes: + sync_session = cast(SyncSession, self._session) + return sync_session.worker_thread.execute( + self._raw_response.content.read(-1), + ) + + +class BaseResponse: + """ + Represents the Backend.AI API response. + Also serves as a high-level wrapper of :class:`aiohttp.ClientResponse`. + + The response objects are meant to be created by the SDK, not the callers. + + :func:`text`, :func:`json` methods return the resolved content directly with + plain synchronous Session while they return the coroutines with AsyncSession. + """ + + __slots__ = ( + '_session', '_raw_response', '_async_mode', + ) + + _session: BaseSession + _raw_response: aiohttp.ClientResponse + _async_mode: bool + + def __init__( + self, + session: BaseSession, + underlying_response: aiohttp.ClientResponse, + *, + async_mode: bool = False, + **kwargs, + ) -> None: + self._session = session + self._raw_response = underlying_response + self._async_mode = async_mode + + @property + def session(self) -> BaseSession: + return self._session + + @property + def status(self) -> int: + return self._raw_response.status + + @property + def reason(self) -> str: + if self._raw_response.reason is not None: + return self._raw_response.reason + return '' + + @property + def headers(self) -> Mapping[str, str]: + return self._raw_response.headers + + @property + def raw_response(self) -> aiohttp.ClientResponse: + return self._raw_response + + @property + def content_type(self) -> str: + return self._raw_response.content_type + + @property + def content_length(self) -> Optional[int]: + return self._raw_response.content_length + + @property + def content(self) -> aiohttp.StreamReader: + return self._raw_response.content + + +class Response(AsyncResponseMixin, BaseResponse): + pass + + +class FetchContextManager: + """ + The context manager returned by :func:`Request.fetch`. + + It provides both synchronous and asynchronous context manager interfaces. + """ + + __slots__ = ( + 'session', 'rqst_ctx_builder', 'response_cls', + 'check_status', + '_async_mode', + '_rqst_ctx', + ) + + _rqst_ctx: Optional[_RequestContextManager] + + def __init__( + self, + session: BaseSession, + rqst_ctx_builder: Callable[[], _RequestContextManager], + *, + response_cls: Type[Response] = Response, + check_status: bool = True, + ) -> None: + self.session = session + self.rqst_ctx_builder = rqst_ctx_builder + self.check_status = check_status + self.response_cls = response_cls + self._async_mode = isinstance(session, AsyncSession) + self._rqst_ctx = None + + async def __aenter__(self) -> Response: + max_retries = len(self.session.config.endpoints) + retry_count = 0 + while True: + try: + retry_count += 1 + self._rqst_ctx = self.rqst_ctx_builder() + assert self._rqst_ctx is not None + raw_resp = await self._rqst_ctx.__aenter__() + if self.check_status and raw_resp.status // 100 != 2: + msg = await raw_resp.text() + await raw_resp.__aexit__(None, None, None) + raise BackendAPIError(raw_resp.status, raw_resp.reason or '', msg) + return self.response_cls(self.session, raw_resp, + async_mode=self._async_mode) + except aiohttp.ClientConnectionError as e: + if retry_count == max_retries: + msg = 'Request to the API endpoint has failed.\n' \ + 'Check your network connection and/or the server status.\n' \ + '\u279c {!r}'.format(e) + raise BackendClientError(msg) from e + else: + self.session.config.rotate_endpoints() + continue + except aiohttp.ClientResponseError as e: + msg = 'API endpoint response error.\n' \ + '\u279c {!r}'.format(e) + await raw_resp.__aexit__(*sys.exc_info()) + raise BackendClientError(msg) from e + finally: + self.session.config.load_balance_endpoints() + + async def __aexit__(self, *exc_info) -> Optional[bool]: + assert self._rqst_ctx is not None + ret = await self._rqst_ctx.__aexit__(*exc_info) + self._rqst_ctx = None + return ret + + +class WebSocketResponse(BaseResponse): + """ + A high-level wrapper of :class:`aiohttp.ClientWebSocketResponse`. + """ + + __slots__ = ('_raw_ws', ) + + def __init__( + self, + session: BaseSession, + underlying_response: aiohttp.ClientResponse, + **kwargs, + ) -> None: + # Unfortunately, aiohttp.ClientWebSocketResponse is not a subclass of aiohttp.ClientResponse. + # Since we block methods that require ClientResponse-specific methods, we just force-typecast. + super().__init__(session, underlying_response, **kwargs) + self._raw_ws = cast(aiohttp.ClientWebSocketResponse, underlying_response) + + @property + def content_type(self) -> str: + raise AttributeError("WebSocketResponse does not have an explicit content type.") + + @property + def content_length(self) -> Optional[int]: + raise AttributeError("WebSocketResponse does not have a fixed content length.") + + @property + def content(self) -> aiohttp.StreamReader: + raise AttributeError("WebSocketResponse does not support reading the content.") + + @property + def raw_websocket(self) -> aiohttp.ClientWebSocketResponse: + return self._raw_ws + + @property + def closed(self) -> bool: + return self._raw_ws.closed + + async def close(self) -> None: + await self._raw_ws.close() + + def __aiter__(self) -> AsyncIterator[aiohttp.WSMessage]: + return self._raw_ws.__aiter__() + + def exception(self) -> Optional[BaseException]: + return self._raw_ws.exception() + + async def send_str(self, raw_str: str) -> None: + if self._raw_ws.closed: + raise aiohttp.ServerDisconnectedError('server disconnected') + await self._raw_ws.send_str(raw_str) + + async def send_json(self, obj: Any) -> None: + if self._raw_ws.closed: + raise aiohttp.ServerDisconnectedError('server disconnected') + await self._raw_ws.send_json(obj) + + async def send_bytes(self, data: bytes) -> None: + if self._raw_ws.closed: + raise aiohttp.ServerDisconnectedError('server disconnected') + await self._raw_ws.send_bytes(data) + + async def receive_str(self) -> str: + if self._raw_ws.closed: + raise aiohttp.ServerDisconnectedError('server disconnected') + return await self._raw_ws.receive_str() + + async def receive_json(self) -> Any: + if self._raw_ws.closed: + raise aiohttp.ServerDisconnectedError('server disconnected') + return await self._raw_ws.receive_json() + + async def receive_bytes(self) -> bytes: + if self._raw_ws.closed: + raise aiohttp.ServerDisconnectedError('server disconnected') + return await self._raw_ws.receive_bytes() + + +class WebSocketContextManager: + """ + The context manager returned by :func:`Request.connect_websocket`. + """ + + __slots__ = ( + 'session', 'ws_ctx_builder', 'response_cls', + 'on_enter', + '_ws_ctx', + ) + + _ws_ctx: Optional[_WSRequestContextManager] + + def __init__( + self, + session: BaseSession, + ws_ctx_builder: Callable[[], _WSRequestContextManager], + *, + on_enter: Callable = None, + response_cls: Type[WebSocketResponse] = WebSocketResponse, + ) -> None: + self.session = session + self.ws_ctx_builder = ws_ctx_builder + self.response_cls = response_cls + self.on_enter = on_enter + self._ws_ctx = None + + async def __aenter__(self) -> WebSocketResponse: + max_retries = len(self.session.config.endpoints) + retry_count = 0 + while True: + try: + retry_count += 1 + self._ws_ctx = self.ws_ctx_builder() + assert self._ws_ctx is not None + raw_ws = await self._ws_ctx.__aenter__() + except aiohttp.ClientConnectionError as e: + if retry_count == max_retries: + msg = 'Request to the API endpoint has failed.\n' \ + 'Check your network connection and/or the server status.\n' \ + 'Error detail: {!r}'.format(e) + raise BackendClientError(msg) from e + else: + self.session.config.rotate_endpoints() + continue + except aiohttp.ClientResponseError as e: + msg = 'API endpoint response error.\n' \ + '\u279c {!r}'.format(e) + raise BackendClientError(msg) from e + else: + break + finally: + self.session.config.load_balance_endpoints() + + wrapped_ws = self.response_cls(self.session, cast(aiohttp.ClientResponse, raw_ws)) + if self.on_enter is not None: + await self.on_enter(wrapped_ws) + return wrapped_ws + + async def __aexit__(self, *args) -> Optional[bool]: + assert self._ws_ctx is not None + ret = await self._ws_ctx.__aexit__(*args) + self._ws_ctx = None + return ret + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class SSEMessage: + event: str + data: str + id: Optional[str] = None + retry: Optional[int] = None + + +class SSEResponse(BaseResponse): + + __slots__ = ( + '_auto_reconnect', '_retry', '_connector', + ) + + def __init__( + self, + session: BaseSession, + underlying_response: aiohttp.ClientResponse, + *, + connector: Callable[[], Awaitable[aiohttp.ClientResponse]], + auto_reconnect: bool = True, + default_retry: int = 5, + **kwargs, + ) -> None: + super().__init__(session, underlying_response, async_mode=True, **kwargs) + self._auto_reconnect = auto_reconnect + self._retry = default_retry + self._connector = connector + + async def fetch_events(self) -> AsyncIterator[SSEMessage]: + msg_lines: List[str] = [] + server_closed = False + while True: + received_line = await self._raw_response.content.readline() + if not received_line: + # connection closed + if self._auto_reconnect and not server_closed: + await asyncio.sleep(self._retry) + self._raw_response = await self._connector() + continue + else: + break + received_line = received_line.strip(b'\r\n') + if received_line.startswith(b':'): + # comment + continue + if not received_line: + # message boundary + if len(msg_lines) == 0: + continue + event_type = 'message' + event_id = None + event_retry = None + data_lines = [] + try: + for stored_line in msg_lines: + hdr, text = stored_line.split(':', maxsplit=1) + text = text.lstrip(' ') + if hdr == 'data': + data_lines.append(text) + elif hdr == 'event': + event_type = text + elif hdr == 'id': + event_id = text + elif hdr == 'retry': + event_retry = int(text) + except (IndexError, ValueError): + log.exception('SSEResponse: parsing-error') + continue + event_data = '\n'.join(data_lines) + msg_lines.clear() + if event_retry is not None: + self._retry = event_retry + yield SSEMessage( + event=event_type, + data=event_data, + id=event_id, + retry=event_retry, + ) + if event_type == 'server_close': + server_closed = True + break + else: + msg_lines.append(received_line.decode('utf-8')) + + def __aiter__(self) -> AsyncIterator[SSEMessage]: + return self.fetch_events() + + +class SSEContextManager: + + __slots__ = ( + 'session', 'rqst_ctx_builder', 'response_cls', + '_rqst_ctx', + ) + + _rqst_ctx: Optional[_RequestContextManager] + + def __init__( + self, + session: BaseSession, + rqst_ctx_builder: Callable[[], _RequestContextManager], + *, + response_cls: Type[SSEResponse] = SSEResponse, + ) -> None: + self.session = session + self.rqst_ctx_builder = rqst_ctx_builder + self.response_cls = response_cls + self._rqst_ctx = None + + async def reconnect(self) -> aiohttp.ClientResponse: + if self._rqst_ctx is not None: + await self._rqst_ctx.__aexit__(None, None, None) + self._rqst_ctx = self.rqst_ctx_builder() + assert self._rqst_ctx is not None + raw_resp = await self._rqst_ctx.__aenter__() + if raw_resp.status // 100 != 2: + msg = await raw_resp.text() + raise BackendAPIError(raw_resp.status, raw_resp.reason or '', msg) + return raw_resp + + async def __aenter__(self) -> SSEResponse: + max_retries = len(self.session.config.endpoints) + retry_count = 0 + while True: + try: + retry_count += 1 + raw_resp = await self.reconnect() + return self.response_cls(self.session, raw_resp, connector=self.reconnect) + except aiohttp.ClientConnectionError as e: + if retry_count == max_retries: + msg = 'Request to the API endpoint has failed.\n' \ + 'Check your network connection and/or the server status.\n' \ + '\u279c {!r}'.format(e) + raise BackendClientError(msg) from e + else: + self.session.config.rotate_endpoints() + continue + except aiohttp.ClientResponseError as e: + msg = 'API endpoint response error.\n' \ + '\u279c {!r}'.format(e) + raise BackendClientError(msg) from e + finally: + self.session.config.load_balance_endpoints() + + async def __aexit__(self, *args) -> Optional[bool]: + assert self._rqst_ctx is not None + ret = await self._rqst_ctx.__aexit__(*args) + self._rqst_ctx = None + return ret diff --git a/src/ai/backend/client/session.py b/src/ai/backend/client/session.py new file mode 100644 index 0000000000..c343be20b9 --- /dev/null +++ b/src/ai/backend/client/session.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import abc +import asyncio +from contextvars import Context, ContextVar, copy_context +import inspect +import threading +from typing import ( + Any, + AsyncIterator, + Awaitable, + Coroutine, + Iterator, + Literal, + Tuple, + Union, + TypeVar, +) +import queue +import warnings + +import aiohttp +from multidict import CIMultiDict +from .config import APIConfig, MIN_API_VERSION, get_config, parse_api_version +from .exceptions import APIVersionWarning, BackendAPIError, BackendClientError +from .types import Sentinel, sentinel + + +__all__ = ( + 'BaseSession', + 'Session', + 'AsyncSession', + 'api_session', +) + + +api_session: ContextVar[BaseSession] = ContextVar('api_session') + + +async def _negotiate_api_version( + http_session: aiohttp.ClientSession, + config: APIConfig, +) -> Tuple[int, str]: + client_version = parse_api_version(config.version) + try: + timeout_config = aiohttp.ClientTimeout( + total=None, connect=None, + sock_connect=config.connection_timeout, + sock_read=config.read_timeout, + ) + headers = CIMultiDict([ + ('User-Agent', config.user_agent), + ]) + probe_url = config.endpoint / 'func/' if config.endpoint_type == 'session' else config.endpoint + async with http_session.get(probe_url, timeout=timeout_config, headers=headers) as resp: + resp.raise_for_status() + server_info = await resp.json() + server_version = parse_api_version(server_info['version']) + if server_version > client_version: + warnings.warn( + "The server API version is higher than the client. " + "Please upgrade the client package.", + category=APIVersionWarning, + ) + if server_version < MIN_API_VERSION: + warnings.warn( + f"The server is too old and does not meet the minimum API version requirement: " + f"v{MIN_API_VERSION[0]}.{MIN_API_VERSION[1]}\n" + f"Please upgrade the server or downgrade/reinstall the client SDK with " + f"the same major.minor release of the server.", + category=APIVersionWarning, + ) + return min(server_version, client_version) + except (asyncio.TimeoutError, aiohttp.ClientError): + # fallback to the configured API version + return client_version + + +async def _close_aiohttp_session(session: aiohttp.ClientSession) -> None: + # This is a hacky workaround for premature closing of SSL transports + # on Windows Proactor event loops. + # Thanks to Vadim Markovtsev's comment on the aiohttp issue #1925. + # (https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-592596034) + transports = 0 + all_is_lost = asyncio.Event() + if session.connector is None: + all_is_lost.set() + else: + if len(session.connector._conns) == 0: + all_is_lost.set() + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + def connection_lost(exc): + orig_lost(exc) + nonlocal transports + transports -= 1 + if transports == 0: + all_is_lost.set() + + def eof_received(): + try: + orig_eof_received() + except AttributeError: + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + proto.connection_lost = connection_lost + proto.eof_received = eof_received + await session.close() + if transports > 0: + await all_is_lost.wait() + + +_Item = TypeVar('_Item') + + +class _SyncWorkerThread(threading.Thread): + + work_queue: queue.Queue[Union[ + Tuple[Union[AsyncIterator, Coroutine], Context], + Sentinel, + ]] + done_queue: queue.Queue[Union[Any, Exception]] + stream_queue: queue.Queue[Union[Any, Exception, Sentinel]] + stream_block: threading.Event + agen_shutdown: bool + + __slots__ = ( + 'work_queue', + 'done_queue', + 'stream_queue', + 'stream_block', + 'agen_shutdown', + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.work_queue = queue.Queue() + self.done_queue = queue.Queue() + self.stream_queue = queue.Queue() + self.stream_block = threading.Event() + self.agen_shutdown = False + + def run(self) -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + while True: + item = self.work_queue.get() + if item is sentinel: + break + coro, ctx = item + if inspect.isasyncgen(coro): + ctx.run(loop.run_until_complete, + self.agen_wrapper(coro)) + else: + try: + # FIXME: Once python/mypy#12756 is resolved, remove the type-ignore tag. + result = ctx.run(loop.run_until_complete, coro) # type: ignore + except Exception as e: + self.done_queue.put_nowait(e) + else: + self.done_queue.put_nowait(result) + self.work_queue.task_done() + except (SystemExit, KeyboardInterrupt): + pass + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.stop() + loop.close() + + def execute(self, coro: Coroutine) -> Any: + ctx = copy_context() # preserve context for the worker thread + try: + self.work_queue.put((coro, ctx)) + result = self.done_queue.get() + self.done_queue.task_done() + if isinstance(result, Exception): + raise result + return result + finally: + del ctx + + async def agen_wrapper(self, agen): + self.agen_shutdown = False + try: + async for item in agen: + self.stream_block.clear() + self.stream_queue.put(item) + # flow-control the generator. + self.stream_block.wait() + if self.agen_shutdown: + break + except Exception as e: + self.stream_queue.put(e) + finally: + self.stream_queue.put(sentinel) + await agen.aclose() + + def execute_generator(self, asyncgen: AsyncIterator[_Item]) -> Iterator[_Item]: + ctx = copy_context() # preserve context for the worker thread + try: + self.work_queue.put((asyncgen, ctx)) + while True: + item = self.stream_queue.get() + try: + if item is sentinel: + break + if isinstance(item, Exception): + raise item + yield item + finally: + self.stream_block.set() + self.stream_queue.task_done() + finally: + del ctx + + def interrupt_generator(self): + self.agen_shutdown = True + self.stream_block.set() + self.stream_queue.put(sentinel) + + +class BaseSession(metaclass=abc.ABCMeta): + """ + The base abstract class for sessions. + """ + + __slots__ = ( + '_config', '_closed', '_context_token', '_proxy_mode', + 'aiohttp_session', 'api_version', + 'System', 'Manager', 'Admin', + 'Agent', 'AgentWatcher', 'ScalingGroup', 'Storage', + 'Image', 'ComputeSession', 'SessionTemplate', + 'Domain', 'Group', 'Auth', 'User', 'KeyPair', + 'BackgroundTask', + 'EtcdConfig', + 'Resource', 'KeypairResourcePolicy', + 'VFolder', 'Dotfile', + 'ServerLog', + ) + + aiohttp_session: aiohttp.ClientSession + api_version: Tuple[int, str] + + _closed: bool + _config: APIConfig + _proxy_mode: bool + + def __init__( + self, *, + config: APIConfig = None, + proxy_mode: bool = False, + ) -> None: + self._closed = False + self._config = config if config else get_config() + self._proxy_mode = proxy_mode + self.api_version = parse_api_version(self._config.version) + + from .func.system import System + from .func.admin import Admin + from .func.agent import Agent, AgentWatcher + from .func.storage import Storage + from .func.auth import Auth + from .func.bgtask import BackgroundTask + from .func.domain import Domain + from .func.etcd import EtcdConfig + from .func.group import Group + from .func.image import Image + from .func.session import ComputeSession + from .func.keypair import KeyPair + from .func.manager import Manager + from .func.resource import Resource + from .func.keypair_resource_policy import KeypairResourcePolicy + from .func.scaling_group import ScalingGroup + from .func.session_template import SessionTemplate + from .func.user import User + from .func.vfolder import VFolder + from .func.dotfile import Dotfile + from .func.server_log import ServerLog + + self.System = System + self.Admin = Admin + self.Agent = Agent + self.AgentWatcher = AgentWatcher + self.Storage = Storage + self.Auth = Auth + self.BackgroundTask = BackgroundTask + self.EtcdConfig = EtcdConfig + self.Domain = Domain + self.Group = Group + self.Image = Image + self.ComputeSession = ComputeSession + self.KeyPair = KeyPair + self.Manager = Manager + self.Resource = Resource + self.KeypairResourcePolicy = KeypairResourcePolicy + self.User = User + self.ScalingGroup = ScalingGroup + self.SessionTemplate = SessionTemplate + self.VFolder = VFolder + self.Dotfile = Dotfile + self.ServerLog = ServerLog + + @property + def proxy_mode(self) -> bool: + """ + If set True, it skips API version negotiation when opening the session. + """ + return self._proxy_mode + + @abc.abstractmethod + def open(self) -> Union[None, Awaitable[None]]: + """ + Initializes the session and perform version negotiation. + """ + raise NotImplementedError + + @abc.abstractmethod + def close(self) -> Union[None, Awaitable[None]]: + """ + Terminates the session and releases underlying resources. + """ + raise NotImplementedError + + @property + def closed(self) -> bool: + """ + Checks if the session is closed. + """ + return self._closed + + @property + def config(self) -> APIConfig: + """ + The configuration used by this session object. + """ + return self._config + + def __enter__(self) -> BaseSession: + raise NotImplementedError + + def __exit__(self, *exc_info) -> Literal[False]: + return False + + async def __aenter__(self) -> BaseSession: + raise NotImplementedError + + async def __aexit__(self, *exc_info) -> Literal[False]: + return False + + +class Session(BaseSession): + """ + A context manager for API client sessions that makes API requests synchronously. + You may call simple request-response APIs like a plain Python function, + but cannot use streaming APIs based on WebSocket and Server-Sent Events. + """ + + __slots__ = ( + '_worker_thread', + ) + + def __init__( + self, *, + config: APIConfig = None, + proxy_mode: bool = False, + ) -> None: + super().__init__(config=config, proxy_mode=proxy_mode) + self._worker_thread = _SyncWorkerThread() + self._worker_thread.start() + + async def _create_aiohttp_session() -> aiohttp.ClientSession: + ssl = None + if self._config.skip_sslcert_validation: + ssl = False + connector = aiohttp.TCPConnector(ssl=ssl) + return aiohttp.ClientSession(connector=connector) + + self.aiohttp_session = self.worker_thread.execute(_create_aiohttp_session()) + + def open(self) -> None: + self._context_token = api_session.set(self) + if not self._proxy_mode: + self.api_version = self.worker_thread.execute( + _negotiate_api_version(self.aiohttp_session, self.config)) + + def close(self) -> None: + """ + Terminates the session. It schedules the ``close()`` coroutine + of the underlying aiohttp session and then enqueues a sentinel + object to indicate termination. Then it waits until the worker + thread to self-terminate by joining. + """ + if self._closed: + return + self._closed = True + self._worker_thread.interrupt_generator() + self._worker_thread.execute(_close_aiohttp_session(self.aiohttp_session)) + self._worker_thread.work_queue.put(sentinel) + self._worker_thread.join() + api_session.reset(self._context_token) + + @property + def worker_thread(self): + """ + The thread that internally executes the asynchronous implementations + of the given API functions. + """ + return self._worker_thread + + def __enter__(self) -> Session: + assert not self.closed, 'Cannot reuse closed session' + self.open() + if self.config.announcement_handler: + try: + payload = self.Manager.get_announcement() + if payload['enabled']: + self.config.announcement_handler(payload['message']) + except (BackendClientError, BackendAPIError): + # The server may be an old one without annoucement API. + pass + return self + + def __exit__(self, *exc_info) -> Literal[False]: + self.close() + return False # raise up the inner exception + + +class AsyncSession(BaseSession): + """ + A context manager for API client sessions that makes API requests asynchronously. + You may call all APIs as coroutines. + WebSocket-based APIs and SSE-based APIs returns special response types. + """ + + def __init__( + self, *, + config: APIConfig = None, + proxy_mode: bool = False, + ) -> None: + super().__init__(config=config, proxy_mode=proxy_mode) + ssl = None + if self._config.skip_sslcert_validation: + ssl = False + connector = aiohttp.TCPConnector(ssl=ssl) + self.aiohttp_session = aiohttp.ClientSession(connector=connector) + + async def _aopen(self) -> None: + self._context_token = api_session.set(self) + if not self._proxy_mode: + self.api_version = await _negotiate_api_version(self.aiohttp_session, self.config) + + def open(self) -> Awaitable[None]: + return self._aopen() + + async def _aclose(self) -> None: + if self._closed: + return + self._closed = True + await _close_aiohttp_session(self.aiohttp_session) + api_session.reset(self._context_token) + + def close(self) -> Awaitable[None]: + return self._aclose() + + async def __aenter__(self) -> AsyncSession: + assert not self.closed, 'Cannot reuse closed session' + await self.open() + if self.config.announcement_handler: + try: + payload = await self.Manager.get_announcement() + if payload['enabled']: + self.config.announcement_handler(payload['message']) + except (BackendClientError, BackendAPIError): + # The server may be an old one without annoucement API. + pass + return self + + async def __aexit__(self, *exc_info) -> Literal[False]: + await self.close() + return False # raise up the inner exception diff --git a/src/ai/backend/client/test_utils.py b/src/ai/backend/client/test_utils.py new file mode 100644 index 0000000000..3aa1903e0b --- /dev/null +++ b/src/ai/backend/client/test_utils.py @@ -0,0 +1,91 @@ +""" +A support module to async mocks in Python versiosn prior to 3.8. +""" + +import sys +from unittest import mock +if sys.version_info >= (3, 8, 0): + # Since Python 3.8, AsyncMock is now part of the stdlib. + # Python 3.8 also adds magic-mocking async iterators and async context managers. + from unittest.mock import AsyncMock +else: + from asynctest import CoroutineMock as AsyncMock # type: ignore + + +class AsyncContextMock(mock.Mock): + """ + Provides a mock that can be used: + + async with mock(): + ... + + Example: + + # In the test code: + mock_obj = unittest.mock.Mock() + mock_obj.fetch.return_value = AsyncContextMock( + status=200, + json=mock.AsyncMock(return_value={'hello': 'world'}) + ) + mocker.patch('mypkg.mymod.MyClass', return_value=mock_obj) + + # In the tested code: + obj = mpkg.mymod.MyClass() + async with obj.fetch() as resp: + # resp.status is 200 + result = await resp.json() + # result is {'hello': 'world'} + """ + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class AsyncContextMagicMock(mock.MagicMock): + """ + Provides a magic mock that can be used: + + async with mock(): + ... + """ + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class AsyncContextCoroutineMock(AsyncMock): + """ + Provides a mock that can be used: + + async with (await mock(...)): + ... + + Example: + + # In the test code: + mock_obj = unittest.mock.AsyncMock() + mock_obj.fetch.return_value = AsyncContextMock( + status=200, + json=mock.AsyncMock(return_value={'hello': 'world'}) + ) + mocker.patch('mypkg.mymod.MyClass', return_value=mock_obj) + + # In the tested code: + obj = mpkg.mymod.MyClass() + async with (await obj.fetch()) as resp: + # resp.status is 200 + result = await resp.json() + # result is {'hello': 'world'} + """ + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass diff --git a/src/ai/backend/client/types.py b/src/ai/backend/client/types.py new file mode 100644 index 0000000000..e67cd45050 --- /dev/null +++ b/src/ai/backend/client/types.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import enum + +__all__ = ( + "Sentinel", + "Undefined", + "sentinel", + "undefined", +) + + +class Sentinel(enum.Enum): + """ + A special type to represent a special value to indicate closing/shutdown of queues. + """ + token = 0 + + +class Undefined(enum.Enum): + """ + A special type to represent an undefined value. + """ + token = 0 + + +sentinel = Sentinel.token +undefined = Undefined.token diff --git a/src/ai/backend/client/utils.py b/src/ai/backend/client/utils.py new file mode 100644 index 0000000000..4d4abe899c --- /dev/null +++ b/src/ai/backend/client/utils.py @@ -0,0 +1,51 @@ +import io +import os + +from tqdm import tqdm + + +class ProgressReportingReader(io.BufferedReader): + + def __init__(self, file_path, *, tqdm_instance=None): + super().__init__(open(file_path, 'rb')) + self._filename = os.path.basename(file_path) + if tqdm_instance is None: + self._owns_tqdm = True + self.tqdm = tqdm( + unit='bytes', + unit_scale=True, + total=os.path.getsize(file_path), + ) + else: + self._owns_tqdm = False + self.tqdm = tqdm_instance + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + if self._owns_tqdm: + self.tqdm.close() + self.close() + + def read(self, *args, **kwargs): + chunk = super().read(*args, **kwargs) + self.tqdm.set_postfix(file=self._filename, refresh=False) + self.tqdm.update(len(chunk)) + return chunk + + def read1(self, *args, **kwargs): + chunk = super().read1(*args, **kwargs) + self.tqdm.set_postfix(file=self._filename, refresh=False) + self.tqdm.update(len(chunk)) + return chunk + + def readinto(self, *args, **kwargs): + count = super().readinto(*args, **kwargs) + self.tqdm.set_postfix(file=self._filename, refresh=False) + self.tqdm.update(count) + + def readinto1(self, *args, **kwargs): + count = super().readinto1(*args, **kwargs) + self.tqdm.set_postfix(file=self._filename, refresh=False) + self.tqdm.update(count) diff --git a/src/ai/backend/client/versioning.py b/src/ai/backend/client/versioning.py new file mode 100644 index 0000000000..ef59e715ab --- /dev/null +++ b/src/ai/backend/client/versioning.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import ( + Callable, + Sequence, + Tuple, + Union, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from .func.session import ComputeSession + + +naming_profile = { + 'path': ('kernel', 'session'), + 'session_events_path': ('/stream/kernel/_/events', '/events/session'), + 'name_arg': ('clientSessionToken', 'name'), + 'event_name_arg': ('sessionId', 'name'), + 'name_gql_field': ('sess_id', 'name'), + 'type_gql_field': ('sess_type', 'type'), +} + + +def get_naming(api_version: Tuple[int, str], key: str) -> str: + if api_version[0] <= 4: + return naming_profile[key][0] + return naming_profile[key][1] + + +def get_id_or_name(api_version: Tuple[int, str], obj: ComputeSession) -> str: + if api_version[0] <= 4: + assert obj.name is not None + return obj.name + if obj.id: + return str(obj.id) + else: + assert obj.name is not None + return obj.name + + +def apply_version_aware_fields( + api_session, + fields: Sequence[Tuple[str, Union[Callable, str]]], +) -> Sequence[Tuple[str, str]]: + version_aware_fields = [] + for f in fields: + if callable(f[1]): + version_aware_fields.append((f[0], f[1](api_session))) + else: + version_aware_fields.append((f[0], f[1])) + return version_aware_fields diff --git a/src/ai/backend/common/BUILD b/src/ai/backend/common/BUILD new file mode 100644 index 0000000000..5e5131d941 --- /dev/null +++ b/src/ai/backend/common/BUILD @@ -0,0 +1,35 @@ +python_sources( + name="lib", + sources=["**/*.py"], + dependencies=[ + ":resources", + "stubs/trafaret:stubs", + ], +) + +python_distribution( + name="dist", + dependencies=[ + ":lib", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-common", + description="Backend.AI commons library", + license="LGPLv3", + ), + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + ], +) diff --git a/src/ai/backend/common/README.md b/src/ai/backend/common/README.md new file mode 100644 index 0000000000..3e30000f2a --- /dev/null +++ b/src/ai/backend/common/README.md @@ -0,0 +1,39 @@ +Backend.AI Commons +================== + +[![PyPI release version](https://badge.fury.io/py/backend.ai-common.svg)](https://pypi.org/project/backend.ai-common/) +![Supported Python versions](https://img.shields.io/pypi/pyversions/backend.ai-common.svg) +[![Build Status](https://travis-ci.com/lablup/backend.ai-common.svg?branch=master)](https://travis-ci.com/lablup/backend.ai-common) +[![Gitter](https://badges.gitter.im/lablup/backend.ai-common.svg)](https://gitter.im/lablup/backend.ai-common) + +Common utilities library for Backend.AI + + +## Installation + +```console +$ pip install backend.ai-common +``` + +## For development + +```console +$ pip install -U pip setuptools +$ pip install -U -r requirements/dev.txt +``` + +### Running test suite + +```console +$ python -m pytest +``` + +With the default halfstack setup, you may need to set the environment variable `BACKEND_ETCD_ADDR` +to specify the non-standard etcd service port (e.g., `localhost:8110`). + +The tests for `common.redis` module requires availability of local TCP ports 16379, 16380, 16381, +26379, 26380, and 26381 to launch a temporary Redis sentinel cluster via `docker compose`. + +In macOS, they require a local `redis-server` executable to be installed, preferably via `brew`, +because `docker compose` in macOS does not support host-mode networking and Redis *cannot* be +configured to use different self IP addresses to announce to the cluster nodes and clients. diff --git a/src/ai/backend/common/VERSION b/src/ai/backend/common/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/common/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/common/__init__.py b/src/ai/backend/common/__init__.py new file mode 100644 index 0000000000..17b3552989 --- /dev/null +++ b/src/ai/backend/common/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +__version__ = (Path(__file__).parent / 'VERSION').read_text().strip() diff --git a/src/ai/backend/common/argparse.py b/src/ai/backend/common/argparse.py new file mode 100644 index 0000000000..eac7dc28c9 --- /dev/null +++ b/src/ai/backend/common/argparse.py @@ -0,0 +1,101 @@ +import argparse +import ipaddress +import pathlib +from typing import cast, Tuple + +from .types import HostPortPair + + +def port_no(s: str) -> int: + try: + port = int(s) + assert port > 0 + assert port < 65536 + except (ValueError, AssertionError): + msg = f'{s!r} is not a valid port number.' + raise argparse.ArgumentTypeError(msg) + return port + + +def port_range(s: str) -> Tuple[int, int]: + try: + port_range = tuple(map(int, s.split('-'))) + except (TypeError, ValueError): + msg = f'{s!r} should be a hyphen-separated pair of integers.' + raise argparse.ArgumentTypeError(msg) + if len(port_range) != 2: + msg = f'{s!r} should have exactly two integers.' + raise argparse.ArgumentTypeError(msg) + if not (0 < port_range[0] < 65536): + msg = f'{port_range[0]} is not a valid port number.' + raise argparse.ArgumentTypeError(msg) + if not (0 < port_range[1] < 65536): + msg = f'{port_range[1]} is not a valid port number.' + raise argparse.ArgumentTypeError(msg) + if not (port_range[0] < port_range[1]): + msg = f'{port_range[0]} should be less than {port_range[1]}.' + raise argparse.ArgumentTypeError(msg) + return cast(Tuple[int, int], port_range) + + +def positive_int(s: str) -> int: + try: + val = int(s) + assert val > 0 + except (ValueError, AssertionError): + msg = f'{s!r} is not a positive integer.' + raise argparse.ArgumentTypeError(msg) + return val + + +def non_negative_int(s: str) -> int: + try: + val = int(s) + assert val >= 0 + except (ValueError, AssertionError): + msg = f'{s!r} is not a non-negative integer.' + raise argparse.ArgumentTypeError(msg) + return val + + +def host_port_pair(s: str) -> Tuple[ipaddress._BaseAddress, int]: + host: str | ipaddress._BaseAddress + pieces = s.rsplit(':', maxsplit=1) + if len(pieces) == 1: + msg = f'{s!r} should contain both IP address and port number.' + raise argparse.ArgumentTypeError(msg) + elif len(pieces) == 2: + # strip potential brackets in IPv6 hostname-port strings (RFC 3986). + host = pieces[0].strip('[]') + try: + host = ipaddress.ip_address(host) + except ValueError: + # Let it be just a hostname. + host = host + try: + port = int(pieces[1]) + assert port > 0 + assert port < 65536 + except (ValueError, AssertionError): + msg = f'{pieces[1]!r} is not a valid port number.' + raise argparse.ArgumentTypeError(msg) + return HostPortPair(host, port) + + +def ipaddr(s: str) -> ipaddress._BaseAddress: + try: + ip = ipaddress.ip_address(s.strip('[]')) + except ValueError: + msg = f'{s!r} is not a valid IP address.' + raise argparse.ArgumentTypeError(msg) + return ip + + +def path(val: str) -> pathlib.Path: + if val is None: + return None + p = pathlib.Path(val) + if not p.exists(): + msg = f'{val!r} is not a valid file/dir path.' + raise argparse.ArgumentTypeError(msg) + return p diff --git a/src/ai/backend/common/asyncio.py b/src/ai/backend/common/asyncio.py new file mode 100644 index 0000000000..8b21c1f60a --- /dev/null +++ b/src/ai/backend/common/asyncio.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import asyncio +import inspect +from typing import ( + Any, + Awaitable, + Callable, + Collection, + Tuple, + Type, + TypeVar, + Sequence, + cast, +) + +__all__ = ( + 'AsyncBarrier', + 'cancel_tasks', + 'current_loop', + 'run_through', +) + +RT = TypeVar('RT') + + +async def cancel_tasks( + tasks: Collection[asyncio.Task[RT]], +) -> Sequence[RT | Exception]: + """ + Cancel all unfinished tasks from the given collection of asyncio tasks, + using :func:`asyncio.gather()` to let them clean up concurrently. + It returns the results and exceptions without raising them, for cases when + the caller wants to silent ignore errors or handle them at once. + """ + copied_tasks = {*tasks} + cancelled_tasks = [] + for task in copied_tasks: + if not task.done(): + task.cancel() + cancelled_tasks.append(task) + return await asyncio.gather(*cancelled_tasks, return_exceptions=True) + + +current_loop: Callable[[], asyncio.AbstractEventLoop] +if hasattr(asyncio, 'get_running_loop'): + current_loop = asyncio.get_running_loop # type: ignore +else: + current_loop = asyncio.get_event_loop # type: ignore + + +async def run_through( + *awaitable_or_callables: Callable[[], None] | Awaitable[None], + ignored_exceptions: Tuple[Type[Exception], ...], +) -> None: + """ + A syntactic sugar to simplify the code patterns like: + + .. code-block:: python3 + + try: + await do1() + except MyError: + pass + try: + await do2() + except MyError: + pass + try: + await do3() + except MyError: + pass + + Using ``run_through()``, it becomes: + + .. code-block:: python3 + + await run_through( + do1(), + do2(), + do3(), + ignored_exceptions=(MyError,), + ) + """ + for f in awaitable_or_callables: + try: + if inspect.iscoroutinefunction(f): + await f() # type: ignore + elif inspect.isawaitable(f): + await f # type: ignore + else: + f() # type: ignore + except Exception as e: + if isinstance(e, cast(Tuple[Any, ...], ignored_exceptions)): + continue + raise + + +class AsyncBarrier: + """ + This class provides a simplified asyncio-version of threading.Barrier class. + """ + + num_parties: int = 1 + cond: asyncio.Condition + + def __init__(self, num_parties: int) -> None: + self.num_parties = num_parties + self.count = 0 + self.cond = asyncio.Condition() + + async def wait(self) -> None: + async with self.cond: + self.count += 1 + if self.count == self.num_parties: + self.cond.notify_all() + else: + while self.count < self.num_parties: + await self.cond.wait() + + def reset(self) -> None: + self.count = 0 + # FIXME: if there are waiting coroutines, let them + # raise BrokenBarrierError like threading.Barrier diff --git a/src/ai/backend/common/bgtask.py b/src/ai/backend/common/bgtask.py new file mode 100644 index 0000000000..1c9205a083 --- /dev/null +++ b/src/ai/backend/common/bgtask.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import time +import uuid +import weakref +from typing import ( + Awaitable, + Callable, + Final, + Literal, + Optional, + TypeAlias, + Union, + Set, + Type, +) + +import aioredis +import aioredis.client +from aiohttp import web +from aiohttp_sse import sse_response + +from . import redis +from .events import ( + BgtaskCancelledEvent, + BgtaskDoneEvent, + BgtaskFailedEvent, + BgtaskUpdatedEvent, + EventDispatcher, + EventProducer, +) +from .logging import BraceStyleAdapter +from .types import AgentId, Sentinel + +sentinel: Final = Sentinel.TOKEN +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.background')) +TaskResult = Literal['bgtask_done', 'bgtask_cancelled', 'bgtask_failed'] +BgtaskEvents: TypeAlias = BgtaskUpdatedEvent | BgtaskDoneEvent | BgtaskCancelledEvent | BgtaskFailedEvent + +MAX_BGTASK_ARCHIVE_PERIOD: Final = 86400 # 24 hours + + +class ProgressReporter: + event_producer: Final[EventProducer] + task_id: Final[uuid.UUID] + total_progress: Union[int, float] + current_progress: Union[int, float] + + def __init__( + self, + event_dispatcher: EventProducer, + task_id: uuid.UUID, + current_progress: int = 0, + total_progress: int = 0, + ) -> None: + self.event_producer = event_dispatcher + self.task_id = task_id + self.current_progress = current_progress + self.total_progress = total_progress + + async def update(self, increment: Union[int, float] = 0, message: str = None): + self.current_progress += increment + # keep the state as local variables because they might be changed + # due to interleaving at await statements below. + current, total = self.current_progress, self.total_progress + redis_producer = self.event_producer.redis_client + + def _pipe_builder(r: aioredis.Redis) -> aioredis.client.Pipeline: + pipe = r.pipeline(transaction=False) + tracker_key = f'bgtask.{self.task_id}' + pipe.hset(tracker_key, mapping={ + 'current': str(current), + 'total': str(total), + 'msg': message or '', + 'last_update': str(time.time()), + }) + pipe.expire(tracker_key, MAX_BGTASK_ARCHIVE_PERIOD) + return pipe + + await redis.execute(redis_producer, _pipe_builder) + await self.event_producer.produce_event( + BgtaskUpdatedEvent( + self.task_id, + message=message, + current_progress=current, + total_progress=total, + ), + ) + + +BackgroundTask = Callable[[ProgressReporter], Awaitable[Optional[str]]] + + +class BackgroundTaskManager: + event_producer: EventProducer + ongoing_tasks: weakref.WeakSet[asyncio.Task] + task_update_queues: Set[asyncio.Queue[Sentinel | BgtaskEvents]] + + def __init__(self, event_producer: EventProducer) -> None: + self.event_producer = event_producer + self.ongoing_tasks = weakref.WeakSet() + self.task_update_queues = set() + + def register_event_handlers(self, event_dispatcher: EventDispatcher) -> None: + """ + Add bgtask related event handlers to the given event dispatcher. + """ + event_dispatcher.subscribe(BgtaskUpdatedEvent, None, self._enqueue_bgtask_status_update) + event_dispatcher.subscribe(BgtaskDoneEvent, None, self._enqueue_bgtask_status_update) + event_dispatcher.subscribe(BgtaskCancelledEvent, None, self._enqueue_bgtask_status_update) + event_dispatcher.subscribe(BgtaskFailedEvent, None, self._enqueue_bgtask_status_update) + + async def _enqueue_bgtask_status_update( + self, + context: None, + source: AgentId, + event: BgtaskEvents, + ) -> None: + for q in self.task_update_queues: + q.put_nowait(event) + + async def push_bgtask_events( + self, + request: web.Request, + task_id: uuid.UUID, + ) -> web.StreamResponse: + """ + A aiohttp-based server-sent events (SSE) responder that pushes the bgtask updates + to the clients. + """ + tracker_key = f'bgtask.{task_id}' + redis_producer = self.event_producer.redis_client + task_info = await redis.execute( + redis_producer, + lambda r: r.hgetall(tracker_key), + encoding='utf-8', + ) + + log.debug('task info: {}', task_info) + if task_info is None: + # The task ID is invalid or represents a task completed more than 24 hours ago. + raise ValueError('No such background task.') + + if task_info['status'] != 'started': + # It is an already finished task! + async with sse_response(request) as resp: + try: + body = { + 'task_id': str(task_id), + 'status': task_info['status'], + 'current_progress': task_info['current'], + 'total_progress': task_info['total'], + 'message': task_info['msg'], + } + await resp.send(json.dumps(body), event=f"task_{task_info['status']}") + finally: + await resp.send('{}', event="server_close") + return resp + + # It is an ongoing task. + my_queue: asyncio.Queue[BgtaskEvents | Sentinel] = asyncio.Queue() + self.task_update_queues.add(my_queue) + try: + async with sse_response(request) as resp: + try: + while True: + event = await my_queue.get() + try: + if event is sentinel: + break + if task_id != event.task_id: + continue + body = { + 'task_id': str(task_id), + 'message': event.message, + } + if isinstance(event, BgtaskUpdatedEvent): + body['current_progress'] = event.current_progress + body['total_progress'] = event.total_progress + await resp.send(json.dumps(body), event=event.name, retry=5) + if (isinstance(event, BgtaskDoneEvent) or + isinstance(event, BgtaskFailedEvent) or + isinstance(event, BgtaskCancelledEvent)): + await resp.send('{}', event="server_close") + break + finally: + my_queue.task_done() + finally: + return resp + finally: + self.task_update_queues.remove(my_queue) + + async def start( + self, + func: BackgroundTask, + name: str = None, + ) -> uuid.UUID: + task_id = uuid.uuid4() + redis_producer = self.event_producer.redis_client + + def _pipe_builder(r: aioredis.Redis) -> aioredis.client.Pipeline: + pipe = r.pipeline() + tracker_key = f'bgtask.{task_id}' + now = str(time.time()) + pipe.hset(tracker_key, mapping={ + 'status': 'started', + 'current': '0', + 'total': '0', + 'msg': '', + 'started_at': now, + 'last_update': now, + }) + pipe.expire(tracker_key, MAX_BGTASK_ARCHIVE_PERIOD) + return pipe + + await redis.execute(redis_producer, _pipe_builder) + + task = asyncio.create_task(self._wrapper_task(func, task_id, name)) + self.ongoing_tasks.add(task) + return task_id + + async def _wrapper_task( + self, + func: BackgroundTask, + task_id: uuid.UUID, + task_name: Optional[str], + ) -> None: + task_result: TaskResult + reporter = ProgressReporter(self.event_producer, task_id) + message = '' + event_cls: Type[BgtaskDoneEvent] | Type[BgtaskCancelledEvent] | Type[BgtaskFailedEvent] = \ + BgtaskDoneEvent + try: + message = await func(reporter) or '' + task_result = 'bgtask_done' + except asyncio.CancelledError: + task_result = 'bgtask_cancelled' + event_cls = BgtaskCancelledEvent + except Exception as e: + task_result = 'bgtask_failed' + event_cls = BgtaskFailedEvent + message = repr(e) + log.exception("Task {} ({}): unhandled error", task_id, task_name) + finally: + redis_producer = self.event_producer.redis_client + + async def _pipe_builder(r: aioredis.Redis): + pipe = r.pipeline() + tracker_key = f'bgtask.{task_id}' + pipe.hset(tracker_key, mapping={ + 'status': task_result[7:], # strip "bgtask_" + 'msg': message, + 'last_update': str(time.time()), + }) + pipe.expire(tracker_key, MAX_BGTASK_ARCHIVE_PERIOD) + await pipe.execute() + + await redis.execute(redis_producer, _pipe_builder) + await self.event_producer.produce_event( + event_cls( + task_id, + message=message, + ), + ) + log.info('Task {} ({}): {}', task_id, task_name or '', task_result) + + async def shutdown(self) -> None: + join_tasks = [] + log.info('Cancelling remaining background tasks...') + for task in self.ongoing_tasks.copy(): + if task.done(): + continue + try: + task.cancel() + await task + except asyncio.CancelledError: + pass + for tq in self.task_update_queues: + tq.put_nowait(sentinel) + join_tasks.append(tq.join()) + await asyncio.gather(*join_tasks) diff --git a/src/ai/backend/common/cli.py b/src/ai/backend/common/cli.py new file mode 100644 index 0000000000..5c573ab47a --- /dev/null +++ b/src/ai/backend/common/cli.py @@ -0,0 +1,101 @@ +from decimal import Decimal +from enum import Enum +import functools +from importlib import import_module +import re +from types import FunctionType +from typing import Any, Optional, Union, Type + +import click + + +def wrap_method(method): + @functools.wraps(method) + def wrapped(self, *args, **kwargs): + return method(self._impl, *args, **kwargs) + return wrapped + + +class LazyClickMixin: + ''' + Click's documentations says "supports lazy loading of subcommands at runtime", + but there is no actual examples and how-tos as indicated by the issue: + https://github.com/pallets/click/issues/945 + + This class fills the gap by binding the methods of original Click classes to + a wrapper that lazily loads the underlying Click object. + ''' + + _import_name: str + _loaded_impl: Optional[Union[click.Command, click.Group]] + + def __init__(self, *, import_name, **kwargs): + self._import_name = import_name + self._loaded_impl = None + super().__init__(**kwargs) + for key, val in vars(type(self).__mro__[2]).items(): + if key.startswith('__'): + continue + if isinstance(val, FunctionType): + setattr(self, key, wrap_method(val).__get__(self, self.__class__)) + + @property + def _impl(self): + if self._loaded_impl: + return self._loaded_impl + # Load when first invoked. + module, name = self._import_name.split(':', 1) + self._loaded_impl = getattr(import_module(module), name) + return self._loaded_impl + + +class LazyGroup(LazyClickMixin, click.Group): + pass + + +class EnumChoice(click.Choice): + + enum: Type[Enum] + + def __init__(self, enum: Type[Enum]): + enum_members = [e.name for e in enum] + super().__init__(enum_members) + self.enum = enum + + def convert(self, value: Any, param, ctx): + if isinstance(value, self.enum): + # for default value, it is already the enum type. + return next(e for e in self.enum if e == value) + value = super().convert(value, param, ctx) + return next(k for k in self.enum.__members__.keys() if k == value) + + def get_metavar(self, param): + name = self.enum.__name__ + name = re.sub(r"([A-Z\d]+)([A-Z][a-z])", r'\1_\2', name) + name = re.sub(r"([a-z\d])([A-Z])", r'\1_\2', name) + return name.upper() + + +class MinMaxRangeParamType(click.ParamType): + name = "min-max decimal range" + + def convert(self, value, param, ctx): + try: + left, _, right = value.partition(':') + if left: + left = Decimal(left) + else: + left = None + if right: + right = Decimal(right) + else: + right = None + return left, right + except (ArithmeticError, ValueError): + self.fail(f"{value!r} contains an invalid number", param, ctx) + + def get_metavar(self, param): + return 'MIN:MAX' + + +MinMaxRange = MinMaxRangeParamType() diff --git a/src/ai/backend/common/config.py b/src/ai/backend/common/config.py new file mode 100644 index 0000000000..791230e7e0 --- /dev/null +++ b/src/ai/backend/common/config.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import os +from pathlib import Path +import sys +from typing import ( + Any, Optional, Union, + Dict, Mapping, MutableMapping, + Tuple, + cast, +) + +import toml +from toml.decoder import InlineTableDict +import trafaret as t + +from . import validators as tx +from .etcd import AsyncEtcd, ConfigScopes +from .exception import ConfigurationError + +__all__ = ( + 'ConfigurationError', + 'etcd_config_iv', + 'redis_config_iv', + 'vfolder_config_iv', + 'read_from_file', + 'read_from_etcd', + 'override_key', + 'override_with_env', + 'check', + 'merge', +) + + +etcd_config_iv = t.Dict({ + t.Key('etcd'): t.Dict({ + t.Key('namespace'): t.String, + t.Key('addr', ('127.0.0.1', 2379)): tx.HostPortPair, + t.Key('user', default=''): t.Null | t.String(allow_blank=True), + t.Key('password', default=''): t.Null | t.String(allow_blank=True), + }).allow_extra('*'), +}).allow_extra('*') + +redis_config_iv = t.Dict({ + t.Key('addr', default=('127.0.0.1', 6379)): tx.HostPortPair, + t.Key('password', default=None): t.Null | t.String, +}).allow_extra('*') + +vfolder_config_iv = t.Dict({ + tx.AliasedKey(['mount', '_mount'], default=None): t.Null | tx.Path(type='dir'), + tx.AliasedKey(['fsprefix', '_fsprefix'], default=''): + tx.Path(type='dir', resolve=False, relative_only=True, allow_nonexisting=True), +}).allow_extra('*') + + +def find_config_file(daemon_name: str) -> Path: + toml_path_from_env = os.environ.get('BACKEND_CONFIG_FILE', None) + if not toml_path_from_env: + toml_paths = [ + Path.cwd() / f'{daemon_name}.toml', + ] + if sys.platform.startswith('linux') or sys.platform.startswith('darwin'): + toml_paths += [ + Path.home() / '.config' / 'backend.ai' / f'{daemon_name}.toml', + Path(f'/etc/backend.ai/{daemon_name}.toml'), + ] + else: + raise ConfigurationError({ + 'read_from_file()': f"Unsupported platform for config path auto-discovery: {sys.platform}", + }) + else: + toml_paths = [Path(toml_path_from_env)] + for _path in toml_paths: + if _path.is_file(): + return _path + else: + searched_paths = ','.join(map(str, toml_paths)) + raise ConfigurationError({ + 'find_config_file()': f"Could not read config from: {searched_paths}", + }) + + +def read_from_file(toml_path: Optional[Union[Path, str]], daemon_name: str) -> Tuple[Dict[str, Any], Path]: + config: Dict[str, Any] + discovered_path: Path + if toml_path is None: + discovered_path = find_config_file(daemon_name) + else: + discovered_path = Path(toml_path) + try: + config = cast(Dict[str, Any], toml.loads(discovered_path.read_text())) + config = _sanitize_inline_dicts(config) + except IOError: + raise ConfigurationError({ + 'read_from_file()': f"Could not read config from: {discovered_path}", + }) + else: + return config, discovered_path + + +async def read_from_etcd(etcd_config: Mapping[str, Any], + scope_prefix_map: Mapping[ConfigScopes, str]) \ + -> Optional[Dict[str, Any]]: + etcd = AsyncEtcd(etcd_config['addr'], etcd_config['namespace'], scope_prefix_map) + raw_value = await etcd.get('daemon/config') + if raw_value is None: + return None + config: Dict[str, Any] + config = cast(Dict[str, Any], toml.loads(raw_value)) + config = _sanitize_inline_dicts(config) + return config + + +def override_key(table: MutableMapping[str, Any], key_path: Tuple[str, ...], value: Any): + for k in key_path[:-1]: + if k not in table: + table[k] = {} + table = table[k] + table[key_path[-1]] = value + + +def override_with_env(table: MutableMapping[str, Any], key_path: Tuple[str, ...], env_key: str): + val = os.environ.get(env_key, None) + if val is None: + return + override_key(table, key_path, val) + + +def check(table: Any, iv: t.Trafaret): + try: + config = iv.check(table) + except t.DataError as e: + raise ConfigurationError(e.as_dict()) + else: + return config + + +def merge(table: Mapping[str, Any], updates: Mapping[str, Any]) -> Mapping[str, Any]: + result = {**table} + for k, v in updates.items(): + if isinstance(v, Mapping): + orig = result.get(k, {}) + assert isinstance(orig, Mapping) + result[k] = merge(orig, v) + else: + result[k] = v + return result + + +def _sanitize_inline_dicts(table: Dict[str, Any] | InlineTableDict) -> Dict[str, Any]: + result: Dict[str, Any] = {} + # Due to the way of toml.decoder to use Python class hierarchy to annotate + # inline or non-inline tables of TOML, we need to skip type checking here. + for k, v in table.items(): # type: ignore + if isinstance(v, InlineTableDict): + # Since this function always returns a copied dict, + # this automatically converts InlineTableDict to dict. + result[k] = _sanitize_inline_dicts(cast(Dict[str, Any], v)) + elif isinstance(v, Dict): + result[k] = _sanitize_inline_dicts(v) + else: + result[k] = v + return result diff --git a/src/ai/backend/common/distributed.py b/src/ai/backend/common/distributed.py new file mode 100644 index 0000000000..b13150e19d --- /dev/null +++ b/src/ai/backend/common/distributed.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import ( + Callable, + Final, + TYPE_CHECKING, +) + +from .logging import BraceStyleAdapter + +if TYPE_CHECKING: + from .events import AbstractEvent, EventProducer + from .lock import AbstractDistributedLock + + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class GlobalTimer: + + """ + Executes the given async function only once in the given interval, + uniquely among multiple manager instances across multiple nodes. + """ + + _event_producer: Final[EventProducer] + + def __init__( + self, + dist_lock: AbstractDistributedLock, + event_producer: EventProducer, + event_factory: Callable[[], AbstractEvent], + interval: float = 10.0, + initial_delay: float = 0.0, + ) -> None: + self._dist_lock = dist_lock + self._event_producer = event_producer + self._event_factory = event_factory + self._stopped = False + self.interval = interval + self.initial_delay = initial_delay + + async def generate_tick(self) -> None: + try: + await asyncio.sleep(self.initial_delay) + if self._stopped: + return + while True: + try: + async with self._dist_lock: + if self._stopped: + return + await self._event_producer.produce_event(self._event_factory()) + if self._stopped: + return + await asyncio.sleep(self.interval) + except asyncio.TimeoutError: # timeout raised from etcd lock + log.warn('timeout raised while trying to acquire lock. retrying...') + except asyncio.CancelledError: + pass + + async def join(self) -> None: + self._tick_task = asyncio.create_task(self.generate_tick()) + + async def leave(self) -> None: + self._stopped = True + await asyncio.sleep(0) + if not self._tick_task.done(): + try: + self._tick_task.cancel() + await self._tick_task + except asyncio.CancelledError: + pass diff --git a/src/ai/backend/common/docker.py b/src/ai/backend/common/docker.py new file mode 100644 index 0000000000..92065d7e30 --- /dev/null +++ b/src/ai/backend/common/docker.py @@ -0,0 +1,399 @@ +import ipaddress +import itertools +import json +import logging +from packaging import version +import re +from typing import ( + Any, Final, Optional, Union, + Dict, Mapping, + Iterable, + Tuple, Sequence, + MutableMapping, +) + +import aiohttp +import yarl + +from .logging import BraceStyleAdapter +from .etcd import ( + AsyncEtcd, + quote as etcd_quote, + unquote as etcd_unquote, +) +from .exception import UnknownImageRegistry + +__all__ = ( + 'arch_name_aliases', + 'default_registry', + 'default_repository', + 'docker_api_arch_aliases', + 'login', + 'get_known_registries', + 'is_known_registry', + 'get_registry_info', + 'MIN_KERNELSPEC', + 'MAX_KERNELSPEC', + 'ImageRef', +) + +arch_name_aliases: Final[Mapping[str, str]] = { + "arm64": "aarch64", # macOS with LLVM + "amd64": "x86_64", # Windows/Linux + "x64": "x86_64", # Windows + "x32": "x86", # Windows + "i686": "x86", # Windows +} +# generalize architecture symbols to match docker API's norm +docker_api_arch_aliases: Final[Mapping[str, str]] = { + 'aarch64': 'arm64', + 'arm64': 'arm64', + 'x86_64': 'amd64', + 'x64': 'amd64', + 'amd64': 'amd64', + 'x86': '386', + 'x32': '386', + 'i686': '386', + '386': '386', +} + +log = BraceStyleAdapter(logging.Logger('ai.backend.common.docker')) + +default_registry = 'index.docker.io' +default_repository = 'lablup' + +MIN_KERNELSPEC = 1 +MAX_KERNELSPEC = 1 + + +async def login( + sess: aiohttp.ClientSession, + registry_url: yarl.URL, + credentials: dict, + scope: str) -> dict: + """ + Authorize to the docker registry using the given credentials and token scope, and returns a set + of required aiohttp.ClientSession.request() keyword arguments for further API requests. + + Some registry servers only rely on HTTP Basic Authentication without token-based access controls + (usually via nginx proxy). We do support them also. :) + """ + basic_auth: Optional[aiohttp.BasicAuth] + + if credentials.get('username') and credentials.get('password'): + basic_auth = aiohttp.BasicAuth( + credentials['username'], credentials['password'], + ) + else: + basic_auth = None + realm = registry_url / 'token' # fallback + service = 'registry' # fallback + async with sess.get(registry_url / 'v2/', auth=basic_auth) as resp: + ping_status = resp.status + www_auth_header = resp.headers.get('WWW-Authenticate') + if www_auth_header: + match = re.search(r'realm="([^"]+)"', www_auth_header) + if match: + realm = yarl.URL(match.group(1)) + match = re.search(r'service="([^"]+)"', www_auth_header) + if match: + service = match.group(1) + if ping_status == 200: + log.debug('docker-registry: {0} -> basic-auth', registry_url) + return {'auth': basic_auth, 'headers': {}} + elif ping_status == 404: + raise RuntimeError(f'Unsupported docker registry: {registry_url}! ' + '(API v2 not implemented)') + elif ping_status == 401: + params = { + 'scope': scope, + 'offline_token': 'true', + 'client_id': 'docker', + 'service': service, + } + async with sess.get(realm, params=params, auth=basic_auth) as resp: + log.debug('docker-registry: {0} -> {1}', registry_url, realm) + if resp.status == 200: + data = json.loads(await resp.read()) + token = data.get('token', None) + return {'auth': None, 'headers': { + 'Authorization': f'Bearer {token}', + }} + raise RuntimeError('authentication for docker registry ' + f'{registry_url} failed') + + +async def get_known_registries(etcd: AsyncEtcd) -> Mapping[str, yarl.URL]: + data = await etcd.get_prefix('config/docker/registry/') + results: MutableMapping[str, yarl.URL] = {} + for key, value in data.items(): + name = etcd_unquote(key) + if isinstance(value, str): + results[name] = yarl.URL(value) + elif isinstance(value, Mapping): + results[name] = yarl.URL(value['']) + return results + + +def is_known_registry(val: str, + known_registries: Union[Mapping[str, Any], Sequence[str]] = None): + if val == default_registry: + return True + if known_registries is not None and val in known_registries: + return True + try: + url = yarl.URL('//' + val) + if url.host and ipaddress.ip_address(url.host): + return True + except ValueError: + pass + return False + + +async def get_registry_info(etcd: AsyncEtcd, name: str) -> Tuple[yarl.URL, dict]: + reg_path = f'config/docker/registry/{etcd_quote(name)}' + item = await etcd.get_prefix(reg_path) + if not item: + raise UnknownImageRegistry(name) + registry_addr = item[''] + if not registry_addr: + raise UnknownImageRegistry(name) + creds = {} + username = item.get('username') + if username is not None: + creds['username'] = username + password = item.get('password') + if password is not None: + creds['password'] = password + return yarl.URL(registry_addr), creds + + +class PlatformTagSet(Mapping): + + __slots__ = ('_data', ) + _data: Dict[str, str] + _rx_ver = re.compile(r'^(?P[a-zA-Z]+)(?P\d+(?:\.\d+)*[a-z0-9]*)?$') + + def __init__(self, tags: Iterable[str]): + self._data = dict() + rx = type(self)._rx_ver + for t in tags: + match = rx.search(t) + if match is None: + raise ValueError('invalid tag-version string', t) + key = match.group('tag') + value = match.group('version') + if key in self._data: + raise ValueError('duplicate platform tag with different versions', t) + if value is None: + value = '' + self._data[key] = value + + def has(self, key: str, version: str = None): + if version is None: + return key in self._data + _v = self._data.get(key, None) + return _v == version + + def __getitem__(self, key: str): + return self._data[key] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __eq__(self, other): + if isinstance(other, (set, frozenset)): + return set(self._data.keys()) == other + return self._data == other + + +class ImageRef: + """ + Class to represent image reference. + passing ['*'] to `known_registries` when creating object + will allow any repository on canonical string. + """ + __slots__ = ('_registry', '_name', '_tag', '_arch', '_tag_set', '_sha') + + _rx_slug = re.compile(r'^[A-Za-z0-9](?:[A-Za-z0-9-._]*[A-Za-z0-9])?$') + + def __init__( + self, + value: str, + known_registries: Union[Mapping[str, Any], Sequence[str]] = None, + architecture='x86_64', + ): + self._arch = arch_name_aliases.get(architecture, architecture) + rx_slug = type(self)._rx_slug + if '://' in value or value.startswith('//'): + raise ValueError('ImageRef should not contain the protocol scheme.') + parts = value.split('/', maxsplit=1) + if len(parts) == 1: + self._registry = default_registry + self._name, self._tag = ImageRef._parse_image_tag(value, True) + if not rx_slug.search(self._tag): + raise ValueError('Invalid image tag') + else: + if is_known_registry(parts[0], known_registries): + self._registry = parts[0] + using_default = (parts[0].endswith('.docker.io') or parts[0] == 'docker.io') + self._name, self._tag = ImageRef._parse_image_tag(parts[1], using_default) + # add ['*'] as magic keyword to accept any repository as valid repo + elif known_registries == ['*']: + self._registry = parts[0] + self._name, self._tag = ImageRef._parse_image_tag(parts[1], False) + else: + self._registry = default_registry + self._name, self._tag = ImageRef._parse_image_tag(value, True) + if not rx_slug.search(self._tag): + raise ValueError('Invalid image tag') + self._update_tag_set() + + @staticmethod + def _parse_image_tag(s: str, using_default_registry: bool = False) -> Tuple[str, str]: + image_tag = s.rsplit(':', maxsplit=1) + if len(image_tag) == 1: + image = image_tag[0] + tag = 'latest' + else: + image = image_tag[0] + tag = image_tag[1] + if not image: + raise ValueError('Empty image repository/name') + if ('/' not in image) and using_default_registry: + image = default_repository + '/' + image + return image, tag + + def _update_tag_set(self): + if self._tag is None: + self._tag_set = (None, PlatformTagSet([])) + return + tags = self._tag.split('-') + self._tag_set = (tags[0], PlatformTagSet(tags[1:])) + + def generate_aliases(self) -> Mapping[str, 'ImageRef']: + basename = self.name.split('/')[-1] + possible_names = basename.rsplit('-') + if len(possible_names) > 1: + possible_names = [basename, possible_names[1]] + + possible_ptags = [] + tag_set = self.tag_set + if not tag_set[0]: + pass + else: + possible_ptags.append([tag_set[0]]) + for tag_key in tag_set[1]: + tag_ver = tag_set[1][tag_key] + tag_list = ['', tag_key, tag_key + tag_ver] + if '.' in tag_ver: + tag_list.append(tag_key + tag_ver.rsplit('.')[0]) + elif tag_key == 'py' and len(tag_ver) > 1: + tag_list.append(tag_key + tag_ver[0]) + if 'cuda' in tag_key: + tag_list.append('gpu') + possible_ptags.append(tag_list) + + ret = {} + for name in possible_names: + ret[name] = self + for name, ptags in itertools.product( + possible_names, + itertools.product(*possible_ptags)): + ret[f"{name}:{'-'.join(t for t in ptags if t)}"] = self + return ret + + @staticmethod + def merge_aliases(genned_aliases_1, genned_aliases_2) -> Mapping[str, 'ImageRef']: + ret = {} + aliases_set_1, aliases_set_2 = set(genned_aliases_1.keys()), set(genned_aliases_2.keys()) + aliases_dup = aliases_set_1 & aliases_set_2 + + for alias in aliases_dup: + ret[alias] = max(genned_aliases_1[alias], genned_aliases_2[alias]) + + for alias in aliases_set_1 - aliases_dup: + ret[alias] = genned_aliases_1[alias] + for alias in aliases_set_2 - aliases_dup: + ret[alias] = genned_aliases_2[alias] + + return ret + + @property + def canonical(self) -> str: + # e.g., registry.docker.io/lablup/kernel-python:3.6-ubuntu + return f'{self.registry}/{self.name}:{self.tag}' + + @property + def registry(self) -> str: + # e.g., lablup + return self._registry + + @property + def name(self) -> str: + # e.g., python + return self._name + + @property + def tag(self) -> str: + # e.g., 3.6-ubuntu + return self._tag + + @property + def architecture(self) -> str: + # e.g., aarch64 + return self._arch + + @property + def tag_set(self) -> Tuple[str, PlatformTagSet]: + # e.g., '3.6', {'ubuntu', 'cuda', ...} + return self._tag_set + + @property + def short(self) -> str: + """ + Returns the image reference string without the registry part. + """ + # e.g., python:3.6-ubuntu + return f'{self.name}:{self.tag}' if self.tag is not None else self.name + + def __str__(self) -> str: + return self.canonical + + def __repr__(self) -> str: + return f'' + + def __hash__(self) -> int: + return hash((self._name, self._tag, self._registry, self._arch)) + + def __eq__(self, other) -> bool: + return (self._registry == other._registry and + self._name == other._name and + self._tag == other._tag and + self._arch == other._arch) + + def __ne__(self, other) -> bool: + return (self._registry != other._registry or + self._name != other._name or + self._tag != other._tag or + self._arch != other._arch) + + def __lt__(self, other) -> bool: + if self == other: # call __eq__ first for resolved check + return False + if self.name != other.name: + raise ValueError('only the image-refs with same names can be compared.') + if self.tag_set[0] != other.tag_set[0]: + return version.parse(self.tag_set[0]) < version.parse(other.tag_set[0]) + ptagset_self, ptagset_other = self.tag_set[1], other.tag_set[1] + for key_self in ptagset_self: + if ptagset_other.has(key_self): + version_self, version_other = ptagset_self.get(key_self), ptagset_other.get(key_self) + if version_self and version_other: + parsed_version_self, parsed_version_other = version.parse(version_self), version.parse(version_other) + if parsed_version_self != parsed_version_other: + return parsed_version_self < parsed_version_other + return len(ptagset_self) > len(ptagset_other) diff --git a/src/ai/backend/common/enum_extension.py b/src/ai/backend/common/enum_extension.py new file mode 100644 index 0000000000..fe361779be --- /dev/null +++ b/src/ai/backend/common/enum_extension.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import enum + +__all__ = ( + 'StringSetFlag', +) + + +class StringSetFlag(enum.Flag): + + def __eq__(self, other): + return self.value == other + + def __hash__(self): + return hash(self.value) + + def __or__(self, other): + if isinstance(other, type(self)): + other = other.value + if not isinstance(other, (set, frozenset)): + other = set((other,)) + return set((self.value,)) | other + + __ror__ = __or__ + + def __and__(self, other): + if isinstance(other, (set, frozenset)): + return self.value in other + if isinstance(other, str): + return self.value == other + raise TypeError + + __rand__ = __and__ + + def __xor__(self, other): + if isinstance(other, (set, frozenset)): + return set((self.value,)) ^ other + if isinstance(other, str): + if other == self.value: + return set() + else: + return other + raise TypeError + + def __rxor__(self, other): + if isinstance(other, (set, frozenset)): + return other ^ set((self.value,)) + if isinstance(other, str): + if other == self.value: + return set() + else: + return other + raise TypeError + + def __str__(self): + return self.value diff --git a/src/ai/backend/common/enum_extension.pyi b/src/ai/backend/common/enum_extension.pyi new file mode 100644 index 0000000000..0f4ddb125b --- /dev/null +++ b/src/ai/backend/common/enum_extension.pyi @@ -0,0 +1,22 @@ +import enum + + +class StringSetFlag(enum.Flag): + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + def __or__( # type: ignore[override] + self, + other: StringSetFlag | str | set[str] | frozenset[str], + ) -> set[str]: ... + def __and__( # type: ignore[override] + self, + other: StringSetFlag | str | set[str] | frozenset[str], + ) -> bool: ... + def __xor__( # type: ignore[override] + self, + other: StringSetFlag | str | set[str] | frozenset[str], + ) -> set[str]: ... + def __ror__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> set[str]: ... + def __rand__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> bool: ... + def __rxor__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> set[str]: ... + def __str__(self) -> str: ... diff --git a/src/ai/backend/common/etcd.py b/src/ai/backend/common/etcd.py new file mode 100644 index 0000000000..dd39cc1cae --- /dev/null +++ b/src/ai/backend/common/etcd.py @@ -0,0 +1,482 @@ +''' +An asynchronous client wrapper for etcd v3 API. + +It uses the etcd3 library using a thread pool executor. +We plan to migrate to aioetcd3 library but it requires more work to get maturity. +Fortunately, etcd3's watchers are not blocking because they are implemented +using callbacks in separate threads. +''' + +from __future__ import annotations + +import asyncio +from collections import namedtuple, ChainMap +import enum +import functools +import logging +from typing import ( + AsyncGenerator, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + ParamSpec, + Tuple, + TypeVar, + Union, + cast, +) +from urllib.parse import quote as _quote, unquote + +from etcetra.client import ( + EtcdClient, EtcdTransactionAction, +) +from etcetra.types import ( + CompareKey, EtcdCredential, HostPortPair as EtcetraHostPortPair, +) +import trafaret as t + +from .logging_utils import BraceStyleAdapter +from .types import HostPortPair, QueueSentinel + +__all__ = ( + 'quote', 'unquote', + 'AsyncEtcd', +) + +Event = namedtuple('Event', 'key event value') + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class ConfigScopes(enum.Enum): + MERGED = 0 + GLOBAL = 1 + SGROUP = 2 + NODE = 3 + + +quote = functools.partial(_quote, safe='') + + +def make_dict_from_pairs(key_prefix, pairs, path_sep='/'): + result = {} + len_prefix = len(key_prefix) + if isinstance(pairs, dict): + iterator = pairs.items() + else: + iterator = pairs + for k, v in iterator: + if not k.startswith(key_prefix): + continue + subkey = k[len_prefix:] + if subkey.startswith(path_sep): + subkey = subkey[1:] + path_components = subkey.split('/') + parent = result + for p in path_components[:-1]: + p = unquote(p) + if p not in parent: + parent[p] = {} + if p in parent and not isinstance(parent[p], dict): + root = parent[p] + parent[p] = {'': root} + parent = parent[p] + parent[unquote(path_components[-1])] = v + return result + + +def _slash(v: str): + return v.rstrip('/') + '/' if len(v) > 0 else '' + + +P = ParamSpec("P") +R = TypeVar("R") + + +class AsyncEtcd: + + etcd: EtcdClient + + _creds: Optional[EtcdCredential] + + def __init__( + self, + addr: HostPortPair | EtcetraHostPortPair, + namespace: str, + scope_prefix_map: Mapping[ConfigScopes, str], + *, + credentials=None, + encoding='utf-8', + ) -> None: + self.scope_prefix_map = t.Dict({ + t.Key(ConfigScopes.GLOBAL): t.String(allow_blank=True), + t.Key(ConfigScopes.SGROUP, optional=True): t.String, + t.Key(ConfigScopes.NODE, optional=True): t.String, + }).check(scope_prefix_map) + if credentials is not None: + self._creds = EtcdCredential(credentials.get('user'), credentials.get('password')) + else: + self._creds = None + + self.ns = namespace + log.info('using etcd cluster from {} with namespace "{}"', addr, namespace) + self.encoding = encoding + self.etcd = EtcdClient( + EtcetraHostPortPair(str(addr.host), addr.port), + credentials=self._creds, + encoding=self.encoding, + ) + + async def close(self): + pass # for backward compatibility + + def _mangle_key(self, k: str) -> str: + if k.startswith('/'): + k = k[1:] + return f'/sorna/{self.ns}/{k}' + + def _demangle_key(self, k: Union[bytes, str]) -> str: + if isinstance(k, bytes): + k = k.decode(self.encoding) + prefix = f'/sorna/{self.ns}/' + if k.startswith(prefix): + k = k[len(prefix):] + return k + + def _merge_scope_prefix_map( + self, + override: Mapping[ConfigScopes, str] = None, + ) -> Mapping[ConfigScopes, str]: + """ + This stub ensures immutable usage of the ChainMap because ChainMap does *not* + have the immutable version in typeshed. + (ref: https://github.com/python/typeshed/issues/6042) + """ + return ChainMap(cast(MutableMapping, override) or {}, self.scope_prefix_map) + + async def put( + self, + key: str, + val: str, + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + """ + Put a single key-value pair to the etcd. + + :param key: The key. This must be quoted by the caller as needed. + :param val: The value. + :param scope: The config scope for putting the values. + :param scope_prefix_map: The scope map used to mangle the prefix for the config scope. + :return: + """ + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}') + async with self.etcd.connect() as communicator: + await communicator.put(mangled_key, str(val)) + + async def put_prefix( + self, + key: str, + dict_obj: Mapping[str, str], + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + """ + Put a nested dict object under the given key prefix. + All keys in the dict object are automatically quoted to avoid conflicts with the path separator. + + :param key: Prefix to put the given data. This must be quoted by the caller as needed. + :param dict_obj: Nested dictionary representing the data. + :param scope: The config scope for putting the values. + :param scope_prefix_map: The scope map used to mangle the prefix for the config scope. + :return: + """ + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + flattened_dict: Dict[str, str] = {} + + def _flatten(prefix: str, inner_dict: Mapping[str, str]) -> None: + for k, v in inner_dict.items(): + if k == '': + flattened_key = prefix + else: + flattened_key = prefix + '/' + quote(k) + if isinstance(v, dict): + _flatten(flattened_key, v) + else: + flattened_dict[flattened_key] = v + + _flatten(key, dict_obj) + + def _txn(action: EtcdTransactionAction): + for k, v in flattened_dict.items(): + action.put(self._mangle_key(f'{_slash(scope_prefix)}{k}'), str(v)) + + async with self.etcd.connect() as communicator: + await communicator.txn(_txn) + + async def put_dict( + self, + dict_obj: Mapping[str, str], + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + """ + Put a flattened key-value pairs into the etcd. + Since the given dict must be a flattened one, its keys must be quoted as needed by the caller. + For new codes, ``put_prefix()`` is recommended. + + :param dict_obj: Flattened key-value pairs to put. + :param scope: The config scope for putting the values. + :param scope_prefix_map: The scope map used to mangle the prefix for the config scope. + :return: + """ + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + + def _pipe(txn: EtcdTransactionAction): + for k, v in dict_obj.items(): + txn.put(self._mangle_key(f'{_slash(scope_prefix)}{k}'), str(v)) + + async with self.etcd.connect() as communicator: + await communicator.txn(_pipe) + + async def get( + self, + key: str, + *, + scope: ConfigScopes = ConfigScopes.MERGED, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ) -> Optional[str]: + """ + Get a single key from the etcd. + Returns ``None`` if the key does not exist. + The returned value may be an empty string if the value is a zero-length string. + + :param key: The key. This must be quoted by the caller as needed. + :param scope: The config scope to get the value. + :param scope_prefix_map: The scope map used to mangle the prefix for the config scope. + :return: + """ + + _scope_prefix_map = self._merge_scope_prefix_map(scope_prefix_map) + if scope == ConfigScopes.MERGED or scope == ConfigScopes.NODE: + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + p = _scope_prefix_map.get(ConfigScopes.SGROUP) + if p is not None: + scope_prefixes.insert(0, p) + p = _scope_prefix_map.get(ConfigScopes.NODE) + if p is not None: + scope_prefixes.insert(0, p) + elif scope == ConfigScopes.SGROUP: + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + p = _scope_prefix_map.get(ConfigScopes.SGROUP) + if p is not None: + scope_prefixes.insert(0, p) + elif scope == ConfigScopes.GLOBAL: + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + else: + raise ValueError('Invalid scope prefix value') + + async with self.etcd.connect() as communicator: + for scope_prefix in scope_prefixes: + value = await communicator.get(self._mangle_key(f'{_slash(scope_prefix)}{key}')) + if value is not None: + return value + return None + + async def get_prefix( + self, + key_prefix: str, + *, + scope: ConfigScopes = ConfigScopes.MERGED, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ) -> Mapping[str, Optional[str]]: + """ + Retrieves all key-value pairs under the given key prefix as a nested dictionary. + All dictionary keys are automatically unquoted. + If a key has a value while it is also used as path prefix for other keys, + the value directly referenced by the key itself is included as a value in a dictionary + with the empty-string key. + + For instance, when the etcd database has the following key-value pairs: + + .. code-block:: + + myprefix/mydata = abc + myprefix/mydata/x = 1 + myprefix/mydata/y = 2 + myprefix/mykey = def + + ``get_prefix("myprefix")`` returns the following dictionary: + + .. code-block:: + + { + "mydata": { + "": "abc", + "x": "1", + "y": "2", + }, + "mykey": "def", + } + + :param key_prefix: The key. This must be quoted by the caller as needed. + :param scope: The config scope to get the value. + :param scope_prefix_map: The scope map used to mangle the prefix for the config scope. + :return: + """ + + _scope_prefix_map = self._merge_scope_prefix_map(scope_prefix_map) + if scope == ConfigScopes.MERGED or scope == ConfigScopes.NODE: + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + p = _scope_prefix_map.get(ConfigScopes.SGROUP) + if p is not None: + scope_prefixes.insert(0, p) + p = _scope_prefix_map.get(ConfigScopes.NODE) + if p is not None: + scope_prefixes.insert(0, p) + elif scope == ConfigScopes.SGROUP: + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + p = _scope_prefix_map.get(ConfigScopes.SGROUP) + if p is not None: + scope_prefixes.insert(0, p) + elif scope == ConfigScopes.GLOBAL: + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + else: + raise ValueError('Invalid scope prefix value') + pair_sets: List[List[Mapping | Tuple]] = [] + async with self.etcd.connect() as communicator: + for scope_prefix in scope_prefixes: + mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}') + values = await communicator.get_prefix(mangled_key_prefix) + pair_sets.append([(self._demangle_key(k), v) for k, v in values.items()]) + + configs = [ + make_dict_from_pairs(f'{_slash(scope_prefix)}{key_prefix}', pairs, '/') + for scope_prefix, pairs in zip(scope_prefixes, pair_sets) + ] + return ChainMap(*configs) + + # for legacy + get_prefix_dict = get_prefix + + async def replace( + self, + key: str, + initial_val: str, + new_val: str, + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ) -> bool: + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}') + + def _txn(success: EtcdTransactionAction, _): + success.put(mangled_key, new_val) + async with self.etcd.connect() as communicator: + _, success = await communicator.txn_compare([ + CompareKey(mangled_key).value == initial_val, + ], _txn) + return success + + async def delete( + self, + key: str, + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}') + async with self.etcd.connect() as communicator: + await communicator.delete(mangled_key) + + async def delete_multi( + self, + keys: Iterable[str], + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + async with self.etcd.connect() as communicator: + def _txn(action: EtcdTransactionAction): + for k in keys: + action.delete(self._mangle_key(f'{_slash(scope_prefix)}{k}')) + await communicator.txn(_txn) + + async def delete_prefix( + self, + key_prefix: str, + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}') + async with self.etcd.connect() as communicator: + await communicator.delete_prefix(mangled_key_prefix) + + async def watch( + self, key: str, *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + once: bool = False, + ready_event: asyncio.Event = None, + cleanup_event: asyncio.Event = None, + wait_timeout: float = None, + ) -> AsyncGenerator[Union[QueueSentinel, Event], None]: + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + scope_prefix_len = len(self._mangle_key(f'{_slash(scope_prefix)}')) + mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}') + # NOTE: yield from in async-generator is not supported. + try: + async with self.etcd.connect() as communicator: + iterator = communicator.watch(mangled_key, ready_event=ready_event) + async for ev in iterator: + if wait_timeout is not None: + try: + ev = await asyncio.wait_for(iterator.__anext__(), wait_timeout) + except asyncio.TimeoutError: + pass + yield Event(ev.key[scope_prefix_len:], ev.event, ev.value) + if once: + return + finally: + if cleanup_event: + cleanup_event.set() + + async def watch_prefix( + self, key_prefix: str, *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + once: bool = False, + ready_event: asyncio.Event = None, + cleanup_event: asyncio.Event = None, + wait_timeout: float = None, + ) -> AsyncGenerator[Union[QueueSentinel, Event], None]: + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] + scope_prefix_len = len(self._mangle_key(f'{_slash(scope_prefix)}')) + mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}') + try: + async with self.etcd.connect() as communicator: + iterator = communicator.watch_prefix(mangled_key_prefix, ready_event=ready_event) + async for ev in iterator: + if wait_timeout is not None: + try: + ev = await asyncio.wait_for(iterator.__anext__(), wait_timeout) + except asyncio.TimeoutError: + pass + yield Event(ev.key[scope_prefix_len:], ev.event, ev.value) + if once: + return + finally: + if cleanup_event: + cleanup_event.set() diff --git a/src/ai/backend/common/events.py b/src/ai/backend/common/events.py new file mode 100644 index 0000000000..40162ac29a --- /dev/null +++ b/src/ai/backend/common/events.py @@ -0,0 +1,880 @@ +from __future__ import annotations + +import abc +import asyncio +from collections import defaultdict +import hashlib +import logging +import secrets +import socket +from typing import ( + Any, + Awaitable, + Callable, + ClassVar, + Coroutine, + Generic, + Mapping, + Optional, + Protocol, + Sequence, + Type, + TypeVar, + TypedDict, + Union, + cast, +) +from types import TracebackType +from typing_extensions import TypeAlias +import uuid + +import aioredis +import aioredis.exceptions +import aioredis.sentinel +from aiotools.context import aclosing +from aiotools.server import process_index +from aiotools.taskgroup import PersistentTaskGroup +import attr + +from . import msgpack, redis +from .logging import BraceStyleAdapter +from .types import ( + EtcdRedisConfig, + RedisConnectionInfo, + aobject, + AgentId, + KernelId, + SessionId, + LogSeverity, +) + +__all__ = ( + 'AbstractEvent', + 'EventCallback', + 'EventDispatcher', + 'EventHandler', + 'EventProducer', +) + +log = BraceStyleAdapter(logging.getLogger('ai.backend.common.events')) + +PTGExceptionHandler: TypeAlias = Callable[[Type[Exception], Exception, TracebackType], Awaitable[None]] + + +class AbstractEvent(metaclass=abc.ABCMeta): + + # derivatives shoudld define the fields. + + name: ClassVar[str] = "undefined" + + @abc.abstractmethod + def serialize(self) -> tuple: + """ + Return a msgpack-serializable tuple. + """ + pass + + @classmethod + @abc.abstractmethod + def deserialize(cls, value: tuple): + """ + Construct the event args from a tuple deserialized from msgpack. + """ + pass + + +class EmptyEventArgs(): + + def serialize(self) -> tuple: + return tuple() + + @classmethod + def deserialize(cls, value: tuple): + return cls() + + +class DoScheduleEvent(EmptyEventArgs, AbstractEvent): + name = "do_schedule" + + +class DoPrepareEvent(EmptyEventArgs, AbstractEvent): + name = "do_prepare" + + +class DoIdleCheckEvent(EmptyEventArgs, AbstractEvent): + name = "do_idle_check" + + +@attr.s(slots=True, frozen=True) +class DoTerminateSessionEvent(AbstractEvent): + name = "do_terminate_session" + + session_id: SessionId = attr.ib() + reason: str = attr.ib() + + def serialize(self) -> tuple: + return ( + str(self.session_id), + self.reason, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + SessionId(uuid.UUID(value[0])), + value[1], + ) + + +@attr.s(slots=True, frozen=True) +class GenericAgentEventArgs(): + + reason: str = attr.ib(default='') + + def serialize(self) -> tuple: + return (self.reason, ) + + @classmethod + def deserialize(cls, value: tuple): + return cls(value[0]) + + +class AgentStartedEvent(GenericAgentEventArgs, AbstractEvent): + name = "agent_started" + + +class AgentTerminatedEvent(GenericAgentEventArgs, AbstractEvent): + name = "agent_terminated" + + +@attr.s(slots=True, frozen=True) +class AgentErrorEvent(AbstractEvent): + name = "agent_error" + + message: str = attr.ib() + traceback: Optional[str] = attr.ib(default=None) + user: Optional[Any] = attr.ib(default=None) + context_env: Mapping[str, Any] = attr.ib(factory=dict) + severity: LogSeverity = attr.ib(default=LogSeverity.ERROR) + + def serialize(self) -> tuple: + return ( + self.message, + self.traceback, + self.user, + self.context_env, + self.severity.value, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + value[0], + value[1], + value[2], + value[3], + LogSeverity(value[4]), + ) + + +@attr.s(slots=True, frozen=True) +class AgentHeartbeatEvent(AbstractEvent): + name = "agent_heartbeat" + + agent_info: Mapping[str, Any] = attr.ib() + + def serialize(self) -> tuple: + return (self.agent_info, ) + + @classmethod + def deserialize(cls, value: tuple): + return cls(value[0]) + + +@attr.s(slots=True, frozen=True) +class KernelCreationEventArgs(): + kernel_id: KernelId = attr.ib() + creation_id: str = attr.ib() + reason: str = attr.ib(default='') + + def serialize(self) -> tuple: + return ( + str(self.kernel_id), + self.creation_id, + self.reason, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + kernel_id=KernelId(uuid.UUID(value[0])), + creation_id=value[1], + reason=value[2], + ) + + +class KernelEnqueuedEvent(KernelCreationEventArgs, AbstractEvent): + name = "kernel_enqueued" + + +class KernelPreparingEvent(KernelCreationEventArgs, AbstractEvent): + name = "kernel_preparing" + + +class KernelPullingEvent(KernelCreationEventArgs, AbstractEvent): + name = "kernel_pulling" + + +@attr.s(auto_attribs=True, slots=True) +class KernelPullProgressEvent(AbstractEvent): + name = "kernel_pull_progress" + kernel_id: uuid.UUID = attr.ib() + current_progress: float = attr.ib() + total_progress: float = attr.ib() + message: Optional[str] = attr.ib(default=None) + + def serialize(self) -> tuple: + return ( + str(self.kernel_id), + self.current_progress, + self.total_progress, + self.message, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + uuid.UUID(value[0]), + value[1], + value[2], + value[3], + ) + + +class KernelCreatingEvent(KernelCreationEventArgs, AbstractEvent): + name = "kernel_creating" + + +class KernelStartedEvent(KernelCreationEventArgs, AbstractEvent): + name = "kernel_started" + + +class KernelCancelledEvent(KernelCreationEventArgs, AbstractEvent): + name = "kernel_cancelled" + + +@attr.s(slots=True, frozen=True) +class KernelTerminationEventArgs(): + kernel_id: KernelId = attr.ib() + reason: str = attr.ib(default='') + exit_code: int = attr.ib(default=-1) + + def serialize(self) -> tuple: + return ( + str(self.kernel_id), + self.reason, + self.exit_code, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + KernelId(uuid.UUID(value[0])), + value[1], + value[2], + ) + + +class KernelTerminatingEvent(KernelTerminationEventArgs, AbstractEvent): + name = "kernel_terminating" + + +class KernelTerminatedEvent(KernelTerminationEventArgs, AbstractEvent): + name = "kernel_terminated" + + +@attr.s(slots=True, frozen=True) +class SessionCreationEventArgs(): + session_id: SessionId = attr.ib() + creation_id: str = attr.ib() + reason: str = attr.ib(default='') + + def serialize(self) -> tuple: + return ( + str(self.session_id), + self.creation_id, + self.reason, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + SessionId(uuid.UUID(value[0])), + value[1], + value[2], + ) + + +class SessionEnqueuedEvent(SessionCreationEventArgs, AbstractEvent): + name = "session_enqueued" + + +class SessionScheduledEvent(SessionCreationEventArgs, AbstractEvent): + name = "session_scheduled" + + +class SessionPreparingEvent(SessionCreationEventArgs, AbstractEvent): + name = "session_preparing" + + +class SessionCancelledEvent(SessionCreationEventArgs, AbstractEvent): + name = "session_cancelled" + + +class SessionStartedEvent(SessionCreationEventArgs, AbstractEvent): + name = "session_started" + + +@attr.s(slots=True, frozen=True) +class SessionTerminationEventArgs(): + session_id: SessionId = attr.ib() + reason: str = attr.ib(default='') + + def serialize(self) -> tuple: + return ( + str(self.session_id), + self.reason, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + SessionId(uuid.UUID(value[0])), + value[1], + ) + + +class SessionTerminatedEvent(SessionTerminationEventArgs, AbstractEvent): + name = "session_terminated" + + +@attr.s(slots=True, frozen=True) +class SessionResultEventArgs(): + session_id: SessionId = attr.ib() + reason: str = attr.ib(default='') + exit_code: int = attr.ib(default=-1) + + def serialize(self) -> tuple: + return ( + str(self.session_id), + self.reason, + self.exit_code, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + SessionId(uuid.UUID(value[0])), + value[1], + value[2], + ) + + +class SessionSuccessEvent(SessionResultEventArgs, AbstractEvent): + name = "session_success" + + +class SessionFailureEvent(SessionResultEventArgs, AbstractEvent): + name = "session_failure" + + +@attr.s(auto_attribs=True, slots=True) +class DoSyncKernelLogsEvent(AbstractEvent): + name = "do_sync_kernel_logs" + + kernel_id: KernelId = attr.ib() + container_id: str = attr.ib() + + def serialize(self) -> tuple: + return ( + str(self.kernel_id), + self.container_id, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + KernelId(uuid.UUID(value[0])), + value[1], + ) + + +@attr.s(auto_attribs=True, slots=True) +class DoSyncKernelStatsEvent(AbstractEvent): + name = "do_sync_kernel_stats" + + kernel_ids: Sequence[KernelId] = attr.ib() + + def serialize(self) -> tuple: + return ( + [*map(str, self.kernel_ids)], + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + kernel_ids=tuple( + KernelId(uuid.UUID(item)) for item in value[0] + ), + ) + + +@attr.s(auto_attribs=True, slots=True) +class GenericSessionEventArgs(AbstractEvent): + session_id: SessionId = attr.ib() + + def serialize(self) -> tuple: + return ( + str(self.session_id), + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + SessionId(uuid.UUID(value[0])), + ) + + +class ExecutionStartedEvent(GenericSessionEventArgs, AbstractEvent): + name = "execution_started" + + +class ExecutionFinishedEvent(GenericSessionEventArgs, AbstractEvent): + name = "execution_finished" + + +class ExecutionTimeoutEvent(GenericSessionEventArgs, AbstractEvent): + name = "execution_timeout" + + +class ExecutionCancelledEvent(GenericSessionEventArgs, AbstractEvent): + name = "execution_cancelled" + + +@attr.s(auto_attribs=True, slots=True) +class BgtaskUpdatedEvent(AbstractEvent): + name = "bgtask_updated" + + task_id: uuid.UUID = attr.ib() + current_progress: float = attr.ib() + total_progress: float = attr.ib() + message: Optional[str] = attr.ib(default=None) + + def serialize(self) -> tuple: + return ( + str(self.task_id), + self.current_progress, + self.total_progress, + self.message, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + uuid.UUID(value[0]), + value[1], + value[2], + value[3], + ) + + +@attr.s(auto_attribs=True, slots=True) +class BgtaskDoneEventArgs(): + task_id: uuid.UUID = attr.ib() + message: Optional[str] = attr.ib(default=None) + + def serialize(self) -> tuple: + return ( + str(self.task_id), + self.message, + ) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + uuid.UUID(value[0]), + value[1], + ) + + +class BgtaskDoneEvent(BgtaskDoneEventArgs, AbstractEvent): + name = "bgtask_done" + + +class BgtaskCancelledEvent(BgtaskDoneEventArgs, AbstractEvent): + name = "bgtask_cancelled" + + +class BgtaskFailedEvent(BgtaskDoneEventArgs, AbstractEvent): + name = "bgtask_failed" + + +class RedisConnectorFunc(Protocol): + def __call__( + self, + ) -> aioredis.ConnectionPool: + ... + + +TEvent = TypeVar('TEvent', bound='AbstractEvent') +TEventCov = TypeVar('TEventCov', bound='AbstractEvent') +TContext = TypeVar('TContext') + +EventCallback = Union[ + Callable[[TContext, AgentId, TEvent], Coroutine[Any, Any, None]], + Callable[[TContext, AgentId, TEvent], None], +] + + +@attr.s(auto_attribs=True, slots=True, frozen=True, eq=False, order=False) +class EventHandler(Generic[TContext, TEvent]): + event_cls: Type[TEvent] + name: str + context: TContext + callback: EventCallback[TContext, TEvent] + coalescing_opts: Optional[CoalescingOptions] + coalescing_state: CoalescingState + + +class CoalescingOptions(TypedDict): + max_wait: float + max_batch_size: int + + +@attr.s(auto_attribs=True, slots=True) +class CoalescingState: + batch_size: int = 0 + last_added: float = 0.0 + last_handle: asyncio.TimerHandle | None = None + fut_sync: asyncio.Future | None = None + + def proceed(self): + if self.fut_sync is not None and not self.fut_sync.done(): + self.fut_sync.set_result(None) + + async def rate_control(self, opts: CoalescingOptions | None) -> bool: + if opts is None: + return True + loop = asyncio.get_running_loop() + if self.fut_sync is None: + self.fut_sync = loop.create_future() + assert self.fut_sync is not None + self.last_added = loop.time() + self.batch_size += 1 + if self.batch_size >= opts['max_batch_size']: + assert self.last_handle is not None + self.last_handle.cancel() + self.fut_sync.cancel() + self.last_handle = None + self.last_added = 0.0 + self.batch_size = 0 + return True + # Schedule. + self.last_handle = loop.call_later( + opts['max_wait'], + self.proceed, + ) + if self.last_added > 0 and loop.time() - self.last_added < opts['max_wait']: + # Cancel the previously pending task. + self.last_handle.cancel() + self.fut_sync.cancel() + # Reschedule. + self.fut_sync = loop.create_future() + self.last_handle = loop.call_later( + opts['max_wait'], + self.proceed, + ) + try: + await self.fut_sync + except asyncio.CancelledError: + if self.last_handle is not None and not self.last_handle.cancelled(): + self.last_handle.cancel() + return False + else: + self.fut_sync = None + self.last_handle = None + self.last_added = 0.0 + self.batch_size = 0 + return True + + +class EventDispatcher(aobject): + """ + We have two types of event handlers: consumer and subscriber. + + Consumers use the distribution pattern. Only one consumer among many manager worker processes + receives the event. + + Consumer example: database updates upon specific events. + + Subscribers use the broadcast pattern. All subscribers in many manager worker processes + receive the same event. + + Subscriber example: enqueuing events to the queues for event streaming API handlers + """ + + consumers: defaultdict[str, set[EventHandler[Any, AbstractEvent]]] + subscribers: defaultdict[str, set[EventHandler[Any, AbstractEvent]]] + redis_client: RedisConnectionInfo + consumer_loop_task: asyncio.Task + subscriber_loop_task: asyncio.Task + consumer_taskgroup: PersistentTaskGroup + subscriber_taskgroup: PersistentTaskGroup + + _log_events: bool + _consumer_name: str + + def __init__( + self, + redis_config: EtcdRedisConfig, + db: int = 0, + log_events: bool = False, + *, + service_name: str = None, + stream_key: str = 'events', + consumer_group: str = "manager", + node_id: str = None, + consumer_exception_handler: PTGExceptionHandler = None, + subscriber_exception_handler: PTGExceptionHandler = None, + ) -> None: + _redis_config = redis_config.copy() + if service_name: + _redis_config['service_name'] = service_name + self.redis_client = redis.get_redis_object(_redis_config, db=db) + self._log_events = log_events + self._closed = False + self.consumers = defaultdict(set) + self.subscribers = defaultdict(set) + self._stream_key = stream_key + self._consumer_group = consumer_group + self._consumer_name = _generate_consumer_id(node_id) + self.consumer_taskgroup = PersistentTaskGroup( + name="consumer_taskgroup", + exception_handler=consumer_exception_handler, + ) + self.subscriber_taskgroup = PersistentTaskGroup( + name="subscriber_taskgroup", + exception_handler=subscriber_exception_handler, + ) + + async def __ainit__(self) -> None: + self.consumer_loop_task = asyncio.create_task(self._consume_loop()) + self.subscriber_loop_task = asyncio.create_task(self._subscribe_loop()) + + async def close(self) -> None: + self._closed = True + try: + cancelled_tasks = [] + await self.consumer_taskgroup.shutdown() + await self.subscriber_taskgroup.shutdown() + if not self.consumer_loop_task.done(): + self.consumer_loop_task.cancel() + cancelled_tasks.append(self.consumer_loop_task) + if not self.subscriber_loop_task.done(): + self.subscriber_loop_task.cancel() + cancelled_tasks.append(self.subscriber_loop_task) + await asyncio.gather(*cancelled_tasks, return_exceptions=True) + except Exception: + log.exception("unexpected error while closing event dispatcher") + finally: + await self.redis_client.close() + + def consume( + self, + event_cls: Type[TEvent], + context: TContext, + callback: EventCallback[TContext, TEvent], + coalescing_opts: CoalescingOptions = None, + *, + name: str = None, + ) -> EventHandler[TContext, TEvent]: + if name is None: + name = f"evh-{secrets.token_urlsafe(16)}" + handler = EventHandler(event_cls, name, context, callback, coalescing_opts, CoalescingState()) + self.consumers[event_cls.name].add(cast(EventHandler[Any, AbstractEvent], handler)) + return handler + + def unconsume( + self, + handler: EventHandler[TContext, TEvent], + ) -> None: + self.consumers[handler.event_cls.name].discard(cast(EventHandler[Any, AbstractEvent], handler)) + + def subscribe( + self, + event_cls: Type[TEvent], + context: TContext, + callback: EventCallback[TContext, TEvent], + coalescing_opts: CoalescingOptions = None, + *, + name: str = None, + ) -> EventHandler[TContext, TEvent]: + if name is None: + name = f"evh-{secrets.token_urlsafe(16)}" + handler = EventHandler(event_cls, name, context, callback, coalescing_opts, CoalescingState()) + self.subscribers[event_cls.name].add(cast(EventHandler[Any, AbstractEvent], handler)) + return handler + + def unsubscribe( + self, + handler: EventHandler[TContext, TEvent], + ) -> None: + self.subscribers[handler.event_cls.name].discard(cast(EventHandler[Any, AbstractEvent], handler)) + + async def handle(self, evh_type: str, evh: EventHandler, source: AgentId, args: tuple) -> None: + coalescing_opts = evh.coalescing_opts + coalescing_state = evh.coalescing_state + cb = evh.callback + event_cls = evh.event_cls + if self._closed: + return + if (await coalescing_state.rate_control(coalescing_opts)): + if self._closed: + return + if self._log_events: + log.debug("DISPATCH_{}(evh:{})", evh_type, evh.name) + if asyncio.iscoroutinefunction(cb): + # mypy cannot catch the meaning of asyncio.iscoroutinefunction(). + await cb(evh.context, source, event_cls.deserialize(args)) # type: ignore + else: + cb(evh.context, source, event_cls.deserialize(args)) # type: ignore + + async def dispatch_consumers( + self, + event_name: str, + source: AgentId, + args: tuple, + ) -> None: + if self._log_events: + log.debug('DISPATCH_CONSUMERS(ev:{}, ag:{})', event_name, source) + for consumer in self.consumers[event_name].copy(): + self.consumer_taskgroup.create_task( + self.handle("CONSUMER", consumer, source, args), + ) + await asyncio.sleep(0) + + async def dispatch_subscribers( + self, + event_name: str, + source: AgentId, + args: tuple, + ) -> None: + if self._log_events: + log.debug('DISPATCH_SUBSCRIBERS(ev:{}, ag:{})', event_name, source) + for subscriber in self.subscribers[event_name].copy(): + self.subscriber_taskgroup.create_task( + self.handle("SUBSCRIBER", subscriber, source, args), + ) + await asyncio.sleep(0) + + async def _consume_loop(self) -> None: + async with aclosing(redis.read_stream_by_group( + self.redis_client, + self._stream_key, + self._consumer_group, + self._consumer_name, + )) as agen: + async for msg_id, msg_data in agen: + if self._closed: + return + if msg_data is None: + continue + try: + await self.dispatch_consumers( + msg_data[b'name'].decode(), + msg_data[b'source'].decode(), + msgpack.unpackb(msg_data[b'args']), + ) + except asyncio.CancelledError: + raise + except Exception: + log.exception('EventDispatcher.consume(): unexpected-error') + + async def _subscribe_loop(self) -> None: + async with aclosing(redis.read_stream( + self.redis_client, + self._stream_key, + )) as agen: + async for msg_id, msg_data in agen: + if self._closed: + return + if msg_data is None: + continue + try: + await self.dispatch_subscribers( + msg_data[b'name'].decode(), + msg_data[b'source'].decode(), + msgpack.unpackb(msg_data[b'args']), + ) + except asyncio.CancelledError: + raise + except Exception: + log.exception('EventDispatcher.subscribe(): unexpected-error') + + +class EventProducer(aobject): + redis_client: RedisConnectionInfo + _log_events: bool + + def __init__( + self, + redis_config: EtcdRedisConfig, + db: int = 0, + *, + service_name: str = None, + stream_key: str = 'events', + log_events: bool = False, + ) -> None: + _redis_config = redis_config.copy() + if service_name: + _redis_config['service_name'] = service_name + self._closed = False + self.redis_client = redis.get_redis_object(_redis_config, db=db) + self._log_events = log_events + self._stream_key = stream_key + + async def __ainit__(self) -> None: + pass + + async def close(self) -> None: + self._closed = True + await self.redis_client.close() + + async def produce_event( + self, + event: AbstractEvent, + *, + source: str = 'manager', + ) -> None: + if self._closed: + return + raw_event = { + b'name': event.name.encode(), + b'source': source.encode(), + b'args': msgpack.packb(event.serialize()), + } + await redis.execute( + self.redis_client, + lambda r: r.xadd(self._stream_key, raw_event), # type: ignore # aio-libs/aioredis-py#1182 + ) + + +def _generate_consumer_id(node_id: str = None) -> str: + h = hashlib.sha1() + h.update(str(node_id or socket.getfqdn()).encode('utf8')) + hostname_hash = h.hexdigest() + h = hashlib.sha1() + h.update(__file__.encode('utf8')) + installation_path_hash = h.hexdigest() + pidx = process_index.get(0) + return f"{hostname_hash}:{installation_path_hash}:{pidx}" diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py new file mode 100644 index 0000000000..ee56839b6a --- /dev/null +++ b/src/ai/backend/common/exception.py @@ -0,0 +1,50 @@ +from typing import Any, Mapping + + +class ConfigurationError(Exception): + + invalid_data: Mapping[str, Any] + + def __init__(self, invalid_data: Mapping[str, Any]) -> None: + super().__init__(invalid_data) + self.invalid_data = invalid_data + + +class UnknownImageReference(ValueError): + ''' + Represents an error for invalid/unknown image reference. + The first argument of this exception should be the reference given by the user. + ''' + + def __str__(self) -> str: + return f'Unknown image reference: {self.args[0]}' + + +class ImageNotAvailable(ValueError): + ''' + Represents an error for unavailability of the image in agents. + The first argument of this exception should be the reference given by the user. + ''' + + def __str__(self) -> str: + return f'Unavailable image in the agent: {self.args[0]}' + + +class UnknownImageRegistry(ValueError): + ''' + Represents an error for invalid/unknown image registry. + The first argument of this exception should be the registry given by the user. + ''' + + def __str__(self) -> str: + return f'Unknown image registry: {self.args[0]}' + + +class AliasResolutionFailed(ValueError): + ''' + Represents an alias resolution failure. + The first argument of this exception should be the alias given by the user. + ''' + + def __str__(self) -> str: + return f'Failed to resolve alias: {self.args[0]}' diff --git a/src/ai/backend/common/files.py b/src/ai/backend/common/files.py new file mode 100644 index 0000000000..fb129b01d7 --- /dev/null +++ b/src/ai/backend/common/files.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from pathlib import Path +from typing import ( + Callable, +) + +import janus + +from .asyncio import current_loop +from .types import Sentinel + +__all__ = ( + 'AsyncFileWriter', +) + + +class AsyncFileWriter: + """ + This class provides a context manager for making sequential async + writes using janus queue. + """ + + def __init__( + self, + target_filename: str | Path, + access_mode: str, + encode: Callable[[str], bytes] = None, + max_chunks: int = None, + ) -> None: + if max_chunks is None: + max_chunks = 0 + self._q: janus.Queue[str | bytes | Sentinel] = janus.Queue(maxsize=max_chunks) + self._target_filename = target_filename + self._access_mode = access_mode + self._binary_mode = 'b' in access_mode + if encode is not None: + self._encode = encode + else: + self._encode = lambda v: v.encode() # default encoder + + async def __aenter__(self): + loop = current_loop() + self._fut = loop.run_in_executor(None, self._write) + return self + + def _write(self) -> None: + with open(self._target_filename, self._access_mode) as f: + while True: + item = self._q.sync_q.get() + if item is Sentinel.TOKEN: + break + if self._binary_mode: + encoded = self._encode(item) if isinstance(item, str) else item + f.write(encoded) + else: + f.write(item) + self._q.sync_q.task_done() + + async def __aexit__(self, exc_type, exc, tb): + await self._q.async_q.put(Sentinel.TOKEN) + try: + await self._fut + finally: + self._q.close() + await self._q.wait_closed() + + async def write(self, item) -> None: + await self._q.async_q.put(item) diff --git a/src/ai/backend/common/identity.py b/src/ai/backend/common/identity.py new file mode 100644 index 0000000000..8c05e0fd68 --- /dev/null +++ b/src/ai/backend/common/identity.py @@ -0,0 +1,245 @@ +from ipaddress import ( + ip_address, + _BaseNetwork as BaseIPNetwork, _BaseAddress as BaseIPAddress, +) +import json +import logging +import os +import socket +import sys +from typing import ( + Awaitable, Callable, Iterable, Optional, +) +from pathlib import Path + +import aiodns +import netifaces + +from .utils import curl + +__all__ = ( + 'detect_cloud', + 'current_provider', + 'get_instance_id', + 'get_instance_ip', + 'get_instance_type', + 'get_instance_region', +) + +log = logging.getLogger(__name__) + + +def is_containerized() -> bool: + ''' + Check if I am running inside a Linux container. + ''' + try: + cginfo = Path('/proc/self/cgroup').read_text() + if '/docker/' in cginfo or '/lxc/' in cginfo: + return True + return False + except IOError: + return False + + +def detect_cloud() -> Optional[str]: + ''' + Detect the cloud provider where I am running on. + ''' + # NOTE: Contributions are welcome! + # Please add other cloud providers such as Rackspace, IBM BlueMix, etc. + if sys.platform.startswith('linux'): + # Google Cloud Platform or Amazon AWS (hvm) + try: + # AWS Nitro-based instances + mb = Path('/sys/devices/virtual/dmi/id/board_vendor').read_text().lower() + if 'amazon' in mb: + return 'amazon' + except IOError: + pass + try: + bios = Path('/sys/devices/virtual/dmi/id/bios_version').read_text().lower() + if 'google' in bios: + return 'google' + if 'amazon' in bios: + return 'amazon' + except IOError: + pass + # Microsoft Azure + # https://gallery.technet.microsoft.com/scriptcenter/Detect-Windows-Azure-aed06d51 + # TODO: this only works with Debian/Ubuntu instances. + # TODO: this does not work inside containers. + try: + dhcp = Path('/var/lib/dhcp/dhclient.eth0.leases').read_text() + if 'unknown-245' in dhcp: + return 'azure' + # alternative method is to read /var/lib/waagent/GoalState.1.xml + # but it requires sudo privilege. + except IOError: + pass + return None + + +def fetch_local_ipaddrs(cidr: BaseIPNetwork) -> Iterable[BaseIPAddress]: + ifnames = netifaces.interfaces() + proto = netifaces.AF_INET if cidr.version == 4 else netifaces.AF_INET6 + for ifname in ifnames: + addrs = netifaces.ifaddresses(ifname).get(proto, None) + if addrs is None: + continue + for entry in addrs: + addr = ip_address(entry['addr']) + if addr in cidr: + yield addr + + +# Detect upon module load. +current_provider = detect_cloud() +if current_provider is None: + log.info('Detected environment: on-premise setup') + log.info('The agent node ID is set using the hostname.') +else: + log.info(f'Detected environment: {current_provider} cloud') + log.info('The agent node ID will follow the instance ID.') + +_defined: bool = False +get_instance_id: Callable[[], Awaitable[str]] +get_instance_ip: Callable[[Optional[BaseIPNetwork]], Awaitable[str]] +get_instance_type: Callable[[], Awaitable[str]] +get_instance_region: Callable[[], Awaitable[str]] + + +def _define_functions(): + global _defined + global get_instance_id + global get_instance_ip + global get_instance_type + global get_instance_region + if _defined: + return + + if current_provider == 'amazon': + # ref: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html + _metadata_prefix = 'http://169.254.169.254/latest/meta-data/' + _dynamic_prefix = 'http://169.254.169.254/latest/dynamic/' + + async def _get_instance_id() -> str: + return await curl(_metadata_prefix + 'instance-id', + lambda: f'i-{socket.gethostname()}') + + async def _get_instance_ip(subnet_hint: BaseIPNetwork = None) -> str: + return await curl(_metadata_prefix + 'local-ipv4', + '127.0.0.1') + + async def _get_instance_type() -> str: + return await curl(_metadata_prefix + 'instance-type', + 'unknown') + + async def _get_instance_region() -> str: + doc = await curl(_dynamic_prefix + 'instance-identity/document', None) + if doc is None: + return 'amazon/unknown' + region = json.loads(doc)['region'] + return f'amazon/{region}' + + elif current_provider == 'azure': + # ref: https://docs.microsoft.com/azure/virtual-machines/virtual-machines-instancemetadataservice-overview + _metadata_prefix = 'http://169.254.169.254/metadata/instance' + + async def _get_instance_id() -> str: + data = await curl(_metadata_prefix, None, + params={'version': '2017-03-01'}, + headers={'Metadata': 'true'}) + if data is None: + return f'i-{socket.gethostname()}' + o = json.loads(data) + return o['compute']['vmId'] + + async def _get_instance_ip(subnet_hint: BaseIPNetwork = None) -> str: + data = await curl(_metadata_prefix, None, + params={'version': '2017-03-01'}, + headers={'Metadata': 'true'}) + if data is None: + return '127.0.0.1' + o = json.loads(data) + return o['network']['interface'][0]['ipv4']['ipaddress'][0]['ipaddress'] + + async def _get_instance_type() -> str: + data = await curl(_metadata_prefix, None, + params={'version': '2017-03-01'}, + headers={'Metadata': 'true'}) + if data is None: + return 'unknown' + o = json.loads(data) + return o['compute']['vmSize'] + + async def _get_instance_region() -> str: + data = await curl(_metadata_prefix, None, + params={'version': '2017-03-01'}, + headers={'Metadata': 'true'}) + if data is None: + return 'azure/unknown' + o = json.loads(data) + region = o['compute']['location'] + return f'azure/{region}' + + elif current_provider == 'google': + # ref: https://cloud.google.com/compute/docs/storing-retrieving-metadata + _metadata_prefix = 'http://metadata.google.internal/computeMetadata/v1/' + + async def _get_instance_id() -> str: + return await curl(_metadata_prefix + 'instance/id', + lambda: f'i-{socket.gethostname()}', + headers={'Metadata-Flavor': 'Google'}) + + async def _get_instance_ip(subnet_hint: BaseIPNetwork = None) -> str: + return await curl(_metadata_prefix + 'instance/network-interfaces/0/ip', + '127.0.0.1', + headers={'Metadata-Flavor': 'Google'}) + + async def _get_instance_type() -> str: + return await curl(_metadata_prefix + 'instance/machine-type', + 'unknown', + headers={'Metadata-Flavor': 'Google'}) + + async def _get_instance_region() -> str: + zone = await curl(_metadata_prefix + 'instance/zone', + 'unknown', + headers={'Metadata-Flavor': 'Google'}) + region = zone.rsplit('-', 1)[0] + return f'google/{region}' + + else: + _metadata_prefix = None + + async def _get_instance_id() -> str: + return f'i-{socket.gethostname()}' + + async def _get_instance_ip(subnet_hint: BaseIPNetwork = None) -> str: + if subnet_hint is not None and subnet_hint.prefixlen > 0: + local_ipaddrs = [*fetch_local_ipaddrs(subnet_hint)] + if local_ipaddrs: + return str(local_ipaddrs[0]) + raise RuntimeError('Could not find my IP address bound to subnet {}', subnet_hint) + try: + myself = socket.gethostname() + resolver = aiodns.DNSResolver() + result = await resolver.gethostbyname(myself, socket.AF_INET) + return result.addresses[0] + except aiodns.error.DNSError: + return '127.0.0.1' + + async def _get_instance_type() -> str: + return 'default' + + async def _get_instance_region() -> str: + return os.environ.get('BACKEND_REGION', 'local') + + get_instance_id = _get_instance_id + get_instance_ip = _get_instance_ip + get_instance_type = _get_instance_type + get_instance_region = _get_instance_region + _defined = True + + +_define_functions() diff --git a/src/ai/backend/common/json.py b/src/ai/backend/common/json.py new file mode 100644 index 0000000000..9ac84b16dc --- /dev/null +++ b/src/ai/backend/common/json.py @@ -0,0 +1,12 @@ +import datetime +import json +import uuid + + +class ExtendedJSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, uuid.UUID): + return str(o) + elif isinstance(o, datetime.datetime): + return o.isoformat() + return super().default(o) diff --git a/src/ai/backend/common/lock.py b/src/ai/backend/common/lock.py new file mode 100644 index 0000000000..96202cf835 --- /dev/null +++ b/src/ai/backend/common/lock.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import abc +import asyncio +import fcntl +import logging +from io import IOBase +from pathlib import Path +from typing import Any, Optional + +from tenacity import ( + AsyncRetrying, + RetryError, + retry_if_exception_type, + stop_after_delay, + stop_never, + wait_exponential, + wait_random, +) + +from etcetra.client import EtcdConnectionManager, EtcdCommunicator + +from ai.backend.common.etcd import AsyncEtcd + +from .logging import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class AbstractDistributedLock(metaclass=abc.ABCMeta): + + def __init__(self, *, lifetime: Optional[float] = None) -> None: + self._lifetime = lifetime + + @abc.abstractmethod + async def __aenter__(self) -> Any: + raise NotImplementedError + + @abc.abstractmethod + async def __aexit__(self, *exc_info) -> Optional[bool]: + raise NotImplementedError + + +class FileLock(AbstractDistributedLock): + + default_timeout: float = 3 # not allow infinite timeout for safety + + _fp: IOBase | None + _locked: bool = False + + def __init__( + self, + path: Path, + *, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + debug: bool = False, + ) -> None: + super().__init__(lifetime=lifetime) + self._fp = None + self._path = path + self._timeout = timeout if timeout is not None else self.default_timeout + self._debug = debug + + @property + def locked(self) -> bool: + return self._locked + + def __del__(self) -> None: + if self._fp is not None: + self._debug = False + self.release() + log.debug("file lock implicitly released: {}", self._path) + + async def acquire(self) -> None: + assert self._fp is None + assert not self._locked + self._path.touch(exist_ok=True) + self._fp = open(self._path, "rb") + stop_func = stop_never if self._timeout <= 0 else stop_after_delay(self._timeout) + try: + async for attempt in AsyncRetrying( + retry=retry_if_exception_type(BlockingIOError), + wait=wait_exponential(multiplier=0.02, min=0.02, max=1.0) + wait_random(0, 0.05), + stop=stop_func, + ): + with attempt: + fcntl.flock(self._fp, fcntl.LOCK_EX | fcntl.LOCK_NB) + self._locked = True + if self._debug: + log.debug("file lock acquired: {}", self._path) + except RetryError: + raise asyncio.TimeoutError(f"failed to lock file: {self._path}") + + def release(self) -> None: + assert self._fp is not None + if self._locked: + fcntl.flock(self._fp, fcntl.LOCK_UN) + self._locked = False + if self._debug: + log.debug("file lock explicitly released: {}", self._path) + self._fp.close() + self._fp = None + + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__(self, *exc_info) -> bool | None: + self.release() + return None + + +class EtcdLock(AbstractDistributedLock): + + _con_mgr: Optional[EtcdConnectionManager] + _debug: bool + + lock_name: str + etcd: AsyncEtcd + timeout: float + + default_timeout: float = 9600 # not allow infinite timeout for safety + + def __init__( + self, + lock_name: str, + etcd: AsyncEtcd, + *, + timeout: Optional[float] = None, + lifetime: Optional[float] = None, + debug: bool = False, + ) -> None: + super().__init__(lifetime=lifetime) + self.lock_name = lock_name + self.etcd = etcd + self._timeout = timeout if timeout is not None else self.default_timeout + self._debug = debug + + async def __aenter__(self) -> EtcdCommunicator: + self._con_mgr = self.etcd.etcd.with_lock( + self.lock_name, + timeout=self._timeout, + ttl=int(self._lifetime) if self._lifetime is not None else None, + ) + assert self._con_mgr is not None # FIXME: not required if with_lock() has an explicit return type. + communicator = await self._con_mgr.__aenter__() + if self._debug: + log.debug('etcd lock acquired') + return communicator + + async def __aexit__(self, *exc_info) -> Optional[bool]: + assert self._con_mgr is not None + await self._con_mgr.__aexit__(*exc_info) + if self._debug: + log.debug('etcd lock released') + self._con_mgr = None + return None diff --git a/src/ai/backend/common/logging.py b/src/ai/backend/common/logging.py new file mode 100644 index 0000000000..06041cc777 --- /dev/null +++ b/src/ai/backend/common/logging.py @@ -0,0 +1,509 @@ +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from contextvars import ContextVar +from datetime import datetime +import json +import logging, logging.config, logging.handlers +import threading +import os +from pathlib import Path +import pickle +import pprint +import time +from typing import ( + Any, Optional, + Mapping, MutableMapping, +) +import socket +import ssl +import sys + +import coloredlogs +from pythonjsonlogger.jsonlogger import JsonFormatter +import trafaret as t +from tblib import pickling_support +import yarl +import zmq + +from . import config +from . import validators as tx +from .logging_utils import BraceStyleAdapter +from .exception import ConfigurationError + +# public APIs of this module +__all__ = ( + 'AbstractLogger', + 'Logger', + 'NoopLogger', + 'BraceStyleAdapter', + 'LogstashHandler', + 'is_active', + 'pretty', +) + +is_active: ContextVar[bool] = ContextVar('is_active', default=False) + +loglevel_iv = t.Enum('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', 'NOTSET') +logformat_iv = t.Enum('simple', 'verbose') +default_pkg_ns = { + '': 'WARNING', + 'ai.backend': 'INFO', +} + +logging_config_iv = t.Dict({ + t.Key('level', default='INFO'): loglevel_iv, + t.Key('pkg-ns', default=default_pkg_ns): t.Mapping(t.String(allow_blank=True), loglevel_iv), + t.Key('drivers', default=['console']): t.List(t.Enum('console', 'logstash', 'file')), + t.Key('console', default=None): t.Null | t.Dict({ + t.Key('colored', default=True): t.Bool, + t.Key('format', default='verbose'): logformat_iv, + }).allow_extra('*'), + t.Key('file', default=None): t.Null | t.Dict({ + t.Key('path'): tx.Path(type='dir', auto_create=True), + t.Key('filename'): t.String, + t.Key('backup-count', default=5): t.Int[1:100], + t.Key('rotation-size', default='10M'): tx.BinarySize, + t.Key('format', default='verbose'): logformat_iv, + }).allow_extra('*'), + t.Key('logstash', default=None): t.Null | t.Dict({ + t.Key('endpoint'): tx.HostPortPair, + t.Key('protocol', default='tcp'): t.Enum('zmq.push', 'zmq.pub', 'tcp', 'udp'), + t.Key('ssl-enabled', default=True): t.Bool, + t.Key('ssl-verify', default=True): t.Bool, + # NOTE: logstash does not have format optoin. + }).allow_extra('*'), +}).allow_extra('*') + + +class PickledException(Exception): + """ + Serves as a wrapper for exceptions that contain unpicklable arguments. + """ + pass + + +class LogstashHandler(logging.Handler): + + def __init__(self, endpoint, protocol: str, *, + ssl_enabled: bool = True, + ssl_verify: bool = True, + myhost: str = None): + super().__init__() + self._endpoint = endpoint + self._protocol = protocol + self._ssl_enabled = ssl_enabled + self._ssl_verify = ssl_verify + self._myhost = myhost + self._sock = None + self._sslctx = None + self._zmqctx = None + + def _setup_transport(self): + if self._sock is not None: + return + if self._protocol == 'zmq.push': + self._zmqctx = zmq.Context() + sock = self._zmqctx.socket(zmq.PUSH) + sock.setsockopt(zmq.LINGER, 50) + sock.setsockopt(zmq.SNDHWM, 20) + sock.connect(f'tcp://{self._endpoint[0]}:{self._endpoint[1]}') + self._sock = sock + elif self._protocol == 'zmq.pub': + self._zmqctx = zmq.Context() + sock = self._zmqctx.socket(zmq.PUB) + sock.setsockopt(zmq.LINGER, 50) + sock.setsockopt(zmq.SNDHWM, 20) + sock.connect(f'tcp://{self._endpoint[0]}:{self._endpoint[1]}') + self._sock = sock + elif self._protocol == 'tcp': + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self._ssl_enabled: + self._sslctx = ssl.create_default_context() + if not self._ssl_verify: + self._sslctx.check_hostname = False + self._sslctx.verify_mode = ssl.CERT_NONE + sock = self._sslctx.wrap_socket(sock, server_hostname=self._endpoint[0]) + sock.connect((str(self._endpoint.host), self._endpoint.port)) + self._sock = sock + elif self._protocol == 'udp': + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.connect((str(self._endpoint.host), self._endpoint.port)) + self._sock = sock + else: + raise ConfigurationError({'logging.LogstashHandler': f'unsupported protocol: {self._protocol}'}) + + def cleanup(self): + if self._sock: + self._sock.close() + self._sslctx = None + if self._zmqctx: + self._zmqctx.term() + + def emit(self, record): + self._setup_transport() + tags = set() + extra_data = dict() + + if record.exc_info: + tags.add('has_exception') + if self.formatter: + extra_data['exception'] = self.formatter.formatException(record.exc_info) + else: + extra_data['exception'] = logging._defaultFormatter.formatException(record.exc_info) + + # This log format follows logstash's event format. + log = OrderedDict([ + ('@timestamp', datetime.now().isoformat()), + ('@version', 1), + ('host', self._myhost), + ('logger', record.name), + ('path', record.pathname), + ('func', record.funcName), + ('lineno', record.lineno), + ('message', record.getMessage()), + ('level', record.levelname), + ('tags', list(tags)), + ]) + log.update(extra_data) + if self._protocol.startswith('zmq'): + self._sock.send_json(log) + else: + # TODO: reconnect if disconnected + self._sock.sendall(json.dumps(log).encode('utf-8')) + + +class ConsoleFormatter(logging.Formatter): + + def formatTime(self, record: logging.LogRecord, datefmt: str = None) -> str: + ct = self.converter(record.created) # type: ignore + if datefmt: + datefmt = datefmt.replace("%f", f"{int(record.msecs):03d}") + return time.strftime(datefmt, ct) + else: + t = time.strftime("%Y-%m-%d %H:%M:%S", ct) + return f"{t}.{int(record.msecs):03d}" + + +class CustomJsonFormatter(JsonFormatter): + + def add_fields(self, log_record, record, message_dict): + super().add_fields(log_record, record, message_dict) + if not log_record.get('timestamp'): + # this doesn't use record.created, so it is slightly off + now = datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ') + log_record['timestamp'] = now + if log_record.get('level', record.levelname): + log_record['level'] = log_record['level'].upper() + else: + log_record['level'] = record.levelname + + +class pretty: + """A simple object wrapper to pretty-format it when formatting the log record.""" + + def __init__(self, obj: Any) -> None: + self.obj = obj + + def __repr__(self) -> str: + return pprint.pformat(self.obj) + + +def log_worker( + daemon_config: Mapping[str, Any], + parent_pid: int, + log_endpoint: str, + ready_event: threading.Event, +) -> None: + console_handler = None + file_handler = None + logstash_handler = None + + log_formats = { + 'simple': '%(levelname)s %(message)s', + 'verbose': '%(asctime)s %(levelname)s %(name)s [%(process)d] %(message)s', + } + + if 'console' in daemon_config['drivers']: + drv_config = daemon_config['console'] + console_formatter: logging.Formatter + if drv_config['colored']: + console_formatter = coloredlogs.ColoredFormatter( + log_formats[drv_config['format']], + datefmt="%Y-%m-%d %H:%M:%S.%f", # coloredlogs has intrinsic support for msec + field_styles={'levelname': {'color': 248, 'bold': True}, + 'name': {'color': 246, 'bold': False}, + 'process': {'color': 'cyan'}, + 'asctime': {'color': 240}}, + level_styles={'debug': {'color': 'green'}, + 'verbose': {'color': 'green', 'bright': True}, + 'info': {'color': 'cyan', 'bright': True}, + 'notice': {'color': 'cyan', 'bold': True}, + 'warning': {'color': 'yellow'}, + 'error': {'color': 'red', 'bright': True}, + 'success': {'color': 77}, + 'critical': {'background': 'red', 'color': 255, 'bold': True}}, + ) + else: + console_formatter = ConsoleFormatter( + log_formats[drv_config['format']], + datefmt="%Y-%m-%d %H:%M:%S.%f", + ) + console_handler = logging.StreamHandler( + stream=sys.stderr, + ) + console_handler.setLevel(daemon_config['level']) + console_handler.setFormatter(console_formatter) + + if 'file' in daemon_config['drivers']: + drv_config = daemon_config['file'] + fmt = '%(timestamp) %(level) %(name) %(processName) %(message)' + file_handler = logging.handlers.RotatingFileHandler( + filename=drv_config['path'] / drv_config['filename'], + backupCount=drv_config['backup-count'], + maxBytes=drv_config['rotation-size'], + encoding='utf-8', + ) + file_handler.setLevel(daemon_config['level']) + file_handler.setFormatter(CustomJsonFormatter(fmt)) + + if 'logstash' in daemon_config['drivers']: + drv_config = daemon_config['logstash'] + logstash_handler = LogstashHandler( + endpoint=drv_config['endpoint'], + protocol=drv_config['protocol'], + ssl_enabled=drv_config['ssl-enabled'], + ssl_verify=drv_config['ssl-verify'], + myhost='hostname', # TODO: implement + ) + logstash_handler.setLevel(daemon_config['level']) + + zctx = zmq.Context() + agg_sock = zctx.socket(zmq.PULL) + agg_sock.bind(log_endpoint) + ep_url = yarl.URL(log_endpoint) + if ep_url.scheme.lower() == 'ipc': + os.chmod(ep_url.path, 0o777) + try: + ready_event.set() + while True: + data = agg_sock.recv() + if not data: + return + try: + rec = pickle.loads(data) + except (pickle.PickleError, TypeError): + # We have an unpickling error. + # Change into a self-created log record with exception info. + rec = logging.makeLogRecord({ + 'name': __name__, + 'msg': 'Cannot unpickle the log record (raw data: %r)', + 'levelno': logging.ERROR, + 'levelname': 'error', + 'args': (data,), # attach the original data for inspection + 'exc_info': sys.exc_info(), + }) + if rec is None: + break + if console_handler: + console_handler.emit(rec) + try: + if file_handler: + file_handler.emit(rec) + if logstash_handler: + logstash_handler.emit(rec) + except OSError: + # don't terminate the log worker. + continue + finally: + if logstash_handler: + logstash_handler.cleanup() + agg_sock.close() + zctx.term() + + +class RelayHandler(logging.Handler): + + _sock: zmq.Socket | None + + def __init__(self, *, endpoint: str) -> None: + super().__init__() + self.endpoint = endpoint + self._zctx = zmq.Context() + # We should use PUSH-PULL socket pairs to avoid + # lost of synchronization sentinel messages. + if endpoint: + self._sock = self._zctx.socket(zmq.PUSH) + assert self._sock is not None + self._sock.setsockopt(zmq.LINGER, 100) + self._sock.connect(self.endpoint) + else: + self._sock = None + + def close(self) -> None: + if self._sock is not None: + self._sock.close() + self._zctx.term() + + def _fallback(self, record: Optional[logging.LogRecord]) -> None: + if record is None: + return + print(record.getMessage(), file=sys.stderr) + + def emit(self, record: Optional[logging.LogRecord]) -> None: + if self._sock is None: + self._fallback(record) + return + # record may be None to signal shutdown. + try: + if record is not None and record.exc_info is not None: + pickling_support.install(record.exc_info[1]) + pickled_rec = pickle.dumps(record) + except ( + pickle.PickleError, + TypeError, + ImportError, # when "Python is likely to be shutting down" + ): + # We have a pickling error. + # Change it into a self-created picklable log record with exception info. + if record is not None: + exc_info: Any + if isinstance(record.exc_info, tuple): + exc_info = ( + PickledException, + PickledException(repr(record.exc_info[1])), # store stringified repr + record.exc_info[2], + ) + else: + exc_info = record.exc_info + record = logging.makeLogRecord({ + 'name': record.name, + 'pathname': record.pathname, + 'lineno': record.lineno, + 'msg': record.getMessage(), + 'levelno': record.levelno, + 'levelname': record.levelname, + 'exc_info': exc_info, + }) + pickled_rec = pickle.dumps(record) + try: + self._sock.send(pickled_rec) + except zmq.ZMQError: + self._fallback(record) + + +class AbstractLogger(metaclass=ABCMeta): + def __init__( + self, + daemon_config: MutableMapping[str, Any], + ) -> None: + pass + + @abstractmethod + def __enter__(self): + raise NotImplementedError + + @abstractmethod + def __exit__(self, *exc_info_args): + raise NotImplementedError + + +class NoopLogger(AbstractLogger): + def __init__( + self, + daemon_config: MutableMapping[str, Any], + ) -> None: + pass + + def __enter__(self): + pass + + def __exit__(self, *exc_info_args): + pass + + +class Logger(AbstractLogger): + + is_master: bool + log_endpoint: str + daemon_config: Mapping[str, Any] + log_config: MutableMapping[str, Any] + log_worker: threading.Thread + + def __init__( + self, + daemon_config: MutableMapping[str, Any], + *, + is_master: bool, + log_endpoint: str, + ) -> None: + legacy_logfile_path = os.environ.get('BACKEND_LOG_FILE') + if legacy_logfile_path: + p = Path(legacy_logfile_path) + config.override_key(daemon_config, ('file', 'path'), p.parent) + config.override_key(daemon_config, ('file', 'filename'), p.name) + config.override_with_env(daemon_config, ('file', 'backup-count'), 'BACKEND_LOG_FILE_COUNT') + legacy_logfile_size = os.environ.get('BACKEND_LOG_FILE_SIZE') + if legacy_logfile_size: + legacy_logfile_size = f'{legacy_logfile_size}M' + config.override_with_env(daemon_config, ('file', 'rotation-size'), legacy_logfile_size) + + cfg = logging_config_iv.check(daemon_config) + + def _check_driver_config_exists_if_activated(cfg, driver): + if driver in cfg['drivers'] and cfg[driver] is None: + raise ConfigurationError({'logging': f'{driver} driver is activated but no config given.'}) + + _check_driver_config_exists_if_activated(cfg, 'console') + _check_driver_config_exists_if_activated(cfg, 'file') + _check_driver_config_exists_if_activated(cfg, 'logstash') + + self.is_master = is_master + self.log_endpoint = log_endpoint + self.daemon_config = cfg + self.log_config = { + 'version': 1, + 'disable_existing_loggers': False, + 'handlers': { + 'null': {'class': 'logging.NullHandler'}, + }, + 'loggers': { + '': {'handlers': [], 'level': cfg['level']}, + **{k: {'handlers': [], 'level': v, 'propagate': False} for k, v in cfg['pkg-ns'].items()}, + }, + } + + def __enter__(self): + tx.fix_trafaret_pickle_support() # monkey-patch for pickling trafaret.DataError + pickling_support.install() # enable pickling of tracebacks + self.log_config['handlers']['relay'] = { + 'class': 'ai.backend.common.logging.RelayHandler', + 'level': self.daemon_config['level'], + 'endpoint': self.log_endpoint, + } + for _logger in self.log_config['loggers'].values(): + _logger['handlers'].append('relay') + logging.config.dictConfig(self.log_config) + self._is_active_token = is_active.set(True) + if self.is_master and self.log_endpoint: + self.relay_handler = logging.getLogger('').handlers[0] + self.ready_event = threading.Event() + assert isinstance(self.relay_handler, RelayHandler) + self.log_worker = threading.Thread( + target=log_worker, name='Logger', + args=(self.daemon_config, os.getpid(), self.log_endpoint, self.ready_event)) + self.log_worker.start() + self.ready_event.wait() + + def __exit__(self, *exc_info_args): + # Resetting generates "different context" errors. + # Since practically we only need to check activeness in alembic scripts + # and it should be active until the program terminates, + # just leave it as-is. + is_active.reset(self._is_active_token) + if self.is_master and self.log_endpoint: + self.relay_handler.emit(None) + self.log_worker.join() + self.relay_handler.close() + ep_url = yarl.URL(self.log_endpoint) + if ep_url.scheme.lower() == 'ipc' and (ep_sock := Path(ep_url.path)).exists(): + ep_sock.unlink() diff --git a/src/ai/backend/common/logging_utils.py b/src/ai/backend/common/logging_utils.py new file mode 100644 index 0000000000..05bd6b0ce5 --- /dev/null +++ b/src/ai/backend/common/logging_utils.py @@ -0,0 +1,24 @@ +import logging + + +class BraceMessage: + + __slots__ = ('fmt', 'args') + + def __init__(self, fmt, args): + self.fmt = fmt + self.args = args + + def __str__(self): + return self.fmt.format(*self.args) + + +class BraceStyleAdapter(logging.LoggerAdapter): + + def __init__(self, logger, extra=None): + super().__init__(logger, extra) + + def log(self, level, msg, *args, **kwargs): + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + self.logger._log(level, BraceMessage(msg, args), (), **kwargs) diff --git a/src/ai/backend/common/msgpack.py b/src/ai/backend/common/msgpack.py new file mode 100644 index 0000000000..b7db574c91 --- /dev/null +++ b/src/ai/backend/common/msgpack.py @@ -0,0 +1,17 @@ +''' +Wrapper of msgpack-python with good defaults. +''' + +from typing import Any + +import msgpack as _msgpack + + +def packb(data: Any, **kwargs) -> bytes: + opts = {"use_bin_type": True, **kwargs} + return _msgpack.packb(data, **opts) + + +def unpackb(packed: bytes, **kwargs) -> Any: + opts = {"raw": False, "use_list": False, **kwargs} + return _msgpack.unpackb(packed, **opts) diff --git a/src/ai/backend/common/networking.py b/src/ai/backend/common/networking.py new file mode 100644 index 0000000000..c93b47f14a --- /dev/null +++ b/src/ai/backend/common/networking.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import asyncio +from contextlib import closing +import socket +from typing import ( + Callable, + Mapping, + TYPE_CHECKING, + TypeVar, +) + +import aiohttp +from async_timeout import timeout as _timeout + +if TYPE_CHECKING: + import yarl + +__all__ = ( + 'find_free_port', + 'curl', +) + +T = TypeVar('T') + + +async def curl( + url: str | yarl.URL, + default_value: str | T | Callable[[], str | T], + params: Mapping[str, str] = None, + headers: Mapping[str, str] = None, + timeout: float = 0.2, +) -> str | T: + """ + A simple curl-like helper function that uses aiohttp to fetch some string/data + from a remote HTTP endpoint. + """ + try: + async with aiohttp.ClientSession() as sess: + async with _timeout(timeout): + async with sess.get(url, params=params, headers=headers) as resp: + assert resp.status == 200 + body = await resp.text() + return body.strip() + except (asyncio.TimeoutError, aiohttp.ClientError, AssertionError): + if callable(default_value): + return default_value() + return default_value + + +def find_free_port(bind_addr: str = '127.0.0.1') -> int: + """ + Find a freely available TCP port in the current host. + Note that since under certain conditions this may have races. + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind((bind_addr, 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] diff --git a/src/ai/backend/common/plugin/__init__.py b/src/ai/backend/common/plugin/__init__.py new file mode 100644 index 0000000000..e11533e039 --- /dev/null +++ b/src/ai/backend/common/plugin/__init__.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +import asyncio +import logging +import pkg_resources +import re +from typing import ( + Any, + ClassVar, + Container, + Dict, + Generic, + Iterator, + Mapping, + Tuple, + Type, + TypeVar, +) +from weakref import WeakSet + +from ai.backend.common.asyncio import cancel_tasks + +from ..etcd import AsyncEtcd +from ..logging_utils import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +__all__ = ( + 'AbstractPlugin', + 'BasePluginContext', +) + + +class AbstractPlugin(metaclass=ABCMeta): + """ + The minimum generic plugin interface. + """ + + plugin_config: Mapping[str, Any] + """ + ``plugin_config`` contains the plugin-specific configuration read from the etcd. + """ + + local_config: Mapping[str, Any] + """ + ``local_config`` contains the configuration read from the disk TOML file of the current daemon. + This configuration is only updated when restarting the daemon and thus plugins should assume + that it's read-only and immutable during its lifetime. + e.g., If the plugin is running with the manager, it's the validated content of manager.toml file. + """ + + config_watch_enabled: ClassVar[bool] = True + """ + If set True (default), the hosting plugin context will watch and automatically update + the etcd's plugin configuration changes via the ``update_plugin_config()`` method. + """ + + def __init__(self, plugin_config: Mapping[str, Any], local_config: Mapping[str, Any]) -> None: + """ + Instantiate the plugin with the given initial configuration. + """ + self.plugin_config = plugin_config + self.local_config = local_config + + @abstractmethod + async def init(self, context: Any = None) -> None: + """ + Initialize any resource used by the plugin. + """ + pass + + @abstractmethod + async def cleanup(self) -> None: + """ + Clean up any resource used by the plugin upon server cleanup. + """ + pass + + @abstractmethod + async def update_plugin_config(self, plugin_config: Mapping[str, Any]) -> None: + """ + Handle runtime configuration updates. + The config parameter contains both the updated parts + and unchanged parts of the configuration. + + The default implementation is just to replace the config property, + but actual plugins may trigger other operations to reflect config changes + and/or inspect the differences of configs before replacing the current config. + """ + self.plugin_config = plugin_config + + +P = TypeVar('P', bound=AbstractPlugin) + + +class BasePluginContext(Generic[P]): + """ + A minimal plugin manager which controls the lifecycles of the given plugins + and watches & applies the configuration changes in etcd. + + The subclasses must redefine ``plugin_group``. + """ + + etcd: AsyncEtcd + local_config: Mapping[str, Any] + plugins: Dict[str, P] + plugin_group: ClassVar[str] = 'backendai_XXX_v10' + + _config_watchers: WeakSet[asyncio.Task] + + def __init__(self, etcd: AsyncEtcd, local_config: Mapping[str, Any]) -> None: + self.etcd = etcd + self.local_config = local_config + self.plugins = {} + self._config_watchers = WeakSet() + if m := re.search(r'^backendai_(\w+)_v(\d+)$', self.plugin_group): + self._group_key = m.group(1) + else: + raise TypeError( + f"{type(self).__name__} has invalid plugin_group class attribute", + self.plugin_group, + ) + + @classmethod + def discover_plugins( + cls, + plugin_group: str, + blocklist: Container[str] = None, + ) -> Iterator[Tuple[str, Type[P]]]: + if blocklist is None: + blocklist = set() + for entrypoint in pkg_resources.iter_entry_points(plugin_group): + if entrypoint.name in blocklist: + continue + log.info('loading plugin (group:{}): {}', plugin_group, entrypoint.name) + yield entrypoint.name, entrypoint.load() + + async def init(self, context: Any = None) -> None: + scanned_plugins = self.discover_plugins(self.plugin_group) + for plugin_name, plugin_entry in scanned_plugins: + plugin_config = await self.etcd.get_prefix( + f"config/plugins/{self._group_key}/{plugin_name}/", + ) + try: + plugin_instance = plugin_entry(plugin_config, self.local_config) + await plugin_instance.init(context=context) + except Exception: + log.exception('error during initialization of plugin: {}', plugin_name) + continue + else: + self.plugins[plugin_name] = plugin_instance + if plugin_instance.config_watch_enabled: + await self.watch_config_changes(plugin_name) + await asyncio.sleep(0) + + async def cleanup(self) -> None: + await cancel_tasks(self._config_watchers) + await asyncio.sleep(0) + for plugin_instance in self.plugins.values(): + await plugin_instance.cleanup() + + async def _watcher(self, plugin_name: str) -> None: + # As wait_timeout applies to the waiting for an internal async queue, + # so short timeouts for polling the changes does not incur gRPC/network overheads. + async for _ in self.etcd.watch_prefix( + f"config/plugins/{self._group_key}/{plugin_name}", + wait_timeout=0.2, + ): + new_config = await self.etcd.get_prefix( + f"config/plugins/{self._group_key}/{plugin_name}/", + ) + await self.plugins[plugin_name].update_plugin_config(new_config) + + async def watch_config_changes(self, plugin_name: str) -> None: + wtask = asyncio.create_task(self._watcher(plugin_name)) + self._config_watchers.add(wtask) diff --git a/src/ai/backend/common/plugin/hook.py b/src/ai/backend/common/plugin/hook.py new file mode 100644 index 0000000000..f6705c9e25 --- /dev/null +++ b/src/ai/backend/common/plugin/hook.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +import enum +import logging +from typing import ( + Any, + Final, + List, + Optional, + Protocol, + Sequence, + Tuple, + Union, +) + +import attr + +from . import AbstractPlugin, BasePluginContext +from ..logging_utils import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +__all__ = ( + 'HookHandler', + 'HookPlugin', + 'HookPluginContext', + 'Reject', + 'HookResults', + 'HookResult', + 'HookReturnTiming', + 'PASSED', + 'REJECTED', + 'ERROR', + 'ALL_COMPLETED', + 'FIRST_COMPLETED', +) + + +class HookHandler(Protocol): + """ + The handler should accept a single argument containing + a tuple of parameters passed to the handler. + If it decides to cancel the ongoing event, it should raise + :class:`HookDenied` exception. + """ + + async def __call__(self, *args) -> Any: + # NOTE: Until https://github.com/python/mypy/issues/5876 is resolved, + # the get_handlers() in the HookPlugin subclasses should be marked + # with "type: ignore" comments. + ... + + +class HookPlugin(AbstractPlugin, metaclass=ABCMeta): + """ + The abstract interface for hook plugins. + """ + + @abstractmethod + def get_handlers(self) -> Sequence[Tuple[str, HookHandler]]: + """ + Returns a sequence of pairs of the event name + and its corresponding handler function. + """ + pass + + +class Reject(Exception): + def __init__(self, reason: str): + super().__init__(reason) + self.reason = reason + + +class HookResults(enum.Enum): + PASSED = 0 + REJECTED = 1 + ERROR = 2 + + +class HookReturnTiming(enum.Enum): + ALL_COMPLETED = 0 + FIRST_COMPLETED = 1 + + +PASSED: Final = HookResults.PASSED +REJECTED: Final = HookResults.REJECTED +ERROR: Final = HookResults.ERROR +ALL_COMPLETED: Final = HookReturnTiming.ALL_COMPLETED +FIRST_COMPLETED: Final = HookReturnTiming.FIRST_COMPLETED + + +@attr.s(auto_attribs=True, slots=True) +class HookResult: + status: HookResults + src_plugin: Optional[Union[str, Sequence[str]]] = None + reason: Optional[str] = None + result: Optional[Any] = None + + +class HookPluginContext(BasePluginContext[HookPlugin]): + """ + A manager for hook plugins with convenient handler invocation. + """ + + plugin_group = 'backendai_hook_v20' + + def _get_handlers( + self, event_name: str, order: Sequence[str] = None, + ) -> Sequence[Tuple[str, HookHandler]]: + handlers = [] + for plugin_name, plugin_instance in self.plugins.items(): + for hooked_event_name, hook_handler in plugin_instance.get_handlers(): + if event_name != hooked_event_name: + continue + handlers.append((plugin_name, hook_handler)) + if order is not None: + non_empty_order = order + handlers.sort(key=lambda item: non_empty_order.index(item)) + else: + # the default is alphabetical order with plugin names + handlers.sort(key=lambda item: item[0]) + return handlers + + async def dispatch( + self, event_name: str, args: Tuple[Any, ...], *, + return_when: HookReturnTiming = ALL_COMPLETED, + success_if_no_hook: bool = True, + order: Sequence[str] = None, + ) -> HookResult: + """ + Invoke the handlers that matches with the given ``event_name``. + If any of the handlers raises :class:`HookDenied`, + the event caller should seize the processing. + """ + executed_plugin_names = [] + results: List[Any] = [] + for plugin_name, hook_handler in self._get_handlers(event_name, order=order): + try: + executed_plugin_names.append(plugin_name) + result = await hook_handler(*args) + except Reject as e: + return HookResult( + status=REJECTED, + src_plugin=plugin_name, + reason=e.reason, + ) + except Exception as e: + return HookResult( + status=ERROR, + src_plugin=plugin_name, + reason=repr(e), + ) + else: + if return_when == FIRST_COMPLETED: + return HookResult( + status=PASSED, + src_plugin=plugin_name, + result=result, + ) + else: + results.append(result) + if not success_if_no_hook and not executed_plugin_names: + return HookResult( + status=REJECTED, + src_plugin=executed_plugin_names, # empty + result=results, # empty + ) + return HookResult( + status=PASSED, + src_plugin=executed_plugin_names, + result=results, + ) + + async def notify( + self, event_name: str, args: Tuple[Any, ...], + ) -> None: + """ + Invoke the handlers that matches with the given ``event_name``. + Regardless of the handler results, the processing continues. + """ + for plugin_name, hook_handler in self._get_handlers(event_name): + try: + await hook_handler(*args) + except Exception: + log.exception('HookPluginContext.notify({}): skipping error in hook handler from {}', + event_name, plugin_name) + continue diff --git a/src/ai/backend/common/plugin/monitor.py b/src/ai/backend/common/plugin/monitor.py new file mode 100644 index 0000000000..ce60719aca --- /dev/null +++ b/src/ai/backend/common/plugin/monitor.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +import enum +from typing import ( + Any, + Mapping, + Union, +) + +from . import AbstractPlugin, BasePluginContext + +__all__ = ( + 'AbstractStatReporterPlugin', + 'AbstractErrorReporterPlugin', + 'StatsPluginContext', + 'ErrorPluginContext', + 'INCREMENT', + 'GAUGE', +) + + +class StatMetricTypes(enum.Enum): + INCREMENT = 0 + GAUGE = 1 + + +INCREMENT = StatMetricTypes.INCREMENT +GAUGE = StatMetricTypes.GAUGE + + +class AbstractStatReporterPlugin(AbstractPlugin, metaclass=ABCMeta): + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + @abstractmethod + async def report_metric( + self, + metric_type: StatMetricTypes, + metric_name: str, + value: Union[float, int] = None, + ) -> None: + pass + + +class AbstractErrorReporterPlugin(AbstractPlugin, metaclass=ABCMeta): + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + @abstractmethod + async def capture_exception( + self, + exc_instance: Exception = None, + context: Mapping[str, Any] = None, + ) -> None: + pass + + @abstractmethod + async def capture_message(self, message: str) -> None: + pass + + +class StatsPluginContext(BasePluginContext[AbstractStatReporterPlugin]): + plugin_group = 'backendai_stats_monitor_v20' + + async def report_metric( + self, + metric_type: StatMetricTypes, + metric_name: str, + value: Union[float, int] = None, + ) -> None: + for plugin_instance in self.plugins.values(): + await plugin_instance.report_metric(metric_type, metric_name, value) + + +class ErrorPluginContext(BasePluginContext[AbstractErrorReporterPlugin]): + plugin_group = 'backendai_error_monitor_v20' + + async def capture_exception( + self, + exc_instance: Exception = None, + context: Mapping[str, Any] = None, + ) -> None: + for plugin_instance in self.plugins.values(): + await plugin_instance.capture_exception(exc_instance, context) + + async def capture_message(self, message: str) -> None: + for plugin_instance in self.plugins.values(): + await plugin_instance.capture_message(message) diff --git a/src/ai/backend/common/plugin/py.typed b/src/ai/backend/common/plugin/py.typed new file mode 100644 index 0000000000..5abed26af8 --- /dev/null +++ b/src/ai/backend/common/plugin/py.typed @@ -0,0 +1 @@ +marker diff --git a/src/ai/backend/common/py.typed b/src/ai/backend/common/py.typed new file mode 100644 index 0000000000..5abed26af8 --- /dev/null +++ b/src/ai/backend/common/py.typed @@ -0,0 +1 @@ +marker diff --git a/src/ai/backend/common/redis.py b/src/ai/backend/common/redis.py new file mode 100644 index 0000000000..f032f05582 --- /dev/null +++ b/src/ai/backend/common/redis.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +import asyncio +import inspect +import logging +import socket +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + Union, +) + +import aioredis +import aioredis.client +import aioredis.sentinel +import aioredis.exceptions +import yarl + +from .logging import BraceStyleAdapter +from .types import EtcdRedisConfig, RedisConnectionInfo +from .validators import DelimiterSeperatedList, HostPortPair + +__all__ = ( + 'execute', + 'subscribe', + 'blpop', + 'read_stream', + 'read_stream_by_group', + 'get_redis_object', +) + +_keepalive_options: MutableMapping[int, int] = {} + +# macOS does not support several TCP_ options +# so check if socket package includes TCP options before adding it +if hasattr(socket, 'TCP_KEEPIDLE'): + _keepalive_options[socket.TCP_KEEPIDLE] = 20 + +if hasattr(socket, 'TCP_KEEPINTVL'): + _keepalive_options[socket.TCP_KEEPINTVL] = 5 + +if hasattr(socket, 'TCP_KEEPCNT'): + _keepalive_options[socket.TCP_KEEPCNT] = 3 + + +_default_conn_opts: Mapping[str, Any] = { + 'socket_timeout': 3.0, + 'socket_connect_timeout': 0.3, + 'socket_keepalive': True, + 'socket_keepalive_options': _keepalive_options, +} + + +_scripts: Dict[str, str] = {} + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class ConnectionNotAvailable(Exception): + pass + + +def _calc_delay_exp_backoff(initial_delay: float, retry_count: float, time_limit: float) -> float: + if time_limit > 0: + return min(initial_delay * (2 ** retry_count), time_limit / 2) + return min(initial_delay * (2 ** retry_count), 30.0) + + +def _parse_stream_msg_id(msg_id: bytes) -> Tuple[int, int]: + timestamp, _, sequence = msg_id.partition(b'-') + return int(timestamp), int(sequence) + + +async def subscribe( + channel: aioredis.client.PubSub, + *, + reconnect_poll_interval: float = 0.3, +) -> AsyncIterator[Any]: + """ + An async-generator wrapper for pub-sub channel subscription. + It automatically recovers from server shutdowns until explicitly cancelled. + """ + async def _reset_chan(): + channel.connection = None + try: + await channel.ping() + except aioredis.exceptions.ConnectionError: + pass + else: + assert channel.connection is not None + await channel.on_connect(channel.connection) + + while True: + try: + if not channel.connection: + raise ConnectionNotAvailable + message = await channel.get_message(ignore_subscribe_messages=True, timeout=10.0) + if message is not None: + yield message["data"] + except ( + aioredis.exceptions.ConnectionError, + aioredis.sentinel.MasterNotFoundError, + aioredis.sentinel.SlaveNotFoundError, + aioredis.exceptions.ReadOnlyError, + aioredis.exceptions.ResponseError, + ConnectionResetError, + ConnectionNotAvailable, + ): + await asyncio.sleep(reconnect_poll_interval) + await _reset_chan() + continue + except aioredis.exceptions.ResponseError as e: + if e.args[0].startswith("NOREPLICAS "): + await asyncio.sleep(reconnect_poll_interval) + await _reset_chan() + continue + raise + except (TimeoutError, asyncio.TimeoutError): + continue + except asyncio.CancelledError: + raise + finally: + await asyncio.sleep(0) + + +async def blpop( + redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel, + key: str, + *, + service_name: str = None, + reconnect_poll_interval: float = 0.3, +) -> AsyncIterator[Any]: + """ + An async-generator wrapper for blpop (blocking left pop). + It automatically recovers from server shutdowns until explicitly cancelled. + """ + _conn_opts = { + **_default_conn_opts, + 'socket_connect_timeout': reconnect_poll_interval, + } + if isinstance(redis, RedisConnectionInfo): + redis_client = redis.client + service_name = service_name or redis.service_name + else: + redis_client = redis + + if isinstance(redis_client, aioredis.sentinel.Sentinel): + assert service_name is not None + r = redis_client.master_for( + service_name, + redis_class=aioredis.Redis, + connection_pool_class=aioredis.sentinel.SentinelConnectionPool, + **_conn_opts, + ) + else: + r = redis_client + while True: + try: + raw_msg = await r.blpop(key, timeout=10.0) + if not raw_msg: + continue + yield raw_msg[1] + except ( + aioredis.exceptions.ConnectionError, + aioredis.sentinel.MasterNotFoundError, + aioredis.exceptions.ReadOnlyError, + aioredis.exceptions.ResponseError, + ConnectionResetError, + ): + await asyncio.sleep(reconnect_poll_interval) + continue + except aioredis.exceptions.ResponseError as e: + if e.args[0].startswith("NOREPLICAS "): + await asyncio.sleep(reconnect_poll_interval) + continue + raise + except (TimeoutError, asyncio.TimeoutError): + continue + except asyncio.CancelledError: + raise + finally: + await asyncio.sleep(0) + + +async def execute( + redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel, + func: Callable[[aioredis.Redis], Awaitable[Any]], + *, + service_name: str = None, + read_only: bool = False, + reconnect_poll_interval: float = 0.3, + encoding: Optional[str] = None, +) -> Any: + """ + Executes a function that issues Redis commands or returns a pipeline/transaction of commands, + with automatic retries upon temporary connection failures. + + Note that when retried, the given function may be executed *multiple* times, so the caller + should take care of side-effects of it. + """ + _conn_opts = { + **_default_conn_opts, + 'socket_connect_timeout': reconnect_poll_interval, + } + if isinstance(redis, RedisConnectionInfo): + redis_client = redis.client + service_name = service_name or redis.service_name + else: + redis_client = redis + + if isinstance(redis_client, aioredis.sentinel.Sentinel): + assert service_name is not None + if read_only: + r = redis_client.slave_for( + service_name, + redis_class=aioredis.Redis, + connection_pool_class=aioredis.sentinel.SentinelConnectionPool, + **_conn_opts, + ) + else: + r = redis_client.master_for( + service_name, + redis_class=aioredis.Redis, + connection_pool_class=aioredis.sentinel.SentinelConnectionPool, + **_conn_opts, + ) + else: + r = redis_client + while True: + try: + async with r: + if callable(func): + aw_or_pipe = func(r) + else: + raise TypeError('The func must be a function or a coroutinefunction ' + 'with no arguments.') + if isinstance(aw_or_pipe, aioredis.client.Pipeline): + result = await aw_or_pipe.execute() + elif inspect.isawaitable(aw_or_pipe): + result = await aw_or_pipe + else: + raise TypeError('The return value must be an awaitable' + 'or aioredis.commands.Pipeline object') + if isinstance(result, aioredis.client.Pipeline): + # This happens when func is an async function that returns a pipeline. + result = await result.execute() + if encoding: + if isinstance(result, bytes): + return result.decode(encoding) + elif isinstance(result, dict): + newdict = {} + for k, v in result.items(): + newdict[k.decode(encoding)] = v.decode(encoding) + return newdict + else: + return result + except ( + aioredis.sentinel.MasterNotFoundError, + aioredis.sentinel.SlaveNotFoundError, + aioredis.exceptions.ReadOnlyError, + ConnectionResetError, + ): + await asyncio.sleep(reconnect_poll_interval) + continue + except aioredis.exceptions.ConnectionError as e: + log.exception(f'execute(): Connecting to redis failed: {e}') + await asyncio.sleep(reconnect_poll_interval) + continue + except aioredis.exceptions.ResponseError as e: + if "NOREPLICAS" in e.args[0]: + await asyncio.sleep(reconnect_poll_interval) + continue + raise + except (TimeoutError, asyncio.TimeoutError): + continue + except asyncio.CancelledError: + raise + finally: + await asyncio.sleep(0) + + +async def execute_script( + redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel, + script_id: str, + script: str, + keys: Sequence[str], + args: Sequence[Union[bytes, memoryview, str, int, float]], # aioredis.connection.EncodableT +) -> Any: + """ + Auto-load and execute the given script. + It uses the hash keys for scripts so that it does not send the whole + script every time but only at the first time. + + Args: + conn: A Redis connection or pool with the commands mixin. + script_id: A human-readable identifier for the script. + This can be arbitrary string but must be unique for each script. + script: The script content. + keys: The Redis keys that will be passed to the script. + args: The arguments that will be passed to the script. + """ + script_hash = _scripts.get(script_id, 'x') + while True: + try: + ret = await execute(redis, lambda r: r.evalsha( + script_hash, + len(keys), + *keys, *args, + )) + break + except aioredis.exceptions.NoScriptError: + # Redis may have been restarted. + script_hash = await execute(redis, lambda r: r.script_load(script)) + _scripts[script_id] = script_hash + except aioredis.exceptions.ResponseError as e: + if 'NOSCRIPT' in e.args[0]: + # Redis may have been restarted. + script_hash = await execute(redis, lambda r: r.script_load(script)) + _scripts[script_id] = script_hash + else: + raise + continue + return ret + + +async def read_stream( + r: RedisConnectionInfo, + stream_key: str, + *, + block_timeout: int = 10_000, # in msec +) -> AsyncIterator[Tuple[bytes, bytes]]: + """ + A high-level wrapper for the XREAD command. + """ + last_id = b'$' + while True: + try: + reply = await execute( + r, + lambda r: r.xread( + {stream_key: last_id}, + block=block_timeout, + ), + ) + if reply is None: + continue + # Keep some latest messages so that other manager + # processes to have chances of fetching them. + await execute( + r, + lambda r: r.xtrim( + stream_key, + maxlen=128, + approximate=True, + ), + ) + for msg_id, msg_data in reply[0][1]: + try: + yield msg_id, msg_data + finally: + last_id = msg_id + except asyncio.CancelledError: + raise + + +async def read_stream_by_group( + r: RedisConnectionInfo, + stream_key: str, + group_name: str, + consumer_id: str, + *, + autoclaim_idle_timeout: int = 1_000, # in msec + block_timeout: int = 10_000, # in msec +) -> AsyncIterator[Tuple[bytes, bytes]]: + """ + A high-level wrapper for the XREADGROUP command + combined with XAUTOCLAIM and XGROUP_CREATE. + """ + while True: + try: + messages = [] + autoclaim_start_id = b'0-0' + while True: + reply = await execute( + r, + lambda r: r.execute_command( + "XAUTOCLAIM", + stream_key, + group_name, + consumer_id, + str(autoclaim_idle_timeout), + autoclaim_start_id, + ), + ) + for msg_id, msg_data in aioredis.client.parse_stream_list(reply[1]): + messages.append((msg_id, msg_data)) + if reply[0] == b'0-0': + break + autoclaim_start_id = reply[0] + reply = await execute( + r, + lambda r: r.xreadgroup( + group_name, + consumer_id, + {stream_key: b">"}, # fetch messages not seen by other consumers + block=block_timeout, + ), + ) + if len(reply) == 0: + continue + assert reply[0][0].decode() == stream_key + for msg_id, msg_data in reply[0][1]: + messages.append((msg_id, msg_data)) + await execute( + r, + lambda r: r.xack( + stream_key, + group_name, + *(msg_id for msg_id, msg_data in reply[0][1]), + ), + ) + for msg_id, msg_data in messages: + yield msg_id, msg_data + except asyncio.CancelledError: + raise + except aioredis.exceptions.ResponseError as e: + if e.args[0].startswith("NOGROUP "): + try: + await execute( + r, + lambda r: r.xgroup_create( + stream_key, + group_name, + b"$", + mkstream=True, + ), + ) + except aioredis.exceptions.ResponseError as e: + if e.args[0].startswith("BUSYGROUP "): + pass + else: + raise + continue + raise + + +def get_redis_object( + redis_config: EtcdRedisConfig, + db: int = 0, + **kwargs, +) -> RedisConnectionInfo: + if _sentinel_addresses := redis_config.get('sentinel'): + sentinel_addresses: Any = None + if isinstance(_sentinel_addresses, str): + sentinel_addresses = DelimiterSeperatedList(HostPortPair).check_and_return(_sentinel_addresses) + else: + sentinel_addresses = _sentinel_addresses + + assert redis_config.get('service_name') is not None + sentinel = aioredis.sentinel.Sentinel( + [(str(host), port) for host, port in sentinel_addresses], + password=redis_config.get('password'), + db=str(db), + sentinel_kwargs={ + **kwargs, + }, + ) + return RedisConnectionInfo( + client=sentinel, + service_name=redis_config.get('service_name'), + ) + else: + redis_url = redis_config.get('addr') + assert redis_url is not None + url = ( + yarl.URL('redis://host') + .with_host(str(redis_url[0])) + .with_port(redis_url[1]) + .with_password(redis_config.get('password')) / str(db) + ) + return RedisConnectionInfo( + client=aioredis.Redis.from_url(str(url), **kwargs), + service_name=None, + ) + + +async def ping_redis_connection(client: aioredis.client.Redis): + try: + _ = await client.time() + except aioredis.exceptions.ConnectionError as e: + log.exception(f'ping_redis_connection(): Connecting to redis failed: {e}') + raise e diff --git a/src/ai/backend/common/sd_notify.py b/src/ai/backend/common/sd_notify.py new file mode 100644 index 0000000000..ad24f66e33 --- /dev/null +++ b/src/ai/backend/common/sd_notify.py @@ -0,0 +1,119 @@ +""" +A wrapper for systemd's daemon status notification protocol. +The methods will silently becomes no-op if NOTIFY_SOCKET environment variable is not set. + +This module implements a subset of the notification protocol, excluding +file descriptor related messages. + +Reference: https://www.freedesktop.org/software/systemd/man/sd_notify.html + +Usage: + +.. code-block:: + + import asyncio + import sd_notify + + sdnotify = sd_notify.Notifier() + + # Report a status message + await sdnotify.update_status("Initialising my service...") + await asyncio.sleep(3) + + # Report that the program init is complete + await sdnotify.ready() + await sdnotify.update_status("Waiting for web requests...") + await asyncio.sleep(3) + + # Report an error to the service manager + await sdnotify.set_watchdog_error("An irrecoverable error occured!") +""" + +from __future__ import annotations + +import asyncio +import os +import socket + +import asyncudp + + +class SystemdNotifier(): + + socket: asyncudp.Socket | None + address: str | None + + def __init__(self) -> None: + self.socket = None + self.address = os.getenv("NOTIFY_SOCKET", None) + + @property + def enabled(self) -> bool: + return (self.address is not None) + + async def _send(self, raw_msg: bytes) -> None: + """ + Send a binary message via the notification socket. + If the `NOTIFY_SOCKET` environment variable is not set, + it will silently skip. + """ + if self.address is None: + return + loop = asyncio.get_running_loop() + if self.socket is None: + self.socket = asyncudp.Socket( + *(await loop.create_datagram_endpoint( + asyncudp._SocketProtocol, + family=socket.AF_UNIX, + remote_addr=self.address, # type: ignore + )), + ) + self.socket.sendto(raw_msg) + + async def ready(self) -> None: + """Report ready service state, i.e., completed initialization.""" + await self._send(b"READY=1\n") + + async def stopping(self) -> None: + """Report the stopping/shutting-down service state.""" + await self._send(b"STOPPING=1\n") + + async def reloading(self) -> None: + """Report the reloading service state.""" + await self._send(b"RELOADING=1\n") + + async def set_errno(self, errno: int) -> None: + """Set an errno-style integer code to indicate service failure.""" + await self._send(b"ERRNO=%d\n" % (errno, )) + + async def set_buserror(self, code: str) -> None: + """Set a D-Bus-style error code to indicate service failure.""" + await self._send(b"BUSERROR=%s\n" % (code.encode('utf8'), )) + + async def set_main_pid(self, pid: int) -> None: + """Set the main PID for the case when the service manager did not fork the process itself.""" + await self._send(b"MAINPID=%d\n" % (pid, )) + + async def update_status(self, msg: str) -> None: + """Set a custom service status message""" + await self._send(b"STATUS=%s\n" % (msg.encode('utf8'), )) + + async def keepalive(self) -> None: + """ + Send a keepalive message to extend the watchdog timestamp. + If the time that this keepalive message is not sent to systemd exceeds the watchdog + timeout (WatchdogSec) then systemd will try to restart the service depending on + the service configuration. + """ + await self._send(b"WATCHDOG=1\n") + + async def trigger_watchdog(self, msg: str = None) -> None: + """ + Triggers the systemd's watchdog handler immediately. + + If `msg` is specified, it will be reported as a custom status message to the + service manager to provide more information. + """ + if msg: + await self.update_status(msg) + await self._send(b"WATCHDOG=trigger\n") diff --git a/src/ai/backend/common/service_ports.py b/src/ai/backend/common/service_ports.py new file mode 100644 index 0000000000..aeeaa663c7 --- /dev/null +++ b/src/ai/backend/common/service_ports.py @@ -0,0 +1,60 @@ +import re +from typing import ( + List, + Sequence, + Set, + Type, +) + +from .types import ServicePort, ServicePortProtocols + +__all__ = ( + 'parse_service_ports' +) + +_rx_service_ports = re.compile( + r'^(?P[\w-]+):(?P\w+):(?P\[\d+(?:,\d+)*\]|\d+)(?:,|$)') + + +def parse_service_ports(s: str, exception_cls: Type[Exception] = None) -> Sequence[ServicePort]: + items: List[ServicePort] = [] + if exception_cls is None: + exception_cls = ValueError + used_ports: Set[int] = set() + while True: + match = _rx_service_ports.search(s) + if match: + s = s[len(match.group(0)):] + name = match.group('name') + if not name: + raise exception_cls('Service port name must be not empty.') + protocol = match.group('proto') + if protocol == 'pty': + # unsupported, skip + continue + if protocol not in ('tcp', 'http', 'preopen'): + raise exception_cls(f'Unsupported service port protocol: {protocol}') + ports = tuple(map(int, match.group('ports').strip('[]').split(','))) + for p in ports: + if p in used_ports: + raise exception_cls(f'The port {p} is already used by another service port.') + if p <= 1024: + raise exception_cls(f'The service port number {p} must be ' + f'larger than 1024 to run without the root privilege.') + if p >= 65535: + raise exception_cls(f'The service port number {p} must be smaller than 65535.') + if p in (2000, 2001, 2002, 2003, 2200, 7681): + raise exception_cls('The service ports 2000 to 2003, 2200 and 7681 ' + 'are reserved for internal use.') + used_ports.add(p) + items.append({ + 'name': name, + 'protocol': ServicePortProtocols(protocol), + 'container_ports': ports, + 'host_ports': (None,) * len(ports), + }) + else: + if len(s) > 0: + raise exception_cls('Invalid format') + break + return items diff --git a/src/ai/backend/common/testutils.py b/src/ai/backend/common/testutils.py new file mode 100644 index 0000000000..661eeafeac --- /dev/null +++ b/src/ai/backend/common/testutils.py @@ -0,0 +1,49 @@ +from unittest import mock +try: + # Since Python 3.8, AsyncMock is now part of the stdlib. + # Python 3.8 also adds magic-mocking async iterators and async context managers. + from unittest.mock import AsyncMock # type: ignore +except ImportError: + from asynctest import CoroutineMock as AsyncMock # type: ignore + + +def mock_corofunc(return_value): + """ + Return mock coroutine function. + + Python's default mock module does not support coroutines. + """ + async def _mock_corofunc(*args, **kargs): + return return_value + return mock.Mock(wraps=_mock_corofunc) + + +async def mock_awaitable(**kwargs): + """ + Mock awaitable. + + An awaitable can be a native coroutine object "returned from" a native + coroutine function. + """ + return AsyncMock(**kwargs) + + +class AsyncContextManagerMock: + """ + Mock async context manager. + + Can be used to get around `async with` statement for testing. + Must implement `__aenter__` and `__aexit__` which returns awaitable. + Attributes of the awaitable (and self for convenience) can be set by + passing `kwargs`. + """ + def __init__(self, *args, **kwargs): + self.context = kwargs + for k, v in kwargs.items(): + setattr(self, k, v) + + async def __aenter__(self): + return AsyncMock(**self.context) + + async def __aexit__(self, exc_type, exc_value, exc_tb): + pass diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py new file mode 100644 index 0000000000..f47c085ceb --- /dev/null +++ b/src/ai/backend/common/types.py @@ -0,0 +1,848 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from collections import UserDict, namedtuple +from contextvars import ContextVar +from decimal import Decimal +import enum +import ipaddress +import itertools +import math +import numbers +from pathlib import PurePosixPath +import sys +from typing import ( + Any, + Dict, + List, + Literal, + Mapping, + NewType, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + TypedDict, + TYPE_CHECKING, + Union, + cast, + overload, +) +import uuid + +import aioredis +import aioredis.client +import aioredis.sentinel +import attr +import trafaret as t +import typeguard + +__all__ = ( + 'aobject', + 'JSONSerializableMixin', + 'DeviceId', + 'ContainerId', + 'SessionId', + 'KernelId', + 'MetricKey', + 'MetricValue', + 'MovingStatValue', + 'PID', + 'HostPID', + 'ContainerPID', + 'BinarySize', + 'HostPortPair', + 'DeviceId', + 'SlotName', + 'IntrinsicSlotNames', + 'ResourceSlot', + 'HardwareMetadata', + 'MountPermission', + 'MountPermissionLiteral', + 'MountTypes', + 'VFolderMount', + 'KernelCreationConfig', + 'KernelCreationResult', + 'ServicePortProtocols', + 'ClusterInfo', + 'ClusterMode', + 'ClusterSSHKeyPair', + 'check_typed_dict', + 'EtcdRedisConfig', + 'RedisConnectionInfo', +) + +if TYPE_CHECKING: + from .docker import ImageRef + + +T_aobj = TypeVar('T_aobj', bound='aobject') + +current_resource_slots: ContextVar[Mapping[SlotName, SlotTypes]] = ContextVar('current_resource_slots') + + +class aobject(object): + ''' + An "asynchronous" object which guarantees to invoke both ``def __init__(self, ...)`` and + ``async def __ainit(self)__`` to ensure asynchronous initialization of the object. + + You can create an instance of subclasses of aboject in the following way: + + .. code-block:: python + + o = await SomeAObj.new(...) + ''' + + @classmethod + async def new(cls: Type[T_aobj], *args, **kwargs) -> T_aobj: + ''' + We can do ``await SomeAObject(...)``, but this makes mypy + to complain about its return type with ``await`` statement. + This is a copy of ``__new__()`` to workaround it. + ''' + instance = super().__new__(cls) + cls.__init__(instance, *args, **kwargs) + await instance.__ainit__() + return instance + + def __init__(self, *args, **kwargs) -> None: + pass + + async def __ainit__(self) -> None: + ''' + Automatically called when creating the instance using + ``await SubclassOfAObject(...)`` + where the arguments are passed to ``__init__()`` as in + the vanilla Python classes. + ''' + pass + + +T1 = TypeVar('T1') +T2 = TypeVar('T2') +T3 = TypeVar('T3') +T4 = TypeVar('T4') + + +@overload +def check_typed_tuple( + value: Tuple[Any], + types: Tuple[Type[T1]], +) -> Tuple[T1]: + ... + + +@overload +def check_typed_tuple( + value: Tuple[Any, Any], + types: Tuple[Type[T1], Type[T2]], +) -> Tuple[T1, T2]: + ... + + +@overload +def check_typed_tuple( + value: Tuple[Any, Any, Any], + types: Tuple[Type[T1], Type[T2], Type[T3]], +) -> Tuple[T1, T2, T3]: + ... + + +@overload +def check_typed_tuple( + value: Tuple[Any, Any, Any, Any], + types: Tuple[Type[T1], Type[T2], Type[T3], Type[T4]], +) -> Tuple[T1, T2, T3, T4]: + ... + + +def check_typed_tuple(value: Tuple[Any, ...], types: Tuple[Type, ...]) -> Tuple: + for val, typ in itertools.zip_longest(value, types): + if typ is not None: + typeguard.check_type('item', val, typ) + return value + + +TD = TypeVar('TD') + + +def check_typed_dict(value: Mapping[Any, Any], expected_type: Type[TD]) -> TD: + """ + Validates the given dict against the given TypedDict class, + and wraps the value as the given TypedDict type. + + This is a shortcut to :func:`typeguard.check_typed_dict()` function to fill extra information + + Currently using this function may not be able to fix type errors, due to an upstream issue: + python/mypy#9827 + """ + assert issubclass(expected_type, dict) and hasattr(expected_type, '__annotations__'), \ + f"expected_type ({type(expected_type)}) must be a TypedDict class" + frame = sys._getframe(1) + _globals = frame.f_globals + _locals = frame.f_locals + memo = typeguard._TypeCheckMemo(_globals, _locals) + typeguard.check_typed_dict('value', value, expected_type, memo) + # Here we passed the check, so return it after casting. + return cast(TD, value) + + +PID = NewType('PID', int) +HostPID = NewType('HostPID', PID) +ContainerPID = NewType('ContainerPID', PID) + +ContainerId = NewType('ContainerId', str) +SessionId = NewType('SessionId', uuid.UUID) +KernelId = NewType('KernelId', uuid.UUID) +ImageAlias = NewType('ImageAlias', str) + +AgentId = NewType('AgentId', str) +DeviceName = NewType('DeviceName', str) +DeviceId = NewType('DeviceId', str) +SlotName = NewType('SlotName', str) +MetricKey = NewType('MetricKey', str) + +AccessKey = NewType('AccessKey', str) +SecretKey = NewType('SecretKey', str) + + +class LogSeverity(str, enum.Enum): + CRITICAL = 'critical' + ERROR = 'error' + WARNING = 'warning' + INFO = 'info' + DEBUG = 'debug' + + +class SlotTypes(str, enum.Enum): + COUNT = 'count' + BYTES = 'bytes' + UNIQUE = 'unique' + + +class HardwareMetadata(TypedDict): + status: Literal["healthy", "degraded", "offline", "unavailable"] + status_info: Optional[str] + metadata: Dict[str, str] + + +class AutoPullBehavior(str, enum.Enum): + DIGEST = 'digest' + TAG = 'tag' + NONE = 'none' + + +class ServicePortProtocols(str, enum.Enum): + HTTP = 'http' + TCP = 'tcp' + PREOPEN = 'preopen' + + +class SessionTypes(str, enum.Enum): + INTERACTIVE = 'interactive' + BATCH = 'batch' + + +class SessionResult(str, enum.Enum): + UNDEFINED = 'undefined' + SUCCESS = 'success' + FAILURE = 'failure' + + +class ClusterMode(str, enum.Enum): + SINGLE_NODE = 'single-node' + MULTI_NODE = 'multi-node' + + +class MovingStatValue(TypedDict): + min: str + max: str + sum: str + avg: str + diff: str + rate: str + version: Optional[int] # for legacy client compatibility + + +MetricValue = TypedDict('MetricValue', { + 'current': str, + 'capacity': Optional[str], + 'pct': Optional[str], + 'unit_hint': str, + 'stats.min': str, + 'stats.max': str, + 'stats.sum': str, + 'stats.avg': str, + 'stats.diff': str, + 'stats.rate': str, +}, total=False) + + +class IntrinsicSlotNames(enum.Enum): + CPU = SlotName('cpu') + MEMORY = SlotName('mem') + + +class DefaultForUnspecified(str, enum.Enum): + LIMITED = 'LIMITED' + UNLIMITED = 'UNLIMITED' + + +class HandlerForUnknownSlotName(str, enum.Enum): + DROP = 'drop' + ERROR = 'error' + + +Quantum = Decimal('0.000') + + +class MountPermission(str, enum.Enum): + READ_ONLY = 'ro' + READ_WRITE = 'rw' + RW_DELETE = 'wd' + + +MountPermissionLiteral = Literal['ro', 'rw', 'wd'] + + +class MountTypes(str, enum.Enum): + VOLUME = 'volume' + BIND = 'bind' + TMPFS = 'tmpfs' + K8S_GENERIC = 'k8s-generic' + K8S_HOSTPATH = 'k8s-hostpath' + + +class HostPortPair(namedtuple('HostPortPair', 'host port')): + + def as_sockaddr(self) -> Tuple[str, int]: + return str(self.host), self.port + + def __str__(self) -> str: + if isinstance(self.host, ipaddress.IPv6Address): + return f'[{self.host}]:{self.port}' + return f'{self.host}:{self.port}' + + +class BinarySize(int): + ''' + A wrapper around Python integers to represent binary sizes for storage and + memory in various places. + + Its string representation and parser, ``from_str()`` classmethod, does not use + any locale-specific digit delimeters -- it supports only standard Python + digit delimeters. + ''' + + suffix_map = { + 'y': 2 ** 80, 'Y': 2 ** 80, # yotta + 'z': 2 ** 70, 'Z': 2 ** 70, # zetta + 'e': 2 ** 60, 'E': 2 ** 60, # exa + 'p': 2 ** 50, 'P': 2 ** 50, # peta + 't': 2 ** 40, 'T': 2 ** 40, # tera + 'g': 2 ** 30, 'G': 2 ** 30, # giga + 'm': 2 ** 20, 'M': 2 ** 20, # mega + 'k': 2 ** 10, 'K': 2 ** 10, # kilo + ' ': 1, + } + suffices = (' ', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y') + endings = ('ibytes', 'ibyte', 'ib', 'bytes', 'byte', 'b') + + @classmethod + def _parse_str(cls, expr: str) -> Union[BinarySize, Decimal]: + if expr.lower() in ('inf', 'infinite', 'infinity'): + return Decimal('Infinity') + orig_expr = expr + expr = expr.strip().replace('_', '') + try: + return cls(expr) + except ValueError: + expr = expr.lower() + dec_expr: Decimal + try: + for ending in cls.endings: + if expr.endswith(ending): + length = len(ending) + 1 + suffix = expr[-length] + dec_expr = Decimal(expr[:-length]) + break + else: + # when there is suffix without scale (e.g., "2K") + if not str.isnumeric(expr[-1]): + suffix = expr[-1] + dec_expr = Decimal(expr[:-1]) + else: + # has no suffix and is not an integer + # -> fractional bytes (e.g., 1.5 byte) + raise ValueError('Fractional bytes are not allowed') + except ArithmeticError: + raise ValueError('Unconvertible value', orig_expr) + try: + multiplier = cls.suffix_map[suffix] + except KeyError: + raise ValueError('Unconvertible value', orig_expr) + return cls(dec_expr * multiplier) + + @classmethod + def finite_from_str( + cls, + expr: Union[str, Decimal, numbers.Integral], + ) -> BinarySize: + if isinstance(expr, Decimal): + if expr.is_infinite(): + raise ValueError('infinite values are not allowed') + return cls(expr) + if isinstance(expr, numbers.Integral): + return cls(int(expr)) + result = cls._parse_str(expr) + if isinstance(result, Decimal) and result.is_infinite(): + raise ValueError('infinite values are not allowed') + return cls(int(result)) + + @classmethod + def from_str( + cls, + expr: Union[str, Decimal, numbers.Integral], + ) -> Union[BinarySize, Decimal]: + if isinstance(expr, Decimal): + return cls(expr) + if isinstance(expr, numbers.Integral): + return cls(int(expr)) + return cls._parse_str(expr) + + def _preformat(self): + scale = self + suffix_idx = 0 + while scale >= 1024: + scale //= 1024 + suffix_idx += 1 + return suffix_idx + + @staticmethod + def _quantize(val, multiplier): + d = Decimal(val) / Decimal(multiplier) + if d == d.to_integral(): + value = d.quantize(Decimal(1)) + else: + value = d.quantize(Decimal('.00')).normalize() + return value + + def __str__(self): + suffix_idx = self._preformat() + if suffix_idx == 0: + if self == 1: + return f'{int(self)} byte' + else: + return f'{int(self)} bytes' + else: + suffix = type(self).suffices[suffix_idx] + multiplier = type(self).suffix_map[suffix] + value = self._quantize(self, multiplier) + return f'{value} {suffix.upper()}iB' + + def __format__(self, format_spec): + if len(format_spec) != 1: + raise ValueError('format-string for BinarySize can be only one character.') + if format_spec == 's': + # automatically scaled + suffix_idx = self._preformat() + if suffix_idx == 0: + return f'{int(self)}' + suffix = type(self).suffices[suffix_idx] + multiplier = type(self).suffix_map[suffix] + value = self._quantize(self, multiplier) + return f'{value}{suffix.lower()}' + else: + # use the given scale + suffix = format_spec.lower() + multiplier = type(self).suffix_map.get(suffix) + if multiplier is None: + raise ValueError('Unsupported scale unit.', suffix) + value = self._quantize(self, multiplier) + return f'{value}{suffix.lower()}'.strip() + + +class ResourceSlot(UserDict): + + __slots__ = ('data', ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def sync_keys(self, other: ResourceSlot) -> None: + self_only_keys = self.data.keys() - other.data.keys() + other_only_keys = other.data.keys() - self.data.keys() + for k in self_only_keys: + other.data[k] = Decimal(0) + for k in other_only_keys: + self.data[k] = Decimal(0) + + def __add__(self, other: ResourceSlot) -> ResourceSlot: + assert isinstance(other, ResourceSlot), 'Only can add ResourceSlot to ResourceSlot.' + self.sync_keys(other) + return type(self)({ + k: self.get(k, 0) + other.get(k, 0) + for k in (self.keys() | other.keys()) + }) + + def __sub__(self, other: ResourceSlot) -> ResourceSlot: + assert isinstance(other, ResourceSlot), 'Only can subtract ResourceSlot from ResourceSlot.' + self.sync_keys(other) + return type(self)({ + k: self.data[k] - other.get(k, 0) + for k in self.keys() + }) + + def __eq__(self, other: object) -> bool: + if other is self: + return True + assert isinstance(other, ResourceSlot), 'Only can compare ResourceSlot objects.' + self.sync_keys(other) + self_values = [self.data[k] for k in sorted(self.data.keys())] + other_values = [other.data[k] for k in sorted(other.data.keys())] + return self_values == other_values + + def __ne__(self, other: object) -> bool: + assert isinstance(other, ResourceSlot), 'Only can compare ResourceSlot objects.' + self.sync_keys(other) + return not self.__eq__(other) + + def eq_contains(self, other: ResourceSlot) -> bool: + assert isinstance(other, ResourceSlot), 'Only can compare ResourceSlot objects.' + common_keys = sorted(other.keys() & self.keys()) + only_other_keys = other.keys() - self.keys() + self_values = [self.data[k] for k in common_keys] + other_values = [other.data[k] for k in common_keys] + return self_values == other_values and all(other[k] == 0 for k in only_other_keys) + + def eq_contained(self, other: ResourceSlot) -> bool: + assert isinstance(other, ResourceSlot), 'Only can compare ResourceSlot objects.' + common_keys = sorted(other.keys() & self.keys()) + only_self_keys = self.keys() - other.keys() + self_values = [self.data[k] for k in common_keys] + other_values = [other.data[k] for k in common_keys] + return self_values == other_values and all(self[k] == 0 for k in only_self_keys) + + def __le__(self, other: ResourceSlot) -> bool: + assert isinstance(other, ResourceSlot), 'Only can compare ResourceSlot objects.' + self.sync_keys(other) + self_values = [self.data[k] for k in self.keys()] + other_values = [other.data[k] for k in self.keys()] + return not any(s > o for s, o in zip(self_values, other_values)) + + def __lt__(self, other: ResourceSlot) -> bool: + assert isinstance(other, ResourceSlot), 'Only can compare ResourceSlot objects.' + self.sync_keys(other) + self_values = [self.data[k] for k in self.keys()] + other_values = [other.data[k] for k in self.keys()] + return (not any(s > o for s, o in zip(self_values, other_values)) and + not (self_values == other_values)) + + def __ge__(self, other: ResourceSlot) -> bool: + assert isinstance(other, ResourceSlot), 'Only can compare ResourceSlot objects.' + self.sync_keys(other) + self_values = [self.data[k] for k in other.keys()] + other_values = [other.data[k] for k in other.keys()] + return not any(s < o for s, o in zip(self_values, other_values)) + + def __gt__(self, other: ResourceSlot) -> bool: + assert isinstance(other, ResourceSlot), 'Only can compare ResourceSlot objects.' + self.sync_keys(other) + self_values = [self.data[k] for k in other.keys()] + other_values = [other.data[k] for k in other.keys()] + return (not any(s < o for s, o in zip(self_values, other_values)) and + not (self_values == other_values)) + + def normalize_slots(self, *, ignore_unknown: bool) -> ResourceSlot: + known_slots = current_resource_slots.get() + unset_slots = known_slots.keys() - self.data.keys() + if not ignore_unknown and (unknown_slots := self.data.keys() - known_slots.keys()): + raise ValueError('Unknown slots', unknown_slots) + data = { + k: v for k, v in self.data.items() + if k in known_slots + } + for k in unset_slots: + data[k] = Decimal(0) + return type(self)(data) + + @classmethod + def _normalize_value(cls, value: Any, unit: str) -> Decimal: + try: + if unit == 'bytes': + if isinstance(value, Decimal): + return Decimal(value) if value.is_finite() else value + if isinstance(value, int): + return Decimal(value) + value = Decimal(BinarySize.from_str(value)) + else: + value = Decimal(value) + if value.is_finite(): + value = value.quantize(Quantum).normalize() + except ArithmeticError: + raise ValueError('Cannot convert to decimal', value) + return value + + @classmethod + def _humanize_value(cls, value: Decimal, unit: str) -> str: + if unit == 'bytes': + try: + result = '{:s}'.format(BinarySize(value)) + except (OverflowError, ValueError): + result = _stringify_number(value) + else: + result = _stringify_number(value) + return result + + @classmethod + def _guess_slot_type(cls, key: str) -> str: + if 'mem' in key: + return 'bytes' + return 'count' + + @classmethod + def from_policy(cls, policy: Mapping[str, Any], slot_types: Mapping) -> 'ResourceSlot': + try: + data = { + k: cls._normalize_value(v, slot_types[k]) + for k, v in policy['total_resource_slots'].items() + if v is not None and k in slot_types + } + # fill missing (depending on the policy for unspecified) + fill = Decimal(0) + if policy['default_for_unspecified'] == DefaultForUnspecified.UNLIMITED: + fill = Decimal('Infinity') + for k in slot_types.keys(): + if k not in data: + data[k] = fill + except KeyError as e: + raise ValueError('unit unknown for slot', e.args[0]) + return cls(data) + + @classmethod + def from_user_input(cls, obj: Mapping[str, Any], slot_types: Optional[Mapping]) -> 'ResourceSlot': + try: + if slot_types is None: + data = { + k: cls._normalize_value(v, cls._guess_slot_type(k)) for k, v in obj.items() + if v is not None + } + else: + data = { + k: cls._normalize_value(v, slot_types[k]) for k, v in obj.items() + if v is not None + } + # fill missing + for k in slot_types.keys(): + if k not in data: + data[k] = Decimal(0) + except KeyError as e: + raise ValueError('unit unknown for slot', e.args[0]) + return cls(data) + + def to_humanized(self, slot_types: Mapping) -> Mapping[str, str]: + try: + return { + k: type(self)._humanize_value(v, slot_types[k]) for k, v in self.data.items() + if v is not None + } + except KeyError as e: + raise ValueError('unit unknown for slot', e.args[0]) + + @classmethod + def from_json(cls, obj: Mapping[str, Any]) -> 'ResourceSlot': + data = { + k: Decimal(v) for k, v in obj.items() + if v is not None + } + return cls(data) + + def to_json(self) -> Mapping[str, str]: + return { + k: _stringify_number(Decimal(v)) for k, v in self.data.items() + if v is not None + } + + +class JSONSerializableMixin(metaclass=ABCMeta): + + @abstractmethod + def to_json(self) -> dict[str, Any]: + raise NotImplementedError + + @classmethod + def from_json(cls, obj: Mapping[str, Any]) -> JSONSerializableMixin: + return cls(**cls.as_trafaret().check(obj)) + + @classmethod + @abstractmethod + def as_trafaret(cls) -> t.Trafaret: + raise NotImplementedError + + +@attr.define(slots=True) +class VFolderMount(JSONSerializableMixin): + name: str + vfid: uuid.UUID + vfsubpath: PurePosixPath + host_path: PurePosixPath + kernel_path: PurePosixPath + mount_perm: MountPermission + + def to_json(self) -> dict[str, Any]: + return { + 'name': self.name, + 'vfid': str(self.vfid), + 'vfsubpath': str(self.vfsubpath), + 'host_path': str(self.host_path), + 'kernel_path': str(self.kernel_path), + 'mount_perm': self.mount_perm.value, + } + + @classmethod + def from_json(cls, obj: Mapping[str, Any]) -> VFolderMount: + return cls(**cls.as_trafaret().check(obj)) + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + from . import validators as tx + return t.Dict({ + t.Key('name'): t.String, + t.Key('vfid'): tx.UUID, + t.Key('vfsubpath', default="."): tx.PurePath, + t.Key('host_path'): tx.PurePath, + t.Key('kernel_path'): tx.PurePath, + t.Key('mount_perm'): tx.Enum(MountPermission), + }) + + +class ImageRegistry(TypedDict): + name: str + url: str + username: Optional[str] + password: Optional[str] + + +class ImageConfig(TypedDict): + canonical: str + architecture: str + digest: str + repo_digest: Optional[str] + registry: ImageRegistry + labels: Mapping[str, str] + + +class ServicePort(TypedDict): + name: str + protocol: ServicePortProtocols + container_ports: Sequence[int] + host_ports: Sequence[Optional[int]] + + +class ClusterInfo(TypedDict): + mode: ClusterMode + size: int + replicas: Mapping[str, int] # per-role kernel counts + network_name: Optional[str] + ssh_keypair: Optional[ClusterSSHKeyPair] + + +class ClusterSSHKeyPair(TypedDict): + public_key: str # OpenSSH authorized-keys compatible format + private_key: str # PEM-encoded string + + +class DeviceModelInfo(TypedDict): + device_id: DeviceId + model_name: str + data: Mapping[str, Any] + + +class KernelCreationResult(TypedDict): + id: KernelId + container_id: ContainerId + service_ports: Sequence[ServicePort] + kernel_host: str + resource_spec: Mapping[str, Any] + attached_devices: Mapping[DeviceName, Sequence[DeviceModelInfo]] + repl_in_port: int + repl_out_port: int + stdin_port: int # legacy + stdout_port: int # legacy + + +class KernelCreationConfig(TypedDict): + image: ImageConfig + auto_pull: AutoPullBehavior + session_type: SessionTypes + cluster_mode: ClusterMode + cluster_role: str # the kernel's role in the cluster + cluster_idx: int # the kernel's index in the cluster + cluster_hostname: str # the kernel's hostname in the cluster + resource_slots: Mapping[str, str] # json form of ResourceSlot + resource_opts: Mapping[str, str] # json form of resource options + environ: Mapping[str, str] + mounts: Sequence[Mapping[str, Any]] # list of serialized VFolderMount + package_directory: Sequence[str] + idle_timeout: int + bootstrap_script: Optional[str] + startup_command: Optional[str] + internal_data: Optional[Mapping[str, Any]] + preopen_ports: List[int] + + +class KernelEnqueueingConfig(TypedDict): + image_ref: ImageRef + cluster_role: str + cluster_idx: int + cluster_hostname: str + creation_config: dict + bootstrap_script: str + startup_command: str + + +def _stringify_number(v: Union[BinarySize, int, float, Decimal]) -> str: + ''' + Stringify a number, preventing unwanted scientific notations. + ''' + if isinstance(v, (float, Decimal)): + if math.isinf(v) and v > 0: + result = 'Infinity' + elif math.isinf(v) and v < 0: + result = '-Infinity' + else: + result = '{:f}'.format(v) + elif isinstance(v, BinarySize): + result = '{:d}'.format(int(v)) + elif isinstance(v, int): + result = '{:d}'.format(v) + else: + result = str(v) + return result + + +class Sentinel(enum.Enum): + TOKEN = 0 + + +class QueueSentinel(enum.Enum): + CLOSED = 0 + TIMEOUT = 1 + + +class EtcdRedisConfig(TypedDict, total=False): + addr: Optional[HostPortPair] + sentinel: Optional[Union[str, List[HostPortPair]]] + service_name: Optional[str] + password: Optional[str] + + +@attr.s(auto_attribs=True) +class RedisConnectionInfo: + client: aioredis.Redis | aioredis.sentinel.Sentinel + service_name: Optional[str] + + async def close(self) -> None: + if isinstance(self.client, aioredis.Redis): + await self.client.close() diff --git a/src/ai/backend/common/utils.py b/src/ai/backend/common/utils.py new file mode 100644 index 0000000000..9a18add8fe --- /dev/null +++ b/src/ai/backend/common/utils.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import base64 +from collections import OrderedDict +from datetime import timedelta +from itertools import chain +import numbers +import random +import re +import sys +from typing import ( + Any, + Iterable, + Iterator, + Mapping, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) +import uuid +if TYPE_CHECKING: + from decimal import Decimal + +# It is a bad practice to keep all "miscellaneous" stuffs +# into the single "utils" module. +# Let's categorize them by purpose and domain, and keep +# refactoring to use the proper module names. + +from .asyncio import ( # for legacy imports # noqa + AsyncBarrier, + cancel_tasks, + current_loop, + run_through, +) +from .enum_extension import StringSetFlag # for legacy imports # noqa +from .files import AsyncFileWriter # for legacy imports # noqa +from .networking import ( # for legacy imports # noqa + curl, + find_free_port, +) +from .types import BinarySize + + +KT = TypeVar('KT') +VT = TypeVar('VT') + + +def env_info() -> str: + """ + Returns a string that contains the Python version and runtime path. + """ + v = sys.version_info + pyver = f'Python {v.major}.{v.minor}.{v.micro}' + if v.releaselevel == 'alpha': + pyver += 'a' + if v.releaselevel == 'beta': + pyver += 'b' + if v.releaselevel == 'candidate': + pyver += 'rc' + if v.releaselevel != 'final': + pyver += str(v.serial) + return f'{pyver} (env: {sys.prefix})' + + +def odict(*args: Tuple[KT, VT]) -> OrderedDict[KT, VT]: + """ + A short-hand for the constructor of OrderedDict. + :code:`odict(('a',1), ('b',2))` is equivalent to + :code:`OrderedDict([('a',1), ('b',2)])`. + """ + return OrderedDict(args) + + +def dict2kvlist(o: Mapping[KT, VT]) -> Iterable[Union[KT, VT]]: + """ + Serializes a dict-like object into a generator of the flatten list of + repeating key-value pairs. It is useful when using HMSET method in Redis. + + Example: + + >>> list(dict2kvlist({'a': 1, 'b': 2})) + ['a', 1, 'b', 2] + """ + return chain.from_iterable((k, v) for k, v in o.items()) + + +def generate_uuid() -> str: + u = uuid.uuid4() + # Strip the last two padding characters because u always has fixed length. + return base64.urlsafe_b64encode(u.bytes)[:-2].decode('ascii') + + +def get_random_seq(length: float, num_points: int, min_distance: float) -> Iterator[float]: + """ + Generate a random sequence of numbers within the range [0, length] + with the given number of points and the minimum distance between the points. + + Note that X ( = the minimum distance d x the number of points N) must be equivalent to or smaller than + the length L + d to guarantee the the minimum distance between the points. + If X == L + d, the points are always equally spaced with d. + + :return: An iterator over the generated sequence + """ + assert num_points * min_distance <= length + min_distance, \ + 'There are too many points or it has a too large distance which cannot be fit into the given length.' + extra = length - (num_points - 1) * min_distance + ro = [random.uniform(0, 1) for _ in range(num_points + 1)] + sum_ro = sum(ro) + rn = [extra * r / sum_ro for r in ro[0:num_points]] + spacing = [min_distance + rn[i] for i in range(num_points)] + cumulative_sum = 0.0 + for s in spacing: + cumulative_sum += s + yield cumulative_sum - min_distance + + +def nmget( + o: Mapping[str, Any], + key_path: str, + def_val: Any = None, + path_delimiter: str = '.', + null_as_default: bool = True, +) -> Any: + """ + A short-hand for retrieving a value from nested mappings + ("nested-mapping-get"). At each level it checks if the given "path" + component in the given key exists and return the default value whenever + fails. + + Example: + >>> o = {'a':{'b':1}, 'x': None} + >>> nmget(o, 'a', 0) + {'b': 1} + >>> nmget(o, 'a.b', 0) + 1 + >>> nmget(o, 'a/b', 0, '/') + 1 + >>> nmget(o, 'a.c', 0) + 0 + >>> nmget(o, 'x', 0) + 0 + >>> nmget(o, 'x', 0, null_as_default=False) + None + """ + pieces = key_path.split(path_delimiter) + while pieces: + p = pieces.pop(0) + if o is None or p not in o: + return def_val + o = o[p] + if o is None and null_as_default: + return def_val + return o + + +def readable_size_to_bytes(expr: Any) -> BinarySize | Decimal: + if isinstance(expr, numbers.Real): + return BinarySize(expr) + return BinarySize.from_str(expr) + + +def str_to_timedelta(tstr: str) -> timedelta: + """ + Convert humanized timedelta string into a Python timedelta object. + + Example: + >>> str_to_timedelta('30min') + datetime.timedelta(seconds=1800) + >>> str_to_timedelta('1d1hr') + datetime.timedelta(days=1, seconds=3600) + >>> str_to_timedelta('2hours 15min') + datetime.timedelta(seconds=8100) + >>> str_to_timedelta('20sec') + datetime.timedelta(seconds=20) + >>> str_to_timedelta('300') + datetime.timedelta(seconds=300) + >>> str_to_timedelta('-1day') + datetime.timedelta(days=-1) + """ + _rx = re.compile(r'(?P[+|-])?\s*' + r'((?P\d+(\.\d+)?)(d|day|days))?\s*' + r'((?P\d+(\.\d+)?)(h|hr|hrs|hour|hours))?\s*' + r'((?P\d+(\.\d+)?)(m|min|mins|minute|minutes))?\s*' + r'((?P\d+(\.\d+)?)(s|sec|secs|second|seconds))?$') + match = _rx.match(tstr) + if not match: + try: + return timedelta(seconds=float(tstr)) # consider bare number string as seconds + except TypeError: + pass + raise ValueError('Invalid time expression') + groups = match.groupdict() + sign = groups.pop('sign', None) + if set(groups.values()) == {None}: + raise ValueError('Invalid time expression') + params = {n: -float(t) if sign == '-' else float(t) for n, t in groups.items() if t} + return timedelta(**params) # type: ignore + + +class FstabEntry: + """ + Entry class represents a non-comment line on the `fstab` file. + """ + def __init__(self, device, mountpoint, fstype, options, d=0, p=0) -> None: + self.device = device + self.mountpoint = mountpoint + self.fstype = fstype + if not options: + options = 'defaults' + self.options = options + self.d = d + self.p = p + + def __eq__(self, o): + return str(self) == str(o) + + def __str__(self): + return "{} {} {} {} {} {}".format(self.device, + self.mountpoint, + self.fstype, + self.options, + self.d, + self.p) + + +class Fstab: + """ + Reader/writer for fstab file. + Takes aiofile pointer for async I/O. It should be writable if add/remove + operations are needed. + + NOTE: This class references Jorge Niedbalski R.'s gist snippet. + We have been converted it to be compatible with Python 3 + and to support async I/O. + (https://gist.github.com/niedbalski/507e974ed2d54a87ad37) + """ + def __init__(self, fp) -> None: + self._fp = fp + + def _hydrate_entry(self, line): + return FstabEntry(*[x for x in line.strip('\n').split(' ') if x not in ('', None)]) + + async def get_entries(self): + await self._fp.seek(0) + for line in await self._fp.readlines(): + try: + line = line.strip() + if not line.startswith("#"): + yield self._hydrate_entry(line) + except TypeError: + pass + + async def get_entry_by_attr(self, attr, value): + async for entry in self.get_entries(): + e_attr = getattr(entry, attr) + if e_attr == value: + return entry + return None + + async def add_entry(self, entry): + if await self.get_entry_by_attr('device', entry.device): + return False + await self._fp.write(str(entry) + '\n') + await self._fp.truncate() + return entry + + async def add(self, device, mountpoint, fstype, options=None, d=0, p=0): + return await self.add_entry(FstabEntry(device, mountpoint, fstype, options, d, p)) + + async def remove_entry(self, entry): + await self._fp.seek(0) + lines = await self._fp.readlines() + found = False + for index, line in enumerate(lines): + try: + if not line.strip().startswith("#"): + if self._hydrate_entry(line) == entry: + found = True + break + except TypeError: + pass + if not found: + return False + lines.remove(line) + await self._fp.seek(0) + await self._fp.write(''.join(lines)) + await self._fp.truncate() + return True + + async def remove_by_mountpoint(self, mountpoint): + entry = await self.get_entry_by_attr('mountpoint', mountpoint) + if entry: + return await self.remove_entry(entry) + return False diff --git a/src/ai/backend/common/validators.py b/src/ai/backend/common/validators.py new file mode 100644 index 0000000000..e3d99b978a --- /dev/null +++ b/src/ai/backend/common/validators.py @@ -0,0 +1,616 @@ +''' +An extension module to Trafaret which provides additional type checkers. +''' + +import datetime +from decimal import Decimal +import enum +import ipaddress +import json +import os +from pathlib import ( + PurePath as _PurePath, + Path as _Path, +) +import re +from typing import ( + Any, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) +import uuid +import pwd + +import dateutil.tz +from dateutil.relativedelta import relativedelta +try: + import jwt + jwt_available = True +except ImportError: + jwt_available = False +import multidict +import trafaret as t +from trafaret.base import TrafaretMeta +from trafaret.lib import _empty +import yarl + +from .types import ( + BinarySize as _BinarySize, + HostPortPair as _HostPortPair, +) + +__all__ = ( + 'AliasedKey', + 'MultiKey', + 'BinarySize', + 'DelimiterSeperatedList', + 'StringList', + 'Enum', + 'JSONString', + 'PurePath', + 'Path', + 'IPNetwork', + 'IPAddress', + 'HostPortPair', + 'PortRange', + 'UserID', + 'GroupID', + 'UUID', + 'TimeZone', + 'TimeDuration', + 'Slug', + 'URL', +) + + +def fix_trafaret_pickle_support(): + + def __reduce__(self): + return (type(self), (self.error, self.name, self.value, self.trafaret, self.code)) + + t.DataError.__reduce__ = __reduce__ + + +class StringLengthMeta(TrafaretMeta): + ''' + A metaclass that makes string-like trafarets to have sliced min/max length indicator. + ''' + + def __getitem__(cls, slice_): + return cls(min_length=slice_.start, max_length=slice_.stop) + + +class AliasedKey(t.Key): + ''' + An extension to trafaret.Key which accepts multiple aliases of a single key. + When successfully matched, the returned key name is the first one of the given aliases + or the renamed key set via ``to_name()`` method or the ``>>`` operator. + ''' + + def __init__(self, names: Sequence[str], **kwargs) -> None: + super().__init__(names[0], **kwargs) + self.names = names + + def __call__(self, data, context=None): + for name in self.names: + if name in data: + key = name + break + else: + key = None + + if key is None: # not specified + if self.default is not _empty: + default = self.default() if callable(self.default) else self.default + try: + result = self.trafaret(default, context=context) + except t.DataError as inner_error: + yield self.get_name(), inner_error, self.names + else: + yield self.get_name(), result, self.names + return + if not self.optional: + yield self.get_name(), t.DataError(error='is required'), self.names + # if optional, just bypass + else: + try: + result = self.trafaret(data[key], context=context) + except t.DataError as inner_error: + yield key, inner_error, self.names + else: + yield self.get_name(), result, self.names + + +class MultiKey(t.Key): + + def get_data(self, data, default): + if isinstance(data, (multidict.MultiDict, multidict.MultiDictProxy)): + return data.getall(self.name, default) + # fallback for plain dicts + raw_value = data.get(self.name, default) + if isinstance(raw_value, (List, Tuple)): + # if plain dict already contains list of values, just return it. + return raw_value + # otherwise, wrap the value in a list. + return [raw_value] + + +class BinarySize(t.Trafaret): + + def check_and_return(self, value: Any) -> Union[_BinarySize, Decimal]: + try: + if not isinstance(value, str): + value = str(value) + return _BinarySize.from_str(value) + except ValueError: + self._failure('value is not a valid binary size', value=value) + + +T_commalist = TypeVar('T_commalist', bound=Type) + + +class DelimiterSeperatedList(t.Trafaret): + + def __init__(self, value_cls: Optional[T_commalist], *, delimiter: str = ',') -> None: + self.delimiter = delimiter + self.value_cls = value_cls + + def check_and_return(self, value: Any) -> Sequence[T_commalist]: + try: + if not isinstance(value, str): + value = str(value) + splited = value.split(self.delimiter) + if self.value_cls: + return [self.value_cls().check_and_return(x) for x in splited] + else: + return splited + except ValueError: + self._failure('value is not a string or not convertible to string', value=value) + + +class StringList(DelimiterSeperatedList): + + def __init__(self, *, delimiter: str = ',') -> None: + super().__init__(None, delimiter=delimiter) + + +T_enum = TypeVar('T_enum', bound=enum.Enum) + + +class Enum(t.Trafaret): + + def __init__(self, enum_cls: Type[T_enum], *, use_name: bool = False) -> None: + self.enum_cls = enum_cls + self.use_name = use_name + + def check_and_return(self, value: Any) -> T_enum: + try: + if self.use_name: + return self.enum_cls[value] + else: + return self.enum_cls(value) + except (KeyError, ValueError): + self._failure(f'value is not a valid member of {self.enum_cls.__name__}', + value=value) + + +class JSONString(t.Trafaret): + + def check_and_return(self, value: Any) -> dict: + try: + return json.loads(value) + except (KeyError, ValueError): + self._failure('value is not a valid JSON string', value=value) + + +class PurePath(t.Trafaret): + + def __init__( + self, *, + base_path: _PurePath = None, + relative_only: bool = False, + ) -> None: + super().__init__() + self._base_path = base_path + self._relative_only = relative_only + + def check_and_return(self, value: Any) -> _PurePath: + p = _PurePath(value) + if self._relative_only and p.is_absolute(): + self._failure('expected relative path but the value is absolute', value=value) + if self._base_path is not None: + try: + p.relative_to(self._base_path) + except ValueError: + self._failure('value is not in the base path', value=value) + return p + + +class Path(PurePath): + + def __init__( + self, *, + type: Literal['dir', 'file'], + base_path: _Path = None, + auto_create: bool = False, + allow_nonexisting: bool = False, + allow_devnull: bool = False, + relative_only: bool = False, + resolve: bool = True, + ) -> None: + super().__init__( + base_path=base_path, + relative_only=relative_only, + ) + self._type = type + if auto_create and type != 'dir': + raise TypeError('Only directory paths can be set auto-created.') + self._auto_create = auto_create + self._allow_nonexisting = allow_nonexisting + self._allow_devnull = allow_devnull + self._resolve = resolve + + def check_and_return(self, value: Any) -> _Path: + try: + p = _Path(value).resolve() if self._resolve else _Path(value) + except (TypeError, ValueError): + self._failure('cannot parse value as a path', value=value) + if self._relative_only and p.is_absolute(): + self._failure('expected relative path but the value is absolute', value=value) + if self._base_path is not None: + try: + _base_path = _Path(self._base_path).resolve() if self._resolve else self._base_path + p.relative_to(_base_path) + except ValueError: + self._failure('value is not in the base path', value=value) + if self._type == 'dir': + if self._auto_create: + p.mkdir(parents=True, exist_ok=True) + if not self._allow_nonexisting and not p.is_dir(): + self._failure('value is not a directory', value=value) + elif self._type == 'file': + if not self._allow_devnull and str(p) == os.devnull: + # it may be not a regular file but a char-device. + return p + if not self._allow_nonexisting and not p.is_file(): + self._failure('value is not a regular file', value=value) + return p + + +class IPNetwork(t.Trafaret): + + def check_and_return(self, value: Any) -> ipaddress._BaseNetwork: + try: + return ipaddress.ip_network(value) + except ValueError: + self._failure('Invalid IP network format', value=value) + + +class IPAddress(t.Trafaret): + + def check_and_return(self, value: Any) -> ipaddress._BaseAddress: + try: + return ipaddress.ip_address(value) + except ValueError: + self._failure('Invalid IP address format', value=value) + + +class HostPortPair(t.Trafaret): + + def __init__(self, *, allow_blank_host: bool = False) -> None: + super().__init__() + self._allow_blank_host = allow_blank_host + + def check_and_return(self, value: Any) -> Tuple[ipaddress._BaseAddress, int]: + host: str | ipaddress._BaseAddress + if isinstance(value, str): + pair = value.rsplit(':', maxsplit=1) + if len(pair) == 1: + self._failure('value as string must contain both address and number', value=value) + host, port = pair[0], pair[1] + elif isinstance(value, Sequence): + if len(value) != 2: + self._failure('value as array must contain only two values for address and number', value=value) + host, port = value[0], value[1] + elif isinstance(value, Mapping): + try: + host, port = value['host'], value['port'] + except KeyError: + self._failure('value as map must contain "host" and "port" keys', value=value) + else: + self._failure('urecognized value type', value=value) + try: + if isinstance(host, str): + host = ipaddress.ip_address(host.strip('[]')) + elif isinstance(host, ipaddress._BaseAddress): + pass + except ValueError: + pass # just treat as a string hostname + if not self._allow_blank_host and not host: + self._failure('value has empty host', value=value) + try: + port = t.ToInt[1:65535].check(port) + except t.DataError: + self._failure('port number must be between 1 and 65535', value=value) + return _HostPortPair(host, port) + + +class PortRange(t.Trafaret): + + def check_and_return(self, value: Any) -> Tuple[int, int]: + if isinstance(value, str): + try: + value = tuple(map(int, value.split('-'))) + except (TypeError, ValueError): + self._failure('value as string should be a hyphen-separated pair of integers', value=value) + elif isinstance(value, Sequence): + if len(value) != 2: + self._failure('value as array must contain only two values', value=value) + else: + self._failure('urecognized value type', value=value) + try: + min_port = t.Int[1:65535].check(value[0]) + max_port = t.Int[1:65535].check(value[1]) + except t.DataError: + self._failure('each value must be a valid port number', value=value) + if not (min_port < max_port): + self._failure('first value must be less than second value', value=value) + return min_port, max_port + + +class UserID(t.Trafaret): + + def __init__(self, *, default_uid: int = None) -> None: + super().__init__() + self._default_uid = default_uid + + def check_and_return(self, value: Any) -> int: + if value is None: + if self._default_uid is not None: + return self._default_uid + else: + return os.getuid() + elif isinstance(value, int): + if value == -1: + return os.getuid() + elif isinstance(value, str): + if not value: + if self._default_uid is not None: + return self._default_uid + else: + return os.getuid() + try: + value = int(value) + except ValueError: + try: + return pwd.getpwnam(value).pw_uid + except KeyError: + self._failure('no such user in system', value=value) + else: + return self.check_and_return(value) + else: + self._failure('value must be either int or str', value=value) + return value + + +class GroupID(t.Trafaret): + + def __init__(self, *, default_gid: int = None) -> None: + super().__init__() + self._default_gid = default_gid + + def check_and_return(self, value: Any) -> int: + if value is None: + if self._default_gid is not None: + return self._default_gid + else: + return os.getgid() + elif isinstance(value, int): + if value == -1: + return os.getgid() + elif isinstance(value, str): + if not value: + if self._default_gid is not None: + return self._default_gid + else: + return os.getgid() + try: + value = int(value) + except ValueError: + try: + return pwd.getpwnam(value).pw_gid + except KeyError: + self._failure('no such group in system', value=value) + else: + return self.check_and_return(value) + else: + self._failure('value must be either int or str', value=value) + return value + + +class UUID(t.Trafaret): + + def check_and_return(self, value: Any) -> uuid.UUID: + try: + if isinstance(value, uuid.UUID): + return value + if isinstance(value, str): + return uuid.UUID(value) + elif isinstance(value, bytes): + return uuid.UUID(bytes=value) + else: + self._failure('value must be string or bytes', value=value) + except ValueError: + self._failure('cannot convert value to UUID', value=value) + + +class TimeZone(t.Trafaret): + + def check_and_return(self, value: Any) -> datetime.tzinfo: + if not isinstance(value, str): + self._failure('value must be string', value=value) + tz = dateutil.tz.gettz(value) + if tz is None: + self._failure('value is not a known timezone', value=value) + return tz + + +class TimeDuration(t.Trafaret): + ''' + Represent the relative difference between two datetime objects, + parsed from human-readable time duration expression strings. + + If you specify years or months, it returns an + :class:`dateutil.relativedelta.relativedelta` instance + which keeps the human-friendly year and month calculation + considering leap years and monthly day count differences, + instead of simply multiplying 365 days to the value of years. + Otherwise, it returns the stdlib's :class:`datetime.timedelta` + instance. + + Example: + >>> t = datetime(2020, 2, 29) + >>> t + check_and_return(years=1) + datetime.datetime(2021, 2, 28, 0, 0) + >>> t + check_and_return(years=2) + datetime.datetime(2022, 2, 28, 0, 0) + >>> t + check_and_return(years=3) + datetime.datetime(2023, 2, 28, 0, 0) + >>> t + check_and_return(years=4) + datetime.datetime(2024, 2, 29, 0, 0) # preserves the same day of month + ''' + def __init__(self, *, allow_negative: bool = False) -> None: + self._allow_negative = allow_negative + + def check_and_return(self, value: Any) -> Union[datetime.timedelta, relativedelta]: + if not isinstance(value, (int, float, str)): + self._failure('value must be a number or string', value=value) + if isinstance(value, (int, float)): + return datetime.timedelta(seconds=value) + assert isinstance(value, str) + if len(value) == 0: + self._failure('value must not be empty', value=value) + try: + unit = value[-1] + if unit.isdigit(): + t = float(value) + if not self._allow_negative and t < 0: + self._failure('value must be positive', value=value) + return datetime.timedelta(seconds=t) + elif value[-2:].isalpha(): + t = int(value[:-2]) + if not self._allow_negative and t < 0: + self._failure('value must be positive', value=value) + if value[-2:] == 'yr': + return relativedelta(years=t) + elif value[-2:] == 'mo': + return relativedelta(months=t) + else: + self._failure('value is not a known time duration', value=value) + else: + t = float(value[:-1]) + if not self._allow_negative and t < 0: + self._failure('value must be positive', value=value) + if value[-1] == 'w': + return datetime.timedelta(weeks=t) + elif value[-1] == 'd': + return datetime.timedelta(days=t) + elif value[-1] == 'h': + return datetime.timedelta(hours=t) + elif value[-1] == 'm': + return datetime.timedelta(minutes=t) + else: + self._failure('value is not a known time duration', value=value) + except ValueError: + self._failure(f'invalid numeric literal: {value[:-1]}', value=value) + + +class Slug(t.Trafaret, metaclass=StringLengthMeta): + + _rx_slug = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?$') + + def __init__(self, *, min_length: Optional[int] = None, max_length: Optional[int] = None, + allow_dot: bool = False) -> None: + super().__init__() + self._allow_dot = allow_dot + if min_length is not None and min_length < 0: + raise TypeError('min_length must be larger than or equal to zero.') + if max_length is not None and max_length < 0: + raise TypeError('max_length must be larger than or equal to zero.') + if max_length is not None and min_length is not None and min_length > max_length: + raise TypeError('min_length must be less than or equal to max_length when both set.') + self._min_length = min_length + self._max_length = max_length + + def check_and_return(self, value: Any) -> str: + if isinstance(value, str): + if self._min_length is not None and len(value) < self._min_length: + self._failure(f'value is too short (min length {self._min_length})', value=value) + if self._max_length is not None and len(value) > self._max_length: + self._failure(f'value is too long (max length {self._max_length})', value=value) + if self._allow_dot and value.startswith('.'): + checked_value = value[1:] + else: + checked_value = value + m = type(self)._rx_slug.search(checked_value) + if not m: + self._failure('value must be a valid slug.', value=value) + else: + self._failure('value must be a string', value=value) + return value + + +if jwt_available: + class JsonWebToken(t.Trafaret): + + default_algorithms = ['HS256'] + + def __init__( + self, *, + secret: str, + inner_iv: t.Trafaret = None, + algorithms: list[str] = default_algorithms, + ) -> None: + self.secret = secret + self.algorithms = algorithms + self.inner_iv = inner_iv + + def check_and_return(self, value: Any) -> Mapping[str, Any]: + try: + token_data = jwt.decode(value, self.secret, algorithms=self.algorithms) + if self.inner_iv is not None: + return self.inner_iv.check(token_data) + return token_data + except jwt.PyJWTError as e: + self._failure(f'cannot decode the given value as JWT: {e}', value=value) + + +class URL(t.Trafaret): + + rx_scheme = re.compile(r"^[-a-z0-9]+://") + + def __init__( + self, *, + scheme_required: bool = True, + ) -> None: + self.scheme_required = scheme_required + + def check_and_return(self, value: Any) -> yarl.URL: + if not isinstance(value, (str, bytes)): + self._failure("A URL must be a unicode string or a byte sequence", value=value) + if isinstance(value, bytes): + value = value.decode('utf-8') + if self.scheme_required: + if not self.rx_scheme.match(value): + self._failure("The given value does not have the scheme (protocol) part", value=value) + try: + return yarl.URL(value) + except ValueError as e: + self._failure(f"cannot convert the given value to URL (error: {e!r})", value=value) diff --git a/src/ai/backend/helpers/BUILD b/src/ai/backend/helpers/BUILD new file mode 100644 index 0000000000..b533fe578b --- /dev/null +++ b/src/ai/backend/helpers/BUILD @@ -0,0 +1,31 @@ +python_sources( + name="lib", + dependencies=[ + ":resources", + ], +) + +# This distribution is not actually uploaded to PyPI. +# We have separate krunner static distributions. +python_distribution( + name="dist", + dependencies=[ + ":lib", + ], + provides=python_artifact( + name="backend.ai-kernel-helper", + description="Backend.AI Kernel Runner Prebuilt Binaries Package", + license="LGPLv3", + ), + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +target( + name="resources", + dependencies=[ + ":version", + ], +) diff --git a/src/ai/backend/helpers/README.md b/src/ai/backend/helpers/README.md new file mode 100644 index 0000000000..0a838d7193 --- /dev/null +++ b/src/ai/backend/helpers/README.md @@ -0,0 +1 @@ +# Backend.AI In-kernel Helper Package diff --git a/src/ai/backend/helpers/VERSION b/src/ai/backend/helpers/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/helpers/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/helpers/__init__.py b/src/ai/backend/helpers/__init__.py new file mode 100644 index 0000000000..8688f5d222 --- /dev/null +++ b/src/ai/backend/helpers/__init__.py @@ -0,0 +1,9 @@ +''' +A helper package for user-written Python codes. +''' + +from .package import install + +__all__ = ( + 'install', +) diff --git a/src/ai/backend/helpers/package.py b/src/ai/backend/helpers/package.py new file mode 100644 index 0000000000..82d9350241 --- /dev/null +++ b/src/ai/backend/helpers/package.py @@ -0,0 +1,44 @@ +from collections import namedtuple +from pathlib import Path +import pkg_resources +import site +import subprocess +import sys + +Package = namedtuple('Package', 'name version is_user') + +__all__ = ( + 'install', +) + + +def install(pkgname, force_install=False): + ''' + Install a Python package from pypi.org or the given index server. + The package is installed inside the user site directory. + ''' + + if not force_install: + user_path = Path(site.USER_SITE).resolve() + installed_pkgs = [] + for pkg in pkg_resources.working_set: + pkg_path = Path(pkg.location).resolve() + is_user = user_path in pkg_path.parents + installed_pkgs.append(Package(pkg.key, pkg.version, is_user)) + + for pkg in installed_pkgs: + if pkgname.lower() == pkg.name.lower(): + print(f"'{pkg.name}' is already installed (version: {pkg.version}).") + return + + sys.stdout.flush() + cmdargs = [sys.executable, '-m', 'pip', 'install', '--user'] + if force_install: + cmdargs.append('-I') + cmdargs.append(pkgname) + subprocess.call(cmdargs) + sys.stdout.flush() + + # Ensure the user site directory to be in sys.path + if site.USER_SITE not in sys.path: + sys.path.insert(0, site.USER_SITE) diff --git a/src/ai/backend/kernel/BUILD b/src/ai/backend/kernel/BUILD new file mode 100644 index 0000000000..f912aca6d5 --- /dev/null +++ b/src/ai/backend/kernel/BUILD @@ -0,0 +1,42 @@ +python_sources( + name="lib", + sources=[ + "**/*.py", + ], + dependencies=[ + ":resources", + ], +) + +python_distribution( + name="dist", + dependencies=[ + ":lib", + ], + provides=python_artifact( + name="backend.ai-kernel", + description="Backend.AI Kernel Runner", + license="LGPLv3", + ), + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +target( + name="resources", + dependencies=[ + ":version", + ], +) + +python_requirements( + name="reqs", + source="requirements.txt", + resolve="python-kernel", + module_mapping={ + "attrs": ["attr", "attrs"], + "pyzmq": ["zmq"], + }, +) diff --git a/src/ai/backend/kernel/README.md b/src/ai/backend/kernel/README.md new file mode 100644 index 0000000000..36c5db2031 --- /dev/null +++ b/src/ai/backend/kernel/README.md @@ -0,0 +1 @@ +# Backend.AI Kernel Runner diff --git a/src/ai/backend/kernel/VERSION b/src/ai/backend/kernel/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/kernel/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/kernel/__init__.py b/src/ai/backend/kernel/__init__.py new file mode 100644 index 0000000000..1c764c0d3f --- /dev/null +++ b/src/ai/backend/kernel/__init__.py @@ -0,0 +1,30 @@ +from .base import BaseRunner +from .terminal import Terminal + + +__all__ = ( + 'BaseRunner', + 'Terminal', + 'lang_map', +) + +lang_map = { + 'app': 'ai.backend.kernel.app.Runner', + 'python': 'ai.backend.kernel.python.Runner', + 'c': 'ai.backend.kernel.c.Runner', + 'cpp': 'ai.backend.kernel.cpp.Runner', + 'golang': 'ai.backend.kernel.golang.Runner', + 'rust': 'ai.backend.kernel.rust.Runner', + 'java': 'ai.backend.kernel.java.Runner', + 'haskell': 'ai.backend.kernel.haskell.Runner', + 'julia': 'ai.backend.kernel.julia.Runner', + 'lua': 'ai.backend.kernel.lua.Runner', + 'nodejs': 'ai.backend.kernel.nodejs.Runner', + 'octave': 'ai.backend.kernel.octave.Runner', + 'php': 'ai.backend.kernel.php.Runner', + 'r': 'ai.backend.kernel.r.Runner', + 'scheme': 'ai.backend.kernel.scheme.Runner', + 'git': 'ai.backend.kernel.git.Runner', + 'vendor.aws_polly': 'ai.backend.kernel.vendor.aws_polly.Runner', + 'vendor.h2o': 'ai.backend.kernel.vendor.h2o.Runner', +} diff --git a/src/ai/backend/kernel/__main__.py b/src/ai/backend/kernel/__main__.py new file mode 100644 index 0000000000..113629f04b --- /dev/null +++ b/src/ai/backend/kernel/__main__.py @@ -0,0 +1,50 @@ +''' +The kernel main program. +''' + +import argparse +import importlib +import os +from pathlib import Path +import signal +import sys + +import uvloop + +from . import lang_map +from .compat import asyncio_run_forever + + +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true', default=False) + parser.add_argument('lang', type=str, choices=lang_map.keys()) + parser.add_argument('runtime_path', type=Path, nargs='?', default=None) + return parser.parse_args(args) + + +def main(args) -> None: + cls_name = lang_map[args.lang] + imp_path, cls_name = cls_name.rsplit('.', 1) + mod = importlib.import_module(imp_path) + cls = getattr(mod, cls_name) + + if args.runtime_path is None: + runtime_path = cls.default_runtime_path + else: + runtime_path = args.runtime_path + runner = cls(runtime_path) + + # Replace stdin with a "null" file + # (trying to read stdin will raise EOFError immediately afterwards.) + sys.stdin = open(os.devnull, 'r', encoding='latin1') + asyncio_run_forever( + runner._init(args), + runner._shutdown(), + stop_signals={signal.SIGINT, signal.SIGTERM}, + ) + + +args = parse_args() +uvloop.install() +main(args) diff --git a/src/ai/backend/kernel/app/__init__.py b/src/ai/backend/kernel/app/__init__.py new file mode 100644 index 0000000000..386381e7e6 --- /dev/null +++ b/src/ai/backend/kernel/app/__init__.py @@ -0,0 +1,57 @@ +""" +This is a special kernel runner for application-only containers +which do not provide query/batch-mode code execution. +""" + +import logging +import os +from typing import List + +from .. import BaseRunner + +log = logging.getLogger() + +DEFAULT_PYFLAGS: List[str] = [] + + +class Runner(BaseRunner): + + log_prefix = 'app-kernel' + default_runtime_path = '/opt/backend.ai/bin/python' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/bash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': ':'.join([ + '/usr/local/nvidia/bin', + '/usr/local/cuda/bin', + '/usr/local/sbin', + '/usr/local/bin', + '/usr/sbin', + '/usr/bin', + '/sbin', + '/bin', + ]), + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + log.warning('batch-mode execution is not supported') + return 0 + + async def execute_heuristic(self) -> int: + log.warning('batch-mode execution is not supported') + return 0 + + async def start_service(self, service_info): + # app kernels use service-definition templates. + return None, {} diff --git a/src/ai/backend/kernel/base.py b/src/ai/backend/kernel/base.py new file mode 100644 index 0000000000..8da04db727 --- /dev/null +++ b/src/ai/backend/kernel/base.py @@ -0,0 +1,841 @@ +from __future__ import annotations +from abc import ABCMeta, abstractmethod +import asyncio +import concurrent.futures +from functools import partial +import json +import logging +import os +from pathlib import Path +import signal +import sys +import time +from typing import ( + Awaitable, + ClassVar, + Dict, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, +) +import uuid + +from async_timeout import timeout +import janus +from jupyter_client import KernelManager +from jupyter_client.kernelspec import KernelSpecManager +import msgpack +import zmq + +from .service import ServiceParser +from .jupyter_client import aexecute_interactive +from .logging import BraceStyleAdapter, setup_logger +from .utils import wait_local_port_open +from .compat import current_loop +from .intrinsic import ( + init_sshd_service, + prepare_sshd_service, + prepare_ttyd_service, + prepare_vscode_service, +) + +log = BraceStyleAdapter(logging.getLogger()) + + +async def pipe_output(stream, outsock, target, log_fd): + assert target in ('stdout', 'stderr') + target = target.encode('ascii') + console_fd = sys.stdout.fileno() if target == 'stdout' else sys.stderr.fileno() + loop = current_loop() + try: + while True: + data = await stream.read(4096) + if not data: + break + await asyncio.gather( + loop.run_in_executor(None, os.write, console_fd, data), + loop.run_in_executor(None, os.write, log_fd, data), + outsock.send_multipart([target, data]), + return_exceptions=True, + ) + except asyncio.CancelledError: + pass + except Exception: + log.exception('unexpected error') + + +async def terminate_and_wait(proc: asyncio.subprocess.Process, timeout: float = 2.0) -> None: + try: + proc.terminate() + try: + await asyncio.wait_for(proc.wait(), timeout=timeout) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + except ProcessLookupError: + pass + + +def promote_path(path_env: str, path_to_promote: Union[Path, str]) -> str: + paths = path_env.split(':') + print(f"promote_path: {path_to_promote=} {path_env=}", file=sys.stderr) + path_to_promote = str(path_to_promote) + result_paths = [ + p for p in paths + if path_to_promote != p + ] + result_paths.insert(0, path_to_promote) + return ":".join(result_paths) + + +class BaseRunner(metaclass=ABCMeta): + + log_prefix: ClassVar[str] = 'generic-kernel' + log_queue: janus.Queue[logging.LogRecord] + task_queue: asyncio.Queue[Awaitable[None]] + default_runtime_path: ClassVar[Optional[str]] = None + default_child_env: ClassVar[MutableMapping[str, str]] = { + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/sh', + 'HOME': '/home/work', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + jupyter_kspec_name: ClassVar[str] = '' + kernel_mgr = None + kernel_client = None + + child_env: MutableMapping[str, str] + subproc: Optional[asyncio.subprocess.Process] + service_parser: Optional[ServiceParser] + runtime_path: Path + + services_running: Dict[str, asyncio.subprocess.Process] + + _build_success: Optional[bool] + + # Set by subclasses. + user_input_queue: Optional[asyncio.Queue[str]] + + def __init__(self, runtime_path: Path) -> None: + self.subproc = None + self.runtime_path = runtime_path + + default_child_env_path = self.default_child_env.pop("PATH", None) + self.child_env = {**os.environ, **self.default_child_env} + if default_child_env_path is not None and "PATH" not in self.child_env: + # set the default PATH env-var only when it's missing from the image + self.child_env["PATH"] = default_child_env_path + config_dir = Path('/home/config') + try: + evdata = (config_dir / 'environ.txt').read_text() + for line in evdata.splitlines(): + k, v = line.split('=', 1) + self.child_env[k] = v + os.environ[k] = v + except FileNotFoundError: + pass + except Exception: + log.exception('Reading /home/config/environ.txt failed!') + + # Add ~/.local/bin to the default PATH + self.child_env["PATH"] += os.pathsep + '~/.local/bin' + os.environ["PATH"] += os.pathsep + '~/.local/bin' + + self.started_at: float = time.monotonic() + self.services_running = {} + + # If the subclass implements interatcive user inputs, it should set a + # asyncio.Queue-like object to self.user_input_queue in the + # init_with_loop() method. + self.user_input_queue = None + + # build status tracker to skip the execute step + self._build_success = None + + async def _init(self, cmdargs) -> None: + self.cmdargs = cmdargs + loop = current_loop() + self._service_lock = asyncio.Lock() + + # Initialize event loop. + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + loop.set_default_executor(executor) + + self.zctx = zmq.asyncio.Context() + self.insock = self.zctx.socket(zmq.PULL) + self.insock.bind('tcp://*:2000') + self.outsock = self.zctx.socket(zmq.PUSH) + self.outsock.bind('tcp://*:2001') + + self.log_queue = janus.Queue() + self.task_queue = asyncio.Queue() + self.init_done = asyncio.Event() + + setup_logger(self.log_queue.sync_q, self.log_prefix, cmdargs.debug) + self._log_task = loop.create_task(self._handle_logs()) + await asyncio.sleep(0) + + service_def_folder = Path('/etc/backend.ai/service-defs') + if service_def_folder.is_dir(): + self.service_parser = ServiceParser({ + 'runtime_path': str(self.runtime_path), + }) + await self.service_parser.parse(service_def_folder) + log.debug('Loaded new-style service definitions.') + else: + self.service_parser = None + + self._main_task = loop.create_task(self.main_loop(cmdargs)) + self._run_task = loop.create_task(self.run_tasks()) + + async def _shutdown(self) -> None: + try: + self.insock.close() + log.debug('shutting down...') + self._run_task.cancel() + self._main_task.cancel() + await self._run_task + await self._main_task + log.debug('terminating service processes...') + running_procs = [*self.services_running.values()] + async with self._service_lock: + await asyncio.gather( + *(terminate_and_wait(proc) for proc in running_procs), + return_exceptions=True, + ) + await asyncio.sleep(0.01) + log.debug('terminated.') + finally: + # allow remaining logs to be flushed. + await asyncio.sleep(0.1) + try: + if self.outsock: + self.outsock.close() + await self._shutdown_jupyter_kernel() + finally: + self._log_task.cancel() + await self._log_task + + async def _init_jupyter_kernel(self) -> None: + """Detect ipython kernel spec for backend.ai and start it if found. + + Called after `init_with_loop`. `jupyter_kspec_name` should be defined to + initialize jupyter kernel. + """ + # Make inline backend defaults in Matplotlib. + kconfigdir = Path('/home/work/.ipython/profile_default/') + kconfigdir.mkdir(parents=True, exist_ok=True) + kconfig_file = kconfigdir / 'ipython_kernel_config.py' + kconfig_file.write_text("c.InteractiveShellApp.matplotlib = 'inline'") + + kernelspec_mgr = KernelSpecManager() + kernelspec_mgr.ensure_native_kernel = False + kspecs = kernelspec_mgr.get_all_specs() + for kname in kspecs: + if self.jupyter_kspec_name in kname: + log.debug('starting ' + kname + ' kernel...') + self.kernel_mgr = KernelManager(kernel_name=kname) + self.kernel_mgr.start_kernel() + if not self.kernel_mgr.is_alive(): + log.error('jupyter query mode is disabled: ' + 'failed to start jupyter kernel') + else: + self.kernel_client = self.kernel_mgr.client() + self.kernel_client.start_channels(shell=True, iopub=True, + stdin=True, hb=True) + try: + self.kernel_client.wait_for_ready(timeout=10) + # self.init_jupyter_kernel() + except RuntimeError: + # Clean up for client and kernel will be done in `shutdown`. + log.error('jupyter channel is not active!') + self.kernel_mgr = None + break + else: + log.debug('jupyter query mode is not available: ' + 'no jupyter kernelspec found') + self.kernel_mgr = None + + async def _shutdown_jupyter_kernel(self): + if self.kernel_mgr and self.kernel_mgr.is_alive(): + log.info('shutting down ' + self.jupyter_kspec_name + ' kernel...') + self.kernel_client.stop_channels() + self.kernel_mgr.shutdown_kernel() + assert not self.kernel_mgr.is_alive(), 'ipykernel failed to shutdown' + + async def _init_with_loop(self) -> None: + if self.init_done is not None: + self.init_done.clear() + try: + await self.init_with_loop() + await init_sshd_service(self.child_env) + except Exception: + log.exception('Unexpected error!') + log.warning('We are skipping the error but the container may not work as expected.') + finally: + if self.init_done is not None: + self.init_done.set() + + @abstractmethod + async def init_with_loop(self) -> None: + """Initialize after the event loop is created.""" + + async def _clean(self, clean_cmd: Optional[str]) -> None: + ret = 0 + try: + if clean_cmd is None or clean_cmd == '': + # skipped + return + elif clean_cmd == '*': + ret = await self.clean_heuristic() + else: + ret = await self.run_subproc(clean_cmd) + except Exception: + log.exception('unexpected error') + ret = -1 + finally: + await asyncio.sleep(0.01) # extra delay to flush logs + payload = json.dumps({ + 'exitCode': ret, + }).encode('utf8') + await self.outsock.send_multipart([b'clean-finished', payload]) + + async def clean_heuristic(self) -> int: + # it should not do anything by default. + return 0 + + async def _bootstrap(self, script_path: Path) -> None: + log.info('Running the user bootstrap script...') + ret = 0 + try: + ret = await self.run_subproc(['/bin/sh', str(script_path)]) + except Exception: + log.exception('unexpected error while executing the user bootstrap script') + ret = -1 + finally: + await asyncio.sleep(0.01) # extra delay to flush logs + log.info('The user bootstrap script has exited with code {}', ret) + + async def _build(self, build_cmd: Optional[str]) -> None: + ret = 0 + try: + if build_cmd is None or build_cmd == '': + # skipped + return + elif build_cmd == '*': + if Path('Makefile').is_file(): + ret = await self.run_subproc('make') + else: + ret = await self.build_heuristic() + else: + ret = await self.run_subproc(build_cmd) + except Exception: + log.exception('unexpected error') + ret = -1 + finally: + await asyncio.sleep(0.01) # extra delay to flush logs + self._build_success = (ret == 0) + payload = json.dumps({ + 'exitCode': ret, + }).encode('utf8') + await self.outsock.send_multipart([b'build-finished', payload]) + + @abstractmethod + async def build_heuristic(self) -> int: + """Process build step.""" + + async def _execute(self, exec_cmd: str) -> None: + ret = 0 + try: + if exec_cmd is None or exec_cmd == '': + # skipped + return + elif exec_cmd == '*': + ret = await self.execute_heuristic() + else: + ret = await self.run_subproc(exec_cmd, batch=True) + except Exception: + log.exception('unexpected error') + ret = -1 + finally: + await asyncio.sleep(0.01) # extra delay to flush logs + payload = json.dumps({ + 'exitCode': ret, + }).encode('utf8') + await self.outsock.send_multipart([b'finished', payload]) + + @abstractmethod + async def execute_heuristic(self) -> int: + """Process execute step.""" + + async def _query(self, code_text: str) -> None: + ret = 0 + try: + ret = await self.query(code_text) + except Exception: + log.exception('unexpected error') + ret = -1 + finally: + payload = json.dumps({ + 'exitCode': ret, + }).encode('utf8') + await self.outsock.send_multipart([b'finished', payload]) + + async def query(self, code_text) -> int: + """Run user's code in query mode. + + The default interface is jupyter kernel. To use different interface, + `Runner` subclass should override this method. + """ + if not hasattr(self, 'kernel_mgr') or self.kernel_mgr is None: + log.error('query mode is disabled: ' + 'failed to start jupyter kernel') + return 127 + + log.debug('executing in query mode...') + + async def output_hook(msg): + content = msg.get('content', '') + if msg['msg_type'] == 'stream': + # content['name'] will be 'stdout' or 'stderr'. + await self.outsock.send_multipart([content['name'].encode('ascii'), + content['text'].encode('utf-8')]) + elif msg['msg_type'] == 'error': + tbs = '\n'.join(content['traceback']) + await self.outsock.send_multipart([b'stderr', tbs.encode('utf-8')]) + elif msg['msg_type'] in ['execute_result', 'display_data']: + data = content['data'] + if len(data) < 1: + return + if len(data) > 1: + data.pop('text/plain', None) + dtype, dval = list(data.items())[0] + + if dtype == 'text/plain': + await self.outsock.send_multipart([b'stdout', + dval.encode('utf-8')]) + elif dtype == 'text/html': + await self.outsock.send_multipart([b'media', + dval.encode('utf-8')]) + # elif dtype == 'text/markdown': + # NotImplementedError + # elif dtype == 'text/latex': + # NotImplementedError + # elif dtype in ['application/json', 'application/javascript']: + # NotImplementedError + elif dtype in ['image/png', 'image/jpeg']: + await self.outsock.send_multipart([ + b'media', + json.dumps({ + 'type': dtype, + 'data': f'data:{dtype};base64,{dval}', + }).encode('utf-8'), + ]) + elif dtype == 'image/svg+xml': + await self.outsock.send_multipart([ + b'media', + json.dumps({'type': dtype, 'data': dval}).encode('utf8'), + ]) + + async def stdin_hook(msg): + if msg['msg_type'] == 'input_request': + prompt = msg['content']['prompt'] + password = msg['content']['password'] + if prompt: + await self.outsock.send_multipart([ + b'stdout', prompt.encode('utf-8')]) + await self.outsock.send_multipart( + [b'waiting-input', + json.dumps({'is_password': password}).encode('utf-8')]) + user_input = await self.user_input_queue.async_q.get() + self.kernel_client.input(user_input) + + # Run jupyter kernel's blocking execution method in an executor pool. + allow_stdin = False if self.user_input_queue is None else True + stdin_hook = None if self.user_input_queue is None else stdin_hook # type: ignore + try: + await aexecute_interactive(self.kernel_client, code_text, timeout=None, + output_hook=output_hook, + allow_stdin=allow_stdin, + stdin_hook=stdin_hook) + except Exception as e: + log.error(str(e)) + return 127 + return 0 + + async def _complete(self, completion_data) -> Sequence[str]: + result: Sequence[str] = [] + try: + result = await self.complete(completion_data) + except Exception: + log.exception('unexpected error') + finally: + return result + + async def complete(self, completion_data) -> Sequence[str]: + """Return the list of strings to be shown in the auto-complete list. + + The default interface is jupyter kernel. To use different interface, + `Runner` subclass should override this method. + """ + # TODO: implement with jupyter_client + ''' + matches = [] + self.outsock.send_multipart([ + b'completion', + json.dumps(matches).encode('utf8'), + ]) + ''' + # if hasattr(self, 'kernel_mgr') and self.kernel_mgr is not None: + # self.kernel_mgr.complete(data, len(data)) + # else: + # return [] + return [] + + async def _interrupt(self): + try: + if self.subproc: + self.subproc.send_signal(signal.SIGINT) + return + return await self.interrupt() + except Exception: + log.exception('unexpected error') + finally: + # this is a unidirectional command -- no explicit finish! + pass + + async def interrupt(self): + """Interrupt the running user code (only called for query-mode). + + The default interface is jupyter kernel. To use different interface, + `Runner` subclass should implement its own `complete` method. + """ + if hasattr(self, 'kernel_mgr') and self.kernel_mgr is not None: + self.kernel_mgr.interrupt_kernel() + + async def _send_status(self): + data = { + 'started_at': self.started_at, + } + await self.outsock.send_multipart([ + b'status', + msgpack.packb(data, use_bin_type=True), + ]) + + @abstractmethod + async def start_service(self, service_info): + """Start an application service daemon.""" + return None, {} + + async def _start_service(self, service_info, user_requested: bool = True): + async with self._service_lock: + try: + if service_info['protocol'] == 'preopen': + # skip subprocess spawning as we assume the user runs it manually. + result = {'status': 'started'} + return + if service_info['name'] in self.services_running: + result = {'status': 'running'} + return + if service_info['protocol'] == 'pty': + result = {'status': 'failed', + 'error': 'not implemented yet'} + return + cwd = Path.cwd() + cmdargs: Optional[Sequence[Union[str, os.PathLike]]] + env: Mapping[str, str] + cmdargs, env = None, {} + if service_info['name'] == 'ttyd': + cmdargs, env = await prepare_ttyd_service(service_info) + elif service_info['name'] == 'sshd': + cmdargs, env = await prepare_sshd_service(service_info) + elif service_info['name'] == 'vscode': + cmdargs, env = await prepare_vscode_service(service_info) + elif self.service_parser is not None: + self.service_parser.variables['ports'] = service_info['ports'] + cmdargs, env = await self.service_parser.start_service( + service_info['name'], + self.child_env.keys(), + service_info['options'], + ) + if cmdargs is None: + # fall-back to legacy service routine + start_info = await self.start_service(service_info) + if start_info is None: + cmdargs, env = None, {} + elif len(start_info) == 3: + cmdargs, env, cwd = start_info + elif len(start_info) == 2: + cmdargs, env = start_info + if cmdargs is None: + # still not found? + log.warning('The service {0} is not supported.', + service_info['name']) + result = { + 'status': 'failed', + 'error': 'unsupported service', + } + return + log.debug('cmdargs: {0}', cmdargs) + log.debug('env: {0}', env) + service_env = {**self.child_env, **env} + # avoid conflicts with Python binary used by service apps. + if 'LD_LIBRARY_PATH' in service_env: + service_env['LD_LIBRARY_PATH'] = \ + service_env['LD_LIBRARY_PATH'].replace('/opt/backend.ai/lib:', '') + try: + proc = await asyncio.create_subprocess_exec( + *map(str, cmdargs), + env=service_env, + cwd=cwd, + ) + self.services_running[service_info['name']] = proc + asyncio.create_task(self._wait_service_proc(service_info['name'], proc)) + with timeout(5.0): + await wait_local_port_open(service_info['port']) + log.info("Service {} has started (pid: {}, port: {})", + service_info['name'], proc.pid, service_info['port']) + result = {'status': "started"} + except asyncio.CancelledError: + # This may happen if the service process gets started but it fails to + # open the port and then terminates (with an error). + result = {'status': "failed", + 'error': f"the process did not start properly: {cmdargs[0]}"} + except asyncio.TimeoutError: + # Takes too much time to open a local port. + if service_info['name'] in self.services_running: + await terminate_and_wait(proc, timeout=10.0) + self.services_running.pop(service_info['name'], None) + result = {'status': "failed", + 'error': f"opening the service port timed out: {service_info['name']}"} + except PermissionError: + result = {'status': "failed", + 'error': f"the target file is not executable: {cmdargs[0]}"} + except FileNotFoundError: + result = {'status': "failed", + 'error': f"the executable file is not found: {cmdargs[0]}"} + except Exception as e: + log.exception('start_service: unexpected error') + result = { + 'status': 'failed', + 'error': repr(e), + } + finally: + if user_requested: + await self.outsock.send_multipart([ + b'service-result', + json.dumps(result).encode('utf8'), + ]) + + async def _wait_service_proc( + self, + service_name: str, + proc: asyncio.subprocess.Process, + ) -> None: + exitcode = await proc.wait() + log.info(f"Service {service_name} (pid: {proc.pid}) has terminated with exit code: {exitcode}") + self.services_running.pop(service_name, None) + + async def run_subproc(self, cmd: Union[str, List[str]], batch: bool = False): + """A thin wrapper for an external command.""" + loop = current_loop() + if Path('/home/work/.logs').is_dir(): + kernel_id = os.environ['BACKENDAI_KERNEL_ID'] + kernel_id_hex = uuid.UUID(kernel_id).hex + log_path = Path( + '/home/work/.logs/task/' + f'{kernel_id_hex[:2]}/{kernel_id_hex[2:4]}/{kernel_id_hex[4:]}.log', + ) + log_path.parent.mkdir(parents=True, exist_ok=True) + else: + log_path = Path(os.path.devnull) + try: + # errors like "command not found" is handled by the spawned shell. + # (the subproc will terminate immediately with return code 127) + if isinstance(cmd, (list, tuple)): + exec_func = partial(asyncio.create_subprocess_exec, *map(str, cmd)) + else: + exec_func = partial(asyncio.create_subprocess_shell, str(cmd)) + pipe_opts = {} + pipe_opts['stdout'] = asyncio.subprocess.PIPE + pipe_opts['stderr'] = asyncio.subprocess.PIPE + with open(log_path, 'ab') as log_out: + env = {**self.child_env} + if batch: + env['_BACKEND_BATCH_MODE'] = '1' + proc = await exec_func( + env=env, + stdin=None, + **pipe_opts, + ) + self.subproc = proc + pipe_tasks = [ + loop.create_task( + pipe_output(proc.stdout, self.outsock, 'stdout', + log_out.fileno())), + loop.create_task( + pipe_output(proc.stderr, self.outsock, 'stderr', + log_out.fileno())), + ] + retcode = await proc.wait() + await asyncio.gather(*pipe_tasks) + return retcode + except Exception: + log.exception('unexpected error') + return -1 + finally: + self.subproc = None + + async def shutdown(self): + pass + + async def _shutdown_service(self, service_name: str): + try: + async with self._service_lock: + if service_name in self.services_running: + await terminate_and_wait(self.services_running[service_name]) + self.services_running.pop(service_name, None) + except Exception: + log.exception('unexpected error (shutdown_service)') + + async def handle_user_input(self, reader, writer): + try: + if self.user_input_queue is None: + writer.write(b'') + else: + await self.outsock.send_multipart([b'waiting-input', b'']) + text = await self.user_input_queue.get() + writer.write(text.encode('utf8')) + await writer.drain() + writer.close() + except Exception: + log.exception('unexpected error (handle_user_input)') + + async def run_tasks(self): + while True: + try: + coro = await self.task_queue.get() + + if (self._build_success is not None and + coro.func == self._execute and + not self._build_success): + self._build_success = None + # skip exec step with "command not found" exit code + payload = json.dumps({ + 'exitCode': 127, + }).encode('utf8') + await self.outsock.send_multipart([b'finished', payload]) + self.task_queue.task_done() + continue + + await coro() + self.task_queue.task_done() + except asyncio.CancelledError: + break + + async def _handle_logs(self): + log_queue = self.log_queue.async_q + try: + while True: + rec = await log_queue.get() + await self.outsock.send_multipart(rec) + log_queue.task_done() + except asyncio.CancelledError: + self.log_queue.close() + await self.log_queue.wait_closed() + + async def _get_apps(self, service_name): + result = {'status': 'done', 'data': []} + if self.service_parser is not None: + if service_name: + apps = await self.service_parser.get_apps(selected_service=service_name) + else: + apps = await self.service_parser.get_apps() + result['data'] = apps + await self.outsock.send_multipart([ + b'apps-result', + json.dumps(result).encode('utf8'), + ]) + + async def main_loop(self, cmdargs): + user_input_server = \ + await asyncio.start_server(self.handle_user_input, + '127.0.0.1', 65000) + await self._init_with_loop() + await self._init_jupyter_kernel() + + user_bootstrap_path = Path('/home/work/bootstrap.sh') + if user_bootstrap_path.is_file(): + await self._bootstrap(user_bootstrap_path) + + log.debug('starting intrinsic services: sshd, ttyd ...') + intrinsic_spawn_coros = [] + intrinsic_spawn_coros.append(self._start_service({ + 'name': 'sshd', + 'port': 2200, + 'protocol': 'tcp', + }, user_requested=False)) + intrinsic_spawn_coros.append(self._start_service({ + 'name': 'ttyd', + 'port': 7681, + 'protocol': 'http', + }, user_requested=False)) + results = await asyncio.gather(*intrinsic_spawn_coros, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + log.exception( + 'error during starting intrinsic services', + exc_info=result, + ) + + log.debug('start serving...') + while True: + try: + data = await self.insock.recv_multipart() + if len(data) != 2: + # maybe some garbage data + continue + op_type = data[0].decode('ascii') + text = data[1].decode('utf8') + if op_type == 'clean': + await self.task_queue.put(partial(self._clean, text)) + if op_type == 'build': # batch-mode step 1 + await self.task_queue.put(partial(self._build, text)) + elif op_type == 'exec': # batch-mode step 2 + await self.task_queue.put(partial(self._execute, text)) + elif op_type == 'code': # query-mode + await self.task_queue.put(partial(self._query, text)) + elif op_type == 'input': # interactive input + if self.user_input_queue is not None: + await self.user_input_queue.put(text) + elif op_type == 'complete': # auto-completion + data = json.loads(text) + await self._complete(data) + elif op_type == 'interrupt': + await self._interrupt() + elif op_type == 'status': + await self._send_status() + elif op_type == 'start-service': # activate a service port + data = json.loads(text) + asyncio.create_task(self._start_service(data)) + elif op_type == 'shutdown-service': # shutdown the service by its name + data = json.loads(text) + await self._shutdown_service(data) + elif op_type == 'get-apps': + await self._get_apps(text) + except asyncio.CancelledError: + break + except NotImplementedError: + log.error('Unsupported operation for this kernel: {0}', op_type) + await asyncio.sleep(0) + except Exception: + log.exception('main_loop: unexpected error') + # we need to continue anyway unless we are shutting down + continue + user_input_server.close() + await user_input_server.wait_closed() + await self.shutdown() diff --git a/src/ai/backend/kernel/c/__init__.py b/src/ai/backend/kernel/c/__init__.py new file mode 100644 index 0000000000..5dc5462601 --- /dev/null +++ b/src/ai/backend/kernel/c/__init__.py @@ -0,0 +1,90 @@ +import asyncio +import logging +import os +from pathlib import Path +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + +DEFAULT_CFLAGS = ['-Wall'] +DEFAULT_LDFLAGS = ['-lrt', '-lm', '-lpthread', '-ldl'] + + +class Runner(BaseRunner): + + log_prefix = 'c-kernel' + default_runtime_path = '/usr/bin/g++' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + self.user_input_queue = asyncio.Queue() + + async def clean_heuristic(self) -> int: + if Path('Makefile').is_file(): + return await self.run_subproc('make clean') + log.warning('skipping the clean phase due to missing "Makefile".') + return 0 + + async def build_heuristic(self) -> int: + if self.runtime_path is None: + raise RuntimeError('Missing runtime path') + if Path('main.c').is_file(): + cfiles_glob = list(Path('.').glob('**/*.c')) + ofiles_glob = [Path(p.stem + '.o') for p in sorted(cfiles_glob)] + for cf in cfiles_glob: + cmd = [str(self.runtime_path), '-c', str(cf), *DEFAULT_CFLAGS] + ret = await self.run_subproc(cmd) + if ret != 0: # stop if gcc has failed + return ret + cmd = [str(self.runtime_path), *map(str, ofiles_glob), + *DEFAULT_CFLAGS, '-o', './main'] + return await self.run_subproc(cmd) + else: + log.error('cannot find build script ("Makefile") ' + 'or the main file ("main.c").') + return 127 + + async def execute_heuristic(self) -> int: + if Path('./main').is_file(): + return await self.run_subproc('./main') + elif Path('./a.out').is_file(): + return await self.run_subproc('./a.out') + else: + log.error('cannot find executable ("a.out" or "main").') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.c', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + cmd = [str(self.runtime_path), tmpf.name, + *DEFAULT_CFLAGS, '-o', './main', *DEFAULT_LDFLAGS] + ret = await self.run_subproc(cmd) + if ret != 0: + return ret + cmd = ['./main'] + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/compat.py b/src/ai/backend/kernel/compat.py new file mode 100644 index 0000000000..7482cb1624 --- /dev/null +++ b/src/ai/backend/kernel/compat.py @@ -0,0 +1,103 @@ +import asyncio +import signal +from typing import ( + Awaitable, Callable, Optional, + Collection, + TypeVar, +) + +__all__ = ( + 'current_loop', +) + + +current_loop: Callable[[], asyncio.AbstractEventLoop] +if hasattr(asyncio, 'get_running_loop'): + current_loop = asyncio.get_running_loop # type: ignore +else: + current_loop = asyncio.get_event_loop # type: ignore + + +all_tasks: Callable[[], Collection[asyncio.Task]] +if hasattr(asyncio, 'all_tasks'): + all_tasks = asyncio.all_tasks # type: ignore +else: + all_tasks = asyncio.Task.all_tasks # type: ignore + + +def _cancel_all_tasks(loop): + to_cancel = all_tasks(loop) + if not to_cancel: + return + for task in to_cancel: + task.cancel() + loop.run_until_complete( + asyncio.gather(*to_cancel, return_exceptions=True)) + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'unhandled exception during asyncio.run() shutdown', + 'exception': task.exception(), + 'task': task, + }) + + +def _asyncio_run(coro, *, debug=False): + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + loop.set_debug(debug) + return loop.run_until_complete(coro) + finally: + try: + _cancel_all_tasks(loop) + if hasattr(loop, 'shutdown_asyncgens'): # Python 3.6+ + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() + + +_T = TypeVar('_T') + +run: Callable[[Awaitable[_T], Optional[bool]], _T] +if hasattr(asyncio, 'run'): + asyncio_run = asyncio.run # type: ignore +else: + asyncio_run = _asyncio_run # type: ignore + + +def asyncio_run_forever(setup_coro, shutdown_coro, *, + stop_signals={signal.SIGINT}, debug=False): + ''' + A proposed-but-not-implemented asyncio.run_forever() API based on + @vxgmichel's idea. + See discussions on https://github.com/python/asyncio/pull/465 + ''' + async def wait_for_stop(): + loop = current_loop() + future = loop.create_future() + for stop_sig in stop_signals: + loop.add_signal_handler(stop_sig, future.set_result, stop_sig) + try: + recv_sig = await future + finally: + loop.remove_signal_handler(recv_sig) + + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + loop.set_debug(debug) + loop.run_until_complete(setup_coro) + loop.run_until_complete(wait_for_stop()) + finally: + try: + loop.run_until_complete(shutdown_coro) + _cancel_all_tasks(loop) + if hasattr(loop, 'shutdown_asyncgens'): # Python 3.6+ + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() diff --git a/src/ai/backend/kernel/cpp/__init__.py b/src/ai/backend/kernel/cpp/__init__.py new file mode 100644 index 0000000000..417c4f32cd --- /dev/null +++ b/src/ai/backend/kernel/cpp/__init__.py @@ -0,0 +1,87 @@ +import logging +import os +from pathlib import Path +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + +DEFAULT_CFLAGS = ['-Wall'] +DEFAULT_LDFLAGS = ['-lrt', '-lm', '-lpthread', '-ldl'] + + +class Runner(BaseRunner): + + log_prefix = 'cpp-kernel' + default_runtime_path = '/usr/bin/g++' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def clean_heuristic(self) -> int: + if Path('Makefile').is_file(): + return await self.run_subproc('make clean') + log.warning('skipping the clean phase due to missing "Makefile".') + return 0 + + async def build_heuristic(self) -> int: + if Path('main.cpp').is_file(): + cppfiles_glob = list(Path('.').glob('**/*.cpp')) + ofiles_glob = [Path(p.stem + '.o') for p in sorted(cppfiles_glob)] + for cppf in cppfiles_glob: + cmd = [str(self.runtime_path), '-c', str(cppf), *DEFAULT_CFLAGS] + ret = await self.run_subproc(cmd) + if ret != 0: # stop if g++ has failed + return ret + cmd = [str(self.runtime_path), *map(str, ofiles_glob), + *DEFAULT_CFLAGS, '-o', './main'] + return await self.run_subproc(cmd) + else: + log.error('cannot find build script ("Makefile") ' + 'or the main file ("main.cpp").') + return 127 + + async def execute_heuristic(self) -> int: + if Path('./main').is_file(): + return await self.run_subproc('./main') + elif Path('./a.out').is_file(): + return await self.run_subproc('./a.out') + else: + log.error('cannot find executable ("a.out" or "main").') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.cpp', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + cmd = [str(self.runtime_path), tmpf.name, + *DEFAULT_CFLAGS, '-o', './main', *DEFAULT_LDFLAGS] + ret = await self.run_subproc(cmd) + if ret != 0: + return ret + cmd = ['./main'] + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/exception.py b/src/ai/backend/kernel/exception.py new file mode 100644 index 0000000000..6af65f4b4f --- /dev/null +++ b/src/ai/backend/kernel/exception.py @@ -0,0 +1,16 @@ +class MessageError(ValueError): + def __init__(self, message): + super().__init__(message) + self.message = message + + +class DisallowedEnvironment(MessageError): + pass + + +class DisallowedArgument(MessageError): + pass + + +class InvalidServiceDefinition(MessageError): + pass diff --git a/src/ai/backend/kernel/git/__init__.py b/src/ai/backend/kernel/git/__init__.py new file mode 100644 index 0000000000..12344d64a7 --- /dev/null +++ b/src/ai/backend/kernel/git/__init__.py @@ -0,0 +1,135 @@ +import asyncio +import json +import logging +import os +from pathlib import Path +import re +import subprocess + +# import pygit2 +# from pygit2 import GIT_SORT_TOPOLOGICAL, GIT_SORT_REVERSE + +from .. import BaseRunner, Terminal + +log = logging.getLogger() + +CHILD_ENV = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/bash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', '/home/backend.ai/libbaihook.so'), +} + + +class Runner(BaseRunner): + + log_prefix = 'shell-kernel' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.child_env.update(CHILD_ENV) + + async def init_with_loop(self): + self.user_input_queue = asyncio.Queue() + self.term = Terminal( + '/bin/bash', + self.stopped, self.outsock, + auto_restart=True, + ) + + parser_show = self.term.subparsers.add_parser('show') + parser_show.add_argument('target', choices=('graph',), default='graph') + parser_show.add_argument('path', type=str) + parser_show.set_defaults(func=self.do_show) + + await self.term.start() + + async def build_heuristic(self) -> int: + raise NotImplementedError + + async def execute_heuristic(self) -> int: + raise NotImplementedError + + async def query(self, code_text) -> int: + return await self.term.handle_command(code_text) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def shutdown(self): + await self.term.shutdown() + + def do_show(self, args): + if args.target == 'graph': + commit_branch_table = {} + commit_info = [] + + if args.path in ['.', None]: + current_dir = Path(f'/proc/{self.term.pid}/cwd').resolve() + else: + current_dir = Path(args.path).resolve() + os.chdir(current_dir) + + # Create commit-branch matching table. + tree_cmd = ['git', 'log', '--pretty=oneline', '--graph', + '--source', '--branches'] + run_res = subprocess.run(tree_cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout = run_res.stdout.decode('utf-8') + stderr = run_res.stderr.decode('utf-8') + prog = re.compile(r'([a-z0-9]+)\s+(\S+).*') + if stderr: + self.outsock.send_multipart([b'stderr', stderr.encode('utf-8')]) + return + + for line in stdout.split('\n'): + r = prog.search(line) + if r and hasattr(r, 'group') and r.group(1) and r.group(2): + oid = r.group(1)[:7] # short oid + branch = r.group(2) + commit_branch_table[oid] = branch + + # Gather commit info w/ branch name. + log_cmd = ['git', 'log', '--pretty=format:%h||%p||%s||%cn', + '--all', '--topo-order', '--reverse'] + run_res = subprocess.run(log_cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout = run_res.stdout.decode('utf-8') + for gitlog in stdout.split('\n'): + items = gitlog.split('||') + oid = items[0] + parent_ids = items[1].split(' ') + message = items[2] + author = items[3] + branch = commit_branch_table.get(oid, None) + parent_branches = [commit_branch_table.get(pid, None) + for pid in parent_ids] + info = dict( + oid=oid, + parent_ids=parent_ids, + author=author, + message=message, + branch=branch, + parent_branches=parent_branches, + ) + commit_info.append(info) + + self.outsock.send_multipart([ + b'media', + json.dumps({ + 'type': 'application/vnd.sorna.gitgraph', + 'data': commit_info, + }).encode('utf-8'), + ]) + else: + raise ValueError('Unsupported show target', args.target) + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/golang/__init__.py b/src/ai/backend/kernel/golang/__init__.py new file mode 100644 index 0000000000..d284181b45 --- /dev/null +++ b/src/ai/backend/kernel/golang/__init__.py @@ -0,0 +1,72 @@ +import logging +import os +from pathlib import Path +import tempfile +from typing import List + +from .. import BaseRunner + +log = logging.getLogger() + +DEFAULT_BFLAGS: List[str] = [''] + + +class Runner(BaseRunner): + + log_prefix = 'go-kernel' + default_runtime_path = '/usr/local/bin/go' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/home/work/bin:/go/bin:/usr/local/go/bin:' + + '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'GOPATH': '/home/work', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + if Path('main.go').is_file(): + go_glob = Path('.').glob('**/*.go') + cmd = [ + str(self.runtime_path), + 'build', '-o', 'main', + *DEFAULT_BFLAGS, *map(str, go_glob), + ] + return await self.run_subproc(cmd) + else: + log.error('cannot find main file ("main.go").') + return 127 + + async def execute_heuristic(self) -> int: + if Path('./main').is_file(): + return await self.run_subproc('./main') + else: + log.error('cannot find executable ("main").') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.go', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + cmd = [str(self.runtime_path), 'run', tmpf.name] + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/haskell/__init__.py b/src/ai/backend/kernel/haskell/__init__.py new file mode 100644 index 0000000000..9baa97df7a --- /dev/null +++ b/src/ai/backend/kernel/haskell/__init__.py @@ -0,0 +1,64 @@ +import logging +import os +from pathlib import Path +import shlex +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'haskell-kernel' + default_runtime_path = '/usr/bin/ghc' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/bash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': ('/root/.cabal/bin:/root/.local/bin:/opt/cabal/2.0/bin:' + '/opt/ghc/8.2.1/bin:/opt/happy/1.19.5/bin:/opt/alex/3.1.7/bin:' + '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin'), + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + # GHC will generate error if no Main module exist among srcfiles. + src_glob = Path('.').glob('**/*.hs') + src_files = ' '.join(map(lambda p: shlex.quote(str(p)), src_glob)) + cmd = f'ghc --make main {src_files}' + return await self.run_subproc(cmd) + + async def execute_heuristic(self) -> int: + if Path('./main').is_file(): + return await self.run_subproc('./main') + else: + log.error('cannot find executable ("main").') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.hs', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + cmd = f'runhaskell {tmpf.name}' + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/intrinsic.py b/src/ai/backend/kernel/intrinsic.py new file mode 100644 index 0000000000..d1ae5e3e23 --- /dev/null +++ b/src/ai/backend/kernel/intrinsic.py @@ -0,0 +1,144 @@ +import asyncio +from collections.abc import Iterable +import logging +import os +from pathlib import Path + +from .logging import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger()) + + +async def init_sshd_service(child_env): + Path('/tmp/dropbear').mkdir(parents=True, exist_ok=True) + auth_path = Path('/home/work/.ssh/authorized_keys') + if not auth_path.is_file(): + auth_path.parent.mkdir(parents=True, exist_ok=True) + auth_path.parent.chmod(0o700) + proc = await asyncio.create_subprocess_exec( + *[ + '/opt/kernel/dropbearkey', + '-t', 'rsa', + '-s', '2048', + '-f', '/tmp/dropbear/id_dropbear', + ], + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=child_env) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"sshd init error: {stderr.decode('utf8')}") + pub_key = stdout.splitlines()[1] + auth_path.write_bytes(pub_key) + auth_path.chmod(0o600) + + # Make the generated private key downloadable by users. + proc = await asyncio.create_subprocess_exec( + *[ + '/opt/kernel/dropbearconvert', + 'dropbear', 'openssh', + '/tmp/dropbear/id_dropbear', '/home/work/id_container', + ], + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=child_env) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"sshd init error: {stderr.decode('utf8')}") + else: + try: + if (auth_path.parent.stat().st_mode & 0o077) != 0: + auth_path.parent.chmod(0o700) + if (auth_path.stat().st_mode & 0o077) != 0: + auth_path.chmod(0o600) + except IOError: + log.warning('could not set the permission for /home/work/.ssh') + proc = await asyncio.create_subprocess_exec( + *[ + '/opt/kernel/dropbearkey', + '-t', 'rsa', + '-s', '2048', + '-f', '/tmp/dropbear/dropbear_rsa_host_key', + ], + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=child_env) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"sshd init error: {stderr.decode('utf8')}") + + cluster_privkey_src_path = Path('/home/config/ssh/id_cluster') + user_ssh_config_path = Path('/home/work/.ssh/config') + if cluster_privkey_src_path.is_file(): + replicas = { + k: v for k, v in + map(lambda pair: pair.rsplit(':', maxsplit=1), + os.environ.get('BACKENDAI_CLUSTER_REPLICAS', 'main:1').split(',')) + } + for role_name, role_replica in replicas.items(): + try: + existing_ssh_config = user_ssh_config_path.read_text() + if '\nHost {role_name}*\n' in existing_ssh_config: + continue + except FileNotFoundError: + pass + with open(user_ssh_config_path, 'a') as f: + f.write(f"\nHost {role_name}*\n") + f.write("\tPort 2200\n") + f.write("\tStrictHostKeyChecking no\n") + f.write("\tIdentityFile /home/config/ssh/id_cluster\n") + cluster_pubkey_src_path = Path('/home/config/ssh/id_cluster.pub') + if cluster_pubkey_src_path.is_file(): + pubkey = cluster_pubkey_src_path.read_bytes() + with open(auth_path, 'ab') as f: + f.write(b'\n') + f.write(pubkey) + f.write(b'\n') + + +async def prepare_sshd_service(service_info): + cmdargs = [ + '/opt/kernel/dropbear', + '-r', '/tmp/dropbear/dropbear_rsa_host_key', + '-E', # show logs in stderr + '-F', # run in foreground + '-s', # disable password logins + # '-W', str(256 * 1024), # recv buffer size (256 KiB) -> built-in during compilation + '-K', '15', # keepalive interval + '-I', '0', # idle timeout + ] + port_config = service_info['port'] + if isinstance(port_config, Iterable): + for port in port_config: + cmdargs.extend(['-p', f"0.0.0.0:{port}"]) + else: + cmdargs.extend(['-p', f"0.0.0.0:{port_config}"]) + env = {} + return cmdargs, env + + +async def prepare_ttyd_service(service_info): + shell = 'sh' + if Path('/bin/bash').exists(): + shell = 'bash' + elif Path('/bin/ash').exists(): + shell = 'ash' + + cmdargs = ['/opt/backend.ai/bin/ttyd', f'/bin/{shell}'] + if shell != 'ash': # Currently Alpine-based containers are not supported. + cmdargs += ['-c', + '/opt/kernel/tmux -2 attach'] + return cmdargs, {} + + +async def prepare_vscode_service(service_info): + # NOTE: This will be replaced as intrinsic binary: /opt/kernel/vscode/... + extension_dir = Path('/home/work/.vscode-exts') + extension_dir.mkdir(parents=True, exist_ok=True) + return [ + '/usr/local/bin/code-server', + '--auth', 'none', + '--bind-addr', '0.0.0.0', + '--port', str(service_info['port']), + '--extensions-dir', str(extension_dir), + ], {'PWD': '/home/work'} diff --git a/src/ai/backend/kernel/java/LablupPatches.java b/src/ai/backend/kernel/java/LablupPatches.java new file mode 100644 index 0000000000..10658f3d25 --- /dev/null +++ b/src/ai/backend/kernel/java/LablupPatches.java @@ -0,0 +1,57 @@ +class BackendInputStream extends InputStream { + private StringReader currentReader; + static Boolean waitUserInput = true; + + @Override + public int read() throws IOException { + if (waitUserInput) { + readFromInputServer(); + waitUserInput = false; + } + int character = currentReader.read(); + if (character == -1) waitUserInput = true; + return character; + } + + private void readFromInputServer() throws IOException { + String scriptPath = "/tmp/lablup_input_stream.py"; + File f = new File(scriptPath); + if (!f.exists()) { + String s = "import socket\n\n" + + "host = '127.0.0.1'\n" + + "port = 65000\n\n" + + "with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sk:\n" + + " try:\n" + + " sk.connect((host, port))\n" + + " userdata = sk.recv(1024)\n" + + " except ConnectRefusedError:\n" + + " userdata = b''\n" + + "print(userdata.decode())"; + try (PrintStream out = new PrintStream( + new FileOutputStream(scriptPath))) { + out.print(s); + } + } + String command = "python " + scriptPath; + String output = executeCommand(command); + currentReader = new StringReader(output); + } + + private String executeCommand(String command) { + StringBuffer output = new StringBuffer(); + Process p; + try { + p = Runtime.getRuntime().exec(command); + p.waitFor(); + BufferedReader reader = new BufferedReader( + new InputStreamReader(p.getInputStream())); + String line = ""; + while ((line = reader.readLine())!= null) { + output.append(line + "\n"); + } + } catch (Exception e) { + e.printStackTrace(); + } + return output.toString(); + } +} diff --git a/src/ai/backend/kernel/java/__init__.py b/src/ai/backend/kernel/java/__init__.py new file mode 100644 index 0000000000..1ad6441da1 --- /dev/null +++ b/src/ai/backend/kernel/java/__init__.py @@ -0,0 +1,110 @@ +import asyncio +import logging +import os +import re +from pathlib import Path +import shlex +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + +JCC = 'javac' +JCR = 'java' + +# Let Java respect container resource limits +DEFAULT_JFLAGS = ['-J-XX:+UnlockExperimentalVMOptions', + '-J-XX:+UseCGroupMemoryLimitForHeap', '-d', '.'] + + +class Runner(BaseRunner): + + log_prefix = 'java-kernel' + default_runtime_path = '/usr/lib/jvm/java-1.8-openjdk/bin/java' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': ('/usr/lib/jvm/java-1.8-openjdk/jre/bin:' + '/usr/lib/jvm/java-1.8-openjdk/bin:/usr/local/sbin:' + '/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin'), + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _code_for_user_input_server(self, code: str) -> str: + # TODO: More elegant way of not touching user code? This method does not work + # for batch exec (no way of knowing the main file). + # Way of monkey patching System.in? + modules = 'import java.io.*;' + static_initializer = (r'\1static{BackendInputStream stream = ' + r'new BackendInputStream();System.setIn(stream);}') + patch = Path(os.path.dirname(__file__)) / 'LablupPatches.java' + altered = re.sub(r'(public[\s]+class[\s]+[\w]+[\s]*{)', static_initializer, + code) + altered = modules + altered + altered = altered + '\n\n' + patch.read_text() + return altered + + async def init_with_loop(self): + self.user_input_queue = asyncio.Queue() + + async def build_heuristic(self) -> int: + if Path('Main.java').is_file(): + java_sources = Path('.').glob('**/*.java') + java_source_list = ' '.join(map(lambda p: shlex.quote(str(p)), java_sources)) + cmd = [JCC, *DEFAULT_JFLAGS, java_source_list] + return await self.run_subproc(cmd) + else: + java_sources = Path('.').glob('**/*.java') + java_source_list = ' '.join(map(lambda p: shlex.quote(str(p)), java_sources)) + cmd = [JCC, *DEFAULT_JFLAGS, java_source_list] + return await self.run_subproc(cmd) + + async def execute_heuristic(self) -> int: + if Path('./main/Main.class').is_file(): + return await self.run_subproc([JCR, 'main.Main']) + elif Path('./Main.class').is_file(): + return await self.run_subproc([JCR, 'Main']) + else: + log.error('cannot find entry class (main.Main).') + return 127 + + async def query(self, code_text) -> int: + # Try to get the name of the first public class using a simple regular + # expression and use it as the name of the main source/class file. + # (In Java, the main function must reside in a public class as a public + # static void method where the filename must be same to the class name) + # + # NOTE: This approach won't perfectly handle all edge cases! + with tempfile.TemporaryDirectory() as tmpdir: + m = re.search(r'public[\s]+class[\s]+([\w]+)[\s]*{', code_text) + if m: + mainpath = Path(tmpdir) / (m.group(1) + '.java') + else: + # TODO: wrap the code using a class skeleton?? + mainpath = Path(tmpdir) / 'main.java' + code = self._code_for_user_input_server(code_text) + with open(mainpath, 'w', encoding='utf-8') as tmpf: + tmpf.write(code) + ret = await self.run_subproc([JCC, str(mainpath)]) + if ret != 0: + return ret + cmd = [JCR, '-classpath', tmpdir, mainpath.stem] + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/julia/__init__.py b/src/ai/backend/kernel/julia/__init__.py new file mode 100644 index 0000000000..fefad508ab --- /dev/null +++ b/src/ai/backend/kernel/julia/__init__.py @@ -0,0 +1,75 @@ +import logging +import os +from pathlib import Path +import tempfile + +import janus + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'julia-kernel' + default_runtime_path = '/usr/local/julia' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash' if Path('/bin/ash').is_file() else '/bin/bash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': ('/usr/local/julia:/usr/local/julia/bin:/usr/local/sbin:' + '/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin'), + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'JULIA_LOAD_PATH': ':/opt/julia', + } + jupyter_kspec_name = 'julia' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_queue = None + self.output_queue = None + + async def init_with_loop(self): + self.input_queue = janus.Queue() + self.output_queue = janus.Queue() + + # We have interactive input functionality! + self._user_input_queue = janus.Queue() + self.user_input_queue = self._user_input_queue.async_q + + # Preparation to initialize ijulia kernel. + cmd = '/usr/local/bin/movecompiled.sh' + await self.run_subproc(cmd) + + async def build_heuristic(self) -> int: + log.info('no build process for julia language') + return 0 + + async def execute_heuristic(self) -> int: + if Path('main.jl').is_file(): + cmd = 'julia main.jl' + return await self.run_subproc(cmd) + else: + log.error('cannot find executable ("main.jl").') + return 127 + + async def start_service(self, service_info): + if service_info['name'] == 'jupyter' or service_info['name'] == 'jupyterlab': + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix='.py', delete=False) as config: + print('c.NotebookApp.allow_root = True', file=config) + print('c.NotebookApp.ip = "0.0.0.0"', file=config) + print('c.NotebookApp.port = {}'.format(service_info['port']), file=config) + print('c.NotebookApp.token = ""', file=config) + print('c.FileContentsManager.delete_to_trash = False', file=config) + jupyter_service_type = service_info['name'] + if jupyter_service_type == 'jupyter': + jupyter_service_type = 'notebook' + return [ + self.runtime_path, '-m', 'jupyter', jupyter_service_type, + '--no-browser', + '--config', config.name, + ], {} diff --git a/src/ai/backend/kernel/jupyter_client.py b/src/ai/backend/kernel/jupyter_client.py new file mode 100644 index 0000000000..40786e4968 --- /dev/null +++ b/src/ai/backend/kernel/jupyter_client.py @@ -0,0 +1,63 @@ +import asyncio +from time import monotonic +import zmq + + +async def aexecute_interactive(kernel_client, code, silent=False, store_history=True, + user_expressions=None, allow_stdin=None, + stop_on_error=True, timeout=None, + output_hook=None, stdin_hook=None): + """Async version of jupyter_client's execute_interactive method. + + https://github.com/jupyter/jupyter_client/blob/master/jupyter_client/blocking/client.py#L213 + """ + msg_id = kernel_client.execute(code, silent=silent, store_history=store_history, + user_expressions=user_expressions, + allow_stdin=allow_stdin, + stop_on_error=stop_on_error) + + stdin_hook = stdin_hook if stdin_hook else kernel_client._stdin_hook_default + output_hook = output_hook if output_hook else kernel_client._output_hook_default + + # Set deadline based on timeout + if timeout is not None: + deadline = monotonic() + timeout + else: + timeout_ms = None + + poller = zmq.asyncio.Poller() + iopub_socket = kernel_client.iopub_channel.socket + poller.register(iopub_socket, zmq.POLLIN) + if allow_stdin: + stdin_socket = kernel_client.stdin_channel.socket + poller.register(stdin_socket, zmq.POLLIN) + else: + stdin_socket = None + + # Wait for zmq events and handle them + while True: + if timeout is not None: + timeout = max(0, deadline - monotonic()) + timeout_ms = 1e3 * timeout + events = dict(await poller.poll(timeout_ms)) + if not events: + raise TimeoutError("Timeout waiting for output") + if iopub_socket in events: + msg = kernel_client.iopub_channel.get_msg(timeout=0) + if msg['parent_header'].get('msg_id') != msg_id: + continue # not from my request + await output_hook(msg) + + # Stop on idle + if msg['header']['msg_type'] == 'status' and \ + msg['content']['execution_state'] == 'idle': + break + if stdin_socket in events: + req = kernel_client.stdin_channel.get_msg(timeout=0) + loop = asyncio.get_event_loop() + loop.create_task(stdin_hook(req)) + + # Output is done, get the reply + if timeout is not None: + timeout = max(0, deadline - monotonic()) + return kernel_client._recv_reply(msg_id, timeout=timeout) diff --git a/src/ai/backend/kernel/logging.py b/src/ai/backend/kernel/logging.py new file mode 100644 index 0000000000..362edac195 --- /dev/null +++ b/src/ai/backend/kernel/logging.py @@ -0,0 +1,52 @@ +from contextlib import closing +from io import StringIO +import logging +from logging.handlers import QueueHandler + + +class LogQHandler(QueueHandler): + def enqueue(self, record): + with closing(StringIO()) as buf: + print(self.formatter.format(record), file=buf) + if record.exc_info is not None: + print(self.formatter.formatException(record.exc_info), file=buf) + self.queue.put_nowait(( + b'stderr', + buf.getvalue().encode('utf8'), + )) + + +class BraceMessage: + + __slots__ = ('fmt', 'args') + + def __init__(self, fmt, args): + self.fmt = fmt + self.args = args + + def __str__(self): + return self.fmt.format(*self.args) + + +class BraceStyleAdapter(logging.LoggerAdapter): + + def __init__(self, logger, extra=None): + super().__init__(logger, extra) + + def log(self, level, msg, *args, **kwargs): + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + self.logger._log(level, BraceMessage(msg, args), (), **kwargs) + + +def setup_logger(log_queue, log_prefix, debug): + # configure logging to publish logs via outsock as well + loghandlers = [logging.StreamHandler()] + if not debug: + loghandlers.append(LogQHandler(log_queue)) + logging.basicConfig( + level=logging.DEBUG if debug else logging.INFO, + format=log_prefix + ': {message}', + style='{', + handlers=loghandlers, + ) diff --git a/src/ai/backend/kernel/lua/__init__.py b/src/ai/backend/kernel/lua/__init__.py new file mode 100644 index 0000000000..88b1fa3138 --- /dev/null +++ b/src/ai/backend/kernel/lua/__init__.py @@ -0,0 +1,59 @@ +import logging +import os +from pathlib import Path +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'lua-kernel' + default_runtime_path = '/usr/local/bin/lua' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + log.info('no build process for lua language') + return 0 + + async def execute_heuristic(self) -> int: + if Path('main.lua').is_file(): + cmd = [str(self.runtime_path), 'main.lua'] + return await self.run_subproc(cmd) + else: + log.error('cannot find executable ("main.lua").') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.lua', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + cmd = [str(self.runtime_path), tmpf.name] + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/nodejs/__init__.py b/src/ai/backend/kernel/nodejs/__init__.py new file mode 100644 index 0000000000..5709be958c --- /dev/null +++ b/src/ai/backend/kernel/nodejs/__init__.py @@ -0,0 +1,63 @@ +import logging +import os +from pathlib import Path +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'nodejs-kernel' + default_runtime_path = '/usr/bin/local/node' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + log.info('no build process for node.js language') + return 0 + + async def execute_heuristic(self) -> int: + if Path('main.js').is_file(): + cmd = [str(self.runtime_path), 'main.js'] + return await self.run_subproc(cmd) + else: + log.error('cannot find executable ("main.js").') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.js', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + cmd = [str(self.runtime_path), tmpf.name] + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + print(service_info['name']) + if service_info['name'] == 'node': + return [ + self.runtime_path, '-m', 'node', + ], {} diff --git a/src/ai/backend/kernel/octave/__init__.py b/src/ai/backend/kernel/octave/__init__.py new file mode 100644 index 0000000000..1fd6d99eb5 --- /dev/null +++ b/src/ai/backend/kernel/octave/__init__.py @@ -0,0 +1,60 @@ +import logging +import os +from pathlib import Path +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'octave-kernel' + default_runtime_path = '/usr/bin/octave-cli' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/bash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + log.info('no build process for octave language') + return 0 + + async def execute_heuristic(self) -> int: + if Path('main.js').is_file(): + cmd = [str(self.runtime_path), 'main.m'] + return await self.run_subproc(cmd) + else: + log.error('cannot find executable ("main.m").') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.m', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + # TODO: support graphics output to display + cmd = [str(self.runtime_path), tmpf.name] + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/php/__init__.py b/src/ai/backend/kernel/php/__init__.py new file mode 100644 index 0000000000..0d9b783ac3 --- /dev/null +++ b/src/ai/backend/kernel/php/__init__.py @@ -0,0 +1,61 @@ +import logging +import os +from pathlib import Path +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'php-kernel' + default_runtime_path = '/usr/bin/php' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'GOPATH': '/home/work', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + log.info('no build process for php language') + return 0 + + async def execute_heuristic(self) -> int: + if Path('main.php').is_file(): + cmd = [str(self.runtime_path), 'main.php'] + return await self.run_subproc(cmd) + else: + log.error('cannot find executable ("main.php").') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.php', dir='.') as tmpf: + tmpf.write(b' int: + if Path('setup.py').is_file(): + cmd = [ + str(self.runtime_path), *DEFAULT_PYFLAGS, + '-m', 'pip', 'install', '--user', '-e', '.', + ] + return await self.run_subproc(cmd) + else: + log.warning('skipping the build phase due to missing "setup.py" file') + return 0 + + async def execute_heuristic(self) -> int: + if Path('main.py').is_file(): + cmd = [ + str(self.runtime_path), *DEFAULT_PYFLAGS, + 'main.py', + ] + return await self.run_subproc(cmd, batch=True) + else: + log.error('cannot find the main script ("main.py").') + return 127 + + async def start_service(self, service_info): + if service_info['name'] in ['jupyter', 'jupyterlab']: + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix='.py', delete=False) as config: + print('c.NotebookApp.allow_root = True', file=config) + print('c.NotebookApp.ip = "0.0.0.0"', file=config) + print('c.NotebookApp.port = {}'.format(service_info['port']), file=config) + print('c.NotebookApp.token = ""', file=config) + print('c.FileContentsManager.delete_to_trash = False', file=config) + print('c.NotebookApp.tornado_settings = {\'ws_ping_interval\': 10000}', file=config) + jupyter_service_type = service_info['name'] + if jupyter_service_type == 'jupyter': + jupyter_service_type = 'notebook' + return [ + self.runtime_path, '-m', jupyter_service_type, + '--no-browser', + '--config', config.name, + ], {} + elif service_info['name'] == 'ipython': + return [ + self.runtime_path, '-m', 'IPython', + ], {} + elif service_info['name'] == 'digits': + return [ + self.runtime_path, '-m', 'digits', + ], {} + elif service_info['name'] == 'tensorboard': + Path('/home/work/logs').mkdir(parents=True, exist_ok=True) + return [ + self.runtime_path, '-m', 'tensorboard.main', + '--logdir', '/home/work/logs', + '--host', '0.0.0.0', + '--port', str(service_info['port']), + '--debugger_port', '6064', # used by in-container TensorFlow + ], {} + elif service_info['name'] == 'spectravis': + return [ + self.runtime_path, '-m', 'http.server', + '8000', + ], {}, '/home/work/spectravis' + elif service_info['name'] == 'sftp': + return [ + self.runtime_path, + '-m', 'sftpserver', + '--port', str(service_info['port']), + ], {} + return None, None diff --git a/src/ai/backend/kernel/python/drawing/__init__.py b/src/ai/backend/kernel/python/drawing/__init__.py new file mode 100644 index 0000000000..875ec0741f --- /dev/null +++ b/src/ai/backend/kernel/python/drawing/__init__.py @@ -0,0 +1,7 @@ +from .canvas import Canvas +from .color import Color, Colors +from .turtle import Turtle, Vec2D + +__all__ = ( + 'Canvas', 'Color', 'Colors', 'Turtle', 'Vec2D', +) diff --git a/src/ai/backend/kernel/python/drawing/canvas.py b/src/ai/backend/kernel/python/drawing/canvas.py new file mode 100644 index 0000000000..c72f109d68 --- /dev/null +++ b/src/ai/backend/kernel/python/drawing/canvas.py @@ -0,0 +1,212 @@ +from six.moves import builtins + +from ..types import MediaRecord +from .encoding import encode_commands +from .turtle import Turtle +from .color import Colors + + +_canvas_id_counter = 0 + + +class DrawingObject: + + def __init__(self, canvas, id_, args): + self._canvas = canvas + self._id = id_ + self._type = args[0] + + def set_x(self, x): + if self._type in (u'rect', u'circle', u'triangle'): + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'x', x)) + + def set_y(self, y): + if self._type in (u'rect', u'circle', u'triangle'): + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'y', y)) + + def set_x1(self, x): + if self._type == u'line': + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'x1', x)) + + def set_y1(self, y): + if self._type == u'line': + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'y1', y)) + + def set_x2(self, x): + if self._type == u'line': + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'x2', x)) + + def set_y2(self, y): + if self._type == u'line': + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'y2', y)) + + def set_radius(self, r): + if self._type == u'circle': + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'radius', r)) + + def rotate(self, a): + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'rotate', a)) + + def set_angle(self, a): + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'angle', a)) + + def stroke(self, color): + color = color.to_hex() + if self._type == u'line': + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'color', color)) + elif self._type == u'circle': + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'border', color)) + elif self._type in (u'rect', u'triangle'): + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'border', color)) + + def fill(self, color): + color = color.to_hex() + if self._type == u'circle': + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'fill', color)) + elif self._type in (u'rect', u'triangle'): + self._canvas._cmd_history.append(( + self._canvas._id, u'update', self._id, + u'fill', color)) + + +class Canvas: + + def __init__(self, width, height, bgcolor=Colors.White, fgcolor=Colors.Black): + global _canvas_id_counter + self._id = _canvas_id_counter + _canvas_id_counter += 1 + self._cmd_history = [] + self._next_objid = 0 + self._cmd_history.append((self._id, u'canvas', + width, height, + bgcolor.to_hex(), + fgcolor.to_hex())) + self.width = width + self.height = height + self.bgcolor = bgcolor + self.fgcolor = fgcolor + + def update(self): + builtins._sorna_emit(MediaRecord( + u'application/x-sorna-drawing', + encode_commands(self._cmd_history), + )) + self._cmd_history = [] + + def show(self): # alias + self.update() + + def create_turtle(self): + t = Turtle(self) + return t + + def stop_animation(self): + self._cmd_history.append((self._id, u'stop-anim')) + + def resume_animation(self): + self._cmd_history.append((self._id, u'resume-anim')) + + def begin_group(self): + self._cmd_history.append((self._id, u'begin-group')) + + def end_group(self): + self._cmd_history.append((self._id, u'end-group')) + + def begin_fill(self, c): + self._cmd_history.append((self._id, u'begin-fill', c.to_hex())) + + def end_fill(self): + self._cmd_history.append((self._id, u'end-fill')) + + def background_color(self, c): + self.bgcolor = c + self._cmd_history.append((self._id, u'bgcolor', c.to_hex())) + + def stroke_color(self, c): + self.fgcolor = c + self._cmd_history.append((self._id, u'fgcolor', c.to_hex())) + + def line(self, x0, y0, x1, y1, color=None): + if color is None: + color = self.fgcolor + args = ( + u'line', x0, y0, x1, y1, + color.to_hex(), + ) + self._cmd_history.append((self._id, u'obj', self._next_objid, args)) + obj = DrawingObject(self, self._next_objid, args) + self._next_objid += 1 + return obj + + def circle(self, x, y, radius, border=None, fill=None, angle=0): + if border is None: + border = self.fgcolor + if fill is None: + fill = Colors.Transparent + args = ( + u'circle', x, y, radius, + border.to_hex(), fill.to_hex(), angle, + ) + self._cmd_history.append((self._id, u'obj', self._next_objid, args)) + obj = DrawingObject(self, self._next_objid, args) + self._next_objid += 1 + return obj + + def rectangle(self, left, top, width, height, border=None, fill=None, angle=0): + if border is None: + border = self.fgcolor + if fill is None: + fill = Colors.Transparent + args = ( + u'rect', left, top, width, height, + border.to_hex(), fill.to_hex(), angle, + ) + self._cmd_history.append((self._id, u'obj', self._next_objid, args)) + obj = DrawingObject(self, self._next_objid, args) + self._next_objid += 1 + return obj + + def triangle(self, left, top, width, height, border=None, fill=None, angle=0): + if border is None: + border = self.fgcolor + if fill is None: + fill = Colors.Transparent + args = ( + u'triangle', left, top, width, height, + border.to_hex(), fill.to_hex(), angle, + ) + self._cmd_history.append((self._id, u'obj', self._next_objid, args)) + obj = DrawingObject(self, self._next_objid, args) + self._next_objid += 1 + return obj + + +__all__ = [ + 'Canvas', +] diff --git a/src/ai/backend/kernel/python/drawing/color.py b/src/ai/backend/kernel/python/drawing/color.py new file mode 100644 index 0000000000..9f96356e94 --- /dev/null +++ b/src/ai/backend/kernel/python/drawing/color.py @@ -0,0 +1,63 @@ +import enum +import struct + +rgba = struct.Struct('BBBB') + + +class Color: + + def __init__(self, red, green, blue, alpha=255): + self.red = red + self.green = green + self.blue = blue + self.alpha = alpha + + @staticmethod + def from_hex(value): + value = value.replace('#', '') + r = int(value[0:2], 16) + g = int(value[2:4], 16) + b = int(value[4:6], 16) + a = int(value[6:8], 16) + return Color(r, g, b, a) + + @staticmethod + def from_rgba(value): + return Color(*value) + + @staticmethod + def from_bytes(value): + r, g, b, a = rgba.unpack(value) + return Color(r, g, b, a) + + def to_hex(self, include_alpha=True): + if include_alpha: + return u'#{:02x}{:02x}{:02x}{:02x}'.format( + self.red, self.green, self.blue, self.alpha) + else: + return u'#{:02x}{:02x}{:02x}'.format( + self.red, self.green, self.blue) + + def to_bytes(self): + return rgba.pack(self.red, self.green, self.blue, self.alpha) + + def to_rgba(self): + return 'rgba({},{},{},{})'.format(self.red, self.green, self.blue, self.alpha) + + +class Colors(Color, enum.Enum): + Transparent = (255, 255, 255, 0) + Black = (0, 0, 0, 255) + Gray = (128, 128, 128, 255) + White = (255, 255, 255, 255) + Red = (255, 0, 0, 255) + Green = (0, 255, 0, 255) + Blue = (0, 0, 255, 255) + Yellow = (255, 255, 0, 255) + Magenta = (255, 0, 255, 255) + Cyan = (0, 255, 255, 255) + + +__all__ = [ + 'Color', 'Colors', +] diff --git a/src/ai/backend/kernel/python/drawing/encoding.py b/src/ai/backend/kernel/python/drawing/encoding.py new file mode 100644 index 0000000000..381be6b198 --- /dev/null +++ b/src/ai/backend/kernel/python/drawing/encoding.py @@ -0,0 +1,12 @@ +import base64 +import msgpack + + +def encode_commands(cmdlist): + bindata = msgpack.packb(cmdlist, use_bin_type=True) + return base64.b64encode(bindata).decode('ascii') + + +def decode_commands(data): + bindata = base64.b64decode(data) + return msgpack.unpackb(bindata, raw=False) diff --git a/src/ai/backend/kernel/python/drawing/turtle.py b/src/ai/backend/kernel/python/drawing/turtle.py new file mode 100644 index 0000000000..950b2cb611 --- /dev/null +++ b/src/ai/backend/kernel/python/drawing/turtle.py @@ -0,0 +1,115 @@ +from .color import Colors +import math + + +class Vec2D(tuple): + '''A helper class taken from Python stdlib's Turtle package.''' + + def __new__(cls, x, y): + return tuple.__new__(cls, (x, y)) + + def __add__(self, other): + return Vec2D(self[0] + other[0], self[1] + other[1]) + + def __mul__(self, other): + if isinstance(other, Vec2D): + return self[0] * other[0] + self[1] * other[1] + return Vec2D(self[0] * other, self[1] * other) + + def __rmul__(self, other): + if isinstance(other, int) or isinstance(other, float): + return Vec2D(self[0] * other, self[1] * other) + + def __sub__(self, other): + return Vec2D(self[0] - other[0], self[1] - other[1]) + + def __neg__(self): + return Vec2D(-self[0], -self[1]) + + def __abs__(self): + return (self[0] ** 2 + self[1] ** 2) ** 0.5 + + def rotate(self, angle): + """rotate self counterclockwise by angle + """ + perp = Vec2D(-self[1], self[0]) + angle = angle * math.pi / 180.0 + c, s = math.cos(angle), math.sin(angle) + return Vec2D(self[0] * c + perp[0] * s, self[1] * c + perp[1] * s) + + def __getnewargs__(self): + return (self[0], self[1]) + + def __repr__(self): + return "(%.2f,%.2f)" % self + + +class Turtle: + + def __init__(self, canvas): + self.canvas = canvas + self.points = [] + self.pen = True + w = self.canvas.width + h = self.canvas.height + self.cursor = self.canvas.triangle( + w / 2, h / 2, 12, 18, + border=Colors.Red, + fill=Colors.from_rgba([255, 200, 200, 255]), + angle=90) + self.angle = 90 + self.points.append((w / 2, h / 2)) + + def forward(self, amt): + x = self.points[-1][0] + y = self.points[-1][1] + x_diff = math.sin(math.radians(self.angle)) * amt + y_diff = -1 * math.cos(math.radians(self.angle)) * amt + self.canvas.begin_group() + if self.pen: + self.canvas.line(x, y, x + x_diff, y + y_diff, + color=Colors.from_rgba([255, 0, 0, 128])) + self.cursor.set_x(x + x_diff) + self.cursor.set_y(y + y_diff) + self.canvas.end_group() + self.points.append((x + x_diff, y + y_diff)) + + def left(self, deg): + self.cursor.rotate(-deg) + self.angle -= deg + + def right(self, deg): + self.cursor.rotate(deg) + self.angle += deg + + def pos(self): + base_x, base_y = self.points[0][0], self.points[0][1] + return Vec2D(self.points[-1][0] - base_x, + self.points[-1][1] - base_y) + + def penup(self): + self.pen = False + + def pendown(self): + self.pen = True + + def setpos(self, x, y=None): + base_x, base_y = self.points[0][0], self.points[0][1] + if y is None: + _x = x[0] + _y = x[1] + x, y = _x, _y + self.canvas.begin_group() + if self.pen: + self.canvas.line(self.points[-1][0], self.points[-1][1], + x + base_x, y + base_y, + color=Colors.from_rgba([255, 0, 0, 128])) + self.cursor.set_x(x + base_x) + self.cursor.set_y(y + base_y) + self.canvas.end_group() + self.points.append((x + base_x, y + base_y)) + + +__all__ = [ + 'Turtle', 'Vec2D', +] diff --git a/src/ai/backend/kernel/python/sitecustomize.py b/src/ai/backend/kernel/python/sitecustomize.py new file mode 100644 index 0000000000..f2063edae7 --- /dev/null +++ b/src/ai/backend/kernel/python/sitecustomize.py @@ -0,0 +1,46 @@ +import os +import socket +import sys + +input_host = '127.0.0.1' +input_port = 65000 + +batch_enabled = int(os.environ.get('_BACKEND_BATCH_MODE', '0')) +if batch_enabled: + # Since latest Python 2 has `builtins`and `input`, + # we cannot detect Python 2 with the existence of them. + if sys.version_info.major > 2: + import builtins + + def _input(prompt=''): + sys.stdout.write(prompt) + sys.stdout.flush() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + try: + sock.connect((input_host, input_port)) + userdata = sock.recv(1024) + except ConnectionRefusedError: + userdata = b'' + return userdata.decode() + builtins._input = input # type: ignore + builtins.input = _input + else: + # __builtins__ is an alias dict for __builtin__ in modules other than __main__. + # Thus, we have to explicitly import __builtin__ module in Python 2. + import __builtin__ + builtins = __builtin__ + + def _raw_input(prompt=''): + sys.stdout.write(prompt) + sys.stdout.flush() + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((input_host, input_port)) + userdata = sock.recv(1024) + except socket.error: + userdata = b'' + finally: + sock.close() + return userdata.decode() + builtins._raw_input = builtins.raw_input # type: ignore + builtins.raw_input = _raw_input # type: ignore diff --git a/src/ai/backend/kernel/python/types.py b/src/ai/backend/kernel/python/types.py new file mode 100644 index 0000000000..bf8baa8db2 --- /dev/null +++ b/src/ai/backend/kernel/python/types.py @@ -0,0 +1,28 @@ +from namedlist import namedtuple, FACTORY + + +InputRequest = namedtuple('InputRequest', [ + ('is_password', False), +]) + +ControlRecord = namedtuple('ControlRecord', [ + ('event', None), +]) + +CompletionRecord = namedtuple('CompletionRecord', [ + ('matches', FACTORY(list)), +]) + +ConsoleRecord = namedtuple('ConsoleRecord', [ + ('target', 'stdout'), # or 'stderr' + ('data', ''), +]) + +MediaRecord = namedtuple('MediaRecord', [ + ('type', None), # mime-type + ('data', None), +]) + +HTMLRecord = namedtuple('HTMLRecord', [ + ('html', None), # raw HTML string +]) diff --git a/src/ai/backend/kernel/r/__init__.py b/src/ai/backend/kernel/r/__init__.py new file mode 100644 index 0000000000..b768677f51 --- /dev/null +++ b/src/ai/backend/kernel/r/__init__.py @@ -0,0 +1,69 @@ +import logging +import os +from pathlib import Path +import tempfile + +import janus + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'r-kernel' + default_runtime_path = '/usr/bin/R' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash' if Path('/bin/ash').is_file() else '/bin/bash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + } + jupyter_kspec_name = 'ir' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_queue = None + self.output_queue = None + + async def init_with_loop(self): + self.input_queue = janus.Queue() + self.output_queue = janus.Queue() + + # We have interactive input functionality! + self._user_input_queue = janus.Queue() + self.user_input_queue = self._user_input_queue.async_q + + async def build_heuristic(self): + log.info('no build process for R language') + return 0 + + async def execute_heuristic(self): + if Path('main.R').is_file(): + cmd = 'Rscript main.R' + return await self.run_subproc(cmd) + else: + log.error('cannot find executable ("main.R").') + return 127 + + async def start_service(self, service_info): + if service_info['name'] in ['jupyter', 'jupyterlab']: + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix='.py', delete=False) as config: + print('c.NotebookApp.allow_root = True', file=config) + print('c.NotebookApp.ip = "0.0.0.0"', file=config) + print('c.NotebookApp.port = {}'.format(service_info['port']), file=config) + print('c.NotebookApp.token = ""', file=config) + print('c.FileContentsManager.delete_to_trash = False', file=config) + jupyter_service_type = service_info['name'] + if jupyter_service_type == 'jupyter': + jupyter_service_type = 'notebook' + return [ + self.runtime_path, '-m', 'jupyter', jupyter_service_type, + '--no-browser', + '--config', config.name, + ], {} diff --git a/src/ai/backend/kernel/r_server_ms/__init__.py b/src/ai/backend/kernel/r_server_ms/__init__.py new file mode 100644 index 0000000000..5c0330bb31 --- /dev/null +++ b/src/ai/backend/kernel/r_server_ms/__init__.py @@ -0,0 +1,118 @@ +from datetime import datetime, timedelta +import logging +import os + +import aiohttp +from yarl import URL + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + ''' + Implements an adaptor to Microsoft R Server API. + ''' + + log_prefix = 'r-server' + + def __init__(self, *args, endpoint=None, credentials=None, **kwargs): + super().__init__(*args, **kwargs) + if endpoint is None: + endpoint = os.environ.get('MRS_ENDPOINT', 'localhost') + if credentials is None: + credentials = { + 'username': os.environ.get('MRS_USERNAME', 'anonymous'), + 'password': os.environ.get('MRS_PASSWORD', 'unknown'), + } + self.http_sess = None + self.endpoint = endpoint + self.credentials = credentials + self.access_token = None + self.expires_on = None + + async def init_with_loop(self): + self.http_sess = aiohttp.ClientSession() + await self._refresh_token() + sess_create_url = self.endpoint + '/sessions' + resp = await self.http_sess.post( + sess_create_url, + headers=self.auth_hdrs, + json={}) + data = await resp.json() + self.sess_id = data['sessionId'] + log.debug('created session:', self.sess_id) + + async def shutdown(self): + await self._refresh_token() + sess_url = f'{self.endpoint}/sessions/{self.sess_id}' + resp = await self.http_sess.delete( + sess_url, + headers=self.auth_hdrs) + resp.raise_for_status() + log.debug('deleted session:', self.sess_id) + revoke_url = URL(f'{self.endpoint}/login/refreshToken') + revoke_url = revoke_url.update_query({ + 'refreshToken': self.refresh_token, + }) + resp = await self.http_sess.delete( + revoke_url, + headers=self.auth_hdrs) + resp.raise_for_status() + await self.http_sess.close() + + async def build_heuristic(self) -> int: + raise NotImplementedError + + async def execute_heuristic(self) -> int: + raise NotImplementedError + + async def query(self, code_text) -> int: + await self._refresh_token() + execute_url = f'{self.endpoint}/sessions/{self.sess_id}/execute' + resp = await self.http_sess.post( + execute_url, + headers=self.auth_hdrs, + json={ + 'code': code_text, + }) + data = await resp.json() + self.outsock.send_multipart(['stdout', data['consoleOutput']]) + return 0 + + async def complete(self, data): + return [] + + async def interrupt(self): + # TODO: cancel session? + pass + + async def _refresh_token(self): + if self.access_token is None: + login_url = self.endpoint + '/login' + resp = await self.http_sess.post( + login_url, + json=self.credentials) + elif self.expires_on is not None and self.expires_on <= datetime.now(): + refresh_url = f'{self.endpoint}/login/refreshToken' + resp = await self.http_sess.post( + refresh_url, + headers=self.auth_hdrs, + json={ + 'refreshToken': self.refresh_token, + }) + else: + return + data = await resp.json() + self.access_token = data['access_token'] + self.refresh_token = data['refresh_token'] + self.expires_on = datetime.now() \ + + timedelta(seconds=int(data['expires_in'])) + self.auth_hdrs = { + 'Authorization': f'Bearer {self.access_token}', + } + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/requirements.txt b/src/ai/backend/kernel/requirements.txt new file mode 100644 index 0000000000..6148d98114 --- /dev/null +++ b/src/ai/backend/kernel/requirements.txt @@ -0,0 +1,8 @@ +# copied from https://github.com/lablup/backend.ai-krunner-static-gnu/blob/main/src/ai/backend/krunner/static_gnu/requirements.txt +async_timeout~=3.0 +pyzmq~=22.2 +uvloop~=0.16 +attrs~=21.2 +janus~=0.6.1 +msgpack~=1.0 +jupyter-client~=6.1 diff --git a/src/ai/backend/kernel/rust/__init__.py b/src/ai/backend/kernel/rust/__init__.py new file mode 100644 index 0000000000..03cc094fa1 --- /dev/null +++ b/src/ai/backend/kernel/rust/__init__.py @@ -0,0 +1,76 @@ +import logging +import os +from pathlib import Path +import tempfile + +from .. import BaseRunner +from ..utils import find_executable + +log = logging.getLogger() + +CARGO = 'cargo' +RUSTC = 'rustc' + + +class Runner(BaseRunner): + + log_prefix = 'rust-kernel' + default_runtime_path = '/usr/bin/rustc' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'GOPATH': '/home/work', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + if Path('Cargo.toml').is_file(): + return await self.run_subproc([CARGO, 'build']) + elif Path('main.rs').is_file(): + return await self.run_subproc([RUSTC, '-o', 'main', 'main.rs']) + else: + log.error( + 'cannot find the main/build file ("Cargo.toml" or "main.rs").') + return 127 + + async def execute_heuristic(self) -> int: + out = find_executable('./target/debug', './target/release') + if out is not None: + return await self.run_subproc([out]) + elif Path('./main').is_file(): + return await self.run_subproc(['./main']) + else: + log.error('cannot find executable ("main" or target directories).') + return 127 + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.rs', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + cmd = [RUSTC, '-o', 'main', tmpf.name] + ret = await self.run_subproc(cmd) + if ret != 0: + return ret + cmd = ['./main'] + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/scheme/__init__.py b/src/ai/backend/kernel/scheme/__init__.py new file mode 100644 index 0000000000..a1f6d27711 --- /dev/null +++ b/src/ai/backend/kernel/scheme/__init__.py @@ -0,0 +1,52 @@ +import logging +import os +import tempfile + +from .. import BaseRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'scheme-kernel' + default_runtime_path = '/usr/bin/rustc' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/ash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': '/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + pass + + async def build_heuristic(self) -> int: + pass + + async def execute_heuristic(self) -> int: + pass + + async def query(self, code_text) -> int: + with tempfile.NamedTemporaryFile(suffix='.scm', dir='.') as tmpf: + tmpf.write(code_text.encode('utf8')) + tmpf.flush() + cmd = f'scheme --quiet < {tmpf.name}' + return await self.run_subproc(cmd) + + async def complete(self, data): + return [] + + async def interrupt(self): + # subproc interrupt is already handled by BaseRunner + pass + + async def start_service(self, service_info): + return None, {} diff --git a/src/ai/backend/kernel/service.py b/src/ai/backend/kernel/service.py new file mode 100644 index 0000000000..49efe03d49 --- /dev/null +++ b/src/ai/backend/kernel/service.py @@ -0,0 +1,154 @@ +""" +Parses and interpolates the service-definition templates stored as json in +``/etc/backend.ai/servce-defs`` of Backend.AI containers. + +See more details at `the documentation about adding new kernel images +`_. +""" + +import json +import logging +from pathlib import Path +from typing import ( + Any, Collection, Optional, Union, + TypedDict, + Mapping, MutableMapping, Dict, + Sequence, List, Tuple, +) + +import attr + +from . import service_actions +from .logging import BraceStyleAdapter +from .exception import DisallowedArgument, DisallowedEnvironment, InvalidServiceDefinition + +log = BraceStyleAdapter(logging.getLogger()) + + +class Action(TypedDict): + action: str + args: Mapping[str, str] + ref: Optional[str] + + +@attr.s(auto_attribs=True, slots=True) +class ServiceDefinition: + command: List[str] + noop: bool = False + url_template: str = '' + prestart_actions: List[Action] = attr.Factory(list) + env: Mapping[str, str] = attr.Factory(dict) + allowed_envs: List[str] = attr.Factory(list) + allowed_arguments: List[str] = attr.Factory(list) + default_arguments: Mapping[str, Union[None, str, List[str]]] = attr.Factory(dict) + + +class ServiceParser: + + variables: MutableMapping[str, str] + services: MutableMapping[str, ServiceDefinition] + + def __init__(self, variables: MutableMapping[str, str]) -> None: + self.variables = variables + self.services = {} + + async def parse(self, path: Path) -> None: + for service_def_file in path.glob('*.json'): + log.debug(f'loading service-definition from {service_def_file}') + try: + with open(service_def_file.absolute(), 'rb') as fr: + raw_service_def = json.load(fr) + # translate naming differences + if 'prestart' in raw_service_def: + raw_service_def['prestart_actions'] = raw_service_def['prestart'] + del raw_service_def['prestart'] + except IOError: + raise InvalidServiceDefinition( + f'could not read the service-def file: {service_def_file.name}') + except json.JSONDecodeError: + raise InvalidServiceDefinition( + f'malformed JSON in service-def file: {service_def_file.name}') + name = service_def_file.stem + try: + self.services[name] = ServiceDefinition(**raw_service_def) + except TypeError as e: + raise InvalidServiceDefinition(e.args[0][11:]) # lstrip "__init__() " + + async def start_service( + self, + service_name: str, + frozen_envs: Collection[str], + opts: Mapping[str, Any], + ) -> Tuple[Optional[Sequence[str]], Mapping[str, str]]: + if service_name not in self.services.keys(): + return None, {} + service = self.services[service_name] + if service.noop: + return [], {} + + for action in service.prestart_actions: + try: + action_impl = getattr(service_actions, action['action']) + except AttributeError: + raise InvalidServiceDefinition( + f"Service-def for {service_name} used invalid action: {action['action']}") + ret = await action_impl(self.variables, **action['args']) + if (ref := action.get('ref')) is not None: + self.variables[ref] = ret + + cmdargs, env = [], {} + + for arg in service.command: + cmdargs.append(arg.format_map(self.variables)) + + additional_arguments = dict(service.default_arguments) + if 'arguments' in opts.keys() and opts['arguments']: + for argname, argvalue in opts['arguments'].items(): + if argname not in service.allowed_arguments: + raise DisallowedArgument( + f'Argument {argname} not allowed for service {service_name}') + additional_arguments[argname] = argvalue + + for env_name, env_value in service.env.items(): + env[env_name.format_map(self.variables)] = env_value.format_map(self.variables) + + if 'envs' in opts.keys() and opts['envs']: + for envname, envvalue in opts['envs'].items(): + if envname not in service.allowed_envs: + raise DisallowedEnvironment( + f'Environment variable {envname} not allowed for service {service_name}') + elif envname in frozen_envs: + raise DisallowedEnvironment( + f'Environment variable {envname} can\'t be overwritten') + env[envname] = envvalue + + for arg_name, arg_value in additional_arguments.items(): + cmdargs.append(arg_name) + if isinstance(arg_value, str): + cmdargs.append(arg_value) + elif isinstance(arg_value, list): + cmdargs += arg_value + + return cmdargs, env + + async def get_apps(self, selected_service: str = '') -> Sequence[Mapping[str, Any]]: + + def _format(service_name: str) -> Mapping[str, Any]: + service_info: Dict[str, Any] = {'name': service_name} + service = self.services[service_name] + if len(service.url_template) > 0: + service_info['url_template'] = service.url_template + if len(service.allowed_arguments) > 0: + service_info['allowed_arguments'] = service.allowed_arguments + if len(service.allowed_envs) > 0: + service_info['allowed_envs'] = service.allowed_envs + return service_info + + apps = [] + if selected_service: + if selected_service in self.services.keys(): + apps.append(_format(selected_service)) + else: + for service_name in self.services.keys(): + apps.append(_format(service_name)) + return apps diff --git a/src/ai/backend/kernel/service_actions.py b/src/ai/backend/kernel/service_actions.py new file mode 100644 index 0000000000..e7fbb45d4b --- /dev/null +++ b/src/ai/backend/kernel/service_actions.py @@ -0,0 +1,72 @@ +from asyncio import create_subprocess_exec, subprocess +import logging +import os +from pathlib import Path +import tempfile +from typing import ( + Any, Optional, + Iterable, + Mapping, MutableMapping, +) + +from .logging import BraceStyleAdapter + +logger = BraceStyleAdapter(logging.getLogger()) + + +async def write_file( + variables: Mapping[str, Any], + filename: str, + body: Iterable[str], + mode: str = '644', + append: bool = False, +) -> None: + filename = filename.format_map(variables) + open_mode = 'w' + ('+' if append else '') + with open(filename, open_mode) as fw: + for line in body: + fw.write(line.format_map(variables) + '\n') + os.chmod(filename, int(mode, 8)) + + +async def write_tempfile( + variables: Mapping[str, Any], + body: Iterable[str], + mode: str = '644', +) -> Optional[str]: + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix='.py', delete=False) as config: + for line in body: + config.write(line.format_map(variables)) + os.chmod(config.name, int(mode, 8)) + return config.name + + +async def run_command( + variables: Mapping[str, Any], + command: Iterable[str], +) -> Optional[MutableMapping[str, str]]: + proc = await create_subprocess_exec(*( + str(piece).format_map(variables) for piece in command + ), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = await proc.communicate() + return {'out': out.decode('utf8'), 'err': err.decode('utf8')} + + +async def mkdir( + variables: Mapping[str, Any], + path: str, +) -> None: + Path(path.format_map(variables)).mkdir(parents=True, exist_ok=True) + + +async def log( + variables: Mapping[str, Any], + message: str, + debug: bool = False, +) -> None: + message = message.format_map(variables).replace('{', '{{').replace('}', '}}') + if debug: + logger.debug(message) + else: + logger.info(message) diff --git a/src/ai/backend/kernel/terminal.py b/src/ai/backend/kernel/terminal.py new file mode 100644 index 0000000000..70b352d1a3 --- /dev/null +++ b/src/ai/backend/kernel/terminal.py @@ -0,0 +1,206 @@ +import argparse +import asyncio +import fcntl +import logging +import os +import pty +import shlex +import signal +import struct +import sys +import termios +import traceback + +import zmq, zmq.asyncio + +from .logging import BraceStyleAdapter +from .utils import safe_close_task + +log = BraceStyleAdapter(logging.getLogger()) + + +class Terminal: + ''' + A wrapper for a terminal-based app. + ''' + + def __init__(self, shell_cmd, ev_term, sock_out, *, + auto_restart=True, loop=None): + self._sorna_media = [] + self.zctx = sock_out.context + + self.ev_term = ev_term + self.pid = None + self.fd = None + + self.shell_cmd = shell_cmd + self.auto_restart = auto_restart + + # For command output + self.sock_out = sock_out + + # For terminal I/O + self.sock_term_in = None + self.sock_term_out = None + self.term_in_task = None + self.term_out_task = None + self.start_lock = asyncio.Lock() + self.accept_term_input = False + + self.cmdparser = argparse.ArgumentParser() + self.subparsers = self.cmdparser.add_subparsers() + + # Base commands for generic terminal-based app + parser_ping = self.subparsers.add_parser('ping') + parser_ping.set_defaults(func=self.do_ping) + + parser_resize = self.subparsers.add_parser('resize') + parser_resize.add_argument('rows', type=int) + parser_resize.add_argument('cols', type=int) + parser_resize.set_defaults(func=self.do_resize_term) + + async def do_ping(self, args) -> int: + await self.sock_out.send_multipart([b'stdout', b'pong!']) + return 0 + + async def do_resize_term(self, args) -> int: + if self.fd is None: + return 0 + origsz_in = struct.pack('HHHH', 0, 0, 0, 0) + origsz_out = fcntl.ioctl(self.fd, termios.TIOCGWINSZ, origsz_in, False) + orig_lines, orig_cols, _, _ = struct.unpack('HHHH', origsz_out) + newsz_in = struct.pack('HHHH', args.rows, args.cols, orig_lines, orig_cols) + newsz_out = fcntl.ioctl(self.fd, termios.TIOCSWINSZ, newsz_in, False) + new_lines, new_cols, _, _ = struct.unpack('HHHH', newsz_out) + await self.sock_out.send_multipart([ + b'stdout', + f'OK; terminal resized to {new_lines} lines and {new_cols} columns'.encode(), + ]) + return 0 + + async def handle_command(self, code_txt) -> int: + try: + if code_txt.startswith('%'): + args = self.cmdparser.parse_args( + shlex.split(code_txt[1:], comments=True)) + if asyncio.iscoroutine(args.func) or \ + asyncio.iscoroutinefunction(args.func): + return await args.func(args) + else: + return args.func(args) + else: + await self.sock_out.send_multipart([b'stderr', b'Invalid command.']) + return 127 + except: + exc_type, exc_val, tb = sys.exc_info() + traces = traceback.format_exception(exc_type, exc_val, tb) + await self.sock_out.send_multipart([b'stderr', ''.join(traces).encode()]) + return 1 + finally: + await self.sock_out.send_multipart([b'finished', b'{}']) + + async def start(self): + assert not self.accept_term_input + await safe_close_task(self.term_in_task) + await safe_close_task(self.term_out_task) + pid, fd = pty.fork() + if pid == 0: + args = shlex.split(self.shell_cmd) + os.execv(args[0], args) + else: + self.pid = pid + self.fd = fd + + if self.sock_term_in is None: + self.sock_term_in = self.zctx.socket(zmq.SUB) + self.sock_term_in.bind('tcp://*:2002') + self.sock_term_in.subscribe(b'') + if self.sock_term_out is None: + self.sock_term_out = self.zctx.socket(zmq.PUB) + self.sock_term_out.bind('tcp://*:2003') + + loop = asyncio.get_running_loop() + term_reader = asyncio.StreamReader() + term_read_protocol = asyncio.StreamReaderProtocol(term_reader) + await loop.connect_read_pipe( + lambda: term_read_protocol, os.fdopen(self.fd, 'rb')) + + _reader_factory = lambda: asyncio.StreamReaderProtocol( + asyncio.StreamReader()) + term_writer_transport, term_writer_protocol = \ + await loop.connect_write_pipe(_reader_factory, + os.fdopen(self.fd, 'wb')) + term_writer = asyncio.StreamWriter(term_writer_transport, + term_writer_protocol, + None) + + self.term_in_task = asyncio.create_task(self.term_in(term_writer)) + self.term_out_task = asyncio.create_task(self.term_out(term_reader)) + self.accept_term_input = True + await asyncio.sleep(0) + + async def restart(self): + try: + async with self.start_lock: + if not self.accept_term_input: + return + self.accept_term_input = False + await self.sock_term_out.send_multipart([b'Restarting...\r\n']) + os.waitpid(self.pid, 0) + await self.start() + except Exception: + log.exception('Unexpected error during restart of terminal') + + async def term_in(self, term_writer): + try: + while True: + data = await self.sock_term_in.recv_multipart() + if not data: + break + if self.accept_term_input: + try: + term_writer.write(data[0]) + await term_writer.drain() + except IOError: + break + except asyncio.CancelledError: + pass + except Exception: + log.exception('Unexpected error at term_in()') + + async def term_out(self, term_reader): + try: + while not term_reader.at_eof(): + try: + data = await term_reader.read(4096) + except IOError: + # In docker containers, this path is taken. + break + if not data: + # In macOS, this path is taken. + break + await self.sock_term_out.send_multipart([data]) + self.fd = None + if not self.auto_restart: + await self.sock_term_out.send_multipart([b'Terminated.\r\n']) + return + if not self.ev_term.is_set() and self.accept_term_input: + asyncio.create_task(self.restart()) + except asyncio.CancelledError: + pass + except Exception: + log.exception('Unexpected error at term_out()') + + async def shutdown(self): + self.term_in_task.cancel() + self.term_out_task.cancel() + await self.term_in_task + await self.term_out_task + self.sock_term_in.close() + self.sock_term_out.close() + os.kill(self.pid, signal.SIGHUP) + os.kill(self.pid, signal.SIGCONT) + await asyncio.sleep(0) + os.waitpid(self.pid, 0) + self.pid = None + self.fd = None diff --git a/src/ai/backend/kernel/test_utils.py b/src/ai/backend/kernel/test_utils.py new file mode 100644 index 0000000000..c0b9c53da4 --- /dev/null +++ b/src/ai/backend/kernel/test_utils.py @@ -0,0 +1,33 @@ +import asynctest + + +class MockableZMQAsyncSock: + + # Since zmq.Socket/zmq.asyncio.Socket uses a special AttributeSetter mixin which + # breaks mocking of those instances as-is, we define a dummy socket interface + # which does not have such side effects. + + @classmethod + def create_mock(cls): + return asynctest.Mock(cls()) + + def bind(self, addr): + pass + + def connect(self, addr): + pass + + def close(self): + pass + + async def send(self, frame): + pass + + async def send_multipart(self, msg): + pass + + async def recv(self): + pass + + async def recv_multipart(self): + pass diff --git a/src/ai/backend/kernel/utils.py b/src/ai/backend/kernel/utils.py new file mode 100644 index 0000000000..465af07492 --- /dev/null +++ b/src/ai/backend/kernel/utils.py @@ -0,0 +1,57 @@ +import asyncio +from pathlib import Path + +from async_timeout import timeout + +__all__ = ( + 'current_loop', + 'find_executable', + 'safe_close_task', + 'wait_local_port_open', +) + + +if hasattr(asyncio, 'get_running_loop'): + current_loop = asyncio.get_running_loop # type: ignore # noqa +else: + current_loop = asyncio.get_event_loop # type: ignore # noqa + + +def find_executable(*paths): + ''' + Find the first executable regular file in the given list of paths. + ''' + for path in paths: + if isinstance(path, (str, bytes)): + path = Path(path) + if not path.exists(): + continue + for child in path.iterdir(): + if child.is_file() and child.stat().st_mode & 0o100 != 0: + return child + return None + + +async def safe_close_task(task): + if task is not None and not task.done(): + task.cancel() + await task + + +async def wait_local_port_open(port): + while True: + try: + with timeout(10.0): + reader, writer = await asyncio.open_connection('127.0.0.1', port) + except ConnectionRefusedError: + await asyncio.sleep(0.1) + continue + except asyncio.TimeoutError: + raise + except Exception: + raise + else: + writer.close() + if hasattr(writer, 'wait_closed'): + await writer.wait_closed() + break diff --git a/src/ai/backend/kernel/vendor/__init__.py b/src/ai/backend/kernel/vendor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/kernel/vendor/aws_polly/__init__.py b/src/ai/backend/kernel/vendor/aws_polly/__init__.py new file mode 100644 index 0000000000..143d8bf436 --- /dev/null +++ b/src/ai/backend/kernel/vendor/aws_polly/__init__.py @@ -0,0 +1,97 @@ +import asyncio +import ctypes +import logging +import os +import threading + +import janus + +from ... import BaseRunner +from .inproc import PollyInprocRunner + +log = logging.getLogger() + + +class Runner(BaseRunner): + + log_prefix = 'vendor.aws_polly-kernel' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inproc_runner = None + self.sentinel = object() + self.input_queue = None + self.output_queue = None + # NOTE: If credentials are missing, + # boto3 will try to use the instance role. + self.access_key = \ + self.child_env.get('AWS_ACCESS_KEY_ID', None) + self.secret_key = \ + self.child_env.get('AWS_SECRET_ACCESS_KEY', None) + os.environ['AWS_DEFAULT_REGION'] = \ + self.child_env.get('AWS_DEFAULT_REGION', 'ap-northeast-2') + + async def init_with_loop(self): + self.input_queue = janus.Queue() + self.output_queue = janus.Queue() + + async def build_heuristic(self) -> int: + raise NotImplementedError + + async def execute_heuristic(self) -> int: + raise NotImplementedError + + async def query(self, code_text) -> int: + self.ensure_inproc_runner() + await self.input_queue.async_q.put(code_text) + # Read the generated outputs until done + while True: + try: + msg = await self.output_queue.async_q.get() + except asyncio.CancelledError: + break + self.output_queue.async_q.task_done() + if msg is self.sentinel: + break + self.outsock.send_multipart(msg) + return 0 + + async def complete(self, data): + self.outsock.send_multipart([ + b'completion', + [], + ]) + + async def interrupt(self): + if self.inproc_runner is None: + log.error('No user code is running!') + return + # A dirty hack to raise an exception inside a running thread. + target_tid = self.inproc_runner.ident + if target_tid not in {t.ident for t in threading.enumerate()}: + log.error('Interrupt failed due to missing thread.') + return + affected_count = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(target_tid), + ctypes.py_object(KeyboardInterrupt)) + if affected_count == 0: + log.error('Interrupt failed due to invalid thread identity.') + elif affected_count > 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_long(target_tid), + ctypes.c_long(0)) + log.error('Interrupt broke the interpreter state -- ' + 'recommended to reset the session.') + + async def start_service(self, service_info): + return None, {} + + def ensure_inproc_runner(self): + if self.inproc_runner is None: + self.inproc_runner = PollyInprocRunner( + self.input_queue.sync_q, + self.output_queue.sync_q, + self.sentinel, + self.access_key, + self.secret_key) + self.inproc_runner.start() diff --git a/src/ai/backend/kernel/vendor/aws_polly/inproc.py b/src/ai/backend/kernel/vendor/aws_polly/inproc.py new file mode 100644 index 0000000000..17cd8f1335 --- /dev/null +++ b/src/ai/backend/kernel/vendor/aws_polly/inproc.py @@ -0,0 +1,70 @@ +import base64 +import io +import json +import logging +import threading + +from boto3 import Session +from botocore.exceptions import BotoCoreError, ClientError + +log = logging.getLogger() + + +class PollyInprocRunner(threading.Thread): + + def __init__(self, input_queue, output_queue, sentinel, + access_key, secret_key): + super().__init__(name='InprocRunner', daemon=True) + + # for interoperability with the main asyncio loop + self.input_queue = input_queue + self.output_queue = output_queue + self.sentinel = sentinel + + self.session = Session( + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + ) + self.polly = self.session.client('polly') + + def run(self): + while True: + code_text = self.input_queue.get() + request = json.loads(code_text) + + content_type = 'application/octet-stream' + encoded_audio = '' + try: + response = self.polly.synthesize_speech( + Text=request.get('text'), + VoiceId=request.get('voiceId'), + TextType=request.get('textType', 'text'), + OutputFormat='ogg_vorbis', + ) + except (BotoCoreError, ClientError) as err: + self.output_queue.put([b'stderr', str(err).encode('utf8')]) + self.output_queue.put(self.sentinel) + self.input_queue.task_done() + continue + else: + content_type = response.get('ContentType').encode('ascii') + data_stream = response.get('AudioStream') + buffer = io.BytesIO() + while True: + chunk = data_stream.read(4096) + if not chunk: + break + buffer.write(chunk) + try: + encoded_audio = (b'data:%s;base64,' % content_type) + \ + base64.b64encode(buffer.getvalue()) + buffer.close() + except Exception as e: + log.error(str(e)) + + self.output_queue.put([ + b'media', + b'{"type":"%s","data":"%s"}' % (content_type, encoded_audio), + ]) + self.output_queue.put(self.sentinel) + self.input_queue.task_done() diff --git a/src/ai/backend/kernel/vendor/h2o/__init__.py b/src/ai/backend/kernel/vendor/h2o/__init__.py new file mode 100644 index 0000000000..0dcf257d84 --- /dev/null +++ b/src/ai/backend/kernel/vendor/h2o/__init__.py @@ -0,0 +1,98 @@ +import asyncio +import logging +import os +from pathlib import Path +import tempfile +from typing import List + +from ... import BaseRunner + +log = logging.getLogger() + +DEFAULT_PYFLAGS: List[str] = [] + + +class Runner(BaseRunner): + + log_prefix = 'h2o-kernel' + default_runtime_path = '/opt/h2oai/dai/python/bin/python' + default_child_env = { + 'TERM': 'xterm', + 'LANG': 'C.UTF-8', + 'SHELL': '/bin/bash', + 'USER': 'work', + 'HOME': '/home/work', + 'PATH': ':'.join([ + '/usr/local/nvidia/bin', + '/usr/local/cuda/bin', + '/usr/local/sbin', + '/usr/local/bin', + '/usr/sbin', + '/usr/bin', + '/sbin', + '/bin', + ]), + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'LD_PRELOAD': os.environ.get('LD_PRELOAD', ''), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def init_with_loop(self): + self.user_input_queue = asyncio.Queue() + + # Load H2O Daemon. + print('Daemonizing H2O (run-dai.sh)...') + Path('/opt/h2oai/dai').mkdir(parents=True, exist_ok=True) + cmd = ['/opt/h2oai/dai/run-dai.sh'] + await self.run_subproc(cmd) + + async def build_heuristic(self) -> int: + if Path('setup.py').is_file(): + cmd = [ + str(self.runtime_path), *DEFAULT_PYFLAGS, + '-m', 'pip', 'install', '--user', '-e', '.', + ] + return await self.run_subproc(cmd) + else: + log.warning('skipping the build phase due to missing "setup.py" file') + return 0 + + async def execute_heuristic(self) -> int: + if Path('main.py').is_file(): + cmd = [ + str(self.runtime_path), *DEFAULT_PYFLAGS, + 'main.py', + ] + return await self.run_subproc(cmd) + else: + log.error('cannot find the main script ("main.py").') + return 127 + + async def start_service(self, service_info): + if service_info['name'] in ['jupyter', 'jupyterlab']: + with tempfile.NamedTemporaryFile( + 'w', encoding='utf-8', suffix='.py', delete=False) as config: + print('c.NotebookApp.allow_root = True', file=config) + print('c.NotebookApp.ip = "0.0.0.0"', file=config) + print('c.NotebookApp.port = {}'.format(service_info['port']), file=config) + print('c.NotebookApp.token = ""', file=config) + print('c.FileContentsManager.delete_to_trash = False', file=config) + print('c.NotebookApp.tornado_settings = {\'ws_ping_interval\': 10000}', file=config) + jupyter_service_type = service_info['name'] + if jupyter_service_type == 'jupyter': + jupyter_service_type = 'notebook' + return [ + self.runtime_path, '-m', jupyter_service_type, + '--no-browser', + '--config', config.name, + ], {} + elif 'h2o' in service_info['name']: + return ['echo', 'h2o daemon already started'], {} + elif service_info['name'] == 'sftp': + return [ + self.runtime_path, + '-m', 'sftpserver', + '--port', str(service_info['port']), + ], {} diff --git a/src/ai/backend/manager/BUILD b/src/ai/backend/manager/BUILD new file mode 100644 index 0000000000..50445b95e2 --- /dev/null +++ b/src/ai/backend/manager/BUILD @@ -0,0 +1,75 @@ +python_sources( + name="service", + sources=["**/*.py"], + dependencies=[ + "src/ai/backend/cli:lib", + "src/ai/backend/common:lib", + ":resources", + ], +) + +pex_binary( + name="server", + dependencies=[ + ":service", + ], + entry_point="server.py", +) + +pex_binary( + name="cli", + dependencies=[ + ":service", + ], + entry_point="cli/__main__.py", +) + +pex_binary( + name="dump-gql-schema", + dependencies=[ + ":service", + ], + entry_point="api/admin.py", +) + +python_distribution( + name="dist", + dependencies=[ + ":service", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-manager", + description="Backend.AI Manager", + license="LGPLv3", + ), + entry_points={ + "backendai_cli_v10": { + "mgr": "ai.backend.manager.cli.__main__:main", + "mgr.start-server": "ai.backend.manager.server:main", + }, + "backendai_scheduler_v10": { + "fifo": "ai.backend.manager.scheduler.fifo:FIFOSlotScheduler", + "lifo": "ai.backend.manager.scheduler.fifo:LIFOSlotScheduler", + "drf": "ai.backend.manager.scheduler.drf:DRFScheduler", + "mof": "ai.backend.manager.scheduler.mof:MOFScheduler", + }, + "backendai_error_monitor_v20": { + "intrinsic": "ai.backend.manager.plugin.error_monitor:ErrorMonitor", + }, + }, + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + ], +) diff --git a/src/ai/backend/manager/README.md b/src/ai/backend/manager/README.md new file mode 100644 index 0000000000..9e78ecf926 --- /dev/null +++ b/src/ai/backend/manager/README.md @@ -0,0 +1,263 @@ +Backend.AI Manager with API Gateway +=================================== + +Package Structure +----------------- + +* `ai.backend.manager`: Computing resource and workload management with public APIs + +Installation +------------ + +Please visit [the installation guides](https://github.com/lablup/backend.ai/wiki). + + +### Kernel/system configuration + +#### Recommended resource limits: + +**`/etc/security/limits.conf`** +``` +root hard nofile 512000 +root soft nofile 512000 +root hard nproc 65536 +root soft nproc 65536 +user hard nofile 512000 +user soft nofile 512000 +user hard nproc 65536 +user soft nproc 65536 +``` + +**sysctl** +``` +fs.file-max=2048000 +net.core.somaxconn=1024 +net.ipv4.tcp_max_syn_backlog=1024 +net.ipv4.tcp_slow_start_after_idle=0 +net.ipv4.tcp_fin_timeout=10 +net.ipv4.tcp_window_scaling=1 +net.ipv4.tcp_tw_reuse=1 +net.ipv4.tcp_early_retrans=1 +net.ipv4.ip_local_port_range="10000 65000" +net.core.rmem_max=16777216 +net.core.wmem_max=16777216 +net.ipv4.tcp_rmem=4096 12582912 16777216 +net.ipv4.tcp_wmem=4096 12582912 16777216 +``` + + +### For development + +#### Prerequisites + +* `libnsappy-dev` or `snappy-devel` system package depending on your distro +* Python 3.6 or higher with [pyenv](https://github.com/pyenv/pyenv) +and [pyenv-virtualenv](https://github.com/pyenv/pyenv-virtualenv) (optional but recommneded) +* Docker 18.03 or later with docker-compose (18.09 or later is recommended) + +#### Common steps + +Clone [the meta repository](https://github.com/lablup/backend.ai) and install a "halfstack" +configuration. The halfstack configuration installs and runs several dependency daemons such as etcd in +the background. + +```console +$ git clone https://github.com/lablup/backend.ai halfstack +$ cd halfstack +$ docker-compose -f docker-compose.halfstack.yml up -d +``` + +Then prepare the source clone of the agent as follows. +First install the current working copy. + +```console +$ git clone https://github.com/lablup/backend.ai-manager manager +$ cd manager +$ pyenv virtualenv venv-manager +$ pyenv local venv-manager +$ pip install -U pip setuptools +$ pip install -U -r requirements/dev.txt +``` + +From now on, let's assume all shell commands are executed inside the virtualenv. + +### Halfstack (single-node development & testing) + +#### Recommended directory structure + +* `backend.ai-dev` + - `manager` (git clone from this repo) + - `agent` (git clone from [the agent repo](https://github.com/lablup/backend.ai-agent)) + - `common` (git clone from [the common repo](https://github.com/lablup/backend.ai-common)) + +Install `backend.ai-common` as an editable package in the manager (and the agent) virtualenvs +to keep the codebase up-to-date. + +```console +$ cd manager +$ pip install -U -e ../common -r requirements/dev.txt +``` + +#### Steps + +Copy (or symlink) the halfstack configs: +```console +$ cp config/halfstack.toml ./manager.toml +$ cp config/halfstack.alembic.ini ./alembic.ini +``` + +Set up Redis: +```console +$ backend.ai mgr etcd put config/redis/addr 127.0.0.1:8110 +``` + +> ℹ️ NOTE: You may replace `backend.ai mgr` with `python -m ai.backend.manager.cli` in case your `PATH` is unmodifiable. + +Set up the public Docker registry: +```console +$ backend.ai mgr etcd put config/docker/registry/index.docker.io "https://registry-1.docker.io" +$ backend.ai mgr etcd put config/docker/registry/index.docker.io/username "lablup" +$ backend.ai mgr etcd rescan-images index.docker.io +``` + +Set up the vfolder paths: +```console +$ mkdir -p "$HOME/vfroot/local" +$ backend.ai mgr etcd put volumes/_mount "$HOME/vfroot" +$ backend.ai mgr etcd put volumes/_default_host local +``` + +Set up the allowed types of vfolder. Allowed values are "user" or "group". +If none is specified, "user" type is set implicitly: +```console +$ backend.ai mgr etcd put volumes/_types/user "" # enable user vfolder +$ backend.ai mgr etcd put volumes/_types/group "" # enable group vfolder +``` + +Set up the database: +```console +$ backend.ai mgr schema oneshot +$ backend.ai mgr fixture populate sample-configs/example-keypairs.json +$ backend.ai mgr fixture populate sample-configs/example-resource-presets.json +``` + +Then, run it (for debugging, append a `--debug` flag): + +```console +$ backend.ai mgr start-server +``` + +To run tests: + +```console +$ python -m flake8 src tests +$ python -m pytest -m 'not integration' tests +``` + +Now you are ready to install the agent. +Head to [the README of Backend.AI Agent](https://github.com/lablup/backend.ai-agent/blob/master/README.md). + +NOTE: To run tests including integration tests, you first need to install and run the agent on the same host. + +## Deployment + +### Configuration + +Put a TOML-formatted manager configuration (see the sample in `config/sample.toml`) +in one of the following locations: + + * `manager.toml` (current working directory) + * `~/.config/backend.ai/manager.toml` (user-config directory) + * `/etc/backend.ai/manager.toml` (system-config directory) + +Only the first found one is used by the daemon. + +Also many configurations shared by both manager and agent are stored in etcd. +As you might have noticed above, the manager provides a CLI interface to access and manipulate the etcd +data. Check out the help page of our etcd command set: + +```console +$ python -m ai.backend.manager.cli etcd --help +``` + +If you run etcd as a Docker container (e.g., via halfstack), you may use the native client as well. +In this case, PLEASE BE WARNED that you must prefix the keys with "/sorna/{namespace}" manaully: + +```console +$ docker exec -it ${ETCD_CONTAINER_ID} /bin/ash -c 'ETCDCTL_API=3 etcdctl ...' +``` + +### Running from a command line + +The minimal command to execute: + +```sh +python -m ai.backend.gateway.server +``` + +For more arguments and options, run the command with `--help` option. + +### Writing a wrapper script + +To use with systemd, crontab, and other system-level daemons, you may need to write a shell script +that executes specific CLI commands provided by Backend.AI modules. + +The following example shows how to set up pyenv and virtualenv for the script-local environment. +It runs the gateway server if no arguments are given, and execute the given arguments as a shell command +if any. +For instance, you may get/set configurations like: `run-manager.sh python -m ai.backend.manager.etcd ...` +where the name of scripts is `run-manager.sh`. + +```bash +#! /bin/bash +if [ -z "$HOME" ]; then + export HOME="/home/devops" +fi +if [ -z "$PYENV_ROOT" ]; then + export PYENV_ROOT="$HOME/.pyenv" + export PATH="$PYENV_ROOT/bin:$PATH" +fi +eval "$(pyenv init -)" +eval "$(pyenv virtualenv-init -)" +pyenv activate venv-bai-manager + +if [ "$#" -eq 0 ]; then + exec python -m ai.backend.gateway.server +else + exec "$@" +fi +``` + +### Networking + +The manager and agent should run in the same local network or different +networks reachable via VPNs, whereas the manager's API service must be exposed to +the public network or another private network that users have access to. + +The manager requires access to the etcd, the PostgreSQL database, and the Redis server. + +| User-to-Manager TCP Ports | Usage | +|:-------------------------:|-------| +| manager:{80,443} | Backend.AI API access | + +| Manager-to-X TCP Ports | Usage | +|:----------------------:|-------| +| etcd:2379 | etcd API access | +| postgres:5432 | Database access | +| redis:6379 | Redis API access | + +The manager must also be able to access TCP ports 6001, 6009, and 30000 to 31000 of the agents in default +configurations. You can of course change those port numbers and ranges in the configuration. + +| Manager-to-Agent TCP Ports | Usage | +|:--------------------------:|-------| +| 6001 | ZeroMQ-based RPC calls from managers to agents | +| 6009 | HTTP watcher API | +| 30000-31000 | Port pool for in-container services | + + +LICENSES +-------- + +[GNU Lesser General Public License](https://github.com/lablup/backend.ai-manager/blob/master/LICENSE) +[Dependencies](https://github.com/lablup/backend.ai-manager/blob/master/DEPENDENCIES.md) diff --git a/src/ai/backend/manager/VERSION b/src/ai/backend/manager/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/manager/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/manager/__init__.py b/src/ai/backend/manager/__init__.py new file mode 100644 index 0000000000..17b3552989 --- /dev/null +++ b/src/ai/backend/manager/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +__version__ = (Path(__file__).parent / 'VERSION').read_text().strip() diff --git a/src/ai/backend/manager/api/__init__.py b/src/ai/backend/manager/api/__init__.py new file mode 100644 index 0000000000..83bfffa83e --- /dev/null +++ b/src/ai/backend/manager/api/__init__.py @@ -0,0 +1,8 @@ +import enum + + +class ManagerStatus(str, enum.Enum): + TERMINATED = 'terminated' # deprecated + PREPARING = 'preparing' # deprecated + RUNNING = 'running' + FROZEN = 'frozen' diff --git a/src/ai/backend/manager/api/admin.py b/src/ai/backend/manager/api/admin.py new file mode 100644 index 0000000000..8aecf57c48 --- /dev/null +++ b/src/ai/backend/manager/api/admin.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import inspect +import logging +import re +from typing import ( + Any, + Iterable, + TYPE_CHECKING, + Tuple, +) + +from aiohttp import web +import aiohttp_cors +import attr +import graphene +from graphql.execution.executors.asyncio import AsyncioExecutor +from graphql.execution import ExecutionResult +from graphql.error import GraphQLError, format_error +import trafaret as t + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common import validators as tx + +from ..models.base import DataLoaderManager +from ..models.gql import ( + Mutations, Queries, + GraphQueryContext, + GQLMutationPrivilegeCheckMiddleware, +) +from .manager import GQLMutationUnfrozenRequiredMiddleware +from .exceptions import GraphQLError as BackendGQLError +from .auth import auth_required +from .types import CORSOptions, WebMiddleware +from .utils import check_api_params +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +_rx_mutation_hdr = re.compile(r"^mutation(\s+\w+)?\s*(\(|{|@)", re.M) + + +class GQLLoggingMiddleware: + + def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: + graph_ctx: GraphQueryContext = info.context + if len(info.path) == 1: + log.info('ADMIN.GQL (ak:{}, {}:{}, op:{})', + graph_ctx.access_key, + info.operation.operation, + info.field_name, + info.operation.name) + return next(root, info, **args) + + +async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResult: + root_ctx: RootContext = request.app['_root.context'] + app_ctx: PrivateContext = request.app['admin.context'] + manager_status = await root_ctx.shared_config.get_manager_status() + known_slot_types = await root_ctx.shared_config.get_resource_slots() + + gql_ctx = GraphQueryContext( + schema=app_ctx.gql_schema, + dataloader_manager=DataLoaderManager(), + local_config=root_ctx.local_config, + shared_config=root_ctx.shared_config, + etcd=root_ctx.shared_config.etcd, + user=request['user'], + access_key=request['keypair']['access_key'], + db=root_ctx.db, + redis_stat=root_ctx.redis_stat, + redis_image=root_ctx.redis_image, + manager_status=manager_status, + known_slot_types=known_slot_types, + background_task_manager=root_ctx.background_task_manager, + storage_manager=root_ctx.storage_manager, + registry=root_ctx.registry, + ) + result = app_ctx.gql_schema.execute( + params['query'], + app_ctx.gql_executor, + variable_values=params['variables'], + operation_name=params['operation_name'], + context_value=gql_ctx, + middleware=[ + GQLLoggingMiddleware(), + GQLMutationUnfrozenRequiredMiddleware(), + GQLMutationPrivilegeCheckMiddleware(), + ], + return_promise=True) + if inspect.isawaitable(result): + result = await result + return result + + +@auth_required +@check_api_params( + t.Dict({ + t.Key('query'): t.String, + t.Key('variables', default=None): t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['operation_name', 'operationName'], default=None): t.Null | t.String, + })) +async def handle_gql(request: web.Request, params: Any) -> web.Response: + result = await _handle_gql_common(request, params) + return web.json_response(result.to_dict(), status=200) + + +@auth_required +@check_api_params( + t.Dict({ + t.Key('query'): t.String, + t.Key('variables', default=None): t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['operation_name', 'operationName'], default=None): t.Null | t.String, + })) +async def handle_gql_legacy(request: web.Request, params: Any) -> web.Response: + # FIXME: remove in v21.09 + result = await _handle_gql_common(request, params) + if result.errors: + errors = [] + for e in result.errors: + if isinstance(e, GraphQLError): + errmsg = format_error(e) + errors.append(errmsg) + else: + errmsg = {'message': str(e)} + errors.append(errmsg) + log.error('ADMIN.GQL Exception: {}', errmsg) + raise BackendGQLError(extra_data=errors) + return web.json_response(result.data, status=200) + + +@attr.s(auto_attribs=True, slots=True, init=False) +class PrivateContext: + gql_executor: AsyncioExecutor + gql_schema: graphene.Schema + + +async def init(app: web.Application) -> None: + app_ctx: PrivateContext = app['admin.context'] + app_ctx.gql_executor = AsyncioExecutor() + app_ctx.gql_schema = graphene.Schema( + query=Queries, + mutation=Mutations, + auto_camelcase=False, + ) + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['admin.context'] = PrivateContext() + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('POST', r'/graphql', handle_gql_legacy)) + cors.add(app.router.add_route('POST', r'/gql', handle_gql)) + return app, [] + + +if __name__ == '__main__': + # If executed as a main program, print all GraphQL schemas. + # (graphene transforms our object model into a textual representation) + # This is useful for writing documentation! + schema = graphene.Schema( + query=Queries, + mutation=Mutations, + auto_camelcase=False) + print('======== GraphQL API Schema ========') + print(str(schema)) diff --git a/src/ai/backend/manager/api/auth.py b/src/ai/backend/manager/api/auth.py new file mode 100644 index 0000000000..0f58916eab --- /dev/null +++ b/src/ai/backend/manager/api/auth.py @@ -0,0 +1,981 @@ +from collections import ChainMap +from datetime import datetime, timedelta +import functools +import hashlib, hmac +import logging +import secrets +from typing import ( + Any, + Final, + Iterable, + Mapping, + TYPE_CHECKING, + Tuple, + cast, +) + +from aiohttp import web +import aiohttp_cors +from aioredis import Redis +from aioredis.client import Pipeline as RedisPipeline +from dateutil.tz import tzutc +from dateutil.parser import parse as dtparse +import sqlalchemy as sa +import trafaret as t + +from ai.backend.common import redis, validators as tx +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.plugin.hook import ( + ALL_COMPLETED, + FIRST_COMPLETED, + PASSED, +) + +from ..models import ( + keypairs, keypair_resource_policies, users, +) +from ..models.user import UserRole, UserStatus, INACTIVE_USER_STATUSES, check_credential +from ..models.keypair import generate_keypair as _gen_keypair, generate_ssh_keypair +from ..models.group import association_groups_users, groups +from ..models.utils import execute_with_retry +from .exceptions import ( + AuthorizationFailed, + GenericBadRequest, + GenericForbidden, + ObjectNotFound, + InternalServerError, + InvalidAuthParameters, + InvalidAPIParameters, + RejectedByHook, +) +from .types import CORSOptions, WebMiddleware +from .utils import check_api_params, set_handler_attr, get_handler_attr + +if TYPE_CHECKING: + from .context import RootContext + +log: Final = BraceStyleAdapter(logging.getLogger(__name__)) + +whois_timezone_info: Final = { + "A": 1 * 3600, + "ACDT": 10.5 * 3600, + "ACST": 9.5 * 3600, + "ACT": -5 * 3600, + "ACWST": 8.75 * 3600, + "ADT": 4 * 3600, + "AEDT": 11 * 3600, + "AEST": 10 * 3600, + "AET": 10 * 3600, + "AFT": 4.5 * 3600, + "AKDT": -8 * 3600, + "AKST": -9 * 3600, + "ALMT": 6 * 3600, + "AMST": -3 * 3600, + "AMT": -4 * 3600, + "ANAST": 12 * 3600, + "ANAT": 12 * 3600, + "AQTT": 5 * 3600, + "ART": -3 * 3600, + "AST": 3 * 3600, + "AT": -4 * 3600, + "AWDT": 9 * 3600, + "AWST": 8 * 3600, + "AZOST": 0 * 3600, + "AZOT": -1 * 3600, + "AZST": 5 * 3600, + "AZT": 4 * 3600, + "AoE": -12 * 3600, + "B": 2 * 3600, + "BNT": 8 * 3600, + "BOT": -4 * 3600, + "BRST": -2 * 3600, + "BRT": -3 * 3600, + "BST": 6 * 3600, + "BTT": 6 * 3600, + "C": 3 * 3600, + "CAST": 8 * 3600, + "CAT": 2 * 3600, + "CCT": 6.5 * 3600, + "CDT": -5 * 3600, + "CEST": 2 * 3600, + "CET": 1 * 3600, + "CHADT": 13.75 * 3600, + "CHAST": 12.75 * 3600, + "CHOST": 9 * 3600, + "CHOT": 8 * 3600, + "CHUT": 10 * 3600, + "CIDST": -4 * 3600, + "CIST": -5 * 3600, + "CKT": -10 * 3600, + "CLST": -3 * 3600, + "CLT": -4 * 3600, + "COT": -5 * 3600, + "CST": -6 * 3600, + "CT": -6 * 3600, + "CVT": -1 * 3600, + "CXT": 7 * 3600, + "ChST": 10 * 3600, + "D": 4 * 3600, + "DAVT": 7 * 3600, + "DDUT": 10 * 3600, + "E": 5 * 3600, + "EASST": -5 * 3600, + "EAST": -6 * 3600, + "EAT": 3 * 3600, + "ECT": -5 * 3600, + "EDT": -4 * 3600, + "EEST": 3 * 3600, + "EET": 2 * 3600, + "EGST": 0 * 3600, + "EGT": -1 * 3600, + "EST": -5 * 3600, + "ET": -5 * 3600, + "F": 6 * 3600, + "FET": 3 * 3600, + "FJST": 13 * 3600, + "FJT": 12 * 3600, + "FKST": -3 * 3600, + "FKT": -4 * 3600, + "FNT": -2 * 3600, + "G": 7 * 3600, + "GALT": -6 * 3600, + "GAMT": -9 * 3600, + "GET": 4 * 3600, + "GFT": -3 * 3600, + "GILT": 12 * 3600, + "GMT": 0 * 3600, + "GST": 4 * 3600, + "GYT": -4 * 3600, + "H": 8 * 3600, + "HDT": -9 * 3600, + "HKT": 8 * 3600, + "HOVST": 8 * 3600, + "HOVT": 7 * 3600, + "HST": -10 * 3600, + "I": 9 * 3600, + "ICT": 7 * 3600, + "IDT": 3 * 3600, + "IOT": 6 * 3600, + "IRDT": 4.5 * 3600, + "IRKST": 9 * 3600, + "IRKT": 8 * 3600, + "IRST": 3.5 * 3600, + "IST": 5.5 * 3600, + "JST": 9 * 3600, + "K": 10 * 3600, + "KGT": 6 * 3600, + "KOST": 11 * 3600, + "KRAST": 8 * 3600, + "KRAT": 7 * 3600, + "KST": 9 * 3600, + "KUYT": 4 * 3600, + "L": 11 * 3600, + "LHDT": 11 * 3600, + "LHST": 10.5 * 3600, + "LINT": 14 * 3600, + "M": 12 * 3600, + "MAGST": 12 * 3600, + "MAGT": 11 * 3600, + "MART": 9.5 * 3600, + "MAWT": 5 * 3600, + "MDT": -6 * 3600, + "MHT": 12 * 3600, + "MMT": 6.5 * 3600, + "MSD": 4 * 3600, + "MSK": 3 * 3600, + "MST": -7 * 3600, + "MT": -7 * 3600, + "MUT": 4 * 3600, + "MVT": 5 * 3600, + "MYT": 8 * 3600, + "N": -1 * 3600, + "NCT": 11 * 3600, + "NDT": 2.5 * 3600, + "NFT": 11 * 3600, + "NOVST": 7 * 3600, + "NOVT": 7 * 3600, + "NPT": 5.5 * 3600, + "NRT": 12 * 3600, + "NST": 3.5 * 3600, + "NUT": -11 * 3600, + "NZDT": 13 * 3600, + "NZST": 12 * 3600, + "O": -2 * 3600, + "OMSST": 7 * 3600, + "OMST": 6 * 3600, + "ORAT": 5 * 3600, + "P": -3 * 3600, + "PDT": -7 * 3600, + "PET": -5 * 3600, + "PETST": 12 * 3600, + "PETT": 12 * 3600, + "PGT": 10 * 3600, + "PHOT": 13 * 3600, + "PHT": 8 * 3600, + "PKT": 5 * 3600, + "PMDT": -2 * 3600, + "PMST": -3 * 3600, + "PONT": 11 * 3600, + "PST": -8 * 3600, + "PT": -8 * 3600, + "PWT": 9 * 3600, + "PYST": -3 * 3600, + "PYT": -4 * 3600, + "Q": -4 * 3600, + "QYZT": 6 * 3600, + "R": -5 * 3600, + "RET": 4 * 3600, + "ROTT": -3 * 3600, + "S": -6 * 3600, + "SAKT": 11 * 3600, + "SAMT": 4 * 3600, + "SAST": 2 * 3600, + "SBT": 11 * 3600, + "SCT": 4 * 3600, + "SGT": 8 * 3600, + "SRET": 11 * 3600, + "SRT": -3 * 3600, + "SST": -11 * 3600, + "SYOT": 3 * 3600, + "T": -7 * 3600, + "TAHT": -10 * 3600, + "TFT": 5 * 3600, + "TJT": 5 * 3600, + "TKT": 13 * 3600, + "TLT": 9 * 3600, + "TMT": 5 * 3600, + "TOST": 14 * 3600, + "TOT": 13 * 3600, + "TRT": 3 * 3600, + "TVT": 12 * 3600, + "U": -8 * 3600, + "ULAST": 9 * 3600, + "ULAT": 8 * 3600, + "UTC": 0 * 3600, + "UYST": -2 * 3600, + "UYT": -3 * 3600, + "UZT": 5 * 3600, + "V": -9 * 3600, + "VET": -4 * 3600, + "VLAST": 11 * 3600, + "VLAT": 10 * 3600, + "VOST": 6 * 3600, + "VUT": 11 * 3600, + "W": -10 * 3600, + "WAKT": 12 * 3600, + "WARST": -3 * 3600, + "WAST": 2 * 3600, + "WAT": 1 * 3600, + "WEST": 1 * 3600, + "WET": 0 * 3600, + "WFT": 12 * 3600, + "WGST": -2 * 3600, + "WGT": -3 * 3600, + "WIB": 7 * 3600, + "WIT": 9 * 3600, + "WITA": 8 * 3600, + "WST": 14 * 3600, + "WT": 0 * 3600, + "X": -11 * 3600, + "Y": -12 * 3600, + "YAKST": 10 * 3600, + "YAKT": 9 * 3600, + "YAPT": 10 * 3600, + "YEKST": 6 * 3600, + "YEKT": 5 * 3600, + "Z": 0 * 3600, +} + + +def _extract_auth_params(request): + """ + HTTP Authorization header must be formatted as: + "Authorization: BackendAI signMethod=HMAC-SHA256, + credential=:" + """ + auth_hdr = request.headers.get('Authorization') + if not auth_hdr: + return None + pieces = auth_hdr.split(' ', 1) + if len(pieces) != 2: + raise InvalidAuthParameters('Malformed authorization header') + auth_type, auth_str = pieces + if auth_type not in ('BackendAI', 'Sorna'): + raise InvalidAuthParameters('Invalid authorization type name') + + raw_params = map(lambda s: s.strip(), auth_str.split(',')) + params = {} + for param in raw_params: + key, value = param.split('=', 1) + params[key.strip()] = value.strip() + + try: + access_key, signature = params['credential'].split(':', 1) + ret = params['signMethod'], access_key, signature + return ret + except (KeyError, ValueError): + raise InvalidAuthParameters('Missing or malformed authorization parameters') + + +def check_date(request: web.Request) -> bool: + raw_date = request.headers.get('Date') + if not raw_date: + raw_date = request.headers.get('X-BackendAI-Date', + request.headers.get('X-Sorna-Date')) + if not raw_date: + return False + try: + # HTTP standard says "Date" header must be in GMT only. + # However, dateutil.parser can recognize other commonly used + # timezone names and offsets. + date = dtparse(raw_date, tzinfos=whois_timezone_info) + if date.tzinfo is None: + date = date.replace(tzinfo=tzutc()) # assume as UTC + now = datetime.now(tzutc()) + min_date = now - timedelta(minutes=15) + max_date = now + timedelta(minutes=15) + request['date'] = date + request['raw_date'] = raw_date + if not (min_date < date < max_date): + return False + except ValueError: + return False + return True + + +async def sign_request(sign_method: str, request: web.Request, secret_key: str) -> str: + try: + mac_type, hash_type = map(lambda s: s.lower(), sign_method.split('-')) + assert mac_type == 'hmac', 'Unsupported request signing method (MAC type)' + assert hash_type in hashlib.algorithms_guaranteed, \ + 'Unsupported request signing method (hash type)' + + new_api_version = request.headers.get('X-BackendAI-Version') + legacy_api_version = request.headers.get('X-Sorna-Version') + api_version = new_api_version or legacy_api_version + assert api_version is not None, 'API version missing in request headers' + body = b'' + if api_version < 'v4.20181215': + if (request.can_read_body and + request.content_type != 'multipart/form-data'): + # read the whole body if neither streaming nor bodyless + body = await request.read() + body_hash = hashlib.new(hash_type, body).hexdigest() + + sign_bytes = ('{0}\n{1}\n{2}\nhost:{3}\ncontent-type:{4}\n' + 'x-{name}-version:{5}\n{6}').format( + request.method, str(request.raw_path), request['raw_date'], + request.host, request.content_type, api_version, + body_hash, + name='backendai' if new_api_version is not None else 'sorna', + ).encode() + sign_key = hmac.new(secret_key.encode(), + request['date'].strftime('%Y%m%d').encode(), + hash_type).digest() + sign_key = hmac.new(sign_key, request.host.encode(), hash_type).digest() + return hmac.new(sign_key, sign_bytes, hash_type).hexdigest() + except ValueError: + raise AuthorizationFailed('Invalid signature') + except AssertionError as e: + raise InvalidAuthParameters(e.args[0]) + + +@web.middleware +async def auth_middleware(request: web.Request, handler) -> web.StreamResponse: + """ + Fetches user information and sets up keypair, uesr, and is_authorized + attributes. + """ + # This is a global middleware: request.app is the root app. + root_ctx: RootContext = request.app['_root.context'] + request['is_authorized'] = False + request['is_admin'] = False + request['is_superadmin'] = False + request['keypair'] = None + request['user'] = None + if not get_handler_attr(request, 'auth_required', False): + return (await handler(request)) + if not check_date(request): + raise InvalidAuthParameters('Date/time sync error') + + # PRE_AUTH_MIDDLEWARE allows authentication via 3rd-party request headers/cookies. + # Any responsible hook must return a valid keypair. + hook_result = await root_ctx.hook_plugin_ctx.dispatch( + 'PRE_AUTH_MIDDLEWARE', + (request,), + return_when=FIRST_COMPLETED, + ) + row = None + if hook_result.status != PASSED: + raise RejectedByHook.from_hook_result(hook_result) + elif hook_result.result: + # Passed one of the hook. + # The "None" access_key means that the hook has allowed anonymous access. + access_key = hook_result.result + if access_key is not None: + async def _query_cred(): + async with root_ctx.db.begin_readonly() as conn: + j = ( + keypairs + .join(users, keypairs.c.user == users.c.uuid) + .join( + keypair_resource_policies, + keypairs.c.resource_policy == keypair_resource_policies.c.name, + ) + ) + query = ( + sa.select([users, keypairs, keypair_resource_policies], use_labels=True) + .select_from(j) + .where( + (keypairs.c.access_key == access_key) & + (keypairs.c.is_active.is_(True)), + ) + ) + result = await conn.execute(query) + return result.first() + + row = await execute_with_retry(_query_cred) + if row is None: + raise AuthorizationFailed('Access key not found') + + async def _pipe_builder(r: Redis) -> RedisPipeline: + pipe = r.pipeline() + num_queries_key = f'kp:{access_key}:num_queries' + pipe.incr(num_queries_key) + pipe.expire(num_queries_key, 86400 * 30) # retention: 1 month + return pipe + + await redis.execute(root_ctx.redis_stat, _pipe_builder) + else: + # unsigned requests may be still accepted for public APIs + pass + else: + # There were no hooks configured. + # Perform our own authentication. + params = _extract_auth_params(request) + if params: + sign_method, access_key, signature = params + + async def _query_cred(): + async with root_ctx.db.begin_readonly() as conn: + j = ( + keypairs + .join(users, keypairs.c.user == users.c.uuid) + .join(keypair_resource_policies, + keypairs.c.resource_policy == keypair_resource_policies.c.name) + ) + query = ( + sa.select([users, keypairs, keypair_resource_policies], use_labels=True) + .select_from(j) + .where( + (keypairs.c.access_key == access_key) & + (keypairs.c.is_active.is_(True)), + ) + ) + result = await conn.execute(query) + return result.first() + + row = await execute_with_retry(_query_cred) + if row is None: + raise AuthorizationFailed('Access key not found') + my_signature = \ + await sign_request(sign_method, request, row['keypairs_secret_key']) + if not secrets.compare_digest(my_signature, signature): + raise AuthorizationFailed('Signature mismatch') + + async def _pipe_builder(r: Redis) -> RedisPipeline: + pipe = r.pipeline() + num_queries_key = f'kp:{access_key}:num_queries' + pipe.incr(num_queries_key) + pipe.expire(num_queries_key, 86400 * 30) # retention: 1 month + return pipe + + await redis.execute(root_ctx.redis_stat, _pipe_builder) + else: + # unsigned requests may be still accepted for public APIs + pass + + if row is not None: + auth_result = { + 'is_authorized': True, + 'keypair': { + col.name: row[f'keypairs_{col.name}'] + for col in keypairs.c + if col.name != 'secret_key' + }, + 'user': { + col.name: row[f'users_{col.name}'] + for col in users.c + if col.name not in ('password', 'description', 'created_at') + }, + 'is_admin': row['keypairs_is_admin'], + } + auth_result['keypair']['resource_policy'] = { + col.name: row[f'keypair_resource_policies_{col.name}'] + for col in keypair_resource_policies.c + } + auth_result['user']['id'] = row['keypairs_user_id'] # legacy + auth_result['is_superadmin'] = (auth_result['user']['role'] == 'superadmin') + # Populate the result to the per-request state dict. + request.update(auth_result) + + # No matter if authenticated or not, pass-through to the handler. + # (if it's required, auth_required decorator will handle the situation.) + return (await handler(request)) + + +def auth_required(handler): + + @functools.wraps(handler) + async def wrapped(request, *args, **kwargs): + if request.get('is_authorized', False): + return (await handler(request, *args, **kwargs)) + raise AuthorizationFailed('Unauthorized access') + + set_handler_attr(wrapped, 'auth_required', True) + return wrapped + + +def admin_required(handler): + + @functools.wraps(handler) + async def wrapped(request, *args, **kwargs): + if request.get('is_authorized', False) and request.get('is_admin', False): + return (await handler(request, *args, **kwargs)) + raise AuthorizationFailed('Unauthorized access') + + set_handler_attr(wrapped, 'auth_required', True) + return wrapped + + +def superadmin_required(handler): + + @functools.wraps(handler) + async def wrapped(request, *args, **kwargs): + if request.get('is_authorized', False) and request.get('is_superadmin', False): + return (await handler(request, *args, **kwargs)) + raise AuthorizationFailed('Unauthorized access') + + set_handler_attr(wrapped, 'auth_required', True) + return wrapped + + +@auth_required +@check_api_params( + t.Dict({ + t.Key('echo'): t.String, + })) +async def test(request: web.Request, params: Any) -> web.Response: + log.info('AUTH.TEST(ak:{})', request['keypair']['access_key']) + resp_data = {'authorized': 'yes'} + if 'echo' in params: + resp_data['echo'] = params['echo'] + return web.json_response(resp_data) + + +@auth_required +@check_api_params( + t.Dict({ + t.Key('group', default=None): t.Null | tx.UUID, + })) +async def get_role(request: web.Request, params: Any) -> web.Response: + group_role = None + root_ctx: RootContext = request.app['_root.context'] + log.info( + 'AUTH.ROLES(ak:{}, d:{}, g:{})', + request['keypair']['access_key'], + request['user']['domain_name'], + params['group'], + ) + if params['group'] is not None: + query = ( + # TODO: per-group role is not yet implemented. + sa.select([association_groups_users.c.group_id]) + .select_from(association_groups_users) + .where( + (association_groups_users.c.group_id == params['group']) & + (association_groups_users.c.user_id == request['user']['uuid']), + ) + ) + async with root_ctx.db.begin() as conn: + result = await conn.execute(query) + row = result.first() + if row is None: + raise ObjectNotFound( + extra_msg='No such project or you are not the member of it.', + object_name='project (user group)', + ) + group_role = 'user' + resp_data = { + 'global_role': 'superadmin' if request['is_superadmin'] else 'user', + 'domain_role': 'admin' if request['is_admin'] else 'user', + 'group_role': group_role, + } + return web.json_response(resp_data) + + +@check_api_params( + t.Dict({ + t.Key('type'): t.Enum('keypair', 'jwt'), + t.Key('domain'): t.String, + t.Key('username'): t.String, + t.Key('password'): t.String, + })) +async def authorize(request: web.Request, params: Any) -> web.Response: + if params['type'] != 'keypair': + # other types are not implemented yet. + raise InvalidAPIParameters('Unsupported authorization type') + log.info('AUTH.AUTHORIZE(d:{0[domain]}, u:{0[username]}, passwd:****, type:{0[type]})', params) + root_ctx: RootContext = request.app['_root.context'] + + # [Hooking point for AUTHORIZE with the FIRST_COMPLETED requirement] + # The hook handlers should accept the whole ``params`` dict, and optional + # ``db`` parameter (if the hook needs to query to database). + # They should return a corresponding Backend.AI user object after performing + # their own authentication steps, like LDAP authentication, etc. + hook_result = await root_ctx.hook_plugin_ctx.dispatch( + 'AUTHORIZE', + (request, params), + return_when=FIRST_COMPLETED, + ) + if hook_result.status != PASSED: + raise RejectedByHook.from_hook_result(hook_result) + elif hook_result.result: + # Passed one of AUTHORIZED hook + user = hook_result.result + else: + # No AUTHORIZE hook is defined (proceed with normal login) + user = await check_credential( + root_ctx.db, + params['domain'], params['username'], params['password'], + ) + if user is None: + raise AuthorizationFailed('User credential mismatch.') + if user['status'] == UserStatus.BEFORE_VERIFICATION: + raise AuthorizationFailed('This account needs email verification.') + if user['status'] in INACTIVE_USER_STATUSES: + raise AuthorizationFailed('User credential mismatch.') + async with root_ctx.db.begin() as conn: + query = (sa.select([keypairs.c.access_key, keypairs.c.secret_key]) + .select_from(keypairs) + .where( + (keypairs.c.user == user['uuid']) & + (keypairs.c.is_active), + ) + .order_by(sa.desc(keypairs.c.is_admin))) + result = await conn.execute(query) + keypair = result.first() + if keypair is None: + raise AuthorizationFailed('No API keypairs found.') + # [Hooking point for POST_AUTHORIZE as one-way notification] + # The hook handlers should accept a tuple of the request, user, and keypair objects. + await root_ctx.hook_plugin_ctx.notify( + 'POST_AUTHORIZE', + (request, user, keypair), + ) + return web.json_response({ + 'data': { + 'access_key': keypair['access_key'], + 'secret_key': keypair['secret_key'], + 'role': user['role'], + 'status': user['status'], + }, + }) + + +@check_api_params( + t.Dict({ + t.Key('domain'): t.String, + t.Key('email'): t.String, + t.Key('password'): t.String, + }).allow_extra('*')) +async def signup(request: web.Request, params: Any) -> web.Response: + log_fmt = 'AUTH.SIGNUP(d:{}, email:{}, passwd:****)' + log_args = (params['domain'], params['email']) + log.info(log_fmt, *log_args) + root_ctx: RootContext = request.app['_root.context'] + + # [Hooking point for PRE_SIGNUP with the ALL_COMPLETED requirement] + # The hook handlers should accept the whole ``params`` dict. + # They should return a dict to override the user information, + # where the keys must be a valid field name of the users table, + # with two exceptions: "resource_policy" (name) and "group" (name). + # A plugin may return an empty dict if it has nothing to override. + hook_result = await root_ctx.hook_plugin_ctx.dispatch( + 'PRE_SIGNUP', + (params, ), + return_when=ALL_COMPLETED, + ) + if hook_result.status != PASSED: + raise RejectedByHook.from_hook_result(hook_result) + else: + # Merge the hook results as a single map. + user_data_overriden = ChainMap(*cast(Mapping, hook_result.result)) + + async with root_ctx.db.begin() as conn: + # Check if email already exists. + query = (sa.select([users]) + .select_from(users) + .where((users.c.email == params['email']))) + result = await conn.execute(query) + row = result.first() + if row is not None: + raise GenericBadRequest('Email already exists') + + # Create a user. + data = { + 'domain_name': params['domain'], + 'username': params['username'] if 'username' in params else params['email'], + 'email': params['email'], + 'password': params['password'], + 'need_password_change': False, + 'full_name': params['full_name'] if 'full_name' in params else '', + 'description': params['description'] if 'description' in params else '', + 'status': UserStatus.ACTIVE, + 'status_info': 'user-signup', + 'role': UserRole.USER, + 'integration_id': None, + } + if user_data_overriden: + for key, val in user_data_overriden.items(): + if key in data: # take only valid fields + data[key] = val + query = (users.insert().values(data)) + result = await conn.execute(query) + if result.rowcount > 0: + checkq = users.select().where(users.c.email == params['email']) + result = await conn.execute(checkq) + user = result.first() + # Create user's first access_key and secret_key. + ak, sk = _gen_keypair() + resource_policy = ( + user_data_overriden.get('resource_policy', 'default') + ) + kp_data = { + 'user_id': params['email'], + 'access_key': ak, + 'secret_key': sk, + 'is_active': True if data.get('status') == UserStatus.ACTIVE else False, + 'is_admin': False, + 'resource_policy': resource_policy, + 'rate_limit': 1000, + 'num_queries': 0, + 'user': user.uuid, + } + query = (keypairs.insert().values(kp_data)) + await conn.execute(query) + + # Add user to the default group. + group_name = user_data_overriden.get('group', 'default') + query = (sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == params['domain']) + .where(groups.c.name == group_name)) + result = await conn.execute(query) + grp = result.first() + if grp is not None: + values = [{'user_id': user.uuid, 'group_id': grp.id}] + query = association_groups_users.insert().values(values) + await conn.execute(query) + else: + raise InternalServerError('Error creating user account') + + resp_data = { + 'access_key': ak, + 'secret_key': sk, + } + + # [Hooking point for POST_SIGNUP as one-way notification] + # The hook handlers should accept a tuple of the user email, + # the new user's UUID, and a dict with initial user's preferences. + initial_user_prefs = { + 'lang': request.headers.get('Accept-Language', 'en-us').split(',')[0].lower(), + } + await root_ctx.hook_plugin_ctx.notify( + 'POST_SIGNUP', + (params['email'], user.uuid, initial_user_prefs), + ) + return web.json_response(resp_data, status=201) + + +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['email', 'username']): t.String, + t.Key('password'): t.String, + })) +async def signout(request: web.Request, params: Any) -> web.Response: + domain_name = request['user']['domain_name'] + log.info('AUTH.SIGNOUT(d:{}, email:{})', domain_name, params['email']) + root_ctx: RootContext = request.app['_root.context'] + if request['user']['email'] != params['email']: + raise GenericForbidden('Not the account owner') + result = await check_credential( + root_ctx.db, + domain_name, params['email'], params['password']) + if result is None: + raise GenericBadRequest('Invalid email and/or password') + async with root_ctx.db.begin() as conn: + # Inactivate the user. + query = ( + users.update() + .values(status=UserStatus.INACTIVE) + .where(users.c.email == params['email']) + ) + await conn.execute(query) + # Inactivate every keypairs of the user. + query = ( + keypairs.update() + .values(is_active=False) + .where(keypairs.c.user_id == params['email']) + ) + await conn.execute(query) + return web.json_response({}) + + +@auth_required +@check_api_params( + t.Dict({ + t.Key('email'): t.String, + t.Key('full_name'): t.String, + })) +async def update_full_name(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + domain_name = request['user']['domain_name'] + email = request['user']['email'] + log_fmt = 'AUTH.UPDATE_FULL_NAME(d:{}, email:{})' + log_args = (domain_name, email) + log.info(log_fmt, *log_args) + async with root_ctx.db.begin() as conn: + query = ( + sa.select([users]) + .select_from(users) + .where( + (users.c.email == email) & + (users.c.domain_name == domain_name), + ) + ) + result = await conn.execute(query) + user = result.first() + if user is None: + log.info(log_fmt + ': Unknown user', *log_args) + return web.json_response({'error_msg': 'Unknown user'}, status=400) + + # If user is not null, then it updates user full_name. + data = { + 'full_name': params['full_name'], + } + update_query = (users.update().values(data).where(users.c.email == email)) + await conn.execute(update_query) + return web.json_response({}, status=200) + + +@auth_required +@check_api_params( + t.Dict({ + t.Key('old_password'): t.String, + t.Key('new_password'): t.String, + t.Key('new_password2'): t.String, + })) +async def update_password(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + domain_name = request['user']['domain_name'] + email = request['user']['email'] + log_fmt = 'AUTH.UDPATE_PASSWORD(d:{}, email:{})' + log_args = (domain_name, email) + log.info(log_fmt, *log_args) + + user = await check_credential(root_ctx.db, domain_name, email, params['old_password']) + if user is None: + log.info(log_fmt + ': old password mismtach', *log_args) + raise AuthorizationFailed('Old password mismatch') + if params['new_password'] != params['new_password2']: + log.info(log_fmt + ': new password mismtach', *log_args) + return web.json_response({'error_msg': 'new password mismitch'}, status=400) + + # [Hooking point for VERIFY_PASSWORD_FORMAT with the ALL_COMPLETED requirement] + # The hook handlers should accept the old password and the new password and implement their + # own password validation rules. + # They should return None if the validation is successful and raise the Reject error + # otherwise. + hook_result = await root_ctx.hook_plugin_ctx.dispatch( + 'VERIFY_PASSWORD_FORMAT', + (params['old_password'], params['new_password']), + return_when=ALL_COMPLETED, + ) + if hook_result.status != PASSED: + hook_result.reason = hook_result.reason or 'invalid password format' + raise RejectedByHook.from_hook_result(hook_result) + + async with root_ctx.db.begin() as conn: + # Update user password. + data = { + 'password': params['new_password'], + 'need_password_change': False, + } + query = (users.update().values(data).where(users.c.email == email)) + await conn.execute(query) + return web.json_response({}, status=200) + + +@auth_required +async def get_ssh_keypair(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + domain_name = request['user']['domain_name'] + access_key = request['keypair']['access_key'] + log_fmt = 'AUTH.GET_SSH_KEYPAIR(d:{}, ak:{})' + log_args = (domain_name, access_key) + log.info(log_fmt, *log_args) + async with root_ctx.db.begin() as conn: + # Get SSH public key. Return partial string from the public key just for checking. + query = ( + sa.select([keypairs.c.ssh_public_key]) + .where(keypairs.c.access_key == access_key) + ) + pubkey = await conn.scalar(query) + return web.json_response({'ssh_public_key': pubkey}, status=200) + + +@auth_required +async def refresh_ssh_keypair(request: web.Request) -> web.Response: + domain_name = request['user']['domain_name'] + access_key = request['keypair']['access_key'] + log_fmt = 'AUTH.REFRESH_SSH_KEYPAIR(d:{}, ak:{})' + log_args = (domain_name, access_key) + log.info(log_fmt, *log_args) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + pubkey, privkey = generate_ssh_keypair() + data = { + 'ssh_public_key': pubkey, + 'ssh_private_key': privkey, + } + query = ( + keypairs.update() + .values(data) + .where(keypairs.c.access_key == access_key) + ) + await conn.execute(query) + return web.json_response(data, status=200) + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app['prefix'] = 'auth' # slashed to distinguish with "/vN/authorize" + app['api_versions'] = (1, 2, 3, 4) + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + root_resource = cors.add(app.router.add_resource(r'')) + cors.add(root_resource.add_route('GET', test)) + cors.add(root_resource.add_route('POST', test)) + test_resource = cors.add(app.router.add_resource('/test')) + cors.add(test_resource.add_route('GET', test)) + cors.add(test_resource.add_route('POST', test)) + cors.add(app.router.add_route('POST', '/authorize', authorize)) + cors.add(app.router.add_route('GET', '/role', get_role)) + cors.add(app.router.add_route('POST', '/signup', signup)) + cors.add(app.router.add_route('POST', '/signout', signout)) + cors.add(app.router.add_route('POST', '/update-password', update_password)) + cors.add(app.router.add_route('POST', '/update-full-name', update_full_name)) + cors.add(app.router.add_route('GET', '/ssh-keypair', get_ssh_keypair)) + cors.add(app.router.add_route('PATCH', '/ssh-keypair', refresh_ssh_keypair)) + return app, [auth_middleware] diff --git a/src/ai/backend/manager/api/cluster_template.py b/src/ai/backend/manager/api/cluster_template.py new file mode 100644 index 0000000000..05a83db50c --- /dev/null +++ b/src/ai/backend/manager/api/cluster_template.py @@ -0,0 +1,399 @@ +import json +import logging +from typing import ( + Any, + List, + Mapping, + TYPE_CHECKING, + Tuple, +) +import uuid + +from aiohttp import web +import aiohttp_cors +import sqlalchemy as sa +import trafaret as t +import yaml + +from ai.backend.common import validators as tx +from ai.backend.common.logging import BraceStyleAdapter + +from ..models import ( + association_groups_users as agus, domains, + groups, session_templates, keypairs, users, UserRole, + query_accessible_session_templates, TemplateType, +) +from ..models.session_template import check_cluster_template +from .auth import auth_required +from .exceptions import InvalidAPIParameters, TaskTemplateNotFound +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, Iterable, WebMiddleware +from .utils import check_api_params, get_access_key_scopes + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict( + { + tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String, + tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String, + t.Key('owner_access_key', default=None): t.Null | t.String, + t.Key('payload'): t.String, + }, +)) +async def create(request: web.Request, params: Any) -> web.Response: + if params['domain'] is None: + params['domain'] = request['user']['domain_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + requester_uuid = request['user']['uuid'] + log.info( + 'CLUSTER_TEMPLATE.CREATE (ak:{0}/{1})', requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + user_uuid = request['user']['uuid'] + + root_ctx: RootContext = request.app['_root.context'] + + async with root_ctx.db.begin() as conn: + if requester_access_key != owner_access_key: + # Admin or superadmin is creating sessions for another user. + # The check for admin privileges is already done in get_access_key_scope(). + query = ( + sa.select([keypairs.c.user, users.c.role, users.c.domain_name]) + .select_from(sa.join(keypairs, users, keypairs.c.user == users.c.uuid)) + .where(keypairs.c.access_key == owner_access_key) + ) + result = await conn.execute(query) + row = result.first() + owner_domain = row['domain_name'] + owner_uuid = row['user'] + owner_role = row['role'] + else: + # Normal case when the user is creating her/his own session. + owner_domain = request['user']['domain_name'] + owner_uuid = requester_uuid + owner_role = UserRole.USER + + query = ( + sa.select([domains.c.name]) + .select_from(domains) + .where( + (domains.c.name == owner_domain) & + (domains.c.is_active), + ) + ) + qresult = await conn.execute(query) + domain_name = qresult.scalar() + if domain_name is None: + raise InvalidAPIParameters('Invalid domain') + + if owner_role == UserRole.SUPERADMIN: + # superadmin can spawn container in any designated domain/group. + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where( + (groups.c.domain_name == params['domain']) & + (groups.c.name == params['group']) & + (groups.c.is_active), + )) + qresult = await conn.execute(query) + group_id = qresult.scalar() + elif owner_role == UserRole.ADMIN: + # domain-admin can spawn container in any group in the same domain. + if params['domain'] != owner_domain: + raise InvalidAPIParameters("You can only set the domain to the owner's domain.") + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where( + (groups.c.domain_name == owner_domain) & + (groups.c.name == params['group']) & + (groups.c.is_active), + )) + qresult = await conn.execute(query) + group_id = qresult.scalar() + else: + # normal users can spawn containers in their group and domain. + if params['domain'] != owner_domain: + raise InvalidAPIParameters("You can only set the domain to your domain.") + query = ( + sa.select([agus.c.group_id]) + .select_from(agus.join(groups, agus.c.group_id == groups.c.id)) + .where( + (agus.c.user_id == owner_uuid) & + (groups.c.domain_name == owner_domain) & + (groups.c.name == params['group']) & + (groups.c.is_active), + )) + qresult = await conn.execute(query) + group_id = qresult.scalar() + if group_id is None: + raise InvalidAPIParameters('Invalid group') + + log.debug('Params: {0}', params) + try: + body = json.loads(params['payload']) + except json.JSONDecodeError: + try: + body = yaml.safe_load(params['payload']) + except (yaml.YAMLError, yaml.MarkedYAMLError): + raise InvalidAPIParameters('Malformed payload') + template_data = check_cluster_template(body) + template_id = uuid.uuid4().hex + resp = { + 'id': template_id, + 'user': user_uuid.hex, + } + query = session_templates.insert().values({ + 'id': template_id, + 'domain_name': params['domain'], + 'group_id': group_id, + 'user_uuid': user_uuid, + 'name': template_data['metadata']['name'], + 'template': template_data, + 'type': TemplateType.CLUSTER, + }) + result = await conn.execute(query) + assert result.rowcount == 1 + return web.json_response(resp) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('all', default=False): t.ToBool, + tx.AliasedKey(['group_id', 'groupId'], default=None): tx.UUID | t.String | t.Null, + }), +) +async def list_template(request: web.Request, params: Any) -> web.Response: + resp = [] + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + domain_name = request['user']['domain_name'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + + log.info('CLUSTER_TEMPLATE.LIST (ak:{})', access_key) + async with root_ctx.db.begin() as conn: + entries: List[Mapping[str, Any]] + if request['is_superadmin'] and params['all']: + j = ( + session_templates + .join(users, session_templates.c.user_uuid == users.c.uuid, isouter=True) + .join(groups, session_templates.c.group_id == groups.c.id, isouter=True) + ) + query = ( + sa.select([session_templates, users.c.email, groups.c.name], use_labels=True) + .select_from(j) + .where( + (session_templates.c.is_active) & + (session_templates.c.type == TemplateType.CLUSTER), + ) + ) + result = await conn.execute(query) + entries = [] + for row in result: + is_owner = True if row.session_templates_user == user_uuid else False + entries.append({ + 'name': row.session_templates_name, + 'id': row.session_templates_id, + 'created_at': row.session_templates_created_at, + 'is_owner': is_owner, + 'user': (str(row.session_templates_user_uuid) + if row.session_templates_user_uuid else None), + 'group': (str(row.session_templates_group_id) + if row.session_templates_group_id else None), + 'user_email': row.users_email, + 'group_name': row.groups_name, + }) + else: + extra_conds = None + if params['group_id'] is not None: + extra_conds = ((session_templates.c.group_id == params['group_id'])) + entries = await query_accessible_session_templates( + conn, + user_uuid, + TemplateType.CLUSTER, + user_role=user_role, + domain_name=domain_name, + allowed_types=['user', 'group'], + extra_conds=extra_conds, + ) + + for entry in entries: + resp.append({ + 'name': entry['name'], + 'id': entry['id'].hex, + 'created_at': str(entry['created_at']), + 'is_owner': entry['is_owner'], + 'user': str(entry['user']), + 'group': str(entry['group']), + 'user_email': entry['user_email'], + 'group_name': entry['group_name'], + 'type': 'user' if entry['user'] is not None else 'group', + }) + return web.json_response(resp) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('format', default='yaml'): t.Null | t.Enum('yaml', 'json'), + t.Key('owner_access_key', default=None): t.Null | t.String, + }), +) +async def get(request: web.Request, params: Any) -> web.Response: + if params['format'] not in ['yaml', 'json']: + raise InvalidAPIParameters('format should be "yaml" or "json"') + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'CLUSTER_TEMPLATE.GET (ak:{0}/{1})', requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + + template_id = request.match_info['template_id'] + root_ctx: RootContext = request.app['_root.context'] + + async with root_ctx.db.begin() as conn: + query = ( + sa.select([session_templates.c.template]) + .select_from(session_templates) + .where( + (session_templates.c.id == template_id) & + (session_templates.c.is_active) & + (session_templates.c.type == TemplateType.CLUSTER), + ) + ) + template = await conn.scalar(query) + if not template: + raise TaskTemplateNotFound + template = json.loads(template) + if params['format'] == 'yaml': + body = yaml.dump(template) + return web.Response(text=body, content_type='text/yaml') + else: + return web.json_response(template) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('payload'): t.String, + t.Key('owner_access_key', default=None): t.Null | t.String, + }), +) +async def put(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + template_id = request.match_info['template_id'] + + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'CLUSTER_TEMPLATE.PUT (ak:{0}/{1})', requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + + async with root_ctx.db.begin() as conn: + query = ( + sa.select([session_templates.c.id]) + .select_from(session_templates) + .where( + (session_templates.c.id == template_id) & + (session_templates.c.is_active) & + (session_templates.c.type == TemplateType.CLUSTER), + ) + ) + result = await conn.scalar(query) + if not result: + raise TaskTemplateNotFound + try: + body = json.loads(params['payload']) + except json.JSONDecodeError: + body = yaml.safe_load(params['payload']) + except (yaml.YAMLError, yaml.MarkedYAMLError): + raise InvalidAPIParameters('Malformed payload') + template_data = check_cluster_template(body) + query = ( + sa.update(session_templates) + .values(template=template_data, name=template_data['metadata']['name']) + .where((session_templates.c.id == template_id)) + ) + result = await conn.execute(query) + assert result.rowcount == 1 + + return web.json_response({'success': True}) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('owner_access_key', default=None): t.Null | t.String, + }), +) +async def delete(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + template_id = request.match_info['template_id'] + + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'CLUSTER_TEMPLATE.DELETE (ak:{0}/{1})', requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + + async with root_ctx.db.begin() as conn: + query = ( + sa.select([session_templates.c.id]) + .select_from(session_templates) + .where( + (session_templates.c.id == template_id) & + (session_templates.c.is_active) & + (session_templates.c.type == TemplateType.CLUSTER), + ) + ) + result = await conn.scalar(query) + if not result: + raise TaskTemplateNotFound + + query = ( + sa.update(session_templates) + .values(is_active=False) + .where((session_templates.c.id == template_id)) + ) + result = await conn.execute(query) + assert result.rowcount == 1 + + return web.json_response({'success': True}) + + +async def init(app: web.Application) -> None: + pass + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['api_versions'] = (4, 5) + app['prefix'] = 'template/cluster' + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('POST', '', create)) + cors.add(app.router.add_route('GET', '', list_template)) + template_resource = cors.add(app.router.add_resource(r'/{template_id}')) + cors.add(template_resource.add_route('GET', get)) + cors.add(template_resource.add_route('PUT', put)) + cors.add(template_resource.add_route('DELETE', delete)) + + return app, [] diff --git a/src/ai/backend/manager/api/context.py b/src/ai/backend/manager/api/context.py new file mode 100644 index 0000000000..796f6c7764 --- /dev/null +++ b/src/ai/backend/manager/api/context.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import attr + +if TYPE_CHECKING: + from ai.backend.common.bgtask import BackgroundTaskManager + from ai.backend.common.events import EventDispatcher, EventProducer + from ai.backend.common.plugin.hook import HookPluginContext + from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext + from ai.backend.common.types import RedisConnectionInfo + + from ..models.storage import StorageSessionManager + from ..models.utils import ExtendedAsyncSAEngine + from ..idle import IdleCheckerHost + from ..plugin.webapp import WebappPluginContext + from ..registry import AgentRegistry + from ..config import LocalConfig, SharedConfig + from ..types import DistributedLockFactory + from .types import CORSOptions + + +class BaseContext: + pass + + +@attr.s(slots=True, auto_attribs=True, init=False) +class RootContext(BaseContext): + pidx: int + db: ExtendedAsyncSAEngine + distributed_lock_factory: DistributedLockFactory + event_dispatcher: EventDispatcher + event_producer: EventProducer + redis_live: RedisConnectionInfo + redis_stat: RedisConnectionInfo + redis_image: RedisConnectionInfo + redis_stream: RedisConnectionInfo + shared_config: SharedConfig + local_config: LocalConfig + cors_options: CORSOptions + + webapp_plugin_ctx: WebappPluginContext + idle_checker_host: IdleCheckerHost + storage_manager: StorageSessionManager + hook_plugin_ctx: HookPluginContext + + registry: AgentRegistry + + error_monitor: ErrorPluginContext + stats_monitor: StatsPluginContext + background_task_manager: BackgroundTaskManager diff --git a/src/ai/backend/manager/api/domainconfig.py b/src/ai/backend/manager/api/domainconfig.py new file mode 100644 index 0000000000..9b2498ce90 --- /dev/null +++ b/src/ai/backend/manager/api/domainconfig.py @@ -0,0 +1,196 @@ +import logging +import re +from typing import Any, TYPE_CHECKING, Tuple + +from aiohttp import web +import aiohttp_cors +import trafaret as t + +from ai.backend.common import msgpack +from ai.backend.common.logging import BraceStyleAdapter + +from ..models import ( + domains, + query_domain_dotfiles, + verify_dotfile_name, + MAXIMUM_DOTFILE_SIZE, +) +from .auth import auth_required, admin_required +from .exceptions import ( + InvalidAPIParameters, DotfileCreationFailed, + DotfileNotFound, DotfileAlreadyExists, + GenericForbidden, DomainNotFound, +) +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, Iterable, WebMiddleware +from .utils import check_api_params + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@server_status_required(READ_ALLOWED) +@admin_required +@check_api_params(t.Dict( + { + t.Key('domain'): t.String, + t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE), + t.Key('path'): t.String, + t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII), + }, +)) +async def create(request: web.Request, params: Any) -> web.Response: + log.info('DOMAINCOFNIG.CREATE_DOTFILE (domain: {0})', params['domain']) + if not request['is_superadmin'] and request['user']['domain_name'] != params['domain']: + raise GenericForbidden('Domain admins cannot create dotfiles of other domains') + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + dotfiles, leftover_space = await query_domain_dotfiles(conn, params['domain']) + if dotfiles is None: + raise DomainNotFound('Input domain is not found') + if leftover_space == 0: + raise DotfileCreationFailed('No leftover space for dotfile storage') + if len(dotfiles) == 100: + raise DotfileCreationFailed('Dotfile creation limit reached') + if not verify_dotfile_name(params['path']): + raise InvalidAPIParameters('dotfile path is reserved for internal operations.') + + duplicate = [x for x in dotfiles if x['path'] == params['path']] + if len(duplicate) > 0: + raise DotfileAlreadyExists + new_dotfiles = list(dotfiles) + new_dotfiles.append({'path': params['path'], 'perm': params['permission'], + 'data': params['data']}) + dotfile_packed = msgpack.packb(new_dotfiles) + if len(dotfile_packed) > MAXIMUM_DOTFILE_SIZE: + raise DotfileCreationFailed('No leftover space for dotfile storage') + + query = (domains.update() + .values(dotfiles=dotfile_packed) + .where(domains.c.name == params['domain'])) + await conn.execute(query) + return web.json_response({}) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict({ + t.Key('domain'): t.String, + t.Key('path', default=None): t.Null | t.String, +})) +async def list_or_get(request: web.Request, params: Any) -> web.Response: + log.info('DOMAINCONFIG.LIST_OR_GET_DOTFILE (domain: {0})', params['domain']) + if not request['is_superadmin'] and request['user']['domain_name'] != params['domain']: + raise GenericForbidden('Users cannot access dotfiles of other domains') + resp = [] + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + if params['path']: + dotfiles, _ = await query_domain_dotfiles(conn, params['domain']) + if dotfiles is None: + raise DomainNotFound + for dotfile in dotfiles: + if dotfile['path'] == params['path']: + return web.json_response(dotfile) + raise DotfileNotFound + else: + dotfiles, _ = await query_domain_dotfiles(conn, params['domain']) + if dotfiles is None: + raise DomainNotFound + for entry in dotfiles: + resp.append({ + 'path': entry['path'], + 'permission': entry['perm'], + 'data': entry['data'], + }) + return web.json_response(resp) + + +@server_status_required(READ_ALLOWED) +@admin_required +@check_api_params(t.Dict( + { + t.Key('domain'): t.String, + t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE), + t.Key('path'): t.String, + t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII), + }, +)) +async def update(request: web.Request, params: Any) -> web.Response: + log.info('DOMAINCONFIG.UPDATE_DOTFILE (domain:{0})', params['domain']) + if not request['is_superadmin'] and request['user']['domain_name'] != params['domain']: + raise GenericForbidden('Domain admins cannot update dotfiles of other domains') + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + dotfiles, _ = await query_domain_dotfiles(conn, params['domain']) + if dotfiles is None: + raise DomainNotFound + new_dotfiles = [x for x in dotfiles if x['path'] != params['path']] + if len(new_dotfiles) == len(dotfiles): + raise DotfileNotFound + + new_dotfiles.append({'path': params['path'], 'perm': params['permission'], + 'data': params['data']}) + dotfile_packed = msgpack.packb(new_dotfiles) + if len(dotfile_packed) > MAXIMUM_DOTFILE_SIZE: + raise DotfileCreationFailed('No leftover space for dotfile storage') + + query = (domains.update() + .values(dotfiles=dotfile_packed) + .where(domains.c.name == params['domain'])) + await conn.execute(query) + return web.json_response({}) + + +@server_status_required(READ_ALLOWED) +@admin_required +@check_api_params( + t.Dict({ + t.Key('domain'): t.String, + t.Key('path'): t.String, + }), +) +async def delete(request: web.Request, params: Any) -> web.Response: + log.info('DOMAINCONFIG.DELETE_DOTFILE (domain:{0})', params['domain']) + if not request['is_superadmin'] and request['user']['domain_name'] != params['domain']: + raise GenericForbidden('Domain admins cannot delete dotfiles of other domains') + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + dotfiles, _ = await query_domain_dotfiles(conn, params['domain']) + if dotfiles is None: + raise DomainNotFound + new_dotfiles = [x for x in dotfiles if x['path'] != params['path']] + if len(new_dotfiles) == len(dotfiles): + raise DotfileNotFound + + dotfile_packed = msgpack.packb(new_dotfiles) + query = (domains.update() + .values(dotfiles=dotfile_packed) + .where(domains.c.name == params['domain'])) + await conn.execute(query) + return web.json_response({'success': True}) + + +async def init(app: web.Application) -> None: + pass + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['api_versions'] = (4, 5) + app['prefix'] = 'domain-config' + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('POST', '/dotfiles', create)) + cors.add(app.router.add_route('GET', '/dotfiles', list_or_get)) + cors.add(app.router.add_route('PATCH', '/dotfiles', update)) + cors.add(app.router.add_route('DELETE', '/dotfiles', delete)) + + return app, [] diff --git a/src/ai/backend/manager/api/etcd.py b/src/ai/backend/manager/api/etcd.py new file mode 100644 index 0000000000..4bdf9c3c5e --- /dev/null +++ b/src/ai/backend/manager/api/etcd.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import logging +from typing import ( + Any, + AsyncGenerator, + Iterable, + Mapping, + TYPE_CHECKING, + Tuple, +) + +from aiohttp import web +import aiohttp_cors +import trafaret as t + +from ai.backend.common.docker import get_known_registries +from ai.backend.common.logging import BraceStyleAdapter + +from .auth import superadmin_required +from .exceptions import InvalidAPIParameters +from .utils import check_api_params +from .types import CORSOptions, WebMiddleware + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +async def get_resource_slots(request: web.Request) -> web.Response: + log.info('ETCD.GET_RESOURCE_SLOTS ()') + root_ctx: RootContext = request.app['_root.context'] + known_slots = await root_ctx.shared_config.get_resource_slots() + return web.json_response(known_slots, status=200) + + +async def get_vfolder_types(request: web.Request) -> web.Response: + log.info('ETCD.GET_VFOLDER_TYPES ()') + root_ctx: RootContext = request.app['_root.context'] + vfolder_types = await root_ctx.shared_config.get_vfolder_types() + return web.json_response(vfolder_types, status=200) + + +@superadmin_required +async def get_docker_registries(request: web.Request) -> web.Response: + """ + Returns the list of all registered docker registries. + """ + log.info('ETCD.GET_DOCKER_REGISTRIES ()') + root_ctx: RootContext = request.app['_root.context'] + _registries = await get_known_registries(root_ctx.shared_config.etcd) + # ``yarl.URL`` is not JSON-serializable, so we need to represent it as string. + known_registries: Mapping[str, str] = {k: v.human_repr() for k, v in _registries.items()} + return web.json_response(known_registries, status=200) + + +@superadmin_required +@check_api_params( + t.Dict({ + t.Key('key'): t.String, + t.Key('prefix', default=False): t.Bool, + })) +async def get_config(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + log.info( + 'ETCD.GET_CONFIG (ak:{}, key:{}, prefix:{})', + request['keypair']['access_key'], params['key'], params['prefix'], + ) + if params['prefix']: + # Flatten the returned ChainMap object for JSON serialization + tree_value = dict(await root_ctx.shared_config.etcd.get_prefix_dict(params['key'])) + return web.json_response({'result': tree_value}) + else: + scalar_value = await root_ctx.shared_config.etcd.get(params['key']) + return web.json_response({'result': scalar_value}) + + +@superadmin_required +@check_api_params( + t.Dict({ + t.Key('key'): t.String, + t.Key('value'): (t.String(allow_blank=True) | + t.Mapping(t.String(allow_blank=True), t.Any)), + })) +async def set_config(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + log.info( + 'ETCD.SET_CONFIG (ak:{}, key:{}, val:{})', + request['keypair']['access_key'], params['key'], params['value'], + ) + if isinstance(params['value'], Mapping): + updates = {} + + def flatten(prefix, o): + for k, v in o.items(): + inner_prefix = prefix if k == '' else f'{prefix}/{k}' + if isinstance(v, Mapping): + flatten(inner_prefix, v) + else: + updates[inner_prefix] = v + + flatten(params['key'], params['value']) + # TODO: chunk support if there are too many keys + if len(updates) > 16: + raise InvalidAPIParameters( + 'Too large update! Split into smaller key-value pair sets.') + await root_ctx.shared_config.etcd.put_dict(updates) + else: + await root_ctx.shared_config.etcd.put(params['key'], params['value']) + return web.json_response({'result': 'ok'}) + + +@superadmin_required +@check_api_params( + t.Dict({ + t.Key('key'): t.String, + t.Key('prefix', default=False): t.Bool, + })) +async def delete_config(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + log.info( + 'ETCD.DELETE_CONFIG (ak:{}, key:{}, prefix:{})', + request['keypair']['access_key'], params['key'], params['prefix'], + ) + if params['prefix']: + await root_ctx.shared_config.etcd.delete_prefix(params['key']) + else: + await root_ctx.shared_config.etcd.delete(params['key']) + return web.json_response({'result': 'ok'}) + + +async def app_ctx(app: web.Application) -> AsyncGenerator[None, None]: + root_ctx: RootContext = app['_root.context'] + if root_ctx.pidx == 0: + await root_ctx.shared_config.register_myself() + yield + if root_ctx.pidx == 0: + await root_ctx.shared_config.deregister_myself() + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.cleanup_ctx.append(app_ctx) + app['prefix'] = 'config' + app['api_versions'] = (3, 4) + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('GET', r'/resource-slots', get_resource_slots)) + cors.add(app.router.add_route('GET', r'/vfolder-types', get_vfolder_types)) + cors.add(app.router.add_route('GET', r'/docker-registries', get_docker_registries)) + cors.add(app.router.add_route('POST', r'/get', get_config)) + cors.add(app.router.add_route('POST', r'/set', set_config)) + cors.add(app.router.add_route('POST', r'/delete', delete_config)) + return app, [] diff --git a/src/ai/backend/manager/api/events.py b/src/ai/backend/manager/api/events.py new file mode 100644 index 0000000000..3440cdca7f --- /dev/null +++ b/src/ai/backend/manager/api/events.py @@ -0,0 +1,417 @@ +from __future__ import annotations + +import asyncio +import logging +import json +from typing import ( + Any, + AsyncIterator, + Final, + Iterable, + Mapping, + Set, + Tuple, + Union, + TYPE_CHECKING, +) + +from aiohttp import web +import aiohttp_cors +from aiohttp_sse import sse_response +from aiotools import adefer +import attr +import sqlalchemy as sa +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.events import ( + BgtaskCancelledEvent, + BgtaskDoneEvent, + BgtaskFailedEvent, + BgtaskUpdatedEvent, + EventDispatcher, + KernelCancelledEvent, + KernelCreatingEvent, + KernelPreparingEvent, + KernelPullingEvent, + KernelStartedEvent, + KernelTerminatedEvent, + KernelTerminatingEvent, + SessionCancelledEvent, + SessionEnqueuedEvent, + SessionFailureEvent, + SessionScheduledEvent, + SessionStartedEvent, + SessionSuccessEvent, + SessionTerminatedEvent, +) +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import AgentId + +from ..models import kernels, groups, UserRole +from ..models.utils import execute_with_retry +from ..types import Sentinel +from .auth import auth_required +from .exceptions import ObjectNotFound, GenericForbidden, GroupNotFound +from .manager import READ_ALLOWED, server_status_required +from .utils import check_api_params + +if TYPE_CHECKING: + from .context import RootContext + from .types import CORSOptions, WebMiddleware + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +sentinel: Final = Sentinel.token + +SessionEventInfo = Tuple[str, dict, str] +BgtaskEvents = Union[BgtaskUpdatedEvent, BgtaskDoneEvent, BgtaskCancelledEvent, BgtaskFailedEvent] + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['name', 'sessionName'], default='*') >> 'session_name': t.String, + t.Key('ownerAccessKey', default=None) >> 'owner_access_key': t.Null | t.String, + t.Key('sessionId', default=None) >> 'session_id': t.Null | tx.UUID, + # NOTE: if set, sessionId overrides sessionName and ownerAccessKey parameters. + tx.AliasedKey(['group', 'groupName'], default='*') >> 'group_name': t.String, + t.Key('scope', default='*'): t.Enum('*', 'session', 'kernel'), + })) +@adefer +async def push_session_events( + defer, + request: web.Request, + params: Mapping[str, Any], +) -> web.StreamResponse: + root_ctx: RootContext = request.app['_root.context'] + app_ctx: PrivateContext = request.app['events.context'] + session_name = params['session_name'] + session_id = params['session_id'] + scope = params['scope'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + access_key = params['owner_access_key'] + if access_key is None: + access_key = request['keypair']['access_key'] + if user_role == UserRole.USER: + if access_key != request['keypair']['access_key']: + raise GenericForbidden + group_name = params['group_name'] + my_queue: asyncio.Queue[Sentinel | SessionEventInfo] = asyncio.Queue() + log.info('PUSH_SESSION_EVENTS (ak:{}, s:{}, g:{})', access_key, session_name, group_name) + if group_name == '*': + group_id = '*' + else: + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.name == group_name) + ) + result = await conn.execute(query) + row = result.first() + if row is None: + raise GroupNotFound + group_id = row['id'] + app_ctx.session_event_queues.add(my_queue) + defer(lambda: app_ctx.session_event_queues.remove(my_queue)) + async with sse_response(request) as resp: + try: + while True: + evdata = await my_queue.get() + try: + if evdata is sentinel: + break + event_name, row, reason = evdata + if user_role in (UserRole.USER, UserRole.ADMIN): + if row['domain_name'] != request['user']['domain_name']: + continue + if user_role == UserRole.USER: + if row['user_uuid'] != user_uuid: + continue + if group_id != '*' and row['group_id'] != group_id: + continue + if scope == 'session' and not event_name.startswith('session_'): + continue + if scope == 'kernel' and not event_name.startswith('kernel_'): + continue + if session_id is not None: + if row['session_id'] != session_id: + continue + else: + if session_name != '*' and not ( + (row['session_name'] == session_name) and + (row['access_key'] == access_key)): + continue + response_data = { + 'reason': reason, + 'sessionName': row['session_name'], + 'ownerAccessKey': row['access_key'], + 'sessionId': str(row['session_id']), + } + if kernel_id := row.get('id'): + response_data['kernelId'] = str(kernel_id) + if cluster_role := row.get('cluster_role'): + response_data['clusterRole'] = cluster_role + if cluster_idx := row.get('cluster_idx'): + response_data['clusterIdx'] = cluster_idx + await resp.send(json.dumps(response_data), event=event_name) + finally: + my_queue.task_done() + finally: + return resp + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict({ + tx.AliasedKey(['task_id', 'taskId']): tx.UUID, +})) +async def push_background_task_events( + request: web.Request, + params: Mapping[str, Any], +) -> web.StreamResponse: + root_ctx: RootContext = request.app['_root.context'] + task_id = params['task_id'] + access_key = request['keypair']['access_key'] + log.info('PUSH_BACKGROUND_TASK_EVENTS (ak:{}, t:{})', access_key, task_id) + try: + return await root_ctx.background_task_manager.push_bgtask_events(request, task_id) + except ValueError as e: + raise ObjectNotFound(extra_data=str(e), object_name='background task') + + +async def enqueue_kernel_creation_status_update( + app: web.Application, + source: AgentId, + event: KernelPreparingEvent | KernelPullingEvent | KernelCreatingEvent | KernelStartedEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['events.context'] + + async def _fetch(): + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.id, + kernels.c.session_id, + kernels.c.session_name, + kernels.c.access_key, + kernels.c.cluster_role, + kernels.c.cluster_idx, + kernels.c.domain_name, + kernels.c.group_id, + kernels.c.user_uuid, + ]) + .select_from(kernels) + .where( + (kernels.c.id == event.kernel_id), + ) + ) + result = await conn.execute(query) + return result.first() + + row = await execute_with_retry(_fetch) + if row is None: + return + for q in app_ctx.session_event_queues: + q.put_nowait((event.name, row._mapping, event.reason)) + + +async def enqueue_kernel_termination_status_update( + app: web.Application, + agent_id: AgentId, + event: KernelCancelledEvent | KernelTerminatingEvent | KernelTerminatedEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['events.context'] + + async def _fetch(): + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.id, + kernels.c.session_id, + kernels.c.session_name, + kernels.c.access_key, + kernels.c.cluster_role, + kernels.c.cluster_idx, + kernels.c.domain_name, + kernels.c.group_id, + kernels.c.user_uuid, + ]) + .select_from(kernels) + .where( + (kernels.c.id == event.kernel_id), + ) + ) + result = await conn.execute(query) + return result.first() + + row = await execute_with_retry(_fetch) + if row is None: + return + for q in app_ctx.session_event_queues: + q.put_nowait((event.name, row._mapping, event.reason)) + + +async def enqueue_session_creation_status_update( + app: web.Application, + source: AgentId, + event: SessionEnqueuedEvent | SessionScheduledEvent | SessionStartedEvent | SessionCancelledEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['events.context'] + + async def _fetch(): + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.id, + kernels.c.session_id, + kernels.c.session_name, + kernels.c.access_key, + kernels.c.domain_name, + kernels.c.group_id, + kernels.c.user_uuid, + ]) + .select_from(kernels) + .where( + (kernels.c.id == event.session_id), + # for the main kernel, kernel ID == session ID + ) + ) + result = await conn.execute(query) + return result.first() + + row = await execute_with_retry(_fetch) + if row is None: + return + for q in app_ctx.session_event_queues: + q.put_nowait((event.name, row._mapping, event.reason)) + + +async def enqueue_session_termination_status_update( + app: web.Application, + agent_id: AgentId, + event: SessionTerminatedEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['events.context'] + + async def _fetch(): + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.id, + kernels.c.session_id, + kernels.c.session_name, + kernels.c.access_key, + kernels.c.domain_name, + kernels.c.group_id, + kernels.c.user_uuid, + ]) + .select_from(kernels) + .where( + (kernels.c.id == event.session_id), + # for the main kernel, kernel ID == session ID + ) + ) + result = await conn.execute(query) + return result.first() + + row = await execute_with_retry(_fetch) + if row is None: + return + for q in app_ctx.session_event_queues: + q.put_nowait((event.name, row._mapping, event.reason)) + + +async def enqueue_batch_task_result_update( + app: web.Application, + agent_id: AgentId, + event: SessionSuccessEvent | SessionFailureEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['events.context'] + + async def _fetch(): + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.id, + kernels.c.session_id, + kernels.c.session_name, + kernels.c.access_key, + kernels.c.domain_name, + kernels.c.group_id, + kernels.c.user_uuid, + ]) + .select_from(kernels) + .where( + (kernels.c.id == event.session_id), + ) + ) + result = await conn.execute(query) + return result.first() + + row = await execute_with_retry(_fetch) + if row is None: + return + for q in app_ctx.session_event_queues: + q.put_nowait((event.name, row._mapping, event.reason)) + + +@attr.s(slots=True, auto_attribs=True, init=False) +class PrivateContext: + session_event_queues: Set[asyncio.Queue[Sentinel | SessionEventInfo]] + + +async def events_app_ctx(app: web.Application) -> AsyncIterator[None]: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['events.context'] + app_ctx.session_event_queues = set() + event_dispatcher: EventDispatcher = root_ctx.event_dispatcher + event_dispatcher.subscribe(SessionEnqueuedEvent, app, enqueue_session_creation_status_update) + event_dispatcher.subscribe(SessionScheduledEvent, app, enqueue_session_creation_status_update) + event_dispatcher.subscribe(KernelPreparingEvent, app, enqueue_kernel_creation_status_update) + event_dispatcher.subscribe(KernelPullingEvent, app, enqueue_kernel_creation_status_update) + event_dispatcher.subscribe(KernelCreatingEvent, app, enqueue_kernel_creation_status_update) + event_dispatcher.subscribe(KernelStartedEvent, app, enqueue_kernel_creation_status_update) + event_dispatcher.subscribe(SessionStartedEvent, app, enqueue_session_creation_status_update) + event_dispatcher.subscribe(KernelTerminatingEvent, app, enqueue_kernel_termination_status_update) + event_dispatcher.subscribe(KernelTerminatedEvent, app, enqueue_kernel_termination_status_update) + event_dispatcher.subscribe(KernelCancelledEvent, app, enqueue_kernel_termination_status_update) + event_dispatcher.subscribe(SessionTerminatedEvent, app, enqueue_session_termination_status_update) + event_dispatcher.subscribe(SessionCancelledEvent, app, enqueue_session_creation_status_update) + event_dispatcher.subscribe(SessionSuccessEvent, app, enqueue_batch_task_result_update) + event_dispatcher.subscribe(SessionFailureEvent, app, enqueue_batch_task_result_update) + root_ctx.background_task_manager.register_event_handlers(event_dispatcher) + yield + + +async def events_shutdown(app: web.Application) -> None: + # shutdown handler is called before waiting for closing active connections. + # We need to put sentinels here to ensure delivery of them to active SSE connections. + app_ctx: PrivateContext = app['events.context'] + join_tasks = [] + for sq in app_ctx.session_event_queues: + sq.put_nowait(sentinel) + join_tasks.append(sq.join()) + await asyncio.gather(*join_tasks) + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app['prefix'] = 'events' + app['events.context'] = PrivateContext() + app['api_versions'] = (3, 4) + app.on_shutdown.append(events_shutdown) + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + add_route = app.router.add_route + app.cleanup_ctx.append(events_app_ctx) + cors.add(add_route('GET', r'/background-task', push_background_task_events)) + cors.add(add_route('GET', r'/session', push_session_events)) + return app, [] diff --git a/src/ai/backend/manager/api/exceptions.py b/src/ai/backend/manager/api/exceptions.py new file mode 100644 index 0000000000..22e374c1c5 --- /dev/null +++ b/src/ai/backend/manager/api/exceptions.py @@ -0,0 +1,418 @@ +""" +This module defines a series of Backend.AI-specific errors based on HTTP Error +classes from aiohttp. +Raising a BackendError is automatically mapped to a corresponding HTTP error +response with RFC7807-style JSON-encoded description in its response body. + +In the client side, you should use "type" field in the body to distinguish +canonical error types beacuse "title" field may change due to localization and +future UX improvements. +""" + +from __future__ import annotations + +import json +from typing import ( + Any, + Dict, + Optional, + Mapping, + Union, + cast, +) + +from aiohttp import web + +from ai.backend.common.plugin.hook import HookResult + +from ..exceptions import AgentError + + +class BackendError(web.HTTPError): + """ + An RFC-7807 error class as a drop-in replacement of the original + aiohttp.web.HTTPError subclasses. + """ + + error_type: str = 'https://api.backend.ai/probs/general-error' + error_title: str = 'General Backend API Error.' + + content_type: str + extra_msg: Optional[str] + + def __init__(self, extra_msg: str = None, extra_data: Any = None, **kwargs): + super().__init__(**kwargs) + self.args = (self.status_code, self.reason, self.error_type) + self.empty_body = False + self.content_type = 'application/problem+json' + self.extra_msg = extra_msg + self.extra_data = extra_data + body = { + 'type': self.error_type, + 'title': self.error_title, + } + if extra_msg is not None: + body['msg'] = extra_msg + if extra_data is not None: + body['data'] = extra_data + self.body = json.dumps(body).encode() + + def __str__(self): + lines = [] + if self.extra_msg: + lines.append(f'{self.error_title} ({self.extra_msg})') + else: + lines.append(self.error_title) + if self.extra_data: + lines.append(' -> extra_data: ' + repr(self.extra_data)) + return '\n'.join(lines) + + def __repr__(self): + lines = [] + if self.extra_msg: + lines.append(f'<{type(self).__name__}({self.status}): ' + f'{self.error_title} ({self.extra_msg})>') + else: + lines.append(f'<{type(self).__name__}({self.status}): ' + f'{self.error_title}>') + if self.extra_data: + lines.append(' -> extra_data: ' + repr(self.extra_data)) + return '\n'.join(lines) + + def __reduce__(self): + return ( + type(self), + (), # empty the constructor args to make unpickler to use + # only the exact current state in __dict__ + self.__dict__, + ) + + +class URLNotFound(BackendError, web.HTTPNotFound): + error_type = 'https://api.backend.ai/probs/url-not-found' + error_title = 'Unknown URL path.' + + +class ObjectNotFound(BackendError, web.HTTPNotFound): + error_type = 'https://api.backend.ai/probs/object-not-found' + object_name = 'object' + + def __init__( + self, + extra_msg: str = None, + extra_data: Any = None, + *, + object_name: str = None, + **kwargs, + ) -> None: + if object_name: + self.object_name = object_name + self.error_title = f'No such {self.object_name}.' + super().__init__(extra_msg, extra_data, **kwargs) + + +class GenericBadRequest(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/generic-bad-request' + error_title = 'Bad request.' + + +class RejectedByHook(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/rejected-by-hook' + error_title = 'Operation rejected by a hook plugin.' + + @classmethod + def from_hook_result(cls, result: HookResult) -> RejectedByHook: + return cls( + extra_msg=result.reason, + extra_data={ + 'plugins': result.src_plugin, + }, + ) + + +class InvalidCredentials(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/invalid-credentials' + error_title = 'Invalid credentials for authentication.' + + +class GenericForbidden(BackendError, web.HTTPForbidden): + error_type = 'https://api.backend.ai/probs/generic-forbidden' + error_title = 'Forbidden operation.' + + +class InsufficientPrivilege(BackendError, web.HTTPForbidden): + error_type = 'https://api.backend.ai/probs/insufficient-privilege' + error_title = 'Insufficient privilege.' + + +class MethodNotAllowed(BackendError, web.HTTPMethodNotAllowed): + error_type = 'https://api.backend.ai/probs/method-not-allowed' + error_title = 'HTTP Method Not Allowed.' + + +class InternalServerError(BackendError, web.HTTPInternalServerError): + error_type = 'https://api.backend.ai/probs/internal-server-error' + error_title = 'Internal server error.' + + +class ServerMisconfiguredError(BackendError, web.HTTPInternalServerError): + error_type = 'https://api.backend.ai/probs/server-misconfigured' + error_title = 'Service misconfigured.' + + +class ServiceUnavailable(BackendError, web.HTTPServiceUnavailable): + error_type = 'https://api.backend.ai/probs/service-unavailable' + error_title = 'Serivce unavailable.' + + +class QueryNotImplemented(BackendError, web.HTTPServiceUnavailable): + error_type = 'https://api.backend.ai/probs/not-implemented' + error_title = 'This API query is not implemented.' + + +class InvalidAuthParameters(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/invalid-auth-params' + error_title = 'Missing or invalid authorization parameters.' + + +class AuthorizationFailed(BackendError, web.HTTPUnauthorized): + error_type = 'https://api.backend.ai/probs/auth-failed' + error_title = 'Credential/signature mismatch.' + + +class InvalidAPIParameters(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/invalid-api-params' + error_title = 'Missing or invalid API parameters.' + + +class GraphQLError(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/graphql-error' + error_title = 'GraphQL-generated error.' + + +class InstanceNotFound(ObjectNotFound): + object_name = 'agent instance' + + +class ImageNotFound(ObjectNotFound): + object_name = 'environment image' + + +class DomainNotFound(ObjectNotFound): + object_name = 'domain' + + +class GroupNotFound(ObjectNotFound): + object_name = 'user group (or project)' + + +class ScalingGroupNotFound(ObjectNotFound): + object_name = 'scaling group' + + +class SessionNotFound(ObjectNotFound): + object_name = 'session' + + +class TooManySessionsMatched(BackendError, web.HTTPNotFound): + error_type = 'https://api.backend.ai/probs/too-many-sessions-matched' + error_title = 'Too many sessions matched.' + + def __init__(self, extra_msg: str = None, extra_data: Dict[str, Any] = None, **kwargs): + if ( + extra_data is not None and + (matches := extra_data.get('matches', None)) is not None + ): + serializable_matches = [{ + 'id': str(item['session_id']), + 'name': item['session_name'], + 'status': item['status'].name, + 'created_at': item['created_at'].isoformat(), + } for item in matches] + extra_data['matches'] = serializable_matches + super().__init__(extra_msg, extra_data, **kwargs) + + +class TooManyKernelsFound(BackendError, web.HTTPNotFound): + error_type = 'https://api.backend.ai/probs/too-many-kernels' + error_title = 'There are two or more matching kernels.' + + +class TaskTemplateNotFound(ObjectNotFound): + object_name = 'task template' + + +class AppNotFound(ObjectNotFound): + object_name = 'app service' + + +class SessionAlreadyExists(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/session-already-exists' + error_title = 'The session already exists but you requested not to reuse existing one.' + + +class VFolderCreationFailed(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/vfolder-creation-failed' + error_title = 'Virtual folder creation has failed.' + + +class VFolderNotFound(ObjectNotFound): + object_name = 'virtual folder' + + +class VFolderAlreadyExists(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/vfolder-already-exists' + error_title = 'The virtual folder already exists with the same name.' + + +class VFolderOperationFailed(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/vfolder-operation-failed' + error_title = 'Virtual folder operation has failed.' + + +class DotfileCreationFailed(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/generic-bad-request' + error_title = 'Dotfile creation has failed.' + + +class DotfileAlreadyExists(BackendError, web.HTTPBadRequest): + error_type = 'https://api.backend.ai/probs/generic-bad-request' + error_title = 'Dotfile already exists.' + + +class DotfileNotFound(ObjectNotFound): + object_name = 'dotfile' + + +class QuotaExceeded(BackendError, web.HTTPPreconditionFailed): + error_type = 'https://api.backend.ai/probs/quota-exceeded' + error_title = 'You have reached your resource limit.' + + +class RateLimitExceeded(BackendError, web.HTTPTooManyRequests): + error_type = 'https://api.backend.ai/probs/rate-limit-exceeded' + error_title = 'You have reached your API query rate limit.' + + +class InstanceNotAvailable(BackendError, web.HTTPServiceUnavailable): + error_type = 'https://api.backend.ai/probs/instance-not-available' + error_title = 'There is no available instance.' + + +class ServerFrozen(BackendError, web.HTTPServiceUnavailable): + error_type = 'https://api.backend.ai/probs/server-frozen' + error_title = 'The server is frozen due to maintenance. Please try again later.' + + +class StorageProxyError(BackendError, web.HTTPError): + error_type = 'https://api.backend.ai/probs/storage-proxy-error' + error_title = 'The storage proxy returned an error.' + + def __init__(self, status: int, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Currently there is no good public way to override the status code + # after initialization of aiohttp.web.StreamResponse objects. :( + self.status_code = status # HTTPException uses self.status_code + self._status = status # StreamResponse uses self._status + self.args = (status, self.args[1], self.args[2]) + + @property + def status(self) -> int: + # override the status property again to refer the subclass' attribute. + return self.status_code + + +class BackendAgentError(BackendError): + """ + An RFC-7807 error class that wraps agent-side errors. + """ + + _short_type_map = { + 'TIMEOUT': 'https://api.backend.ai/probs/agent-timeout', + 'INVALID_INPUT': 'https://api.backend.ai/probs/agent-invalid-input', + 'FAILURE': 'https://api.backend.ai/probs/agent-failure', + } + + def __init__(self, agent_error_type: str, + exc_info: Union[str, AgentError, Exception, Mapping[str, Optional[str]], None] = None): + super().__init__() + agent_details: Mapping[str, Optional[str]] + if not agent_error_type.startswith('https://'): + agent_error_type = self._short_type_map[agent_error_type.upper()] + self.args = ( + self.status_code, + self.reason, + self.error_type, + agent_error_type, + ) + if isinstance(exc_info, str): + agent_details = { + 'type': agent_error_type, + 'title': exc_info, + } + elif isinstance(exc_info, AgentError): + e = cast(AgentError, exc_info) + agent_details = { + 'type': agent_error_type, + 'title': 'Agent-side exception occurred.', + 'exception': e.exc_repr, + } + elif isinstance(exc_info, Exception): + agent_details = { + 'type': agent_error_type, + 'title': 'Unexpected exception ocurred.', + 'exception': repr(exc_info), + } + elif isinstance(exc_info, Mapping): + agent_details = exc_info + else: + agent_details = { + 'type': agent_error_type, + 'title': None if exc_info is None else str(exc_info), + } + self.agent_details = agent_details + self.agent_error_type = agent_error_type + self.agent_error_title = agent_details['title'] + self.agent_exception = agent_details.get('exception', '') + self.body = json.dumps({ + 'type': self.error_type, + 'title': self.error_title, + 'agent-details': agent_details, + }).encode() + + def __str__(self): + if self.agent_exception: + return f'{self.agent_error_title} ({self.agent_exception})' + return f'{self.agent_error_title}' + + def __repr__(self): + if self.agent_exception: + return f'<{type(self).__name__}: {self.agent_error_title} ({self.agent_exception})>' + return f'<{type(self).__name__}: {self.agent_error_title}>' + + def __reduce__(self): + return (type(self), (self.agent_error_type, self.agent_details)) + + +class KernelCreationFailed(BackendAgentError, web.HTTPInternalServerError): + error_type = 'https://api.backend.ai/probs/kernel-creation-failed' + error_title = 'Kernel creation has failed.' + + +class KernelDestructionFailed(BackendAgentError, web.HTTPInternalServerError): + error_type = 'https://api.backend.ai/probs/kernel-destruction-failed' + error_title = 'Kernel destruction has failed.' + + +class KernelRestartFailed(BackendAgentError, web.HTTPInternalServerError): + error_type = 'https://api.backend.ai/probs/kernel-restart-failed' + error_title = 'Kernel restart has failed.' + + +class KernelExecutionFailed(BackendAgentError, web.HTTPInternalServerError): + error_type = 'https://api.backend.ai/probs/kernel-execution-failed' + error_title = 'Executing user code in the kernel has failed.' + + +class UnknownImageReferenceError(ObjectNotFound): + object_name = 'image reference' diff --git a/src/ai/backend/manager/api/groupconfig.py b/src/ai/backend/manager/api/groupconfig.py new file mode 100644 index 0000000000..2377ce7687 --- /dev/null +++ b/src/ai/backend/manager/api/groupconfig.py @@ -0,0 +1,283 @@ +import logging +import re +from typing import Any, TYPE_CHECKING, Tuple + +import sqlalchemy as sa +from aiohttp import web +import aiohttp_cors +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common import msgpack +from ai.backend.common.logging import BraceStyleAdapter + +from ..models import ( + groups, + association_groups_users as agus, + query_group_dotfiles, + query_group_domain, + verify_dotfile_name, + MAXIMUM_DOTFILE_SIZE, +) +from .auth import auth_required, admin_required +from .exceptions import ( + InvalidAPIParameters, DotfileCreationFailed, + DotfileNotFound, DotfileAlreadyExists, + GenericForbidden, GroupNotFound, +) +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, Iterable, WebMiddleware +from .utils import check_api_params + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@server_status_required(READ_ALLOWED) +@admin_required +@check_api_params(t.Dict( + { + tx.AliasedKey(['group', 'groupId', 'group_id']): tx.UUID | t.String, + t.Key('domain', default=None): t.String | t.Null, + t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE), + t.Key('path'): t.String, + t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII), + }, +)) +async def create(request: web.Request, params: Any) -> web.Response: + log.info('GROUPCONFIG.CREATE_DOTFILE (group: {0})', params['group']) + root_ctx: RootContext = request.app['_root.context'] + group_id_or_name = params['group'] + async with root_ctx.db.begin() as conn: + if isinstance(group_id_or_name, str): + if params['domain'] is None: + raise InvalidAPIParameters('Missing parameter \'domain\'') + + query = (sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == params['domain']) + .where(groups.c.name == group_id_or_name)) + group_id = await conn.scalar(query) + domain = params['domain'] + else: + group_id = group_id_or_name + # if group UUID is given, override input domain + domain = await query_group_domain(conn, group_id) + if group_id is None or domain is None: + raise GroupNotFound + if not request['is_superadmin'] and request['user']['domain_name'] != domain: + raise GenericForbidden('Admins cannot create group dotfiles of other domains') + + dotfiles, leftover_space = await query_group_dotfiles(conn, group_id) + if dotfiles is None: + raise GroupNotFound + if leftover_space == 0: + raise DotfileCreationFailed('No leftover space for dotfile storage') + if len(dotfiles) == 100: + raise DotfileCreationFailed('Dotfile creation limit reached') + if not verify_dotfile_name(params['path']): + raise InvalidAPIParameters('dotfile path is reserved for internal operations.') + + duplicate = [x for x in dotfiles if x['path'] == params['path']] + if len(duplicate) > 0: + raise DotfileAlreadyExists + new_dotfiles = list(dotfiles) + new_dotfiles.append({'path': params['path'], 'perm': params['permission'], + 'data': params['data']}) + dotfile_packed = msgpack.packb(new_dotfiles) + if len(dotfile_packed) > MAXIMUM_DOTFILE_SIZE: + raise DotfileCreationFailed('No leftover space for dotfile storage') + + query = (groups.update() + .values(dotfiles=dotfile_packed) + .where(groups.c.id == group_id)) + await conn.execute(query) + return web.json_response({}) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict({ + tx.AliasedKey(['group', 'groupId', 'group_id']): tx.UUID | t.String, + t.Key('domain', default=None): t.String | t.Null, + t.Key('path', default=None): t.Null | t.String, +})) +async def list_or_get(request: web.Request, params: Any) -> web.Response: + log.info('GROUPCONFIG.LIST_OR_GET_DOTFILE (group: {0})', params['group']) + root_ctx: RootContext = request.app['_root.context'] + resp = [] + group_id_or_name = params['group'] + async with root_ctx.db.begin() as conn: + if isinstance(group_id_or_name, str): + if params['domain'] is None: + raise InvalidAPIParameters('Missing parameter \'domain\'') + query = (sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == params['domain']) + .where(groups.c.name == group_id_or_name)) + group_id = await conn.scalar(query) + domain = params['domain'] + else: + group_id = group_id_or_name + domain = await query_group_domain(conn, group_id) + if group_id is None or domain is None: + raise GroupNotFound + if not request['is_superadmin']: + if request['is_admin']: + if request['user']['domain_name'] != domain: + raise GenericForbidden( + 'Domain admins cannot access group dotfiles of other domains') + else: + # check if user (non-admin) is in the group + query = (sa.select([agus.c.group_id]) + .select_from(agus) + .where(agus.c.user_id == request['user']['uuid'])) + result = await conn.execute(query) + rows = result.fetchall() + if group_id not in map(lambda x: x.group_id, rows): + raise GenericForbidden( + 'Users cannot access group dotfiles of other groups') + + if params['path']: + dotfiles, _ = await query_group_dotfiles(conn, group_id) + if dotfiles is None: + raise GroupNotFound + for dotfile in dotfiles: + if dotfile['path'] == params['path']: + return web.json_response(dotfile) + raise DotfileNotFound + else: + dotfiles, _ = await query_group_dotfiles(conn, group_id) + if dotfiles is None: + raise GroupNotFound + for entry in dotfiles: + resp.append({ + 'path': entry['path'], + 'permission': entry['perm'], + 'data': entry['data'], + }) + return web.json_response(resp) + + +@server_status_required(READ_ALLOWED) +@admin_required +@check_api_params(t.Dict( + { + tx.AliasedKey(['group', 'groupId', 'group_id']): tx.UUID | t.String, + t.Key('domain', default=None): t.String | t.Null, + t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE), + t.Key('path'): t.String, + t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII), + }, +)) +async def update(request: web.Request, params: Any) -> web.Response: + log.info('GROUPCONFIG.UPDATE_DOTFILE (domain:{0})', params['domain']) + root_ctx: RootContext = request.app['_root.context'] + group_id_or_name = params['group'] + async with root_ctx.db.begin() as conn: + if isinstance(group_id_or_name, str): + if params['domain'] is None: + raise InvalidAPIParameters('Missing parameter \'domain\'') + query = (sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == params['domain']) + .where(groups.c.name == group_id_or_name)) + group_id = await conn.scalar(query) + domain = params['domain'] + else: + group_id = group_id_or_name + domain = await query_group_domain(conn, group_id) + if group_id is None or domain is None: + raise GroupNotFound + if not request['is_superadmin'] and request['user']['domain_name'] != domain: + raise GenericForbidden('Admins cannot update group dotfiles of other domains') + + dotfiles, _ = await query_group_dotfiles(conn, group_id) + if dotfiles is None: + raise GroupNotFound + new_dotfiles = [x for x in dotfiles if x['path'] != params['path']] + if len(new_dotfiles) == len(dotfiles): + raise DotfileNotFound + + new_dotfiles.append({'path': params['path'], 'perm': params['permission'], + 'data': params['data']}) + dotfile_packed = msgpack.packb(new_dotfiles) + if len(dotfile_packed) > MAXIMUM_DOTFILE_SIZE: + raise DotfileCreationFailed('No leftover space for dotfile storage') + + query = (groups.update() + .values(dotfiles=dotfile_packed) + .where(groups.c.id == group_id)) + await conn.execute(query) + return web.json_response({}) + + +@server_status_required(READ_ALLOWED) +@admin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['group', 'groupId', 'group_id']): tx.UUID | t.String, + t.Key('domain', default=None): t.String | t.Null, + t.Key('path'): t.String, + }), +) +async def delete(request: web.Request, params: Any) -> web.Response: + log.info('GROUPCONFIG.DELETE_DOTFILE (domain:{0})', params['domain']) + root_ctx: RootContext = request.app['_root.context'] + group_id_or_name = params['group'] + async with root_ctx.db.begin() as conn: + if isinstance(group_id_or_name, str): + if params['domain'] is None: + raise InvalidAPIParameters('Missing parameter \'domain\'') + query = (sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == params['domain']) + .where(groups.c.name == group_id_or_name)) + group_id = await conn.scalar(query) + domain = params['domain'] + else: + group_id = group_id_or_name + domain = await query_group_domain(conn, group_id) + if group_id is None or domain is None: + raise GroupNotFound + if not request['is_superadmin'] and request['user']['domain_name'] != domain: + raise GenericForbidden('Admins cannot delete dotfiles of other domains') + + dotfiles, _ = await query_group_dotfiles(conn, group_id) + if dotfiles is None: + raise DotfileNotFound + new_dotfiles = [x for x in dotfiles if x['path'] != params['path']] + if len(new_dotfiles) == len(dotfiles): + raise DotfileNotFound + + dotfile_packed = msgpack.packb(new_dotfiles) + query = (groups.update() + .values(dotfiles=dotfile_packed) + .where(groups.c.id == group_id)) + await conn.execute(query) + return web.json_response({'success': True}) + + +async def init(app: web.Application) -> None: + pass + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['api_versions'] = (4, 5) + app['prefix'] = 'group-config' + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('POST', '/dotfiles', create)) + cors.add(app.router.add_route('GET', '/dotfiles', list_or_get)) + cors.add(app.router.add_route('PATCH', '/dotfiles', update)) + cors.add(app.router.add_route('DELETE', '/dotfiles', delete)) + + return app, [] diff --git a/src/ai/backend/manager/api/image.py b/src/ai/backend/manager/api/image.py new file mode 100644 index 0000000000..5d9cf44e09 --- /dev/null +++ b/src/ai/backend/manager/api/image.py @@ -0,0 +1,472 @@ +import base64 +import secrets +from typing import ( + Any, + Iterable, + TYPE_CHECKING, + Tuple, +) + +from aiohttp import web +import aiohttp_cors +import jinja2 +import sqlalchemy as sa +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.docker import ImageRef +from ai.backend.common.etcd import ( + quote as etcd_quote, +) +from ai.backend.common.types import ( + SessionTypes, +) + +from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE +from ..models import ( + domains, groups, query_allowed_sgroups, + association_groups_users as agus, +) +from ..types import UserScope +from .auth import admin_required +from .exceptions import InvalidAPIParameters +from .manager import ALL_ALLOWED, READ_ALLOWED, server_status_required +from .types import CORSOptions, WebMiddleware +from .utils import ( + check_api_params, +) + +if TYPE_CHECKING: + from .context import RootContext + + +DOCKERFILE_TEMPLATE = r"""# syntax = docker/dockerfile:1.0-experimental +FROM {{ src }} +MAINTAINER Backend.AI Manager + +USER root + +{% if runtime_type == 'python' -%} +ENV PYTHONUNBUFFERED=1 \ + LANG=C.UTF-8 + +RUN --mount=type=bind,source=wheelhouse,target=/root/wheelhouse \ + PIP_OPTS="--no-cache-dir --no-index --find-links=/root/wheelhouse" && \ + {{ runtime_path }} -m pip install ${PIP_OPTS} -U pip setuptools && \ + {{ runtime_path }} -m pip install ${PIP_OPTS} Pillow && \ + {{ runtime_path }} -m pip install ${PIP_OPTS} h5py && \ + {{ runtime_path }} -m pip install ${PIP_OPTS} ipython && \ + {{ runtime_path }} -m pip install ${PIP_OPTS} jupyter && \ + {{ runtime_path }} -m pip install ${PIP_OPTS} jupyterlab + +# Install ipython kernelspec +RUN {{ runtime_path }} -m ipykernel install \ + --prefix={{ runtime_path.parent.parent }} \ + --display-name "{{ brand }} on Backend.AI" +{%- endif %} + +LABEL ai.backend.kernelspec="1" \ + ai.backend.envs.corecount="{{ cpucount_envvars | join(',') }}" \ + ai.backend.features="{% if has_ipykernel %}query batch {% endif %}uid-match" \ + ai.backend.resource.min.cpu="{{ min_cpu }}" \ + ai.backend.resource.min.mem="{{ min_mem }}" \ + ai.backend.resource.preferred.shmem="{{ pref_shmem }}" \ + ai.backend.accelerators="{{ accelerators | join(',') }}" \ +{%- if 'cuda' is in accelerators %} + ai.backend.resource.min.cuda.device=1 \ + ai.backend.resource.min.cuda.shares=0.1 \ +{%- endif %} + ai.backend.base-distro="{{ base_distro }}" \ +{%- if service_ports %} + ai.backend.service-ports="{% for item in service_ports -%} + {{- item['name'] }}: + {{- item['protocol'] }}: + {%- if (item['ports'] | length) > 1 -%} + [{{ item['ports'] | join(',') }}] + {%- else -%} + {{ item['ports'][0] }} + {%- endif -%} + {{- ',' if not loop.last }} + {%- endfor %}" \ +{%- endif %} + ai.backend.runtime-type="{{ runtime_type }}" \ + ai.backend.runtime-path="{{ runtime_path }}" +""" # noqa + + +@server_status_required(READ_ALLOWED) +@admin_required +async def get_import_image_form(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + query = ( + sa.select([groups.c.name]) + .select_from( + sa.join( + groups, domains, + groups.c.domain_name == domains.c.name, + ), + ) + .where( + (domains.c.name == request['user']['domain_name']) & + (domains.c.is_active) & + (groups.c.is_active), + ) + ) + result = await conn.execute(query) + rows = result.fetchall() + accessible_groups = [row['name'] for row in rows] + + # FIXME: Currently this only consider domain-level scaling group associations, + # thus ignoring the group name query. + rows = await query_allowed_sgroups( + conn, request['user']['domain_name'], '', request['keypair']['access_key'], + ) + accessible_scaling_groups = [row['name'] for row in rows] + + return web.json_response({ + 'fieldGroups': [ + { + 'name': 'Import options', + 'fields': [ + { + 'name': 'src', + 'type': 'string', + 'label': 'Source Docker image', + 'placeholder': 'index.docker.io/lablup/tensorflow:2.0-source', + 'help': 'The full Docker image name to import from. ' + 'The registry must be accessible by the client.', + }, + { + 'name': 'target', + 'type': 'string', + 'label': 'Target Docker image', + 'placeholder': 'index.docker.io/lablup/tensorflow:2.0-target', + 'help': 'The full Docker image name of the imported image.' + 'The registry must be accessible by the client.', + }, + { + 'name': 'brand', + 'type': 'string', + 'label': 'Name of Jupyter kernel', + 'placeholder': 'TensorFlow 2.0', + 'help': 'The name of kernel to be shown in the Jupyter\'s kernel menu. ' + 'This will be suffixed with "on Backend.AI".', + }, + { + 'name': 'baseDistro', + 'type': 'choice', + 'choices': ['ubuntu', 'centos'], + 'default': 'ubuntu', + 'label': 'Base LINUX distribution', + 'help': 'The base Linux distribution used by the source image', + }, + { + 'name': 'minCPU', + 'type': 'number', + 'min': 1, + 'max': None, + 'label': 'Minimum required CPU core(s)', + 'help': 'The minimum number of CPU cores required by the image', + }, + { + 'name': 'minMemory', + 'type': 'binarysize', + 'min': '64m', + 'max': None, + 'label': 'Minimum required memory size', + 'help': 'The minimum size of the main memory required by the image', + }, + { + 'name': 'preferredSharedMemory', + 'type': 'binarysize', + 'min': '64m', + 'max': None, + 'label': 'Preferred shared memory size', + 'help': 'The preferred (default) size of the shared memory', + }, + { + 'name': 'supportedAccelerators', + 'type': 'multichoice[str]', + 'choices': ['cuda'], + 'default': 'cuda', + 'label': 'Supported accelerators', + 'help': 'The list of accelerators supported by the image', + }, + { + 'name': 'runtimeType', + 'type': 'choice', + 'choices': ['python'], + 'default': 'python', + 'label': 'Runtime type of the image', + 'help': 'The runtime type of the image. ' + 'Currently, the source image must have installed Python 2.7, 3.5, 3.6, ' + 'or 3.7 at least to import. ' + 'This will be used as the kernel of Jupyter service in this image.', + }, + { + 'name': 'runtimePath', + 'type': 'string', + 'default': '/usr/local/bin/python', + 'label': 'Path of the runtime', + 'placeholder': '/usr/local/bin/python', + 'help': 'The path to the main executalbe of runtime language of the image. ' + 'Even for the same "python"-based images, this may differ significantly ' + 'image by image. (e.g., /usr/bin/python, /usr/local/bin/python, ' + '/opt/something/bin/python, ...) ' + 'Please check this carefully not to get confused with OS-default ones ' + 'and custom-installed ones.', + }, + { + 'name': 'CPUCountEnvs', + 'type': 'list[string]', + 'default': ['NPROC', 'OMP_NUM_THREADS', 'OPENBLAS_NUM_THREADS'], + 'label': 'CPU count environment variables', + 'help': 'The name of environment variables to be overriden to the number of CPU ' + 'cores actually allocated to the container. Required for legacy ' + 'computation libraries.', + }, + { + 'name': 'servicePorts', + 'type': 'multichoice[template]', + 'templates': [ + {'name': 'jupyter', 'protocol': 'http', 'ports': [8080]}, + {'name': 'jupyterlab', 'protocol': 'http', 'ports': [8090]}, + {'name': 'tensorboard', 'protocol': 'http', 'ports': [6006]}, + {'name': 'digits', 'protocol': 'http', 'ports': [5000]}, + {'name': 'vscode', 'protocol': 'http', 'ports': [8180]}, + {'name': 'h2o-dai', 'protocol': 'http', 'ports': [12345]}, + ], + 'label': 'Supported service ports', + 'help': 'The list of service ports supported by this image. ' + 'Note that sshd (port 2200) and ttyd (port 7681) are intrinsic; ' + 'they are always included regardless of the source image. ' + 'The port number 2000-2003 are reserved by Backend.AI, and ' + 'all port numbers must be larger than 1024 and smaller than 65535.', + }, + ], + }, + { + 'name': 'Import Task Options', + 'help': 'The import task uses 1 CPU core and 2 GiB of memory.', + 'fields': [ + { + 'name': 'group', + 'type': 'choice', + 'choices': accessible_groups, + 'label': 'Group to build image', + 'help': 'The user group where the import task will be executed.', + }, + { + 'name': 'scalingGroup', + 'type': 'choice', + 'choices': accessible_scaling_groups, + 'label': 'Scaling group to build image', + 'help': 'The scaling group where the import task will take resources from.', + }, + ], + }, + ], + }) + + +@server_status_required(ALL_ALLOWED) +@admin_required +@check_api_params( + t.Dict({ + t.Key('src'): t.String, + t.Key('target'): t.String, + t.Key('architecture', default=DEFAULT_IMAGE_ARCH): t.String, + t.Key('launchOptions', default={}): t.Dict({ + t.Key('scalingGroup', default='default'): t.String, + t.Key('group', default='default'): t.String, + }).allow_extra('*'), + t.Key('brand'): t.String, + t.Key('baseDistro'): t.Enum('ubuntu', 'centos'), + t.Key('minCPU', default=1): t.Int[1:], + t.Key('minMemory', default='64m'): tx.BinarySize, + t.Key('preferredSharedMemory', default='64m'): tx.BinarySize, + t.Key('supportedAccelerators'): t.List(t.String), + t.Key('runtimeType'): t.Enum('python'), + t.Key('runtimePath'): tx.Path(type='file', allow_nonexisting=True, resolve=False), + t.Key('CPUCountEnvs'): t.List(t.String), + t.Key('servicePorts', default=[]): t.List(t.Dict({ + t.Key('name'): t.String, + t.Key('protocol'): t.Enum('http', 'tcp', 'pty'), + t.Key('ports'): t.List(t.Int[1:65535], min_length=1), + })), + }).allow_extra('*')) +async def import_image(request: web.Request, params: Any) -> web.Response: + """ + Import a docker image and convert it to a Backend.AI-compatible one, + by automatically installing a few packages and adding image labels. + + Currently we only support auto-conversion of Python-based kernels (e.g., + NGC images) which has its own Python version installed. + + Internally, it launches a temporary kernel in an arbitrary agent within + the client's domain, the "default" group, and the "default" scaling group. + (The client may change the group and scaling group using *launchOptions.* + If the client is a super-admin, it uses the "default" domain.) + + This temporary kernel occupies only 1 CPU core and 1 GiB memory. + The kernel concurrency limit is not applied here, but we choose an agent + based on their resource availability. + The owner of this kernel is always the client that makes the API request. + + This API returns immediately after launching the temporary kernel. + The client may check the progress of the import task using session logs. + """ + + tpl = jinja2.Template(DOCKERFILE_TEMPLATE) + root_ctx: RootContext = request.app['_root.context'] + + async with root_ctx.db.begin() as conn: + query = ( + sa.select([domains.c.allowed_docker_registries]) + .select_from(domains) + .where(domains.c.name == request['user']['domain_name']) + ) + result = await conn.execute(query) + allowed_docker_registries = result.scalar() + + # TODO: select agent to run image builder based on image architecture + source_image = ImageRef(params['src'], allowed_docker_registries, params['architecture']) + target_image = ImageRef(params['target'], allowed_docker_registries, params['architecture']) + + # TODO: validate and convert arguments to template variables + dockerfile_content = tpl.render({ + 'base_distro': params['baseDistro'], + 'cpucount_envvars': ['NPROC', 'OMP_NUM_THREADS', 'OPENBLAS_NUM_THREADS'], + 'runtime_type': params['runtimeType'], + 'runtime_path': params['runtimePath'], + 'service_ports': params['servicePorts'], + 'min_cpu': params['minCPU'], + 'min_mem': params['minMemory'], + 'pref_shmem': params['preferredSharedMemory'], + 'accelerators': params['supportedAccelerators'], + 'src': params['src'], + 'brand': params['brand'], + 'has_ipykernel': True, # TODO: in the future, we may allow import of service-port only kernels. + }) + + session_creation_id = secrets.token_urlsafe(32) + session_id = f'image-import-{secrets.token_urlsafe(8)}' + access_key = request['keypair']['access_key'] + resource_policy = request['keypair']['resource_policy'] + + async with root_ctx.db.begin() as conn: + query = ( + sa.select([groups.c.id]) + .select_from( + sa.join( + groups, domains, + groups.c.domain_name == domains.c.name, + ), + ) + .where( + (domains.c.name == request['user']['domain_name']) & + (groups.c.name == params['launchOptions']['group']) & + (domains.c.is_active) & + (groups.c.is_active), + ) + ) + result = await conn.execute(query) + group_id = result.scalar() + if group_id is None: + raise InvalidAPIParameters("Invalid domain or group.") + + query = ( + sa.select([agus]) + .select_from(agus) + .where( + (agus.c.user_id == request['user']['uuid']) & + (agus.c.group_id == group_id), + ) + ) + result = await conn.execute(query) + row = result.first() + if row is None: + raise InvalidAPIParameters("You do not belong to the given group.") + + importer_image = ImageRef( + root_ctx.local_config['manager']['importer-image'], + allowed_docker_registries, + params['architecture'], + ) + + docker_creds = {} + for img_ref in (source_image, target_image): + registry_info = await root_ctx.shared_config.etcd.get_prefix_dict( + f'config/docker/registry/{etcd_quote(img_ref.registry)}') + docker_creds[img_ref.registry] = { + 'username': registry_info.get('username'), + 'password': registry_info.get('password'), + } + + kernel_id = await root_ctx.registry.enqueue_session( + session_creation_id, + session_id, + access_key, + [{ + 'image_ref': importer_image, + 'cluster_role': DEFAULT_ROLE, + 'cluster_idx': 1, + 'cluster_hostname': f"{DEFAULT_ROLE}1", + 'creation_config': { + 'resources': {'cpu': '1', 'mem': '2g'}, + 'scaling_group': params['launchOptions']['scalingGroup'], + 'environ': { + 'SRC_IMAGE': source_image.canonical, + 'TARGET_IMAGE': target_image.canonical, + 'RUNTIME_PATH': params['runtimePath'], + 'BUILD_SCRIPT': ( + base64.b64encode(dockerfile_content.encode('utf8')).decode('ascii') + ), + }, + }, + 'startup_command': '/root/build-image.sh', + 'bootstrap_script': '', + }], + None, + SessionTypes.BATCH, + resource_policy, + user_scope=UserScope( + domain_name=request['user']['domain_name'], + group_id=group_id, + user_uuid=request['user']['uuid'], + user_role=request['user']['role'], + ), + internal_data={ + 'domain_socket_proxies': ['/var/run/docker.sock'], + 'docker_credentials': docker_creds, + 'prevent_vfolder_mounts': True, + 'block_service_ports': True, + }, + ) + return web.json_response({ + 'importTask': { + 'sessionId': session_id, + 'taskId': str(kernel_id), + }, + }, status=200) + + +async def init(app: web.Application) -> None: + pass + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['prefix'] = 'image' + app['api_versions'] = (4,) + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('GET', '/import', get_import_image_form)) + cors.add(app.router.add_route('POST', '/import', import_image)) + return app, [] diff --git a/src/ai/backend/manager/api/logs.py b/src/ai/backend/manager/api/logs.py new file mode 100644 index 0000000000..7bf6b3caf7 --- /dev/null +++ b/src/ai/backend/manager/api/logs.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +import datetime as dt +from datetime import datetime +import logging +import uuid +from ai.backend.common.events import EventHandler + +from aiohttp import web +import aiohttp_cors +import attr +import sqlalchemy as sa +import trafaret as t +from typing import Any, TYPE_CHECKING, Tuple, MutableMapping + +from ai.backend.common import redis, validators as tx +from ai.backend.common.distributed import GlobalTimer +from ai.backend.common.events import AbstractEvent, EmptyEventArgs +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import AgentId, LogSeverity, RedisConnectionInfo + +from ..defs import REDIS_LIVE_DB, LockID +from ..models import ( + error_logs, UserRole, groups, + association_groups_users as agus, +) +from .auth import auth_required +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, Iterable, WebMiddleware +from .utils import check_api_params, get_access_key_scopes + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class DoLogCleanupEvent(EmptyEventArgs, AbstractEvent): + name = "do_log_cleanup" + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict( + { + t.Key('severity'): tx.Enum(LogSeverity), + t.Key('source'): t.String, + t.Key('message'): t.String, + t.Key('context_lang'): t.String, + t.Key('context_env'): tx.JSONString, + t.Key('request_url', default=None): t.Null | t.String, + t.Key('request_status', default=None): t.Null | t.Int, + t.Key('traceback', default=None): t.Null | t.String, + }, +)) +async def append(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + params['domain'] = request['user']['domain_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + requester_uuid = request['user']['uuid'] + log.info('CREATE (ak:{0}/{1})', + requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*') + + async with root_ctx.db.begin() as conn: + resp = { + 'success': True, + } + query = error_logs.insert().values({ + 'severity': params['severity'], + 'source': params['source'], + 'user': requester_uuid, + 'message': params['message'], + 'context_lang': params['context_lang'], + 'context_env': params['context_env'], + 'request_url': params['request_url'], + 'request_status': params['request_status'], + 'traceback': params['traceback'], + }) + result = await conn.execute(query) + assert result.rowcount == 1 + return web.json_response(resp) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('mark_read', default=False): t.ToBool(), + t.Key('page_size', default=20): t.ToInt(lt=101), + t.Key('page_no', default=1): t.ToInt(), + }), +) +async def list_logs(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + resp: MutableMapping[str, Any] = {'logs': []} + domain_name = request['user']['domain_name'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info('LIST (ak:{0}/{1})', + requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*') + async with root_ctx.db.begin() as conn: + is_admin = True + select_query = ( + sa.select([error_logs]) + .select_from(error_logs) + .order_by(sa.desc(error_logs.c.created_at)) + .limit(params['page_size']) + ) + count_query = ( + sa.select([sa.func.count()]) + .select_from(error_logs) + ) + if params['page_no'] > 1: + select_query = select_query.offset((params['page_no'] - 1) * params['page_size']) + if request['is_superadmin']: + pass + elif user_role == UserRole.ADMIN or user_role == 'admin': + j = (groups.join(agus, groups.c.id == agus.c.group_id)) + usr_query = ( + sa.select([agus.c.user_id]) + .select_from(j) + .where(groups.c.domain_name == domain_name) + ) + result = await conn.execute(usr_query) + usrs = result.fetchall() + user_ids = [g.id for g in usrs] + where = error_logs.c.user.in_(user_ids) + select_query = select_query.where(where) + count_query = count_query.where(where) + else: + is_admin = False + where = ((error_logs.c.user == user_uuid) & + (~error_logs.c.is_cleared)) + select_query = select_query.where(where) + count_query = count_query.where(where) + + result = await conn.execute(select_query) + for row in result: + result_item = { + 'log_id': str(row['id']), + 'created_at': datetime.timestamp(row['created_at']), + 'severity': row['severity'], + 'source': row['source'], + 'user': row['user'], + 'is_read': row['is_read'], + 'message': row['message'], + 'context_lang': row['context_lang'], + 'context_env': row['context_env'], + 'request_url': row['request_url'], + 'request_status': row['request_status'], + 'traceback': row['traceback'], + } + if result_item['user'] is not None: + result_item['user'] = str(result_item['user']) + if is_admin: + result_item['is_cleared'] = row['is_cleared'] + resp['logs'].append(result_item) + resp['count'] = await conn.scalar(count_query) + if params['mark_read']: + read_update_query = ( + sa.update(error_logs) + .values(is_read=True) + .where(error_logs.c.id.in_([x['log_id'] for x in resp['logs']])) + ) + await conn.execute(read_update_query) + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +async def mark_cleared(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + domain_name = request['user']['domain_name'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + log_id = uuid.UUID(request.match_info['log_id']) + + log.info('CLEAR') + async with root_ctx.db.begin() as conn: + update_query = ( + sa.update(error_logs) + .values(is_cleared=True) + ) + if request['is_superadmin']: + update_query = update_query.where(error_logs.c.id == log_id) + elif user_role == UserRole.ADMIN or user_role == 'admin': + j = (groups.join(agus, groups.c.id == agus.c.group_id)) + usr_query = ( + sa.select([agus.c.user_id]) + .select_from(j) + .where(groups.c.domain_name == domain_name) + ) + result = await conn.execute(usr_query) + usrs = result.fetchall() + user_ids = [g.id for g in usrs] + update_query = update_query.where( + (error_logs.c.user.in_(user_ids)) & + (error_logs.c.id == log_id), + ) + else: + update_query = update_query.where( + (error_logs.c.user == user_uuid) & + (error_logs.c.id == log_id), + ) + + result = await conn.execute(update_query) + assert result.rowcount == 1 + + return web.json_response({'success': True}, status=200) + + +async def log_cleanup_task(app: web.Application, src: AgentId, event: DoLogCleanupEvent) -> None: + root_ctx: RootContext = app['_root.context'] + etcd = root_ctx.shared_config.etcd + raw_lifetime = await etcd.get('config/logs/error/retention') + if raw_lifetime is None: + raw_lifetime = '90d' + try: + lifetime = tx.TimeDuration().check(raw_lifetime) + except ValueError: + lifetime = dt.timedelta(days=90) + log.warning( + "Failed to parse the error log retention period ({}) read from etcd; " + "falling back to 90 days", + raw_lifetime, + ) + boundary = datetime.now() - lifetime + async with root_ctx.db.begin() as conn: + query = ( + sa.delete(error_logs) + .where(error_logs.c.created_at < boundary) + ) + result = await conn.execute(query) + if result.rowcount > 0: + log.info('Cleaned up {} log(s) filed before {}', result.rowcount, boundary) + + +@attr.s(slots=True, auto_attribs=True, init=False) +class PrivateContext: + log_cleanup_timer: GlobalTimer + log_cleanup_timer_redis: RedisConnectionInfo + log_cleanup_timer_evh: EventHandler[web.Application, DoLogCleanupEvent] + + +async def init(app: web.Application) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['logs.context'] + app_ctx.log_cleanup_timer_evh = root_ctx.event_dispatcher.consume( + DoLogCleanupEvent, app, log_cleanup_task, + ) + app_ctx.log_cleanup_timer_redis = redis.get_redis_object( + root_ctx.shared_config.data['redis'], + db=REDIS_LIVE_DB, + ) + app_ctx.log_cleanup_timer = GlobalTimer( + root_ctx.distributed_lock_factory(LockID.LOCKID_LOG_CLEANUP_TIMER, 20.0), + root_ctx.event_producer, + lambda: DoLogCleanupEvent(), + 20.0, + initial_delay=17.0, + ) + await app_ctx.log_cleanup_timer.join() + + +async def shutdown(app: web.Application) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['logs.context'] + await app_ctx.log_cleanup_timer.leave() + root_ctx.event_dispatcher.unconsume(app_ctx.log_cleanup_timer_evh) + await app_ctx.log_cleanup_timer_redis.close() + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['api_versions'] = (4, 5) + app['prefix'] = 'logs/error' + app['logs.context'] = PrivateContext() + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('POST', '', append)) + cors.add(app.router.add_route('GET', '', list_logs)) + cors.add(app.router.add_route('POST', r'/{log_id}/clear', mark_cleared)) + + return app, [] diff --git a/src/ai/backend/manager/api/manager.py b/src/ai/backend/manager/api/manager.py new file mode 100644 index 0000000000..2986c44531 --- /dev/null +++ b/src/ai/backend/manager/api/manager.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import asyncio +import enum +import functools +import json +import logging +import socket +import sqlalchemy as sa +import trafaret as t +from typing import ( + Any, + Final, + FrozenSet, + Iterable, + Tuple, + TYPE_CHECKING, +) + +from aiohttp import web +import aiohttp_cors +from aiotools import aclosing +import attr +import graphene + +from ai.backend.common import validators as tx +from ai.backend.common.events import DoScheduleEvent +from ai.backend.common.logging import BraceStyleAdapter + +from .. import __version__ +from ..defs import DEFAULT_ROLE +from ..models import agents, kernels, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES +from . import ManagerStatus +from .auth import superadmin_required +from .exceptions import ( + InstanceNotFound, + InvalidAPIParameters, + GenericBadRequest, + ServerFrozen, + ServiceUnavailable, +) +from .types import CORSOptions, WebMiddleware +from .utils import check_api_params + +if TYPE_CHECKING: + from .context import RootContext + from ai.backend.manager.models.gql import GraphQueryContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class SchedulerOps(enum.Enum): + INCLUDE_AGENTS = 'include-agents' + EXCLUDE_AGENTS = 'exclude-agents' + + +def server_status_required(allowed_status: FrozenSet[ManagerStatus]): + + def decorator(handler): + + @functools.wraps(handler) + async def wrapped(request, *args, **kwargs): + root_ctx: RootContext = request.app['_root.context'] + status = await root_ctx.shared_config.get_manager_status() + if status not in allowed_status: + if status == ManagerStatus.FROZEN: + raise ServerFrozen + msg = f'Server is not in the required status: {allowed_status}' + raise ServiceUnavailable(msg) + return (await handler(request, *args, **kwargs)) + + return wrapped + + return decorator + + +READ_ALLOWED: Final = frozenset({ManagerStatus.RUNNING, ManagerStatus.FROZEN}) +ALL_ALLOWED: Final = frozenset({ManagerStatus.RUNNING}) + + +class GQLMutationUnfrozenRequiredMiddleware: + + def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: + graph_ctx: GraphQueryContext = info.context + if info.operation.operation == 'mutation' and \ + graph_ctx.manager_status == ManagerStatus.FROZEN: + raise ServerFrozen + return next(root, info, **args) + + +async def detect_status_update(root_ctx: RootContext) -> None: + try: + async with aclosing(root_ctx.shared_config.watch_manager_status()) as agen: + async for ev in agen: + if ev.event == 'put': + root_ctx.shared_config.get_manager_status.cache_clear() + updated_status = await root_ctx.shared_config.get_manager_status() + log.debug('Process-{0} detected manager status update: {1}', + root_ctx.pidx, updated_status) + except asyncio.CancelledError: + pass + + +async def fetch_manager_status(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + log.info('MANAGER.FETCH_MANAGER_STATUS ()') + try: + status = await root_ctx.shared_config.get_manager_status() + # etcd_info = await root_ctx.shared_config.get_manager_nodes_info() + configs = root_ctx.local_config['manager'] + + async with root_ctx.db.begin() as conn: + query = ( + sa.select([sa.func.count()]) + .select_from(kernels) + .where( + (kernels.c.cluster_role == DEFAULT_ROLE) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ) + ) + active_sessions_num = await conn.scalar(query) + + _id = configs['id'] if configs.get('id') else socket.gethostname() + nodes = [ + { + 'id': _id, + 'num_proc': configs['num-proc'], + 'service_addr': str(configs['service-addr']), + 'heartbeat_timeout': configs['heartbeat-timeout'], + 'ssl_enabled': configs['ssl-enabled'], + 'active_sessions': active_sessions_num, + 'status': status.value, + 'version': __version__, + 'api_version': request['api_version'], + }, + ] + return web.json_response({ + 'nodes': nodes, + 'status': status.value, # legacy? + 'active_sessions': active_sessions_num, # legacy? + }) + except: + log.exception('GET_MANAGER_STATUS: exception') + raise + + +@superadmin_required +@check_api_params( + t.Dict({ + t.Key('status'): tx.Enum(ManagerStatus, use_name=True), + t.Key('force_kill', default=False): t.ToBool, + })) +async def update_manager_status(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + log.info('MANAGER.UPDATE_MANAGER_STATUS (status:{}, force_kill:{})', + params['status'], params['force_kill']) + try: + params = await request.json() + status = params['status'] + force_kill = params['force_kill'] + except json.JSONDecodeError: + raise InvalidAPIParameters(extra_msg='No request body!') + except (AssertionError, ValueError) as e: + raise InvalidAPIParameters(extra_msg=str(e.args[0])) + + if force_kill: + await root_ctx.registry.kill_all_sessions() + await root_ctx.shared_config.update_manager_status(status) + + return web.Response(status=204) + + +async def get_announcement(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + data = await root_ctx.shared_config.etcd.get('manager/announcement') + if data is None: + ret = {'enabled': False, 'message': ''} + else: + ret = {'enabled': True, 'message': data} + return web.json_response(ret) + + +@superadmin_required +@check_api_params( + t.Dict({ + t.Key('enabled', default='false'): t.ToBool, + t.Key('message', default=None): t.Null | t.String, + })) +async def update_announcement(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + if params['enabled']: + if not params['message']: + raise InvalidAPIParameters(extra_msg='Empty message not allowed to enable announcement') + await root_ctx.shared_config.etcd.put('manager/announcement', params['message']) + else: + await root_ctx.shared_config.etcd.delete('manager/announcement') + return web.Response(status=204) + + +iv_scheduler_ops_args = { + SchedulerOps.INCLUDE_AGENTS: t.List(t.String), + SchedulerOps.EXCLUDE_AGENTS: t.List(t.String), +} + + +@superadmin_required +@check_api_params( + t.Dict({ + t.Key('op'): tx.Enum(SchedulerOps), + t.Key('args'): t.Any, + })) +async def perform_scheduler_ops(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + try: + args = iv_scheduler_ops_args[params['op']].check(params['args']) + except t.DataError as e: + raise InvalidAPIParameters( + f"Input validation failed for args with {params['op']}", + extra_data=e.as_dict(), + ) + if params['op'] in (SchedulerOps.INCLUDE_AGENTS, SchedulerOps.EXCLUDE_AGENTS): + schedulable = (params['op'] == SchedulerOps.INCLUDE_AGENTS) + async with root_ctx.db.begin() as conn: + query = ( + agents.update() + .values(schedulable=schedulable) + .where(agents.c.id.in_(args)) + ) + result = await conn.execute(query) + if result.rowcount < len(args): + raise InstanceNotFound() + if schedulable: + # trigger scheduler + await root_ctx.event_producer.produce_event(DoScheduleEvent()) + else: + raise GenericBadRequest('Unknown scheduler operation') + return web.Response(status=204) + + +@attr.s(slots=True, auto_attribs=True, init=False) +class PrivateContext: + status_watch_task: asyncio.Task + + +async def init(app: web.Application) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['manager.context'] + app_ctx.status_watch_task = asyncio.create_task(detect_status_update(root_ctx)) + + +async def shutdown(app: web.Application) -> None: + app_ctx: PrivateContext = app['manager.context'] + if app_ctx.status_watch_task is not None: + app_ctx.status_watch_task.cancel() + await app_ctx.status_watch_task + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app['api_versions'] = (2, 3, 4) + app['manager.context'] = PrivateContext() + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + status_resource = cors.add(app.router.add_resource('/status')) + cors.add(status_resource.add_route('GET', fetch_manager_status)) + cors.add(status_resource.add_route('PUT', update_manager_status)) + announcement_resource = cors.add(app.router.add_resource('/announcement')) + cors.add(announcement_resource.add_route('GET', get_announcement)) + cors.add(announcement_resource.add_route('POST', update_announcement)) + cors.add(app.router.add_route('POST', '/scheduler/operation', perform_scheduler_ops)) + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + return app, [] diff --git a/src/ai/backend/manager/api/py.typed b/src/ai/backend/manager/api/py.typed new file mode 100644 index 0000000000..5abed26af8 --- /dev/null +++ b/src/ai/backend/manager/api/py.typed @@ -0,0 +1 @@ +marker diff --git a/src/ai/backend/manager/api/ratelimit.py b/src/ai/backend/manager/api/ratelimit.py new file mode 100644 index 0000000000..4b45200bcb --- /dev/null +++ b/src/ai/backend/manager/api/ratelimit.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from decimal import Decimal +import logging +import time +from typing import ( + Iterable, + Final, + Tuple, +) + +from aiohttp import web +from aiotools import apartial +import attr + +from ai.backend.common import redis +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import RedisConnectionInfo + +from ..defs import REDIS_RLIM_DB +from .context import RootContext +from .exceptions import RateLimitExceeded +from .types import CORSOptions, WebRequestHandler, WebMiddleware + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +_time_prec: Final = Decimal('1e-3') # msec +_rlim_window: Final = 60 * 15 + +# We implement rate limiting using a rolling counter, which prevents +# last-minute and first-minute bursts between the intervals. + +_rlim_script = ''' +local access_key = KEYS[1] +local now = tonumber(ARGV[1]) +local window = tonumber(ARGV[2]) +local request_id = tonumber(redis.call('INCR', '__request_id')) +if request_id >= 1e12 then + redis.call('SET', '__request_id', 1) +end +if redis.call('EXISTS', access_key) == 1 then + redis.call('ZREMRANGEBYSCORE', access_key, 0, now - window) +end +redis.call('ZADD', access_key, now, tostring(request_id)) +redis.call('EXPIRE', access_key, window) +return redis.call('ZCARD', access_key) +''' + + +@web.middleware +async def rlim_middleware( + app: web.Application, + request: web.Request, + handler: WebRequestHandler, +) -> web.StreamResponse: + # This is a global middleware: request.app is the root app. + app_ctx: PrivateContext = app['ratelimit.context'] + now = Decimal(time.time()).quantize(_time_prec) + rr = app_ctx.redis_rlim + if request['is_authorized']: + rate_limit = request['keypair']['rate_limit'] + access_key = request['keypair']['access_key'] + ret = await redis.execute_script( + rr, 'ratelimit', _rlim_script, + [access_key], + [str(now), str(_rlim_window)], + ) + if ret is None: + remaining = rate_limit + else: + rolling_count = int(ret) + if rolling_count > rate_limit: + raise RateLimitExceeded + remaining = rate_limit - rolling_count + response = await handler(request) + response.headers['X-RateLimit-Limit'] = str(rate_limit) + response.headers['X-RateLimit-Remaining'] = str(remaining) + response.headers['X-RateLimit-Window'] = str(_rlim_window) + return response + else: + # No checks for rate limiting for non-authorized queries. + response = await handler(request) + response.headers['X-RateLimit-Limit'] = '1000' + response.headers['X-RateLimit-Remaining'] = '1000' + response.headers['X-RateLimit-Window'] = str(_rlim_window) + return response + + +@attr.s(slots=True, auto_attribs=True, init=False) +class PrivateContext: + redis_rlim: RedisConnectionInfo + redis_rlim_script: str + + +async def init(app: web.Application) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['ratelimit.context'] + app_ctx.redis_rlim = redis.get_redis_object(root_ctx.shared_config.data['redis'], db=REDIS_RLIM_DB) + app_ctx.redis_rlim_script = \ + await redis.execute(app_ctx.redis_rlim, lambda r: r.script_load(_rlim_script)) + + +async def shutdown(app: web.Application) -> None: + app_ctx: PrivateContext = app['ratelimit.context'] + await redis.execute(app_ctx.redis_rlim, lambda r: r.flushdb()) + await app_ctx.redis_rlim.close() + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app['api_versions'] = (1, 2, 3, 4) + app['ratelimit.context'] = PrivateContext() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + # middleware must be wrapped by web.middleware at the outermost level. + return app, [web.middleware(apartial(rlim_middleware, app))] diff --git a/src/ai/backend/manager/api/resource.py b/src/ai/backend/manager/api/resource.py new file mode 100644 index 0000000000..b1ea40376d --- /dev/null +++ b/src/ai/backend/manager/api/resource.py @@ -0,0 +1,796 @@ +""" +Resource preset APIs. +""" + +import copy +from datetime import datetime, timedelta +from dateutil.relativedelta import relativedelta +from decimal import Decimal +import functools +import json +import logging +import re +from typing import ( + Any, + Iterable, + TYPE_CHECKING, + Tuple, + MutableMapping, +) + +import aiohttp +from aiohttp import web +import aiohttp_cors +from aioredis import Redis +from aioredis.client import Pipeline as RedisPipeline +from async_timeout import timeout as _timeout +from dateutil.tz import tzutc +import msgpack +import sqlalchemy as sa +import trafaret as t +import yarl + +from ai.backend.common import redis, validators as tx +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.utils import nmget +from ai.backend.common.types import DefaultForUnspecified, ResourceSlot + +from ..models import ( + agents, resource_presets, + domains, groups, kernels, users, + AgentStatus, + association_groups_users, + query_allowed_sgroups, + AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, + RESOURCE_USAGE_KERNEL_STATUSES, LIVE_STATUS, +) +from .auth import auth_required, superadmin_required +from .exceptions import ( + InvalidAPIParameters, +) +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, WebMiddleware +from .utils import check_api_params + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +_json_loads = functools.partial(json.loads, parse_float=Decimal) + + +@auth_required +async def list_presets(request: web.Request) -> web.Response: + """ + Returns the list of all resource presets. + """ + log.info('LIST_PRESETS (ak:{})', request['keypair']['access_key']) + root_ctx: RootContext = request.app['_root.context'] + await root_ctx.shared_config.get_resource_slots() + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([resource_presets]) + .select_from(resource_presets) + ) + # TODO: uncomment when we implement scaling group. + # scaling_group = request.query.get('scaling_group') + # if scaling_group is not None: + # query = query.where(resource_presets.c.scaling_group == scaling_group) + resp: MutableMapping[str, Any] = {'presets': []} + async for row in (await conn.stream(query)): + preset_slots = row['resource_slots'].normalize_slots(ignore_unknown=True) + resp['presets'].append({ + 'name': row['name'], + 'shared_memory': str(row['shared_memory']) if row['shared_memory'] else None, + 'resource_slots': preset_slots.to_json(), + }) + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + t.Key('scaling_group', default=None): t.Null | t.String, + t.Key('group', default='default'): t.String, + })) +async def check_presets(request: web.Request, params: Any) -> web.Response: + """ + Returns the list of all resource presets in the current scaling group, + with additional information including allocatability of each preset, + amount of total remaining resources, and the current keypair resource limits. + """ + root_ctx: RootContext = request.app['_root.context'] + try: + access_key = request['keypair']['access_key'] + resource_policy = request['keypair']['resource_policy'] + domain_name = request['user']['domain_name'] + # TODO: uncomment when we implement scaling group. + # scaling_group = request.query.get('scaling_group') + # assert scaling_group is not None, 'scaling_group parameter is missing.' + except (json.decoder.JSONDecodeError, AssertionError) as e: + raise InvalidAPIParameters(extra_msg=str(e.args[0])) + known_slot_types = await root_ctx.shared_config.get_resource_slots() + resp: MutableMapping[str, Any] = { + 'keypair_limits': None, + 'keypair_using': None, + 'keypair_remaining': None, + 'scaling_group_remaining': None, + 'scaling_groups': None, + 'presets': [], + } + log.info('CHECK_PRESETS (ak:{}, g:{}, sg:{})', + request['keypair']['access_key'], params['group'], params['scaling_group']) + + async with root_ctx.db.begin_readonly() as conn: + # Check keypair resource limit. + keypair_limits = ResourceSlot.from_policy(resource_policy, known_slot_types) + keypair_occupied = await root_ctx.registry.get_keypair_occupancy(access_key, conn=conn) + keypair_remaining = keypair_limits - keypair_occupied + + # Check group resource limit and get group_id. + j = sa.join( + groups, association_groups_users, + association_groups_users.c.group_id == groups.c.id, + ) + query = ( + sa.select([groups.c.id, groups.c.total_resource_slots]) + .select_from(j) + .where( + (association_groups_users.c.user_id == request['user']['uuid']) & + (groups.c.name == params['group']) & + (domains.c.name == domain_name), + ) + ) + result = await conn.execute(query) + row = result.first() + group_id = row['id'] + group_resource_slots = row['total_resource_slots'] + if group_id is None: + raise InvalidAPIParameters('Unknown user group') + group_resource_policy = { + 'total_resource_slots': group_resource_slots, + 'default_for_unspecified': DefaultForUnspecified.UNLIMITED, + } + group_limits = ResourceSlot.from_policy(group_resource_policy, known_slot_types) + group_occupied = await root_ctx.registry.get_group_occupancy(group_id, conn=conn) + group_remaining = group_limits - group_occupied + + # Check domain resource limit. + query = (sa.select([domains.c.total_resource_slots]) + .where(domains.c.name == domain_name)) + domain_resource_slots = await conn.scalar(query) + domain_resource_policy = { + 'total_resource_slots': domain_resource_slots, + 'default_for_unspecified': DefaultForUnspecified.UNLIMITED, + } + domain_limits = ResourceSlot.from_policy(domain_resource_policy, known_slot_types) + domain_occupied = await root_ctx.registry.get_domain_occupancy(domain_name, conn=conn) + domain_remaining = domain_limits - domain_occupied + + # Take minimum remaining resources. There's no need to merge limits and occupied. + # To keep legacy, we just merge all remaining slots into `keypair_remainig`. + for slot in known_slot_types: + keypair_remaining[slot] = min( + keypair_remaining[slot], + group_remaining[slot], + domain_remaining[slot], + ) + + # Prepare per scaling group resource. + sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key) + sgroup_names = [sg.name for sg in sgroups] + if params['scaling_group'] is not None: + if params['scaling_group'] not in sgroup_names: + raise InvalidAPIParameters('Unknown scaling group') + sgroup_names = [params['scaling_group']] + per_sgroup = { + sgname: { + 'using': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}), + 'remaining': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}), + } for sgname in sgroup_names + } + + # Per scaling group resource using from resource occupying kernels. + query = ( + sa.select([kernels.c.occupied_slots, kernels.c.scaling_group]) + .select_from(kernels) + .where( + (kernels.c.user_uuid == request['user']['uuid']) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & + (kernels.c.scaling_group.in_(sgroup_names)), + ) + ) + async for row in (await conn.stream(query)): + per_sgroup[row['scaling_group']]['using'] += row['occupied_slots'] + + # Per scaling group resource remaining from agents stats. + sgroup_remaining = ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}) + query = ( + sa.select([agents.c.available_slots, agents.c.occupied_slots, agents.c.scaling_group]) + .select_from(agents) + .where( + (agents.c.status == AgentStatus.ALIVE) & + (agents.c.scaling_group.in_(sgroup_names)), + ) + ) + agent_slots = [] + async for row in (await conn.stream(query)): + remaining = row['available_slots'] - row['occupied_slots'] + remaining += ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}) + sgroup_remaining += remaining + agent_slots.append(remaining) + per_sgroup[row['scaling_group']]['remaining'] += remaining + + # Take maximum allocatable resources per sgroup. + for sgname, sgfields in per_sgroup.items(): + for rtype, slots in sgfields.items(): + if rtype == 'remaining': + for slot in known_slot_types.keys(): + if slot in slots: + slots[slot] = min(keypair_remaining[slot], slots[slot]) + per_sgroup[sgname][rtype] = slots.to_json() # type: ignore # it's serialization + for slot in known_slot_types.keys(): + sgroup_remaining[slot] = min(keypair_remaining[slot], sgroup_remaining[slot]) + + # Fetch all resource presets in the current scaling group. + query = ( + sa.select([resource_presets]) + .select_from(resource_presets) + ) + async for row in (await conn.stream(query)): + # Check if there are any agent that can allocate each preset. + allocatable = False + preset_slots = row['resource_slots'].normalize_slots(ignore_unknown=True) + for agent_slot in agent_slots: + if agent_slot >= preset_slots and keypair_remaining >= preset_slots: + allocatable = True + break + resp['presets'].append({ + 'name': row['name'], + 'resource_slots': preset_slots.to_json(), + 'shared_memory': str(row['shared_memory']) if row['shared_memory'] is not None else None, + 'allocatable': allocatable, + }) + + # Return group resource status as NaN if not allowed. + group_resource_visibility = \ + await root_ctx.shared_config.get_raw('config/api/resources/group_resource_visibility') + group_resource_visibility = t.ToBool().check(group_resource_visibility) + if not group_resource_visibility: + group_limits = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) + group_occupied = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) + group_remaining = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) + + resp['keypair_limits'] = keypair_limits.to_json() + resp['keypair_using'] = keypair_occupied.to_json() + resp['keypair_remaining'] = keypair_remaining.to_json() + resp['group_limits'] = group_limits.to_json() + resp['group_using'] = group_occupied.to_json() + resp['group_remaining'] = group_remaining.to_json() + resp['scaling_group_remaining'] = sgroup_remaining.to_json() + resp['scaling_groups'] = per_sgroup + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +async def recalculate_usage(request: web.Request) -> web.Response: + """ + Update `keypair_resource_usages` in redis and `agents.c.occupied_slots`. + + Those two values are sometimes out of sync. In that case, calling this API + re-calculates the values for running containers and updates them in DB. + """ + log.info('RECALCULATE_USAGE ()') + root_ctx: RootContext = request.app['_root.context'] + await root_ctx.registry.recalc_resource_usage() + return web.json_response({}, status=200) + + +async def get_container_stats_for_period(request: web.Request, start_date, end_date, group_ids=None): + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin_readonly() as conn: + j = ( + kernels + .join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid) + ) + query = ( + sa.select([ + kernels.c.id, + kernels.c.container_id, + kernels.c.session_name, + kernels.c.access_key, + kernels.c.agent, + kernels.c.domain_name, + kernels.c.group_id, + kernels.c.attached_devices, + kernels.c.occupied_slots, + kernels.c.resource_opts, + kernels.c.vfolder_mounts, + kernels.c.mounts, + kernels.c.image, + kernels.c.status, + kernels.c.status_changed, + kernels.c.last_stat, + kernels.c.created_at, + kernels.c.terminated_at, + groups.c.name, + users.c.email, + ]) + .select_from(j) + .where( + # Filter sessions which existence period overlaps with requested period + ((kernels.c.terminated_at >= start_date) & (kernels.c.created_at < end_date) & + (kernels.c.status.in_(RESOURCE_USAGE_KERNEL_STATUSES))) | + # Or, filter running sessions which created before requested end_date + ((kernels.c.created_at < end_date) & (kernels.c.status.in_(LIVE_STATUS))), + ) + .order_by(sa.asc(kernels.c.terminated_at)) + ) + if group_ids: + query = query.where(kernels.c.group_id.in_(group_ids)) + result = await conn.execute(query) + rows = result.fetchall() + + def _pipe_builder(r: Redis) -> RedisPipeline: + pipe = r.pipeline() + for row in rows: + pipe.get(str(row['id'])) + return pipe + + raw_stats = await redis.execute(root_ctx.redis_stat, _pipe_builder) + + objs_per_group = {} + local_tz = root_ctx.shared_config['system']['timezone'] + + for row, raw_stat in zip(rows, raw_stats): + group_id = str(row['group_id']) + last_stat = row['last_stat'] + if not last_stat: + if raw_stat is None: + log.warn('stat object for {} not found on redis, skipping', str(row['id'])) + continue + last_stat = msgpack.unpackb(raw_stat) + nfs = None + if row['vfolder_mounts']: + # For >=22.03, return used host directories instead of volume host, which is not so useful. + nfs = list(set([str(mount.host_path) for mount in row['vfolder_mounts']])) + elif row['mounts'] and isinstance(row['mounts'][0], list): + # For the kernel records that have legacy contents of `mounts`. + nfs = list(set([mount[2] for mount in row['mounts']])) + if row['terminated_at'] is None: + used_time = used_days = None + else: + used_time = str(row['terminated_at'] - row['created_at']) + used_days = (row['terminated_at'].astimezone(local_tz).toordinal() - + row['created_at'].astimezone(local_tz).toordinal() + 1) + device_type = set() + smp = 0 + gpu_mem_allocated = 0 + if row.attached_devices and row.attached_devices.get('cuda'): + for dev_info in row.attached_devices['cuda']: + if dev_info.get('model_name'): + device_type.add(dev_info['model_name']) + smp += int(nmget(dev_info, 'data.smp', 0)) + gpu_mem_allocated += int(nmget(dev_info, 'data.mem', 0)) + gpu_allocated = 0 + if 'cuda.devices' in row.occupied_slots: + gpu_allocated = row.occupied_slots['cuda.devices'] + if 'cuda.shares' in row.occupied_slots: + gpu_allocated = row.occupied_slots['cuda.shares'] + c_info = { + 'id': str(row['id']), + 'container_id': row['container_id'], + 'domain_name': row['domain_name'], + 'group_id': str(row['group_id']), + 'group_name': row['name'], + 'name': row['session_name'], + 'access_key': row['access_key'], + 'email': row['email'], + 'agent': row['agent'], + 'cpu_allocated': float(row.occupied_slots.get('cpu', 0)), + 'cpu_used': float(nmget(last_stat, 'cpu_used.current', 0)), + 'mem_allocated': int(row.occupied_slots.get('mem', 0)), + 'mem_used': int(nmget(last_stat, 'mem.capacity', 0)), + 'shared_memory': int(nmget(row.resource_opts, 'shmem', 0)), + 'disk_allocated': 0, # TODO: disk quota limit + 'disk_used': (int(nmget(last_stat, 'io_scratch_size/stats.max', 0, '/'))), + 'io_read': int(nmget(last_stat, 'io_read.current', 0)), + 'io_write': int(nmget(last_stat, 'io_write.current', 0)), + 'used_time': used_time, + 'used_days': used_days, + 'device_type': list(device_type), + 'smp': float(smp), + 'gpu_mem_allocated': float(gpu_mem_allocated), + 'gpu_allocated': float(gpu_allocated), # devices or shares + 'nfs': nfs, + 'image_id': row['image'], # TODO: image id + 'image_name': row['image'], + 'created_at': str(row['created_at']), + 'terminated_at': str(row['terminated_at']), + 'status': row['status'].name, + 'status_changed': str(row['status_changed']), + } + if group_id not in objs_per_group: + objs_per_group[group_id] = { + 'domain_name': row['domain_name'], + 'g_id': group_id, + 'g_name': row['name'], # this is group's name + 'g_cpu_allocated': c_info['cpu_allocated'], + 'g_cpu_used': c_info['cpu_used'], + 'g_mem_allocated': c_info['mem_allocated'], + 'g_mem_used': c_info['mem_used'], + 'g_shared_memory': c_info['shared_memory'], + 'g_disk_allocated': c_info['disk_allocated'], + 'g_disk_used': c_info['disk_used'], + 'g_io_read': c_info['io_read'], + 'g_io_write': c_info['io_write'], + 'g_device_type': copy.deepcopy(c_info['device_type']), + 'g_smp': c_info['smp'], + 'g_gpu_mem_allocated': c_info['gpu_mem_allocated'], + 'g_gpu_allocated': c_info['gpu_allocated'], + 'c_infos': [c_info], + } + else: + objs_per_group[group_id]['g_cpu_allocated'] += c_info['cpu_allocated'] + objs_per_group[group_id]['g_cpu_used'] += c_info['cpu_used'] + objs_per_group[group_id]['g_mem_allocated'] += c_info['mem_allocated'] + objs_per_group[group_id]['g_mem_used'] += c_info['mem_used'] + objs_per_group[group_id]['g_shared_memory'] += c_info['shared_memory'] + objs_per_group[group_id]['g_disk_allocated'] += c_info['disk_allocated'] + objs_per_group[group_id]['g_disk_used'] += c_info['disk_used'] + objs_per_group[group_id]['g_io_read'] += c_info['io_read'] + objs_per_group[group_id]['g_io_write'] += c_info['io_write'] + for device in c_info['device_type']: + if device not in objs_per_group[group_id]['g_device_type']: + g_dev_type = objs_per_group[group_id]['g_device_type'] + g_dev_type.append(device) + objs_per_group[group_id]['g_device_type'] = list(set(g_dev_type)) + objs_per_group[group_id]['g_smp'] += c_info['smp'] + objs_per_group[group_id]['g_gpu_mem_allocated'] += c_info['gpu_mem_allocated'] + objs_per_group[group_id]['g_gpu_allocated'] += c_info['gpu_allocated'] + objs_per_group[group_id]['c_infos'].append(c_info) + return list(objs_per_group.values()) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.MultiKey('group_ids'): t.List(t.String) | t.Null, + t.Key('month'): t.Regexp(r'^\d{6}', re.ASCII), + }), + loads=_json_loads) +async def usage_per_month(request: web.Request, params: Any) -> web.Response: + """ + Return usage statistics of terminated containers for a specified month. + The date/time comparison is done using the configured timezone. + + :param group_ids: If not None, query containers only in those groups. + :param month: The year-month to query usage statistics. ex) "202006" to query for Jun 2020 + """ + log.info('USAGE_PER_MONTH (g:[{}], month:{})', + ','.join(params['group_ids']), params['month']) + root_ctx: RootContext = request.app['_root.context'] + local_tz = root_ctx.shared_config['system']['timezone'] + try: + start_date = datetime.strptime(params['month'], '%Y%m').replace(tzinfo=local_tz) + end_date = start_date + relativedelta(months=+1) + except ValueError: + raise InvalidAPIParameters(extra_msg='Invalid date values') + resp = await get_container_stats_for_period(request, start_date, end_date, params['group_ids']) + log.debug('container list are retrieved for month {0}', params['month']) + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + t.Key('group_id'): t.String | t.Null, + t.Key('start_date'): t.Regexp(r'^\d{8}$', re.ASCII), + t.Key('end_date'): t.Regexp(r'^\d{8}$', re.ASCII), + }), + loads=_json_loads) +async def usage_per_period(request: web.Request, params: Any) -> web.Response: + """ + Return usage statistics of terminated containers belonged to the given group for a specified + period in dates. + The date/time comparison is done using the configured timezone. + + :param group_id: If not None, query containers only in the group. + :param start_date str: "yyyymmdd" format. + :param end_date str: "yyyymmdd" format. + """ + root_ctx: RootContext = request.app['_root.context'] + group_id = params['group_id'] + local_tz = root_ctx.shared_config['system']['timezone'] + try: + start_date = datetime.strptime(params['start_date'], '%Y%m%d').replace(tzinfo=local_tz) + end_date = datetime.strptime(params['end_date'], '%Y%m%d').replace(tzinfo=local_tz) + end_date = end_date + timedelta(days=1) # include sessions in end_date + if end_date - start_date > timedelta(days=100): + raise InvalidAPIParameters('Cannot query more than 100 days') + except ValueError: + raise InvalidAPIParameters(extra_msg='Invalid date values') + if end_date <= start_date: + raise InvalidAPIParameters(extra_msg='end_date must be later than start_date.') + log.info('USAGE_PER_MONTH (g:{}, start_date:{}, end_date:{})', + group_id, start_date, end_date) + group_ids = [group_id] if group_id is not None else None + resp = await get_container_stats_for_period(request, start_date, end_date, group_ids=group_ids) + log.debug('container list are retrieved from {0} to {1}', start_date, end_date) + return web.json_response(resp, status=200) + + +async def get_time_binned_monthly_stats(request: web.Request, user_uuid=None): + """ + Generate time-binned (15 min) stats for the last one month (2880 points). + The structure of the result would be: + + [ + # [ + # timestamp, num_sessions, + # cpu_allocated, mem_allocated, gpu_allocated, + # io_read, io_write, scratch_used, + # ] + [1562083808.657106, 1, 1.2, 1073741824, ...], + [1562084708.657106, 2, 4.0, 1073741824, ...], + ] + + Note that the timestamp is in UNIX-timestamp. + """ + # Get all or user kernels for the last month from DB. + time_window = 900 # 15 min + now = datetime.now(tzutc()) + start_date = now - timedelta(days=30) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([kernels]) + .select_from(kernels) + .where( + (kernels.c.terminated_at >= start_date) & + (kernels.c.status.in_(RESOURCE_USAGE_KERNEL_STATUSES)), + ) + .order_by(sa.asc(kernels.c.created_at)) + ) + if user_uuid is not None: + query = query.where(kernels.c.user_uuid == user_uuid) + result = await conn.execute(query) + rows = result.fetchall() + + # Build time-series of time-binned stats. + rowcount = len(rows) + now_ts = now.timestamp() + start_date_ts = start_date.timestamp() + ts = start_date_ts + idx = 0 + tseries = [] + # Iterate over each time window. + while ts < now_ts: + # Initialize the time-binned stats. + num_sessions = 0 + cpu_allocated = 0 + mem_allocated = 0 + gpu_allocated = Decimal(0) + io_read_bytes = 0 + io_write_bytes = 0 + disk_used = 0 + # Accumulate stats for containers overlapping with this time window. + while idx < rowcount and \ + ts + time_window > rows[idx].created_at.timestamp() and \ + ts < rows[idx].terminated_at.timestamp(): + # Accumulate stats for overlapping containers in this time window. + row = rows[idx] + num_sessions += 1 + cpu_allocated += int(row.occupied_slots.get('cpu', 0)) + mem_allocated += int(row.occupied_slots.get('mem', 0)) + if 'cuda.devices' in row.occupied_slots: + gpu_allocated += int(row.occupied_slots['cuda.devices']) + if 'cuda.shares' in row.occupied_slots: + gpu_allocated += Decimal(row.occupied_slots['cuda.shares']) + raw_stat = await redis.execute(root_ctx.redis_stat, lambda r: r.get(str(row['id']))) + if raw_stat: + last_stat = msgpack.unpackb(raw_stat) + io_read_bytes += int(nmget(last_stat, 'io_read.current', 0)) + io_write_bytes += int(nmget(last_stat, 'io_write.current', 0)) + disk_used += int(nmget(last_stat, 'io_scratch_size/stats.max', 0, '/')) + idx += 1 + stat = { + "date": ts, + "num_sessions": { + "value": num_sessions, + "unit_hint": "count", + }, + "cpu_allocated": { + "value": cpu_allocated, + "unit_hint": "count", + }, + "mem_allocated": { + "value": mem_allocated, + "unit_hint": "bytes", + }, + "gpu_allocated": { + "value": float(gpu_allocated), + "unit_hint": "count", + }, + "io_read_bytes": { + "value": io_read_bytes, + "unit_hint": "bytes", + }, + "io_write_bytes": { + "value": io_write_bytes, + "unit_hint": "bytes", + }, + "disk_used": { + "value ": disk_used, + "unit_hint": "bytes", + }, + } + tseries.append(stat) + ts += time_window + return tseries + + +@server_status_required(READ_ALLOWED) +@auth_required +async def user_month_stats(request: web.Request) -> web.Response: + """ + Return time-binned (15 min) stats for terminated user sessions + over last 30 days. + """ + access_key = request['keypair']['access_key'] + user_uuid = request['user']['uuid'] + log.info('USER_LAST_MONTH_STATS (ak:{}, u:{})', access_key, user_uuid) + stats = await get_time_binned_monthly_stats(request, user_uuid=user_uuid) + return web.json_response(stats, status=200) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +async def admin_month_stats(request: web.Request) -> web.Response: + """ + Return time-binned (15 min) stats for all terminated sessions + over last 30 days. + """ + log.info('ADMIN_LAST_MONTH_STATS ()') + stats = await get_time_binned_monthly_stats(request, user_uuid=None) + return web.json_response(stats, status=200) + + +async def get_watcher_info(request: web.Request, agent_id: str) -> dict: + """ + Get watcher information. + + :return addr: address of agent watcher (eg: http://127.0.0.1:6009) + :return token: agent watcher token ("insecure" if not set in config server) + """ + root_ctx: RootContext = request.app['_root.context'] + token = root_ctx.shared_config['watcher']['token'] + if token is None: + token = 'insecure' + agent_ip = await root_ctx.shared_config.etcd.get(f'nodes/agents/{agent_id}/ip') + raw_watcher_port = await root_ctx.shared_config.etcd.get( + f'nodes/agents/{agent_id}/watcher_port', + ) + watcher_port = 6099 if raw_watcher_port is None else int(raw_watcher_port) + # TODO: watcher scheme is assumed to be http + addr = yarl.URL(f'http://{agent_ip}:{watcher_port}') + return { + 'addr': addr, + 'token': token, + } + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['agent_id', 'agent']): t.String, + })) +async def get_watcher_status(request: web.Request, params: Any) -> web.Response: + log.info('GET_WATCHER_STATUS (ag:{})', params['agent_id']) + watcher_info = await get_watcher_info(request, params['agent_id']) + connector = aiohttp.TCPConnector() + async with aiohttp.ClientSession(connector=connector) as sess: + with _timeout(5.0): + headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} + async with sess.get(watcher_info['addr'], headers=headers) as resp: + if resp.status == 200: + data = await resp.json() + return web.json_response(data, status=resp.status) + else: + data = await resp.text() + return web.Response(text=data, status=resp.status) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['agent_id', 'agent']): t.String, + })) +async def watcher_agent_start(request: web.Request, params: Any) -> web.Response: + log.info('WATCHER_AGENT_START (ag:{})', params['agent_id']) + watcher_info = await get_watcher_info(request, params['agent_id']) + connector = aiohttp.TCPConnector() + async with aiohttp.ClientSession(connector=connector) as sess: + with _timeout(20.0): + watcher_url = watcher_info['addr'] / 'agent/start' + headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} + async with sess.post(watcher_url, headers=headers) as resp: + if resp.status == 200: + data = await resp.json() + return web.json_response(data, status=resp.status) + else: + data = await resp.text() + return web.Response(text=data, status=resp.status) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['agent_id', 'agent']): t.String, + })) +async def watcher_agent_stop(request: web.Request, params: Any) -> web.Response: + log.info('WATCHER_AGENT_STOP (ag:{})', params['agent_id']) + watcher_info = await get_watcher_info(request, params['agent_id']) + connector = aiohttp.TCPConnector() + async with aiohttp.ClientSession(connector=connector) as sess: + with _timeout(20.0): + watcher_url = watcher_info['addr'] / 'agent/stop' + headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} + async with sess.post(watcher_url, headers=headers) as resp: + if resp.status == 200: + data = await resp.json() + return web.json_response(data, status=resp.status) + else: + data = await resp.text() + return web.Response(text=data, status=resp.status) + + +@server_status_required(READ_ALLOWED) +@superadmin_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['agent_id', 'agent']): t.String, + })) +async def watcher_agent_restart(request: web.Request, params: Any) -> web.Response: + log.info('WATCHER_AGENT_RESTART (ag:{})', params['agent_id']) + watcher_info = await get_watcher_info(request, params['agent_id']) + connector = aiohttp.TCPConnector() + async with aiohttp.ClientSession(connector=connector) as sess: + with _timeout(20.0): + watcher_url = watcher_info['addr'] / 'agent/restart' + headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} + async with sess.post(watcher_url, headers=headers) as resp: + if resp.status == 200: + data = await resp.json() + return web.json_response(data, status=resp.status) + else: + data = await resp.text() + return web.Response(text=data, status=resp.status) + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app['api_versions'] = (4,) + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + add_route = app.router.add_route + cors.add(add_route('GET', '/presets', list_presets)) + cors.add(add_route('POST', '/check-presets', check_presets)) + cors.add(add_route('POST', '/recalculate-usage', recalculate_usage)) + cors.add(add_route('GET', '/usage/month', usage_per_month)) + cors.add(add_route('GET', '/usage/period', usage_per_period)) + cors.add(add_route('GET', '/stats/user/month', user_month_stats)) + cors.add(add_route('GET', '/stats/admin/month', admin_month_stats)) + cors.add(add_route('GET', '/watcher', get_watcher_status)) + cors.add(add_route('POST', '/watcher/agent/start', watcher_agent_start)) + cors.add(add_route('POST', '/watcher/agent/stop', watcher_agent_stop)) + cors.add(add_route('POST', '/watcher/agent/restart', watcher_agent_restart)) + return app, [] diff --git a/src/ai/backend/manager/api/scaling_group.py b/src/ai/backend/manager/api/scaling_group.py new file mode 100644 index 0000000000..16d4ccffb7 --- /dev/null +++ b/src/ai/backend/manager/api/scaling_group.py @@ -0,0 +1,125 @@ +import logging +from typing import ( + Any, + Iterable, + TYPE_CHECKING, + Tuple, +) + +from aiohttp import web +import aiohttp +import aiohttp_cors +import aiotools +from dataclasses import dataclass, field +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.logging import BraceStyleAdapter + +from ai.backend.manager.api.exceptions import ObjectNotFound + +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + +from ..models import ( + query_allowed_sgroups, +) +from .auth import auth_required +from .manager import ( + READ_ALLOWED, + server_status_required) +from .types import CORSOptions, WebMiddleware +from .utils import check_api_params + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@dataclass(unsafe_hash=True) +class WSProxyVersionQueryParams: + db_ctx: ExtendedAsyncSAEngine = field(hash=False) + + +@aiotools.lru_cache(expire_after=30) # expire after 30 seconds +async def query_wsproxy_version( + wsproxy_addr: str, +) -> str: + async with aiohttp.ClientSession() as session: + async with session.get(wsproxy_addr + '/status') as resp: + version_json = await resp.json() + return version_json['api_version'] + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + tx.AliasedKey(['group', 'group_id', 'group_name']): tx.UUID | t.String, + }), +) +async def list_available_sgroups(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + domain_name = request['user']['domain_name'] + group_id_or_name = params['group'] + log.info('SGROUPS.LIST(ak:{}, g:{}, d:{})', access_key, group_id_or_name, domain_name) + async with root_ctx.db.begin() as conn: + sgroups = await query_allowed_sgroups( + conn, domain_name, group_id_or_name, access_key) + return web.json_response({ + 'scaling_groups': [ + {'name': sgroup['name']} + for sgroup in sgroups + ], + }, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params(t.Dict({ + tx.AliasedKey(['group', 'group_id', 'group_name'], default=None): t.Null | tx.UUID | t.String, +})) +async def get_wsproxy_version(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + scaling_group_name = request.match_info['scaling_group'] + access_key = request['keypair']['access_key'] + domain_name = request['user']['domain_name'] + group_id_or_name = params['group'] + log.info('SGROUPS.LIST(ak:{}, g:{}, d:{})', access_key, group_id_or_name, domain_name) + async with root_ctx.db.begin_readonly() as conn: + sgroups = await query_allowed_sgroups( + conn, domain_name, group_id_or_name or '', access_key) + for sgroup in sgroups: + if sgroup['name'] == scaling_group_name: + wsproxy_addr = sgroup['wsproxy_addr'] + if not wsproxy_addr: + wsproxy_version = 'v1' + else: + wsproxy_version = await query_wsproxy_version(wsproxy_addr) + return web.json_response({ + 'wsproxy_version': wsproxy_version, + }) + else: + raise ObjectNotFound(object_name='scaling group') + + +async def init(app: web.Application) -> None: + pass + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app['prefix'] = 'scaling-groups' + app['api_versions'] = (2, 3, 4) + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + root_resource = cors.add(app.router.add_resource(r'')) + cors.add(root_resource.add_route('GET', list_available_sgroups)) + cors.add(app.router.add_route('GET', '/{scaling_group}/wsproxy-version', get_wsproxy_version)) + return app, [] diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py new file mode 100644 index 0000000000..3cfe34d52b --- /dev/null +++ b/src/ai/backend/manager/api/session.py @@ -0,0 +1,2278 @@ +""" +REST-style session management APIs. +""" +from __future__ import annotations + +import asyncio +import base64 +import functools +import json +import logging +import re +import secrets +import time +import uuid +import yarl +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + Union, + TYPE_CHECKING, + cast, +) +from decimal import Decimal +from datetime import datetime, timedelta +from io import BytesIO +from pathlib import PurePosixPath +from urllib.parse import urlparse + +import aiohttp +import aiohttp_cors +import aioredis +import aiotools +import attr +import multidict +import sqlalchemy as sa +import trafaret as t +from aiohttp import web, hdrs +from async_timeout import timeout +from dateutil.parser import isoparse +from dateutil.tz import tzutc +from sqlalchemy.sql.expression import true, null + +from ai.backend.manager.models.image import ImageRow + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection + +from ai.backend.common import redis, validators as tx +from ai.backend.common.docker import ImageRef +from ai.backend.common.exception import ( + UnknownImageReference, + AliasResolutionFailed, +) +from ai.backend.common.events import ( + AgentHeartbeatEvent, + AgentStartedEvent, + AgentTerminatedEvent, + DoSyncKernelLogsEvent, + DoSyncKernelStatsEvent, + DoTerminateSessionEvent, + KernelCancelledEvent, + KernelCreatingEvent, + KernelPreparingEvent, + KernelPullingEvent, + KernelStartedEvent, + KernelTerminatedEvent, + KernelTerminatingEvent, + SessionEnqueuedEvent, + SessionScheduledEvent, + SessionPreparingEvent, + SessionCancelledEvent, + SessionFailureEvent, + SessionStartedEvent, + SessionSuccessEvent, + SessionTerminatedEvent, +) +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.utils import cancel_tasks, str_to_timedelta +from ai.backend.common.types import ( + AccessKey, + AgentId, + KernelId, + ClusterMode, + KernelEnqueueingConfig, + SessionTypes, + check_typed_dict, +) +from ai.backend.common.plugin.monitor import GAUGE + +from ..config import DEFAULT_CHUNK_SIZE +from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE, REDIS_STREAM_DB +from ..types import UserScope +from ..models import ( + domains, + association_groups_users as agus, groups, + keypairs, kernels, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, + query_bootstrap_script, + keypair_resource_policies, + scaling_groups, + users, UserRole, + vfolders, + AgentStatus, KernelStatus, + query_accessible_vfolders, + session_templates, + verify_vfolder_name, + DEAD_KERNEL_STATUSES, +) +from ..models.kernel import match_session_ids +from ..models.utils import execute_with_retry +from .exceptions import ( + AppNotFound, + InvalidAPIParameters, + ObjectNotFound, + ImageNotFound, + InsufficientPrivilege, + ServiceUnavailable, + SessionNotFound, + SessionAlreadyExists, + TooManySessionsMatched, + BackendError, + InternalServerError, + TaskTemplateNotFound, + StorageProxyError, + UnknownImageReferenceError, +) +from .auth import auth_required +from .types import CORSOptions, WebMiddleware +from .utils import ( + catch_unexpected, check_api_params, get_access_key_scopes, undefined, +) +from .manager import ALL_ALLOWED, READ_ALLOWED, server_status_required +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +_json_loads = functools.partial(json.loads, parse_float=Decimal) + + +class UndefChecker(t.Trafaret): + def check_and_return(self, value: Any) -> object: + if value == undefined: + return value + else: + self._failure('Invalid Undef format', value=value) + return None + + +creation_config_v1 = t.Dict({ + t.Key('mounts', default=None): t.Null | t.List(t.String), + t.Key('environ', default=None): t.Null | t.Mapping(t.String, t.String), + t.Key('clusterSize', default=None): t.Null | t.Int[1:], +}) +creation_config_v2 = t.Dict({ + t.Key('mounts', default=None): t.Null | t.List(t.String), + t.Key('environ', default=None): t.Null | t.Mapping(t.String, t.String), + t.Key('clusterSize', default=None): t.Null | t.Int[1:], + t.Key('instanceMemory', default=None): t.Null | tx.BinarySize, + t.Key('instanceCores', default=None): t.Null | t.Int, + t.Key('instanceGPUs', default=None): t.Null | t.Float, + t.Key('instanceTPUs', default=None): t.Null | t.Int, +}) +creation_config_v3 = t.Dict({ + t.Key('mounts', default=None): t.Null | t.List(t.String), + t.Key('environ', default=None): t.Null | t.Mapping(t.String, t.String), + tx.AliasedKey(['cluster_size', 'clusterSize'], default=None): + t.Null | t.Int[1:], + tx.AliasedKey(['scaling_group', 'scalingGroup'], default=None): + t.Null | t.String, + t.Key('resources', default=None): t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['resource_opts', 'resourceOpts'], default=None): + t.Null | t.Mapping(t.String, t.Any), +}) +creation_config_v3_template = t.Dict({ + t.Key('mounts', default=undefined): UndefChecker | t.Null | t.List(t.String), + t.Key('environ', default=undefined): UndefChecker | t.Null | t.Mapping(t.String, t.String), + tx.AliasedKey(['cluster_size', 'clusterSize'], default=undefined): + UndefChecker | t.Null | t.Int[1:], + tx.AliasedKey(['scaling_group', 'scalingGroup'], default=undefined): + UndefChecker | t.Null | t.String, + t.Key('resources', default=undefined): UndefChecker | t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['resource_opts', 'resourceOpts'], default=undefined): + UndefChecker | t.Null | t.Mapping(t.String, t.Any), +}) +creation_config_v4 = t.Dict({ + t.Key('mounts', default=None): t.Null | t.List(t.String), + tx.AliasedKey(['mount_map', 'mountMap'], default=None): t.Null | t.Mapping(t.String, t.String), + t.Key('environ', default=None): t.Null | t.Mapping(t.String, t.String), + tx.AliasedKey(['cluster_size', 'clusterSize'], default=None): t.Null | t.Int[1:], + tx.AliasedKey(['scaling_group', 'scalingGroup'], default=None): t.Null | t.String, + t.Key('resources', default=None): t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['resource_opts', 'resourceOpts'], default=None): t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['preopen_ports', 'preopenPorts'], default=None): t.Null | t.List(t.Int[1024:65535]), +}) +creation_config_v4_template = t.Dict({ + t.Key('mounts', default=undefined): UndefChecker | t.Null | t.List(t.String), + tx.AliasedKey(['mount_map', 'mountMap'], default=undefined): + UndefChecker | t.Null | t.Mapping(t.String, t.String), + t.Key('environ', default=undefined): UndefChecker | t.Null | t.Mapping(t.String, t.String), + tx.AliasedKey(['cluster_size', 'clusterSize'], default=undefined): + UndefChecker | t.Null | t.Int[1:], + tx.AliasedKey(['scaling_group', 'scalingGroup'], default=undefined): + UndefChecker | t.Null | t.String, + t.Key('resources', default=undefined): UndefChecker | t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['resource_opts', 'resourceOpts'], default=undefined): + UndefChecker | t.Null | t.Mapping(t.String, t.Any), +}) +creation_config_v5 = t.Dict({ + t.Key('mounts', default=None): t.Null | t.List(t.String), + tx.AliasedKey(['mount_map', 'mountMap'], default=None): + t.Null | t.Mapping(t.String, t.String), + t.Key('environ', default=None): t.Null | t.Mapping(t.String, t.String), + # cluster_size is moved to the root-level parameters + tx.AliasedKey(['scaling_group', 'scalingGroup'], default=None): t.Null | t.String, + t.Key('resources', default=None): t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['resource_opts', 'resourceOpts'], default=None): t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['preopen_ports', 'preopenPorts'], default=None): t.Null | t.List(t.Int[1024:65535]), + tx.AliasedKey(['agent_list', 'agentList'], default=None): t.Null | t.List(t.String), +}) +creation_config_v5_template = t.Dict({ + t.Key('mounts', default=undefined): UndefChecker | t.Null | t.List(t.String), + tx.AliasedKey(['mount_map', 'mountMap'], default=undefined): + UndefChecker | t.Null | t.Mapping(t.String, t.String), + t.Key('environ', default=undefined): UndefChecker | t.Null | t.Mapping(t.String, t.String), + # cluster_size is moved to the root-level parameters + tx.AliasedKey(['scaling_group', 'scalingGroup'], default=undefined): + UndefChecker | t.Null | t.String, + t.Key('resources', default=undefined): UndefChecker | t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['resource_opts', 'resourceOpts'], default=undefined): + UndefChecker | t.Null | t.Mapping(t.String, t.Any), +}) + + +overwritten_param_check = t.Dict({ + t.Key('template_id'): tx.UUID, + t.Key('session_name'): t.Regexp(r'^(?=.{4,64}$)\w[\w.-]*\w$', re.ASCII), + t.Key('image', default=None): t.Null | t.String, + tx.AliasedKey(['session_type', 'sess_type']): tx.Enum(SessionTypes), + t.Key('group', default=None): t.Null | t.String, + t.Key('domain', default=None): t.Null | t.String, + t.Key('config', default=None): t.Null | t.Mapping(t.String, t.Any), + t.Key('tag', default=None): t.Null | t.String, + t.Key('enqueue_only', default=False): t.ToBool, + t.Key('max_wait_seconds', default=0): t.Int[0:], + t.Key('reuse', default=True): t.ToBool, + t.Key('startup_command', default=None): t.Null | t.String, + t.Key('bootstrap_script', default=None): t.Null | t.String, + t.Key('owner_access_key', default=None): t.Null | t.String, + tx.AliasedKey(['scaling_group', 'scalingGroup'], default=None): t.Null | t.String, + tx.AliasedKey(['cluster_size', 'clusterSize'], default=None): t.Null | t.Int[1:], + tx.AliasedKey(['cluster_mode', 'clusterMode'], default='single-node'): tx.Enum(ClusterMode), + tx.AliasedKey(['starts_at', 'startsAt'], default=None): t.Null | t.String, +}).allow_extra('*') + + +def sub(d, old, new): + for k, v in d.items(): + if isinstance(v, Mapping) or isinstance(v, dict): + d[k] = sub(v, old, new) + elif d[k] == old: + d[k] = new + return d + + +def drop(d, dropval): + newd = {} + for k, v in d.items(): + if isinstance(v, Mapping) or isinstance(v, dict): + newval = drop(v, dropval) + if len(newval.keys()) > 0: # exclude empty dict always + newd[k] = newval + elif v != dropval: + newd[k] = v + return newd + + +async def _query_userinfo( + request: web.Request, + params: Any, + conn: SAConnection, +) -> Tuple[uuid.UUID, uuid.UUID, dict]: + if params['domain'] is None: + params['domain'] = request['user']['domain_name'] + scopes_param = { + 'owner_access_key': ( + None if params['owner_access_key'] is undefined + else params['owner_access_key'] + ), + } + requester_access_key, owner_access_key = await get_access_key_scopes(request, scopes_param) + requester_uuid = request['user']['uuid'] + + owner_uuid = None + group_id = None + resource_policy = None + + if requester_access_key != owner_access_key: + # Admin or superadmin is creating sessions for another user. + # The check for admin privileges is already done in get_access_key_scope(). + query = ( + sa.select([keypairs.c.user, keypairs.c.resource_policy, + users.c.role, users.c.domain_name]) + .select_from(sa.join(keypairs, users, keypairs.c.user == users.c.uuid)) + .where(keypairs.c.access_key == owner_access_key) + ) + result = await conn.execute(query) + row = result.first() + owner_domain = row['domain_name'] + owner_uuid = row['user'] + owner_role = row['role'] + query = ( + sa.select([keypair_resource_policies]) + .select_from(keypair_resource_policies) + .where(keypair_resource_policies.c.name == row['resource_policy']) + ) + result = await conn.execute(query) + resource_policy = result.first() + else: + # Normal case when the user is creating her/his own session. + owner_domain = request['user']['domain_name'] + owner_uuid = requester_uuid + owner_role = UserRole.USER + resource_policy = request['keypair']['resource_policy'] + + query = ( + sa.select([domains.c.name]) + .select_from(domains) + .where( + (domains.c.name == owner_domain) & + (domains.c.is_active), + ) + ) + qresult = await conn.execute(query) + domain_name = qresult.scalar() + if domain_name is None: + raise InvalidAPIParameters('Invalid domain') + + if owner_role == UserRole.SUPERADMIN: + # superadmin can spawn container in any designated domain/group. + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where( + (groups.c.domain_name == params['domain']) & + (groups.c.name == params['group']) & + (groups.c.is_active), + )) + qresult = await conn.execute(query) + group_id = qresult.scalar() + elif owner_role == UserRole.ADMIN: + # domain-admin can spawn container in any group in the same domain. + if params['domain'] != owner_domain: + raise InvalidAPIParameters("You can only set the domain to the owner's domain.") + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where( + (groups.c.domain_name == owner_domain) & + (groups.c.name == params['group']) & + (groups.c.is_active), + )) + qresult = await conn.execute(query) + group_id = qresult.scalar() + else: + # normal users can spawn containers in their group and domain. + if params['domain'] != owner_domain: + raise InvalidAPIParameters("You can only set the domain to your domain.") + query = ( + sa.select([agus.c.group_id]) + .select_from(agus.join(groups, agus.c.group_id == groups.c.id)) + .where( + (agus.c.user_id == owner_uuid) & + (groups.c.domain_name == owner_domain) & + (groups.c.name == params['group']) & + (groups.c.is_active), + )) + qresult = await conn.execute(query) + group_id = qresult.scalar() + if group_id is None: + raise InvalidAPIParameters('Invalid group') + + return owner_uuid, group_id, resource_policy + + +async def _create(request: web.Request, params: dict[str, Any]) -> web.Response: + if params['domain'] is None: + params['domain'] = request['user']['domain_name'] + scopes_param = { + 'owner_access_key': ( + None if params['owner_access_key'] is undefined + else params['owner_access_key'] + ), + } + requester_access_key, owner_access_key = await get_access_key_scopes(request, scopes_param) + log.info('GET_OR_CREATE (ak:{0}/{1}, img:{2}, s:{3})', + requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*', + params['image'], params['session_name']) + + root_ctx: RootContext = request.app['_root.context'] + app_ctx: PrivateContext = request.app['session.context'] + + resp: MutableMapping[str, Any] = {} + current_task = asyncio.current_task() + assert current_task is not None + + # Check work directory and reserved name directory. + mount_map = params['config'].get('mount_map') + if mount_map is not None: + original_folders = mount_map.keys() + alias_folders = mount_map.values() + if len(alias_folders) != len(set(alias_folders)): + raise InvalidAPIParameters('Duplicate alias folder name exists.') + + alias_name: str + for alias_name in alias_folders: + if alias_name is None: + continue + if alias_name.startswith("/home/work/"): + alias_name = alias_name.replace('/home/work/', '') + if alias_name == '': + raise InvalidAPIParameters('Alias name cannot be empty.') + if not verify_vfolder_name(alias_name): + raise InvalidAPIParameters(str(alias_name) + ' is reserved for internal path.') + if alias_name in original_folders: + raise InvalidAPIParameters('Alias name cannot be set to an existing folder name: ' + + str(alias_name)) + + # Resolve the image reference. + try: + async with root_ctx.db.begin_readonly_session() as session: + image_row = await ImageRow.resolve(session, [ + ImageRef(params['image'], ['*'], params['architecture']), + params['image'], + ]) + requested_image_ref = image_row.image_ref + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([domains.c.allowed_docker_registries]) + .select_from(domains) + .where(domains.c.name == params['domain']) + ) + allowed_registries = await conn.scalar(query) + if requested_image_ref.registry not in allowed_registries: + raise AliasResolutionFailed + except AliasResolutionFailed: + raise ImageNotFound('unknown alias or disallowed registry') + + # Check existing (owner_access_key, session_name) instance + try: + # NOTE: We can reuse the session IDs of TERMINATED sessions only. + # NOTE: Reusing a session in the PENDING status returns an empty value in service_ports. + kern = await root_ctx.registry.get_session(params['session_name'], owner_access_key) + running_image_ref = ImageRef(kern['image'], [kern['registry']], kern['architecture']) + if running_image_ref != requested_image_ref: + # The image must be same if get_or_create() called multiple times + # against an existing (non-terminated) session + raise SessionAlreadyExists(extra_data={'existingSessionId': str(kern['id'])}) + if not params['reuse']: + # Respond as error since the client did not request to reuse, + # but provide the overlapping session ID for later use. + raise SessionAlreadyExists(extra_data={'existingSessionId': str(kern['id'])}) + # Respond as success with the reused session's information. + return web.json_response({ + 'sessionId': str(kern['id']), + 'sessionName': str(kern['session_name']), + 'status': kern['status'].name, + 'service_ports': kern['service_ports'], + 'created': False, + }, status=200) + except SessionNotFound: + # It's time to create a new session. + pass + + if params['session_type'] == SessionTypes.BATCH and not params['startup_command']: + raise InvalidAPIParameters('Batch sessions must have a non-empty startup command.') + if params['session_type'] != SessionTypes.BATCH and params['starts_at']: + raise InvalidAPIParameters('Parameter starts_at should be used only for batch sessions') + starts_at: Union[datetime, None] = None + if params['starts_at']: + try: + starts_at = isoparse(params['starts_at']) + except ValueError: + _td = str_to_timedelta(params['starts_at']) + starts_at = datetime.now(tzutc()) + _td + + if params['cluster_size'] > 1: + log.debug(" -> cluster_mode:{} (replicate)", params['cluster_mode']) + + if params['dependencies'] is None: + params['dependencies'] = [] + + session_creation_id = secrets.token_urlsafe(16) + start_event = asyncio.Event() + kernel_id: Optional[KernelId] = None + session_creation_tracker = app_ctx.session_creation_tracker + session_creation_tracker[session_creation_id] = start_event + + async with root_ctx.db.begin_readonly() as conn: + owner_uuid, group_id, resource_policy = await _query_userinfo(request, params, conn) + + # Use keypair bootstrap_script if it is not delivered as a parameter + # (only for INTERACTIVE sessions). + if params['session_type'] == SessionTypes.INTERACTIVE and not params['bootstrap_script']: + script, _ = await query_bootstrap_script(conn, owner_access_key) + params['bootstrap_script'] = script + + try: + kernel_id = await asyncio.shield(app_ctx.database_ptask_group.create_task( + root_ctx.registry.enqueue_session( + session_creation_id, + params['session_name'], owner_access_key, + [{ + 'image_ref': requested_image_ref, + 'cluster_role': DEFAULT_ROLE, + 'cluster_idx': 1, + 'cluster_hostname': f"{DEFAULT_ROLE}1", + 'creation_config': params['config'], + 'bootstrap_script': params['bootstrap_script'], + 'startup_command': params['startup_command'], + }], + params['config']['scaling_group'], + params['session_type'], + resource_policy, + user_scope=UserScope( + domain_name=params['domain'], # type: ignore # params always have it + group_id=group_id, + user_uuid=owner_uuid, + user_role=request['user']['role'], + ), + cluster_mode=params['cluster_mode'], + cluster_size=params['cluster_size'], + session_tag=params['tag'], + starts_at=starts_at, + agent_list=params['config']['agent_list'], + dependency_sessions=params['dependencies'], + callback_url=params['callback_url'], + )), + ) + resp['sessionId'] = str(kernel_id) # changed since API v5 + resp['sessionName'] = str(params['session_name']) + resp['status'] = 'PENDING' + resp['servicePorts'] = [] + resp['created'] = True + + if not params['enqueue_only']: + app_ctx.pending_waits.add(current_task) + max_wait = params['max_wait_seconds'] + try: + if max_wait > 0: + with timeout(max_wait): + await start_event.wait() + else: + await start_event.wait() + except asyncio.TimeoutError: + resp['status'] = 'TIMEOUT' + else: + await asyncio.sleep(0.5) + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.status, + kernels.c.service_ports, + ]) + .select_from(kernels) + .where(kernels.c.id == kernel_id) + ) + result = await conn.execute(query) + row = result.first() + if row['status'] == KernelStatus.RUNNING: + resp['status'] = 'RUNNING' + for item in row['service_ports']: + response_dict = { + 'name': item['name'], + 'protocol': item['protocol'], + 'ports': item['container_ports'], + } + if 'url_template' in item.keys(): + response_dict['url_template'] = item['url_template'] + if 'allowed_arguments' in item.keys(): + response_dict['allowed_arguments'] = item['allowed_arguments'] + if 'allowed_envs' in item.keys(): + response_dict['allowed_envs'] = item['allowed_envs'] + resp['servicePorts'].append(response_dict) + else: + resp['status'] = row['status'].name + except asyncio.CancelledError: + raise + except BackendError: + log.exception('GET_OR_CREATE: exception') + raise + except UnknownImageReference: + raise UnknownImageReferenceError(f"Unknown image reference: {params['image']}") + except Exception: + await root_ctx.error_monitor.capture_exception(context={'user': owner_uuid}) + log.exception('GET_OR_CREATE: unexpected error!') + raise InternalServerError + finally: + app_ctx.pending_waits.discard(current_task) + del session_creation_tracker[session_creation_id] + return web.json_response(resp, status=201) + + +@server_status_required(ALL_ALLOWED) +@auth_required +@check_api_params(t.Dict( + { + tx.AliasedKey(['template_id', 'templateId']): t.Null | tx.UUID, + tx.AliasedKey(['name', 'clientSessionToken'], default=undefined) >> 'session_name': + UndefChecker | t.Regexp(r'^(?=.{4,64}$)\w[\w.-]*\w$', re.ASCII), + tx.AliasedKey(['image', 'lang'], default=undefined): UndefChecker | t.Null | t.String, + tx.AliasedKey(['arch', 'architecture'], default=DEFAULT_IMAGE_ARCH) >> 'architecture': t.String, + tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type': + tx.Enum(SessionTypes), + tx.AliasedKey(['group', 'groupName', 'group_name'], default=undefined): + UndefChecker | t.Null | t.String, + tx.AliasedKey(['domain', 'domainName', 'domain_name'], default=undefined): + UndefChecker | t.Null | t.String, + tx.AliasedKey(['cluster_size', 'clusterSize'], default=1): + t.ToInt[1:], # new in APIv6 + tx.AliasedKey(['cluster_mode', 'clusterMode'], default='single-node'): + tx.Enum(ClusterMode), # new in APIv6 + t.Key('config', default=dict): t.Mapping(t.String, t.Any), + t.Key('tag', default=undefined): UndefChecker | t.Null | t.String, + t.Key('enqueueOnly', default=False) >> 'enqueue_only': t.ToBool, + t.Key('maxWaitSeconds', default=0) >> 'max_wait_seconds': t.Int[0:], + tx.AliasedKey(['starts_at', 'startsAt'], default=None): t.Null | t.String, + t.Key('reuseIfExists', default=True) >> 'reuse': t.ToBool, + t.Key('startupCommand', default=None) >> 'startup_command': + UndefChecker | t.Null | t.String, + tx.AliasedKey(['bootstrap_script', 'bootstrapScript'], default=undefined): + UndefChecker | t.Null | t.String, + t.Key('dependencies', default=None): + UndefChecker | t.Null | t.List(tx.UUID) | t.List(t.String), + tx.AliasedKey(['callback_url', 'callbackUrl', 'callbackURL'], default=None): + UndefChecker | t.Null | tx.URL, + t.Key('owner_access_key', default=undefined): UndefChecker | t.Null | t.String, + }, +), loads=_json_loads) +async def create_from_template(request: web.Request, params: dict[str, Any]) -> web.Response: + # TODO: we need to refactor session_template model to load the template configs + # by one batch. Currently, we need to set every template configs one by one. + root_ctx: RootContext = request.app['_root.context'] + + if params['image'] is None and params['template_id'] is None: + raise InvalidAPIParameters('Both image and template_id can\'t be None!') + + api_version = request['api_version'] + try: + if 6 <= api_version[0]: + params['config'] = creation_config_v5_template.check(params['config']) + elif 5 <= api_version[0]: + params['config'] = creation_config_v4_template.check(params['config']) + elif (4, '20190315') <= api_version: + params['config'] = creation_config_v3_template.check(params['config']) + except t.DataError as e: + log.debug('Validation error: {0}', e.as_dict()) + raise InvalidAPIParameters('Input validation error', + extra_data=e.as_dict()) + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([session_templates]) + .select_from(session_templates) + .where( + (session_templates.c.id == params['template_id']) & + session_templates.c.is_active, + ) + ) + result = await conn.execute(query) + template_info = result.fetchone() + template = template_info['template'] + if not template: + raise TaskTemplateNotFound + group_name = None + if template_info['domain_name'] and template_info['group_id']: + query = ( + sa.select([groups.c.name]) + .select_from(groups) + .where( + (groups.c.domain_name == template_info['domain_name']) & + (groups.c.id == template_info['group_id']), + ) + ) + group_name = await conn.scalar(query) + + if isinstance(template, str): + template = json.loads(template) + log.debug('Template: {0}', template) + + param_from_template = { + 'image': template['spec']['kernel']['image'], + 'architecture': template['spec']['kernel'].get('architecture', DEFAULT_IMAGE_ARCH), + } + if 'domain_name' in template_info: + param_from_template['domain'] = template_info['domain_name'] + if group_name: + param_from_template['group'] = group_name + if template['spec']['session_type'] == 'interactive': + param_from_template['session_type'] = SessionTypes.INTERACTIVE + elif template['spec']['session_type'] == 'batch': + param_from_template['session_type'] = SessionTypes.BATCH + + # TODO: Remove `type: ignore` when mypy supports type inference for walrus operator + # Check https://github.com/python/mypy/issues/7316 + # TODO: remove `NOQA` when flake8 supports Python 3.8 and walrus operator + # Check https://gitlab.com/pycqa/flake8/issues/599 + if tag := template['metadata'].get('tag'): # noqa + param_from_template['tag'] = tag + if runtime_opt := template['spec']['kernel']['run']: # noqa + if bootstrap := runtime_opt['bootstrap']: # noqa + param_from_template['bootstrap_script'] = bootstrap + if startup := runtime_opt['startup_command']: # noqa + param_from_template['startup_command'] = startup + + config_from_template: MutableMapping[Any, Any] = {} + if scaling_group := template['spec'].get('scaling_group'): # noqa + config_from_template['scaling_group'] = scaling_group + if mounts := template['spec'].get('mounts'): # noqa + config_from_template['mounts'] = list(mounts.keys()) + config_from_template['mount_map'] = { + key: value + for (key, value) in mounts.items() + if len(value) > 0 + } + if environ := template['spec']['kernel'].get('environ'): # noqa + config_from_template['environ'] = environ + if resources := template['spec'].get('resources'): # noqa + config_from_template['resources'] = resources + if 'agent_list' in template['spec']: + config_from_template['agent_list'] = template['spec']['agent_list'] + + override_config = drop(dict(params['config']), undefined) + override_params = drop(dict(params), undefined) + + log.debug('Default config: {0}', config_from_template) + log.debug('Default params: {0}', param_from_template) + + log.debug('Override config: {0}', override_config) + log.debug('Override params: {0}', override_params) + if override_config: + config_from_template.update(override_config) + if override_params: + param_from_template.update(override_params) + try: + params = overwritten_param_check.check(param_from_template) + except RuntimeError as e1: + log.exception(e1) + except t.DataError as e2: + log.debug('Error: {0}', str(e2)) + raise InvalidAPIParameters('Error while validating template') + params['config'] = config_from_template + + log.debug('Updated param: {0}', params) + + if git := template['spec']['kernel']['git']: # noqa + if _dest := git.get('dest_dir'): # noqa + target = _dest + else: + target = git['repository'].split('/')[-1] + + cmd_builder = 'git clone ' + if credential := git.get('credential'): # noqa + proto, url = git['repository'].split('://') + cmd_builder += f'{proto}://{credential["username"]}:{credential["password"]}@{url}' + else: + cmd_builder += git['repository'] + if branch := git.get('branch'): # noqa + cmd_builder += f' -b {branch}' + cmd_builder += f' {target}\n' + + if commit := git.get('commit'): # noqa + cmd_builder = 'CWD=$(pwd)\n' + cmd_builder + cmd_builder += f'cd {target}\n' + cmd_builder += f'git checkout {commit}\n' + cmd_builder += 'cd $CWD\n' + + bootstrap = base64.b64decode(params.get('bootstrap_script') or b'').decode() + bootstrap += '\n' + bootstrap += cmd_builder + params['bootstrap_script'] = base64.b64encode(bootstrap.encode()).decode() + return await _create(request, params) + + +@server_status_required(ALL_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['name', 'clientSessionToken']) >> 'session_name': + t.Regexp(r'^(?=.{4,64}$)\w[\w.-]*\w$', re.ASCII), + tx.AliasedKey(['image', 'lang']): t.String, + tx.AliasedKey(['arch', 'architecture'], default=DEFAULT_IMAGE_ARCH) >> 'architecture': t.String, + tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'session_type': + tx.Enum(SessionTypes), + tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String, + tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String, + tx.AliasedKey(['cluster_size', 'clusterSize'], default=1): + t.ToInt[1:], # new in APIv6 + tx.AliasedKey(['cluster_mode', 'clusterMode'], default='single-node'): + tx.Enum(ClusterMode), # new in APIv6 + t.Key('config', default=dict): t.Mapping(t.String, t.Any), + t.Key('tag', default=None): t.Null | t.String, + t.Key('enqueueOnly', default=False) >> 'enqueue_only': t.ToBool, + t.Key('maxWaitSeconds', default=0) >> 'max_wait_seconds': t.ToInt[0:], + tx.AliasedKey(['starts_at', 'startsAt'], default=None): t.Null | t.String, + t.Key('reuseIfExists', default=True) >> 'reuse': t.ToBool, + t.Key('startupCommand', default=None) >> 'startup_command': t.Null | t.String, + tx.AliasedKey(['bootstrap_script', 'bootstrapScript'], default=None): t.Null | t.String, + t.Key('dependencies', default=None): t.Null | t.List(tx.UUID) | t.List(t.String), + tx.AliasedKey(['callback_url', 'callbackUrl', 'callbackURL'], default=None): t.Null | tx.URL, + t.Key('owner_access_key', default=None): t.Null | t.String, + }), + loads=_json_loads) +async def create_from_params(request: web.Request, params: dict[str, Any]) -> web.Response: + if params['session_name'] in ['from-template']: + raise InvalidAPIParameters(f'Requested session ID {params["session_name"]} is reserved word') + api_version = request['api_version'] + if 6 <= api_version[0]: + creation_config = creation_config_v5.check(params['config']) + elif 5 <= api_version[0]: + creation_config = creation_config_v4.check(params['config']) + elif (4, '20190315') <= api_version: + creation_config = creation_config_v3.check(params['config']) + elif 2 <= api_version[0] <= 4: + creation_config = creation_config_v2.check(params['config']) + elif api_version[0] == 1: + creation_config = creation_config_v1.check(params['config']) + else: + raise InvalidAPIParameters('API version not supported') + params['config'] = creation_config + if params['config']['agent_list'] is not None and request['user']['role'] != (UserRole.SUPERADMIN): + raise InsufficientPrivilege('You are not allowed to manually assign agents for your session.') + if request['user']['role'] == (UserRole.SUPERADMIN): + if not params['config']['agent_list']: + pass + else: + agent_count = len(params['config']['agent_list']) + if params['cluster_mode'] == "multi-node": + if agent_count != params['cluster_size']: + raise InvalidAPIParameters( + "For multi-node cluster sessions, the number of manually assigned agents " + "must be same to the clsuter size. " + "Note that you may specify duplicate agents in the list.", + ) + else: + if agent_count != 1: + raise InvalidAPIParameters( + "For non-cluster sessions and single-node cluster sessions, " + "you may specify only one manually assigned agent.", + ) + return await _create(request, params) + + +@server_status_required(ALL_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + t.Key('clientSessionToken') >> 'session_name': + t.Regexp(r'^(?=.{4,64}$)\w[\w.-]*\w$', re.ASCII), + tx.AliasedKey(['template_id', 'templateId']): t.Null | tx.UUID, + tx.AliasedKey(['type', 'sessionType'], default='interactive') >> 'sess_type': + tx.Enum(SessionTypes), + tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String, + tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String, + tx.AliasedKey(['scaling_group', 'scalingGroup'], default=None): t.Null | t.String, + t.Key('tag', default=None): t.Null | t.String, + t.Key('enqueueOnly', default=False) >> 'enqueue_only': t.ToBool, + t.Key('maxWaitSeconds', default=0) >> 'max_wait_seconds': t.Int[0:], + t.Key('owner_access_key', default=None): t.Null | t.String, + }), + loads=_json_loads) +async def create_cluster(request: web.Request, params: dict[str, Any]) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + app_ctx: PrivateContext = request.app['session.context'] + if params['domain'] is None: + params['domain'] = request['user']['domain_name'] + scopes_param = { + 'owner_access_key': ( + None if params['owner_access_key'] is undefined + else params['owner_access_key'] + ), + } + requester_access_key, owner_access_key = await get_access_key_scopes(request, scopes_param) + log.info('CREAT_CLUSTER (ak:{0}/{1}, s:{3})', + requester_access_key, owner_access_key if owner_access_key != requester_access_key else '*', + params['session_name']) + + resp: MutableMapping[str, Any] = {} + + # Check existing (owner_access_key, session) kernel instance + try: + # NOTE: We can reuse the session IDs of TERMINATED sessions only. + # NOTE: Reusing a session in the PENDING status returns an empty value in service_ports. + await root_ctx.registry.get_session(params['session_name'], owner_access_key) + except SessionNotFound: + pass + except TooManySessionsMatched: + raise SessionAlreadyExists + else: + raise SessionAlreadyExists + + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([session_templates.c.template]) + .select_from(session_templates) + .where( + (session_templates.c.id == params['template_id']) & + session_templates.c.is_active, + ) + ) + template = await conn.scalar(query) + log.debug('task template: {}', template) + if not template: + raise TaskTemplateNotFound + + mounts = [] + mount_map = {} + environ = {} + + if _mounts := template['spec'].get('mounts'): # noqa + mounts = list(_mounts.keys()) + mount_map = { + key: value + for (key, value) in _mounts.items() + if len(value) > 0 + } + if _environ := template['spec'].get('environ'): # noqa + environ = _environ + + log.debug('cluster template: {}', template) + + kernel_configs: List[KernelEnqueueingConfig] = [] + for node in template['spec']['nodes']: + # Resolve session template. + kernel_config = { + 'image': template['spec']['kernel']['image'], + 'architecture': template['spec']['kernel'].get('architecture', DEFAULT_IMAGE_ARCH), + 'cluster_role': node['cluster_role'], + 'creation_config': { + 'mount': mounts, + 'mount_map': mount_map, + 'environ': environ, + }, + } + + if template['spec']['sess_type'] == 'interactive': + kernel_config['sess_type'] = SessionTypes.INTERACTIVE + elif template['spec']['sess_type'] == 'batch': + kernel_config['sess_type'] = SessionTypes.BATCH + + if tag := template['metadata'].get('tag', None): + kernel_config['tag'] = tag + if runtime_opt := template['spec']['kernel']['run']: + if bootstrap := runtime_opt['bootstrap']: + kernel_config['bootstrap_script'] = bootstrap + if startup := runtime_opt['startup_command']: + kernel_config['startup_command'] = startup + + if resources := template['spec'].get('resources'): + kernel_config['creation_config']['resources'] = resources + + if git := template['spec']['kernel']['git']: + if _dest := git.get('dest_dir'): + target = _dest + else: + target = git['repository'].split('/')[-1] + + cmd_builder = 'git clone ' + if credential := git.get('credential'): + proto, url = git['repository'].split('://') + cmd_builder += f'{proto}://{credential["username"]}:{credential["password"]}@{url}' + else: + cmd_builder += git['repository'] + if branch := git.get('branch'): + cmd_builder += f' -b {branch}' + cmd_builder += f' {target}\n' + + if commit := git.get('commit'): + cmd_builder = 'CWD=$(pwd)\n' + cmd_builder + cmd_builder += f'cd {target}\n' + cmd_builder += f'git checkout {commit}\n' + cmd_builder += 'cd $CWD\n' + + bootstrap = base64.b64decode(kernel_config.get('bootstrap_script') or b'').decode() + bootstrap += '\n' + bootstrap += cmd_builder + kernel_config['bootstrap_script'] = base64.b64encode(bootstrap.encode()).decode() + + # Resolve the image reference. + try: + async with root_ctx.db.begin_readonly_session() as session: + image_row = await ImageRow.resolve(session, [ + ImageRef(kernel_config['image'], ['*'], kernel_config['architecture']), + kernel_config['image'], + ]) + requested_image_ref = image_row.image_ref + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([domains.c.allowed_docker_registries]) + .select_from(domains) + .where(domains.c.name == params['domain']) + ) + allowed_registries = await conn.scalar(query) + if requested_image_ref.registry not in allowed_registries: + raise AliasResolutionFailed + kernel_config['image_ref'] = requested_image_ref + except AliasResolutionFailed: + raise ImageNotFound('unknown alias or disallowed registry') + + for i in range(node['replicas']): + kernel_config['cluster_idx'] = i + 1 + kernel_configs.append( + check_typed_dict(kernel_config, KernelEnqueueingConfig), # type: ignore + ) + + session_creation_id = secrets.token_urlsafe(16) + start_event = asyncio.Event() + kernel_id: Optional[KernelId] = None + session_creation_tracker = app_ctx.session_creation_tracker + session_creation_tracker[session_creation_id] = start_event + current_task = asyncio.current_task() + assert current_task is not None + + try: + async with root_ctx.db.begin_readonly() as conn: + owner_uuid, group_id, resource_policy = await _query_userinfo(request, params, conn) + + session_id = await asyncio.shield(app_ctx.database_ptask_group.create_task( + root_ctx.registry.enqueue_session( + session_creation_id, + params['session_name'], + owner_access_key, + kernel_configs, + params['scaling_group'], + params['sess_type'], + resource_policy, + user_scope=UserScope( + domain_name=params['domain'], # type: ignore + group_id=group_id, + user_uuid=owner_uuid, + user_role=request['user']['role'], + ), + session_tag=params['tag'], + ), + )) + kernel_id = cast(KernelId, session_id) # the main kernel's ID is the session ID. + resp['kernelId'] = str(kernel_id) + resp['status'] = 'PENDING' + resp['servicePorts'] = [] + resp['created'] = True + + if not params['enqueue_only']: + app_ctx.pending_waits.add(current_task) + max_wait = params['max_wait_seconds'] + try: + if max_wait > 0: + with timeout(max_wait): + await start_event.wait() + else: + await start_event.wait() + except asyncio.TimeoutError: + resp['status'] = 'TIMEOUT' + else: + await asyncio.sleep(0.5) + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.status, + kernels.c.service_ports, + ]) + .select_from(kernels) + .where(kernels.c.id == kernel_id) + ) + result = await conn.execute(query) + row = result.first() + if row['status'] == KernelStatus.RUNNING: + resp['status'] = 'RUNNING' + for item in row['service_ports']: + response_dict = { + 'name': item['name'], + 'protocol': item['protocol'], + 'ports': item['container_ports'], + } + if 'url_template' in item.keys(): + response_dict['url_template'] = item['url_template'] + if 'allowed_arguments' in item.keys(): + response_dict['allowed_arguments'] = item['allowed_arguments'] + if 'allowed_envs' in item.keys(): + response_dict['allowed_envs'] = item['allowed_envs'] + resp['servicePorts'].append(response_dict) + else: + resp['status'] = row['status'].name + + except asyncio.CancelledError: + raise + except BackendError: + log.exception('GET_OR_CREATE: exception') + raise + except UnknownImageReference: + raise UnknownImageReferenceError(f"Unknown image reference: {params['image']}") + except Exception: + await root_ctx.error_monitor.capture_exception() + log.exception('GET_OR_CREATE: unexpected error!') + raise InternalServerError + finally: + app_ctx.pending_waits.discard(current_task) + del session_creation_tracker[session_creation_id] + return web.json_response(resp, status=201) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + t.Key('login_session_token', default=None): t.Null | t.String, + tx.AliasedKey(['app', 'service']): t.String, + # The port argument is only required to use secondary ports + # when the target app listens multiple TCP ports. + # Otherwise it should be omitted or set to the same value of + # the actual port number used by the app. + tx.AliasedKey(['port'], default=None): t.Null | t.Int[1024:65535], + tx.AliasedKey(['envs'], default=None): t.Null | t.String, # stringified JSON + # e.g., '{"PASSWORD": "12345"}' + tx.AliasedKey(['arguments'], default=None): t.Null | t.String, # stringified JSON + # e.g., '{"-P": "12345"}' + # The value can be one of: + # None, str, List[str] + })) +async def start_service(request: web.Request, params: Mapping[str, Any]) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + session_name: str = request.match_info['session_name'] + app_ctx: PrivateContext = request.app['session.context'] + access_key: AccessKey = request['keypair']['access_key'] + service: str = params['app'] + myself = asyncio.current_task() + assert myself is not None + try: + kernel = await asyncio.shield(app_ctx.database_ptask_group.create_task( + root_ctx.registry.get_session(session_name, access_key), + )) + except (SessionNotFound, TooManySessionsMatched): + raise + + query = (sa.select([scaling_groups.c.wsproxy_addr]) + .select_from(scaling_groups) + .where((scaling_groups.c.name == kernel['scaling_group']))) + + async with root_ctx.db.begin_readonly() as conn: + result = await conn.execute(query) + sgroup = result.first() + wsproxy_addr = sgroup['wsproxy_addr'] + if not wsproxy_addr: + raise ServiceUnavailable('No coordinator configured for this resource group') + + if kernel['kernel_host'] is None: + kernel_host = urlparse(kernel['agent_addr']).hostname + else: + kernel_host = kernel['kernel_host'] + for sport in kernel['service_ports']: + if sport['name'] == service: + if params['port']: + # using one of the primary/secondary ports of the app + try: + hport_idx = sport['container_ports'].index(params['port']) + except ValueError: + raise InvalidAPIParameters( + f"Service {service} does not open the port number {params['port']}.") + host_port = sport['host_ports'][hport_idx] + else: + # using the default (primary) port of the app + if 'host_ports' not in sport: + host_port = sport['host_port'] # legacy kernels + else: + host_port = sport['host_ports'][0] + break + else: + raise AppNotFound(f'{session_name}:{service}') + + await asyncio.shield(app_ctx.database_ptask_group.create_task( + root_ctx.registry.increment_session_usage(session_name, access_key), + )) + + opts: MutableMapping[str, Union[None, str, List[str]]] = {} + if params['arguments'] is not None: + opts['arguments'] = json.loads(params['arguments']) + if params['envs'] is not None: + opts['envs'] = json.loads(params['envs']) + + result = await asyncio.shield( + app_ctx.rpc_ptask_group.create_task( + root_ctx.registry.start_service(session_name, access_key, service, opts), + ), + ) + if result['status'] == 'failed': + raise InternalServerError( + "Failed to launch the app service", + extra_data=result['error']) + + async with aiohttp.ClientSession() as session: + async with session.post(f'{wsproxy_addr}/v2/conf', json={ + 'login_session_token': params['login_session_token'], + 'kernel_host': kernel_host, + 'kernel_port': host_port, + }) as resp: + token_json = await resp.json() + return web.json_response({ + 'token': token_json['token'], + 'wsproxy_addr': wsproxy_addr, + }) + + +async def handle_kernel_creation_lifecycle( + app: web.Application, + source: AgentId, + event: (KernelPreparingEvent | KernelPullingEvent | KernelCreatingEvent | + KernelStartedEvent | KernelCancelledEvent), +) -> None: + """ + Update the database and perform post_create_kernel() upon + the events for each step of kernel creation. + + To avoid race condition between consumer and subscriber event handlers, + we only have this handler to subscribe all kernel creation events, + but distinguish which one to process using a unique creation_id + generated when initiating the create_kernels() agent RPC call. + """ + root_ctx: RootContext = app['_root.context'] + # ck_id = (event.creation_id, event.kernel_id) + ck_id = event.kernel_id + if ck_id in root_ctx.registry.kernel_creation_tracker: + log.debug( + "handle_kernel_creation_lifecycle: ev:{} k:{}", + event.name, event.kernel_id, + ) + if isinstance(event, KernelPreparingEvent): + # State transition is done by the DoPrepareEvent handler inside the scheduler-distpacher object. + pass + elif isinstance(event, KernelPullingEvent): + await root_ctx.registry.set_kernel_status(event.kernel_id, KernelStatus.PULLING, event.reason) + elif isinstance(event, KernelCreatingEvent): + await root_ctx.registry.set_kernel_status(event.kernel_id, KernelStatus.PREPARING, event.reason) + elif isinstance(event, KernelStartedEvent): + # post_create_kernel() coroutines are waiting for the creation tracker events to be set. + if (tracker := root_ctx.registry.kernel_creation_tracker.get(ck_id)) and not tracker.done(): + tracker.set_result(None) + elif isinstance(event, KernelCancelledEvent): + if (tracker := root_ctx.registry.kernel_creation_tracker.get(ck_id)) and not tracker.done(): + tracker.cancel() + + +async def handle_kernel_termination_lifecycle( + app: web.Application, + source: AgentId, + event: KernelTerminatingEvent | KernelTerminatedEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + if isinstance(event, KernelTerminatingEvent): + # The destroy_kernel() API handler will set the "TERMINATING" status. + pass + elif isinstance(event, KernelTerminatedEvent): + await root_ctx.registry.mark_kernel_terminated(event.kernel_id, event.reason, event.exit_code) + await root_ctx.registry.check_session_terminated(event.kernel_id, event.reason) + + +async def handle_session_creation_lifecycle( + app: web.Application, + source: AgentId, + event: SessionStartedEvent | SessionCancelledEvent, +) -> None: + """ + Update the database according to the session-level lifecycle events + published by the manager. + """ + app_ctx: PrivateContext = app['session.context'] + if event.creation_id not in app_ctx.session_creation_tracker: + return + log.debug('handle_session_creation_lifecycle: ev:{} s:{}', event.name, event.session_id) + if isinstance(event, SessionStartedEvent): + if tracker := app_ctx.session_creation_tracker.get(event.creation_id): + tracker.set() + elif isinstance(event, SessionCancelledEvent): + if tracker := app_ctx.session_creation_tracker.get(event.creation_id): + tracker.set() + + +async def handle_session_termination_lifecycle( + app: web.Application, + agent_id: AgentId, + event: SessionTerminatedEvent, +) -> None: + """ + Update the database according to the session-level lifecycle events + published by the manager. + """ + root_ctx: RootContext = app['_root.context'] + if isinstance(event, SessionTerminatedEvent): + await root_ctx.registry.mark_session_terminated(event.session_id, event.reason) + + +async def handle_destroy_session( + app: web.Application, + source: AgentId, + event: DoTerminateSessionEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + await root_ctx.registry.destroy_session( + functools.partial( + root_ctx.registry.get_session_by_session_id, + event.session_id, + ), + forced=False, + reason=event.reason or 'killed-by-event', + ) + + +async def handle_kernel_stat_sync( + app: web.Application, + agent_id: AgentId, + event: DoSyncKernelStatsEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + if root_ctx.local_config['debug']['periodic-sync-stats']: + await root_ctx.registry.sync_kernel_stats(event.kernel_ids) + + +async def _make_session_callback(data: dict[str, Any], url: yarl.URL) -> None: + log_func = log.info + log_msg: str = "" + log_fmt: str = "" + log_arg: Any = None + begin = time.monotonic() + try: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=30.0), + ) as session: + try: + async with session.post(url, json=data) as response: + if response.content_length is not None and response.content_length > 0: + log_func = log.warning + log_msg = "warning" + log_fmt = "{3[0]} {3[1]} - the callback response body was not empty! " \ + "(len: {3[2]:,} bytes)" + log_arg = (response.status, response.reason, response.content_length) + else: + log_msg = "result" + log_fmt = "{3[0]} {3[1]}" + log_arg = (response.status, response.reason) + except aiohttp.ClientError as e: + log_func = log.warning + log_msg, log_fmt, log_arg = "failed", "{3}", repr(e) + except asyncio.CancelledError: + log_func = log.warning + log_msg, log_fmt, log_arg = "cancelled", "elapsed_time = {3:.6f}", time.monotonic() - begin + except asyncio.TimeoutError: + log_func = log.warning + log_msg, log_fmt, log_arg = "timeout", "elapsed_time = {3:.6f}", time.monotonic() - begin + finally: + log_func( + "Session lifecycle callback " + log_msg + " (e:{0}, s:{1}, url:{2}): " + log_fmt, + data['event'], data['session_id'], url, + log_arg, + ) + + +async def invoke_session_callback( + app: web.Application, + source: AgentId, + event: SessionEnqueuedEvent | SessionScheduledEvent | SessionPreparingEvent + | SessionStartedEvent | SessionCancelledEvent | SessionTerminatedEvent + | SessionSuccessEvent | SessionFailureEvent, +) -> None: + app_ctx: PrivateContext = app['session.context'] + root_ctx: RootContext = app['_root.context'] + data = { + "type": "session_lifecycle", + "event": event.name.removeprefix("session_"), + "session_id": str(event.session_id), + "when": datetime.now(tzutc()).isoformat(), + } + try: + async with root_ctx.db.begin_readonly() as db: + session = await root_ctx.registry.get_session_by_session_id( + event.session_id, + db_connection=db, + ) + except SessionNotFound: + return + url = session['callback_url'] + if url is None: + return + app_ctx.webhook_ptask_group.create_task( + _make_session_callback(data, url), + ) + + +async def handle_batch_result( + app: web.Application, + source: AgentId, + event: SessionSuccessEvent | SessionFailureEvent, +) -> None: + """ + Update the database according to the batch-job completion results + """ + root_ctx: RootContext = app['_root.context'] + if isinstance(event, SessionSuccessEvent): + await root_ctx.registry.set_session_result(event.session_id, True, event.exit_code) + elif isinstance(event, SessionFailureEvent): + await root_ctx.registry.set_session_result(event.session_id, False, event.exit_code) + await root_ctx.registry.destroy_session( + functools.partial( + root_ctx.registry.get_session_by_session_id, + event.session_id, + ), + reason='task-finished', + ) + + +async def handle_agent_lifecycle( + app: web.Application, + source: AgentId, + event: AgentStartedEvent | AgentTerminatedEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + if isinstance(event, AgentStartedEvent): + log.info('instance_lifecycle: ag:{0} joined ({1})', source, event.reason) + await root_ctx.registry.update_instance(source, { + 'status': AgentStatus.ALIVE, + }) + if isinstance(event, AgentTerminatedEvent): + if event.reason == 'agent-lost': + await root_ctx.registry.mark_agent_terminated(source, AgentStatus.LOST) + elif event.reason == 'agent-restart': + log.info('agent@{0} restarting for maintenance.', source) + await root_ctx.registry.update_instance(source, { + 'status': AgentStatus.RESTARTING, + }) + else: + # On normal instance termination, kernel_terminated events were already + # triggered by the agent. + await root_ctx.registry.mark_agent_terminated(source, AgentStatus.TERMINATED) + + +async def handle_agent_heartbeat( + app: web.Application, + source: AgentId, + event: AgentHeartbeatEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + await root_ctx.registry.handle_heartbeat(source, event.agent_info) + + +@catch_unexpected(log) +async def check_agent_lost(root_ctx: RootContext, interval: float) -> None: + try: + now = datetime.now(tzutc()) + timeout = timedelta(seconds=root_ctx.local_config['manager']['heartbeat-timeout']) + + async def _check_impl(r: aioredis.Redis): + async for agent_id, prev in r.hscan_iter('agent.last_seen'): + prev = datetime.fromtimestamp(float(prev), tzutc()) + if now - prev > timeout: + await root_ctx.event_producer.produce_event( + AgentTerminatedEvent("agent-lost"), + source=agent_id.decode()) + + await redis.execute(root_ctx.redis_live, _check_impl) + except asyncio.CancelledError: + pass + + +async def handle_kernel_log( + app: web.Application, + source: AgentId, + event: DoSyncKernelLogsEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + redis_conn = redis.get_redis_object(root_ctx.shared_config.data['redis'], db=REDIS_STREAM_DB) + # The log data is at most 10 MiB. + log_buffer = BytesIO() + log_key = f'containerlog.{event.container_id}' + try: + list_size = await redis.execute( + redis_conn, + lambda r: r.llen(log_key), + ) + if list_size is None: + # The log data is expired due to a very slow event delivery. + # (should never happen!) + log.warning('tried to store console logs for cid:{}, but the data is expired', + event.container_id) + return + for _ in range(list_size): + # Read chunk-by-chunk to allow interleaving with other Redis operations. + chunk = await redis.execute(redis_conn, lambda r: r.lpop(log_key)) + if chunk is None: # maybe missing + log_buffer.write(b"(container log unavailable)\n") + break + log_buffer.write(chunk) + try: + log_data = log_buffer.getvalue() + + async def _update_log() -> None: + async with root_ctx.db.begin() as conn: + update_query = ( + sa.update(kernels) + .values(container_log=log_data) + .where(kernels.c.id == event.kernel_id) + ) + await conn.execute(update_query) + + await execute_with_retry(_update_log) + finally: + # Clear the log data from Redis when done. + await redis.execute( + redis_conn, + lambda r: r.delete(log_key), + ) + finally: + log_buffer.close() + await redis_conn.close() + + +async def report_stats(root_ctx: RootContext, interval: float) -> None: + stats_monitor = root_ctx.stats_monitor + await stats_monitor.report_metric( + GAUGE, 'ai.backend.manager.coroutines', len(asyncio.all_tasks())) + + all_inst_ids = [ + inst_id async for inst_id + in root_ctx.registry.enumerate_instances()] + await stats_monitor.report_metric( + GAUGE, 'ai.backend.manager.agent_instances', len(all_inst_ids)) + + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([sa.func.count()]) + .select_from(kernels) + .where( + (kernels.c.cluster_role == DEFAULT_ROLE) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ) + ) + n = await conn.scalar(query) + await stats_monitor.report_metric( + GAUGE, 'ai.backend.manager.active_kernels', n) + subquery = ( + sa.select([sa.func.count()]) + .select_from(keypairs) + .where(keypairs.c.is_active == true()) + .group_by(keypairs.c.user_id) + ) + query = sa.select([sa.func.count()]).select_from(subquery.alias()) + n = await conn.scalar(query) + await stats_monitor.report_metric( + GAUGE, 'ai.backend.users.has_active_key', n) + + subquery = subquery.where(keypairs.c.last_used != null()) + query = sa.select([sa.func.count()]).select_from(subquery.alias()) + n = await conn.scalar(query) + await stats_monitor.report_metric( + GAUGE, 'ai.backend.users.has_used_key', n) + + """ + query = sa.select([sa.func.count()]).select_from(usage) + n = await conn.scalar(query) + await stats_monitor.report_metric( + GAUGE, 'ai.backend.manager.accum_kernels', n) + """ + + +@server_status_required(ALL_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['name', 'clientSessionToken']) >> 'session_name': + t.Regexp(r'^(?=.{4,64}$)\w[\w.-]*\w$', re.ASCII), + }), +) +async def rename_session(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + new_name = params['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + log.info( + 'RENAME_SESSION (ak:{0}/{1}, s:{2}, newname:{3})', + request, owner_access_key, session_name, new_name, + ) + async with root_ctx.db.begin() as conn: + compute_session = await root_ctx.registry.get_session( + session_name, owner_access_key, + allow_stale=True, + db_connection=conn, + for_update=True, + ) + if compute_session['status'] != KernelStatus.RUNNING: + raise InvalidAPIParameters('Can\'t change name of not running session') + update_query = ( + sa.update(kernels) + .values(session_name=new_name) + .where(kernels.c.session_id == compute_session['session_id']) + ) + await conn.execute(update_query) + + return web.Response(status=204) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + t.Key('forced', default='false'): t.ToBool(), + t.Key('owner_access_key', default=None): t.Null | t.String, + })) +async def destroy(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + if params['forced'] and request['user']['role'] not in (UserRole.ADMIN, UserRole.SUPERADMIN): + raise InsufficientPrivilege('You are not allowed to force-terminate') + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + # domain_name = None + # if requester_access_key != owner_access_key and \ + # not request['is_superadmin'] and request['is_admin']: + # domain_name = request['user']['domain_name'] + log.info('DESTROY (ak:{0}/{1}, s:{2}, forced:{3})', + requester_access_key, owner_access_key, session_name, params['forced']) + last_stat = await root_ctx.registry.destroy_session( + functools.partial( + root_ctx.registry.get_session, + session_name, owner_access_key, + # domain_name=domain_name, + ), + forced=params['forced'], + ) + resp = { + 'stats': last_stat, + } + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + t.Key('id'): t.String(), + })) +async def match_sessions(request: web.Request, params: Any) -> web.Response: + """ + A quick session-ID matcher API for use with auto-completion in CLI. + """ + root_ctx: RootContext = request.app['_root.context'] + id_or_name_prefix = params['id'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + log.info('MATCH_SESSIONS(ak:{0}/{1}, prefix:{2})', + requester_access_key, owner_access_key, id_or_name_prefix) + matches: List[Dict[str, Any]] = [] + async with root_ctx.db.begin_readonly() as conn: + session_infos = await match_session_ids( + id_or_name_prefix, + owner_access_key, + db_connection=conn, + ) + if session_infos: + matches.extend({ + 'id': str(item['session_id']), + 'name': item['session_name'], + 'status': item['status'].name, + } for item in session_infos) + return web.json_response({ + 'matches': matches, + }, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +async def get_info(request: web.Request) -> web.Response: + # NOTE: This API should be replaced with GraphQL version. + resp = {} + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + log.info('GET_INFO (ak:{0}/{1}, s:{2})', + requester_access_key, owner_access_key, session_name) + try: + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + kern = await root_ctx.registry.get_session(session_name, owner_access_key) + resp['domainName'] = kern['domain_name'] + resp['groupId'] = str(kern['group_id']) + resp['userId'] = str(kern['user_uuid']) + resp['lang'] = kern['image'] # legacy + resp['image'] = kern['image'] + resp['architecture'] = kern['architecture'] + resp['registry'] = kern['registry'] + resp['tag'] = kern['tag'] + + # Resource occupation + resp['containerId'] = str(kern['container_id']) + resp['occupiedSlots'] = str(kern['occupied_slots']) + resp['occupiedShares'] = str(kern['occupied_shares']) + resp['environ'] = str(kern['environ']) + + # Lifecycle + resp['status'] = kern['status'].name # "e.g. 'KernelStatus.RUNNING' -> 'RUNNING' " + resp['statusInfo'] = str(kern['status_info']) + resp['statusData'] = kern['status_data'] + age = datetime.now(tzutc()) - kern['created_at'] + resp['age'] = int(age.total_seconds() * 1000) # age in milliseconds + resp['creationTime'] = str(kern['created_at']) + resp['terminationTime'] = str(kern['terminated_at']) if kern['terminated_at'] else None + + resp['numQueriesExecuted'] = kern['num_queries'] + resp['lastStat'] = kern['last_stat'] + + # Resource limits collected from agent heartbeats were erased, as they were deprecated + # TODO: factor out policy/image info as a common repository + + log.info('information retrieved: {0!r}', resp) + except BackendError: + log.exception('GET_INFO: exception') + raise + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +async def restart(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + session_creation_id = secrets.token_urlsafe(16) + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + log.info('RESTART (ak:{0}/{1}, s:{2})', + requester_access_key, owner_access_key, session_name) + try: + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + await root_ctx.registry.restart_session(session_creation_id, session_name, owner_access_key) + except BackendError: + log.exception('RESTART: exception') + raise + except: + await root_ctx.error_monitor.capture_exception(context={'user': request['user']['uuid']}) + log.exception('RESTART: unexpected error') + raise web.HTTPInternalServerError + return web.Response(status=204) + + +@server_status_required(READ_ALLOWED) +@auth_required +async def execute(request: web.Request) -> web.Response: + resp = {} + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + try: + params = await request.json(loads=json.loads) + log.info('EXECUTE(ak:{0}/{1}, s:{2})', + requester_access_key, owner_access_key, session_name) + except json.decoder.JSONDecodeError: + log.warning('EXECUTE: invalid/missing parameters') + raise InvalidAPIParameters + try: + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + api_version = request['api_version'] + if api_version[0] == 1: + run_id = params.get('runId', secrets.token_hex(8)) + mode = 'query' + code = params.get('code', None) + opts = None + elif api_version[0] >= 2: + assert 'runId' in params, 'runId is missing!' + run_id = params['runId'] # maybe None + assert params.get('mode'), 'mode is missing or empty!' + mode = params['mode'] + assert mode in {'query', 'batch', 'complete', 'continue', 'input'}, \ + 'mode has an invalid value.' + if mode in {'continue', 'input'}: + assert run_id is not None, 'continuation requires explicit run ID' + code = params.get('code', None) + opts = params.get('options', None) + else: + raise RuntimeError("should not reach here") + # handle cases when some params are deliberately set to None + if code is None: code = '' # noqa + if opts is None: opts = {} # noqa + if mode == 'complete': + # For legacy + resp['result'] = await root_ctx.registry.get_completions( + session_name, owner_access_key, code, opts) + else: + raw_result = await root_ctx.registry.execute( + session_name, owner_access_key, + api_version, run_id, mode, code, opts, + flush_timeout=2.0) + if raw_result is None: + # the kernel may have terminated from its side, + # or there was interruption of agents. + resp['result'] = { + 'status': 'finished', + 'runId': run_id, + 'exitCode': 130, + 'options': {}, + 'files': [], + 'console': [], + } + return web.json_response(resp, status=200) + # Keep internal/public API compatilibty + result = { + 'status': raw_result['status'], + 'runId': raw_result['runId'], + 'exitCode': raw_result.get('exitCode'), + 'options': raw_result.get('options'), + 'files': raw_result.get('files'), + } + if api_version[0] == 1: + result['stdout'] = raw_result.get('stdout') + result['stderr'] = raw_result.get('stderr') + result['media'] = raw_result.get('media') + result['html'] = raw_result.get('html') + else: + result['console'] = raw_result.get('console') + resp['result'] = result + except AssertionError as e: + log.warning('EXECUTE: invalid/missing parameters: {0!r}', e) + raise InvalidAPIParameters(extra_msg=e.args[0]) + except BackendError: + log.exception('EXECUTE: exception') + raise + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +async def interrupt(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + log.info('INTERRUPT(ak:{0}/{1}, s:{2})', + requester_access_key, owner_access_key, session_name) + try: + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + await root_ctx.registry.interrupt_session(session_name, owner_access_key) + except BackendError: + log.exception('INTERRUPT: exception') + raise + return web.Response(status=204) + + +@server_status_required(READ_ALLOWED) +@auth_required +async def complete(request: web.Request) -> web.Response: + resp = { + 'result': { + 'status': 'finished', + 'completions': [], + }, + } + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + try: + params = await request.json(loads=json.loads) + log.info('COMPLETE(ak:{0}/{1}, s:{2})', + requester_access_key, owner_access_key, session_name) + except json.decoder.JSONDecodeError: + raise InvalidAPIParameters + try: + code = params.get('code', '') + opts = params.get('options', None) or {} + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + resp['result'] = cast( + Dict[str, Any], + await root_ctx.registry.get_completions(session_name, owner_access_key, code, opts), + ) + except AssertionError: + raise InvalidAPIParameters + except BackendError: + log.exception('COMPLETE: exception') + raise + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + t.Key('service_name'): t.String, + })) +async def shutdown_service(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + log.info('SHUTDOWN_SERVICE (ak:{0}/{1}, s:{2})', + requester_access_key, owner_access_key, session_name) + service_name = params.get('service_name') + try: + await root_ctx.registry.shutdown_service(session_name, owner_access_key, service_name) + except BackendError: + log.exception('SHUTDOWN_SERVICE: exception') + raise + return web.Response(status=204) + + +@server_status_required(READ_ALLOWED) +@auth_required +async def upload_files(request: web.Request) -> web.Response: + loop = asyncio.get_event_loop() + reader = await request.multipart() + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + log.info('UPLOAD_FILE (ak:{0}/{1}, s:{2})', + requester_access_key, owner_access_key, session_name) + try: + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + file_count = 0 + upload_tasks = [] + async for file in aiotools.aiter(reader.next, None): + if file_count == 20: + raise InvalidAPIParameters('Too many files') + file_count += 1 + # This API handles only small files, so let's read it at once. + chunks = [] + recv_size = 0 + while True: + chunk = await file.read_chunk(size=1048576) + if not chunk: + break + chunk_size = len(chunk) + if recv_size + chunk_size >= 1048576: + raise InvalidAPIParameters('Too large file') + chunks.append(chunk) + recv_size += chunk_size + data = file.decode(b''.join(chunks)) + log.debug('received file: {0} ({1:,} bytes)', file.filename, recv_size) + t = loop.create_task( + root_ctx.registry.upload_file(session_name, owner_access_key, + file.filename, data)) + upload_tasks.append(t) + await asyncio.gather(*upload_tasks) + except BackendError: + log.exception('UPLOAD_FILES: exception') + raise + return web.Response(status=204) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.MultiKey('files'): t.List(t.String), + })) +async def download_files(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + files = params.get('files') + log.info( + 'DOWNLOAD_FILE (ak:{0}/{1}, s:{2}, path:{3!r})', + requester_access_key, owner_access_key, session_name, + files[0], + ) + try: + assert len(files) <= 5, 'Too many files' + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + # TODO: Read all download file contents. Need to fix by using chuncking, etc. + results = await asyncio.gather( + *map( + functools.partial(root_ctx.registry.download_file, session_name, owner_access_key), + files, + ), + ) + log.debug('file(s) inside container retrieved') + except asyncio.CancelledError: + raise + except BackendError: + log.exception('DOWNLOAD_FILE: exception') + raise + except (ValueError, FileNotFoundError): + raise InvalidAPIParameters('The file is not found.') + except Exception: + await root_ctx.error_monitor.capture_exception(context={'user': request['user']['uuid']}) + log.exception('DOWNLOAD_FILE: unexpected error!') + raise InternalServerError + + with aiohttp.MultipartWriter('mixed') as mpwriter: + headers = multidict.MultiDict({'Content-Encoding': 'identity'}) + for tarbytes in results: + mpwriter.append(tarbytes, headers) + return web.Response(body=mpwriter, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('file'): t.String, + })) +async def download_single(request: web.Request, params: Any) -> web.Response: + """ + Download a single file from the scratch root. Only for small files. + """ + root_ctx: RootContext = request.app['_root.context'] + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + file = params['file'] + log.info( + 'DOWNLOAD_SINGLE (ak:{0}/{1}, s:{2}, path:{3!r})', + requester_access_key, owner_access_key, session_name, file, + ) + try: + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + result = await root_ctx.registry.download_file(session_name, owner_access_key, file) + except asyncio.CancelledError: + raise + except BackendError: + log.exception('DOWNLOAD_SINGLE: exception') + raise + except (ValueError, FileNotFoundError): + raise InvalidAPIParameters('The file is not found.') + except Exception: + await root_ctx.error_monitor.capture_exception(context={'user': request['user']['uuid']}) + log.exception('DOWNLOAD_SINGLE: unexpected error!') + raise InternalServerError + return web.Response(body=result, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +async def list_files(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + try: + session_name = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + params = await request.json(loads=json.loads) + path = params.get('path', '.') + log.info( + 'LIST_FILES (ak:{0}/{1}, s:{2}, path:{3})', + requester_access_key, owner_access_key, session_name, path, + ) + except (asyncio.TimeoutError, AssertionError, + json.decoder.JSONDecodeError) as e: + log.warning('LIST_FILES: invalid/missing parameters, {0!r}', e) + raise InvalidAPIParameters(extra_msg=str(e.args[0])) + resp: MutableMapping[str, Any] = {} + try: + await root_ctx.registry.increment_session_usage(session_name, owner_access_key) + result = await root_ctx.registry.list_files(session_name, owner_access_key, path) + resp.update(result) + log.debug('container file list for {0} retrieved', path) + except asyncio.CancelledError: + raise + except BackendError: + log.exception('LIST_FILES: exception') + raise + except Exception: + await root_ctx.error_monitor.capture_exception(context={'user': request['user']['uuid']}) + log.exception('LIST_FILES: unexpected error!') + raise InternalServerError + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + t.Key('owner_access_key', default=None): t.Null | t.String, + })) +async def get_container_logs(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + session_name: str = request.match_info['session_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info('GET_CONTAINER_LOG (ak:{}/{}, s:{})', + requester_access_key, owner_access_key, session_name) + resp = {'result': {'logs': ''}} + async with root_ctx.db.begin_readonly() as conn: + compute_session = await root_ctx.registry.get_session( + session_name, owner_access_key, + allow_stale=True, + db_connection=conn, + ) + if ( + compute_session['status'] in DEAD_KERNEL_STATUSES + and compute_session['container_log'] is not None + ): + log.debug('returning log from database record') + resp['result']['logs'] = compute_session['container_log'].decode('utf-8') + return web.json_response(resp, status=200) + try: + registry = root_ctx.registry + await registry.increment_session_usage(session_name, owner_access_key) + resp['result']['logs'] = await registry.get_logs_from_agent(session_name, owner_access_key) + log.debug('returning log from agent') + except BackendError: + log.exception('GET_CONTAINER_LOG(ak:{}/{}, s:{}): unexpected error', + requester_access_key, owner_access_key, session_name) + raise + return web.json_response(resp, status=200) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['session_name', 'sessionName', 'task_id', 'taskId']) >> 'kernel_id': tx.UUID, + })) +async def get_task_logs(request: web.Request, params: Any) -> web.StreamResponse: + log.info('GET_TASK_LOG (ak:{}, k:{})', + request['keypair']['access_key'], params['kernel_id']) + root_ctx: RootContext = request.app['_root.context'] + domain_name = request['user']['domain_name'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + kernel_id_str = params['kernel_id'].hex + async with root_ctx.db.begin_readonly() as conn: + matched_vfolders = await query_accessible_vfolders( + conn, user_uuid, + user_role=user_role, domain_name=domain_name, + allowed_vfolder_types=['user'], + extra_vf_conds=(vfolders.c.name == '.logs')) + if not matched_vfolders: + raise ObjectNotFound( + extra_data={'vfolder_name': '.logs'}, + object_name='vfolder', + ) + log_vfolder = matched_vfolders[0] + + proxy_name, volume_name = root_ctx.storage_manager.split_host(log_vfolder['host']) + response = web.StreamResponse(status=200) + response.headers[hdrs.CONTENT_TYPE] = "text/plain" + prepared = False + try: + async with root_ctx.storage_manager.request( + log_vfolder['host'], 'POST', 'folder/file/fetch', + json={ + 'volume': volume_name, + 'vfid': str(log_vfolder['id']), + 'relpath': str( + PurePosixPath('task') + / kernel_id_str[:2] / kernel_id_str[2:4] + / f'{kernel_id_str[4:]}.log', + ), + }, + raise_for_status=True, + ) as (_, storage_resp): + while True: + chunk = await storage_resp.content.read(DEFAULT_CHUNK_SIZE) + if not chunk: + break + if not prepared: + await response.prepare(request) + prepared = True + await response.write(chunk) + except aiohttp.ClientResponseError as e: + raise StorageProxyError(status=e.status, extra_msg=e.message) + finally: + if prepared: + await response.write_eof() + return response + + +@attr.s(slots=True, auto_attribs=True, init=False) +class PrivateContext: + session_creation_tracker: Dict[str, asyncio.Event] + pending_waits: Set[asyncio.Task[None]] + agent_lost_checker: asyncio.Task[None] + stats_task: asyncio.Task[None] + database_ptask_group: aiotools.PersistentTaskGroup + rpc_ptask_group: aiotools.PersistentTaskGroup + webhook_ptask_group: aiotools.PersistentTaskGroup + + +async def init(app: web.Application) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['session.context'] + app_ctx.session_creation_tracker = {} + app_ctx.database_ptask_group = aiotools.PersistentTaskGroup() + app_ctx.rpc_ptask_group = aiotools.PersistentTaskGroup() + app_ctx.webhook_ptask_group = aiotools.PersistentTaskGroup() + + # passive events + evd = root_ctx.event_dispatcher + evd.subscribe(KernelPreparingEvent, app, handle_kernel_creation_lifecycle, name="api.session.kprep") + evd.subscribe(KernelPullingEvent, app, handle_kernel_creation_lifecycle, name="api.session.kpull") + evd.subscribe(KernelCreatingEvent, app, handle_kernel_creation_lifecycle, name="api.session.kcreat") + evd.subscribe(KernelStartedEvent, app, handle_kernel_creation_lifecycle, name="api.session.kstart") + evd.subscribe(KernelCancelledEvent, app, handle_kernel_creation_lifecycle, name="api.session.kstart") + evd.subscribe( + SessionStartedEvent, app, handle_session_creation_lifecycle, name="api.session.sstart", + ) + evd.subscribe( + SessionCancelledEvent, app, handle_session_creation_lifecycle, name="api.session.scancel", + ) + evd.consume( + KernelTerminatingEvent, app, handle_kernel_termination_lifecycle, name="api.session.kterming", + ) + evd.consume( + KernelTerminatedEvent, app, handle_kernel_termination_lifecycle, name="api.session.kterm", + ) + evd.consume( + SessionTerminatedEvent, app, handle_session_termination_lifecycle, name="api.session.sterm", + ) + evd.consume(SessionEnqueuedEvent, app, invoke_session_callback) + evd.consume(SessionScheduledEvent, app, invoke_session_callback) + evd.consume(SessionPreparingEvent, app, invoke_session_callback) + evd.consume(SessionStartedEvent, app, invoke_session_callback) + evd.consume(SessionCancelledEvent, app, invoke_session_callback) + evd.consume(SessionTerminatedEvent, app, invoke_session_callback) + evd.consume(SessionSuccessEvent, app, invoke_session_callback) + evd.consume(SessionFailureEvent, app, invoke_session_callback) + evd.consume(SessionSuccessEvent, app, handle_batch_result) + evd.consume(SessionFailureEvent, app, handle_batch_result) + evd.consume(AgentStartedEvent, app, handle_agent_lifecycle) + evd.consume(AgentTerminatedEvent, app, handle_agent_lifecycle) + evd.consume(AgentHeartbeatEvent, app, handle_agent_heartbeat) + + # action-trigerring events + evd.consume(DoSyncKernelStatsEvent, app, handle_kernel_stat_sync, name="api.session.synckstat") + evd.consume(DoSyncKernelLogsEvent, app, handle_kernel_log, name="api.session.syncklog") + evd.consume(DoTerminateSessionEvent, app, handle_destroy_session, name="api.session.doterm") + + app_ctx.pending_waits = set() + + # Scan ALIVE agents + app_ctx.agent_lost_checker = aiotools.create_timer( + functools.partial(check_agent_lost, root_ctx), 1.0) + app_ctx.stats_task = aiotools.create_timer( + functools.partial(report_stats, root_ctx), 5.0, + ) + + +async def shutdown(app: web.Application) -> None: + app_ctx: PrivateContext = app['session.context'] + app_ctx.agent_lost_checker.cancel() + await app_ctx.agent_lost_checker + app_ctx.stats_task.cancel() + await app_ctx.stats_task + + await app_ctx.webhook_ptask_group.shutdown() + await app_ctx.database_ptask_group.shutdown() + await app_ctx.rpc_ptask_group.shutdown() + + await cancel_tasks(app_ctx.pending_waits) + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['api_versions'] = (1, 2, 3, 4) + app['session.context'] = PrivateContext() + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('POST', '', create_from_params)) + cors.add(app.router.add_route('POST', '/_/create', create_from_params)) + cors.add(app.router.add_route('POST', '/_/create-from-template', create_from_template)) + cors.add(app.router.add_route('POST', '/_/create-cluster', create_cluster)) + cors.add(app.router.add_route('GET', '/_/match', match_sessions)) + session_resource = cors.add(app.router.add_resource(r'/{session_name}')) + cors.add(session_resource.add_route('GET', get_info)) + cors.add(session_resource.add_route('PATCH', restart)) + cors.add(session_resource.add_route('DELETE', destroy)) + cors.add(session_resource.add_route('POST', execute)) + task_log_resource = cors.add(app.router.add_resource(r'/_/logs')) + cors.add(task_log_resource.add_route('HEAD', get_task_logs)) + cors.add(task_log_resource.add_route('GET', get_task_logs)) + cors.add(app.router.add_route('GET', '/{session_name}/logs', get_container_logs)) + cors.add(app.router.add_route('POST', '/{session_name}/rename', rename_session)) + cors.add(app.router.add_route('POST', '/{session_name}/interrupt', interrupt)) + cors.add(app.router.add_route('POST', '/{session_name}/complete', complete)) + cors.add(app.router.add_route('POST', '/{session_name}/shutdown-service', shutdown_service)) + cors.add(app.router.add_route('POST', '/{session_name}/upload', upload_files)) + cors.add(app.router.add_route('GET', '/{session_name}/download', download_files)) + cors.add(app.router.add_route('GET', '/{session_name}/download_single', download_single)) + cors.add(app.router.add_route('GET', '/{session_name}/files', list_files)) + cors.add(app.router.add_route('POST', '/{session_name}/start-service', start_service)) + return app, [] diff --git a/src/ai/backend/manager/api/session_template.py b/src/ai/backend/manager/api/session_template.py new file mode 100644 index 0000000000..a9fe5af233 --- /dev/null +++ b/src/ai/backend/manager/api/session_template.py @@ -0,0 +1,342 @@ +import json +import datetime +import logging +from typing import ( + Any, + List, + Dict, + Mapping, + TYPE_CHECKING, + Tuple, +) +import uuid + +from aiohttp import web +import aiohttp_cors +import sqlalchemy as sa +import trafaret as t +import yaml + +from ai.backend.common import validators as tx +from ai.backend.common.logging import BraceStyleAdapter + +from ..models import ( + groups, session_templates, users, TemplateType, +) +from ..models.session_template import check_task_template + +from .auth import auth_required +from .exceptions import InvalidAPIParameters, TaskTemplateNotFound +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, Iterable, WebMiddleware +from .utils import check_api_params, get_access_key_scopes +from .session import _query_userinfo + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict( + { + tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String, + tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String, + t.Key('owner_access_key', default=None): t.Null | t.String, + t.Key('payload'): t.String, + }, +)) +async def create(request: web.Request, params: Any) -> web.Response: + if params['domain'] is None: + params['domain'] = request['user']['domain_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'SESSION_TEMPLATE.CREATE (ak:{0}/{1})', + requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + user_uuid, group_id, _ = await _query_userinfo(request, params, conn) + log.debug('Params: {0}', params) + try: + body = json.loads(params['payload']) + except json.JSONDecodeError: + try: + body = yaml.safe_load(params['payload']) + except (yaml.YAMLError, yaml.MarkedYAMLError): + raise InvalidAPIParameters('Malformed payload') + for st in body['session_templates']: + template_data = check_task_template(st['template']) + template_id = uuid.uuid4().hex + resp = { + 'id': template_id, + 'user': user_uuid.hex, + } + name = st['name'] if 'name' in st else template_data['metadata']['name'] + if 'group_id' in st: + group_id = st['group_id'] + if 'user_uuid' in st: + user_uuid = st['user_uuid'] + query = session_templates.insert().values({ + 'id': template_id, + 'created_at': datetime.datetime.now(), + 'domain_name': params['domain'], + 'group_id': group_id, + 'user_uuid': user_uuid, + 'name': name, + 'template': template_data, + 'type': TemplateType.TASK, + }) + result = await conn.execute(query) + assert result.rowcount == 1 + return web.json_response(resp) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('all', default=False): t.ToBool, + tx.AliasedKey(['group_id', 'groupId'], default=None): tx.UUID | t.String | t.Null, + }), +) +async def list_template(request: web.Request, params: Any) -> web.Response: + resp = [] + access_key = request['keypair']['access_key'] + domain_name = request['user']['domain_name'] + user_uuid = request['user']['uuid'] + log.info('SESSION_TEMPLATE.LIST (ak:{})', access_key) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + entries: List[Mapping[str, Any]] + j = ( + session_templates + .join(users, session_templates.c.user_uuid == users.c.uuid, isouter=True) + .join(groups, session_templates.c.group_id == groups.c.id, isouter=True) + ) + query = ( + sa.select([session_templates, users.c.email, groups.c.name], use_labels=True) + .select_from(j) + .where( + (session_templates.c.is_active) & + (session_templates.c.type == TemplateType.TASK), + ) + ) + result = await conn.execute(query) + entries = [] + for row in result.fetchall(): + is_owner = True if row.session_templates_user_uuid == user_uuid else False + entries.append({ + 'name': row.session_templates_name, + 'id': row.session_templates_id, + 'created_at': row.session_templates_created_at, + 'is_owner': is_owner, + 'user': (str(row.session_templates_user_uuid) + if row.session_templates_user_uuid else None), + 'group': (str(row.session_templates_group_id) + if row.session_templates_group_id else None), + 'user_email': row.users_email, + 'group_name': row.groups_name, + 'domain_name': domain_name, + 'type': row.session_templates_type, + 'template': row.session_templates_template, + }) + for entry in entries: + resp.append({ + 'name': entry['name'], + 'id': entry['id'].hex, + 'created_at': str(entry['created_at']), + 'is_owner': entry['is_owner'], + 'user': str(entry['user']), + 'group': str(entry['group']), + 'user_email': entry['user_email'], + 'group_name': entry['group_name'], + 'domain_name': domain_name, + 'type': entry['type'], + 'template': entry['template'], + }) + return web.json_response(resp) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('format', default='json'): t.Null | t.Enum('yaml', 'json'), + t.Key('owner_access_key', default=None): t.Null | t.String, + }), +) +async def get(request: web.Request, params: Any) -> web.Response: + if params['format'] not in ['yaml', 'json']: + raise InvalidAPIParameters('format should be "yaml" or "json"') + resp: Dict[str, Any] = {} + domain_name = request['user']['domain_name'] + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'SESSION_TEMPLATE.GET (ak:{0}/{1})', + requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + template_id = request.match_info['template_id'] + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + query = ( + sa.select([ + session_templates.c.template, + session_templates.c.name, + session_templates.c.user_uuid, + session_templates.c.group_id, + ]) + .select_from(session_templates) + .where( + (session_templates.c.id == template_id) & + (session_templates.c.is_active) & + (session_templates.c.type == TemplateType.TASK), + ) + ) + result = await conn.execute(query) + for row in result.fetchall(): + resp.update({ + 'template': row.template, + 'name': row.name, + 'user_uuid': str(row.user_uuid), + 'group_id': str(row.group_id), + 'domain_name': domain_name, + }) + if isinstance(resp, str): + resp = json.loads(resp) + else: + resp = json.loads(json.dumps(resp)) + return web.json_response(resp) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + tx.AliasedKey(['group', 'groupName', 'group_name'], default='default'): t.String, + tx.AliasedKey(['domain', 'domainName', 'domain_name'], default='default'): t.String, + t.Key('payload'): t.String, + t.Key('owner_access_key', default=None): t.Null | t.String, + }), +) +async def put(request: web.Request, params: Any) -> web.Response: + if params['domain'] is None: + params['domain'] = request['user']['domain_name'] + template_id = request.match_info['template_id'] + + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'SESSION_TEMPLATE.PUT (ak:{0}/{1})', + requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + user_uuid, group_id, _ = await _query_userinfo(request, params, conn) + query = ( + sa.select([session_templates.c.id]) + .select_from(session_templates) + .where( + (session_templates.c.id == template_id) & + (session_templates.c.is_active) & + (session_templates.c.type == TemplateType.TASK), + ) + ) + result = await conn.scalar(query) + if not result: + raise TaskTemplateNotFound + try: + body = json.loads(params['payload']) + except json.JSONDecodeError: + body = yaml.safe_load(params['payload']) + except (yaml.YAMLError, yaml.MarkedYAMLError): + raise InvalidAPIParameters('Malformed payload') + for st in body['session_templates']: + template_data = check_task_template(st['template']) + name = st['name'] if 'name' in st else template_data['metadata']['name'] + if 'group_id' in st: + group_id = st['group_id'] + if 'user_uuid' in st: + user_uuid = st['user_uuid'] + query = ( + sa.update(session_templates) + .values({ + 'group_id': group_id, + 'user_uuid': user_uuid, + 'name': name, + 'template': template_data, + }) + .where((session_templates.c.id == template_id)) + ) + result = await conn.execute(query) + assert result.rowcount == 1 + return web.json_response({'success': True}) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('owner_access_key', default=None): t.Null | t.String, + }), +) +async def delete(request: web.Request, params: Any) -> web.Response: + template_id = request.match_info['template_id'] + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'SESSION_TEMPLATE.DELETE (ak:{0}/{1})', + requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + query = ( + sa.select([session_templates.c.id]) + .select_from(session_templates) + .where( + (session_templates.c.id == template_id) & + (session_templates.c.is_active) & + (session_templates.c.type == TemplateType.TASK), + ) + ) + result = await conn.scalar(query) + if not result: + raise TaskTemplateNotFound + query = ( + sa.update(session_templates) + .values(is_active=False) + .where((session_templates.c.id == template_id)) + ) + result = await conn.execute(query) + assert result.rowcount == 1 + + return web.json_response({'success': True}) + + +async def init(app: web.Application) -> None: + pass + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['api_versions'] = (4, 5) + app['prefix'] = 'template/session' + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('POST', '', create)) + cors.add(app.router.add_route('GET', '', list_template)) + template_resource = cors.add(app.router.add_resource(r'/{template_id}')) + cors.add(template_resource.add_route('GET', get)) + cors.add(template_resource.add_route('PUT', put)) + cors.add(template_resource.add_route('DELETE', delete)) + + return app, [] diff --git a/src/ai/backend/manager/api/stream.py b/src/ai/backend/manager/api/stream.py new file mode 100644 index 0000000000..f12e384a4e --- /dev/null +++ b/src/ai/backend/manager/api/stream.py @@ -0,0 +1,743 @@ +''' +WebSocket-based streaming kernel interaction APIs. + +NOTE: For nginx-based setups, we need to gather all websocket-based API handlers + under this "/stream/"-prefixed app. +''' + +from __future__ import annotations + +import asyncio +import base64 +from collections import defaultdict +from datetime import timedelta +import json +import logging +import secrets +import textwrap +from typing import ( + Any, + AsyncIterator, + DefaultDict, + Iterable, + List, + Mapping, + MutableMapping, + TYPE_CHECKING, + Tuple, + Union, +) +from urllib.parse import urlparse +import uuid +import weakref + +import aiohttp +import aiotools +from aiohttp import web +import aiohttp_cors +from aiotools import apartial, adefer +import attr +import trafaret as t +import zmq, zmq.asyncio + +from ai.backend.common import redis, validators as tx +from ai.backend.common.events import KernelTerminatingEvent +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + AccessKey, + AgentId, + KernelId, SessionId, +) + +from ai.backend.manager.idle import AppStreamingStatus + +from ..defs import DEFAULT_ROLE +from ..models import kernels +from .auth import auth_required +from .exceptions import ( + AppNotFound, + BackendError, + InternalServerError, + InvalidAPIParameters, + SessionNotFound, + TooManySessionsMatched, +) +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, WebMiddleware +from .utils import check_api_params, call_non_bursty +from .wsproxy import TCPProxy +if TYPE_CHECKING: + from ..config import SharedConfig + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@server_status_required(READ_ALLOWED) +@auth_required +@adefer +async def stream_pty(defer, request: web.Request) -> web.StreamResponse: + root_ctx: RootContext = request.app['_root.context'] + app_ctx: PrivateContext = request.app['stream.context'] + database_ptask_group: aiotools.PersistentTaskGroup = request.app['database_ptask_group'] + session_name = request.match_info['session_name'] + access_key = request['keypair']['access_key'] + api_version = request['api_version'] + try: + compute_session = await asyncio.shield( + database_ptask_group.create_task(root_ctx.registry.get_session(session_name, access_key)), + ) + except SessionNotFound: + raise + log.info('STREAM_PTY(ak:{0}, s:{1})', access_key, session_name) + stream_key = compute_session['id'] + + await asyncio.shield(database_ptask_group.create_task( + root_ctx.registry.increment_session_usage(session_name, access_key), + )) + ws = web.WebSocketResponse(max_msg_size=root_ctx.local_config['manager']['max-wsmsg-size']) + await ws.prepare(request) + + myself = asyncio.current_task() + assert myself is not None + app_ctx.stream_pty_handlers[stream_key].add(myself) + defer(lambda: app_ctx.stream_pty_handlers[stream_key].discard(myself)) + + async def connect_streams(compute_session) -> Tuple[zmq.asyncio.Socket, zmq.asyncio.Socket]: + # TODO: refactor as custom row/table method + if compute_session.kernel_host is None: + kernel_host = urlparse(compute_session.agent_addr).hostname + else: + kernel_host = compute_session.kernel_host + stdin_addr = f'tcp://{kernel_host}:{compute_session.stdin_port}' + log.debug('stream_pty({0}): stdin: {1}', stream_key, stdin_addr) + stdin_sock = await app_ctx.zctx.socket(zmq.PUB) + stdin_sock.connect(stdin_addr) + stdin_sock.setsockopt(zmq.LINGER, 100) + stdout_addr = f'tcp://{kernel_host}:{compute_session.stdout_port}' + log.debug('stream_pty({0}): stdout: {1}', stream_key, stdout_addr) + stdout_sock = await app_ctx.zctx.socket(zmq.SUB) + stdout_sock.connect(stdout_addr) + stdout_sock.setsockopt(zmq.LINGER, 100) + stdout_sock.subscribe(b'') + return stdin_sock, stdout_sock + + # Wrap sockets in a list so that below coroutines can share reference changes. + socks = list(await connect_streams(compute_session)) + app_ctx.stream_stdin_socks[stream_key].add(socks[0]) + defer(lambda: app_ctx.stream_stdin_socks[stream_key].discard(socks[0])) + stream_sync = asyncio.Event() + + async def stream_stdin(): + nonlocal socks + try: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + data = json.loads(msg.data) + if data['type'] == 'stdin': + raw_data = base64.b64decode(data['chars'].encode('ascii')) + try: + await socks[0].send_mlutipart([raw_data]) + except (RuntimeError, zmq.error.ZMQError): + # when socks[0] is closed, re-initiate the connection. + app_ctx.stream_stdin_socks[stream_key].discard(socks[0]) + socks[1].close() + kernel = await asyncio.shield( + database_ptask_group.create_task( + root_ctx.registry.get_session( + session_name, + access_key, + ), + ), + ) + stdin_sock, stdout_sock = await connect_streams(kernel) + socks[0] = stdin_sock + socks[1] = stdout_sock + app_ctx.stream_stdin_socks[stream_key].add(socks[0]) + socks[0].write([raw_data]) + log.debug('stream_stdin({0}): zmq stream reset', + stream_key) + stream_sync.set() + continue + else: + await asyncio.shield( + database_ptask_group.create_task( + root_ctx.registry.increment_session_usage(session_name, access_key), + ), + ) + run_id = secrets.token_hex(8) + if data['type'] == 'resize': + code = f"%resize {data['rows']} {data['cols']}" + await root_ctx.registry.execute( + session_name, access_key, + api_version, run_id, 'query', code, {}, + flush_timeout=None, + ) + elif data['type'] == 'ping': + await root_ctx.registry.execute( + session_name, access_key, + api_version, run_id, 'query', '%ping', {}, + flush_timeout=None, + ) + elif data['type'] == 'restart': + # Close existing zmq sockets and let stream + # handlers get a new one with changed stdin/stdout + # ports. + log.debug('stream_stdin: restart requested') + if not socks[0].closed: + await asyncio.shield( + database_ptask_group.create_task( + root_ctx.registry.restart_session( + run_id, + session_name, + access_key, + ), + ), + ) + socks[0].close() + else: + log.warning( + "stream_stdin({0}): " + "duplicate kernel restart request; " + "ignoring it.", + stream_key, + ) + elif msg.type == aiohttp.WSMsgType.ERROR: + log.warning('stream_stdin({0}): connection closed ({1})', + stream_key, ws.exception()) + except asyncio.CancelledError: + # Agent or kernel is terminated. + raise + except Exception: + await root_ctx.error_monitor.capture_exception(context={'user': request['user']['uuid']}) + log.exception('stream_stdin({0}): unexpected error', stream_key) + finally: + log.debug('stream_stdin({0}): terminated', stream_key) + if not socks[0].closed: + socks[0].close() + + async def stream_stdout(): + nonlocal socks + log.debug('stream_stdout({0}): started', stream_key) + try: + while True: + try: + data = await socks[1].recv_multipart() + except (asyncio.CancelledError, zmq.error.ZMQError): + if socks[0] not in app_ctx.stream_stdin_socks: + # we are terminating + return + # connection is closed, so wait until stream_stdin() recovers it. + await stream_sync.wait() + stream_sync.clear() + log.debug('stream_stdout({0}): zmq stream reset', stream_key) + continue + if ws.closed: + break + await ws.send_str(json.dumps({ + 'type': 'out', + 'data': base64.b64encode(data[0]).decode('ascii'), + }, ensure_ascii=False)) + except asyncio.CancelledError: + pass + except: + await root_ctx.error_monitor.capture_exception(context={'user': request['user']['uuid']}) + log.exception('stream_stdout({0}): unexpected error', stream_key) + finally: + log.debug('stream_stdout({0}): terminated', stream_key) + socks[1].close() + + # According to aiohttp docs, reading ws must be done inside this task. + # We execute the stdout handler as another task. + stdout_task = asyncio.create_task(stream_stdout()) + try: + await stream_stdin() + except Exception: + await root_ctx.error_monitor.capture_exception(context={'user': request['user']['uuid']}) + log.exception('stream_pty({0}): unexpected error', stream_key) + finally: + stdout_task.cancel() + await stdout_task + return ws + + +@server_status_required(READ_ALLOWED) +@auth_required +@adefer +async def stream_execute(defer, request: web.Request) -> web.StreamResponse: + ''' + WebSocket-version of gateway.kernel.execute(). + ''' + root_ctx: RootContext = request.app['_root.context'] + app_ctx: PrivateContext = request.app['stream.context'] + database_ptask_group: aiotools.PersistentTaskGroup = request.app['database_ptask_group'] + rpc_ptask_group: aiotools.PersistentTaskGroup = request.app['rpc_ptask_group'] + + local_config = root_ctx.local_config + registry = root_ctx.registry + session_name = request.match_info['session_name'] + access_key = request['keypair']['access_key'] + api_version = request['api_version'] + log.info('STREAM_EXECUTE(ak:{0}, s:{1})', access_key, session_name) + try: + compute_session = await asyncio.shield( + database_ptask_group.create_task( + registry.get_session(session_name, access_key), # noqa + ), + ) + except SessionNotFound: + raise + stream_key = compute_session['id'] + + await asyncio.shield(database_ptask_group.create_task( + registry.increment_session_usage(session_name, access_key), + )) + ws = web.WebSocketResponse(max_msg_size=local_config['manager']['max-wsmsg-size']) + await ws.prepare(request) + + myself = asyncio.current_task() + assert myself is not None + app_ctx.stream_execute_handlers[stream_key].add(myself) + defer(lambda: app_ctx.stream_execute_handlers[stream_key].discard(myself)) + + # This websocket connection itself is a "run". + run_id = secrets.token_hex(8) + + try: + if ws.closed: + log.debug('STREAM_EXECUTE: client disconnected (cancelled)') + return ws + params = await ws.receive_json() + assert params.get('mode'), 'mode is missing or empty!' + mode = params['mode'] + assert mode in {'query', 'batch'}, 'mode has an invalid value.' + code = params.get('code', '') + opts = params.get('options', None) or {} + + while True: + # TODO: rewrite agent and kernel-runner for unbuffered streaming. + raw_result = await registry.execute( + session_name, access_key, + api_version, run_id, mode, code, opts, + flush_timeout=0.2) + if ws.closed: + log.debug('STREAM_EXECUTE: client disconnected (interrupted)') + await asyncio.shield(rpc_ptask_group.create_task( + registry.interrupt_session(session_name, access_key), + )) + break + if raw_result is None: + # repeat until we get finished + log.debug('STREAM_EXECUTE: none returned, continuing...') + mode = 'continue' + code = '' + opts.clear() + continue + await ws.send_json({ + 'status': raw_result['status'], + 'console': raw_result.get('console'), + 'exitCode': raw_result.get('exitCode'), + 'options': raw_result.get('options'), + 'files': raw_result.get('files'), + }) + if raw_result['status'] == 'waiting-input': + mode = 'input' + code = await ws.receive_str() + elif raw_result['status'] == 'finished': + break + else: + # repeat until we get finished + mode = 'continue' + code = '' + opts.clear() + except (json.decoder.JSONDecodeError, AssertionError) as e: + log.warning('STREAM_EXECUTE: invalid/missing parameters: {0!r}', e) + if not ws.closed: + await ws.send_json({ + 'status': 'error', + 'msg': f'Invalid API parameters: {e!r}', + }) + except BackendError as e: + log.exception('STREAM_EXECUTE: exception') + if not ws.closed: + await ws.send_json({ + 'status': 'error', + 'msg': f'BackendError: {e!r}', + }) + raise + except asyncio.CancelledError: + if not ws.closed: + await ws.send_json({ + 'status': 'server-restarting', + 'msg': 'The API server is going to restart for maintenance. ' + 'Please connect again with the same run ID.', + }) + raise + finally: + return ws + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['app', 'service']): t.String, + # The port argument is only required to use secondary ports + # when the target app listens multiple TCP ports. + # Otherwise it should be omitted or set to the same value of + # the actual port number used by the app. + tx.AliasedKey(['port'], default=None): t.Null | t.Int[1024:65535], + tx.AliasedKey(['envs'], default=None): t.Null | t.String, # stringified JSON + # e.g., '{"PASSWORD": "12345"}' + tx.AliasedKey(['arguments'], default=None): t.Null | t.String, # stringified JSON + # e.g., '{"-P": "12345"}' + # The value can be one of: + # None, str, List[str] + })) +@adefer +async def stream_proxy(defer, request: web.Request, params: Mapping[str, Any]) -> web.StreamResponse: + root_ctx: RootContext = request.app['_root.context'] + app_ctx: PrivateContext = request.app['stream.context'] + database_ptask_group: aiotools.PersistentTaskGroup = request.app['database_ptask_group'] + rpc_ptask_group: aiotools.PersistentTaskGroup = request.app['rpc_ptask_group'] + session_name: str = request.match_info['session_name'] + access_key: AccessKey = request['keypair']['access_key'] + service: str = params['app'] + myself = asyncio.current_task() + assert myself is not None + try: + kernel = await asyncio.shield(database_ptask_group.create_task( + root_ctx.registry.get_session(session_name, access_key), + )) + except (SessionNotFound, TooManySessionsMatched): + raise + stream_key = kernel['id'] + stream_id = uuid.uuid4().hex + app_ctx.stream_proxy_handlers[stream_key].add(myself) + defer(lambda: app_ctx.stream_proxy_handlers[stream_key].discard(myself)) + if kernel['kernel_host'] is None: + kernel_host = urlparse(kernel['agent_addr']).hostname + else: + kernel_host = kernel['kernel_host'] + for sport in kernel['service_ports']: + if sport['name'] == service: + if params['port']: + # using one of the primary/secondary ports of the app + try: + hport_idx = sport['container_ports'].index(params['port']) + except ValueError: + raise InvalidAPIParameters( + f"Service {service} does not open the port number {params['port']}.") + host_port = sport['host_ports'][hport_idx] + else: # using the default (primary) port of the app + if 'host_ports' not in sport: + host_port = sport['host_port'] # legacy kernels + else: + host_port = sport['host_ports'][0] + dest = (kernel_host, host_port) + break + else: + raise AppNotFound(f'{session_name}:{service}') + + log.info( + 'STREAM_WSPROXY (ak:{}, s:{}): tunneling {}:{} to {}', + access_key, session_name, + service, sport['protocol'], '{}:{}'.format(*dest), + ) + if sport['protocol'] == 'tcp': + proxy_cls = TCPProxy + elif sport['protocol'] == 'pty': + raise NotImplementedError + elif sport['protocol'] == 'http': + proxy_cls = TCPProxy + elif sport['protocol'] == 'preopen': + proxy_cls = TCPProxy + else: + raise InvalidAPIParameters( + f"Unsupported service protocol: {sport['protocol']}") + + redis_live = root_ctx.redis_live + conn_tracker_key = f"session.{kernel['id']}.active_app_connections" + conn_tracker_val = f"{kernel['id']}:{service}:{stream_id}" + + _conn_tracker_script = textwrap.dedent(''' + local now = redis.call('TIME') + now = now[1] + (now[2] / (10^6)) + redis.call('ZADD', KEYS[1], now, ARGV[1]) + ''') + + async def refresh_cb(kernel_id: str, data: bytes) -> None: + await asyncio.shield(rpc_ptask_group.create_task( + call_non_bursty( + conn_tracker_key, + apartial( + redis.execute_script, + redis_live, 'update_conn_tracker', _conn_tracker_script, + [conn_tracker_key], + [conn_tracker_val], + ), + max_bursts=128, max_idle=5000, + ), + )) + + down_cb = apartial(refresh_cb, kernel['id']) + up_cb = apartial(refresh_cb, kernel['id']) + ping_cb = apartial(refresh_cb, kernel['id']) + + kernel_id = kernel['id'] + + async def add_conn_track() -> None: + async with app_ctx.conn_tracker_lock: + app_ctx.active_session_ids[kernel_id] += 1 + now = await redis.execute(redis_live, lambda r: r.time()) + now = now[0] + (now[1] / (10**6)) + await redis.execute( + redis_live, + # aioredis' ZADD implementation flattens mapping in value-key order + lambda r: r.zadd(conn_tracker_key, {conn_tracker_val: now}), + ) + await root_ctx.idle_checker_host.update_app_streaming_status( + kernel_id, + AppStreamingStatus.HAS_ACTIVE_CONNECTIONS, + ) + + async def clear_conn_track() -> None: + async with app_ctx.conn_tracker_lock: + app_ctx.active_session_ids[kernel_id] -= 1 + if app_ctx.active_session_ids[kernel_id] <= 0: + del app_ctx.active_session_ids[kernel_id] + await redis.execute(redis_live, lambda r: r.zrem(conn_tracker_key, conn_tracker_val)) + remaining_count = await redis.execute( + redis_live, + lambda r: r.zcount( + conn_tracker_key, + float('-inf'), float('+inf'), + ), + ) + if remaining_count == 0: + await root_ctx.idle_checker_host.update_app_streaming_status( + kernel_id, + AppStreamingStatus.NO_ACTIVE_CONNECTIONS, + ) + + try: + await asyncio.shield(database_ptask_group.create_task( + add_conn_track(), + )) + await asyncio.shield(database_ptask_group.create_task( + root_ctx.registry.increment_session_usage(session_name, access_key), + )) + + opts: MutableMapping[str, Union[None, str, List[str]]] = {} + if params['arguments'] is not None: + opts['arguments'] = json.loads(params['arguments']) + if params['envs'] is not None: + opts['envs'] = json.loads(params['envs']) + + result = await asyncio.shield( + rpc_ptask_group.create_task( + root_ctx.registry.start_service(session_name, access_key, service, opts), + ), + ) + if result['status'] == 'failed': + raise InternalServerError( + "Failed to launch the app service", + extra_data=result['error']) + + # TODO: weakref to proxies for graceful shutdown? + ws = web.WebSocketResponse( + autoping=False, + max_msg_size=root_ctx.local_config['manager']['max-wsmsg-size'], + ) + await ws.prepare(request) + proxy = proxy_cls( + ws, dest[0], dest[1], + downstream_callback=down_cb, + upstream_callback=up_cb, + ping_callback=ping_cb, + ) + return await proxy.proxy() + except asyncio.CancelledError: + log.debug('stream_proxy({}, {}) cancelled', stream_key, service) + raise + finally: + await asyncio.shield(database_ptask_group.create_task(clear_conn_track())) + + +@server_status_required(READ_ALLOWED) +@auth_required +async def get_stream_apps(request: web.Request) -> web.Response: + session_name = request.match_info['session_name'] + access_key = request['keypair']['access_key'] + root_ctx: RootContext = request.app['_root.context'] + compute_session = await root_ctx.registry.get_session(session_name, access_key) + if compute_session['service_ports'] is None: + return web.json_response([]) + resp = [] + for item in compute_session['service_ports']: + response_dict = { + 'name': item['name'], + 'protocol': item['protocol'], + 'ports': item['container_ports'], + } + if 'url_template' in item.keys(): + response_dict['url_template'] = item['url_template'] + if 'allowed_arguments' in item.keys(): + response_dict['allowed_arguments'] = item['allowed_arguments'] + if 'allowed_envs' in item.keys(): + response_dict['allowed_envs'] = item['allowed_envs'] + resp.append(response_dict) + return web.json_response(resp) + + +async def handle_kernel_terminating( + app: web.Application, + source: AgentId, + event: KernelTerminatingEvent, +) -> None: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['stream.context'] + try: + kernel = await root_ctx.registry.get_kernel( + event.kernel_id, + (kernels.c.cluster_role, kernels.c.status), + allow_stale=True, + ) + except SessionNotFound: + return + if kernel['cluster_role'] == DEFAULT_ROLE: + stream_key = kernel['id'] + cancelled_tasks = [] + for sock in app_ctx.stream_stdin_socks[stream_key]: + sock.close() + for handler in list(app_ctx.stream_pty_handlers.get(stream_key, [])): + handler.cancel() + cancelled_tasks.append(handler) + for handler in list(app_ctx.stream_execute_handlers.get(stream_key, [])): + handler.cancel() + cancelled_tasks.append(handler) + for handler in list(app_ctx.stream_proxy_handlers.get(stream_key, [])): + handler.cancel() + cancelled_tasks.append(handler) + await asyncio.gather(*cancelled_tasks, return_exceptions=True) + # TODO: reconnect if restarting? + + +async def stream_conn_tracker_gc(root_ctx: RootContext, app_ctx: PrivateContext) -> None: + redis_live = root_ctx.redis_live + shared_config: SharedConfig = root_ctx.shared_config + try: + while True: + no_packet_timeout: timedelta = tx.TimeDuration().check( + await shared_config.etcd.get('config/idle/app-streaming-packet-timeout') or '5m', + ) + async with app_ctx.conn_tracker_lock: + now = await redis.execute(redis_live, lambda r: r.time()) + now = now[0] + (now[1] / (10**6)) + for session_id in app_ctx.active_session_ids.keys(): + conn_tracker_key = f"session.{session_id}.active_app_connections" + prev_remaining_count = await redis.execute( + redis_live, + lambda r: r.zcount(conn_tracker_key, float('-inf'), float('+inf')), + ) + removed_count = await redis.execute( + redis_live, + lambda r: r.zremrangebyscore( + conn_tracker_key, float('-inf'), now - no_packet_timeout.total_seconds(), + ), + ) + remaining_count = await redis.execute( + redis_live, + lambda r: r.zcount(conn_tracker_key, float('-inf'), float('+inf')), + ) + log.debug(f"conn_tracker: gc {session_id} " + f"removed/remaining = {removed_count}/{remaining_count}") + if prev_remaining_count > 0 and remaining_count == 0: + await root_ctx.idle_checker_host.update_app_streaming_status( + session_id, + AppStreamingStatus.NO_ACTIVE_CONNECTIONS, + ) + await asyncio.sleep(10) + except asyncio.CancelledError: + pass + + +@attr.s(slots=True, auto_attribs=True, init=False) +class PrivateContext: + stream_pty_handlers: DefaultDict[KernelId, weakref.WeakSet[asyncio.Task]] + stream_execute_handlers: DefaultDict[KernelId, weakref.WeakSet[asyncio.Task]] + stream_proxy_handlers: DefaultDict[KernelId, weakref.WeakSet[asyncio.Task]] + stream_stdin_socks: DefaultDict[KernelId, weakref.WeakSet[zmq.asyncio.Socket]] + zctx: zmq.asyncio.Context + conn_tracker_lock: asyncio.Lock + conn_tracker_gc_task: asyncio.Task + active_session_ids: DefaultDict[SessionId, int] + + +async def stream_app_ctx(app: web.Application) -> AsyncIterator[None]: + root_ctx: RootContext = app['_root.context'] + app_ctx: PrivateContext = app['stream.context'] + + app_ctx.stream_pty_handlers = defaultdict(weakref.WeakSet) + app_ctx.stream_execute_handlers = defaultdict(weakref.WeakSet) + app_ctx.stream_proxy_handlers = defaultdict(weakref.WeakSet) + app_ctx.stream_stdin_socks = defaultdict(weakref.WeakSet) + app_ctx.zctx = zmq.asyncio.Context() + app_ctx.conn_tracker_lock = asyncio.Lock() + app_ctx.active_session_ids = defaultdict(int) # multiset[int] + app_ctx.conn_tracker_gc_task = asyncio.create_task(stream_conn_tracker_gc(root_ctx, app_ctx)) + + root_ctx.event_dispatcher.subscribe(KernelTerminatingEvent, app, handle_kernel_terminating) + + yield + + # The shutdown handler below is called before this cleanup. + app_ctx.zctx.term() + + +async def stream_shutdown(app: web.Application) -> None: + database_ptask_group: aiotools.PersistentTaskGroup = app['database_ptask_group'] + rpc_ptask_group: aiotools.PersistentTaskGroup = app['rpc_ptask_group'] + await database_ptask_group.shutdown() + await rpc_ptask_group.shutdown() + cancelled_tasks: List[asyncio.Task] = [] + app_ctx: PrivateContext = app['stream.context'] + app_ctx.conn_tracker_gc_task.cancel() + cancelled_tasks.append(app_ctx.conn_tracker_gc_task) + for per_kernel_handlers in app_ctx.stream_pty_handlers.values(): + for handler in list(per_kernel_handlers): + if not handler.done(): + handler.cancel() + cancelled_tasks.append(handler) + for per_kernel_handlers in app_ctx.stream_execute_handlers.values(): + for handler in list(per_kernel_handlers): + if not handler.done(): + handler.cancel() + cancelled_tasks.append(handler) + for per_kernel_handlers in app_ctx.stream_proxy_handlers.values(): + for handler in list(per_kernel_handlers): + if not handler.done(): + handler.cancel() + cancelled_tasks.append(handler) + await asyncio.gather(*cancelled_tasks, return_exceptions=True) + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.cleanup_ctx.append(stream_app_ctx) + app.on_shutdown.append(stream_shutdown) + app['prefix'] = 'stream' + app['api_versions'] = (2, 3, 4) + app['stream.context'] = PrivateContext() + app["database_ptask_group"] = aiotools.PersistentTaskGroup() + app["rpc_ptask_group"] = aiotools.PersistentTaskGroup() + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + add_route = app.router.add_route + cors.add(add_route('GET', r'/session/{session_name}/pty', stream_pty)) + cors.add(add_route('GET', r'/session/{session_name}/execute', stream_execute)) + cors.add(add_route('GET', r'/session/{session_name}/apps', get_stream_apps)) + # internally both tcp/http proxies use websockets as API/agent-level transports, + # and thus they have the same implementation here. + cors.add(add_route('GET', r'/session/{session_name}/httpproxy', stream_proxy)) + cors.add(add_route('GET', r'/session/{session_name}/tcpproxy', stream_proxy)) + return app, [] diff --git a/src/ai/backend/manager/api/types.py b/src/ai/backend/manager/api/types.py new file mode 100644 index 0000000000..a5c08a4310 --- /dev/null +++ b/src/ai/backend/manager/api/types.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import ( + Awaitable, + Callable, + Iterable, + AsyncContextManager, + Mapping, + Tuple, + TYPE_CHECKING, +) +from typing_extensions import TypeAlias + +from aiohttp import web +import aiohttp_cors + +if TYPE_CHECKING: + from .context import RootContext + + +WebRequestHandler: TypeAlias = Callable[ + [web.Request], + Awaitable[web.StreamResponse], +] +WebMiddleware: TypeAlias = Callable[ + [web.Request, WebRequestHandler], + Awaitable[web.StreamResponse], +] + +CORSOptions: TypeAlias = Mapping[str, aiohttp_cors.ResourceOptions] +AppCreator: TypeAlias = Callable[ + [CORSOptions], + Tuple[web.Application, Iterable[WebMiddleware]], +] + +CleanupContext: TypeAlias = Callable[['RootContext'], AsyncContextManager[None]] diff --git a/src/ai/backend/manager/api/userconfig.py b/src/ai/backend/manager/api/userconfig.py new file mode 100644 index 0000000000..513afe59e1 --- /dev/null +++ b/src/ai/backend/manager/api/userconfig.py @@ -0,0 +1,241 @@ +import logging +import re +from typing import Any, TYPE_CHECKING, Tuple + +from aiohttp import web +import aiohttp_cors +import trafaret as t + +from ai.backend.common import msgpack +from ai.backend.common.logging import BraceStyleAdapter + +from ..models import ( + keypairs, + vfolders, + query_accessible_vfolders, + query_bootstrap_script, + query_owned_dotfiles, + verify_dotfile_name, + MAXIMUM_DOTFILE_SIZE, +) +from .auth import auth_required +from .exceptions import ( + InvalidAPIParameters, DotfileCreationFailed, + DotfileNotFound, DotfileAlreadyExists, +) +from .manager import READ_ALLOWED, server_status_required +from .types import CORSOptions, Iterable, WebMiddleware +from .utils import check_api_params, get_access_key_scopes + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict( + { + t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE), + t.Key('path'): t.String, + t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII), + t.Key('owner_access_key', default=None): t.Null | t.String, + }, +)) +async def create(request: web.Request, params: Any) -> web.Response: + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'USERCONFIG.CREATE (ak:{0}/{1})', requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + root_ctx: RootContext = request.app['_root.context'] + user_uuid = request['user']['uuid'] + async with root_ctx.db.begin() as conn: + path: str = params['path'] + dotfiles, leftover_space = await query_owned_dotfiles(conn, owner_access_key) + if leftover_space == 0: + raise DotfileCreationFailed('No leftover space for dotfile storage') + if len(dotfiles) == 100: + raise DotfileCreationFailed('Dotfile creation limit reached') + if not verify_dotfile_name(path): + raise InvalidAPIParameters('dotfile path is reserved for internal operations.') + duplicate_vfolder = \ + await query_accessible_vfolders(conn, user_uuid, extra_vf_conds=(vfolders.c.name == path)) + if len(duplicate_vfolder) > 0: + raise InvalidAPIParameters('dotfile path conflicts with your dot-prefixed vFolder') + duplicate = [x for x in dotfiles if x['path'] == path] + if len(duplicate) > 0: + raise DotfileAlreadyExists + new_dotfiles = list(dotfiles) + new_dotfiles.append({'path': path, 'perm': params['permission'], 'data': params['data']}) + dotfile_packed = msgpack.packb(new_dotfiles) + if len(dotfile_packed) > MAXIMUM_DOTFILE_SIZE: + raise DotfileCreationFailed('No leftover space for dotfile storage') + + query = ( + keypairs.update() + .values(dotfiles=dotfile_packed) + .where(keypairs.c.access_key == owner_access_key) + ) + await conn.execute(query) + return web.json_response({}) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params(t.Dict({ + t.Key('path', default=None): t.Null | t.String, + t.Key('owner_access_key', default=None): t.Null | t.String, +})) +async def list_or_get(request: web.Request, params: Any) -> web.Response: + resp = [] + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'USERCONFIG.LIST_OR_GET (ak:{0}/{1})', requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + async with root_ctx.db.begin() as conn: + if params['path']: + dotfiles, _ = await query_owned_dotfiles(conn, owner_access_key) + for dotfile in dotfiles: + if dotfile['path'] == params['path']: + return web.json_response(dotfile) + raise DotfileNotFound + else: + dotfiles, _ = await query_owned_dotfiles(conn, access_key) + for entry in dotfiles: + resp.append({ + 'path': entry['path'], + 'permission': entry['perm'], + 'data': entry['data'], + }) + return web.json_response(resp) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict( + { + t.Key('data'): t.String(max_length=MAXIMUM_DOTFILE_SIZE), + t.Key('path'): t.String, + t.Key('permission'): t.Regexp(r'^[0-7]{3}$', re.ASCII), + t.Key('owner_access_key', default=None): t.Null | t.String, + }, +)) +async def update(request: web.Request, params: Any) -> web.Response: + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'USERCONFIG.CREATE (ak:{0}/{1})', requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + path: str = params['path'] + dotfiles, _ = await query_owned_dotfiles(conn, owner_access_key) + new_dotfiles = [x for x in dotfiles if x['path'] != path] + if len(new_dotfiles) == len(dotfiles): + raise DotfileNotFound + + new_dotfiles.append({'path': path, 'perm': params['permission'], 'data': params['data']}) + dotfile_packed = msgpack.packb(new_dotfiles) + if len(dotfile_packed) > MAXIMUM_DOTFILE_SIZE: + raise DotfileCreationFailed('No leftover space for dotfile storage') + + query = ( + keypairs.update() + .values(dotfiles=dotfile_packed) + .where(keypairs.c.access_key == owner_access_key) + ) + await conn.execute(query) + return web.json_response({}) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('path'): t.String, + t.Key('owner_access_key', default=None): t.Null | t.String, + }), +) +async def delete(request: web.Request, params: Any) -> web.Response: + requester_access_key, owner_access_key = await get_access_key_scopes(request, params) + log.info( + 'USERCONFIG.DELETE (ak:{0}/{1})', requester_access_key, + owner_access_key if owner_access_key != requester_access_key else '*', + ) + root_ctx: RootContext = request.app['_root.context'] + path = params['path'] + async with root_ctx.db.begin() as conn: + dotfiles, _ = await query_owned_dotfiles(conn, owner_access_key) + new_dotfiles = [x for x in dotfiles if x['path'] != path] + if len(new_dotfiles) == len(dotfiles): + raise DotfileNotFound + dotfile_packed = msgpack.packb(new_dotfiles) + query = (keypairs.update() + .values(dotfiles=dotfile_packed) + .where(keypairs.c.access_key == owner_access_key)) + await conn.execute(query) + return web.json_response({'success': True}) + + +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params(t.Dict( + { + t.Key('script'): t.String(allow_blank=True, max_length=MAXIMUM_DOTFILE_SIZE), + }, +)) +async def update_bootstrap_script(request: web.Request, params: Any) -> web.Response: + access_key = request['keypair']['access_key'] + log.info('UPDATE_BOOTSTRAP_SCRIPT (ak:{0})', access_key) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + script = params.get('script', '').strip() + if len(script) > MAXIMUM_DOTFILE_SIZE: + raise DotfileCreationFailed('Maximum bootstrap script length reached') + query = (keypairs.update() + .values(bootstrap_script=script) + .where(keypairs.c.access_key == access_key)) + await conn.execute(query) + return web.json_response({}) + + +@auth_required +@server_status_required(READ_ALLOWED) +async def get_bootstrap_script(request: web.Request) -> web.Response: + access_key = request['keypair']['access_key'] + log.info('USERCONFIG.GET_BOOTSTRAP_SCRIPT (ak:{0})', access_key) + root_ctx: RootContext = request.app['_root.context'] + async with root_ctx.db.begin() as conn: + script, _ = await query_bootstrap_script(conn, access_key) + return web.json_response(script) + + +async def init(app: web.Application) -> None: + pass + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iterable[WebMiddleware]]: + app = web.Application() + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + app['api_versions'] = (4, 5) + app['prefix'] = 'user-config' + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + cors.add(app.router.add_route('POST', '/dotfiles', create)) + cors.add(app.router.add_route('GET', '/dotfiles', list_or_get)) + cors.add(app.router.add_route('PATCH', '/dotfiles', update)) + cors.add(app.router.add_route('DELETE', '/dotfiles', delete)) + cors.add(app.router.add_route('POST', '/bootstrap-script', update_bootstrap_script)) + cors.add(app.router.add_route('GET', '/bootstrap-script', get_bootstrap_script)) + + return app, [] diff --git a/src/ai/backend/manager/api/utils.py b/src/ai/backend/manager/api/utils.py new file mode 100644 index 0000000000..ea91412efd --- /dev/null +++ b/src/ai/backend/manager/api/utils.py @@ -0,0 +1,372 @@ +import asyncio +from collections import defaultdict +import functools +import io +import inspect +import itertools +import json +import logging +import numbers +import re +import time +import traceback +from typing import ( + Any, + Awaitable, + Callable, + Hashable, + Mapping, + MutableMapping, + Optional, TYPE_CHECKING, + Tuple, + Union, +) +import uuid + +from aiohttp import web +import trafaret as t +import sqlalchemy as sa +import yaml + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import AccessKey + +from ..models import keypairs, users, UserRole +from .exceptions import InvalidAPIParameters, GenericForbidden, QueryNotImplemented + +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +_rx_sitepkg_path = re.compile(r'^.+/site-packages/') + + +def method_placeholder(orig_method): + async def _handler(request): + raise web.HTTPMethodNotAllowed(request.method, [orig_method]) + + return _handler + + +async def get_access_key_scopes(request: web.Request, params: Any = None) -> Tuple[AccessKey, AccessKey]: + if not request['is_authorized']: + raise GenericForbidden('Only authorized requests may have access key scopes.') + root_ctx: RootContext = request.app['_root.context'] + requester_access_key: AccessKey = request['keypair']['access_key'] + if ( + params is not None and + (owner_access_key := params.get('owner_access_key', None)) is not None and + owner_access_key != requester_access_key + ): + async with root_ctx.db.begin_readonly() as conn: + query = ( + sa.select([users.c.domain_name, users.c.role]) + .select_from( + sa.join(keypairs, users, + keypairs.c.user == users.c.uuid)) + .where(keypairs.c.access_key == owner_access_key) + ) + result = await conn.execute(query) + row = result.first() + if row is None: + raise InvalidAPIParameters("Unknown owner access key") + owner_domain = row['domain_name'] + owner_role = row['role'] + if request['is_superadmin']: + pass + elif request['is_admin']: + if request['user']['domain_name'] != owner_domain: + raise GenericForbidden( + "Domain-admins can perform operations on behalf of " + "other users in the same domain only.", + ) + if owner_role == UserRole.SUPERADMIN: + raise GenericForbidden( + "Domain-admins cannot perform operations on behalf of super-admins.", + ) + pass + else: + raise GenericForbidden( + "Only admins can perform operations on behalf of other users.", + ) + return requester_access_key, owner_access_key + return requester_access_key, requester_access_key + + +async def get_user_scopes( + request: web.Request, + params: Optional[dict[str, Any]] = None, +) -> tuple[uuid.UUID, UserRole]: + root_ctx: RootContext = request.app['_root.context'] + if not request['is_authorized']: + raise GenericForbidden("Only authorized requests may have user scopes.") + if ( + params is not None and + (owner_user_email := params.get('owner_user_email')) is not None + ): + if not request['is_superadmin']: + raise InvalidAPIParameters("Only superadmins may have user scopes.") + async with root_ctx.db.begin_readonly() as conn: + user_query = ( + sa.select([users.c.uuid, users.c.role, users.c.domain_name]) + .select_from(users) + .where( + (users.c.email == owner_user_email), + ) + ) + result = await conn.execute(user_query) + row = result.first() + if row is None: + raise InvalidAPIParameters("Cannot delegate an unknown user") + owner_user_uuid = row['uuid'] + owner_user_role = row['role'] + owner_user_domain = row['domain_name'] + if request['is_superadmin']: + pass + elif request['is_admin']: + if request['user']['domain_name'] != owner_user_domain: + raise GenericForbidden( + "Domain-admins can perform operations on behalf of " + "other users in the same domain only.", + ) + if owner_user_role == UserRole.SUPERADMIN: + raise GenericForbidden( + "Domain-admins cannot perform operations on behalf of super-admins.", + ) + pass + else: + raise GenericForbidden( + "Only admins can perform operations on behalf of other users.", + ) + else: + owner_user_uuid = request['user']['uuid'] + owner_user_role = request['user']['role'] + return owner_user_uuid, owner_user_role + + +def check_api_params( + checker: t.Trafaret, + loads: Callable[[str], Any] = None, + query_param_checker: t.Trafaret = None, +) -> Any: + # FIXME: replace ... with [web.Request, Any...] in the future mypy + def wrap(handler: Callable[..., Awaitable[web.Response]]): + + @functools.wraps(handler) + async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: + orig_params: Any + body: str = '' + try: + body_exists = request.can_read_body + if body_exists: + body = await request.text() + if request.content_type == 'text/yaml': + orig_params = yaml.load(body, Loader=yaml.BaseLoader) + else: + orig_params = (loads or json.loads)(body) + else: + orig_params = dict(request.query) + stripped_params = orig_params.copy() + log.debug('stripped raw params: {}', mask_sensitive_keys(stripped_params)) + checked_params = checker.check(stripped_params) + if body_exists and query_param_checker: + query_params = query_param_checker.check(request.query) + kwargs['query'] = query_params + except (json.decoder.JSONDecodeError, yaml.YAMLError, yaml.MarkedYAMLError): + raise InvalidAPIParameters('Malformed body') + except t.DataError as e: + raise InvalidAPIParameters('Input validation error', + extra_data=e.as_dict()) + return await handler(request, checked_params, *args, **kwargs) + + return wrapped + + return wrap + + +_danger_words = ['password', 'passwd', 'secret'] + + +def mask_sensitive_keys(data: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Returns a new cloned mapping by masking the values of + sensitive keys with "***" from the given mapping. + """ + sanitized = dict() + for k, v in data.items(): + if any((w in k.lower()) for w in _danger_words): + sanitized[k] = '***' + else: + sanitized[k] = v + return sanitized + + +def trim_text(value: str, maxlen: int) -> str: + if len(value) <= maxlen: + return value + value = value[:maxlen - 3] + '...' + return value + + +class _Infinity(numbers.Number): + + def __lt__(self, o): + return False + + def __le__(self, o): + return False + + def __gt__(self, o): + return True + + def __ge__(self, o): + return False + + def __float__(self): + return float('inf') + + def __int__(self): + return 0xffff_ffff_ffff_ffff # a practical 64-bit maximum + + def __hash__(self): + return hash(self) + + +numbers.Number.register(_Infinity) +Infinity = _Infinity() + + +def prettify_traceback(exc): + # Make a compact stack trace string + with io.StringIO() as buf: + while exc is not None: + print(f'Exception: {exc!r}', file=buf) + if exc.__traceback__ is None: + print(' (no traceback available)', file=buf) + else: + for frame in traceback.extract_tb(exc.__traceback__): + short_path = _rx_sitepkg_path.sub('/', frame.filename) + print(f' {short_path}:{frame.lineno} ({frame.name})', file=buf) + exc = exc.__context__ + return f'Traceback:\n{buf.getvalue()}' + + +def catch_unexpected(log, reraise_cancellation: bool = True, raven=None): + def _wrap(func): + + @functools.wraps(func) + async def _wrapped(*args, **kwargs): + try: + return await func(*args, **kwargs) + except asyncio.CancelledError: + if reraise_cancellation: + raise + except Exception: + if raven: + raven.captureException() + log.exception('unexpected error!') + raise + + return _wrapped + + return _wrap + + +def set_handler_attr(func, key, value): + attrs = getattr(func, '_backend_attrs', None) + if attrs is None: + attrs = {} + attrs[key] = value + setattr(func, '_backend_attrs', attrs) + + +def get_handler_attr(request, key, default=None): + # When used in the aiohttp server-side codes, we should use + # request.match_info.hanlder instead of handler passed to the middleware + # functions because aiohttp wraps this original handler with functools.partial + # multiple times to implement its internal middleware processing. + attrs = getattr(request.match_info.handler, '_backend_attrs', None) + if attrs is not None: + return attrs.get(key, default) + return default + + +async def not_impl_stub(request) -> web.Response: + raise QueryNotImplemented + + +def chunked(iterable, n): + it = iter(iterable) + while True: + chunk = tuple(itertools.islice(it, n)) + if not chunk: + return + yield chunk + + +_burst_last_call: float = 0.0 +_burst_times: MutableMapping[Hashable, float] = dict() +_burst_counts: MutableMapping[Hashable, int] = defaultdict(int) + + +async def call_non_bursty(key: Hashable, coro: Callable[[], Any], *, + max_bursts: int = 64, + max_idle: Union[int, float] = 100.0): + ''' + Execute a coroutine once upon max_bursts bursty invocations or max_idle + milliseconds after bursts smaller than max_bursts. + ''' + global _burst_last_call, _burst_calls, _burst_counts + if inspect.iscoroutine(coro): + # Coroutine objects may not be called before garbage-collected + # as this function throttles the frequency of invocation. + # That will generate a bogus warning by the asyncio's debug facility. + raise TypeError('You must pass coroutine function, not coroutine object.') + now = time.monotonic() + + if now - _burst_last_call > 3.0: + # garbage-collect keys + cleaned_keys = [] + for k, tick in _burst_times.items(): + if now - tick > (max_idle / 1e3): + cleaned_keys.append(k) + for k in cleaned_keys: + del _burst_times[k] + _burst_counts.pop(k, None) + + last_called = _burst_times.get(key, 0) + _burst_times[key] = now + _burst_last_call = now + invoke = False + + if now - last_called > (max_idle / 1e3): + invoke = True + _burst_counts.pop(key, None) + else: + _burst_counts[key] += 1 + if _burst_counts[key] >= max_bursts: + invoke = True + del _burst_counts[key] + + if invoke: + if inspect.iscoroutinefunction(coro): + return await coro() + else: + return coro() + + +class Singleton(type): + _instances: MutableMapping[Any, Any] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Undefined(metaclass=Singleton): + pass + + +undefined = Undefined() diff --git a/src/ai/backend/manager/api/vfolder.py b/src/ai/backend/manager/api/vfolder.py new file mode 100644 index 0000000000..c128dea0a9 --- /dev/null +++ b/src/ai/backend/manager/api/vfolder.py @@ -0,0 +1,2380 @@ +from __future__ import annotations + +import asyncio +import functools +import json +import logging +import math +import stat +import uuid +from datetime import datetime +from pathlib import Path +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Mapping, + MutableMapping, + Sequence, + Set, + TYPE_CHECKING, + Tuple, +) + +import aiohttp +from aiohttp import web +import aiohttp_cors +import sqlalchemy as sa +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.bgtask import ProgressReporter +from ai.backend.common.logging import BraceStyleAdapter + +from ..models import ( + agents, + kernels, + users, groups, keypairs, + vfolders, vfolder_invitations, vfolder_permissions, + AgentStatus, + KernelStatus, + VFolderInvitationState, + VFolderOwnershipType, + VFolderPermission, + VFolderPermissionValidator, + VFolderUsageMode, + UserRole, + query_accessible_vfolders, + query_owned_dotfiles, + get_allowed_vfolder_hosts_by_group, + get_allowed_vfolder_hosts_by_user, + verify_vfolder_name, +) +from .auth import admin_required, auth_required, superadmin_required +from .exceptions import ( + VFolderCreationFailed, VFolderNotFound, VFolderAlreadyExists, + GenericForbidden, ObjectNotFound, InvalidAPIParameters, ServerMisconfiguredError, + BackendAgentError, InternalServerError, GroupNotFound, +) +from .manager import ( + READ_ALLOWED, ALL_ALLOWED, + server_status_required, +) +from .resource import get_watcher_info +from .utils import check_api_params, get_user_scopes +if TYPE_CHECKING: + from .context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +VFolderRow = Mapping[str, Any] + + +def vfolder_permission_required(perm: VFolderPermission): + """ + Checks if the target vfolder exists and is either: + - owned by the current access key, or + - allowed accesses by the access key under the specified permission. + + The decorated handler should accept an extra argument + which contains a dict object describing the matched VirtualFolder table row. + """ + + # FIXME: replace ... with [web.Request, VFolderRow, Any...] in the future mypy + def _wrapper(handler: Callable[..., Awaitable[web.Response]]): + + @functools.wraps(handler) + async def _wrapped(request: web.Request, *args, **kwargs) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + domain_name = request['user']['domain_name'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + folder_name = request.match_info['name'] + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + vf_user_cond = None + vf_group_cond = None + if perm == VFolderPermission.READ_ONLY: + # if READ_ONLY is requested, any permission accepts. + invited_perm_cond = vfolder_permissions.c.permission.in_([ + VFolderPermission.READ_ONLY, + VFolderPermission.READ_WRITE, + VFolderPermission.RW_DELETE, + ]) + if not request['is_admin']: + vf_group_cond = vfolders.c.permission.in_([ + VFolderPermission.READ_ONLY, + VFolderPermission.READ_WRITE, + VFolderPermission.RW_DELETE, + ]) + elif perm == VFolderPermission.READ_WRITE: + invited_perm_cond = vfolder_permissions.c.permission.in_([ + VFolderPermission.READ_WRITE, + VFolderPermission.RW_DELETE, + ]) + if not request['is_admin']: + vf_group_cond = vfolders.c.permission.in_([ + VFolderPermission.READ_WRITE, + VFolderPermission.RW_DELETE, + ]) + elif perm == VFolderPermission.RW_DELETE: + # If RW_DELETE is requested, only RW_DELETE accepts. + invited_perm_cond = ( + vfolder_permissions.c.permission == VFolderPermission.RW_DELETE + ) + if not request['is_admin']: + vf_group_cond = ( + vfolders.c.permission == VFolderPermission.RW_DELETE + ) + else: + # Otherwise, just compare it as-is (for future compatibility). + invited_perm_cond = (vfolder_permissions.c.permission == perm) + if not request['is_admin']: + vf_group_cond = (vfolders.c.permission == perm) + async with root_ctx.db.begin() as conn: + entries = await query_accessible_vfolders( + conn, user_uuid, + user_role=user_role, domain_name=domain_name, + allowed_vfolder_types=allowed_vfolder_types, + extra_vf_conds=(vfolders.c.name == folder_name), + extra_vfperm_conds=invited_perm_cond, + extra_vf_user_conds=vf_user_cond, + extra_vf_group_conds=vf_group_cond, + ) + if len(entries) == 0: + raise VFolderNotFound( + 'Your operation may be permission denied.') + return await handler(request, entries[0], *args, **kwargs) + + return _wrapped + + return _wrapper + + +# FIXME: replace ... with [web.Request, VFolderRow, Any...] in the future mypy +def vfolder_check_exists(handler: Callable[..., Awaitable[web.Response]]): + """ + Checks if the target vfolder exists and is owned by the current user. + + The decorated handler should accept an extra "row" argument + which contains the matched VirtualFolder table row. + """ + + @functools.wraps(handler) + async def _wrapped(request: web.Request, *args, **kwargs) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + user_uuid = request['user']['uuid'] + folder_name = request.match_info['name'] + async with root_ctx.db.begin() as conn: + j = sa.join( + vfolders, vfolder_permissions, + vfolders.c.id == vfolder_permissions.c.vfolder, isouter=True) + query = ( + sa.select('*') + .select_from(j) + .where(((vfolders.c.user == user_uuid) | + (vfolder_permissions.c.user == user_uuid)) & + (vfolders.c.name == folder_name))) + try: + result = await conn.execute(query) + except sa.exc.DataError: + raise InvalidAPIParameters + row = result.first() + if row is None: + raise VFolderNotFound() + return await handler(request, row, *args, **kwargs) + + return _wrapped + + +@auth_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('name'): tx.Slug(allow_dot=True), + t.Key('host', default=None) >> 'folder_host': t.String | t.Null, + t.Key('usage_mode', default='general'): tx.Enum(VFolderUsageMode) | t.Null, + t.Key('permission', default='rw'): tx.Enum(VFolderPermission) | t.Null, + tx.AliasedKey(['unmanaged_path', 'unmanagedPath'], default=None): t.String | t.Null, + tx.AliasedKey(['group', 'groupId', 'group_id'], default=None): tx.UUID | t.String | t.Null, + t.Key('quota', default=None): tx.BinarySize | t.Null, + t.Key('cloneable', default=False): t.Bool, + }), +) +async def create(request: web.Request, params: Any) -> web.Response: + resp: Dict[str, Any] = {} + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + resource_policy = request['keypair']['resource_policy'] + domain_name = request['user']['domain_name'] + group_id_or_name = params['group'] + log.info('VFOLDER.CREATE (ak:{}, vf:{}, vfh:{}, umod:{}, perm:{})', + access_key, params['name'], params['folder_host'], + params['usage_mode'].value, params['permission'].value) + folder_host = params['folder_host'] + unmanaged_path = params['unmanaged_path'] + # Check if user is trying to created unmanaged vFolder + if unmanaged_path: + # Approve only if user is Admin or Superadmin + if user_role not in (UserRole.ADMIN, UserRole.SUPERADMIN): + raise GenericForbidden('Insufficient permission') + else: + # Resolve host for the new virtual folder. + if not folder_host: + folder_host = \ + await root_ctx.shared_config.etcd.get('volumes/default_host') + if not folder_host: + raise InvalidAPIParameters( + 'You must specify the vfolder host ' + 'because the default host is not configured.') + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + for vf_type in allowed_vfolder_types: + if vf_type not in ('user', 'group'): + raise ServerMisconfiguredError( + f'Invalid vfolder type(s): {str(allowed_vfolder_types)}.' + ' Only "user" or "group" is allowed.') + + if not verify_vfolder_name(params['name']): + raise InvalidAPIParameters(f'{params["name"]} is reserved for internal operations.') + if params['name'].startswith('.') and params['name'] != '.local': + if params['group'] is not None: + raise InvalidAPIParameters('dot-prefixed vfolders cannot be a group folder.') + + async with root_ctx.db.begin() as conn: + # Convert group name to uuid if group name is given. + if isinstance(group_id_or_name, str): + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == domain_name) + .where(groups.c.name == group_id_or_name) + ) + group_id = await conn.scalar(query) + else: + group_id = group_id_or_name + if not unmanaged_path: + # Check resource policy's allowed_vfolder_hosts + if group_id is not None: + allowed_hosts = await get_allowed_vfolder_hosts_by_group(conn, resource_policy, + domain_name, group_id) + else: + allowed_hosts = await get_allowed_vfolder_hosts_by_user(conn, resource_policy, + domain_name, user_uuid) + # TODO: handle legacy host lists assuming that volume names don't overlap? + if folder_host not in allowed_hosts: + raise InvalidAPIParameters('You are not allowed to use this vfolder host.') + + # Check resource policy's max_vfolder_count + if resource_policy['max_vfolder_count'] > 0: + query = (sa.select([sa.func.count()]) + .where(vfolders.c.user == user_uuid)) + result = await conn.scalar(query) + if result >= resource_policy['max_vfolder_count']: + raise InvalidAPIParameters('You cannot create more vfolders.') + + # Limit vfolder size quota if it is larger than max_vfolder_size of the resource policy. + max_vfolder_size = resource_policy.get('max_vfolder_size', 0) + if ( + max_vfolder_size > 0 + and ( + params['quota'] is None + or params['quota'] <= 0 + or params['quota'] > max_vfolder_size + ) + ): + params['quota'] = max_vfolder_size + + # Prevent creation of vfolder with duplicated name. + extra_vf_conds = [vfolders.c.name == params['name']] + if not unmanaged_path: + extra_vf_conds.append(vfolders.c.host == folder_host) + entries = await query_accessible_vfolders( + conn, user_uuid, + user_role=user_role, domain_name=domain_name, + allowed_vfolder_types=allowed_vfolder_types, + extra_vf_conds=(sa.and_(*extra_vf_conds)), + ) + if len(entries) > 0: + raise VFolderAlreadyExists + + # Check if group exists. + if group_id_or_name and group_id is None: + raise GroupNotFound + if group_id is not None: + if 'group' not in allowed_vfolder_types: + raise InvalidAPIParameters('group vfolder cannot be created in this host') + if not request['is_admin']: + raise GenericForbidden('no permission') + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == domain_name) + .where(groups.c.id == group_id) + ) + _gid = await conn.scalar(query) + if str(_gid) != str(group_id): + raise InvalidAPIParameters('No such group.') + else: + if 'user' not in allowed_vfolder_types: + raise InvalidAPIParameters('user vfolder cannot be created in this host') + try: + folder_id = uuid.uuid4() + if not unmanaged_path: + # Try to create actual only if vFolder is managed one + async with root_ctx.storage_manager.request( + folder_host, 'POST', 'folder/create', + json={ + 'volume': root_ctx.storage_manager.split_host(folder_host)[1], + 'vfid': str(folder_id), + 'options': {'quota': params['quota']}, + }, + ): + pass + except aiohttp.ClientResponseError: + raise VFolderCreationFailed + user_uuid = str(user_uuid) if group_id is None else None + group_uuid = str(group_id) if group_id is not None else None + ownership_type = 'group' if group_uuid is not None else 'user' + insert_values = { + 'id': folder_id.hex, + 'name': params['name'], + 'usage_mode': params['usage_mode'], + 'permission': params['permission'], + 'last_used': None, + 'max_size': int(params['quota'] / (2**20)) if params['quota'] else None, # in MBytes + 'host': folder_host, + 'creator': request['user']['email'], + 'ownership_type': VFolderOwnershipType(ownership_type), + 'user': user_uuid, + 'group': group_uuid, + 'unmanaged_path': '', + 'cloneable': params['cloneable'], + } + resp = { + 'id': folder_id.hex, + 'name': params['name'], + 'host': folder_host, + 'usage_mode': params['usage_mode'].value, + 'permission': params['permission'].value, + 'max_size': int(params['quota'] / (2**20)) if params['quota'] else None, # in MBytes + 'creator': request['user']['email'], + 'ownership_type': ownership_type, + 'user': user_uuid, + 'group': group_uuid, + 'cloneable': params['cloneable'], + } + if unmanaged_path: + insert_values.update({ + 'host': '', + 'unmanaged_path': unmanaged_path, + }) + resp['unmanaged_path'] = unmanaged_path + query = (sa.insert(vfolders, insert_values)) + try: + result = await conn.execute(query) + except sa.exc.DataError: + raise InvalidAPIParameters + assert result.rowcount == 1 + return web.json_response(resp, status=201) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('all', default=False): t.ToBool, + tx.AliasedKey(['group_id', 'groupId'], default=None): tx.UUID | t.String | t.Null, + tx.AliasedKey(['owner_user_email', 'ownerUserEmail'], default=None): t.Email | t.Null, + }), +) +async def list_folders(request: web.Request, params: Any) -> web.Response: + resp = [] + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + domain_name = request['user']['domain_name'] + + def make_entries(result, user_uuid) -> List[Dict[str, Any]]: + entries = [] + for row in result: + entries.append({ + 'name': row.vfolders_name, + 'id': row.vfolders_id, + 'host': row.vfolders_host, + 'usage_mode': row.vfolders_usage_mode, + 'created_at': row.vfolders_created_at, + 'is_owner': (row.vfolders_user == user_uuid), + 'permission': row.vfolders_permission, + 'user': str(row.vfolders_user) if row.vfolders_user else None, + 'group': str(row.vfolders_group) if row.vfolders_group else None, + 'creator': row.vfolders_creator, + 'user_email': row.users_email, + 'group_name': row.groups_name, + 'ownership_type': row.vfolders_ownership_type, + 'type': row.vfolders_ownership_type, # legacy + 'unmanaged_path': row.vfolders_unmanaged_path, + 'cloneable': row.vfolders_cloneable if row.vfolders_cloneable else False, + 'max_files': row.vfolders_max_files, + 'max_size': row.vfolders_max_size, + }) + return entries + + log.info('VFOLDER.LIST (ak:{})', access_key) + entries: List[Mapping[str, Any]] | Sequence[Mapping[str, Any]] + owner_user_uuid, owner_user_role = await get_user_scopes(request, params) + async with root_ctx.db.begin_readonly() as conn: + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + if params['all']: + raise InvalidAPIParameters("Deprecated use of 'all' option") + else: + extra_vf_conds = None + if params['group_id'] is not None: + # Note: user folders should be returned even when group_id is specified. + extra_vf_conds = ( + (vfolders.c.group == params['group_id']) | + (vfolders.c.user.isnot(None)) + ) + entries = await query_accessible_vfolders( + conn, + owner_user_uuid, + user_role=owner_user_role, + domain_name=domain_name, + allowed_vfolder_types=allowed_vfolder_types, + extra_vf_conds=extra_vf_conds, + ) + for entry in entries: + resp.append({ + 'name': entry['name'], + 'id': entry['id'].hex, + 'host': entry['host'], + 'usage_mode': entry['usage_mode'].value, + 'created_at': str(entry['created_at']), + 'is_owner': entry['is_owner'], + 'permission': entry['permission'].value, + 'user': str(entry['user']) if entry['user'] else None, + 'group': str(entry['group']) if entry['group'] else None, + 'creator': entry['creator'], + 'user_email': entry['user_email'], + 'group_name': entry['group_name'], + 'ownership_type': entry['ownership_type'].value, + 'type': entry['ownership_type'].value, # legacy + 'cloneable': entry['cloneable'], + 'max_files': entry['max_files'], + 'max_size': entry['max_size'], + }) + return web.json_response(resp, status=200) + + +@superadmin_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('id'): t.String, + }), +) +async def delete_by_id(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.DELETE_BY_ID (ak:{}, vf:{})', access_key, params['id']) + async with root_ctx.db.begin() as conn: + query = ( + sa.select([vfolders.c.host]) + .select_from(vfolders) + .where(vfolders.c.id == params['id']) + ) + folder_host = await conn.scalar(query) + folder_id = uuid.UUID(params['id']) + query = (sa.delete(vfolders).where(vfolders.c.id == folder_id)) + await conn.execute(query) + # fs-level deletion may fail or take longer time + # but let's complete the db transaction to reflect that it's deleted. + async with root_ctx.storage_manager.request( + folder_host, 'POST', 'folder/delete', + json={ + 'volume': root_ctx.storage_manager.split_host(folder_host)[1], + 'vfid': str(folder_id), + }, + ): + pass + return web.Response(status=204) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + tx.AliasedKey(['group_id', 'groupId'], default=None): tx.UUID | t.String | t.Null, + }), +) +async def list_hosts(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.LIST_HOSTS (ak:{})', access_key) + domain_name = request['user']['domain_name'] + group_id = params['group_id'] + domain_admin = request['user']['role'] == UserRole.ADMIN + resource_policy = request['keypair']['resource_policy'] + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + async with root_ctx.db.begin() as conn: + allowed_hosts: Set[str] = set() + if 'user' in allowed_vfolder_types: + allowed_hosts_by_user = await get_allowed_vfolder_hosts_by_user( + conn, resource_policy, domain_name, request['user']['uuid'], group_id) + allowed_hosts = allowed_hosts | allowed_hosts_by_user + if 'group' in allowed_vfolder_types: + allowed_hosts_by_group = await get_allowed_vfolder_hosts_by_group( + conn, resource_policy, domain_name, group_id, domain_admin=domain_admin) + allowed_hosts = allowed_hosts | allowed_hosts_by_group + all_volumes = await root_ctx.storage_manager.get_all_volumes() + all_hosts = {f"{proxy_name}:{volume_data['name']}" for proxy_name, volume_data in all_volumes} + allowed_hosts = allowed_hosts & all_hosts + default_host = await root_ctx.shared_config.get_raw('volumes/default_host') + if default_host not in allowed_hosts: + default_host = None + volume_info = { + f"{proxy_name}:{volume_data['name']}": { + 'backend': volume_data['backend'], + 'capabilities': volume_data['capabilities'], + } + for proxy_name, volume_data in all_volumes + if f"{proxy_name}:{volume_data['name']}" in allowed_hosts + } + resp = { + 'default': default_host, + 'allowed': sorted(allowed_hosts), + 'volume_info': volume_info, + } + return web.json_response(resp, status=200) + + +@superadmin_required +@server_status_required(READ_ALLOWED) +async def list_all_hosts(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.LIST_ALL_HOSTS (ak:{})', access_key) + all_volumes = await root_ctx.storage_manager.get_all_volumes() + all_hosts = {f"{proxy_name}:{volume_data['name']}" for proxy_name, volume_data in all_volumes} + default_host = await root_ctx.shared_config.get_raw('volumes/default_host') + if default_host not in all_hosts: + default_host = None + resp = { + 'default': default_host, + 'allowed': sorted(all_hosts), + } + return web.json_response(resp, status=200) + + +@superadmin_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('folder_host'): t.String, + })) +async def get_volume_perf_metric(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.VOLUME_PERF_METRIC (ak:{})', access_key) + proxy_name, volume_name = root_ctx.storage_manager.split_host(params['folder_host']) + async with root_ctx.storage_manager.request( + proxy_name, 'GET', 'volume/performance-metric', + json={ + 'volume': volume_name, + }, + ) as (_, storage_resp): + storage_reply = await storage_resp.json() + return web.json_response(storage_reply, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +async def list_allowed_types(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.LIST_ALLOWED_TYPES (ak:{})', access_key) + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + return web.json_response(allowed_vfolder_types, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_ONLY) +async def get_info(request: web.Request, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + resp: Dict[str, Any] = {} + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.GETINFO (ak:{}, vf:{})', access_key, folder_name) + if row['permission'] is None: + is_owner = True + permission = VFolderPermission.OWNER_PERM + else: + is_owner = row['is_owner'] + permission = row['permission'] + proxy_name, volume_name = root_ctx.storage_manager.split_host(row['host']) + async with root_ctx.storage_manager.request( + proxy_name, 'GET', 'folder/usage', + json={ + 'volume': volume_name, + 'vfid': str(row['id']), + }, + ) as (_, storage_resp): + usage = await storage_resp.json() + resp = { + 'name': row['name'], + 'id': row['id'].hex, + 'host': row['host'], + 'numFiles': usage['file_count'], # legacy + 'num_files': usage['file_count'], + 'used_bytes': usage['used_bytes'], # added in v20.09 + 'created': str(row['created_at']), # legacy + 'created_at': str(row['created_at']), + 'last_used': str(row['created_at']), + 'user': str(row['user']), + 'group': str(row['group']), + 'type': 'user' if row['user'] is not None else 'group', + 'is_owner': is_owner, + 'permission': permission, + 'usage_mode': row['usage_mode'], + 'cloneable': row['cloneable'], + 'max_size': row['max_size'], + } + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('folder_host'): t.String, + t.Key('id'): tx.UUID, + })) +async def get_quota(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + proxy_name, volume_name = root_ctx.storage_manager.split_host(params['folder_host']) + log.info('VFOLDER.GET_QUOTA (volume_name:{}, vf:{})', volume_name, params['id']) + + # Permission check for the requested vfolder. + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + domain_name = request['user']['domain_name'] + if user_role == UserRole.SUPERADMIN: + pass + else: + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + async with root_ctx.db.begin_readonly() as conn: + extra_vf_conds = [vfolders.c.id == params['id']] + entries = await query_accessible_vfolders( + conn, + user_uuid, + user_role=user_role, + domain_name=domain_name, + allowed_vfolder_types=allowed_vfolder_types, + extra_vf_conds=(sa.and_(*extra_vf_conds)), + ) + if len(entries) < 0: + raise VFolderNotFound('no such accessible vfolder') + + async with root_ctx.storage_manager.request( + proxy_name, 'GET', 'volume/quota', + json={ + 'volume': volume_name, + 'vfid': str(params['id']), + }, + ) as (_, storage_resp): + storage_reply = await storage_resp.json() + return web.json_response(storage_reply, status=200) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('folder_host'): t.String, + t.Key('id'): tx.UUID, + t.Key('input'): t.Mapping(t.String, t.Any), + }), +) +async def update_quota(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + proxy_name, volume_name = root_ctx.storage_manager.split_host(params['folder_host']) + quota = int(params['input']['size_bytes']) + log.info('VFOLDER.UPDATE_QUOTA (volume_name:{}, quota:{}, vf:{})', volume_name, quota, params['id']) + + # Permission check for the requested vfolder. + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + domain_name = request['user']['domain_name'] + if user_role == UserRole.SUPERADMIN: + pass + else: + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + async with root_ctx.db.begin_readonly() as conn: + extra_vf_conds = [vfolders.c.id == params['id']] + entries = await query_accessible_vfolders( + conn, + user_uuid, + user_role=user_role, + domain_name=domain_name, + allowed_vfolder_types=allowed_vfolder_types, + extra_vf_conds=(sa.and_(*extra_vf_conds)), + ) + if len(entries) < 0: + raise VFolderNotFound('no such accessible vfolder') + + # Limit vfolder size quota if it is larger than max_vfolder_size of the resource policy. + resource_policy = request['keypair']['resource_policy'] + max_vfolder_size = resource_policy.get('max_vfolder_size', 0) + if ( + max_vfolder_size > 0 + and ( + quota <= 0 + or quota > max_vfolder_size + ) + ): + quota = max_vfolder_size + + async with root_ctx.storage_manager.request( + proxy_name, 'PATCH', 'volume/quota', + json={ + 'volume': volume_name, + 'vfid': str(params['id']), + 'size_bytes': quota, + }, + ): + pass + + # Update the quota for the vfolder in DB. + async with root_ctx.db.begin() as conn: + query = ( + sa.update(vfolders) + .values(max_size=math.ceil(quota / 2**20)) # in Mbytes + .where(vfolders.c.id == params['id']) + ) + result = await conn.execute(query) + assert result.rowcount == 1 + + return web.json_response({'size_bytes': quota}, status=200) + + +@superadmin_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('folder_host'): t.String, + t.Key('id'): tx.UUID, + })) +async def get_usage(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + proxy_name, volume_name = root_ctx.storage_manager.split_host(params['folder_host']) + log.info('VFOLDER.GET_USAGE (volume_name:{}, vf:{})', volume_name, params['id']) + async with root_ctx.storage_manager.request( + proxy_name, 'GET', 'folder/usage', + json={ + 'volume': volume_name, + 'vfid': str(params['id']), + }, + ) as (_, storage_resp): + usage = await storage_resp.json() + return web.json_response(usage, status=200) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@vfolder_permission_required(VFolderPermission.OWNER_PERM) +@check_api_params( + t.Dict({ + t.Key('new_name'): tx.Slug(allow_dot=True), + })) +async def rename_vfolder(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + old_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + domain_name = request['user']['domain_name'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + new_name = params['new_name'] + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + log.info('VFOLDER.RENAME (ak:{}, vf.old:{}, vf.new:{})', + access_key, old_name, new_name) + async with root_ctx.db.begin() as conn: + entries = await query_accessible_vfolders( + conn, + user_uuid, + user_role=user_role, + domain_name=domain_name, + allowed_vfolder_types=allowed_vfolder_types, + ) + for entry in entries: + if entry['name'] == new_name: + raise InvalidAPIParameters( + 'One of your accessible vfolders already has ' + 'the name you requested.') + for entry in entries: + if entry['name'] == old_name: + if not entry['is_owner']: + raise InvalidAPIParameters( + 'Cannot change the name of a vfolder ' + 'that is not owned by myself.') + query = ( + sa.update(vfolders) + .values(name=new_name) + .where(vfolders.c.id == entry['id'])) + await conn.execute(query) + break + return web.Response(status=201) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@vfolder_permission_required(VFolderPermission.OWNER_PERM) +@check_api_params( + t.Dict({ + t.Key('cloneable', default=None): t.Bool | t.Null, + t.Key('permission', default=None): tx.Enum(VFolderPermission) | t.Null, + })) +async def update_vfolder_options(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + updated_fields = {} + if params['cloneable'] is not None and params['cloneable'] != row['cloneable']: + updated_fields['cloneable'] = params['cloneable'] + if params['permission'] is not None and params['permission'] != row['permission']: + updated_fields['permission'] = params['permission'] + if not row['is_owner']: + raise InvalidAPIParameters( + 'Cannot change the options of a vfolder ' + 'that is not owned by myself.') + + if len(updated_fields) > 0: + async with root_ctx.db.begin() as conn: + query = ( + sa.update(vfolders) + .values(**updated_fields) + .where(vfolders.c.id == row['id'])) + await conn.execute(query) + return web.Response(status=201) + + +@auth_required +@server_status_required(READ_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_WRITE) +@check_api_params( + t.Dict({ + t.Key('path'): t.String, + t.Key('parents', default=True): t.ToBool, + t.Key('exist_ok', default=False): t.ToBool, + })) +async def mkdir(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.MKDIR (ak:{}, vf:{}, path:{})', access_key, folder_name, params['path']) + proxy_name, volume_name = root_ctx.storage_manager.split_host(row['host']) + async with root_ctx.storage_manager.request( + proxy_name, 'POST', 'folder/file/mkdir', + json={ + 'volume': volume_name, + 'vfid': str(row['id']), + 'relpath': params['path'], + 'parents': params['parents'], + 'exist_ok': params['exist_ok'], + }, + ): + pass + return web.Response(status=201) + + +@auth_required +@server_status_required(READ_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_ONLY) +@check_api_params( + t.Dict({ + tx.AliasedKey(['path', 'file']): t.String, + t.Key('archive', default=False): t.ToBool, + })) +async def create_download_session(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + log_fmt = 'VFOLDER.CREATE_DOWNLOAD_SESSION(ak:{}, vf:{}, path:{})' + log_args = (request['keypair']['access_key'], row['name'], params['path']) + log.info(log_fmt, *log_args) + unmanaged_path = row['unmanaged_path'] + proxy_name, volume_name = root_ctx.storage_manager.split_host(row['host']) + async with root_ctx.storage_manager.request( + proxy_name, 'POST', 'folder/file/download', + json={ + 'volume': volume_name, + 'vfid': str(row['id']), + 'relpath': params['path'], + 'archive': params['archive'], + 'unmanaged_path': unmanaged_path if unmanaged_path else None, + }, + ) as (client_api_url, storage_resp): + storage_reply = await storage_resp.json() + resp = { + 'token': storage_reply['token'], + 'url': str(client_api_url / 'download'), + } + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_WRITE) +@check_api_params( + t.Dict({ + t.Key('path'): t.String, + t.Key('size'): t.ToInt, + })) +async def create_upload_session(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + log_fmt = 'VFOLDER.CREATE_UPLOAD_SESSION (ak:{}, vf:{})' + log_args = (access_key, folder_name) + log.info(log_fmt, *log_args) + proxy_name, volume_name = root_ctx.storage_manager.split_host(row['host']) + async with root_ctx.storage_manager.request( + proxy_name, 'POST', 'folder/file/upload', + json={ + 'volume': volume_name, + 'vfid': str(row['id']), + 'relpath': params['path'], + 'size': params['size'], + }, + ) as (client_api_url, storage_resp): + storage_reply = await storage_resp.json() + resp = { + 'token': storage_reply['token'], + 'url': str(client_api_url / 'upload'), + } + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_WRITE) +@check_api_params( + t.Dict({ + t.Key('target_path'): t.String, + t.Key('new_name'): t.String, + t.Key('is_dir', default=False): t.ToBool, # ignored since 22.03 + })) +async def rename_file(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.RENAME_FILE (ak:{}, vf:{}, target_path:{}, new_name:{})', + access_key, folder_name, params['target_path'], params['new_name']) + proxy_name, volume_name = root_ctx.storage_manager.split_host(row['host']) + async with root_ctx.storage_manager.request( + proxy_name, 'POST', 'folder/file/rename', + json={ + 'volume': volume_name, + 'vfid': str(row['id']), + 'relpath': params['target_path'], + 'new_name': params['new_name'], + }, + ): + pass + return web.json_response({}, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_WRITE) +@check_api_params( + t.Dict({ + t.Key('src'): t.String, + t.Key('dst'): t.String, + })) +async def move_file(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.MOVE_FILE (ak:{}, vf:{}, src:{}, dst:{})', + access_key, folder_name, params['src'], params['dst']) + proxy_name, volume_name = root_ctx.storage_manager.split_host(row['host']) + async with root_ctx.storage_manager.request( + proxy_name, 'POST', 'folder/file/move', + json={ + 'volume': volume_name, + 'vfid': str(row['id']), + 'src_relpath': params['src'], + 'dst_relpath': params['dst'], + }, + ): + pass + return web.json_response({}, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_WRITE) +@check_api_params( + t.Dict({ + t.Key('files'): t.List(t.String), + t.Key('recursive', default=False): t.ToBool, + })) +async def delete_files(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + recursive = params['recursive'] + log.info('VFOLDER.DELETE_FILES (ak:{}, vf:{}, path:{}, recursive:{})', + access_key, folder_name, folder_name, recursive) + proxy_name, volume_name = root_ctx.storage_manager.split_host(row['host']) + async with root_ctx.storage_manager.request( + proxy_name, 'POST', 'folder/file/delete', + json={ + 'volume': volume_name, + 'vfid': str(row['id']), + 'relpaths': params['files'], + 'recursive': recursive, + }, + ): + pass + return web.json_response({}, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_ONLY) +@check_api_params( + t.Dict({ + t.Key('path', default=''): t.String(allow_blank=True), + })) +async def list_files(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.LIST_FILES (ak:{}, vf:{}, path:{})', + access_key, folder_name, params['path']) + proxy_name, volume_name = root_ctx.storage_manager.split_host(row['host']) + async with root_ctx.storage_manager.request( + proxy_name, 'POST', 'folder/file/list', + json={ + 'volume': volume_name, + 'vfid': str(row['id']), + 'relpath': params['path'], + }, + ) as (_, storage_resp): + result = await storage_resp.json() + resp = { + 'items': [ + { + 'name': item['name'], + 'type': item['type'], + 'size': item['stat']['size'], # humanize? + 'mode': oct(item['stat']['mode'])[2:][-3:], + 'created': item['stat']['created'], + 'modified': item['stat']['modified'], + } + for item in result['items'] + ], + 'files': json.dumps([ # for legacy (to be removed in 21.03) + { + 'filename': item['name'], + 'size': item['stat']['size'], + 'mode': stat.filemode(item['stat']['mode']), + 'ctime': datetime.fromisoformat(item['stat']['created']).timestamp(), + 'atime': 0, + 'mtime': datetime.fromisoformat(item['stat']['modified']).timestamp(), + } + for item in result['items'] + ]), + } + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(READ_ALLOWED) +async def list_sent_invitations(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.LIST_SENT_INVITATIONS (ak:{})', access_key) + async with root_ctx.db.begin() as conn: + j = sa.join(vfolders, vfolder_invitations, + vfolders.c.id == vfolder_invitations.c.vfolder) + query = ( + sa.select([vfolder_invitations, vfolders.c.name]) + .select_from(j) + .where( + (vfolder_invitations.c.inviter == request['user']['email']) & + (vfolder_invitations.c.state == VFolderInvitationState.PENDING), + ) + ) + result = await conn.execute(query) + invitations = result.fetchall() + invs_info = [] + for inv in invitations: + invs_info.append({ + 'id': str(inv.id), + 'inviter': inv.inviter, + 'invitee': inv.invitee, + 'perm': inv.permission, + 'state': inv.state.value, + 'created_at': str(inv.created_at), + 'modified_at': str(inv.modified_at), + 'vfolder_id': str(inv.vfolder), + 'vfolder_name': inv.name, + }) + resp = {'invitations': invs_info} + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + tx.AliasedKey(['perm', 'permission']): VFolderPermissionValidator, + }), +) +async def update_invitation(request: web.Request, params: Any) -> web.Response: + """ + Update sent invitation's permission. Other fields are not allowed to be updated. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + inv_id = request.match_info['inv_id'] + perm = params['perm'] + log.info('VFOLDER.UPDATE_INVITATION (ak:{}, inv:{})', access_key, inv_id) + async with root_ctx.db.begin() as conn: + query = ( + sa.update(vfolder_invitations) + .values(permission=perm) + .where( + (vfolder_invitations.c.id == inv_id) & + (vfolder_invitations.c.inviter == request['user']['email']) & + (vfolder_invitations.c.state == VFolderInvitationState.PENDING), + ) + ) + await conn.execute(query) + resp = {'msg': f'vfolder invitation updated: {inv_id}.'} + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + tx.AliasedKey(['perm', 'permission'], default='rw'): VFolderPermissionValidator, + tx.AliasedKey(['emails', 'user_ids', 'userIDs']): t.List(t.String), + }), +) +async def invite(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + user_uuid = request['user']['uuid'] + perm = params['perm'] + invitee_emails = params['emails'] + log.info('VFOLDER.INVITE (ak:{}, vf:{}, inv.users:{})', + access_key, folder_name, ','.join(invitee_emails)) + if folder_name.startswith('.'): + raise GenericForbidden('Cannot share private dot-prefixed vfolders.') + async with root_ctx.db.begin() as conn: + # Get virtual folder. + query = (sa.select('*') + .select_from(vfolders) + .where((vfolders.c.user == user_uuid) & + (vfolders.c.name == folder_name))) + try: + result = await conn.execute(query) + except sa.exc.DataError: + raise InvalidAPIParameters + vf = result.first() + if vf is None: + raise VFolderNotFound() + + # Get invited user's keypairs except vfolder owner. + query = ( + sa.select([keypairs.c.user_id, keypairs.c.user]) + .select_from(keypairs) + .where(keypairs.c.user_id.in_(invitee_emails)) + .where(keypairs.c.user_id != request['user']['email']) + ) + try: + result = await conn.execute(query) + except sa.exc.DataError: + raise InvalidAPIParameters + kps = result.fetchall() + if len(kps) < 1: + raise ObjectNotFound(object_name='vfolder invitation') + + # Prevent inviting user who already share the target folder. + invitee_uuids = [kp.user for kp in kps] + j = sa.join(vfolders, vfolder_permissions, + vfolders.c.id == vfolder_permissions.c.vfolder) + query = ( + sa.select([sa.func.count()]) + .select_from(j) + .where( + ( + vfolders.c.user.in_(invitee_uuids) | + vfolder_permissions.c.user.in_(invitee_uuids) + ) & + (vfolders.c.name == folder_name), + ) + ) + result = await conn.execute(query) + if result.scalar() > 0: + raise VFolderAlreadyExists + + # Create invitation. + invitees = [kp.user_id for kp in kps] + invited_ids = [] + for invitee in set(invitees): + inviter = request['user']['id'] + # Do not create invitation if already exists. + query = ( + sa.select([sa.func.count()]) + .select_from(vfolder_invitations) + .where( + (vfolder_invitations.c.inviter == inviter) & + (vfolder_invitations.c.invitee == invitee) & + (vfolder_invitations.c.vfolder == vf.id) & + (vfolder_invitations.c.state == VFolderInvitationState.PENDING), + ) + ) + result = await conn.execute(query) + if result.scalar() > 0: + continue + + # TODO: insert multiple values with one query. + # insert().values([{}, {}, ...]) does not work: + # sqlalchemy.exc.CompileError: The 'default' dialect with current + # database version settings does not support in-place multirow + # inserts. + query = (sa.insert(vfolder_invitations, { + 'id': uuid.uuid4().hex, + 'permission': perm, + 'vfolder': vf.id, + 'inviter': inviter, + 'invitee': invitee, + 'state': VFolderInvitationState.PENDING, + })) + try: + await conn.execute(query) + invited_ids.append(invitee) + except sa.exc.DataError: + pass + resp = {'invited_ids': invited_ids} + return web.json_response(resp, status=201) + + +@auth_required +@server_status_required(READ_ALLOWED) +async def invitations(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.INVITATIONS (ak:{})', access_key) + async with root_ctx.db.begin() as conn: + j = sa.join(vfolders, vfolder_invitations, + vfolders.c.id == vfolder_invitations.c.vfolder) + query = ( + sa.select([vfolder_invitations, vfolders.c.name]) + .select_from(j) + .where( + (vfolder_invitations.c.invitee == request['user']['id']) & + (vfolder_invitations.c.state == VFolderInvitationState.PENDING), + ) + ) + result = await conn.execute(query) + invitations = result.fetchall() + invs_info = [] + for inv in invitations: + invs_info.append({ + 'id': str(inv.id), + 'inviter': inv.inviter, + 'invitee': inv.invitee, + 'perm': inv.permission, + 'state': inv.state, + 'created_at': str(inv.created_at), + 'modified_at': str(inv.modified_at), + 'vfolder_id': str(inv.vfolder), + 'vfolder_name': inv.name, + }) + resp = {'invitations': invs_info} + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('inv_id'): t.String, + }), +) +async def accept_invitation(request: web.Request, params: Any) -> web.Response: + """Accept invitation by invitee. + + * `inv_ak` parameter is removed from 19.06 since virtual folder's ownership is + moved from keypair to a user or a group. + + :param inv_id: ID of vfolder_invitations row. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + user_uuid = request['user']['uuid'] + inv_id = params['inv_id'] + log.info('VFOLDER.ACCEPT_INVITATION (ak:{}, inv:{})', access_key, inv_id) + async with root_ctx.db.begin() as conn: + # Get invitation. + query = ( + sa.select([vfolder_invitations]) + .select_from(vfolder_invitations) + .where( + (vfolder_invitations.c.id == inv_id) & + (vfolder_invitations.c.state == VFolderInvitationState.PENDING), + ) + ) + result = await conn.execute(query) + invitation = result.first() + if invitation is None: + raise ObjectNotFound(object_name='vfolder invitation') + + # Get target virtual folder. + query = ( + sa.select([vfolders.c.name]) + .select_from(vfolders) + .where(vfolders.c.id == invitation.vfolder) + ) + result = await conn.execute(query) + target_vfolder = result.first() + if target_vfolder is None: + raise VFolderNotFound + + # Prevent accepting vfolder with duplicated name. + j = sa.join( + vfolders, vfolder_permissions, + vfolders.c.id == vfolder_permissions.c.vfolder, + isouter=True, + ) + query = ( + sa.select([sa.func.count()]) + .select_from(j) + .where( + ((vfolders.c.user == user_uuid) | + (vfolder_permissions.c.user == user_uuid)) & + (vfolders.c.name == target_vfolder.name), + ) + ) + result = await conn.execute(query) + if result.scalar() > 0: + raise VFolderAlreadyExists + + # Create permission relation between the vfolder and the invitee. + query = (sa.insert(vfolder_permissions, { + 'permission': VFolderPermission(invitation.permission), + 'vfolder': invitation.vfolder, + 'user': user_uuid, + })) + await conn.execute(query) + + # Clear used invitation. + query = ( + sa.update(vfolder_invitations) + .where(vfolder_invitations.c.id == inv_id) + .values(state=VFolderInvitationState.ACCEPTED) + ) + await conn.execute(query) + return web.json_response({}) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('inv_id'): t.String, + })) +async def delete_invitation(request: web.Request, params: Any) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + request_email = request['user']['email'] + inv_id = params['inv_id'] + log.info('VFOLDER.DELETE_INVITATION (ak:{}, inv:{})', access_key, inv_id) + try: + async with root_ctx.db.begin() as conn: + query = ( + sa.select([ + vfolder_invitations.c.inviter, + vfolder_invitations.c.invitee, + ]) + .select_from(vfolder_invitations) + .where( + (vfolder_invitations.c.id == inv_id) & + (vfolder_invitations.c.state == VFolderInvitationState.PENDING), + ) + ) + result = await conn.execute(query) + row = result.first() + if row is None: + raise ObjectNotFound(object_name='vfolder invitation') + if request_email == row.inviter: + state = VFolderInvitationState.CANCELED + elif request_email == row.invitee: + state = VFolderInvitationState.REJECTED + else: + raise GenericForbidden('Cannot change other user\'s invitaiton') + query = ( + sa.update(vfolder_invitations) + .values(state=state) + .where(vfolder_invitations.c.id == inv_id) + ) + await conn.execute(query) + except sa.exc.IntegrityError as e: + raise InternalServerError(f'integrity error: {e}') + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + raise InternalServerError(f'unexpected error: {e}') + return web.json_response({}) + + +@admin_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('permission', default='rw'): VFolderPermissionValidator, + t.Key('emails'): t.List(t.String), + }), +) +async def share(request: web.Request, params: Any) -> web.Response: + """ + Share a group folder to users with overriding permission. + + This will create vfolder_permission(s) relation directly without + creating invitation(s). Only group-type vfolders are allowed to + be shared directly. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + folder_name = request.match_info['name'] + log.info('VFOLDER.SHARE (ak:{}, vf:{}, perm:{}, users:{})', + access_key, folder_name, params['permission'], ','.join(params['emails'])) + async with root_ctx.db.begin() as conn: + from ..models import association_groups_users as agus + + # Get the group-type virtual folder. + query = ( + sa.select([vfolders.c.id, vfolders.c.ownership_type, vfolders.c.group]) + .select_from(vfolders) + .where( + (vfolders.c.ownership_type == VFolderOwnershipType.GROUP) & + (vfolders.c.name == folder_name), + ) + ) + result = await conn.execute(query) + vf_infos = result.fetchall() + if len(vf_infos) < 1: + raise VFolderNotFound('Only project folders are directly sharable.') + if len(vf_infos) > 1: + raise InternalServerError(f'Multiple project folders found: {folder_name}') + vf_info = vf_infos[0] + + # Convert users' emails to uuids and check if user belong to the group of vfolder. + j = users.join(agus, users.c.uuid == agus.c.user_id) + query = ( + sa.select([users.c.uuid, users.c.email]) + .select_from(j) + .where( + (users.c.email.in_(params['emails'])) & + (users.c.email != request['user']['email']) & + (agus.c.group_id == vf_info['group']), + ) + ) + result = await conn.execute(query) + user_info = result.fetchall() + users_to_share = [u['uuid'] for u in user_info] + emails_to_share = [u['email'] for u in user_info] + if len(user_info) < 1: + raise ObjectNotFound(object_name='user') + if len(user_info) < len(params['emails']): + users_not_in_vfolder_group = list(set(params['emails']) - set(emails_to_share)) + raise ObjectNotFound( + 'Some user does not belong to folder\'s group: ' + ','.join(users_not_in_vfolder_group), + object_name='user', + ) + + # Do not share to users who have already been shared the folder. + query = ( + sa.select([vfolder_permissions.c.user]) + .select_from(vfolder_permissions) + .where( + (vfolder_permissions.c.user.in_(users_to_share)) & + (vfolder_permissions.c.vfolder == vf_info['id']), + ) + ) + result = await conn.execute(query) + users_not_to_share = [u.user for u in result.fetchall()] + users_to_share = list(set(users_to_share) - set(users_not_to_share)) + + # Create vfolder_permission(s). + for _user in users_to_share: + query = (sa.insert(vfolder_permissions, { + 'permission': params['permission'], + 'vfolder': vf_info['id'], + 'user': _user, + })) + await conn.execute(query) + # Update existing vfolder_permission(s). + for _user in users_not_to_share: + query = ( + sa.update(vfolder_permissions) + .values(permission=params['permission']) + .where(vfolder_permissions.c.vfolder == vf_info['id']) + .where(vfolder_permissions.c.user == _user) + ) + await conn.execute(query) + + return web.json_response({'shared_emails': emails_to_share}, status=201) + + +@admin_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('emails'): t.List(t.String), + }), +) +async def unshare(request: web.Request, params: Any) -> web.Response: + """ + Unshare a group folder from users. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + folder_name = request.match_info['name'] + log.info('VFOLDER.UNSHARE (ak:{}, vf:{}, users:{})', + access_key, folder_name, ','.join(params['emails'])) + async with root_ctx.db.begin() as conn: + # Get the group-type virtual folder. + query = ( + sa.select([vfolders.c.id]) + .select_from(vfolders) + .where( + (vfolders.c.ownership_type == VFolderOwnershipType.GROUP) & + (vfolders.c.name == folder_name), + ) + ) + result = await conn.execute(query) + vf_infos = result.fetchall() + if len(vf_infos) < 1: + raise VFolderNotFound('Only project folders are directly unsharable.') + if len(vf_infos) > 1: + raise InternalServerError(f'Multiple project folders found: {folder_name}') + vf_info = vf_infos[0] + + # Convert users' emails to uuids. + query = ( + sa.select([users.c.uuid]) + .select_from(users) + .where(users.c.email.in_(params['emails'])) + ) + result = await conn.execute(query) + users_to_unshare = [u['uuid'] for u in result.fetchall()] + if len(users_to_unshare) < 1: + raise ObjectNotFound(object_name='user(s).') + + # Delete vfolder_permission(s). + query = ( + sa.delete(vfolder_permissions) + .where( + (vfolder_permissions.c.vfolder == vf_info['id']) & + (vfolder_permissions.c.user.in_(users_to_unshare)), + ) + ) + await conn.execute(query) + return web.json_response({'unshared_emails': params['emails']}, status=200) + + +@auth_required +@server_status_required(ALL_ALLOWED) +async def delete(request: web.Request) -> web.Response: + root_ctx: RootContext = request.app['_root.context'] + folder_name = request.match_info['name'] + access_key = request['keypair']['access_key'] + domain_name = request['user']['domain_name'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + log.info('VFOLDER.DELETE (ak:{}, vf:{})', access_key, folder_name) + async with root_ctx.db.begin() as conn: + entries = await query_accessible_vfolders( + conn, + user_uuid, + user_role=user_role, + domain_name=domain_name, + allowed_vfolder_types=allowed_vfolder_types, + ) + for entry in entries: + if entry['name'] == folder_name: + # Folder owner OR user who have DELETE permission can delete folder. + if ( + not entry['is_owner'] + and entry['permission'] != VFolderPermission.RW_DELETE + ): + raise InvalidAPIParameters( + 'Cannot delete the vfolder ' + 'that is not owned by myself.') + break + else: + raise InvalidAPIParameters('No such vfolder.') + folder_host = entry['host'] + folder_id = entry['id'] + query = (sa.delete(vfolders).where(vfolders.c.id == folder_id)) + await conn.execute(query) + # fs-level deletion may fail or take longer time + # but let's complete the db transaction to reflect that it's deleted. + proxy_name, volume_name = root_ctx.storage_manager.split_host(folder_host) + async with root_ctx.storage_manager.request( + proxy_name, 'POST', 'folder/delete', + json={ + 'volume': volume_name, + 'vfid': str(folder_id), + }, + ): + pass + return web.Response(status=204) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_ONLY) +async def leave(request: web.Request, row: VFolderRow) -> web.Response: + """ + Leave a shared vfolder. + + Cannot leave a group vfolder or a vfolder that the requesting user owns. + """ + if row['ownership_type'] == VFolderOwnershipType.GROUP: + raise InvalidAPIParameters('Cannot leave a group vfolder.') + if row['is_owner']: + raise InvalidAPIParameters('Cannot leave a vfolder owned by the requesting user.') + + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + user_uuid = request['user']['uuid'] + vfolder_id = row['id'] + perm = None + log.info('VFOLDER.LEAVE(ak:{}, vfid:{}, uid:{}, perm:{})', + access_key, vfolder_id, user_uuid, perm) + async with root_ctx.db.begin() as conn: + query = ( + sa.delete(vfolder_permissions) + .where(vfolder_permissions.c.vfolder == vfolder_id) + .where(vfolder_permissions.c.user == user_uuid) + ) + await conn.execute(query) + resp = {'msg': 'left the shared vfolder'} + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@vfolder_permission_required(VFolderPermission.READ_ONLY) +@check_api_params( + t.Dict({ + t.Key('cloneable', default=False): t.Bool, + t.Key('target_name'): tx.Slug(allow_dot=True), + t.Key('target_host', default=None) >> 'folder_host': t.String | t.Null, + t.Key('usage_mode', default='general'): tx.Enum(VFolderUsageMode) | t.Null, + t.Key('permission', default='rw'): tx.Enum(VFolderPermission) | t.Null, + }), +) +async def clone(request: web.Request, params: Any, row: VFolderRow) -> web.Response: + resp: Dict[str, Any] = {} + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + user_role = request['user']['role'] + user_uuid = request['user']['uuid'] + resource_policy = request['keypair']['resource_policy'] + domain_name = request['user']['domain_name'] + log.info('VFOLDER.CLONE (ak:{}, vf:{}, vft:{}, vfh:{}, umod:{}, perm:{})', + access_key, row['name'], params['target_name'], params['folder_host'], + params['usage_mode'].value, params['permission'].value) + source_folder_host = row['host'] + target_folder_host = params['folder_host'] + source_proxy_name, source_volume_name = root_ctx.storage_manager.split_host(source_folder_host) + target_proxy_name, target_volume_name = root_ctx.storage_manager.split_host(target_folder_host) + + # check if the source vfolder is allowed to be cloned + if not row['cloneable']: + raise GenericForbidden('The source vfolder is not permitted to be cloned.') + + if not target_folder_host: + target_folder_host = \ + await root_ctx.shared_config.etcd.get('volumes/default_host') + if not target_folder_host: + raise InvalidAPIParameters( + 'You must specify the vfolder host ' + 'because the default host is not configured.') + + allowed_vfolder_types = await root_ctx.shared_config.get_vfolder_types() + for vf_type in allowed_vfolder_types: + if vf_type not in ('user', 'group'): + raise ServerMisconfiguredError( + f'Invalid vfolder type(s): {str(allowed_vfolder_types)}.' + ' Only "user" or "group" is allowed.') + + if not verify_vfolder_name(params['target_name']): + raise InvalidAPIParameters(f'{params["target_name"]} is reserved for internal operations.') + + if source_proxy_name != target_proxy_name: + raise InvalidAPIParameters('proxy name of source and target vfolders must be equal.') + + async with root_ctx.db.begin() as conn: + allowed_hosts = await get_allowed_vfolder_hosts_by_user( + conn, resource_policy, domain_name, user_uuid, + ) + # TODO: handle legacy host lists assuming that volume names don't overlap? + if target_folder_host not in allowed_hosts: + raise InvalidAPIParameters('You are not allowed to use this vfolder host.') + + # Check resource policy's max_vfolder_count + if resource_policy['max_vfolder_count'] > 0: + query = ( + sa.select([sa.func.count()]) + .where(vfolders.c.user == user_uuid) + ) + result = await conn.scalar(query) + if result >= resource_policy['max_vfolder_count']: + raise InvalidAPIParameters('You cannot create more vfolders.') + + # Prevent creation of vfolder with duplicated name. + extra_vf_conds = [vfolders.c.name == params['target_name']] + extra_vf_conds.append(vfolders.c.host == target_folder_host) + entries = await query_accessible_vfolders( + conn, user_uuid, + user_role=user_role, domain_name=domain_name, + allowed_vfolder_types=allowed_vfolder_types, + extra_vf_conds=(sa.and_(*extra_vf_conds)), + ) + if len(entries) > 0: + raise VFolderAlreadyExists + if params['target_name'].startswith('.'): + dotfiles, _ = await query_owned_dotfiles(conn, access_key) + for dotfile in dotfiles: + if params['target_name'] == dotfile['path']: + raise InvalidAPIParameters('vFolder name conflicts with your dotfile.') + + if 'user' not in allowed_vfolder_types: + raise InvalidAPIParameters('user vfolder cannot be created in this host') + + # Generate the ID of the destination vfolder. + # TODO: If we refactor to use ORM, the folder ID will be created from the database by inserting + # the actual object (with RETURNING clause). In that case, we need to temporarily + # mark the object to be "unusable-yet" until the storage proxy craetes the destination + # vfolder. After done, we need to make another transaction to clear the unusable state. + folder_id = uuid.uuid4() + + # Create the destination vfolder. + # (assuming that this operation finishes quickly!) + # TODO: copy vfolder options + async with root_ctx.storage_manager.request( + target_folder_host, 'POST', 'folder/create', + json={ + 'volume': target_volume_name, + 'vfid': str(folder_id), + # 'options': {'quota': params['quota']}, + }, + ): + pass + + # Insert the row for the destination vfolder. + user_uuid = str(user_uuid) + group_uuid = None + ownership_type = 'user' + insert_values = { + 'id': folder_id, + 'name': params['target_name'], + 'usage_mode': params['usage_mode'], + 'permission': params['permission'], + 'last_used': None, + 'host': target_folder_host, + 'creator': request['user']['email'], + 'ownership_type': VFolderOwnershipType(ownership_type), + 'user': user_uuid, + 'group': group_uuid, + 'unmanaged_path': '', + 'cloneable': params['cloneable'], + } + insert_query = sa.insert(vfolders, insert_values) + try: + result = await conn.execute(insert_query) + except sa.exc.DataError: + # TODO: pass exception info + raise InvalidAPIParameters + + # Start the clone operation as a background task. + async def _clone_bgtask(reporter: ProgressReporter) -> None: + async with root_ctx.storage_manager.request( + source_folder_host, 'POST', 'folder/clone', + json={ + 'src_volume': source_volume_name, + 'src_vfid': str(row['id']), + 'dst_volume': target_volume_name, + 'dst_vfid': str(folder_id), + }, + ): + pass + + task_id = await root_ctx.background_task_manager.start(_clone_bgtask) + + # Return the information about the destination vfolder. + resp = { + 'id': folder_id.hex, + 'name': params['target_name'], + 'host': target_folder_host, + 'usage_mode': params['usage_mode'].value, + 'permission': params['permission'].value, + 'creator': request['user']['email'], + 'ownership_type': ownership_type, + 'user': user_uuid, + 'group': group_uuid, + 'cloneable': params['cloneable'], + 'bgtask_id': str(task_id), + } + return web.json_response(resp, status=201) + + +@auth_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + tx.AliasedKey(['vfolder_id', 'vfolderId']): tx.UUID, + }), +) +async def list_shared_vfolders(request: web.Request, params: Any) -> web.Response: + """ + List shared vfolders. + + Not available for group vfolders. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + target_vfid = params['vfolder_id'] + log.info('VFOLDER.LIST_SHARED_VFOLDERS (ak:{})', access_key) + async with root_ctx.db.begin() as conn: + j = ( + vfolder_permissions + .join(vfolders, vfolders.c.id == vfolder_permissions.c.vfolder) + .join(users, users.c.uuid == vfolder_permissions.c.user) + ) + query = ( + sa.select([vfolder_permissions, vfolders.c.id, vfolders.c.name, users.c.email]) + .select_from(j) + ) + if target_vfid is not None: + query = query.where(vfolders.c.id == target_vfid) + result = await conn.execute(query) + shared_list = result.fetchall() + shared_info = [] + for shared in shared_list: + shared_info.append({ + 'vfolder_id': str(shared.id), + 'vfolder_name': str(shared.name), + 'shared_by': request['user']['email'], + 'shared_to': { + 'uuid': str(shared.user), + 'email': shared.email, + }, + 'perm': shared.permission.value, + }) + resp = {'shared': shared_info} + return web.json_response(resp, status=200) + + +@auth_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('vfolder'): tx.UUID, + t.Key('user'): tx.UUID, + tx.AliasedKey(['perm', 'permission']): VFolderPermissionValidator | t.Null, + }), +) +async def update_shared_vfolder(request: web.Request, params: Any) -> web.Response: + """ + Update permission for shared vfolders. + + If params['perm'] is None, remove user's permission for the vfolder. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + vfolder_id = params['vfolder'] + user_uuid = params['user'] + perm = params['perm'] + log.info('VFOLDER.UPDATE_SHARED_VFOLDER(ak:{}, vfid:{}, uid:{}, perm:{})', + access_key, vfolder_id, user_uuid, perm) + async with root_ctx.db.begin() as conn: + if perm is not None: + query = ( + sa.update(vfolder_permissions) + .values(permission=perm) + .where(vfolder_permissions.c.vfolder == vfolder_id) + .where(vfolder_permissions.c.user == user_uuid) + ) + else: + query = ( + sa.delete(vfolder_permissions) + .where(vfolder_permissions.c.vfolder == vfolder_id) + .where(vfolder_permissions.c.user == user_uuid) + ) + await conn.execute(query) + resp = {'msg': 'shared vfolder permission updated'} + return web.json_response(resp, status=200) + + +@superadmin_required +@server_status_required(READ_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('fstab_path', default=None): t.String | t.Null, + t.Key('agent_id', default=None): t.String | t.Null, + }), +) +async def get_fstab_contents(request: web.Request, params: Any) -> web.Response: + """ + Return the contents of `/etc/fstab` file. + """ + access_key = request['keypair']['access_key'] + log.info('VFOLDER.GET_FSTAB_CONTENTS(ak:{}, ag:{})', access_key, params['agent_id']) + if params['fstab_path'] is None: + params['fstab_path'] = '/etc/fstab' + if params['agent_id'] is not None: + # Return specific agent's fstab. + watcher_info = await get_watcher_info(request, params['agent_id']) + try: + client_timeout = aiohttp.ClientTimeout(total=10.0) + async with aiohttp.ClientSession(timeout=client_timeout) as sess: + headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} + url = watcher_info['addr'] / 'fstab' + async with sess.get(url, headers=headers, params=params) as watcher_resp: + if watcher_resp.status == 200: + content = await watcher_resp.text() + resp = { + 'content': content, + 'node': 'agent', + 'node_id': params['agent_id'], + } + return web.json_response(resp) + else: + message = await watcher_resp.text() + raise BackendAgentError( + 'FAILURE', f'({watcher_resp.status}: {watcher_resp.reason}) {message}') + except asyncio.CancelledError: + raise + except asyncio.TimeoutError: + log.error('VFOLDER.GET_FSTAB_CONTENTS(u:{}): timeout from watcher (agent:{})', + access_key, params['agent_id']) + raise BackendAgentError('TIMEOUT', 'Could not fetch fstab data from agent') + except Exception: + log.exception('VFOLDER.GET_FSTAB_CONTENTS(u:{}): ' + 'unexpected error while reading from watcher (agent:{})', + access_key, params['agent_id']) + raise InternalServerError + else: + resp = { + 'content': + "# Since Backend.AI 20.09, reading the manager fstab is no longer supported.", + 'node': 'manager', + 'node_id': 'manager', + } + return web.json_response(resp) + + +@superadmin_required +@server_status_required(READ_ALLOWED) +async def list_mounts(request: web.Request) -> web.Response: + """ + List all mounted vfolder hosts in vfroot. + + All mounted hosts from connected (ALIVE) agents are also gathered. + Generally, agents should be configured to have same hosts structure, + but newly introduced one may not. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log.info('VFOLDER.LIST_MOUNTS(ak:{})', access_key) + mount_prefix = await root_ctx.shared_config.get_raw('volumes/_mount') + if mount_prefix is None: + mount_prefix = '/mnt' + + # NOTE: Changed in 20.09: the manager instances no longer have mountpoints. + all_volumes = [*await root_ctx.storage_manager.get_all_volumes()] + all_mounts = [ + volume_data['path'] + for proxy_name, volume_data in all_volumes + ] + all_vfolder_hosts = [ + f"{proxy_name}:{volume_data['name']}" + for proxy_name, volume_data in all_volumes + ] + resp: MutableMapping[str, Any] = { + 'manager': { + 'success': True, + 'mounts': all_mounts, + 'message': '(legacy)', + }, + 'storage-proxy': { + 'success': True, + 'mounts': [*zip(all_vfolder_hosts, all_mounts)], + 'message': '', + }, + 'agents': {}, + } + + # Scan mounted vfolder hosts for connected agents. + async def _fetch_mounts( + sema: asyncio.Semaphore, + sess: aiohttp.ClientSession, + agent_id: str, + ) -> Tuple[str, Mapping]: + async with sema: + watcher_info = await get_watcher_info(request, agent_id) + headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} + url = watcher_info['addr'] / 'mounts' + try: + async with sess.get(url, headers=headers) as watcher_resp: + if watcher_resp.status == 200: + data = { + 'success': True, + 'mounts': await watcher_resp.json(), + 'message': '', + } + else: + data = { + 'success': False, + 'mounts': [], + 'message': await watcher_resp.text(), + } + return (agent_id, data) + except asyncio.CancelledError: + raise + except asyncio.TimeoutError: + log.error( + 'VFOLDER.LIST_MOUNTS(u:{}): timeout from watcher (agent:{})', + access_key, agent_id, + ) + raise + except Exception: + log.exception( + 'VFOLDER.LIST_MOUNTS(u:{}): ' + 'unexpected error while reading from watcher (agent:{})', + access_key, agent_id, + ) + raise + + async with root_ctx.db.begin() as conn: + query = ( + sa.select([agents.c.id]) + .select_from(agents) + .where(agents.c.status == AgentStatus.ALIVE) + ) + result = await conn.execute(query) + rows = result.fetchall() + + client_timeout = aiohttp.ClientTimeout(total=10.0) + async with aiohttp.ClientSession(timeout=client_timeout) as sess: + sema = asyncio.Semaphore(8) + mounts = await asyncio.gather(*[ + _fetch_mounts(sema, sess, row.id) for row in rows + ], return_exceptions=True) + for mount in mounts: + if isinstance(mount, Exception): + # exceptions are already logged. + continue + resp['agents'][mount[0]] = mount[1] + + return web.json_response(resp, status=200) + + +@superadmin_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('fs_location'): t.String, + t.Key('name'): t.String, + t.Key('fs_type', default='nfs'): t.String, + t.Key('options', default=None): t.String | t.Null, + t.Key('scaling_group', default=None): t.String | t.Null, + t.Key('fstab_path', default=None): t.String | t.Null, + t.Key('edit_fstab', default=False): t.ToBool, + }), +) +async def mount_host(request: web.Request, params: Any) -> web.Response: + """ + Mount device into vfolder host. + + Mount a device (eg: nfs) located at `fs_location` into `/name` in the + host machines (manager and all agents). `fs_type` can be specified by requester, + which fallbaks to 'nfs'. + + If `scaling_group` is specified, try to mount for agents in the scaling group. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log_fmt = 'VFOLDER.MOUNT_HOST(ak:{}, name:{}, fs:{}, sg:{})' + log_args = (access_key, params['name'], params['fs_location'], params['scaling_group']) + log.info(log_fmt, *log_args) + mount_prefix = await root_ctx.shared_config.get_raw('volumes/_mount') + if mount_prefix is None: + mount_prefix = '/mnt' + + # NOTE: Changed in 20.09: the manager instances no longer have mountpoints. + resp: MutableMapping[str, Any] = { + 'manager': { + 'success': True, + 'message': 'Managers do not have mountpoints since v20.09.', + }, + 'agents': {}, + } + + # Mount on running agents. + async with root_ctx.db.begin() as conn: + query = ( + sa.select([agents.c.id]) + .select_from(agents) + .where(agents.c.status == AgentStatus.ALIVE) + ) + if params['scaling_group'] is not None: + query = query.where(agents.c.scaling == params['scaling_group']) + result = await conn.execute(query) + rows = result.fetchall() + + async def _mount( + sema: asyncio.Semaphore, + sess: aiohttp.ClientSession, + agent_id: str, + ) -> Tuple[str, Mapping]: + async with sema: + watcher_info = await get_watcher_info(request, agent_id) + try: + headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} + url = watcher_info['addr'] / 'mounts' + async with sess.post(url, json=params, headers=headers) as resp: + if resp.status == 200: + data = { + 'success': True, + 'message': await resp.text(), + } + else: + data = { + 'success': False, + 'message': await resp.text(), + } + return (agent_id, data) + except asyncio.CancelledError: + raise + except asyncio.TimeoutError: + log.error( + log_fmt + ': timeout from watcher (ag:{})', + *log_args, agent_id, + ) + raise + except Exception: + log.exception( + log_fmt + ': unexpected error while reading from watcher (ag:{})', + *log_args, agent_id, + ) + raise + + client_timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=client_timeout) as sess: + sema = asyncio.Semaphore(8) + results = await asyncio.gather(*[ + _mount(sema, sess, row.id) for row in rows + ], return_exceptions=True) + for result in results: + if isinstance(result, Exception): + # exceptions are already logged. + continue + resp['agents'][result[0]] = result[1] + + return web.json_response(resp, status=200) + + +@superadmin_required +@server_status_required(ALL_ALLOWED) +@check_api_params( + t.Dict({ + t.Key('name'): t.String, + t.Key('scaling_group', default=None): t.String | t.Null, + t.Key('fstab_path', default=None): t.String | t.Null, + t.Key('edit_fstab', default=False): t.ToBool, + }), +) +async def umount_host(request: web.Request, params: Any) -> web.Response: + """ + Unmount device from vfolder host. + + Unmount a device (eg: nfs) located at `/name` from the host machines + (manager and all agents). + + If `scaling_group` is specified, try to unmount for agents in the scaling group. + """ + root_ctx: RootContext = request.app['_root.context'] + access_key = request['keypair']['access_key'] + log_fmt = 'VFOLDER.UMOUNT_HOST(ak:{}, name:{}, sg:{})' + log_args = (access_key, params['name'], params['scaling_group']) + log.info(log_fmt, *log_args) + mount_prefix = await root_ctx.shared_config.get_raw('volumes/_mount') + if mount_prefix is None: + mount_prefix = '/mnt' + mountpoint = Path(mount_prefix) / params['name'] + assert Path(mount_prefix) != mountpoint + + async with root_ctx.db.begin() as conn, conn.begin(): + # Prevent unmount if target host is mounted to running kernels. + query = ( + sa.select([kernels.c.mounts]) + .select_from(kernels) + .where(kernels.c.status != KernelStatus.TERMINATED) + ) + result = await conn.execute(query) + _kernels = result.fetchall() + _mounted = set() + for kern in _kernels: + if kern.mounts: + _mounted.update([m[1] for m in kern.mounts]) + if params['name'] in _mounted: + return web.json_response({ + 'title': 'Target host is used in sessions', + 'message': 'Target host is used in sessions', + }, status=409) + + query = (sa.select([agents.c.id]) + .select_from(agents) + .where(agents.c.status == AgentStatus.ALIVE)) + if params['scaling_group'] is not None: + query = query.where(agents.c.scaling == params['scaling_group']) + result = await conn.execute(query) + _agents = result.fetchall() + + # Unmount from manager. + # NOTE: Changed in 20.09: the manager instances no longer have mountpoints. + resp: MutableMapping[str, Any] = { + 'manager': { + 'success': True, + 'message': 'Managers do not have mountpoints since v20.09.', + }, + 'agents': {}, + } + + # Unmount from running agents. + async def _umount( + sema: asyncio.Semaphore, + sess: aiohttp.ClientSession, + agent_id: str, + ) -> Tuple[str, Mapping]: + async with sema: + watcher_info = await get_watcher_info(request, agent_id) + try: + headers = {'X-BackendAI-Watcher-Token': watcher_info['token']} + url = watcher_info['addr'] / 'mounts' + async with sess.delete(url, json=params, headers=headers) as resp: + if resp.status == 200: + data = { + 'success': True, + 'message': await resp.text(), + } + else: + data = { + 'success': False, + 'message': await resp.text(), + } + return (agent_id, data) + except asyncio.CancelledError: + raise + except asyncio.TimeoutError: + log.error( + log_fmt + ': timeout from watcher (agent:{})', + *log_args, agent_id, + ) + raise + except Exception: + log.exception( + log_fmt + ': unexpected error while reading from watcher (agent:{})', + *log_args, agent_id, + ) + raise + + client_timeout = aiohttp.ClientTimeout(total=10.0) + async with aiohttp.ClientSession(timeout=client_timeout) as sess: + sema = asyncio.Semaphore(8) + results = await asyncio.gather(*[ + _umount(sema, sess, _agent.id) for _agent in _agents + ], return_exceptions=True) + for result in results: + if isinstance(result, Exception): + # exceptions are already logged. + continue + resp['agents'][result[0]] = result[1] + + return web.json_response(resp, status=200) + + +async def init(app: web.Application) -> None: + pass + + +async def shutdown(app: web.Application) -> None: + pass + + +def create_app(default_cors_options): + app = web.Application() + app['prefix'] = 'folders' + app['api_versions'] = (2, 3, 4) + app.on_startup.append(init) + app.on_shutdown.append(shutdown) + cors = aiohttp_cors.setup(app, defaults=default_cors_options) + add_route = app.router.add_route + root_resource = cors.add(app.router.add_resource(r'')) + cors.add(root_resource.add_route('POST', create)) + cors.add(root_resource.add_route('GET', list_folders)) + cors.add(root_resource.add_route('DELETE', delete_by_id)) + vfolder_resource = cors.add(app.router.add_resource(r'/{name}')) + cors.add(vfolder_resource.add_route('GET', get_info)) + cors.add(vfolder_resource.add_route('DELETE', delete)) + cors.add(add_route('GET', r'/_/hosts', list_hosts)) + cors.add(add_route('GET', r'/_/all-hosts', list_all_hosts)) + cors.add(add_route('GET', r'/_/allowed-types', list_allowed_types)) + cors.add(add_route('GET', r'/_/all_hosts', list_all_hosts)) # legacy underbar + cors.add(add_route('GET', r'/_/allowed_types', list_allowed_types)) # legacy underbar + cors.add(add_route('GET', r'/_/perf-metric', get_volume_perf_metric)) + cors.add(add_route('POST', r'/{name}/rename', rename_vfolder)) + cors.add(add_route('POST', r'/{name}/update-options', update_vfolder_options)) + cors.add(add_route('POST', r'/{name}/mkdir', mkdir)) + cors.add(add_route('POST', r'/{name}/request-upload', create_upload_session)) + cors.add(add_route('POST', r'/{name}/request-download', create_download_session)) + cors.add(add_route('POST', r'/{name}/move-file', move_file)) + cors.add(add_route('POST', r'/{name}/rename-file', rename_file)) + cors.add(add_route('DELETE', r'/{name}/delete-files', delete_files)) + cors.add(add_route('POST', r'/{name}/rename_file', rename_file)) # legacy underbar + cors.add(add_route('DELETE', r'/{name}/delete_files', delete_files)) # legacy underbar + cors.add(add_route('GET', r'/{name}/files', list_files)) + cors.add(add_route('POST', r'/{name}/invite', invite)) + cors.add(add_route('POST', r'/{name}/leave', leave)) + cors.add(add_route('POST', r'/{name}/share', share)) + cors.add(add_route('DELETE', r'/{name}/unshare', unshare)) + cors.add(add_route('POST', r'/{name}/clone', clone)) + cors.add(add_route('GET', r'/invitations/list-sent', list_sent_invitations)) + cors.add(add_route('GET', r'/invitations/list_sent', list_sent_invitations)) # legacy underbar + cors.add(add_route('POST', r'/invitations/update/{inv_id}', update_invitation)) + cors.add(add_route('GET', r'/invitations/list', invitations)) + cors.add(add_route('POST', r'/invitations/accept', accept_invitation)) + cors.add(add_route('DELETE', r'/invitations/delete', delete_invitation)) + cors.add(add_route('GET', r'/_/shared', list_shared_vfolders)) + cors.add(add_route('POST', r'/_/shared', update_shared_vfolder)) + cors.add(add_route('GET', r'/_/fstab', get_fstab_contents)) + cors.add(add_route('GET', r'/_/mounts', list_mounts)) + cors.add(add_route('POST', r'/_/mounts', mount_host)) + cors.add(add_route('DELETE', r'/_/mounts', umount_host)) + cors.add(add_route('GET', r'/_/quota', get_quota)) + cors.add(add_route('POST', r'/_/quota', update_quota)) + cors.add(add_route('GET', r'/_/usage', get_usage)) + return app, [] diff --git a/src/ai/backend/manager/api/wsproxy.py b/src/ai/backend/manager/api/wsproxy.py new file mode 100644 index 0000000000..3d09ed8f4f --- /dev/null +++ b/src/ai/backend/manager/api/wsproxy.py @@ -0,0 +1,257 @@ +""" +WebSocket-based streaming kernel interaction APIs. +""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +import asyncio +import logging +from typing import ( + Any, + Awaitable, + Callable, + Optional, + Union, +) + +import aiohttp +import aiotools +from aiohttp import WSCloseCode +from aiohttp import web + +from ai.backend.common.logging import BraceStyleAdapter + +from ..config import DEFAULT_CHUNK_SIZE + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class ServiceProxy(metaclass=ABCMeta): + """ + The abstract base class to implement service proxy handlers. + """ + + __slots__ = ( + 'ws', + 'host', + 'port', + 'downstream_cb', + 'upstream_cb', + 'ping_cb', + ) + + def __init__( + self, + down_ws: web.WebSocketResponse, + dest_host: str, + dest_port: int, + *, + downstream_callback: Callable[[Any], Awaitable[None]] = None, + upstream_callback: Callable[[Any], Awaitable[None]] = None, + ping_callback: Callable[[Any], Awaitable[None]] = None, + ) -> None: + self.ws = down_ws + self.host = dest_host + self.port = dest_port + self.downstream_cb = downstream_callback + self.upstream_cb = upstream_callback + self.ping_cb = ping_callback + + @abstractmethod + async def proxy(self) -> web.WebSocketResponse: + pass + + +class TCPProxy(ServiceProxy): + + __slots__ = ( + *ServiceProxy.__slots__, + 'down_task', + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.down_task: Optional[asyncio.Task] = None + + async def proxy(self) -> web.WebSocketResponse: + try: + try: + log.debug('Trying to open proxied TCP connection to {}:{}', self.host, self.port) + reader, writer = await asyncio.open_connection(self.host, self.port) + except ConnectionRefusedError: + await self.ws.close(code=WSCloseCode.TRY_AGAIN_LATER) + return self.ws + except Exception: + log.exception("TCPProxy.proxy(): unexpected initial connection error") + await self.ws.close(code=WSCloseCode.INTERNAL_ERROR) + return self.ws + + async def downstream() -> None: + try: + while True: + try: + chunk = await reader.read(DEFAULT_CHUNK_SIZE) + if not chunk: + break + await self.ws.send_bytes(chunk) + except (RuntimeError, ConnectionResetError, + asyncio.CancelledError): + # connection interrupted by client-side + break + else: + if self.downstream_cb is not None: + await self.downstream_cb(chunk) + except asyncio.CancelledError: + pass + except Exception: + log.exception("TCPProxy.proxy(): unexpected downstream error") + finally: + await self.ws.close(code=WSCloseCode.GOING_AWAY) + + log.debug('TCPProxy connected {0}:{1}', self.host, self.port) + self.down_task = asyncio.create_task(downstream()) + async for msg in self.ws: + if msg.type == web.WSMsgType.BINARY: + try: + writer.write(msg.data) + await writer.drain() + except RuntimeError: + log.debug("Error on writing: Is it closed?") + if self.upstream_cb is not None: + await self.upstream_cb(msg.data) + elif msg.type == web.WSMsgType.PING: + await self.ws.pong(msg.data) + if self.ping_cb is not None: + await self.ping_cb(msg.data) + elif msg.type == web.WSMsgType.ERROR: + log.debug("TCPProxy.proxy(): websocket upstream error", exc_info=msg.data) + writer.close() + await writer.wait_closed() + + except asyncio.CancelledError: + pass + except Exception: + log.exception("TCPProxy.proxy(): unexpected upstream error") + finally: + if self.down_task is not None and not self.down_task.done(): + self.down_task.cancel() + await self.down_task + log.debug('websocket connection closed') + return self.ws + + +class WebSocketProxy: + __slots__ = ( + 'up_conn', 'down_conn', + 'upstream_buffer', 'upstream_buffer_task', + 'downstream_cb', 'upstream_cb', 'ping_cb', + ) + + up_conn: aiohttp.ClientWebSocketResponse + down_conn: web.WebSocketResponse + # FIXME: use __future__.annotations in Python 3.7+ + upstream_buffer: asyncio.Queue # contains: Tuple[Union[bytes, str], web.WSMsgType] + upstream_buffer_task: Optional[asyncio.Task] + downstream_cb: Callable[[str | bytes], Awaitable[None]] | None + upstream_cb: Callable[[str | bytes], Awaitable[None]] | None + ping_cb: Callable[[str | bytes], Awaitable[None]] | None + + def __init__( + self, + up_conn: aiohttp.ClientWebSocketResponse, + down_conn: web.WebSocketResponse, + *, + downstream_callback: Callable[[str | bytes], Awaitable[None]] = None, + upstream_callback: Callable[[str | bytes], Awaitable[None]] = None, + ping_callback: Callable[[str | bytes], Awaitable[None]] = None, + ): + self.up_conn = up_conn + self.down_conn = down_conn + self.upstream_buffer = asyncio.Queue() + self.upstream_buffer_task = None + self.downstream_cb = downstream_callback + self.upstream_cb = upstream_callback + self.ping_cb = ping_callback + + async def proxy(self) -> None: + asyncio.create_task(self.downstream()) + await self.upstream() + + async def upstream(self) -> None: + try: + async for msg in self.down_conn: + if msg.type in (web.WSMsgType.TEXT, web.WSMsgType.binary): + await self.write(msg.data, msg.type) + if self.upstream_cb is not None: + await self.upstream_cb(msg.data) + elif msg.type == web.WSMsgType.PING: + if self.ping_cb is not None: + await self.ping_cb(msg.data) + elif msg.type == aiohttp.WSMsgType.ERROR: + log.error("ws connection closed with exception {}", + self.up_conn.exception()) + break + elif msg.type == aiohttp.WSMsgType.CLOSE: + break + # here, client gracefully disconnected + except asyncio.CancelledError: + # here, client forcibly disconnected + raise + finally: + await self.close_downstream() + + async def downstream(self) -> None: + try: + async with aiotools.PersistentTaskGroup() as tg: + self.upstream_buffer_task = tg.create_task( + self.consume_upstream_buffer(), + ) + async for msg in self.up_conn: + if msg.type == aiohttp.WSMsgType.TEXT: + await self.down_conn.send_str(msg.data) + if self.downstream_cb is not None: + await asyncio.shield(tg.create_task(self.downstream_cb(msg.data))) + if msg.type == aiohttp.WSMsgType.BINARY: + await self.down_conn.send_bytes(msg.data) + if self.downstream_cb is not None: + await asyncio.shield(tg.create_task(self.downstream_cb(msg.data))) + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break + # here, server gracefully disconnected + except asyncio.CancelledError: + raise + except Exception: + log.exception('unexpected error') + finally: + await self.close_upstream() + + async def consume_upstream_buffer(self) -> None: + while True: + msg, tp = await self.upstream_buffer.get() + try: + if self.up_conn and not self.up_conn.closed: + if tp == aiohttp.WSMsgType.TEXT: + await self.up_conn.send_str(msg) + elif tp == aiohttp.WSMsgType.binary: + await self.up_conn.send_bytes(msg) + else: + await self.close_downstream() + finally: + self.upstream_buffer.task_done() + + async def write(self, msg: Union[bytes, str], tp: web.WSMsgType) -> None: + await self.upstream_buffer.put((msg, tp)) + + async def close_downstream(self) -> None: + if not self.down_conn.closed: + await self.down_conn.close() + + async def close_upstream(self) -> None: + if self.upstream_buffer_task: + self.upstream_buffer_task.cancel() + await self.upstream_buffer_task + if self.up_conn: + await self.up_conn.close() diff --git a/src/ai/backend/manager/cli/__init__.py b/src/ai/backend/manager/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/manager/cli/__main__.py b/src/ai/backend/manager/cli/__main__.py new file mode 100644 index 0000000000..cf58b24d5e --- /dev/null +++ b/src/ai/backend/manager/cli/__main__.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import asyncio +import logging +import subprocess +import sys +from datetime import datetime +from functools import partial +from pathlib import Path + +import aioredis, aioredis.client +import click +import psycopg2 +import sqlalchemy as sa +from more_itertools import chunked +from setproctitle import setproctitle + +from ai.backend.common import redis as redis_helper +from ai.backend.common.cli import LazyGroup +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.validators import TimeDuration + +from ai.backend.manager.models import kernels +from ai.backend.manager.models.utils import connect_database + +from ..config import load as load_config +from ..models.keypair import generate_keypair as _gen_keypair +from .context import CLIContext, init_logger, redis_ctx + +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.cli')) + + +@click.group(invoke_without_command=True, context_settings={'help_option_names': ['-h', '--help']}) +@click.option('-f', '--config-path', '--config', type=Path, default=None, + help='The config file path. (default: ./manager.conf and /etc/backend.ai/manager.conf)') +@click.option('--debug', is_flag=True, + help='Enable the debug mode and override the global log level to DEBUG.') +@click.pass_context +def main(ctx, config_path, debug): + """ + Manager Administration CLI + """ + local_config = load_config(config_path) + setproctitle(f"backend.ai: manager.cli {local_config['etcd']['namespace']}") + ctx.obj = CLIContext( + logger=init_logger(local_config), + local_config=local_config, + ) + + +@main.command(context_settings=dict( + ignore_unknown_options=True, +)) +@click.option('--psql-container', 'container_name', type=str, default=None, + metavar='ID_OR_NAME', + help='Open a postgres client shell using the psql executable ' + 'shipped with the given postgres container. ' + 'If not set or set as an empty string "", it will auto-detect ' + 'the psql container from the halfstack. ' + 'If set "-", it will use the host-provided psql executable. ' + 'You may append additional arguments passed to the psql cli command. ' + '[default: auto-detect from halfstack]') +@click.option('--psql-help', is_flag=True, + help='Show the help text of the psql command instead of ' + 'this dbshell command.') +@click.argument('psql_args', nargs=-1, type=click.UNPROCESSED) +@click.pass_obj +def dbshell(cli_ctx: CLIContext, container_name, psql_help, psql_args): + """ + Run the database shell. + + All arguments except `--psql-container` and `--psql-help` are transparently + forwarded to the psql command. For instance, you can use `-c` to execute a + psql/SQL statement on the command line. Note that you do not have to specify + connection-related options because the dbshell command fills out them from the + manager configuration. + """ + local_config = cli_ctx.local_config + if psql_help: + psql_args = ['--help'] + if not container_name: + # Try to get the database container name of the halfstack + candidate_container_names = subprocess.check_output( + ['docker', 'ps', '--format', '{{.Names}}', '--filter', 'name=half-db'], + ) + if not candidate_container_names: + click.echo("Could not find the halfstack postgres container. " + "Please set the container name explicitly.", + err=True) + sys.exit(1) + container_name = candidate_container_names.decode().splitlines()[0].strip() + elif container_name == '-': + # Use the host-provided psql command + cmd = [ + 'psql', + (f"postgres://{local_config['db']['user']}:{local_config['db']['password']}" + f"@{local_config['db']['addr']}/{local_config['db']['name']}"), + *psql_args, + ] + subprocess.call(cmd) + return + # Use the container to start the psql client command + print(f"using the db container {container_name} ...") + cmd = [ + 'docker', 'exec', '-i', '-t', + container_name, + 'psql', + '-U', local_config['db']['user'], + '-d', local_config['db']['name'], + *psql_args, + ] + subprocess.call(cmd) + + +@main.command() +@click.pass_obj +def generate_keypair(cli_ctx: CLIContext): + """ + Generate a random keypair and print it out to stdout. + """ + log.info('generating keypair...') + ak, sk = _gen_keypair() + print(f'Access Key: {ak} ({len(ak)} bytes)') + print(f'Secret Key: {sk} ({len(sk)} bytes)') + + +@main.command() +@click.option('-r', '--retention', type=str, default='1yr', + help='The retention limit. e.g., 20d, 1mo, 6mo, 1yr') +@click.option('-v', '--vacuum-full', type=bool, default=False, + help='Reclaim storage occupied by dead tuples.' + 'If not set or set False, it will run VACUUM without FULL.' + 'If set True, it will run VACUUM FULL.' + 'When VACUUM FULL is being processed, the database is locked.' + '[default: False]') +@click.pass_obj +def clear_history(cli_ctx: CLIContext, retention, vacuum_full) -> None: + """ + Delete old records from the kernels table and + invoke the PostgreSQL's vaccuum operation to clear up the actual disk space. + """ + local_config = cli_ctx.local_config + with cli_ctx.logger: + today = datetime.now() + duration = TimeDuration() + expiration = today - duration.check_and_return(retention) + expiration_date = expiration.strftime('%Y-%m-%d %H:%M:%S') + + async def _clear_redis_history(): + try: + async with connect_database(cli_ctx.local_config) as db: + async with db.begin_readonly() as conn: + query = ( + sa.select([kernels.c.id]) + .select_from(kernels) + .where( + (kernels.c.terminated_at < expiration), + ) + ) + result = await conn.execute(query) + target_kernels = [str(x['id']) for x in result.all()] + + delete_count = 0 + async with redis_ctx(cli_ctx) as redis_conn_set: + + def _build_pipe( + r: aioredis.Redis, + kernel_ids: list[str], + ) -> aioredis.client.Pipeline: + pipe = r.pipeline(transaction=False) + pipe.delete(*kernel_ids) + return pipe + + if len(target_kernels) > 0: + # Apply chunking to avoid excessive length of command params + # and indefinite blocking of the Redis server. + for kernel_ids in chunked(target_kernels, 32): + results = await redis_helper.execute( + redis_conn_set.stat, + partial(_build_pipe, kernel_ids=kernel_ids), + ) + # Each DEL command returns the number of keys deleted. + delete_count += sum(results) + log.info( + "Cleaned up {:,} redis statistics records older than {:}.", + delete_count, expiration_date, + ) + + # Sync and compact the persistent database of Redis + redis_config = await redis_helper.execute( + redis_conn_set.stat, + lambda r: r.config_get("appendonly"), + ) + if redis_config['appendonly'] == 'yes': + await redis_helper.execute( + redis_conn_set.stat, + lambda r: r.bgrewriteaof(), + ) + log.info("Issued BGREWRITEAOF to the Redis database.") + else: + await redis_helper.execute( + redis_conn_set.stat, + lambda r: r.execute_command("BGSAVE SCHEDULE"), + ) + log.info("Issued BGSAVE to the Redis database.") + except: + log.exception("Unexpected error while cleaning up redis history") + + asyncio.run(_clear_redis_history()) + + conn = psycopg2.connect( + host=local_config['db']['addr'][0], + port=local_config['db']['addr'][1], + dbname=local_config['db']['name'], + user=local_config['db']['user'], + password=local_config['db']['password'], + ) + with conn.cursor() as curs: + if vacuum_full: + vacuum_sql = "VACUUM FULL" + else: + vacuum_sql = "VACUUM" + + curs.execute(f""" + SELECT COUNT(*) FROM kernels WHERE terminated_at < '{expiration_date}'; + """) + deleted_count = curs.fetchone()[0] + + conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + log.info('Deleting old records...') + curs.execute(f""" + DELETE FROM kernels WHERE terminated_at < '{expiration_date}'; + """) + log.info(f'Perfoming {vacuum_sql} operation...') + curs.execute(vacuum_sql) + conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED) + + curs.execute(""" + SELECT COUNT(*) FROM kernels; + """) + table_size = curs.fetchone()[0] + log.info(f'kernels table size: {table_size}') + + log.info('Cleaned up {:,} database records older than {:}.', deleted_count, expiration_date) + + +@main.group(cls=LazyGroup, import_name='ai.backend.manager.cli.dbschema:cli') +def schema(): + '''Command set for managing the database schema.''' + + +@main.group(cls=LazyGroup, import_name='ai.backend.manager.cli.etcd:cli') +def etcd(): + '''Command set for putting/getting data to/from etcd.''' + + +@main.group(cls=LazyGroup, import_name='ai.backend.manager.cli.fixture:cli') +def fixture(): + '''Command set for managing fixtures.''' + + +@main.group(cls=LazyGroup, import_name='ai.backend.manager.cli.gql:cli') +def gql(): + '''Command set for GraphQL schema.''' + + +@main.group(cls=LazyGroup, import_name='ai.backend.manager.cli.image:cli') +def image(): + '''Command set for managing images.''' + + +if __name__ == '__main__': + main() diff --git a/src/ai/backend/manager/cli/context.py b/src/ai/backend/manager/cli/context.py new file mode 100644 index 0000000000..d3c3c28738 --- /dev/null +++ b/src/ai/backend/manager/cli/context.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import atexit +import contextlib +import attr +import os +from typing import TYPE_CHECKING, AsyncIterator + +from ai.backend.common import redis +from ai.backend.common.config import redis_config_iv +from ai.backend.common.logging import AbstractLogger, Logger, NoopLogger +from ai.backend.common.types import RedisConnectionInfo + +from ai.backend.manager.config import SharedConfig +from ai.backend.manager.defs import REDIS_STAT_DB, REDIS_LIVE_DB, REDIS_IMAGE_DB, REDIS_STREAM_DB + + +if TYPE_CHECKING: + from ..config import LocalConfig + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class CLIContext: + logger: AbstractLogger + local_config: LocalConfig + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RedisConnectionSet: + live: RedisConnectionInfo + stat: RedisConnectionInfo + image: RedisConnectionInfo + stream: RedisConnectionInfo + + +def init_logger(local_config: LocalConfig, nested: bool = False) -> AbstractLogger: + if nested: + # return a dummy logger that does nothing. + return NoopLogger(local_config) + if 'drivers' in local_config['logging'] and 'file' in local_config['logging']['drivers']: + local_config['logging']['drivers'].remove('file') + # log_endpoint = f'tcp://127.0.0.1:{find_free_port()}' + ipc_base_path = local_config['manager']['ipc-base-path'] + log_sockpath = ipc_base_path / f'manager-cli-{os.getpid()}.sock' + log_endpoint = f'ipc://{log_sockpath}' + + def _clean_logger(): + try: + os.unlink(log_sockpath) + except FileNotFoundError: + pass + + atexit.register(_clean_logger) + return Logger(local_config['logging'], is_master=True, log_endpoint=log_endpoint) + + +@contextlib.asynccontextmanager +async def redis_ctx(cli_ctx: CLIContext) -> AsyncIterator[RedisConnectionSet]: + local_config = cli_ctx.local_config + shared_config = SharedConfig( + local_config['etcd']['addr'], + local_config['etcd']['user'], + local_config['etcd']['password'], + local_config['etcd']['namespace'], + ) + await shared_config.reload() + raw_redis_config = await shared_config.etcd.get_prefix('config/redis') + local_config['redis'] = redis_config_iv.check(raw_redis_config) + redis_live = redis.get_redis_object(shared_config.data['redis'], db=REDIS_LIVE_DB) + redis_stat = redis.get_redis_object(shared_config.data['redis'], db=REDIS_STAT_DB) + redis_image = redis.get_redis_object( + shared_config.data['redis'], db=REDIS_IMAGE_DB, + ) + redis_stream = redis.get_redis_object( + shared_config.data['redis'], db=REDIS_STREAM_DB, + ) + yield RedisConnectionSet( + live=redis_live, + stat=redis_stat, + image=redis_image, + stream=redis_stream, + ) + await redis_stream.close() + await redis_image.close() + await redis_stat.close() + await redis_live.close() diff --git a/src/ai/backend/manager/cli/dbschema.py b/src/ai/backend/manager/cli/dbschema.py new file mode 100644 index 0000000000..8d7b44bcbb --- /dev/null +++ b/src/ai/backend/manager/cli/dbschema.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from alembic.config import Config +from alembic import command +from alembic.runtime.migration import MigrationContext +from alembic.script import ScriptDirectory +import click +import sqlalchemy as sa + +from ai.backend.common.logging import BraceStyleAdapter + +from ..models.base import metadata +if TYPE_CHECKING: + from .context import CLIContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@click.group() +def cli(args) -> None: + pass + + +@cli.command() +@click.option('-f', '--alembic-config', + default='alembic.ini', metavar='PATH', + help='The path to Alembic config file. ' + '[default: alembic.ini]') +@click.pass_obj +def show(cli_ctx: CLIContext, alembic_config) -> None: + '''Show the current schema information.''' + with cli_ctx.logger: + alembic_cfg = Config(alembic_config) + sa_url = alembic_cfg.get_main_option('sqlalchemy.url') + engine = sa.create_engine(sa_url) + with engine.begin() as connection: + context = MigrationContext.configure(connection) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(alembic_cfg) + heads = script.get_heads() + head_rev = heads[0] if len(heads) > 0 else None + print(f'Current database revision: {current_rev}') + print(f'The head revision of available migrations: {head_rev}') + + +@cli.command() +@click.option('-f', '--alembic-config', default='alembic.ini', metavar='PATH', + help='The path to Alembic config file. ' + '[default: alembic.ini]') +@click.pass_obj +def oneshot(cli_ctx: CLIContext, alembic_config) -> None: + ''' + Set up your database with one-shot schema migration instead of + iterating over multiple revisions if there is no existing database. + It uses alembic.ini to configure database connection. + + Reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + #building-an-up-to-date-database-from-scratch + ''' + with cli_ctx.logger: + alembic_cfg = Config(alembic_config) + sa_url = alembic_cfg.get_main_option('sqlalchemy.url') + + engine = sa.create_engine(sa_url) + engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + + with engine.begin() as connection: + context = MigrationContext.configure(connection) + current_rev = context.get_current_revision() + + if current_rev is None: + # For a fresh clean database, create all from scratch. + # (it will raise error if tables already exist.) + log.info('Detected a fresh new database.') + log.info('Creating tables...') + with engine.begin() as connection: + alembic_cfg.attributes['connection'] = connection + metadata.create_all(engine, checkfirst=False) + log.info('Stamping alembic version to head...') + command.stamp(alembic_cfg, 'head') + else: + # If alembic version info is already available, perform incremental upgrade. + log.info('Detected an existing database.') + log.info('Performing schema upgrade to head...') + with engine.begin() as connection: + alembic_cfg.attributes['connection'] = connection + command.upgrade(alembic_cfg, 'head') + + log.info("If you don't need old migrations, delete them and set " + "\"down_revision\" value in the earliest migration to \"None\".") diff --git a/src/ai/backend/manager/cli/etcd.py b/src/ai/backend/manager/cli/etcd.py new file mode 100644 index 0000000000..044ae1d5a3 --- /dev/null +++ b/src/ai/backend/manager/cli/etcd.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +from typing import AsyncIterator, TYPE_CHECKING +import sys + +import click + +from ai.backend.common.cli import EnumChoice, MinMaxRange +from ai.backend.common.config import redis_config_iv +from ai.backend.common.etcd import ( + AsyncEtcd, ConfigScopes, + quote as etcd_quote, + unquote as etcd_unquote, +) +from ai.backend.common.logging import BraceStyleAdapter + +from .image_impl import ( + alias as alias_impl, + dealias as dealias_impl, + forget_image as forget_image_impl, + inspect_image as inspect_image_impl, + list_images as list_images_impl, + rescan_images as rescan_images_impl, + set_image_resource_limit as set_image_resource_limit_impl, +) +from ..config import SharedConfig + +if TYPE_CHECKING: + from .context import CLIContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@click.group() +def cli() -> None: + pass + + +@contextlib.asynccontextmanager +async def etcd_ctx(cli_ctx: CLIContext) -> AsyncIterator[AsyncEtcd]: + local_config = cli_ctx.local_config + creds = None + if local_config['etcd']['user']: + creds = { + 'user': local_config['etcd']['user'], + 'password': local_config['etcd']['password'], + } + scope_prefix_map = { + ConfigScopes.GLOBAL: '', + # TODO: provide a way to specify other scope prefixes + } + etcd = AsyncEtcd(local_config['etcd']['addr'], local_config['etcd']['namespace'], + scope_prefix_map, credentials=creds) + try: + yield etcd + finally: + await etcd.close() + + +@contextlib.asynccontextmanager +async def config_ctx(cli_ctx: CLIContext) -> AsyncIterator[SharedConfig]: + local_config = cli_ctx.local_config + # scope_prefix_map is created inside ConfigServer + shared_config = SharedConfig( + local_config['etcd']['addr'], + local_config['etcd']['user'], + local_config['etcd']['password'], + local_config['etcd']['namespace'], + ) + await shared_config.reload() + raw_redis_config = await shared_config.etcd.get_prefix('config/redis') + local_config['redis'] = redis_config_iv.check(raw_redis_config) + try: + yield shared_config + finally: + await shared_config.close() + + +@cli.command() +@click.argument('key') +@click.argument('value') +@click.option('-s', '--scope', type=EnumChoice(ConfigScopes), default=ConfigScopes.GLOBAL, + help='The configuration scope to put the value.') +@click.pass_obj +def put(cli_ctx: CLIContext, key, value, scope) -> None: + '''Put a single key-value pair into the etcd.''' + async def _impl(): + async with etcd_ctx(cli_ctx) as etcd: + try: + await etcd.put(key, value, scope=scope) + except Exception: + log.exception('An error occurred.') + with cli_ctx.logger: + asyncio.run(_impl()) + + +@cli.command() +@click.argument('key', type=str) +@click.argument('file', type=click.File('rb')) +@click.option('-s', '--scope', type=EnumChoice(ConfigScopes), + default=ConfigScopes.GLOBAL, + help='The configuration scope to put the value.') +@click.pass_obj +def put_json(cli_ctx: CLIContext, key, file, scope) -> None: + ''' + Put a JSON object from FILE to the etcd as flattened key-value pairs + under the given KEY prefix. + ''' + async def _impl(): + async with etcd_ctx(cli_ctx) as etcd: + try: + value = json.load(file) + await etcd.put_prefix(key, value, scope=scope) + except Exception: + log.exception('An error occurred.') + with cli_ctx.logger: + asyncio.run(_impl()) + + +@cli.command() +@click.argument('src_prefix', type=str) +@click.argument('dst_prefix', type=str) +@click.option('-s', '--scope', type=EnumChoice(ConfigScopes), + default=ConfigScopes.GLOBAL, + help='The configuration scope to get/put the subtree. ' + 'To move between different scopes, use the global scope ' + 'and specify the per-scope prefixes manually.') +@click.pass_obj +def move_subtree(cli_ctx: CLIContext, src_prefix, dst_prefix, scope) -> None: + ''' + Move a subtree to another key prefix. + ''' + async def _impl(): + async with etcd_ctx(cli_ctx) as etcd: + try: + subtree = await etcd.get_prefix(src_prefix, scope=scope) + await etcd.put_prefix(dst_prefix, subtree, scope=scope) + await etcd.delete_prefix(src_prefix, scope=scope) + except Exception: + log.exception('An error occurred.') + with cli_ctx.logger: + asyncio.run(_impl()) + + +@cli.command() +@click.argument('key') +@click.option('--prefix', is_flag=True, + help='Get all key-value pairs prefixed with the given key ' + 'as a JSON form.') +@click.option('-s', '--scope', type=EnumChoice(ConfigScopes), + default=ConfigScopes.GLOBAL, + help='The configuration scope to put the value.') +@click.pass_obj +def get(cli_ctx: CLIContext, key, prefix, scope) -> None: + ''' + Get the value of a key in the configured etcd namespace. + ''' + async def _impl(): + async with etcd_ctx(cli_ctx) as etcd: + try: + if prefix: + data = await etcd.get_prefix(key, scope=scope) + print(json.dumps(dict(data), indent=4)) + else: + val = await etcd.get(key, scope=scope) + if val is None: + sys.exit(1) + print(val) + except Exception: + log.exception('An error occurred.') + with cli_ctx.logger: + asyncio.run(_impl()) + + +@cli.command() +@click.argument('key') +@click.option('--prefix', is_flag=True, + help='Delete all keys prefixed with the given key.') +@click.option('-s', '--scope', type=EnumChoice(ConfigScopes), + default=ConfigScopes.GLOBAL, + help='The configuration scope to put the value.') +@click.pass_obj +def delete(cli_ctx: CLIContext, key, prefix, scope) -> None: + '''Delete the key in the configured etcd namespace.''' + async def _impl(): + async with etcd_ctx(cli_ctx) as etcd: + try: + if prefix: + await etcd.delete_prefix(key, scope=scope) + else: + await etcd.delete(key, scope=scope) + except Exception: + log.exception('An error occurred.') + with cli_ctx.logger: + asyncio.run(_impl()) + + +@cli.command() +@click.option('-s', '--short', is_flag=True, + help='Show only the image references and digests.') +@click.option('-i', '--installed', is_flag=True, + help='Show only the installed images.') +@click.pass_obj +def list_images(cli_ctx, short, installed) -> None: + '''List all configured images.''' + with cli_ctx.logger: + log.warn('etcd list-images command is deprecated, use image list instead') + asyncio.run(list_images_impl(cli_ctx, short, installed)) + + +@cli.command() +@click.argument('canonical_or_alias') +@click.argument('architecture') +@click.pass_obj +def inspect_image(cli_ctx, canonical_or_alias, architecture) -> None: + '''Show the details of the given image or alias.''' + with cli_ctx.logger: + log.warn('etcd inspect-image command is deprecated, use image inspect instead') + asyncio.run(inspect_image_impl(cli_ctx, canonical_or_alias, architecture)) + + +@cli.command() +@click.argument('canonical_or_alias') +@click.argument('architecture') +@click.pass_obj +def forget_image(cli_ctx, canonical_or_alias, architecture) -> None: + '''Forget (delete) a specific image.''' + with cli_ctx.logger: + log.warn('etcd forget-image command is deprecated, use image forget instead') + asyncio.run(forget_image_impl(cli_ctx, canonical_or_alias, architecture)) + + +@cli.command() +@click.argument('canonical_or_alias') +@click.argument('slot_type') +@click.argument('range_value', type=MinMaxRange) +@click.argument('architecture') +@click.pass_obj +def set_image_resource_limit( + cli_ctx, + canonical_or_alias, + slot_type, + range_value, + architecture, +) -> None: + '''Set the MIN:MAX values of a SLOT_TYPE limit for the given image REFERENCE.''' + with cli_ctx.logger: + log.warn('etcd set-image-resource-limit command is deprecated, ' + 'use image set-resource-limit instead') + asyncio.run(set_image_resource_limit_impl( + cli_ctx, + canonical_or_alias, + slot_type, + range_value, + architecture, + )) + + +@cli.command() +@click.argument('registry') +@click.pass_obj +def rescan_images(cli_ctx: CLIContext, registry) -> None: + ''' + Update the kernel image metadata from all configured docker registries. + + Pass the name (usually hostname or "lablup") of the Docker registry configured as REGISTRY. + ''' + with cli_ctx.logger: + log.warn('etcd rescan-images command is deprecated, use image rescan instead') + asyncio.run(rescan_images_impl(cli_ctx, registry)) + + +@cli.command() +@click.argument('alias') +@click.argument('target') +@click.argument('architecture') +@click.pass_obj +def alias(cli_ctx, alias, target, architecture) -> None: + '''Add an image alias from the given alias to the target image reference.''' + with cli_ctx.logger: + log.warn('etcd alias command is deprecated, use image alias instead') + asyncio.run(alias_impl(cli_ctx, alias, target, architecture)) + + +@cli.command() +@click.argument('alias') +@click.pass_obj +def dealias(cli_ctx, alias) -> None: + '''Remove an alias.''' + with cli_ctx.logger: + log.warn('etcd dealias command is deprecated, use image dealias instead') + asyncio.run(dealias_impl(cli_ctx, alias)) + + +@cli.command() +@click.argument('value') +@click.pass_obj +def quote(cli_ctx: CLIContext, value) -> None: + ''' + Quote the given string for use as a URL piece in etcd keys. + Use this to generate argument inputs for aliases and raw image keys. + ''' + print(etcd_quote(value)) + + +@cli.command() +@click.argument('value') +@click.pass_obj +def unquote(cli_ctx: CLIContext, value) -> None: + ''' + Unquote the given string used as a URL piece in etcd keys. + ''' + print(etcd_unquote(value)) diff --git a/src/ai/backend/manager/cli/fixture.py b/src/ai/backend/manager/cli/fixture.py new file mode 100644 index 0000000000..407cc7bf04 --- /dev/null +++ b/src/ai/backend/manager/cli/fixture.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import asyncio +import logging +import json +from pathlib import Path +from typing import TYPE_CHECKING +from urllib.parse import quote_plus as urlquote + +import click +import sqlalchemy as sa + +from ai.backend.common.logging import BraceStyleAdapter + +from ..models.base import populate_fixture + +if TYPE_CHECKING: + from .context import CLIContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.argument('fixture_path', type=Path) +@click.pass_obj +def populate(cli_ctx: CLIContext, fixture_path) -> None: + + async def _impl(): + log.info("Populating fixture '{0}' ...", fixture_path) + try: + fixture = json.loads(fixture_path.read_text(encoding='utf8')) + except AttributeError: + log.error('No such fixture.') + return + db_username = cli_ctx.local_config['db']['user'] + db_password = cli_ctx.local_config['db']['password'] + db_addr = cli_ctx.local_config['db']['addr'] + db_name = cli_ctx.local_config['db']['name'] + engine = sa.ext.asyncio.create_async_engine( + f"postgresql+asyncpg://{urlquote(db_username)}:{urlquote(db_password)}@{db_addr}/{db_name}", + ) + try: + await populate_fixture(engine, fixture) + except: + log.exception("Failed to populate fixtures due to the following error:") + else: + log.info("Done") + log.warning("Some rows may be skipped if they already exist.") + finally: + await engine.dispose() + + """Populate fixtures.""" + with cli_ctx.logger: + asyncio.run(_impl()) + + +@cli.command() +@click.pass_obj +def list(cli_ctx: CLIContext) -> None: + """List all available fixtures.""" + with cli_ctx.logger: + log.warning('This command is deprecated.') diff --git a/src/ai/backend/manager/cli/gql.py b/src/ai/backend/manager/cli/gql.py new file mode 100644 index 0000000000..4a74af66be --- /dev/null +++ b/src/ai/backend/manager/cli/gql.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import click +import graphene + +from ai.backend.common.logging import BraceStyleAdapter + +from ..models.gql import Queries, Mutations + +if TYPE_CHECKING: + from .context import CLIContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@click.group() +def cli(args) -> None: + pass + + +@cli.command() +@click.pass_obj +def show(cli_ctx: CLIContext) -> None: + with cli_ctx.logger: + schema = graphene.Schema( + query=Queries, + mutation=Mutations, + auto_camelcase=False) + log.info('======== GraphQL API Schema ========') + print(str(schema)) diff --git a/src/ai/backend/manager/cli/image.py b/src/ai/backend/manager/cli/image.py new file mode 100644 index 0000000000..8c2a755e28 --- /dev/null +++ b/src/ai/backend/manager/cli/image.py @@ -0,0 +1,112 @@ +import asyncio +import logging + +import click + +from ai.backend.common.cli import MinMaxRange +from ai.backend.common.logging import BraceStyleAdapter + +from .image_impl import ( + alias as alias_impl, + dealias as dealias_impl, + forget_image as forget_image_impl, + inspect_image as inspect_image_impl, + list_images as list_images_impl, + rescan_images as rescan_images_impl, + set_image_resource_limit as set_image_resource_limit_impl, +) +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@click.group() +def cli() -> None: + pass + + +@cli.command() +@click.option('-s', '--short', is_flag=True, + help='Show only the image references and digests.') +@click.option('-i', '--installed', is_flag=True, + help='Show only the installed images.') +@click.pass_obj +def list(cli_ctx, short, installed) -> None: + '''List all configured images.''' + with cli_ctx.logger: + asyncio.run(list_images_impl(cli_ctx, short, installed)) + + +@cli.command() +@click.argument('canonical_or_alias') +@click.argument('architecture') +@click.pass_obj +def inspect(cli_ctx, canonical_or_alias, architecture) -> None: + '''Show the details of the given image or alias.''' + with cli_ctx.logger: + asyncio.run(inspect_image_impl(cli_ctx, canonical_or_alias, architecture)) + + +@cli.command() +@click.argument('canonical_or_alias') +@click.argument('architecture') +@click.pass_obj +def forget(cli_ctx, canonical_or_alias, architecture) -> None: + '''Forget (delete) a specific image.''' + with cli_ctx.logger: + asyncio.run(forget_image_impl(cli_ctx, canonical_or_alias, architecture)) + + +@cli.command() +@click.argument('canonical_or_alias') +@click.argument('slot_type') +@click.argument('range_value', type=MinMaxRange) +@click.argument('architecture') +@click.pass_obj +def set_resource_limit( + cli_ctx, + canonical_or_alias, + slot_type, + range_value, + architecture, +) -> None: + '''Set the MIN:MAX values of a SLOT_TYPE limit for the given image REFERENCE.''' + with cli_ctx.logger: + asyncio.run(set_image_resource_limit_impl( + cli_ctx, + canonical_or_alias, + slot_type, + range_value, + architecture, + )) + + +@cli.command() +@click.argument('registry') +@click.pass_obj +def rescan(cli_ctx, registry) -> None: + ''' + Update the kernel image metadata from all configured docker registries. + + Pass the name (usually hostname or "lablup") of the Docker registry configured as REGISTRY. + ''' + with cli_ctx.logger: + asyncio.run(rescan_images_impl(cli_ctx, registry)) + + +@cli.command() +@click.argument('alias') +@click.argument('target') +@click.argument('architecture') +@click.pass_obj +def alias(cli_ctx, alias, target, architecture) -> None: + '''Add an image alias from the given alias to the target image reference.''' + with cli_ctx.logger: + asyncio.run(alias_impl(cli_ctx, alias, target, architecture)) + + +@cli.command() +@click.argument('alias') +@click.pass_obj +def dealias(cli_ctx, alias) -> None: + '''Remove an alias.''' + with cli_ctx.logger: + asyncio.run(dealias_impl(cli_ctx, alias)) diff --git a/src/ai/backend/manager/cli/image_impl.py b/src/ai/backend/manager/cli/image_impl.py new file mode 100644 index 0000000000..4b909c0065 --- /dev/null +++ b/src/ai/backend/manager/cli/image_impl.py @@ -0,0 +1,152 @@ +import contextlib +import logging +from pprint import pprint + +import sqlalchemy as sa +from tabulate import tabulate +from typing import AsyncIterator + +from ai.backend.common.docker import ImageRef +from ai.backend.common.etcd import AsyncEtcd, ConfigScopes +from ai.backend.common.exception import UnknownImageReference +from ai.backend.common.logging import BraceStyleAdapter + +from ai.backend.manager.models.image import ( + ImageAliasRow, + ImageRow, + rescan_images as rescan_images_func, +) +from ai.backend.manager.models.utils import ( + connect_database, +) + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@contextlib.asynccontextmanager +async def etcd_ctx(cli_ctx) -> AsyncIterator[AsyncEtcd]: + local_config = cli_ctx.local_config + creds = None + if local_config['etcd']['user']: + creds = { + 'user': local_config['etcd']['user'], + 'password': local_config['etcd']['password'], + } + scope_prefix_map = { + ConfigScopes.GLOBAL: '', + # TODO: provide a way to specify other scope prefixes + } + etcd = AsyncEtcd(local_config['etcd']['addr'], local_config['etcd']['namespace'], + scope_prefix_map, credentials=creds) + try: + yield etcd + finally: + await etcd.close() + + +async def list_images(cli_ctx, short, installed): + async with connect_database(cli_ctx.local_config) as db: + async with db.begin_readonly_session() as session: + displayed_items = [] + try: + items = await ImageRow.list(session) + # NOTE: installed/installed_agents fields are no longer provided in CLI, + # until we finish the epic refactoring of image metadata db. + for item in items: + if installed and not item.installed: + continue + if short: + displayed_items.append((item.image_ref.canonical, item.config_digest)) + else: + pprint(item) + if short: + print(tabulate(displayed_items, tablefmt='plain')) + except Exception: + log.exception('An error occurred.') + + +async def inspect_image(cli_ctx, canonical_or_alias, architecture): + async with connect_database(cli_ctx.local_config) as db: + async with db.begin_readonly_session() as session: + try: + image_row = await ImageRow.resolve(session, [ + ImageRef(canonical_or_alias, ['*'], architecture), + canonical_or_alias, + ]) + pprint(await image_row.inspect()) + except UnknownImageReference: + log.exception('Image not found.') + except Exception: + log.exception('An error occurred.') + + +async def forget_image(cli_ctx, canonical_or_alias, architecture): + async with connect_database(cli_ctx.local_config) as db: + async with db.begin_session() as session: + try: + image_row = await ImageRow.resolve(session, [ + ImageRef(canonical_or_alias, ['*'], architecture), + canonical_or_alias, + ]) + await session.delete(image_row) + except UnknownImageReference: + log.exception('Image not found.') + except Exception: + log.exception('An error occurred.') + + +async def set_image_resource_limit( + cli_ctx, + canonical_or_alias, + slot_type, + range_value, + architecture, +): + async with connect_database(cli_ctx.local_config) as db: + async with db.begin_session() as session: + try: + image_row = await ImageRow.resolve(session, [ + ImageRef(canonical_or_alias, ['*'], architecture), + canonical_or_alias, + ]) + await image_row.set_resource_limit(slot_type, range_value) + except UnknownImageReference: + log.exception('Image not found.') + except Exception: + log.exception('An error occurred.') + + +async def rescan_images(cli_ctx, registry): + async with connect_database(cli_ctx.local_config) as db: + async with etcd_ctx(cli_ctx) as etcd: + try: + await rescan_images_func(etcd, db, registry=registry) + except Exception: + log.exception('An error occurred.') + + +async def alias(cli_ctx, alias, target, architecture): + async with connect_database(cli_ctx.local_config) as db: + async with db.begin_session() as session: + try: + image_row = await ImageRow.resolve(session, [ + ImageRef(target, ['*'], architecture), + ]) + await ImageAliasRow.create(session, alias, image_row) + except UnknownImageReference: + log.exception('Image not found.') + except Exception: + log.exception('An error occurred.') + + +async def dealias(cli_ctx, alias): + async with connect_database(cli_ctx.local_config) as db: + async with db.begin_session() as session: + alias_row = await session.scalar( + sa.select(ImageAliasRow) + .where(ImageAliasRow.alias == alias), + ) + if alias_row is None: + log.exception('Alias not found.') + return + await session.delete(alias_row) diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py new file mode 100644 index 0000000000..daa844c73c --- /dev/null +++ b/src/ai/backend/manager/config.py @@ -0,0 +1,594 @@ +from __future__ import annotations + +""" +Configuration Schema on etcd +---------------------------- + +The etcd (v3) itself is a flat key-value storage, but we use its prefix-based filtering +by using a directory-like configuration structure. +At the root, it contains "/sorna/{namespace}" as the common prefix. + +In most cases, a single global configurations are sufficient, but cluster administrators +may want to apply different settings (e.g., resource slot types, vGPU sizes, etc.) +to different scaling groups or even each node. + +To support such requirements, we add another level of prefix named "configuration scope". +There are three types of configuration scopes: + + * Global + * Scaling group + * Node + +When reading configurations, the underlying `ai.backend.common.etcd.AsyncEtcd` class +returns a `collections.ChainMap` instance that merges three configuration scopes +in the order of node, scaling group, and global, so that node-level configs override +scaling-group configs, and scaling-group configs override global configs if they exist. + +Note that the global scope prefix may be an empty string; this allows use of legacy +etcd databases without explicit migration. When the global scope prefix is an empty string, +it does not make a new depth in the directory structure, so "{namespace}/config/x" (not +"{namespace}//config/x"!) is recognized as the global config. + +Notes on Docker registry configurations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A registry name contains the host, port (only for non-standards), and the path. +So, they must be URL-quoted (including slashes) to avoid parsing +errors due to intermediate slashes and colons. +Alias keys are also URL-quoted in the same way. + +{namespace} + + '' # ConfigScoeps.GLOBAL + + config + + system + - timezone: "UTC" # pytz-compatible timezone names (e.g., "Asia/Seoul") + + api + - allow-origins: "*" + + resources + - group_resource_visibility: "true" # return group resource status in check-presets + # (default: false) + + docker + + image + - auto_pull: "digest" (default) | "tag" | "none" + + registry + + "index.docker.io": "https://registry-1.docker.io" + - username: "lablup" + + {registry-name}: {registry-URL} # {registry-name} is url-quoted + - username: {username} + - password: {password} + - type: "docker" | "harbor" | "harbor2" + - project: "project1-name,project2-name,..." # harbor only + - ssl-verify: "yes" | "no" + ... + + redis + - addr: "{redis-host}:{redis-port}" + - password: {password} + + idle + - enabled: "timeout,utilization" # comma-separated list of checker names + - app-streaming-packet-timeout: "5m" # in seconds; idleness of app-streaming TCP connections + # NOTE: idle checkers get activated AFTER the app-streaming packet timeout has passed. + - checkers + + "timeout" + - threshold: "10m" + + "utilization" + + resource-thresholds + + "cpu_util" + - average: 30 # in percent + + "mem" + - average: 30 # in percent + + "cuda_util" + - average: 30 # in percent # CUDA core utilization + + "cuda_mem" + - average: 30 # in percent + # NOTE: To use "cuda.mem" criteria, user programs must use + # an incremental allocation strategy for CUDA memory. + - thresholds-check-operator: "and" + # "and" (default, so any other words except the "or"): + # garbage collect a session only when ALL of the resources are + # under-utilized not exceeding their thresholds. + # ex) (cpu < threshold) AND (mem < threshold) AND ... + # "or": + # garbage collect a session when ANY of the resources is + # under-utilized not exceeding their thresholds. + # ex) (cpu < threshold) OR (mem < threshold) OR ... + - time-window: "12h" # time window to average utilization + # a session will not be terminated until this time + - initial-grace-period: "5m" # time to allow to be idle for first + # "session_lifetime" does not have etcd config but it is configured via + # the keypair_resource_polices table. + + resource_slots + - {"cuda.device"}: {"count"} + - {"cuda.mem"}: {"bytes"} + - {"cuda.smp"}: {"count"} + ... + + plugins + + accelerator + + "cuda" + - allocation_mode: "discrete" + ... + + scheduler + + "fifo" + + "lifo" + + "drf" + ... + + network + + subnet + - agent: "0.0.0.0/0" + - container: "0.0.0.0/0" + + overlay + - mtu: 1500 # Maximum Transmission Unit + + rpc + - keepalive-timeout: 60 # seconds + + watcher + - token: {some-secret} + + volumes + # pre-20.09 + - _mount: {path-to-mount-root-for-vfolder-partitions} + - _default_host: {default-vfolder-partition-name} + - _fsprefix: {path-prefix-inside-host-mounts} + # 20.09 and later + - default_host: "{default-proxy}:{default-volume}" + + proxies: # each proxy may provide multiple volumes + + "local" # proxy name + - client_api: "http://localhost:6021" + - manager_api: "http://localhost:6022" + - secret: "xxxxxx..." # for manager API + - ssl_verify: true | false # for manager API + + "mynas1" + - client_api: "https://proxy1.example.com:6021" + - manager_api: "https://proxy1.example.com:6022" + - secret: "xxxxxx..." # for manager API + - ssl_verify: true | false # for manager API + ... + ... + ... + + nodes + + manager + - {instance-id}: "up" + ... + # etcd.get("config/redis/addr") is not None => single redis node + # etcd.get("config/redis/sentinel") is not None => redis sentinel + + redis: + - addr: "tcp://redis:6379" + - sentinel: {comma-seperated list of sentinel addresses} + - service_name: "mymanager" + - password: {redis-auth-password} + + agents + + {instance-id}: {"starting","running"} # ConfigScopes.NODE + - ip: {"127.0.0.1"} + - watcher_port: {"6009"} + ... + + sgroup + + {name} # ConfigScopes.SGROUP + - swarm-manager/token + - swarm-manager/host + - swarm-worker/token + - iprange # to choose ethernet iface when creating containers + - resource_policy # the name of scaling-group resource-policy in database + + nodes + - {instance-id}: 1 # just a membership set +""" + +from abc import abstractmethod +from collections import UserDict +from contextvars import ContextVar +import logging +import os +from pathlib import Path +from pprint import pformat +import secrets +import socket +import sys +from typing import ( + Any, + Awaitable, + Callable, + Final, + List, + Mapping, + Optional, + Sequence, +) + +import aiotools +import click +import trafaret as t +import yarl + +from ai.backend.common import config, validators as tx +from ai.backend.common.etcd import AsyncEtcd +from ai.backend.common.identity import get_instance_id +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + SlotName, SlotTypes, + HostPortPair, + current_resource_slots, +) +from ai.backend.common.etcd import ConfigScopes + +from .api.exceptions import ServerMisconfiguredError +from .api.manager import ManagerStatus +from ..manager.defs import INTRINSIC_SLOTS + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +_max_cpu_count = os.cpu_count() +_file_perm = (Path(__file__).parent / 'server.py').stat() + +DEFAULT_CHUNK_SIZE: Final = 256 * 1024 # 256 KiB +DEFAULT_INFLIGHT_CHUNKS: Final = 8 + +shared_config_defaults = { + 'volumes/_mount': '/mnt', + 'volumes/_default_host': 'local', + 'volumes/_fsprefix': '/', + 'config/api/allow-origins': '*', + 'config/docker/image/auto_pull': 'digest', +} + +current_vfolder_types: ContextVar[List[str]] = ContextVar('current_vfolder_types') + +manager_local_config_iv = t.Dict({ + t.Key('db'): t.Dict({ + t.Key('type', default='postgresql'): t.Enum('postgresql'), + t.Key('addr'): tx.HostPortPair, + t.Key('name'): tx.Slug[2:64], + t.Key('user'): t.String, + t.Key('password'): t.String, + }), + t.Key('manager'): t.Dict({ + t.Key('ipc-base-path', default="/tmp/backend.ai/manager/ipc"): + tx.Path(type='dir', auto_create=True), + t.Key('num-proc', default=_max_cpu_count): t.Int[1:_max_cpu_count], + t.Key('id', default=f"i-{socket.gethostname()}"): t.String, + t.Key('user', default=None): tx.UserID(default_uid=_file_perm.st_uid), + t.Key('user', default=None): tx.UserID(default_uid=_file_perm.st_uid), + t.Key('group', default=None): tx.GroupID(default_gid=_file_perm.st_gid), + t.Key('service-addr', default=('0.0.0.0', 8080)): tx.HostPortPair, + t.Key('heartbeat-timeout', default=5.0): t.Float[1.0:], # type: ignore + t.Key('secret', default=None): t.Null | t.String, + t.Key('ssl-enabled', default=False): t.ToBool, + t.Key('ssl-cert', default=None): t.Null | tx.Path(type='file'), + t.Key('ssl-privkey', default=None): t.Null | tx.Path(type='file'), + t.Key('event-loop', default='asyncio'): t.Enum('asyncio', 'uvloop'), + t.Key('distributed-lock', default='pg_advisory'): + t.Enum('filelock', 'pg_advisory', 'redlock', 'etcd'), + t.Key('pid-file', default=os.devnull): tx.Path( + type='file', + allow_nonexisting=True, + allow_devnull=True, + ), + t.Key('hide-agents', default=False): t.Bool, + t.Key('importer-image', default='lablup/importer:manylinux2010'): t.String, + t.Key('max-wsmsg-size', default=16 * (2**20)): t.ToInt, # default: 16 MiB + t.Key('aiomonitor-port', default=50001): t.Int[1:65535], + }).allow_extra('*'), + t.Key('docker-registry'): t.Dict({ # deprecated in v20.09 + t.Key('ssl-verify', default=True): t.ToBool, + }).allow_extra('*'), + t.Key('logging'): t.Any, # checked in ai.backend.common.logging + t.Key('debug'): t.Dict({ + t.Key('enabled', default=False): t.ToBool, + t.Key('log-events', default=False): t.ToBool, + t.Key('log-scheduler-ticks', default=False): t.ToBool, + t.Key('periodic-sync-stats', default=False): t.ToBool, + }).allow_extra('*'), +}).merge(config.etcd_config_iv).allow_extra('*') + +_shdefs: Mapping[str, Any] = { + 'system': { + 'timezone': 'UTC', + }, + 'api': { + 'allow-origins': '*', + }, + 'redis': { + 'addr': '127.0.0.1:6379', + 'password': None, + }, + 'docker': { + 'registry': {}, + }, + 'network': { + 'subnet': { + 'agent': '0.0.0.0/0', + 'container': '0.0.0.0/0', + }, + }, + 'plugins': { + 'accelerator': {}, + 'scheduler': {}, + }, + 'watcher': { + 'token': None, + }, +} + +container_registry_iv = t.Dict({ + t.Key(''): tx.URL, + t.Key('type', default="docker"): t.String, + t.Key('username', default=None): t.Null | t.String, + t.Key('password', default=None): t.Null | t.String, + t.Key('project', default=None): t.Null | tx.StringList | t.List(t.String), + t.Key('ssl-verify', default=True): t.ToBool, +}).allow_extra('*') + +shared_config_iv = t.Dict({ + t.Key('system', default=_shdefs['system']): t.Dict({ + t.Key('timezone', default=_shdefs['system']['timezone']): tx.TimeZone, + }).allow_extra('*'), + t.Key('api', default=_shdefs['api']): t.Dict({ + t.Key('allow-origins', default=_shdefs['api']['allow-origins']): t.String, + }).allow_extra('*'), + t.Key('redis', default=_shdefs['redis']): t.Dict({ + t.Key('addr', default=_shdefs['redis']['addr']): t.Null | tx.HostPortPair, + t.Key('sentinel', default=None): t.Null | tx.DelimiterSeperatedList(tx.HostPortPair), + t.Key('service_name', default=None): t.Null | t.String, + t.Key('password', default=_shdefs['redis']['password']): t.Null | t.String, + }).allow_extra('*'), + t.Key('docker', default=_shdefs['docker']): t.Dict({ + t.Key('registry'): t.Mapping(t.String, container_registry_iv), + }).allow_extra('*'), + t.Key('plugins', default=_shdefs['plugins']): t.Dict({ + t.Key('accelerator', default=_shdefs['plugins']['accelerator']): + t.Mapping(t.String, t.Mapping(t.String, t.Any)), + t.Key('scheduler', default=_shdefs['plugins']['scheduler']): + t.Mapping(t.String, t.Mapping(t.String, t.Any)), + }).allow_extra('*'), + t.Key('network', default=_shdefs['network']): t.Dict({ + t.Key('subnet', default=_shdefs['network']['subnet']): t.Dict({ + t.Key('agent', default=_shdefs['network']['subnet']['agent']): tx.IPNetwork, + t.Key('container', default=_shdefs['network']['subnet']['container']): tx.IPNetwork, + }).allow_extra('*'), + t.Key('overlay', default=None): t.Null | t.Dict({ + t.Key('mtu', default=1500): t.Int[1:], + }).allow_extra('*'), + }).allow_extra('*'), + t.Key('watcher', default=_shdefs['watcher']): t.Dict({ + t.Key('token', default=_shdefs['watcher']['token']): t.Null | t.String, + }).allow_extra('*'), +}).allow_extra('*') + +volume_config_iv = t.Dict({ + t.Key('default_host'): t.String, + t.Key('proxies'): t.Mapping( + tx.Slug, + t.Dict({ + t.Key('client_api'): t.String, + t.Key('manager_api'): t.String, + t.Key('secret'): t.String, + t.Key('ssl_verify'): t.ToBool, + }), + ), +}).allow_extra('*') + + +ConfigWatchCallback = Callable[[Sequence[str]], Awaitable[None]] + + +class AbstractConfig(UserDict): + + _watch_callbacks: List[ConfigWatchCallback] + + def __init__(self, initial_data: Mapping[str, Any] = None) -> None: + super().__init__(initial_data) + self._watch_callbacks = [] + + @abstractmethod + async def reload(self) -> None: + pass + + def add_watch_callback(self, cb: ConfigWatchCallback) -> None: + self._watch_callbacks.append(cb) + + async def dispatch_watch_callbacks(self, updated_keys: Sequence[str]) -> None: + for cb in self._watch_callbacks: + await cb(updated_keys) + + +class LocalConfig(AbstractConfig): + + async def reload(self) -> None: + raise NotImplementedError + + +def load(config_path: Path = None, debug: bool = False) -> LocalConfig: + + # Determine where to read configuration. + raw_cfg, cfg_src_path = config.read_from_file(config_path, 'manager') + + # Override the read config with environment variables (for legacy). + config.override_with_env(raw_cfg, ('etcd', 'namespace'), 'BACKEND_NAMESPACE') + config.override_with_env(raw_cfg, ('etcd', 'addr'), 'BACKEND_ETCD_ADDR') + config.override_with_env(raw_cfg, ('etcd', 'user'), 'BACKEND_ETCD_USER') + config.override_with_env(raw_cfg, ('etcd', 'password'), 'BACKEND_ETCD_PASSWORD') + config.override_with_env(raw_cfg, ('db', 'addr'), 'BACKEND_DB_ADDR') + config.override_with_env(raw_cfg, ('db', 'name'), 'BACKEND_DB_NAME') + config.override_with_env(raw_cfg, ('db', 'user'), 'BACKEND_DB_USER') + config.override_with_env(raw_cfg, ('db', 'password'), 'BACKEND_DB_PASSWORD') + config.override_with_env(raw_cfg, ('manager', 'num-proc'), 'BACKEND_MANAGER_NPROC') + config.override_with_env(raw_cfg, ('manager', 'ssl-cert'), 'BACKEND_SSL_CERT') + config.override_with_env(raw_cfg, ('manager', 'ssl-privkey'), 'BACKEND_SSL_KEY') + config.override_with_env(raw_cfg, ('manager', 'pid-file'), 'BACKEND_PID_FILE') + config.override_with_env(raw_cfg, ('manager', 'api-listen-addr', 'host'), + 'BACKEND_SERVICE_IP') + config.override_with_env(raw_cfg, ('manager', 'api-listen-addr', 'port'), + 'BACKEND_SERVICE_PORT') + config.override_with_env(raw_cfg, ('manager', 'event-listen-addr', 'host'), + 'BACKEND_ADVERTISED_MANAGER_HOST') + config.override_with_env(raw_cfg, ('manager', 'event-listen-addr', 'port'), + 'BACKEND_EVENTS_PORT') + config.override_with_env(raw_cfg, ('docker-registry', 'ssl-verify'), + 'BACKEND_SKIP_SSLCERT_VALIDATION') + if debug: + config.override_key(raw_cfg, ('debug', 'enabled'), True) + config.override_key(raw_cfg, ('logging', 'level'), 'DEBUG') + config.override_key(raw_cfg, ('logging', 'pkg-ns', 'ai.backend'), 'DEBUG') + config.override_key(raw_cfg, ('logging', 'pkg-ns', 'aiohttp'), 'DEBUG') + + # Validate and fill configurations + # (allow_extra will make configs to be forward-copmatible) + try: + cfg = config.check(raw_cfg, manager_local_config_iv) + if 'debug' in cfg and cfg['debug']['enabled']: + print('== Manager configuration ==', file=sys.stderr) + print(pformat(cfg), file=sys.stderr) + cfg['_src'] = cfg_src_path + if cfg['manager']['secret'] is None: + cfg['manager']['secret'] = secrets.token_urlsafe(16) + except config.ConfigurationError as e: + print('Validation of manager configuration has failed:', file=sys.stderr) + print(pformat(e.invalid_data), file=sys.stderr) + raise click.Abort() + else: + return LocalConfig(cfg) + + +class SharedConfig(AbstractConfig): + + def __init__( + self, + etcd_addr: HostPortPair, + etcd_user: Optional[str], + etcd_password: Optional[str], + namespace: str, + ) -> None: + # WARNING: importing etcd3/grpc must be done after forks. + super().__init__() + credentials = None + if etcd_user: + credentials = { + 'user': etcd_user, + 'password': etcd_password, + } + scope_prefix_map = { + ConfigScopes.GLOBAL: '', + # TODO: provide a way to specify other scope prefixes + } + self.etcd = AsyncEtcd(etcd_addr, namespace, scope_prefix_map, credentials=credentials) + + async def close(self) -> None: + await self.etcd.close() + + async def reload(self) -> None: + raw_cfg = await self.etcd.get_prefix('config') + try: + cfg = shared_config_iv.check(raw_cfg) + except config.ConfigurationError as e: + print('Validation of shared etcd configuration has failed:', file=sys.stderr) + print(pformat(e.invalid_data), file=sys.stderr) + raise click.Abort() + else: + self.data = cfg + + def __hash__(self) -> int: + # When used as a key in dicts, we don't care our contents. + # Just treat it lke an opaque object. + return hash(id(self)) + + async def get_raw(self, key: str, allow_null: bool = True) -> Optional[str]: + value = await self.etcd.get(key) + if value is None: + value = shared_config_defaults.get(key, None) + if not allow_null and value is None: + raise ServerMisconfiguredError( + 'A required etcd config is missing.', key) + return value + + async def register_myself(self) -> None: + instance_id = await get_instance_id() + manager_info = { + f'nodes/manager/{instance_id}': 'up', + } + await self.etcd.put_dict(manager_info) + + async def deregister_myself(self) -> None: + instance_id = await get_instance_id() + await self.etcd.delete_prefix(f'nodes/manager/{instance_id}') + + async def update_resource_slots( + self, + slot_key_and_units: Mapping[SlotName, SlotTypes], + ) -> None: + updates = {} + known_slots = await self.get_resource_slots() + for k, v in slot_key_and_units.items(): + if k not in known_slots or v != known_slots[k]: + updates[f'config/resource_slots/{k}'] = v.value + if updates: + await self.etcd.put_dict(updates) + + async def update_manager_status(self, status) -> None: + await self.etcd.put('manager/status', status.value) + self.get_manager_status.cache_clear() + + @aiotools.lru_cache(maxsize=1, expire_after=2.0) + async def _get_resource_slots(self): + raw_data = await self.etcd.get_prefix_dict('config/resource_slots') + return { + SlotName(k): SlotTypes(v) for k, v in raw_data.items() + } + + async def get_resource_slots(self) -> Mapping[SlotName, SlotTypes]: + """ + Returns the system-wide known resource slots and their units. + """ + try: + ret = current_resource_slots.get() + except LookupError: + configured_slots = await self._get_resource_slots() + ret = {**INTRINSIC_SLOTS, **configured_slots} + current_resource_slots.set(ret) + return ret + + @aiotools.lru_cache(maxsize=1, expire_after=2.0) + async def _get_vfolder_types(self): + return await self.etcd.get_prefix('volumes/_types') + + async def get_vfolder_types(self) -> Sequence[str]: + """ + Returns the vfolder types currently set. One of "user" and/or "group". + If none is specified, "user" type is implicitly assumed. + """ + try: + ret = current_vfolder_types.get() + except LookupError: + vf_types = await self._get_vfolder_types() + if not vf_types: + vf_types = {'user': ''} + ret = list(vf_types.keys()) + current_vfolder_types.set(ret) + return ret + + @aiotools.lru_cache(maxsize=1, expire_after=5.0) + async def get_manager_nodes_info(self): + return await self.etcd.get_prefix_dict('nodes/manager') + + @aiotools.lru_cache(maxsize=1, expire_after=2.0) + async def get_manager_status(self) -> ManagerStatus: + status = await self.etcd.get('manager/status') + if status is None: + return ManagerStatus.TERMINATED + return ManagerStatus(status) + + async def watch_manager_status(self): + async with aiotools.aclosing(self.etcd.watch('manager/status')) as agen: + async for ev in agen: + yield ev + + # TODO: refactor using contextvars in Python 3.7 so that the result is cached + # in a per-request basis. + @aiotools.lru_cache(maxsize=1, expire_after=2.0) + async def get_allowed_origins(self): + return await self.etcd.get('config/api/allow-origins') + + def get_redis_url(self, db: int = 0) -> yarl.URL: + """ + Returns a complete URL composed from the given Redis config. + """ + url = (yarl.URL('redis://host') + .with_host(str(self.data['redis']['addr'][0])) + .with_port(self.data['redis']['addr'][1]) + .with_password(self.data['redis']['password']) + / str(db)) + return url diff --git a/src/ai/backend/manager/container_registry/__init__.py b/src/ai/backend/manager/container_registry/__init__.py new file mode 100644 index 0000000000..35159be142 --- /dev/null +++ b/src/ai/backend/manager/container_registry/__init__.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import Any, Mapping, Type, TYPE_CHECKING + +import yarl + +if TYPE_CHECKING: + from .base import BaseContainerRegistry + + +def get_container_registry(registry_info: Mapping[str, Any]) -> Type[BaseContainerRegistry]: + registry_url = yarl.URL(registry_info['']) + registry_type = registry_info.get('type', 'docker') + cr_cls: Type[BaseContainerRegistry] + if registry_url.host is not None and registry_url.host.endswith('.docker.io'): + from .docker import DockerHubRegistry + cr_cls = DockerHubRegistry + elif registry_type == 'docker': + from .docker import DockerRegistry_v2 + cr_cls = DockerRegistry_v2 + elif registry_type == 'harbor': + from .harbor import HarborRegistry_v1 + cr_cls = HarborRegistry_v1 + elif registry_type == 'harbor2': + from .harbor import HarborRegistry_v2 + cr_cls = HarborRegistry_v2 + else: + raise RuntimeError(f"Unsupported registry type: {registry_type}") + return cr_cls diff --git a/src/ai/backend/manager/container_registry/base.py b/src/ai/backend/manager/container_registry/base.py new file mode 100644 index 0000000000..51c7360a35 --- /dev/null +++ b/src/ai/backend/manager/container_registry/base.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import asyncio +import logging +import json +from contextvars import ContextVar +from typing import ( + Any, + AsyncIterator, + Dict, + Mapping, + Optional, + cast, +) + +import aiohttp +import aiotools +import sqlalchemy as sa +import yarl + +from abc import ABCMeta, abstractmethod + +from ai.backend.common.bgtask import ProgressReporter +from ai.backend.common.docker import ( + ImageRef, + MIN_KERNELSPEC, MAX_KERNELSPEC, + arch_name_aliases, + login as registry_login, +) +from ai.backend.common.logging import BraceStyleAdapter + +from ai.backend.manager.models.image import ImageRow, ImageType +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class BaseContainerRegistry(metaclass=ABCMeta): + + db: ExtendedAsyncSAEngine + registry_name: str + registry_info: Mapping[str, Any] + registry_url: yarl.URL + max_concurrency_per_registry: int + base_hdrs: Dict[str, str] + credentials: Dict[str, str] + ssl_verify: bool + + sema: ContextVar[asyncio.Semaphore] + reporter: ContextVar[Optional[ProgressReporter]] + all_updates: ContextVar[Dict[ImageRef, Dict[str, Any]]] + + def __init__( + self, + db: ExtendedAsyncSAEngine, + registry_name: str, + registry_info: Mapping[str, Any], + *, + max_concurrency_per_registry: int = 4, + ssl_verify: bool = True, + ) -> None: + self.db = db + self.registry_name = registry_name + self.registry_info = registry_info + self.registry_url = registry_info[''] + self.max_concurrency_per_registry = max_concurrency_per_registry + self.base_hdrs = { + 'Accept': 'application/vnd.docker.distribution.manifest.v2+json', + } + self.credentials = {} + self.ssl_verify = ssl_verify + self.sema = ContextVar('sema') + self.reporter = ContextVar('reporter', default=None) + self.all_updates = ContextVar('all_updates') + + async def rescan_single_registry( + self, + reporter: ProgressReporter = None, + ) -> None: + self.all_updates.set({}) + self.sema.set(asyncio.Semaphore(self.max_concurrency_per_registry)) + self.reporter.set(reporter) + username = self.registry_info['username'] + if username is not None: + self.credentials['username'] = username + password = self.registry_info['password'] + if password is not None: + self.credentials['password'] = password + non_kernel_words = ( + 'common-', 'commons-', 'base-', + 'krunner', 'builder', + 'backendai', 'geofront', + ) + ssl_ctx = None # default + if not self.registry_info['ssl-verify']: + ssl_ctx = False + connector = aiohttp.TCPConnector(ssl=ssl_ctx) + async with aiohttp.ClientSession(connector=connector) as sess: + async with aiotools.TaskGroup() as tg: + async for image in self.fetch_repositories(sess): + if not any((w in image) for w in non_kernel_words): # skip non-kernel images + tg.create_task(self._scan_image(sess, image)) + + all_updates = self.all_updates.get() + if not all_updates: + log.info('No images found in registry {0}', self.registry_url) + else: + image_identifiers = [ + (k.canonical, k.architecture) for k in all_updates.keys() + ] + async with self.db.begin_session() as session: + existing_images = await session.scalars( + sa.select(ImageRow) + .where( + sa.func.ROW(ImageRow.name, ImageRow.architecture) + .in_(image_identifiers), + ), + ) + + for image_row in existing_images: + key = image_row.image_ref + values = all_updates.get(key) + if values is None: + continue + all_updates.pop(key) + image_row.config_digest = values['config_digest'] + image_row.size_bytes = values['size_bytes'] + image_row.accelerators = values.get('accels') + image_row.labels = values['labels'] + image_row.resources = values['resources'] + + session.add_all([ + ImageRow( + name=k.canonical, + registry=k.registry, + image=k.name, + tag=k.tag, + architecture=k.architecture, + config_digest=v['config_digest'], + size_bytes=v['size_bytes'], + type=ImageType.COMPUTE, + accelerators=v.get('accels'), + labels=v['labels'], + resources=v['resources'], + ) for k, v in all_updates.items() + ]) + + async def _scan_image( + self, + sess: aiohttp.ClientSession, + image: str, + ) -> None: + rqst_args = await registry_login( + sess, + self.registry_url, + self.credentials, + f'repository:{image}:pull', + ) + rqst_args['headers'].update(**self.base_hdrs) + tags = [] + tag_list_url: Optional[yarl.URL] + tag_list_url = (self.registry_url / f'v2/{image}/tags/list').with_query( + {'n': '10'}, + ) + while tag_list_url is not None: + async with sess.get(tag_list_url, **rqst_args) as resp: + data = json.loads(await resp.read()) + if 'tags' in data: + # sometimes there are dangling image names in the hub. + tags.extend(data['tags']) + tag_list_url = None + next_page_link = resp.links.get('next') + if next_page_link: + next_page_url = cast(yarl.URL, next_page_link['url']) + tag_list_url = ( + self.registry_url + .with_path(next_page_url.path) + .with_query(next_page_url.query) + ) + if (reporter := self.reporter.get()) is not None: + reporter.total_progress += len(tags) + async with aiotools.TaskGroup() as tg: + for tag in tags: + tg.create_task(self._scan_tag(sess, rqst_args, image, tag)) + + async def _scan_tag( + self, + sess: aiohttp.ClientSession, + rqst_args, + image: str, + tag: str, + ) -> None: + skip_reason = None + + async def _load_manifest(_tag: str): + async with sess.get(self.registry_url / f'v2/{image}/manifests/{_tag}', + **rqst_args) as resp: + if resp.status == 404: + # ignore missing tags + # (may occur after deleting an image from the docker hub) + return {} + resp.raise_for_status() + data = await resp.json() + + if data['mediaType'] == 'application/vnd.docker.distribution.manifest.list.v2+json': + # recursively call _load_manifests with detected arch and corresponding image digest + ret = {} + for m in data['manifests']: + ret.update( + await _load_manifest( + m['digest'], + ), + ) + if (reporter := self.reporter.get()) is not None: + reporter.total_progress += len(ret) - 1 + return ret + + config_digest = data['config']['digest'] + size_bytes = (sum(layer['size'] for layer in data['layers']) + + data['config']['size']) + async with sess.get(self.registry_url / f'v2/{image}/blobs/{config_digest}', + **rqst_args) as resp: + resp.raise_for_status() + data = json.loads(await resp.read()) + architecture = arch_name_aliases.get(data['architecture'], data['architecture']) + labels = {} + if 'container_config' in data: + raw_labels = data['container_config'].get('Labels') + if raw_labels: + labels.update(raw_labels) + else: + log.warn('label not found on image {}:{}/{}', image, _tag, architecture) + else: + raw_labels = data['config'].get('Labels') + if raw_labels: + labels.update(raw_labels) + else: + log.warn('label not found on image {}:{}/{}', image, _tag, architecture) + return { + architecture: { + 'size': size_bytes, + 'labels': labels, + 'digest': config_digest, + }, + } + + async with self.sema.get(): + manifests = await _load_manifest(tag) + + if len(manifests.keys()) == 0: + log.warning('Skipped image - {}:{} (missing/deleted)', image, tag) + progress_msg = f"Skipped {image}:{tag} (missing/deleted)" + if (reporter := self.reporter.get()) is not None: + await reporter.update(1, message=progress_msg) + + idx = 0 + for architecture, manifest in manifests.items(): + idx += 1 + if manifest is None: + skip_reason = 'missing/deleted' + continue + + try: + size_bytes = manifest['size'] + labels = manifest['labels'] + config_digest = manifest['digest'] + if 'ai.backend.kernelspec' not in labels: + # Skip non-Backend.AI kernel images + skip_reason = architecture + ": missing kernelspec" + continue + if not (MIN_KERNELSPEC <= int(labels['ai.backend.kernelspec']) <= MAX_KERNELSPEC): + # Skip unsupported kernelspec images + skip_reason = architecture + ": unsupported kernelspec" + continue + + update_key = ImageRef( + f'{self.registry_name}/{image}:{tag}', + [self.registry_name], + architecture, + ) + updates = { + 'config_digest': config_digest, + 'size_bytes': size_bytes, + 'labels': labels, + } + accels = labels.get('ai.backend.accelerators') + if accels: + updates['accels'] = accels + + resources = {} + res_prefix = 'ai.backend.resource.min.' + for k, v in filter(lambda pair: pair[0].startswith(res_prefix), + labels.items()): + res_key = k[len(res_prefix):] + resources[res_key] = {'min': v} + updates['resources'] = resources + self.all_updates.get().update({ + update_key: updates, + }) + finally: + if skip_reason: + log.warning('Skipped image - {}:{}/{} ({})', image, tag, architecture, skip_reason) + progress_msg = f"Skipped {image}:{tag}/{architecture} ({skip_reason})" + else: + log.info('Updated image - {0}:{1}/{2}', image, tag, architecture) + progress_msg = f"Updated {image}:{tag}/{architecture}" + if (reporter := self.reporter.get()) is not None: + await reporter.update(1, message=progress_msg) + + @abstractmethod + async def fetch_repositories( + self, + sess: aiohttp.ClientSession, + ) -> AsyncIterator[str]: + yield "" diff --git a/src/ai/backend/manager/container_registry/docker.py b/src/ai/backend/manager/container_registry/docker.py new file mode 100644 index 0000000000..03e08cab14 --- /dev/null +++ b/src/ai/backend/manager/container_registry/docker.py @@ -0,0 +1,92 @@ +import json +import logging +from typing import AsyncIterator, Optional, cast + +import aiohttp +import yarl + +from ai.backend.common.docker import ( + login as registry_login, +) +from ai.backend.common.logging import BraceStyleAdapter + +from .base import BaseContainerRegistry + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class DockerHubRegistry(BaseContainerRegistry): + + async def fetch_repositories( + self, + sess: aiohttp.ClientSession, + ) -> AsyncIterator[str]: + # We need some special treatment for the Docker Hub. + params = {'page_size': '30'} + username = self.registry_info['username'] + hub_url = yarl.URL('https://hub.docker.com') + repo_list_url: Optional[yarl.URL] + repo_list_url = hub_url / f'v2/repositories/{username}/' + while repo_list_url is not None: + async with sess.get(repo_list_url, params=params) as resp: + if resp.status == 200: + data = await resp.json() + for item in data['results']: + # skip legacy images + if item['name'].startswith('kernel-'): + continue + yield f"{username}/{item['name']}" + else: + log.error('Failed to fetch repository list from {0} ' + '(status={1})', + repo_list_url, resp.status) + break + repo_list_url = None + next_page_link = data.get('next', None) + if next_page_link: + next_page_url = yarl.URL(next_page_link) + repo_list_url = ( + hub_url + .with_path(next_page_url.path) + .with_query(next_page_url.query) + ) + + +class DockerRegistry_v2(BaseContainerRegistry): + + async def fetch_repositories( + self, + sess: aiohttp.ClientSession, + ) -> AsyncIterator[str]: + # The credential should have the catalog search privilege. + rqst_args = await registry_login( + sess, + self.registry_url, + self.credentials, + 'registry:catalog:*', + ) + catalog_url: Optional[yarl.URL] + catalog_url = (self.registry_url / 'v2/_catalog').with_query( + {'n': '30'}, + ) + while catalog_url is not None: + async with sess.get(catalog_url, **rqst_args) as resp: + if resp.status == 200: + data = json.loads(await resp.read()) + for item in data['repositories']: + yield item + log.debug('found {} repositories', len(data['repositories'])) + else: + log.warning('Docker registry {0} does not allow/support ' + 'catalog search. (status={1})', + self.registry_url, resp.status) + break + catalog_url = None + next_page_link = resp.links.get('next') + if next_page_link: + next_page_url = cast(yarl.URL, next_page_link['url']) + catalog_url = ( + self.registry_url + .with_path(next_page_url.path) + .with_query(next_page_url.query) + ) diff --git a/src/ai/backend/manager/container_registry/harbor.py b/src/ai/backend/manager/container_registry/harbor.py new file mode 100644 index 0000000000..553a7ac4a2 --- /dev/null +++ b/src/ai/backend/manager/container_registry/harbor.py @@ -0,0 +1,109 @@ +import logging +from typing import AsyncIterator, Optional, cast + +import aiohttp +import yarl + +from ai.backend.common.logging import BraceStyleAdapter + +from .base import BaseContainerRegistry + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class HarborRegistry_v1(BaseContainerRegistry): + + async def fetch_repositories( + self, + sess: aiohttp.ClientSession, + ) -> AsyncIterator[str]: + api_url = self.registry_url / 'api' + registry_projects = self.registry_info['project'] + rqst_args = {} + if self.credentials: + rqst_args['auth'] = aiohttp.BasicAuth( + self.credentials['username'], + self.credentials['password'], + ) + project_list_url: Optional[yarl.URL] + project_list_url = (api_url / 'projects').with_query( + {'page_size': '30'}, + ) + project_ids = [] + while project_list_url is not None: + async with sess.get(project_list_url, allow_redirects=False, **rqst_args) as resp: + projects = await resp.json() + for item in projects: + if item['name'] in registry_projects: + project_ids.append(item['project_id']) + project_list_url = None + next_page_link = resp.links.get('next') + if next_page_link: + next_page_url = cast(yarl.URL, next_page_link['url']) + project_list_url = ( + self.registry_url + .with_path(next_page_url.path) + .with_query(next_page_url.query) + ) + if not project_ids: + log.warning('There is no given project.') + return + repo_list_url: Optional[yarl.URL] + for project_id in project_ids: + repo_list_url = (api_url / 'repositories').with_query( + {'project_id': project_id, 'page_size': '30'}, + ) + while repo_list_url is not None: + async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp: + items = await resp.json() + repos = [item['name'] for item in items] + for item in repos: + yield item + repo_list_url = None + next_page_link = resp.links.get('next') + if next_page_link: + next_page_url = cast(yarl.URL, next_page_link['url']) + repo_list_url = ( + self.registry_url + .with_path(next_page_url.path) + .with_query(next_page_url.query) + ) + + +class HarborRegistry_v2(BaseContainerRegistry): + + async def fetch_repositories( + self, + sess: aiohttp.ClientSession, + ) -> AsyncIterator[str]: + api_url = self.registry_url / 'api' / 'v2.0' + registry_projects = self.registry_info['project'] + rqst_args = {} + if self.credentials: + rqst_args['auth'] = aiohttp.BasicAuth( + self.credentials['username'], + self.credentials['password'], + ) + repo_list_url: Optional[yarl.URL] + for project_name in registry_projects: + repo_list_url = (api_url / 'projects' / project_name / 'repositories').with_query( + {'page_size': '30'}, + ) + while repo_list_url is not None: + async with sess.get(repo_list_url, allow_redirects=False, **rqst_args) as resp: + items = await resp.json() + if isinstance(items, dict) and (errors := items.get('errors', [])): + raise RuntimeError(f"failed to fetch repositories in project {project_name}", + errors[0]['code'], errors[0]['message']) + repos = [item['name'] for item in items] + for item in repos: + yield item + repo_list_url = None + next_page_link = resp.links.get('next') + if next_page_link: + next_page_url = cast(yarl.URL, next_page_link['url']) + repo_list_url = ( + self.registry_url + .with_path(next_page_url.path) + .with_query(next_page_url.query) + ) diff --git a/src/ai/backend/manager/defs.py b/src/ai/backend/manager/defs.py new file mode 100644 index 0000000000..17935229d2 --- /dev/null +++ b/src/ai/backend/manager/defs.py @@ -0,0 +1,51 @@ +""" +Common definitions/constants used throughout the manager. +""" + +import enum +import platform +import re +from typing import Final + +from ai.backend.common.docker import arch_name_aliases +from ai.backend.common.types import SlotName, SlotTypes + +INTRINSIC_SLOTS: Final = { + SlotName('cpu'): SlotTypes('count'), + SlotName('mem'): SlotTypes('bytes'), +} + +MANAGER_ARCH = platform.machine().lower().strip() + + +DEFAULT_IMAGE_ARCH = arch_name_aliases.get(MANAGER_ARCH, MANAGER_ARCH) +# DEFAULT_IMAGE_ARCH = 'x86_64' + +# The default container role name for multi-container sessions +DEFAULT_ROLE: Final = "main" + +_RESERVED_VFOLDER_PATTERNS = [r'^\.[a-z0-9]+rc$', r'^\.[a-z0-9]+_profile$'] +RESERVED_DOTFILES = ['.terminfo', '.jupyter', '.ssh', '.ssh/authorized_keys', '.local', '.config'] +RESERVED_VFOLDERS = ['.terminfo', '.jupyter', '.tmux.conf', '.ssh', '/bin', '/boot', '/dev', '/etc', + '/lib', '/lib64', '/media', '/mnt', '/opt', '/proc', '/root', '/run', '/sbin', + '/srv', '/sys', '/tmp', '/usr', '/var', '/home'] +RESERVED_VFOLDER_PATTERNS = [re.compile(x) for x in _RESERVED_VFOLDER_PATTERNS] + +# Redis database IDs depending on purposes +REDIS_STAT_DB: Final = 0 +REDIS_RLIM_DB: Final = 1 +REDIS_LIVE_DB: Final = 2 +REDIS_IMAGE_DB: Final = 3 +REDIS_STREAM_DB: Final = 4 + + +# The unique identifiers for distributed locks. +# To be used with PostgreSQL advisory locks, the values are defined as integers. +class LockID(enum.IntEnum): + LOCKID_TEST = 42 + LOCKID_SCHEDULE = 91 + LOCKID_PREPARE = 92 + LOCKID_SCHEDULE_TIMER = 191 + LOCKID_PREPARE_TIMER = 192 + LOCKID_LOG_CLEANUP_TIMER = 195 + LOCKID_IDLE_CHECK_TIMER = 196 diff --git a/src/ai/backend/manager/exceptions.py b/src/ai/backend/manager/exceptions.py new file mode 100644 index 0000000000..dd98c35929 --- /dev/null +++ b/src/ai/backend/manager/exceptions.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import ( + Any, + List, + Tuple, + TypedDict, + TYPE_CHECKING, +) + +from aiotools import TaskGroupError + +if TYPE_CHECKING: + from ai.backend.common.types import AgentId + + +class InvalidArgument(Exception): + """ + An internal exception class to represent invalid arguments in internal APIs. + This is wrapped as InvalidAPIParameters in web request handlers. + """ + pass + + +class AgentError(RuntimeError): + """ + A dummy exception class to distinguish agent-side errors passed via + agent rpc calls. + + It carries two args tuple: the exception type and exception arguments from + the agent. + """ + + __slots__ = ( + 'agent_id', 'exc_name', 'exc_repr', 'exc_tb', + ) + + def __init__( + self, + agent_id: AgentId, + exc_name: str, + exc_repr: str, + exc_args: Tuple[Any, ...], + exc_tb: str = None, + ) -> None: + super().__init__(agent_id, exc_name, exc_repr, exc_args, exc_tb) + self.agent_id = agent_id + self.exc_name = exc_name + self.exc_repr = exc_repr + self.exc_args = exc_args + self.exc_tb = exc_tb + + +class MultiAgentError(TaskGroupError): + """ + An exception that is a collection of multiple errors from multiple agents. + """ + + +class ErrorDetail(TypedDict, total=False): + src: str + name: str + repr: str + agent_id: str # optional + collection: List[Any] # optional; currently mypy cannot handle recursive types + + +class ErrorStatusInfo(TypedDict): + error: ErrorDetail + + +def convert_to_status_data(e: Exception, is_debug: bool = False) -> ErrorStatusInfo: + if isinstance(e, MultiAgentError): + data = ErrorStatusInfo( + error={ + "src": "agent", + "name": "MultiAgentError", + "repr": f"MultiAgentError({len(e.__errors__)})", + "collection": [ + convert_to_status_data(sub_error, is_debug)['error'] + for sub_error in + e.__errors__ + ], + }, + ) + return data + elif isinstance(e, AgentError): + data = ErrorStatusInfo( + error={ + "src": "agent", + "name": e.exc_name, + "repr": e.exc_repr, + }, + ) + if is_debug: + data["error"]["agent_id"] = e.agent_id + return data + return ErrorStatusInfo( + error={ + "src": "other", + "name": e.__class__.__name__, + "repr": repr(e), + }, + ) diff --git a/src/ai/backend/manager/idle.py b/src/ai/backend/manager/idle.py new file mode 100644 index 0000000000..a759ffd89a --- /dev/null +++ b/src/ai/backend/manager/idle.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +from collections import defaultdict +from decimal import Decimal +import enum +import logging +import math +from abc import ABCMeta, abstractmethod +from datetime import datetime, timedelta +from typing import ( + Any, + ClassVar, + DefaultDict, + List, + Mapping, + MutableMapping, + Sequence, + Set, + Type, + TYPE_CHECKING, + Union, +) + +import sqlalchemy as sa +import trafaret as t +from dateutil.tz import tzutc +from sqlalchemy.engine import Row + +import ai.backend.common.validators as tx +from ai.backend.common import msgpack, redis as redis_helper +from ai.backend.common.distributed import GlobalTimer +from ai.backend.common.events import ( + AbstractEvent, + DoIdleCheckEvent, + DoTerminateSessionEvent, + EventDispatcher, + EventHandler, + EventProducer, + ExecutionCancelledEvent, + ExecutionFinishedEvent, + ExecutionStartedEvent, + ExecutionTimeoutEvent, + SessionStartedEvent, +) +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import AccessKey, RedisConnectionInfo, SessionTypes +from ai.backend.common.utils import nmget + +from .defs import DEFAULT_ROLE, REDIS_LIVE_DB, REDIS_STAT_DB, LockID +from .models import kernels, keypair_resource_policies, keypairs +from .models.kernel import LIVE_STATUS +from .types import DistributedLockFactory + +if TYPE_CHECKING: + from .config import SharedConfig + from ai.backend.common.types import AgentId, KernelId, SessionId + from sqlalchemy.ext.asyncio import ( + AsyncConnection as SAConnection, + ) + from .models.utils import ExtendedAsyncSAEngine as SAEngine + +log = BraceStyleAdapter(logging.getLogger("ai.backend.manager.idle")) + + +class AppStreamingStatus(enum.Enum): + NO_ACTIVE_CONNECTIONS = 0 + HAS_ACTIVE_CONNECTIONS = 1 + + +class ThresholdOperator(enum.Enum): + AND = 'and' + OR = 'or' + + +class IdleCheckerHost: + + check_interval: ClassVar[float] = 15.0 + + def __init__( + self, + db: SAEngine, + shared_config: SharedConfig, + event_dispatcher: EventDispatcher, + event_producer: EventProducer, + lock_factory: DistributedLockFactory, + ) -> None: + self._checkers: list[BaseIdleChecker] = [] + self._frozen = False + self._db = db + self._shared_config = shared_config + self._event_dispatcher = event_dispatcher + self._event_producer = event_producer + self._lock_factory = lock_factory + self._redis_live = redis_helper.get_redis_object( + self._shared_config.data['redis'], + db=REDIS_LIVE_DB, + ) + self._redis_stat = redis_helper.get_redis_object( + self._shared_config.data['redis'], + db=REDIS_STAT_DB, + ) + + def add_checker(self, checker: BaseIdleChecker): + if self._frozen: + raise RuntimeError("Cannot add a new idle checker after the idle checker host is frozen.") + self._checkers.append(checker) + + async def start(self) -> None: + self._frozen = True + for checker in self._checkers: + raw_config = await self._shared_config.etcd.get_prefix_dict( + f"config/idle/checkers/{checker.name}", + ) + await checker.populate_config(raw_config or {}) + self.timer = GlobalTimer( + self._lock_factory(LockID.LOCKID_IDLE_CHECK_TIMER, self.check_interval), + self._event_producer, + lambda: DoIdleCheckEvent(), + self.check_interval, + ) + self._evh_idle_check = self._event_dispatcher.consume( + DoIdleCheckEvent, + None, + self._do_idle_check, + ) + await self.timer.join() + + async def shutdown(self) -> None: + for checker in self._checkers: + await checker.aclose() + await self.timer.leave() + self._event_dispatcher.unconsume(self._evh_idle_check) + await self._redis_stat.close() + await self._redis_live.close() + + async def update_app_streaming_status( + self, + session_id: SessionId, + status: AppStreamingStatus, + ) -> None: + for checker in self._checkers: + await checker.update_app_streaming_status(session_id, status) + + async def _do_idle_check( + self, + context: None, + source: AgentId, + event: DoIdleCheckEvent, + ) -> None: + log.debug("do_idle_check(): triggered") + policy_cache: dict[AccessKey, Row] = {} + async with self._db.begin_readonly() as conn: + query = ( + sa.select([kernels]) + .select_from(kernels) + .where( + (kernels.c.status.in_(LIVE_STATUS)) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + ) + result = await conn.execute(query) + rows = result.fetchall() + for session in rows: + policy = policy_cache.get(session["access_key"], None) + if policy is None: + query = ( + sa.select([keypair_resource_policies]) + .select_from( + sa.join( + keypairs, + keypair_resource_policies, + keypair_resource_policies.c.name == keypairs.c.resource_policy, + ), + ) + .where(keypairs.c.access_key == session["access_key"]) + ) + result = await conn.execute(query) + policy = result.first() + assert policy is not None + policy_cache[session["access_key"]] = policy + for checker in self._checkers: + if not (await checker.check_session(session, conn, policy)): + log.info( + "The {} idle checker triggered termination of s:{}", + checker.name, session['id'], + ) + await self._event_producer.produce_event( + DoTerminateSessionEvent(session["id"], f"idle-{checker.name}"), + ) + # If any one of checkers decided to terminate the session, + # we can skip over remaining checkers. + break + + +class BaseIdleChecker(metaclass=ABCMeta): + + name: ClassVar[str] = "base" + + def __init__( + self, + event_dispatcher: EventDispatcher, + redis_live: RedisConnectionInfo, + redis_stat: RedisConnectionInfo, + ) -> None: + self._event_dispatcher = event_dispatcher + self._redis_live = redis_live + self._redis_stat = redis_stat + + async def aclose(self) -> None: + pass + + @abstractmethod + async def populate_config(self, config: Mapping[str, Any]) -> None: + raise NotImplementedError + + async def update_app_streaming_status( + self, + session_id: SessionId, + status: AppStreamingStatus, + ) -> None: + pass + + @abstractmethod + async def check_session(self, session: Row, dbconn: SAConnection, policy: Row) -> bool: + """ + Return True if the session should be kept alive or + return False if the session should be terminated. + """ + return True + + +class TimeoutIdleChecker(BaseIdleChecker): + """ + Checks the idleness of a session by the elapsed time since last used. + The usage means processing of any computation requests, such as + query/batch-mode code execution and having active service-port connections. + """ + + name: ClassVar[str] = "timeout" + + _config_iv = t.Dict( + { + t.Key("threshold", default="10m"): tx.TimeDuration(), + }, + ).allow_extra("*") + + idle_timeout: timedelta + _evhandlers: List[EventHandler[None, AbstractEvent]] + + def __init__( + self, + event_dispatcher: EventDispatcher, + redis_live: RedisConnectionInfo, + redis_stat: RedisConnectionInfo, + ) -> None: + super().__init__(event_dispatcher, redis_live, redis_stat) + d = self._event_dispatcher + self._evhandlers = [ + d.consume(SessionStartedEvent, None, self._session_started_cb), # type: ignore + d.consume(ExecutionStartedEvent, None, self._execution_started_cb), # type: ignore + d.consume(ExecutionFinishedEvent, None, self._execution_exited_cb), # type: ignore + d.consume(ExecutionTimeoutEvent, None, self._execution_exited_cb), # type: ignore + d.consume(ExecutionCancelledEvent, None, self._execution_exited_cb), # type: ignore + ] + + async def aclose(self) -> None: + for _evh in self._evhandlers: + self._event_dispatcher.unconsume(_evh) + + async def populate_config(self, raw_config: Mapping[str, Any]) -> None: + config = self._config_iv.check(raw_config) + self.idle_timeout = config["threshold"] + log.info( + "TimeoutIdleChecker: default idle_timeout = {0:,} seconds", + self.idle_timeout.total_seconds(), + ) + + async def update_app_streaming_status( + self, + session_id: SessionId, + status: AppStreamingStatus, + ) -> None: + if status == AppStreamingStatus.HAS_ACTIVE_CONNECTIONS: + await self._disable_timeout(session_id) + elif status == AppStreamingStatus.NO_ACTIVE_CONNECTIONS: + await self._update_timeout(session_id) + + async def _disable_timeout(self, session_id: SessionId) -> None: + log.debug(f"TimeoutIdleChecker._disable_timeout({session_id})") + await redis_helper.execute( + self._redis_live, + lambda r: r.set( + f"session.{session_id}.last_access", "0", xx=True, + ), + ) + + async def _update_timeout(self, session_id: SessionId) -> None: + log.debug(f"TimeoutIdleChecker._update_timeout({session_id})") + t = await redis_helper.execute(self._redis_live, lambda r: r.time()) + t = t[0] + (t[1] / (10**6)) + await redis_helper.execute( + self._redis_live, + lambda r: r.set( + f"session.{session_id}.last_access", + f"{t:.06f}", + ex=max(86400, int(self.idle_timeout.total_seconds() * 2)), + ), + ) + + async def _session_started_cb( + self, + context: None, + source: AgentId, + event: SessionStartedEvent, + ) -> None: + await self._update_timeout(event.session_id) + + async def _execution_started_cb( + self, + context: None, + source: AgentId, + event: ExecutionStartedEvent, + ) -> None: + await self._disable_timeout(event.session_id) + + async def _execution_exited_cb( + self, + context: None, + source: AgentId, + event: ExecutionFinishedEvent | ExecutionTimeoutEvent | ExecutionCancelledEvent, + ) -> None: + await self._update_timeout(event.session_id) + + async def check_session(self, session: Row, dbconn: SAConnection, policy: Row) -> bool: + session_id = session["id"] + if session["session_type"] == SessionTypes.BATCH: + return True + active_streams = await redis_helper.execute( + self._redis_live, + lambda r: r.zcount( + f"session.{session_id}.active_app_connections", + float('-inf'), float('+inf'), + ), + ) + if active_streams is not None and active_streams > 0: + return True + t = await redis_helper.execute(self._redis_live, lambda r: r.time()) + t = t[0] + (t[1] / (10**6)) + raw_last_access = \ + await redis_helper.execute( + self._redis_live, + lambda r: r.get(f"session.{session_id}.last_access"), + ) + if raw_last_access is None or raw_last_access == "0": + return True + last_access = float(raw_last_access) + # serves as the default fallback if keypair resource policy's idle_timeout is "undefined" + idle_timeout = self.idle_timeout.total_seconds() + # setting idle_timeout: + # - zero/inf means "infinite" + # - negative means "undefined" + if policy["idle_timeout"] >= 0: + idle_timeout = float(policy["idle_timeout"]) + if ( + (idle_timeout <= 0) + or (math.isinf(idle_timeout) and idle_timeout > 0) + or (t - last_access <= idle_timeout) + ): + return True + return False + + +class SessionLifetimeChecker(BaseIdleChecker): + + name: ClassVar[str] = "session_lifetime" + + async def populate_config(self, config: Mapping[str, Any]) -> None: + pass + + async def check_session(self, session: Row, dbconn: SAConnection, policy: Row) -> bool: + now = await dbconn.scalar(sa.select(sa.func.now())) + if policy["max_session_lifetime"] > 0: + # TODO: once per-status time tracking is implemented, let's change created_at + # to the timestamp when the session entered PREPARING status. + if now - session["created_at"] >= timedelta(seconds=policy["max_session_lifetime"]): + return False + return True + + +class UtilizationIdleChecker(BaseIdleChecker): + """ + Checks the idleness of a session by the average utilization of compute devices. + """ + + name: ClassVar[str] = "utilization" + + _config_iv = t.Dict( + { + t.Key("time-window", default="10m"): tx.TimeDuration(), + t.Key("initial-grace-period", default="5m"): tx.TimeDuration(), + t.Key("thresholds-check-operator", default=ThresholdOperator.AND): + tx.Enum(ThresholdOperator), + t.Key("resource-thresholds"): t.Dict( + { + t.Key("cpu_util", default=None): t.Null | t.Dict({t.Key("average"): t.Float}), + t.Key("mem", default=None): t.Null | t.Dict({t.Key("average"): t.Float}), + t.Key("cuda_util", default=None): t.Null | t.Dict({t.Key("average"): t.Float}), + t.Key("cuda_mem", default=None): t.Null | t.Dict({t.Key("average"): t.Float}), + }, + ), + }, + ).allow_extra("*") + + resource_thresholds: MutableMapping[str, Union[int, float, Decimal]] + thresholds_check_operator: str + time_window: timedelta + initial_grace_period: timedelta + _evhandlers: List[EventHandler[None, AbstractEvent]] + slot_resource_map: Mapping[str, Set[str]] = { + 'cpu': {'cpu_util'}, + 'mem': {'mem'}, + 'cuda': {'cuda_util', 'cuda_mem'}, + } + + async def populate_config(self, raw_config: Mapping[str, Any]) -> None: + config = self._config_iv.check(raw_config) + self.resource_thresholds = { + k: nmget(v, 'average') for k, v in config.get('resource-thresholds').items() + } + self.thresholds_check_operator = config.get("thresholds-check-operator") + self.time_window = config.get("time-window") + self.initial_grace_period = config.get("initial-grace-period") + + thresholds_log = " ".join([f"{k}({threshold})," for k, + threshold in self.resource_thresholds.items()]) + log.info( + f"UtilizationIdleChecker(%): {thresholds_log} " + f"thresholds-check-operator(\"{self.thresholds_check_operator}\"), " + f"time-window({self.time_window.total_seconds()}s)", + ) + + async def check_session(self, session: Row, dbconn: SAConnection, policy: Row) -> bool: + session_id = session["id"] + interval = IdleCheckerHost.check_interval + window_size = int(self.time_window.total_seconds() / interval) + occupied_slots = session["occupied_slots"] + unavailable_resources: Set[str] = set() + + util_series_key = f"session.{session_id}.util_series" + util_last_collected_key = f"session.{session_id}.util_last_collected" + + # Wait until the time "interval" is passed after the last udpated time. + t = await redis_helper.execute(self._redis_live, lambda r: r.time()) + t = t[0] + (t[1] / (10**6)) + raw_util_last_collected = await redis_helper.execute( + self._redis_live, + lambda r: r.get(util_last_collected_key), + ) + util_last_collected = float(raw_util_last_collected) if raw_util_last_collected else 0 + if t - util_last_collected < interval: + return True + + # Respect initial grace period (no termination of the session) + now = datetime.now(tzutc()) + if now - session["created_at"] <= self.initial_grace_period: + return True + + # Merge same type of (exclusive) resources as a unique resource with the values added. + # Example: {cuda.device: 0, cuda.shares: 0.5} -> {cuda: 0.5}. + unique_res_map: DefaultDict[str, Any] = defaultdict(Decimal) + for k, v in occupied_slots.items(): + unique_key = k.split('.')[0] + unique_res_map[unique_key] += v + + # Do not take into account unallocated resources. For example, do not garbage collect + # a session without GPU even if cuda_util is configured in resource-thresholds. + for slot in unique_res_map: + if unique_res_map[slot] == 0: + unavailable_resources.update(self.slot_resource_map[slot]) + + # Respect idle_timeout, from keypair resource policy, over time_window. + if policy["idle_timeout"] >= 0: + window_size = int(float(policy["idle_timeout"]) / interval) + if (window_size <= 0) or (math.isinf(window_size) and window_size > 0): + return True + + # Get current utilization data from all containers of the session. + if session["cluster_size"] > 1: + query = ( + sa.select([kernels.c.id]) + .select_from(kernels) + .where( + (kernels.c.session_id == session_id) & + (kernels.c.status.in_(LIVE_STATUS)), + ) + ) + result = await dbconn.execute(query) + rows = result.fetchall() + kernel_ids = [k["id"] for k in rows] + else: + kernel_ids = [session_id] + current_utilizations = await self.get_current_utilization(kernel_ids, occupied_slots) + if current_utilizations is None: + return True + + # Update utilization time-series data. + not_enough_data = False + raw_util_series = await redis_helper.execute(self._redis_live, lambda r: r.get(util_series_key)) + + try: + util_series = msgpack.unpackb(raw_util_series, use_list=True) + except TypeError: + util_series = {k: [] for k in self.resource_thresholds.keys()} + + for k in util_series: + util_series[k].append(current_utilizations[k]) + if len(util_series[k]) > window_size: + util_series[k].pop(0) + else: + not_enough_data = True + await redis_helper.execute( + self._redis_live, + lambda r: r.set( + util_series_key, + msgpack.packb(util_series), + ex=max(86400, int(self.time_window.total_seconds() * 2)), + ), + ) + await redis_helper.execute( + self._redis_live, + lambda r: r.set( + util_last_collected_key, + f"{t:.06f}", + ex=max(86400, int(self.time_window.total_seconds() * 2)), + ), + ) + + if not_enough_data: + return True + + # Check over-utilized (not to be collected) resources. + avg_utils = {k: sum(v) / len(v) for k, v in util_series.items()} + sufficiently_utilized = { + k: (float(avg_utils[k]) >= float(threshold)) + for k, threshold in self.resource_thresholds.items() + if (threshold is not None) and (k not in unavailable_resources) + } + + if len(sufficiently_utilized) < 1: + check_result = True + elif self.thresholds_check_operator == ThresholdOperator.OR: + check_result = all(sufficiently_utilized.values()) + else: # "and" operation is the default + check_result = any(sufficiently_utilized.values()) + if not check_result: + log.info("utilization timeout: {} ({}, {})", + session_id, avg_utils, self.thresholds_check_operator) + return check_result + + async def get_current_utilization( + self, + kernel_ids: Sequence[KernelId], + occupied_slots: Mapping[str, Any], + ) -> Mapping[str, float] | None: + """ + Return the current utilization key-value pairs of multiple kernels, possibly the + components of a cluster session. If there are multiple kernel_ids, this method + will return the averaged values over the kernels for each utilization. + """ + try: + utilizations = {k: 0.0 for k in self.resource_thresholds.keys()} + live_stat = {} + for kernel_id in kernel_ids: + raw_live_stat = await redis_helper.execute( + self._redis_stat, + lambda r: r.get(str(kernel_id)), + ) + live_stat = msgpack.unpackb(raw_live_stat) + kernel_utils = { + k: float(nmget(live_stat, f"{k}.pct", 0.0)) + for k in self.resource_thresholds.keys() + } + + utilizations = { + k: utilizations[k] + kernel_utils[k] + for k in self.resource_thresholds.keys() + } + utilizations = { + k: utilizations[k] / len(kernel_ids) + for k in self.resource_thresholds.keys() + } + + # NOTE: Manual calculation of mem utilization. + # mem.capacity does not report total amount of memory allocated to + # the container, and mem.pct always report >90% even when nothing is + # executing. So, we just replace it with the value of occupied slot. + mem_slots = float(occupied_slots.get('mem', 0)) + mem_current = float(nmget(live_stat, "mem.current", 0.0)) + utilizations['mem'] = mem_current / mem_slots * 100 if mem_slots > 0 else 0 + return utilizations + except Exception as e: + log.warning("Unable to collect utilization for idleness check", exc_info=e) + return None + + +checker_registry: Mapping[str, Type[BaseIdleChecker]] = { + TimeoutIdleChecker.name: TimeoutIdleChecker, + UtilizationIdleChecker.name: UtilizationIdleChecker, +} + + +async def init_idle_checkers( + db: SAEngine, + shared_config: SharedConfig, + event_dispatcher: EventDispatcher, + event_producer: EventProducer, + lock_factory: DistributedLockFactory, +) -> IdleCheckerHost: + """ + Create an instance of session idleness checker + from the given configuration and using the given event dispatcher. + """ + checker_host = IdleCheckerHost(db, shared_config, event_dispatcher, event_producer, lock_factory) + checker_init_args = (event_dispatcher, checker_host._redis_live, checker_host._redis_stat) + log.info("Initializing idle checker: session_lifetime") + checker_host.add_checker(SessionLifetimeChecker(*checker_init_args)) # enabled by default + enabled_checkers = await shared_config.etcd.get("config/idle/enabled") + if enabled_checkers: + for checker_name in enabled_checkers.split(","): + checker_cls = checker_registry.get(checker_name, None) + if checker_cls is None: + log.warning("ignoring an unknown idle checker name: {}", checker_name) + continue + log.info("Initializing idle checker: {}", checker_name) + checker_instance = checker_cls(*checker_init_args) + checker_host.add_checker(checker_instance) + return checker_host diff --git a/src/ai/backend/manager/models/__init__.py b/src/ai/backend/manager/models/__init__.py new file mode 100644 index 0000000000..d464244d2f --- /dev/null +++ b/src/ai/backend/manager/models/__init__.py @@ -0,0 +1,52 @@ +from .base import metadata + +from . import agent as _agent +from . import domain as _domain +from . import group as _group +from . import image as _image +from . import kernel as _kernel +from . import keypair as _keypair +from . import user as _user +from . import vfolder as _vfolder +from . import dotfile as _dotfile +from . import resource_policy as _rpolicy +from . import resource_preset as _rpreset +from . import scaling_group as _sgroup +from . import session_template as _sessiontemplate +from . import storage as _storage +from . import error_logs as _errorlogs + +__all__ = ( + 'metadata', + *_agent.__all__, + *_domain.__all__, + *_group.__all__, + *_image.__all__, + *_kernel.__all__, + *_keypair.__all__, + *_user.__all__, + *_vfolder.__all__, + *_dotfile.__all__, + *_rpolicy.__all__, + *_rpreset.__all__, + *_sgroup.__all__, + *_sessiontemplate.__all__, + *_storage.__all__, + *_errorlogs.__all__, +) + +from .agent import * # noqa +from .domain import * # noqa +from .group import * # noqa +from .image import * # noqa +from .kernel import * # noqa +from .keypair import * # noqa +from .user import * # noqa +from .vfolder import * # noqa +from .dotfile import * # noqa +from .resource_policy import * # noqa +from .resource_preset import * # noqa +from .scaling_group import * # noqa +from .session_template import * # noqa +from .storage import * # noqa +from .error_logs import * # noqa diff --git a/src/ai/backend/manager/models/agent.py b/src/ai/backend/manager/models/agent.py new file mode 100644 index 0000000000..41e6aa70c3 --- /dev/null +++ b/src/ai/backend/manager/models/agent.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +import enum +from typing import ( + Any, + Dict, + Mapping, + Sequence, + TYPE_CHECKING, +) + +from dateutil.parser import parse as dtparse +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +import sqlalchemy as sa +from sqlalchemy.sql.expression import true +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +from sqlalchemy.engine.row import Row +from sqlalchemy.dialects import postgresql as pgsql + +from ai.backend.common import msgpack, redis +from ai.backend.common.types import ( + AgentId, + BinarySize, + HardwareMetadata, + ResourceSlot, +) + +from .kernel import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, kernels +from .base import ( + batch_result, + EnumType, Item, + metadata, + PaginatedList, + privileged_mutation, + ResourceSlotColumn, + set_if_set, + simple_db_mutate, +) +from .user import UserRole +from .minilang.queryfilter import QueryFilterParser +from .minilang.ordering import QueryOrderParser +if TYPE_CHECKING: + from ai.backend.manager.models.gql import GraphQueryContext + +__all__: Sequence[str] = ( + 'agents', 'AgentStatus', + 'AgentList', 'Agent', 'ModifyAgent', + 'recalc_agent_resource_occupancy', +) + + +class AgentStatus(enum.Enum): + ALIVE = 0 + LOST = 1 + RESTARTING = 2 + TERMINATED = 3 + + +agents = sa.Table( + 'agents', metadata, + sa.Column('id', sa.String(length=64), primary_key=True), + sa.Column('status', EnumType(AgentStatus), nullable=False, index=True, + default=AgentStatus.ALIVE), + sa.Column('status_changed', sa.DateTime(timezone=True), nullable=True), + sa.Column('region', sa.String(length=64), index=True, nullable=False), + sa.Column('scaling_group', sa.ForeignKey('scaling_groups.name'), index=True, + nullable=False, server_default='default', default='default'), + sa.Column('schedulable', sa.Boolean(), + nullable=False, server_default=true(), default=True), + + sa.Column('available_slots', ResourceSlotColumn(), nullable=False), + sa.Column('occupied_slots', ResourceSlotColumn(), nullable=False), + + sa.Column('addr', sa.String(length=128), nullable=False), + sa.Column('first_contact', sa.DateTime(timezone=True), + server_default=sa.func.now()), + sa.Column('lost_at', sa.DateTime(timezone=True), nullable=True), + + sa.Column('version', sa.String(length=64), nullable=False), + sa.Column('architecture', sa.String(length=32), nullable=False), + sa.Column('compute_plugins', pgsql.JSONB(), nullable=False, default={}), +) + + +class Agent(graphene.ObjectType): + + class Meta: + interfaces = (Item, ) + + status = graphene.String() + status_changed = GQLDateTime() + region = graphene.String() + scaling_group = graphene.String() + schedulable = graphene.Boolean() + available_slots = graphene.JSONString() + occupied_slots = graphene.JSONString() + addr = graphene.String() + architecture = graphene.String() + first_contact = GQLDateTime() + lost_at = GQLDateTime() + live_stat = graphene.JSONString() + version = graphene.String() + compute_plugins = graphene.JSONString() + hardware_metadata = graphene.JSONString() + + # Legacy fields + mem_slots = graphene.Int() + cpu_slots = graphene.Float() + gpu_slots = graphene.Float() + tpu_slots = graphene.Float() + used_mem_slots = graphene.Int() + used_cpu_slots = graphene.Float() + used_gpu_slots = graphene.Float() + used_tpu_slots = graphene.Float() + cpu_cur_pct = graphene.Float() + mem_cur_bytes = graphene.Float() + + compute_containers = graphene.List( + 'ai.backend.manager.models.ComputeContainer', + status=graphene.String()) + + @classmethod + def from_row( + cls, + ctx: GraphQueryContext, + row: Row, + ) -> Agent: + mega = 2 ** 20 + return cls( + id=row['id'], + status=row['status'].name, + status_changed=row['status_changed'], + region=row['region'], + scaling_group=row['scaling_group'], + schedulable=row['schedulable'], + available_slots=row['available_slots'].to_json(), + occupied_slots=row['occupied_slots'].to_json(), + addr=row['addr'], + architecture=row['architecture'], + first_contact=row['first_contact'], + lost_at=row['lost_at'], + version=row['version'], + compute_plugins=row['compute_plugins'], + # legacy fields + mem_slots=BinarySize.from_str(row['available_slots']['mem']) // mega, + cpu_slots=row['available_slots']['cpu'], + gpu_slots=row['available_slots'].get('cuda.device', 0), + tpu_slots=row['available_slots'].get('tpu.device', 0), + used_mem_slots=BinarySize.from_str( + row['occupied_slots'].get('mem', 0)) // mega, + used_cpu_slots=float(row['occupied_slots'].get('cpu', 0)), + used_gpu_slots=float(row['occupied_slots'].get('cuda.device', 0)), + used_tpu_slots=float(row['occupied_slots'].get('tpu.device', 0)), + ) + + async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Any: + ctx: GraphQueryContext = info.context + rs = ctx.redis_stat + live_stat = await redis.execute(rs, lambda r: r.get(str(self.id))) + if live_stat is not None: + live_stat = msgpack.unpackb(live_stat) + return live_stat + + async def resolve_cpu_cur_pct(self, info: graphene.ResolveInfo) -> Any: + ctx: GraphQueryContext = info.context + rs = ctx.redis_stat + live_stat = await redis.execute(rs, lambda r: r.get(str(self.id))) + if live_stat is not None: + live_stat = msgpack.unpackb(live_stat) + try: + return float(live_stat['node']['cpu_util']['pct']) + except (KeyError, TypeError, ValueError): + return 0.0 + return 0.0 + + async def resolve_mem_cur_bytes(self, info: graphene.ResolveInfo) -> Any: + ctx: GraphQueryContext = info.context + rs = ctx.redis_stat + live_stat = await redis.execute(rs, lambda r: r.get(str(self.id))) + if live_stat is not None: + live_stat = msgpack.unpackb(live_stat) + try: + return int(live_stat['node']['mem']['current']) + except (KeyError, TypeError, ValueError): + return 0 + return 0 + + async def resolve_hardware_metadata( + self, + info: graphene.ResolveInfo, + ) -> Mapping[str, HardwareMetadata]: + graph_ctx: GraphQueryContext = info.context + return await graph_ctx.registry.gather_agent_hwinfo(self.id) + + _queryfilter_fieldspec = { + "id": ("id", None), + "status": ("status", lambda s: AgentStatus[s]), + "status_changed": ("status_changed", dtparse), + "region": ("region", None), + "scaling_group": ("scaling_group", None), + "schedulable": ("schedulabe", None), + "addr": ("addr", None), + "first_contact": ("first_contat", dtparse), + "lost_at": ("lost_at", dtparse), + "version": ("version", None), + } + + _queryorder_colmap = { + "id": "id", + "status": "status", + "status_changed": "status_changed", + "region": "region", + "scaling_group": "scaling_group", + "schedulable": "schedulable", + "first_contact": "first_contact", + "lost_at": "lost_at", + "version": "version", + "available_slots": "available_slots", + "occupied_slots": "occupied_slots", + } + + @classmethod + async def load_count( + cls, + graph_ctx: GraphQueryContext, *, + scaling_group: str = None, + raw_status: str = None, + filter: str = None, + ) -> int: + query = ( + sa.select([sa.func.count()]) + .select_from(agents) + ) + if scaling_group is not None: + query = query.where(agents.c.scaling_group == scaling_group) + if raw_status is not None: + query = query.where(agents.c.status == AgentStatus[raw_status]) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + async with graph_ctx.db.begin_readonly() as conn: + result = await conn.execute(query) + return result.scalar() + + @classmethod + async def load_slice( + cls, + graph_ctx: GraphQueryContext, + limit: int, offset: int, *, + scaling_group: str = None, + raw_status: str = None, + filter: str = None, + order: str = None, + ) -> Sequence[Agent]: + query = ( + sa.select([agents]) + .select_from(agents) + .limit(limit) + .offset(offset) + ) + if scaling_group is not None: + query = query.where(agents.c.scaling_group == scaling_group) + if raw_status is not None: + query = query.where(agents.c.status == AgentStatus[raw_status]) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + if order is not None: + qoparser = QueryOrderParser(cls._queryorder_colmap) + query = qoparser.append_ordering(query, order) + else: + query = query.order_by( + agents.c.status.asc(), + agents.c.scaling_group.asc(), + agents.c.id.asc(), + ) + async with graph_ctx.db.begin_readonly() as conn: + return [ + cls.from_row(graph_ctx, row) + async for row in (await conn.stream(query)) + ] + + @classmethod + async def load_all( + cls, + graph_ctx: GraphQueryContext, *, + scaling_group: str = None, + raw_status: str = None, + ) -> Sequence[Agent]: + query = ( + sa.select([agents]) + .select_from(agents) + ) + if scaling_group is not None: + query = query.where(agents.c.scaling_group == scaling_group) + if raw_status is not None: + query = query.where(agents.c.status == AgentStatus[raw_status]) + async with graph_ctx.db.begin_readonly() as conn: + return [ + cls.from_row(graph_ctx, row) + async for row in (await conn.stream(query)) + ] + + @classmethod + async def batch_load( + cls, + graph_ctx: GraphQueryContext, + agent_ids: Sequence[AgentId], *, + raw_status: str = None, + ) -> Sequence[Agent | None]: + query = ( + sa.select([agents]) + .select_from(agents) + .where(agents.c.id.in_(agent_ids)) + .order_by( + agents.c.id, + ) + ) + if raw_status is not None: + query = query.where(agents.c.status == AgentStatus[raw_status]) + async with graph_ctx.db.begin_readonly() as conn: + return await batch_result( + graph_ctx, conn, query, cls, + agent_ids, lambda row: row['id'], + ) + + +class ModifyAgentInput(graphene.InputObjectType): + schedulable = graphene.Boolean(required=False, default=True) + scaling_group = graphene.String(required=False) + + +class AgentList(graphene.ObjectType): + class Meta: + interfaces = (PaginatedList, ) + + items = graphene.List(Agent, required=True) + + +async def recalc_agent_resource_occupancy(db_conn: SAConnection, agent_id: AgentId) -> None: + query = ( + sa.select([ + kernels.c.occupied_slots, + ]) + .select_from(kernels) + .where( + (kernels.c.agent == agent_id) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ) + ) + occupied_slots = ResourceSlot() + result = await db_conn.execute(query) + for row in result: + occupied_slots += row['occupied_slots'] + query = ( + sa.update(agents) + .values({ + 'occupied_slots': occupied_slots, + }) + .where(agents.c.id == agent_id) + ) + await db_conn.execute(query) + + +class ModifyAgent(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + id = graphene.String(required=True) + props = ModifyAgentInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + @privileged_mutation( + UserRole.SUPERADMIN, + lambda id, **kwargs: (None, id), + ) + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + id: str, + props: ModifyAgentInput, + ) -> ModifyAgent: + graph_ctx: GraphQueryContext = info.context + data: Dict[str, Any] = {} + set_if_set(props, data, 'schedulable') + set_if_set(props, data, 'scaling_group') + await graph_ctx.registry.update_scaling_group(id, data['scaling_group']) + + update_query = ( + sa.update(agents).values(data).where(agents.c.id == id) + ) + return await simple_db_mutate(cls, graph_ctx, update_query) diff --git a/src/ai/backend/manager/models/alembic/README b/src/ai/backend/manager/models/alembic/README new file mode 100644 index 0000000000..98e4f9c44e --- /dev/null +++ b/src/ai/backend/manager/models/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/src/ai/backend/manager/models/alembic/env.py b/src/ai/backend/manager/models/alembic/env.py new file mode 100644 index 0000000000..7bf29c765c --- /dev/null +++ b/src/ai/backend/manager/models/alembic/env.py @@ -0,0 +1,79 @@ +from __future__ import with_statement +from alembic import context +from sqlalchemy import engine_from_config, pool +from logging.config import fileConfig +from ai.backend.common.logging import is_active as logging_active + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. + +if not logging_active.get(): + assert config.config_file_name is not None + fileConfig(config.config_file_name) + +# Import the shared metadata and all models. +# (We need to explicilty import models because model modules +# should be executed to add table definitions to the metadata.) +from ai.backend.manager.models.base import metadata +import ai.backend.manager.models.agent +import ai.backend.manager.models.keypair +import ai.backend.manager.models.kernel +import ai.backend.manager.models.vfolder + +target_metadata = metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=target_metadata, literal_binds=True) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/ai/backend/manager/models/alembic/script.py.mako b/src/ai/backend/manager/models/alembic/script.py.mako new file mode 100644 index 0000000000..2c0156303a --- /dev/null +++ b/src/ai/backend/manager/models/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/src/ai/backend/manager/models/alembic/versions/01456c812164_add_idle_timeout_to_keypair_resource_.py b/src/ai/backend/manager/models/alembic/versions/01456c812164_add_idle_timeout_to_keypair_resource_.py new file mode 100644 index 0000000000..9575b7a0c0 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/01456c812164_add_idle_timeout_to_keypair_resource_.py @@ -0,0 +1,28 @@ +"""add-idle-timeout-to-keypair-resource-policy + +Revision ID: 01456c812164 +Revises: dbc1e053b880 +Create Date: 2019-02-22 22:16:47.685740 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '01456c812164' +down_revision = 'dbc1e053b880' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('keypair_resource_policies', + sa.Column('idle_timeout', sa.BigInteger(), + nullable=False, server_default='1800')) + op.alter_column('keypair_resource_policies', 'idle_timeout', + server_default=None) + + +def downgrade(): + op.drop_column('keypair_resource_policies', 'idle_timeout') diff --git a/src/ai/backend/manager/models/alembic/versions/015d84d5a5ef_add_image_table.py b/src/ai/backend/manager/models/alembic/versions/015d84d5a5ef_add_image_table.py new file mode 100644 index 0000000000..d815fa972c --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/015d84d5a5ef_add_image_table.py @@ -0,0 +1,59 @@ +"""add image table + +Revision ID: 015d84d5a5ef +Revises: 60a1effa77d2 +Create Date: 2022-02-15 23:45:19.814677 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from ai.backend.manager.models.base import ForeignKeyIDColumn, IDColumn, convention + +from ai.backend.manager.models.image import ImageType + + +# revision identifiers, used by Alembic. +revision = '015d84d5a5ef' +down_revision = '60a1effa77d2' +branch_labels = None +depends_on = None + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + op.create_table( + 'images', metadata, + IDColumn('id'), + sa.Column('name', sa.String, index=True, nullable=False), + sa.Column('image', sa.String, nullable=False, index=True), + sa.Column( + 'created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True), + sa.Column('tag', sa.String, nullable=False, index=True), + sa.Column('registry', sa.String, nullable=False, index=True), + sa.Column('architecture', sa.String, nullable=False, server_default='x86_64', index=True), + sa.Column('config_digest', sa.CHAR(length=72), nullable=False), + sa.Column('size_bytes', sa.BigInteger, nullable=False), + sa.Column('type', sa.Enum(ImageType, name='image_type'), nullable=False), + sa.Column('accelerators', sa.String), + sa.Column('labels', postgresql.JSONB(), nullable=False), + sa.Column('resources', postgresql.JSONB(), nullable=False), + sa.Index('ix_image_name_architecture', 'name', 'architecture', unique=True), + sa.Index('ix_image_image_tag_registry', 'image', 'tag', 'registry'), + ) + + op.create_table( + 'image_aliases', metadata, + IDColumn('id'), + sa.Column('alias', sa.String, unique=True, index=True), + ForeignKeyIDColumn('image', 'images.id', nullable=False), + sa.Index('ix_image_alias_unique_ref', 'image', 'alias', unique=True), + ) + + +def downgrade(): + op.drop_table('image_aliases') + op.drop_table('images') + op.execute('DROP TYPE image_type') diff --git a/src/ai/backend/manager/models/alembic/versions/0262e50e90e0_add_ssh_keypair_into_keypair.py b/src/ai/backend/manager/models/alembic/versions/0262e50e90e0_add_ssh_keypair_into_keypair.py new file mode 100644 index 0000000000..a37709d8e0 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/0262e50e90e0_add_ssh_keypair_into_keypair.py @@ -0,0 +1,69 @@ +"""add_ssh_keypair_into_keypair + +Revision ID: 0262e50e90e0 +Revises: 4b7b650bc30e +Create Date: 2019-12-12 07:19:48.052928 + +""" +from alembic import op +import sqlalchemy as sa + +from cryptography.hazmat.primitives import serialization as crypto_serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend as crypto_default_backend + +from ai.backend.manager.models.base import convention + +# revision identifiers, used by Alembic. +revision = '0262e50e90e0' +down_revision = '4b7b650bc30e' +branch_labels = None +depends_on = None + + +def generate_ssh_keypair(): + key = rsa.generate_private_key( + backend=crypto_default_backend(), + public_exponent=65537, + key_size=2048 + ) + private_key = key.private_bytes( + crypto_serialization.Encoding.PEM, + crypto_serialization.PrivateFormat.TraditionalOpenSSL, + crypto_serialization.NoEncryption() + ).decode("utf-8") + public_key = key.public_key().public_bytes( + crypto_serialization.Encoding.OpenSSH, + crypto_serialization.PublicFormat.OpenSSH + ).decode("utf-8") + return (public_key, private_key) + + +def upgrade(): + op.add_column('keypairs', sa.Column('ssh_public_key', sa.String(length=750), nullable=True)) + op.add_column('keypairs', sa.Column('ssh_private_key', sa.String(length=2000), nullable=True)) + + # partial table to be preserved and referred + metadata = sa.MetaData(naming_convention=convention) + keypairs = sa.Table( + 'keypairs', metadata, + sa.Column('access_key', sa.String(length=20), primary_key=True), + sa.Column('ssh_public_key', sa.String(length=750), nullable=True), + sa.Column('ssh_private_key', sa.String(length=2000), nullable=True), + ) + + # Fill in SSH keypairs in every keypairs. + conn = op.get_bind() + query = sa.select([keypairs.c.access_key]).select_from(keypairs) + rows = conn.execute(query).fetchall() + for row in rows: + pubkey, privkey = generate_ssh_keypair() + query = (sa.update(keypairs) + .values(ssh_public_key=pubkey, ssh_private_key=privkey) + .where(keypairs.c.access_key == row.access_key)) + conn.execute(query) + + +def downgrade(): + op.drop_column('keypairs', 'ssh_public_key') + op.drop_column('keypairs', 'ssh_private_key') diff --git a/src/ai/backend/manager/models/alembic/versions/02950808ca3d_add_agent_version.py b/src/ai/backend/manager/models/alembic/versions/02950808ca3d_add_agent_version.py new file mode 100644 index 0000000000..2555a7e95f --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/02950808ca3d_add_agent_version.py @@ -0,0 +1,38 @@ +"""add-agent-version + +Revision ID: 02950808ca3d +Revises: 4cc87e7fbfdf +Create Date: 2019-06-02 21:14:12.320029 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '02950808ca3d' +down_revision = '4cc87e7fbfdf' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + 'agents', + # Set the defualt to "19.06.0" for now (since it is the first version to have this field) + # and let the heartbeat handler route to update with the exact value. + sa.Column('compute_plugins', postgresql.JSONB(astext_type=sa.Text()), nullable=False, + server_default=sa.text("'{}'::jsonb"))) + op.add_column( + 'agents', + sa.Column('version', sa.String(length=64), nullable=False, + server_default=sa.literal('19.06.0'))) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('agents', 'version') + op.drop_column('agents', 'compute_plugins') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/06184d82a211_add_session_creation_id.py b/src/ai/backend/manager/models/alembic/versions/06184d82a211_add_session_creation_id.py new file mode 100644 index 0000000000..99fcd9583d --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/06184d82a211_add_session_creation_id.py @@ -0,0 +1,24 @@ +"""add-session_creation_id + +Revision ID: 06184d82a211 +Revises: 250e8656cf45 +Create Date: 2020-12-24 19:58:44.515321 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '06184d82a211' +down_revision = '250e8656cf45' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('session_creation_id', sa.String(length=32), nullable=True)) + + +def downgrade(): + op.drop_column('kernels', 'session_creation_id') diff --git a/src/ai/backend/manager/models/alembic/versions/0c5733f80e4d_index_kernel_timestamps.py b/src/ai/backend/manager/models/alembic/versions/0c5733f80e4d_index_kernel_timestamps.py new file mode 100644 index 0000000000..41b1d25694 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/0c5733f80e4d_index_kernel_timestamps.py @@ -0,0 +1,46 @@ +"""index-kernel-timestamps + +Revision ID: 0c5733f80e4d +Revises: 9bd986a75a2a +Create Date: 2019-09-24 15:58:58.932029 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '0c5733f80e4d' +down_revision = '9bd986a75a2a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('kernels', 'status', + existing_type=postgresql.ENUM('PENDING', 'PREPARING', 'BUILDING', 'PULLING', 'RUNNING', 'RESTARTING', 'RESIZING', 'SUSPENDED', 'TERMINATING', 'TERMINATED', 'ERROR', 'CANCELLED', name='kernelstatus'), + nullable=False, + existing_server_default=sa.text("'PENDING'::kernelstatus")) + op.alter_column('kernels', 'type', + existing_type=postgresql.ENUM('INTERACTIVE', 'BATCH', name='sessiontypes'), + nullable=False, + existing_server_default=sa.text("'INTERACTIVE'::sessiontypes")) + op.create_index(op.f('ix_kernels_status_changed'), 'kernels', ['status_changed'], unique=False) + op.create_index('ix_kernels_updated_order', 'kernels', ['created_at', 'terminated_at', 'status_changed'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_kernels_updated_order', table_name='kernels') + op.drop_index(op.f('ix_kernels_status_changed'), table_name='kernels') + op.alter_column('kernels', 'type', + existing_type=postgresql.ENUM('INTERACTIVE', 'BATCH', name='sessiontypes'), + nullable=True, + existing_server_default=sa.text("'INTERACTIVE'::sessiontypes")) + op.alter_column('kernels', 'status', + existing_type=postgresql.ENUM('PENDING', 'PREPARING', 'BUILDING', 'PULLING', 'RUNNING', 'RESTARTING', 'RESIZING', 'SUSPENDED', 'TERMINATING', 'TERMINATED', 'ERROR', 'CANCELLED', name='kernelstatus'), + nullable=True, + existing_server_default=sa.text("'PENDING'::kernelstatus")) + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/0d553d59f369_users_replace_is_active_to_status_and_its_info.py b/src/ai/backend/manager/models/alembic/versions/0d553d59f369_users_replace_is_active_to_status_and_its_info.py new file mode 100644 index 0000000000..8947c889c1 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/0d553d59f369_users_replace_is_active_to_status_and_its_info.py @@ -0,0 +1,66 @@ +"""replace_users_is_active_to_status_and_its_info + +Revision ID: 0d553d59f369 +Revises: 9cd61b1ae70d +Create Date: 2020-07-04 23:44:09.191729 + +""" +import textwrap + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '0d553d59f369' +down_revision = '9cd61b1ae70d' +branch_labels = None +depends_on = None + +userstatus_choices = ( + 'active', + 'inactive', + 'deleted', + 'before-verification', +) +userstatus = postgresql.ENUM( + *userstatus_choices, + name='userstatus' +) + +def upgrade(): + userstatus.create(op.get_bind()) + op.add_column( + 'users', + sa.Column('status', sa.Enum(*userstatus_choices, name='userstatus'), nullable=True) + ) + op.add_column('users', sa.Column('status_info', sa.Unicode(), nullable=True)) + + # Set user's status field. + conn = op.get_bind() + query = textwrap.dedent( + "UPDATE users SET status = 'active', status_info = 'migrated' WHERE is_active = 't';" + ) + conn.execute(query) + query = textwrap.dedent( + "UPDATE users SET status = 'inactive', status_info = 'migrated' WHERE is_active <> 't';" + ) + conn.execute(query) + + op.alter_column('users', column_name='status', nullable=False) + op.drop_column('users', 'is_active') + + +def downgrade(): + op.add_column('users', sa.Column('is_active', sa.Boolean(), nullable=True)) + + # Set user's is_active field. + conn = op.get_bind() + query = textwrap.dedent("UPDATE users SET is_active = 't' WHERE status = 'active';") + conn.execute(query) + query = textwrap.dedent("UPDATE users SET is_active = 'f' WHERE status <> 'active';") + conn.execute(query) + + op.drop_column('users', 'status_info') + op.drop_column('users', 'status') + userstatus.drop(op.get_bind()) diff --git a/src/ai/backend/manager/models/alembic/versions/0e558d06e0e3_add_service_ports.py b/src/ai/backend/manager/models/alembic/versions/0e558d06e0e3_add_service_ports.py new file mode 100644 index 0000000000..759c7f74ee --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/0e558d06e0e3_add_service_ports.py @@ -0,0 +1,24 @@ +"""add-service-ports + +Revision ID: 0e558d06e0e3 +Revises: 10e39a34eed5 +Create Date: 2018-12-13 17:39:35.573747 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '0e558d06e0e3' +down_revision = '10e39a34eed5' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('service_ports', sa.JSON(), nullable=True)) + + +def downgrade(): + op.drop_column('kernels', 'service_ports') diff --git a/src/ai/backend/manager/models/alembic/versions/0f3bc98edaa0_more_status.py b/src/ai/backend/manager/models/alembic/versions/0f3bc98edaa0_more_status.py new file mode 100644 index 0000000000..e980cbc7c1 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/0f3bc98edaa0_more_status.py @@ -0,0 +1,71 @@ +"""more_status + +Revision ID: 0f3bc98edaa0 +Revises: 7ea324d0535b +Create Date: 2017-08-11 13:12:55.236519 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '0f3bc98edaa0' +down_revision = '7ea324d0535b' +branch_labels = None +depends_on = None + +agentstatus = postgresql.ENUM( + 'ALIVE', 'LOST', 'RESTARTING', 'TERMINATED', + name='agentstatus', +) + +kernelstatus_choices = ( + 'PREPARING', 'BUILDING', 'RUNNING', + 'RESTARTING', 'RESIZING', 'SUSPENDED', + 'TERMINATING', 'TERMINATED', 'ERROR', +) +kernelstatus = postgresql.ENUM( + *kernelstatus_choices, + name='kernelstatus') + + +def upgrade(): + agentstatus.create(op.get_bind()) + kernelstatus.create(op.get_bind()) + op.add_column('agents', sa.Column('lost_at', sa.DateTime(timezone=True), nullable=True)) + op.add_column('agents', sa.Column('status', sa.Enum('ALIVE', 'LOST', 'RESTARTING', 'TERMINATED', name='agentstatus'), nullable=False)) + op.create_index(op.f('ix_agents_status'), 'agents', ['status'], unique=False) + op.add_column('kernels', sa.Column('agent_addr', sa.String(length=128), nullable=False)) + op.add_column('kernels', sa.Column('cpu_slot', sa.Integer(), nullable=False)) + op.add_column('kernels', sa.Column('gpu_slot', sa.Integer(), nullable=False)) + op.add_column('kernels', sa.Column('mem_slot', sa.Integer(), nullable=False)) + op.add_column('kernels', sa.Column('repl_in_port', sa.Integer(), nullable=False)) + op.add_column('kernels', sa.Column('repl_out_port', sa.Integer(), nullable=False)) + op.add_column('kernels', sa.Column('stdin_port', sa.Integer(), nullable=False)) + op.add_column('kernels', sa.Column('stdout_port', sa.Integer(), nullable=False)) + op.drop_column('kernels', 'allocated_cores') + op.add_column('kernels', sa.Column('cpu_set', sa.ARRAY(sa.Integer), nullable=True)) + op.add_column('kernels', sa.Column('gpu_set', sa.ARRAY(sa.Integer), nullable=True)) + op.alter_column('kernels', column_name='status', type_=sa.Enum(*kernelstatus_choices, name='kernelstatus'), + postgresql_using='status::kernelstatus') + + +def downgrade(): + op.drop_column('kernels', 'stdout_port') + op.drop_column('kernels', 'stdin_port') + op.drop_column('kernels', 'repl_out_port') + op.drop_column('kernels', 'repl_in_port') + op.drop_column('kernels', 'mem_slot') + op.drop_column('kernels', 'gpu_slot') + op.drop_column('kernels', 'cpu_slot') + op.drop_column('kernels', 'agent_addr') + op.drop_index(op.f('ix_agents_status'), table_name='agents') + op.drop_column('agents', 'status') + op.drop_column('agents', 'lost_at') + op.alter_column('kernels', column_name='status', type_=sa.String(length=64)) + op.add_column('kernels', sa.Column('allocated_cores', sa.ARRAY(sa.Integer), nullable=True)) + op.drop_column('kernels', 'cpu_set') + op.drop_column('kernels', 'gpu_set') + agentstatus.drop(op.get_bind()) + kernelstatus.drop(op.get_bind()) diff --git a/src/ai/backend/manager/models/alembic/versions/0f7a4b643940_.py b/src/ai/backend/manager/models/alembic/versions/0f7a4b643940_.py new file mode 100644 index 0000000000..ad258aca93 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/0f7a4b643940_.py @@ -0,0 +1,24 @@ +"""empty message + +Revision ID: 0f7a4b643940 +Revises: 7dd1d81c3204, 911023380bc9 +Create Date: 2022-03-14 06:20:12.850338 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '0f7a4b643940' +down_revision = ('7dd1d81c3204', '911023380bc9') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/10e39a34eed5_enlarge_kernels_lang_column_length.py b/src/ai/backend/manager/models/alembic/versions/10e39a34eed5_enlarge_kernels_lang_column_length.py new file mode 100644 index 0000000000..e72c13bb67 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/10e39a34eed5_enlarge_kernels_lang_column_length.py @@ -0,0 +1,32 @@ +"""Enlarge kernels.lang column length + +Revision ID: 10e39a34eed5 +Revises: d582942886ad +Create Date: 2018-10-29 13:52:10.583443 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '10e39a34eed5' +down_revision = 'd582942886ad' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('kernels', 'lang', + existing_type=sa.String(length=64), + type_=sa.String(length=512)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('kernels', 'lang', + existing_type=sa.String(length=512), + type_=sa.String(length=64)) + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/11146ba02235_change_char_col_to_str.py b/src/ai/backend/manager/models/alembic/versions/11146ba02235_change_char_col_to_str.py new file mode 100644 index 0000000000..7d1376b46a --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/11146ba02235_change_char_col_to_str.py @@ -0,0 +1,29 @@ +"""change char col to str + +Revision ID: 11146ba02235 +Revises: 0f7a4b643940 +Create Date: 2022-03-25 12:32:05.637628 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql.expression import bindparam + +# revision identifiers, used by Alembic. +revision = '11146ba02235' +down_revision = '0f7a4b643940' +branch_labels = None +depends_on = None + + +def upgrade(): + conn = op.get_bind() + op.alter_column('agents', column_name='architecture', type_=sa.String(length=32)) + query = ''' + UPDATE agents + SET architecture = TRIM (architecture); + ''' + conn.execute(query) + +def downgrade(): + op.alter_column('agents', column_name='architecture', type_=sa.CHAR(length=32)) diff --git a/src/ai/backend/manager/models/alembic/versions/185852ff9872_add_vfolder_permissions_table.py b/src/ai/backend/manager/models/alembic/versions/185852ff9872_add_vfolder_permissions_table.py new file mode 100644 index 0000000000..57abc631d5 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/185852ff9872_add_vfolder_permissions_table.py @@ -0,0 +1,39 @@ +"""add vfolder_permissions table + +Revision ID: 185852ff9872 +Revises: 1fa6a31ea8e3 +Create Date: 2018-07-05 16:02:05.225094 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import GUID + + +# revision identifiers, used by Alembic. +revision = '185852ff9872' +down_revision = '1fa6a31ea8e3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'vfolder_permissions', + sa.Column('permission', sa.String(length=2), nullable=True), + sa.Column('vfolder', GUID(), nullable=False), + sa.Column('access_key', sa.String(length=20), nullable=False), + sa.ForeignKeyConstraint(['access_key'], ['keypairs.access_key'], + name=op.f('fk_vfolder_permissions_access_key_keypairs')), + sa.ForeignKeyConstraint(['vfolder'], ['vfolders.id'], + name=op.f('fk_vfolder_permissions_vfolder_vfolders'), + onupdate='CASCADE', ondelete='CASCADE') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('vfolder_permissions') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/1e673659b283_add_clusterized_column_to_agents_table.py b/src/ai/backend/manager/models/alembic/versions/1e673659b283_add_clusterized_column_to_agents_table.py new file mode 100644 index 0000000000..b050f4acb7 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/1e673659b283_add_clusterized_column_to_agents_table.py @@ -0,0 +1,27 @@ +"""Add clusterized column to agents table + +Revision ID: 1e673659b283 +Revises: d5cc54fd36b5 +Create Date: 2020-01-07 17:52:51.771357 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1e673659b283' +down_revision = 'd5cc54fd36b5' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + 'agents', + sa.Column('clusterized', sa.Boolean, default=False) + ) + + +def downgrade(): + op.drop_column('agents', 'clusterized') diff --git a/src/ai/backend/manager/models/alembic/versions/1e8531583e20_add_dotfile_column_to_keypairs.py b/src/ai/backend/manager/models/alembic/versions/1e8531583e20_add_dotfile_column_to_keypairs.py new file mode 100644 index 0000000000..5053c3820e --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/1e8531583e20_add_dotfile_column_to_keypairs.py @@ -0,0 +1,26 @@ +"""Add dotfile column to keypairs + +Revision ID: 1e8531583e20 +Revises: ce209920f654 +Create Date: 2020-01-17 15:59:09.367691 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '1e8531583e20' +down_revision = 'ce209920f654' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + 'keypairs', + sa.Column('dotfiles', sa.LargeBinary(length=64 * 1024), nullable=False, server_default='\\x90') + ) + + +def downgrade(): + op.drop_column('keypairs', 'dotfiles') diff --git a/src/ai/backend/manager/models/alembic/versions/1fa6a31ea8e3_add_inviter_field_for_vfolder_.py b/src/ai/backend/manager/models/alembic/versions/1fa6a31ea8e3_add_inviter_field_for_vfolder_.py new file mode 100644 index 0000000000..b765e75098 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/1fa6a31ea8e3_add_inviter_field_for_vfolder_.py @@ -0,0 +1,28 @@ +"""add inviter field for vfolder_invitations + +Revision ID: 1fa6a31ea8e3 +Revises: 26d0c387e764 +Create Date: 2018-07-05 00:09:35.230704 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1fa6a31ea8e3' +down_revision = '26d0c387e764' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('vfolder_invitations', sa.Column('inviter', sa.String(length=256))) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('vfolder_invitations', 'inviter') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/202b6dcbc159_add_internal_data_to_kernels.py b/src/ai/backend/manager/models/alembic/versions/202b6dcbc159_add_internal_data_to_kernels.py new file mode 100644 index 0000000000..c46f306dca --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/202b6dcbc159_add_internal_data_to_kernels.py @@ -0,0 +1,28 @@ +"""add-internal-data-to-kernels + +Revision ID: 202b6dcbc159 +Revises: 3f1dafab60b2 +Create Date: 2019-10-01 16:13:20.935285 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '202b6dcbc159' +down_revision = '3f1dafab60b2' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('internal_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('kernels', 'internal_data') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/22964745c12b_add_total_resource_slots_to_group.py b/src/ai/backend/manager/models/alembic/versions/22964745c12b_add_total_resource_slots_to_group.py new file mode 100644 index 0000000000..73b0c2715e --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/22964745c12b_add_total_resource_slots_to_group.py @@ -0,0 +1,49 @@ +"""add_total_resource_slots_to_group + +Revision ID: 22964745c12b +Revises: 02950808ca3d +Create Date: 2019-06-17 15:57:39.442741 + +""" +import textwrap +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '22964745c12b' +down_revision = '02950808ca3d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('domains', sa.Column('integration_id', sa.String(length=512), nullable=True)) + op.alter_column('domains', 'total_resource_slots', + existing_type=postgresql.JSONB(astext_type=sa.Text()), + nullable=True) + op.add_column('groups', sa.Column('integration_id', sa.String(length=512), nullable=True)) + op.add_column('groups', sa.Column('total_resource_slots', + postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + op.add_column('users', sa.Column('integration_id', sa.String(length=512), nullable=True)) + # ### end Alembic commandk ### + + print('\nSet group\'s total_resource_slots with empty dictionary.') + query = textwrap.dedent('''\ + UPDATE groups SET total_resource_slots = '{}'::jsonb; + ''') + connection = op.get_bind() + connection.execute(query) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'integration_id') + op.drop_column('groups', 'total_resource_slots') + op.drop_column('groups', 'integration_id') + op.alter_column('domains', 'total_resource_slots', + existing_type=postgresql.JSONB(astext_type=sa.Text()), + nullable=False) + op.drop_column('domains', 'integration_id') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/22e52d03fc61_add_allowed_docker_registries_in_domains.py b/src/ai/backend/manager/models/alembic/versions/22e52d03fc61_add_allowed_docker_registries_in_domains.py new file mode 100644 index 0000000000..2cad0f53c9 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/22e52d03fc61_add_allowed_docker_registries_in_domains.py @@ -0,0 +1,46 @@ +"""add_allowed_docker_registries_in_domains + +Revision ID: 22e52d03fc61 +Revises: c401d78cc7b9 +Create Date: 2019-07-29 11:44:55.593760 + +""" +import os + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '22e52d03fc61' +down_revision = 'c401d78cc7b9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('domains', + sa.Column('allowed_docker_registries', + postgresql.ARRAY(sa.String()), nullable=True)) + # ### end Alembic commands ### + + print('\nSet default allowed_docker_registries.') + allowed_registries = os.environ.get('ALLOWED_DOCKER_REGISTRIES', None) + if allowed_registries: + allowed_registries = allowed_registries.replace(' ', '') + allowed_registries = '{index.docker.io,' + allowed_registries + '}' + else: + allowed_registries = '{index.docker.io}' + connection = op.get_bind() + query = ("UPDATE domains SET allowed_docker_registries = '{}';".format(allowed_registries)) + connection.execute(query) + + op.alter_column('domains', column_name='allowed_docker_registries', + nullable=False) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('domains', 'allowed_docker_registries') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/250e8656cf45_add_status_data.py b/src/ai/backend/manager/models/alembic/versions/250e8656cf45_add_status_data.py new file mode 100644 index 0000000000..eaef0d5c6e --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/250e8656cf45_add_status_data.py @@ -0,0 +1,24 @@ +"""add-status_data + +Revision ID: 250e8656cf45 +Revises: 57e717103287 +Create Date: 2020-12-23 14:19:08.801283 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '250e8656cf45' +down_revision = '57e717103287' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('status_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + + +def downgrade(): + op.drop_column('kernels', 'status_data') diff --git a/src/ai/backend/manager/models/alembic/versions/25e903510fa1_add_dotfiles_to_domains_and_groups.py b/src/ai/backend/manager/models/alembic/versions/25e903510fa1_add_dotfiles_to_domains_and_groups.py new file mode 100644 index 0000000000..f9c67d38b6 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/25e903510fa1_add_dotfiles_to_domains_and_groups.py @@ -0,0 +1,36 @@ +"""add_dotfiles_to_domains_and_groups + +Revision ID: 25e903510fa1 +Revises: 0d553d59f369 +Create Date: 2020-09-11 17:00:00.564219 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '25e903510fa1' +down_revision = '0d553d59f369' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + 'domains', + sa.Column('dotfiles', sa.LargeBinary(length=65536), nullable=False, server_default='\\x90') + ) + op.add_column( + 'groups', + sa.Column('dotfiles', sa.LargeBinary(length=65536), nullable=True, server_default='\\x90') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('groups', 'dotfiles') + op.drop_column('domains', 'dotfiles') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/26d0c387e764_create_vfolder_invitations_table.py b/src/ai/backend/manager/models/alembic/versions/26d0c387e764_create_vfolder_invitations_table.py new file mode 100644 index 0000000000..6f198b77cb --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/26d0c387e764_create_vfolder_invitations_table.py @@ -0,0 +1,39 @@ +"""create vfolder_invitations table + +Revision ID: 26d0c387e764 +Revises: f8a71c3bffa2 +Create Date: 2018-07-04 14:57:46.517587 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import GUID + + +# revision identifiers, used by Alembic. +revision = '26d0c387e764' +down_revision = 'f8a71c3bffa2' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'vfolder_invitations', + sa.Column('id', GUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('permission', sa.String(length=2), nullable=True), + sa.Column('invitee', sa.String(length=256), nullable=False), + sa.Column('vfolder', GUID(), nullable=False), + sa.ForeignKeyConstraint(['vfolder'], ['vfolders.id'], + name=op.f('fk_vfolder_invitations_vfolder_vfolders'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_vfolder_invitations')) + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('vfolder_invitations') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/2a82340fa30e_add_mounts_info_in_kernel_db.py b/src/ai/backend/manager/models/alembic/versions/2a82340fa30e_add_mounts_info_in_kernel_db.py new file mode 100644 index 0000000000..5d16610c8b --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/2a82340fa30e_add_mounts_info_in_kernel_db.py @@ -0,0 +1,28 @@ +"""add_mounts_info_in_kernel_db + +Revision ID: 2a82340fa30e +Revises: c1409ad0e8da +Create Date: 2019-08-01 15:59:41.807766 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '2a82340fa30e' +down_revision = 'c1409ad0e8da' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('mounts', sa.ARRAY(sa.String()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('kernels', 'mounts') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/2b0931e4a059_convert_lang_to_image_and_registry.py b/src/ai/backend/manager/models/alembic/versions/2b0931e4a059_convert_lang_to_image_and_registry.py new file mode 100644 index 0000000000..6c60363ab8 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/2b0931e4a059_convert_lang_to_image_and_registry.py @@ -0,0 +1,31 @@ +"""convert-lang-to-image-and-registry + +Revision ID: 2b0931e4a059 +Revises: f0f4ee907155 +Create Date: 2019-01-28 23:53:44.342786 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '2b0931e4a059' +down_revision = 'f0f4ee907155' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('kernels', column_name='lang', new_column_name='image') + op.add_column('kernels', sa.Column('registry', sa.String(length=512), + nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('kernels', column_name='image', new_column_name='lang') + op.drop_column('kernels', 'registry') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/352fa4f88f61_add_tpu_slot_on_kernel_model.py b/src/ai/backend/manager/models/alembic/versions/352fa4f88f61_add_tpu_slot_on_kernel_model.py new file mode 100644 index 0000000000..91fcb1e71b --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/352fa4f88f61_add_tpu_slot_on_kernel_model.py @@ -0,0 +1,32 @@ +"""add tpu slot on kernel model + +Revision ID: 352fa4f88f61 +Revises: 57b523dec0e8 +Create Date: 2018-11-12 11:39:30.613081 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '352fa4f88f61' +down_revision = '57b523dec0e8' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('tpu_set', sa.ARRAY(sa.Integer()), nullable=True)) + op.add_column('kernels', sa.Column('tpu_slot', sa.Float(), nullable=False, + server_default='0')) + op.alter_column('kernels', 'tpu_slot', server_default=None) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('kernels', 'tpu_slot') + op.drop_column('kernels', 'tpu_set') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/3bb80d1887d6_add_preopen_ports.py b/src/ai/backend/manager/models/alembic/versions/3bb80d1887d6_add_preopen_ports.py new file mode 100644 index 0000000000..1aa68253c8 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/3bb80d1887d6_add_preopen_ports.py @@ -0,0 +1,28 @@ +"""add preopen ports + +Revision ID: 3bb80d1887d6 +Revises: 1e8531583e20 +Create Date: 2020-02-05 17:02:42.344726 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3bb80d1887d6' +down_revision = '1e8531583e20' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('preopen_ports', sa.ARRAY(sa.Integer()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('kernels', 'preopen_ports') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/3cf19d906e71_.py b/src/ai/backend/manager/models/alembic/versions/3cf19d906e71_.py new file mode 100644 index 0000000000..7242091af5 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/3cf19d906e71_.py @@ -0,0 +1,24 @@ +"""empty message + +Revision ID: 3cf19d906e71 +Revises: 22964745c12b, 5d8e6043455e +Create Date: 2019-06-17 16:45:14.580560 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3cf19d906e71' +down_revision = ('22964745c12b', '5d8e6043455e') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/3f1dafab60b2_merge.py b/src/ai/backend/manager/models/alembic/versions/3f1dafab60b2_merge.py new file mode 100644 index 0000000000..121738bb21 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/3f1dafab60b2_merge.py @@ -0,0 +1,24 @@ +"""merge + +Revision ID: 3f1dafab60b2 +Revises: c092dabf3ee5, 6f5fe19894b7 +Create Date: 2019-09-30 04:34:42.092031 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3f1dafab60b2' +down_revision = ('c092dabf3ee5', '6f5fe19894b7') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/405aa2c39458_job_queue.py b/src/ai/backend/manager/models/alembic/versions/405aa2c39458_job_queue.py new file mode 100644 index 0000000000..832c655671 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/405aa2c39458_job_queue.py @@ -0,0 +1,192 @@ +"""job-queue + +Revision ID: 405aa2c39458 +Revises: 5b45f28d2cac +Create Date: 2019-09-16 02:08:41.396372 + +""" +import textwrap + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from ai.backend.manager.models.base import GUID + + +# revision identifiers, used by Alembic. +revision = '405aa2c39458' +down_revision = '5b45f28d2cac' +branch_labels = None +depends_on = None + +sessionresult = postgresql.ENUM( + 'UNDEFINED', 'SUCCESS', 'FAILURE', + name='sessionresult' +) + +sessiontypes = postgresql.ENUM( + 'INTERACTIVE', 'BATCH', + name='sessiontypes' +) + +kernelstatus_new_values = [ + 'PENDING', # added + 'PREPARING', + 'BUILDING', + 'PULLING', # added + 'RUNNING', + 'RESTARTING', + 'RESIZING', + 'SUSPENDED', + 'TERMINATING', + 'TERMINATED', + 'ERROR', +] +kernelstatus_new = postgresql.ENUM(*kernelstatus_new_values, name='kernelstatus') + +kernelstatus_old_values = [ + # 'PENDING', # added + 'PREPARING', + 'BUILDING', + # 'PULLING', # added + 'RUNNING', + 'RESTARTING', + 'RESIZING', + 'SUSPENDED', + 'TERMINATING', + 'TERMINATED', + 'ERROR', +] +kernelstatus_old = postgresql.ENUM(*kernelstatus_old_values, name='kernelstatus') + + +def upgrade(): + conn = op.get_bind() + sessionresult.create(conn) + sessiontypes.create(conn) + conn.execute('ALTER TYPE kernelstatus RENAME TO kernelstatus_old;') + kernelstatus_new.create(conn) + conn.execute(textwrap.dedent('''\ + CREATE FUNCTION kernelstatus_new_old_compare( + new_enum_val kernelstatus, old_enum_val kernelstatus_old + ) + RETURNS boolean AS $$ + SELECT new_enum_val::text <> old_enum_val::text; + $$ LANGUAGE SQL IMMUTABLE; + + CREATE OPERATOR <> ( + leftarg = kernelstatus, + rightarg = kernelstatus_old, + procedure = kernelstatus_new_old_compare + ); + + ALTER TABLE kernels + ALTER COLUMN "status" DROP DEFAULT, + ALTER COLUMN "status" TYPE kernelstatus USING "status"::text::kernelstatus, + ALTER COLUMN "status" SET DEFAULT 'PENDING'::kernelstatus; + + DROP FUNCTION kernelstatus_new_old_compare( + new_enum_val kernelstatus, old_enum_val kernelstatus_old + ) CASCADE; + + DROP TYPE kernelstatus_old; + ''')) + + op.add_column('agents', sa.Column('status_changed', sa.DateTime(timezone=True), nullable=True)) + + op.create_table( + 'kernel_dependencies', + sa.Column('kernel_id', GUID(), nullable=False), + sa.Column('depends_on', GUID(), nullable=False), + sa.ForeignKeyConstraint(['depends_on'], ['kernels.id'], + name=op.f('fk_kernel_dependencies_depends_on_kernels')), + sa.ForeignKeyConstraint(['kernel_id'], ['kernels.id'], + name=op.f('fk_kernel_dependencies_kernel_id_kernels')), + sa.PrimaryKeyConstraint('kernel_id', 'depends_on', name=op.f('pk_kernel_dependencies')) + ) + op.create_index(op.f('ix_kernel_dependencies_depends_on'), + 'kernel_dependencies', ['depends_on'], unique=False) + op.create_index(op.f('ix_kernel_dependencies_kernel_id'), + 'kernel_dependencies', ['kernel_id'], unique=False) + op.add_column( + 'kernels', + sa.Column( + 'result', + sa.Enum('UNDEFINED', 'SUCCESS', 'FAILURE', name='sessionresult'), + default='UNDEFINED', + server_default='UNDEFINED', + nullable=False, + ) + ) + op.add_column('kernels', sa.Column('status_changed', sa.DateTime(timezone=True), nullable=True)) + op.add_column( + 'kernels', + sa.Column( + 'type', + sa.Enum('INTERACTIVE', 'BATCH', name='sessiontypes'), + default='INTERACTIVE', + server_default='INTERACTIVE', + nullable=True, + ) + ) + op.alter_column( + 'kernels', 'agent_addr', + existing_type=sa.VARCHAR(length=128), + nullable=True) + op.create_index(op.f('ix_kernels_result'), 'kernels', ['result'], unique=False) + op.create_index(op.f('ix_kernels_type'), 'kernels', ['type'], unique=False) + + +def downgrade(): + conn = op.get_bind() + conn.execute('ALTER TYPE kernelstatus RENAME TO kernelstatus_new;') + kernelstatus_old.create(conn) + conn.execute(textwrap.dedent('''\ + CREATE FUNCTION kernelstatus_new_old_compare( + old_enum_val kernelstatus, new_enum_val kernelstatus_new + ) + RETURNS boolean AS $$ + SELECT old_enum_val::text <> new_enum_val::text; + $$ LANGUAGE SQL IMMUTABLE; + + CREATE OPERATOR <> ( + leftarg = kernelstatus, + rightarg = kernelstatus_new, + procedure = kernelstatus_new_old_compare + ); + + ALTER TABLE kernels + ALTER COLUMN "status" DROP DEFAULT, + ALTER COLUMN "status" TYPE kernelstatus USING ( + CASE "status"::text + WHEN 'PULLING' THEN 'PREPARING' + WHEN 'PENDING' THEN 'PREPARING' + ELSE "status"::text + END + )::kernelstatus, + ALTER COLUMN "status" SET DEFAULT 'PREPARING'::kernelstatus; + + DROP FUNCTION kernelstatus_new_old_compare( + old_enum_val kernelstatus, new_enum_val kernelstatus_new + ) CASCADE; + + DROP TYPE kernelstatus_new; + ''')) + + op.drop_index(op.f('ix_kernels_type'), table_name='kernels') + op.drop_index(op.f('ix_kernels_result'), table_name='kernels') + op.alter_column( + 'kernels', 'agent_addr', + existing_type=sa.VARCHAR(length=128), + nullable=False) + op.drop_column('kernels', 'type') + op.drop_column('kernels', 'status_changed') + op.drop_column('kernels', 'result') + op.drop_column('agents', 'status_changed') + op.drop_index(op.f('ix_kernel_dependencies_kernel_id'), table_name='kernel_dependencies') + op.drop_index(op.f('ix_kernel_dependencies_depends_on'), table_name='kernel_dependencies') + op.drop_table('kernel_dependencies') + + sessionresult.drop(op.get_bind()) + sessiontypes.drop(op.get_bind()) diff --git a/src/ai/backend/manager/models/alembic/versions/4545f5c948b3_add_io_scratch_size_stats.py b/src/ai/backend/manager/models/alembic/versions/4545f5c948b3_add_io_scratch_size_stats.py new file mode 100644 index 0000000000..86766ef114 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/4545f5c948b3_add_io_scratch_size_stats.py @@ -0,0 +1,25 @@ +"""add_io_scratch_size_stats + +Revision ID: 4545f5c948b3 +Revises: e7371ca5797a +Create Date: 2017-10-10 15:57:48.463055 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '4545f5c948b3' +down_revision = 'e7371ca5797a' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('io_max_scratch_size', sa.BigInteger(), nullable=True)) + op.drop_column('kernels', 'mem_cur_bytes') + + +def downgrade(): + op.add_column('kernels', sa.Column('mem_cur_bytes', sa.BIGINT(), autoincrement=False, nullable=True)) + op.drop_column('kernels', 'io_max_scratch_size') diff --git a/src/ai/backend/manager/models/alembic/versions/48ab2dfefba9_reindex_kernel_updated_order.py b/src/ai/backend/manager/models/alembic/versions/48ab2dfefba9_reindex_kernel_updated_order.py new file mode 100644 index 0000000000..e56fd90e37 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/48ab2dfefba9_reindex_kernel_updated_order.py @@ -0,0 +1,30 @@ +"""reindex-kernel-updated-order + +Revision ID: 48ab2dfefba9 +Revises: 0c5733f80e4d +Create Date: 2019-09-24 16:04:29.928068 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '48ab2dfefba9' +down_revision = '0c5733f80e4d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_kernels_updated_order', table_name='kernels') + op.create_index('ix_kernels_updated_order', 'kernels', [sa.text("greatest('created_at', 'terminated_at', 'status_changed')")], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_kernels_updated_order', table_name='kernels') + op.create_index('ix_kernels_updated_order', 'kernels', ['created_at', 'terminated_at', 'status_changed'], unique=False) + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/4b7b650bc30e_add_creator_in_vfolders.py b/src/ai/backend/manager/models/alembic/versions/4b7b650bc30e_add_creator_in_vfolders.py new file mode 100644 index 0000000000..03ea33cbcf --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/4b7b650bc30e_add_creator_in_vfolders.py @@ -0,0 +1,28 @@ +"""add_creator_in_vfolders + +Revision ID: 4b7b650bc30e +Revises: 202b6dcbc159 +Create Date: 2019-10-30 00:23:57.085692 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '4b7b650bc30e' +down_revision = '202b6dcbc159' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('vfolders', sa.Column('creator', sa.String(length=128), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('vfolders', 'creator') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/4b8a66fb8d82_revamp_keypairs.py b/src/ai/backend/manager/models/alembic/versions/4b8a66fb8d82_revamp_keypairs.py new file mode 100644 index 0000000000..5c61ee36fc --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/4b8a66fb8d82_revamp_keypairs.py @@ -0,0 +1,30 @@ +"""revamp_keypairs + +Revision ID: 4b8a66fb8d82 +Revises: 854bd902b1bc +Create Date: 2017-09-13 01:57:42.355633 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = '4b8a66fb8d82' +down_revision = '854bd902b1bc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('keypairs', column_name='billing_plan', new_column_name='resource_policy') + op.create_index(op.f('ix_keypairs_is_active'), 'keypairs', ['is_active'], unique=False) + op.create_index(op.f('ix_keypairs_user_id'), 'keypairs', ['user_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('keypairs', column_name='resource_policy', new_column_name='billing_plan') + op.drop_index(op.f('ix_keypairs_user_id'), table_name='keypairs') + op.drop_index(op.f('ix_keypairs_is_active'), table_name='keypairs') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/4cc87e7fbfdf_stats_refactor.py b/src/ai/backend/manager/models/alembic/versions/4cc87e7fbfdf_stats_refactor.py new file mode 100644 index 0000000000..d3981e2d64 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/4cc87e7fbfdf_stats_refactor.py @@ -0,0 +1,216 @@ +"""stats-refactor + +Revision ID: 4cc87e7fbfdf +Revises: e18ed5fcfedf +Create Date: 2019-05-30 18:40:17.669756 + +""" +from datetime import timedelta +from decimal import Decimal +import math + +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import ( + convention, IDColumn, ResourceSlotColumn, +) +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql.expression import bindparam + +# revision identifiers, used by Alembic. +revision = '4cc87e7fbfdf' +down_revision = 'e18ed5fcfedf' +branch_labels = None +depends_on = None + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + + # previous table def used for migration + kernels = sa.Table( + 'kernels', metadata, + + # preserved and referred columns + IDColumn(), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True), + sa.Column('terminated_at', sa.DateTime(timezone=True), + nullable=True, default=sa.null(), index=True), + sa.Column('occupied_slots', ResourceSlotColumn(), nullable=False), + sa.Column('occupied_shares', postgresql.JSONB(), nullable=False, default={}), + + # old column(s) to be migrated and removed + sa.Column('cpu_used', sa.BigInteger(), default=0), # msec + sa.Column('mem_max_bytes', sa.BigInteger(), default=0), + sa.Column('net_rx_bytes', sa.BigInteger(), default=0), + sa.Column('net_tx_bytes', sa.BigInteger(), default=0), + sa.Column('io_read_bytes', sa.BigInteger(), default=0), + sa.Column('io_write_bytes', sa.BigInteger(), default=0), + sa.Column('io_max_scratch_size', sa.BigInteger(), default=0), + + # new column(s) to be added + sa.Column('last_stat', postgresql.JSONB(), nullable=True, default=sa.null()), + ) + + op.add_column('kernels', sa.Column('last_stat', postgresql.JSONB(astext_type=sa.Text()), + nullable=True)) + + connection = op.get_bind() + query = sa.select([ + kernels.c.id, + kernels.c.created_at, + kernels.c.terminated_at, + kernels.c.occupied_slots, + kernels.c.occupied_shares, + kernels.c.cpu_used, + kernels.c.mem_max_bytes, + kernels.c.net_rx_bytes, + kernels.c.net_tx_bytes, + kernels.c.io_read_bytes, + kernels.c.io_write_bytes, + kernels.c.io_max_scratch_size, + ]).select_from(kernels).order_by(kernels.c.created_at) + results = connection.execute(query).fetchall() + + updates = [] + q_pct = Decimal('0.00') + for row in results: + if row['terminated_at'] is None: + cpu_avg_pct = 0 + else: + cpu_avg_pct = ( + Decimal(100) * + Decimal(timedelta(microseconds=1e3 * row['cpu_used']) / + (row['terminated_at'] - row['created_at'])) + ).quantize(q_pct).normalize() + mem_capacity = 0 + _oslots = row['occupied_slots'] + if _oslots: + mem_capacity = _oslots.get('mem') + if mem_capacity is None or mem_capacity == 0: + # fallback: try legacy field + _oshares = row['occupied_shares'] + mem_capacity = _oshares.get('mem') + if mem_capacity is None or mem_capacity == 0: + # fallback: round-up to nearest GiB + mem_capacity = math.ceil(row['mem_max_bytes'] / (2**30)) * (2**30) + if mem_capacity is None or mem_capacity == 0: + # fallback: assume 1 GiB + mem_capacity = 2**30 + last_stat = { + 'cpu_used': { + 'current': str(row['cpu_used']), + 'capacity': None, + }, + 'cpu_util': { + 'current': str(cpu_avg_pct), + 'capacity': None, + 'stats.avg': str(cpu_avg_pct), + }, + 'mem': { + 'current': str(row['mem_max_bytes']), + 'capacity': str(mem_capacity), + 'stats.max': str(row['mem_max_bytes']), + }, + 'io_read': { + 'current': str(row['io_read_bytes']), + 'capacity': None, + 'stats.rate': '0', + }, + 'io_write': { + 'current': str(row['io_write_bytes']), + 'capacity': None, + 'stats.rate': '0', + }, + 'io_scratch_size': { + 'current': str(row['io_max_scratch_size']), + 'capacity': str(10 * (2**30)), # 10 GiB + 'stats.max': str(row['io_max_scratch_size']), + }, + } + updates.append({'row_id': row['id'], 'last_stat': last_stat}) + + if updates: + query = (sa.update(kernels) + .values(last_stat=bindparam('last_stat')) + .where(kernels.c.id == bindparam('row_id'))) + connection.execute(query, updates) + + op.drop_column('kernels', 'io_max_scratch_size') + op.drop_column('kernels', 'net_rx_bytes') + op.drop_column('kernels', 'net_tx_bytes') + op.drop_column('kernels', 'mem_max_bytes') + op.drop_column('kernels', 'io_write_bytes') + op.drop_column('kernels', 'io_read_bytes') + op.drop_column('kernels', 'cpu_used') + + +def downgrade(): + metadata = sa.MetaData(naming_convention=convention) + + kernels = sa.Table( + 'kernels', metadata, + + # preserved and referred columns + IDColumn(), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True), + sa.Column('terminated_at', sa.DateTime(timezone=True), + nullable=True, default=sa.null(), index=True), + sa.Column('occupied_slots', ResourceSlotColumn(), nullable=False), + sa.Column('occupied_shares', postgresql.JSONB(), nullable=False, default={}), + + # old column(s) to be migrated + sa.Column('last_stat', postgresql.JSONB(), nullable=True, default=sa.null()), + + # new column(s) to be added + sa.Column('cpu_used', sa.BigInteger(), default=0), # msec + sa.Column('mem_max_bytes', sa.BigInteger(), default=0), + sa.Column('net_rx_bytes', sa.BigInteger(), default=0), + sa.Column('net_tx_bytes', sa.BigInteger(), default=0), + sa.Column('io_read_bytes', sa.BigInteger(), default=0), + sa.Column('io_write_bytes', sa.BigInteger(), default=0), + sa.Column('io_max_scratch_size', sa.BigInteger(), default=0), + ) + + op.add_column('kernels', sa.Column('cpu_used', sa.BIGINT(), autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('io_read_bytes', sa.BIGINT(), autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('io_write_bytes', sa.BIGINT(), + autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('mem_max_bytes', sa.BIGINT(), autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('net_tx_bytes', sa.BIGINT(), autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('net_rx_bytes', sa.BIGINT(), autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('io_max_scratch_size', sa.BIGINT(), + autoincrement=False, nullable=True)) + + # Restore old stats + connection = op.get_bind() + query = sa.select([kernels.c.id, kernels.c.last_stat]).select_from(kernels) + results = connection.execute(query).fetchall() + updates = [] + for row in results: + last_stat = row['last_stat'] + updates.append({ + 'row_id': row['id'], + 'cpu_used': Decimal(last_stat['cpu_used']['current']), + 'io_read_bytes': int(last_stat['io_read']['current']), + 'io_write_bytes': int(last_stat['io_write']['current']), + 'mem_max_bytes': int(last_stat['mem']['stats.max']), + 'io_max_scratch_size': int(last_stat['io_scratch_size']['stats.max']), + }) + if updates: + query = (sa.update(kernels) + .values({ + 'cpu_used': bindparam('cpu_used'), + 'io_read_bytes': bindparam('io_read_bytes'), + 'io_write_bytes': bindparam('io_write_bytes'), + 'mem_max_bytes': bindparam('mem_max_bytes'), + 'net_tx_bytes': 0, + 'net_rx_bytes': 0, + 'io_max_scratch_size': bindparam('io_max_scratch_size'), + }) + .where(kernels.c.id == bindparam('row_id'))) + connection.execute(query, updates) + + op.drop_column('kernels', 'last_stat') diff --git a/src/ai/backend/manager/models/alembic/versions/513164749de4_add_cancelled_to_kernelstatus.py b/src/ai/backend/manager/models/alembic/versions/513164749de4_add_cancelled_to_kernelstatus.py new file mode 100644 index 0000000000..3cc0cac6ef --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/513164749de4_add_cancelled_to_kernelstatus.py @@ -0,0 +1,92 @@ +"""add-cancelled-to-kernelstatus + +Revision ID: 513164749de4 +Revises: 405aa2c39458 +Create Date: 2019-09-20 11:13:39.157834 + +""" +import textwrap + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '513164749de4' +down_revision = '405aa2c39458' +branch_labels = None +depends_on = None + +kernelstatus_new_values = [ + 'PENDING', + 'PREPARING', + 'BUILDING', + 'PULLING', + 'RUNNING', + 'RESTARTING', + 'RESIZING', + 'SUSPENDED', + 'TERMINATING', + 'TERMINATED', + 'ERROR', + 'CANCELLED' # added +] +kernelstatus_new = postgresql.ENUM(*kernelstatus_new_values, name='kernelstatus') + +kernelstatus_old_values = [ + 'PENDING', + 'PREPARING', + 'BUILDING', + 'PULLING', + 'RUNNING', + 'RESTARTING', + 'RESIZING', + 'SUSPENDED', + 'TERMINATING', + 'TERMINATED', + # 'ERROR', +] +kernelstatus_old = postgresql.ENUM(*kernelstatus_old_values, name='kernelstatus') + + +def upgrade(): + conn = op.get_bind() + conn.execute('DROP INDEX IF EXISTS ix_kernels_unique_sess_token;') + conn.execute('ALTER TYPE kernelstatus RENAME TO kernelstatus_old;') + kernelstatus_new.create(conn) + conn.execute(textwrap.dedent('''\ + ALTER TABLE kernels + ALTER COLUMN "status" DROP DEFAULT, + ALTER COLUMN "status" TYPE kernelstatus USING "status"::text::kernelstatus, + ALTER COLUMN "status" SET DEFAULT 'PENDING'::kernelstatus; + DROP TYPE kernelstatus_old; + ''')) + op.create_index( + 'ix_kernels_unique_sess_token', 'kernels', ['access_key', 'sess_id'], + unique=True, postgresql_where=sa.text( + "status NOT IN ('TERMINATED', 'CANCELLED') and role = 'master'" + )) + + +def downgrade(): + op.drop_index('ix_kernels_unique_sess_token', table_name='kernels') + conn = op.get_bind() + conn.execute('ALTER TYPE kernelstatus RENAME TO kernelstatus_new;') + kernelstatus_old.create(conn) + conn.execute(textwrap.dedent('''\ + ALTER TABLE kernels + ALTER COLUMN "status" DROP DEFAULT, + ALTER COLUMN "status" TYPE kernelstatus USING ( + CASE "status"::text + WHEN 'CANCELLED' THEN 'TERMINATED' + ELSE "status"::text + END + )::kernelstatus, + ALTER COLUMN "status" SET DEFAULT 'PREPARING'::kernelstatus; + DROP TYPE kernelstatus_new; + ''')) + op.create_index( + 'ix_kernels_unique_sess_token', 'kernels', ['access_key', 'sess_id'], + unique=True, postgresql_where=sa.text( + "status != 'TERMINATED' and role = 'master'" + )) diff --git a/src/ai/backend/manager/models/alembic/versions/518ecf41f567_add_index_for_cluster_role.py b/src/ai/backend/manager/models/alembic/versions/518ecf41f567_add_index_for_cluster_role.py new file mode 100644 index 0000000000..efa32f60f7 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/518ecf41f567_add_index_for_cluster_role.py @@ -0,0 +1,28 @@ +"""add-index-for-cluster_role + +Revision ID: 518ecf41f567 +Revises: dc9b66466e43 +Create Date: 2021-01-07 00:04:53.794638 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = '518ecf41f567' +down_revision = 'dc9b66466e43' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f('ix_kernels_cluster_role'), 'kernels', ['cluster_role'], unique=False) + op.create_index('ix_kernels_status_role', 'kernels', ['status', 'cluster_role'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_kernels_status_role', table_name='kernels') + op.drop_index(op.f('ix_kernels_cluster_role'), table_name='kernels') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/51dddd79aa21_add_logs_column_on_kernel_table.py b/src/ai/backend/manager/models/alembic/versions/51dddd79aa21_add_logs_column_on_kernel_table.py new file mode 100644 index 0000000000..d560387997 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/51dddd79aa21_add_logs_column_on_kernel_table.py @@ -0,0 +1,24 @@ +"""Add logs column on kernel table + +Revision ID: 51dddd79aa21 +Revises: 3bb80d1887d6 +Create Date: 2020-02-11 14:45:55.496745 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '51dddd79aa21' +down_revision = '3bb80d1887d6' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('container_log', sa.LargeBinary(), nullable=True)) + + +def downgrade(): + op.drop_column('kernels', 'container_log') diff --git a/src/ai/backend/manager/models/alembic/versions/529113b08c2c_add_vfolder_type_column.py b/src/ai/backend/manager/models/alembic/versions/529113b08c2c_add_vfolder_type_column.py new file mode 100644 index 0000000000..5fc60ce7a4 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/529113b08c2c_add_vfolder_type_column.py @@ -0,0 +1,128 @@ +"""add_vfolder_type_column + +Revision ID: 529113b08c2c +Revises: c481d3dc6c7d +Create Date: 2020-04-09 16:37:35.460936 + +""" +import textwrap +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql.expression import bindparam +from ai.backend.manager.models import VFolderPermission, VFolderUsageMode, VFolderOwnershipType +from ai.backend.manager.models.base import convention, EnumValueType, IDColumn, GUID + + +# revision identifiers, used by Alembic. +revision = '529113b08c2c' +down_revision = 'c481d3dc6c7d' +branch_labels = None +depends_on = None + +vfperm_choices = list(map(lambda v: v.value, VFolderPermission)) +# vfolderpermission type should already be defined. + +vfusagemode_choices = list(map(lambda v: v.value, VFolderUsageMode)) +vfolderusagemode = postgresql.ENUM( + *vfusagemode_choices, + name='vfolderusagemode', +) + +vfownershiptype_choices= list(map(lambda v: v.value, VFolderOwnershipType)) +vfolderownershiptype = postgresql.ENUM( + *vfownershiptype_choices, + name='vfolderownershiptype', +) + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + # partial table to be preserved and referred + vfolders = sa.Table( + 'vfolders', metadata, + IDColumn('id'), + sa.Column('ownership_type', EnumValueType(VFolderOwnershipType), nullable=False), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=True), + sa.Column('group', GUID, sa.ForeignKey('groups.id'), nullable=True), + ) + + vfolderusagemode.create(op.get_bind()) + vfolderownershiptype.create(op.get_bind()) + op.add_column('vfolder_invitations', + sa.Column('modified_at', sa.DateTime(timezone=True), nullable=True, + onupdate=sa.func.current_timestamp())) + op.add_column( + 'vfolders', + sa.Column('usage_mode', sa.Enum(*vfusagemode_choices, name='vfolderusagemode'), nullable=True) + ) + op.add_column( + 'vfolders', + sa.Column('ownership_type', + sa.Enum(*vfownershiptype_choices, name='vfolderownershiptype'), + nullable=True) + ) + op.add_column( + 'vfolders', + sa.Column('permission', + sa.Enum(*vfperm_choices, name='vfolderpermission'), + nullable=True) + ) + + # Fill vfolders.c.usage_mode with 'general' and vfolders.c.permission. + conn = op.get_bind() + query = textwrap.dedent("UPDATE vfolders SET usage_mode = 'general';") + conn.execute(query) + query = textwrap.dedent("UPDATE vfolders SET permission = 'wd' WHERE \"user\" IS NOT NULL;") + conn.execute(query) + query = textwrap.dedent("UPDATE vfolders SET permission = 'rw' WHERE \"group\" IS NOT NULL;") + conn.execute(query) + + # Set vfolders.c.ownership_type field based on user and group column. + query = (sa.select([vfolders.c.id, vfolders.c.user, vfolders.c.group]) + .select_from(vfolders)) + updates = [] + for row in conn.execute(query).fetchall(): + if row['group']: + ownership_type = VFolderOwnershipType.GROUP + else: + ownership_type = VFolderOwnershipType.USER + updates.append({'vfid': row['id'], 'otype': ownership_type}) + if updates: + query = (sa.update(vfolders) + .values(ownership_type=bindparam('otype')) + .where(vfolders.c.id == bindparam('vfid'))) + conn.execute(query, updates) + + # Create indexes for name. + op.create_index(op.f('ix_vfolders_name'), 'vfolders', ['name'], unique=False) + + # Constraints + op.create_check_constraint( + 'ownership_type_match_with_user_or_group', + 'vfolders', + '(ownership_type = \'user\' AND "user" IS NOT NULL) OR ' + '(ownership_type = \'group\' AND "group" IS NOT NULL)' + ) + op.create_check_constraint( + 'either_one_of_user_or_group', + 'vfolders', + '("user" IS NULL AND "group" IS NOT NULL) OR ("user" IS NOT NULL AND "group" IS NULL)' + ) + + op.alter_column('vfolders', column_name='usage_mode', nullable=False) + op.alter_column('vfolders', column_name='permission', nullable=False) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('ck_vfolders_ownership_type_match_with_user_or_group', 'vfolders') + op.drop_constraint('ck_vfolders_either_one_of_user_or_group', 'vfolders') + op.drop_index(op.f('ix_vfolders_name'), table_name='vfolders') + op.drop_column('vfolders', 'usage_mode') + op.drop_column('vfolders', 'ownership_type') + op.drop_column('vfolders', 'permission') + op.drop_column('vfolder_invitations', 'modified_at') + vfolderusagemode.drop(op.get_bind()) + vfolderownershiptype.drop(op.get_bind()) + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/548cc8aa49c8_update_cluster_columns_in_kernels.py b/src/ai/backend/manager/models/alembic/versions/548cc8aa49c8_update_cluster_columns_in_kernels.py new file mode 100644 index 0000000000..4b3db775f4 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/548cc8aa49c8_update_cluster_columns_in_kernels.py @@ -0,0 +1,55 @@ +"""update_cluster_columns_in_kernels + +Revision ID: 548cc8aa49c8 +Revises: 1e673659b283 +Create Date: 2020-09-08 18:50:05.594899 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '548cc8aa49c8' +down_revision = '1e673659b283' +branch_labels = None +depends_on = None + + +def upgrade(): + op.drop_index('ix_kernels_sess_id_role', table_name='kernels') + op.drop_index('ix_kernels_unique_sess_token', table_name='kernels') + + op.add_column('kernels', sa.Column('cluster_size', sa.Integer, nullable=False, + default=1, server_default=sa.text('1'))) + op.add_column('kernels', sa.Column('cluster_hostname', sa.String(length=64), nullable=True)) + conn = op.get_bind() + query = "UPDATE kernels k " \ + " SET cluster_size = (SELECT COUNT(*) FROM kernels j WHERE j.session_id = k.session_id);" + conn.execute(query) + query = "UPDATE kernels SET cluster_hostname = CONCAT(role, CAST(idx AS TEXT));" + conn.execute(query) + op.alter_column('kernels', 'cluster_hostname', nullable=False) + + op.alter_column('kernels', 'idx', new_column_name='cluster_idx', nullable=False) + op.alter_column('kernels', 'role', new_column_name='cluster_role', nullable=False) + + op.create_index('ix_kernels_sess_id_role', 'kernels', ['session_id', 'cluster_role'], unique=False) + op.create_index('ix_kernels_unique_sess_token', 'kernels', ['access_key', 'session_id'], unique=True, + postgresql_where=sa.text("status NOT IN ('TERMINATED', 'CANCELLED') " + "and cluster_role = 'main'")) + + +def downgrade(): + op.drop_index('ix_kernels_unique_sess_token', table_name='kernels') + op.drop_index('ix_kernels_sess_id_role', table_name='kernels') + + op.alter_column('kernels', 'cluster_idx', new_column_name='idx') + op.alter_column('kernels', 'cluster_role', new_column_name='role') + op.drop_column('kernels', 'cluster_size') + op.drop_column('kernels', 'cluster_hostname') + + op.create_index('ix_kernels_unique_sess_token', 'kernels', + ['access_key', 'session_name'], unique=True, + postgresql_where=sa.text("status NOT IN ('TERMINATED', 'CANCELLED') " + "and role = 'main'")) + op.create_index('ix_kernels_sess_id_role', 'kernels', ['session_name', 'role'], unique=False) diff --git a/src/ai/backend/manager/models/alembic/versions/57b523dec0e8_add_tpu_slots.py b/src/ai/backend/manager/models/alembic/versions/57b523dec0e8_add_tpu_slots.py new file mode 100644 index 0000000000..453aa8e28a --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/57b523dec0e8_add_tpu_slots.py @@ -0,0 +1,34 @@ +"""add tpu slots + +Revision ID: 57b523dec0e8 +Revises: 10e39a34eed5 +Create Date: 2018-11-12 10:54:45.271417 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '57b523dec0e8' +down_revision = '10e39a34eed5' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('agents', sa.Column('tpu_slots', sa.Float(), nullable=False, + server_default='0')) + op.add_column('agents', sa.Column('used_tpu_slots', sa.Float(), nullable=False, + server_default='0')) + op.alter_column('agents', 'tpu_slots', server_default=None) + op.alter_column('agents', 'used_tpu_slots', server_default=None) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('agents', 'used_tpu_slots') + op.drop_column('agents', 'tpu_slots') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/57e717103287_rename_clone_allowed_to_cloneable.py b/src/ai/backend/manager/models/alembic/versions/57e717103287_rename_clone_allowed_to_cloneable.py new file mode 100644 index 0000000000..2c10f4d085 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/57e717103287_rename_clone_allowed_to_cloneable.py @@ -0,0 +1,22 @@ +"""rename-clone_allowed-to-cloneable + +Revision ID: 57e717103287 +Revises: eec98e65902a +Create Date: 2020-10-04 14:14:55.167654 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = '57e717103287' +down_revision = 'eec98e65902a' +branch_labels = None +depends_on = None + + +def upgrade(): + op.alter_column('vfolders', 'clone_allowed', new_column_name='cloneable') + + +def downgrade(): + op.alter_column('vfolders', 'cloneable', new_column_name='clone_allowed') diff --git a/src/ai/backend/manager/models/alembic/versions/5b45f28d2cac_add_resource_opts_in_kernels.py b/src/ai/backend/manager/models/alembic/versions/5b45f28d2cac_add_resource_opts_in_kernels.py new file mode 100644 index 0000000000..9caee57a70 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/5b45f28d2cac_add_resource_opts_in_kernels.py @@ -0,0 +1,30 @@ +"""add_resource_opts_in_kernels + +Revision ID: 5b45f28d2cac +Revises: 9c89b9011872 +Create Date: 2019-09-08 10:07:20.971662 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '5b45f28d2cac' +down_revision = '9c89b9011872' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('resource_opts', + postgresql.JSONB(astext_type=sa.Text()), + nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('kernels', 'resource_opts') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/5d8e6043455e_add_user_group_ids_in_vfolder.py b/src/ai/backend/manager/models/alembic/versions/5d8e6043455e_add_user_group_ids_in_vfolder.py new file mode 100644 index 0000000000..61447dc933 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/5d8e6043455e_add_user_group_ids_in_vfolder.py @@ -0,0 +1,162 @@ +"""add_user_group_ids_in_vfolder + +Revision ID: 5d8e6043455e +Revises: 02950808ca3d +Create Date: 2019-06-06 15:02:58.804516 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql.expression import bindparam +from ai.backend.manager.models.base import convention, ForeignKeyIDColumn, GUID, IDColumn + + +# revision identifiers, used by Alembic. +revision = '5d8e6043455e' +down_revision = '02950808ca3d' +branch_labels = None +depends_on = None + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + # partial table to be preserved and referred + keypairs = sa.Table( + 'keypairs', metadata, + sa.Column('access_key', sa.String(length=20), primary_key=True), + ForeignKeyIDColumn('user', 'users.uuid', nullable=False), + ) + vfolders = sa.Table( + 'vfolders', metadata, + IDColumn('id'), + sa.Column('belongs_to', sa.String(length=20), sa.ForeignKey('keypairs.access_key')), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=True), + sa.Column('group', GUID, sa.ForeignKey('groups.id'), nullable=True), + ) + vfolder_permissions = sa.Table( + 'vfolder_permissions', metadata, + sa.Column('vfolder', GUID, + sa.ForeignKey('vfolders.id', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False), + sa.Column('access_key', sa.String(length=20), sa.ForeignKey('keypairs.access_key')), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=True), + ) + + op.add_column('vfolders', sa.Column('user', GUID(), nullable=True)) + op.add_column('vfolders', sa.Column('group', GUID(), nullable=True)) + op.create_foreign_key(op.f('fk_vfolders_user_users'), + 'vfolders', 'users', ['user'], ['uuid']) + op.create_foreign_key(op.f('fk_vfolders_group_groups'), + 'vfolders', 'groups', ['group'], ['id']) + op.add_column('vfolder_permissions', sa.Column('user', GUID(), nullable=True)) + op.create_foreign_key(op.f('fk_vfolder_permissions_user_users'), + 'vfolder_permissions', 'users', ['user'], ['uuid']) + + connection = op.get_bind() + + # Migrate vfolders' belongs_to keypair into user. + j = vfolders.join(keypairs, vfolders.c.belongs_to == keypairs.c.access_key) + query = sa.select([vfolders.c.id, keypairs.c.user]).select_from(j) + results = connection.execute(query).fetchall() + updates = [{'vid': row.id, 'user': row.user} for row in results] + if updates: + query = (sa.update(vfolders) + .values(user=bindparam('user')) + .where(vfolders.c.id == bindparam('vid'))) + connection.execute(query, updates) + + # Migrate vfolder_permissions' access_key into user. + j = vfolder_permissions.join(keypairs, + vfolder_permissions.c.access_key == keypairs.c.access_key) + query = (sa.select([vfolder_permissions.c.vfolder, keypairs.c.access_key, keypairs.c.user]) + .select_from(j)) + results = connection.execute(query).fetchall() + updates = [{'_vfolder': row.vfolder, '_access_key': row.access_key, '_user': row.user} + for row in results] + if updates: + query = (sa.update(vfolder_permissions) + .values(user=bindparam('_user')) + .where(vfolder_permissions.c.vfolder == bindparam('_vfolder')) + .where(vfolder_permissions.c.access_key == bindparam('_access_key'))) + connection.execute(query, updates) + + op.drop_constraint('fk_vfolders_belongs_to_keypairs', 'vfolders', type_='foreignkey') + op.drop_column('vfolders', 'belongs_to') + op.alter_column('vfolder_permissions', 'user', nullable=False) + op.drop_constraint('fk_vfolder_permissions_access_key_keypairs', + 'vfolder_permissions', type_='foreignkey') + op.drop_column('vfolder_permissions', 'access_key') + + +def downgrade(): + ####################################################################### + # CAUTION: group vfolders will be lost by downgrading this migration! + ####################################################################### + + metadata = sa.MetaData(naming_convention=convention) + # partial table to be preserved and referred + keypairs = sa.Table( + 'keypairs', metadata, + sa.Column('access_key', sa.String(length=20), primary_key=True), + ForeignKeyIDColumn('user', 'users.uuid', nullable=False), + ) + vfolders = sa.Table( + 'vfolders', metadata, + IDColumn('id'), + sa.Column('belongs_to', sa.String(length=20), sa.ForeignKey('keypairs.access_key')), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=True), + # sa.Column('group', GUID, sa.ForeignKey('groups.id'), nullable=True), + ) + vfolder_permissions = sa.Table( + 'vfolder_permissions', metadata, + sa.Column('vfolder', GUID, + sa.ForeignKey('vfolders.id', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False), + sa.Column('access_key', sa.String(length=20), sa.ForeignKey('keypairs.access_key')), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=True), + ) + + op.add_column('vfolders', + sa.Column('belongs_to', sa.String(length=20), autoincrement=False, nullable=True)) + op.create_foreign_key('fk_vfolders_belongs_to_keypairs', + 'vfolders', 'keypairs', ['belongs_to'], ['access_key']) + op.add_column('vfolder_permissions', + sa.Column('access_key', sa.String(length=20), autoincrement=False, nullable=True)) + op.create_foreign_key('fk_vfolder_permissions_access_key_keypairs', + 'vfolder_permissions', 'keypairs', ['access_key'], ['access_key']) + + connection = op.get_bind() + + # Migrate vfolders' user_id into belongs_to. + j = vfolders.join(keypairs, vfolders.c.user == keypairs.c.user) + query = sa.select([vfolders.c.id, keypairs.c.access_key]).select_from(j) + results = connection.execute(query).fetchall() + updates = [{'vid': row.id, 'belongs_to': row.access_key} for row in results] + if updates: + query = (sa.update(vfolders) + .values(belongs_to=bindparam('belongs_to')) + .where(vfolders.c.id == bindparam('vid'))) + connection.execute(query, updates) + + # Migrate vfolder_permissions' used into access_key. + j = (vfolder_permissions.join(keypairs, vfolder_permissions.c.user == keypairs.c.user)) + query = (sa.select([vfolder_permissions.c.vfolder, keypairs.c.user, keypairs.c.access_key]) + .select_from(j)) + results = connection.execute(query).fetchall() + updates = [{'_vfolder': row.vfolder, '_access_key': row.access_key, '_user': row.user} \ + for row in results] + if updates: + query = (sa.update(vfolder_permissions) + .values(access_key=bindparam('_access_key')) + .where(vfolder_permissions.c.vfolder == bindparam('_vfolder')) + .where(vfolder_permissions.c.user == bindparam('_user'))) + connection.execute(query, updates) + + op.alter_column('vfolders', 'belongs_to', nullable=False) + op.alter_column('vfolder_permissions', 'access_key', nullable=False) + op.drop_constraint(op.f('fk_vfolders_user_users'), 'vfolders', type_='foreignkey') + op.drop_constraint(op.f('fk_vfolders_group_groups'), 'vfolders', type_='foreignkey') + op.drop_column('vfolders', 'user') + op.drop_column('vfolders', 'group') + op.drop_constraint(op.f('fk_vfolder_permissions_user_users'), 'vfolder_permissions', type_='foreignkey') + op.drop_column('vfolder_permissions', 'user') diff --git a/src/ai/backend/manager/models/alembic/versions/5de06da3c2b5_init.py b/src/ai/backend/manager/models/alembic/versions/5de06da3c2b5_init.py new file mode 100644 index 0000000000..788ea53674 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/5de06da3c2b5_init.py @@ -0,0 +1,76 @@ +"""init + +Revision ID: 5de06da3c2b5 +Revises: +Create Date: 2017-06-08 15:08:23.166237 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import GUID + + +# revision identifiers, used by Alembic. +revision = '5de06da3c2b5' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'keypairs', + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('access_key', sa.String(length=20), nullable=False), + sa.Column('secret_key', sa.String(length=40), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('billing_plan', sa.String(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), nullable=True), + sa.Column('last_used', sa.DateTime(timezone=True), nullable=True), + sa.Column('concurrency_limit', sa.Integer(), nullable=True), + sa.Column('concurrency_used', sa.Integer(), nullable=True), + sa.Column('rate_limit', sa.Integer(), nullable=True), + sa.Column('num_queries', sa.Integer(), server_default='0', nullable=True), + sa.PrimaryKeyConstraint('access_key') + ) + op.create_table( + 'kernels', + sa.Column('sess_id', GUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('lang', sa.String(length=64), nullable=True), + sa.Column('access_key', sa.String(length=20), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), nullable=True), + sa.Column('terminated_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('status', sa.String(), nullable=True), + sa.Column('agent_id', sa.String(), nullable=True), + sa.Column('container_id', sa.String(), nullable=True), + sa.ForeignKeyConstraint(['access_key'], ['keypairs.access_key'], ), + sa.PrimaryKeyConstraint('sess_id') + ) + op.create_table( + 'usage', + sa.Column('id', GUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('access_key_id', sa.String(length=20), nullable=True), + sa.Column('kernel_type', sa.String(), nullable=True), + sa.Column('kernel_id', sa.String(), nullable=True), + sa.Column('started_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('terminated_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('cpu_used', sa.Integer(), server_default='0', nullable=True), + sa.Column('mem_used', sa.Integer(), server_default='0', nullable=True), + sa.Column('io_used', sa.Integer(), server_default='0', nullable=True), + sa.Column('net_used', sa.Integer(), server_default='0', nullable=True), + sa.ForeignKeyConstraint(['access_key_id'], ['keypairs.access_key'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('usage') + op.drop_table('kernels') + op.drop_table('keypairs') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/5e88398bc340_add_unmanaged_path_column_to_vfolders.py b/src/ai/backend/manager/models/alembic/versions/5e88398bc340_add_unmanaged_path_column_to_vfolders.py new file mode 100644 index 0000000000..c3f0a0d8f6 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/5e88398bc340_add_unmanaged_path_column_to_vfolders.py @@ -0,0 +1,26 @@ +"""Add unmanaged_path column to vfolders + +Revision ID: 5e88398bc340 +Revises: d452bacd085c +Create Date: 2019-11-28 13:41:03.545551 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5e88398bc340' +down_revision = 'd452bacd085c' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('vfolders', sa.Column('unmanaged_path', sa.String(length=512), nullable=True)) + + +def downgrade(): + op.drop_column('vfolders', 'unmanaged_path') + + diff --git a/src/ai/backend/manager/models/alembic/versions/60a1effa77d2_add_coordinator_address_column_on_.py b/src/ai/backend/manager/models/alembic/versions/60a1effa77d2_add_coordinator_address_column_on_.py new file mode 100644 index 0000000000..a0744f1e37 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/60a1effa77d2_add_coordinator_address_column_on_.py @@ -0,0 +1,24 @@ +"""Add wsproxy_addr column on scaling_group + +Revision ID: 60a1effa77d2 +Revises: 8679d0a7e22b +Create Date: 2021-09-17 13:19:57.525513 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '60a1effa77d2' +down_revision = '8679d0a7e22b' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('scaling_groups', sa.Column('wsproxy_addr', sa.String(length=1024), nullable=True)) + + +def downgrade(): + op.drop_column('scaling_groups', 'wsproxy_addr') diff --git a/src/ai/backend/manager/models/alembic/versions/65c4a109bbc7_.py b/src/ai/backend/manager/models/alembic/versions/65c4a109bbc7_.py new file mode 100644 index 0000000000..8a0c465bee --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/65c4a109bbc7_.py @@ -0,0 +1,22 @@ +"""Merge migration + +Revision ID: 65c4a109bbc7 +Revises: 0262e50e90e0, 5e88398bc340 +Create Date: 2019-12-16 01:42:44.316419 + +""" + + +# revision identifiers, used by Alembic. +revision = '65c4a109bbc7' +down_revision = ('0262e50e90e0', '5e88398bc340') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/6f1c1b83870a_merge_user_s_first__last_name_into_full_.py b/src/ai/backend/manager/models/alembic/versions/6f1c1b83870a_merge_user_s_first__last_name_into_full_.py new file mode 100644 index 0000000000..0429726f9f --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/6f1c1b83870a_merge_user_s_first__last_name_into_full_.py @@ -0,0 +1,34 @@ +"""merge user's first_/last_name into full_name + +Revision ID: 6f1c1b83870a +Revises: 7a82e0c70122 +Create Date: 2019-05-22 15:52:57.173180 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '6f1c1b83870a' +down_revision = '7a82e0c70122' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('full_name', sa.String(length=64), nullable=True)) + op.drop_column('users', 'last_name') + op.drop_column('users', 'first_name') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('first_name', sa.VARCHAR(length=32), + autoincrement=False, nullable=True)) + op.add_column('users', sa.Column('last_name', sa.VARCHAR(length=32), + autoincrement=False, nullable=True)) + op.drop_column('users', 'full_name') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/6f5fe19894b7_vfolder_invitation_state_to_enum_type.py b/src/ai/backend/manager/models/alembic/versions/6f5fe19894b7_vfolder_invitation_state_to_enum_type.py new file mode 100644 index 0000000000..2a3571d3be --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/6f5fe19894b7_vfolder_invitation_state_to_enum_type.py @@ -0,0 +1,45 @@ +"""vfolder_invitation_state_to_enum_type + +Revision ID: 6f5fe19894b7 +Revises: 48ab2dfefba9 +Create Date: 2019-09-28 21:05:55.409422 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models import VFolderInvitationState +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = '6f5fe19894b7' +down_revision = '48ab2dfefba9' +branch_labels = None +depends_on = None + +vfinvs_choices = list(map(lambda v: v.value, VFolderInvitationState)) +vfolderinvitationstate = postgresql.ENUM( + *vfinvs_choices, + name='vfolderinvitationstate', +) + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + vfolderinvitationstate.create(op.get_bind()) + op.alter_column('vfolder_invitations', column_name='state', + type_=sa.Enum(*vfinvs_choices, name='vfolderinvitationstate'), + postgresql_using='state::vfolderinvitationstate') + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + vfolderinvitationstate.create(op.get_bind()) + op.alter_column('vfolder_invitations', column_name='state', + type_=sa.String(length=10), + postgresql_using='state::text::vfolderinvitationstate') + vfolderinvitationstate.drop(op.get_bind()) diff --git a/src/ai/backend/manager/models/alembic/versions/7a82e0c70122_add_group_model.py b/src/ai/backend/manager/models/alembic/versions/7a82e0c70122_add_group_model.py new file mode 100644 index 0000000000..98745798e9 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/7a82e0c70122_add_group_model.py @@ -0,0 +1,54 @@ +"""add group model + +Revision ID: 7a82e0c70122 +Revises: bae1a7326e8a +Create Date: 2019-05-09 10:00:55.788734 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import GUID + + +# revision identifiers, used by Alembic. +revision = '7a82e0c70122' +down_revision = 'bae1a7326e8a' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'groups', + sa.Column('id', GUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=64), nullable=False), + sa.Column('description', sa.String(length=512), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('modified_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('domain_name', sa.String(length=64), nullable=False), + sa.ForeignKeyConstraint(['domain_name'], ['domains.name'], + name=op.f('fk_groups_domain_name_domains'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_groups')), + sa.UniqueConstraint('name', 'domain_name', name='uq_groups_name_domain_name') + ) + op.create_index(op.f('ix_groups_domain_name'), 'groups', ['domain_name'], unique=False) + op.create_table( + 'association_groups_users', + sa.Column('user_id', GUID(), nullable=False), + sa.Column('group_id', GUID(), nullable=False), + sa.ForeignKeyConstraint(['group_id'], ['groups.id'], + name=op.f('fk_association_groups_users_group_id_groups'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['users.uuid'], + name=op.f('fk_association_groups_users_user_id_users'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.UniqueConstraint('user_id', 'group_id', name='uq_association_user_id_group_id') + ) + + +def downgrade(): + op.drop_table('association_groups_users') + op.drop_index(op.f('ix_groups_domain_name'), table_name='groups') + op.drop_table('groups') diff --git a/src/ai/backend/manager/models/alembic/versions/7dd1d81c3204_add_vfolder_mounts_to_kernels.py b/src/ai/backend/manager/models/alembic/versions/7dd1d81c3204_add_vfolder_mounts_to_kernels.py new file mode 100644 index 0000000000..89759c1a71 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/7dd1d81c3204_add_vfolder_mounts_to_kernels.py @@ -0,0 +1,29 @@ +"""add-vfolder-mounts-to-kernels + +Revision ID: 7dd1d81c3204 +Revises: 60a1effa77d2 +Create Date: 2022-03-09 16:41:48.304128 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '7dd1d81c3204' +down_revision = '60a1effa77d2' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('vfolder_mounts', sa.JSON(), nullable=True)) + op.create_index('ix_keypairs_resource_policy', 'keypairs', ['resource_policy'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_keypairs_resource_policy', table_name='keypairs') + op.drop_column('kernels', 'vfolder_mounts') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/7ea324d0535b_vfolder_and_kernel.py b/src/ai/backend/manager/models/alembic/versions/7ea324d0535b_vfolder_and_kernel.py new file mode 100644 index 0000000000..c28b89d906 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/7ea324d0535b_vfolder_and_kernel.py @@ -0,0 +1,114 @@ +"""vfolder-and-kernel + +Revision ID: 7ea324d0535b +Revises: 5de06da3c2b5 +Create Date: 2017-08-08 16:25:59.553570 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import GUID +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '7ea324d0535b' +down_revision = '5de06da3c2b5' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'agents', + sa.Column('id', sa.String(length=64), nullable=False), + sa.Column('mem_slots', sa.Integer(), nullable=False), + sa.Column('cpu_slots', sa.Integer(), nullable=False), + sa.Column('gpu_slots', sa.Integer(), nullable=False), + sa.Column('used_mem_slots', sa.Integer(), nullable=False), + sa.Column('used_cpu_slots', sa.Integer(), nullable=False), + sa.Column('used_gpu_slots', sa.Integer(), nullable=False), + sa.Column('addr', sa.String(length=128), nullable=False), + sa.Column('first_contact', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.PrimaryKeyConstraint('id', name=op.f('pk_agents')) + ) + op.create_table( + 'vfolders', + sa.Column('id', GUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('host', sa.String(length=128), nullable=False), + sa.Column('name', sa.String(length=64), nullable=False), + sa.Column('max_files', sa.Integer(), nullable=True), + sa.Column('max_size', sa.Integer(), nullable=True), + sa.Column('num_files', sa.Integer(), nullable=True), + sa.Column('cur_size', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('last_used', sa.DateTime(timezone=True), nullable=True), + sa.Column('belongs_to', sa.String(length=20), nullable=False), + sa.ForeignKeyConstraint(['belongs_to'], ['keypairs.access_key'], name=op.f('fk_vfolders_belongs_to_keypairs')), + sa.PrimaryKeyConstraint('id', name=op.f('pk_vfolders')) + ) + op.create_table( + 'vfolder_attachment', + sa.Column('vfolder', GUID(), nullable=False), + sa.Column('kernel', GUID(), nullable=False), + sa.ForeignKeyConstraint(['kernel'], ['kernels.sess_id'], name=op.f('fk_vfolder_attachment_kernel_kernels')), + sa.ForeignKeyConstraint(['vfolder'], ['vfolders.id'], name=op.f('fk_vfolder_attachment_vfolder_vfolders')), + sa.PrimaryKeyConstraint('vfolder', 'kernel', name=op.f('pk_vfolder_attachment')) + ) + op.drop_table('usage') + op.add_column('kernels', sa.Column('agent', sa.String(length=64), nullable=True)) + op.add_column('kernels', sa.Column('allocated_cores', sa.ARRAY(sa.Integer()), nullable=True)) + op.add_column('kernels', sa.Column('cpu_used', sa.BigInteger(), nullable=True)) + op.add_column('kernels', sa.Column('cur_mem_bytes', sa.BigInteger(), nullable=True)) + op.add_column('kernels', sa.Column('io_read_bytes', sa.BigInteger(), nullable=True)) + op.add_column('kernels', sa.Column('io_write_bytes', sa.BigInteger(), nullable=True)) + op.add_column('kernels', sa.Column('max_mem_bytes', sa.BigInteger(), nullable=True)) + op.add_column('kernels', sa.Column('net_rx_bytes', sa.BigInteger(), nullable=True)) + op.add_column('kernels', sa.Column('net_tx_bytes', sa.BigInteger(), nullable=True)) + op.add_column('kernels', sa.Column('num_queries', sa.BigInteger(), nullable=True)) + op.add_column('kernels', sa.Column('status_info', sa.Unicode(), nullable=True)) + op.create_index(op.f('ix_kernels_created_at'), 'kernels', ['created_at'], unique=False) + op.create_index(op.f('ix_kernels_status'), 'kernels', ['status'], unique=False) + op.create_index(op.f('ix_kernels_terminated_at'), 'kernels', ['terminated_at'], unique=False) + op.create_foreign_key(op.f('fk_kernels_agent_agents'), 'kernels', 'agents', ['agent'], ['id']) + op.drop_column('kernels', 'agent_id') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True)) + op.drop_constraint(op.f('fk_kernels_agent_agents'), 'kernels', type_='foreignkey') + op.drop_index(op.f('ix_kernels_terminated_at'), table_name='kernels') + op.drop_index(op.f('ix_kernels_status'), table_name='kernels') + op.drop_index(op.f('ix_kernels_created_at'), table_name='kernels') + op.drop_column('kernels', 'status_info') + op.drop_column('kernels', 'num_queries') + op.drop_column('kernels', 'net_tx_bytes') + op.drop_column('kernels', 'net_rx_bytes') + op.drop_column('kernels', 'max_mem_bytes') + op.drop_column('kernels', 'io_write_bytes') + op.drop_column('kernels', 'io_read_bytes') + op.drop_column('kernels', 'cur_mem_bytes') + op.drop_column('kernels', 'cpu_used') + op.drop_column('kernels', 'allocated_cores') + op.drop_column('kernels', 'agent') + op.create_table( + 'usage', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('access_key_id', sa.VARCHAR(length=20), autoincrement=False, nullable=True), + sa.Column('kernel_type', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('kernel_id', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('started_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('terminated_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=True), + sa.Column('cpu_used', sa.INTEGER(), server_default=sa.text('0'), autoincrement=False, nullable=True), + sa.Column('mem_used', sa.INTEGER(), server_default=sa.text('0'), autoincrement=False, nullable=True), + sa.Column('io_used', sa.INTEGER(), server_default=sa.text('0'), autoincrement=False, nullable=True), + sa.Column('net_used', sa.INTEGER(), server_default=sa.text('0'), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(['access_key_id'], ['keypairs.access_key'], name='fk_usage_access_key_id_keypairs'), + sa.PrimaryKeyConstraint('id', name='pk_usage') + ) + op.drop_table('vfolder_attachment') + op.drop_table('vfolders') + op.drop_table('agents') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/80176413d8aa_keypairs_get_is_admin.py b/src/ai/backend/manager/models/alembic/versions/80176413d8aa_keypairs_get_is_admin.py new file mode 100644 index 0000000000..a1acc412ef --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/80176413d8aa_keypairs_get_is_admin.py @@ -0,0 +1,30 @@ +"""keypairs_get_is_admin + +Revision ID: 80176413d8aa +Revises: 4b8a66fb8d82 +Create Date: 2017-09-14 16:01:59.994941 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql.expression import false + +# revision identifiers, used by Alembic. +revision = '80176413d8aa' +down_revision = '4b8a66fb8d82' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('keypairs', sa.Column('is_admin', sa.Boolean(), nullable=False, default=False, server_default=false())) + op.create_index(op.f('ix_keypairs_is_admin'), 'keypairs', ['is_admin'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_keypairs_is_admin'), table_name='keypairs') + op.drop_column('keypairs', 'is_admin') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/819c2b3830a9_add_user_model.py b/src/ai/backend/manager/models/alembic/versions/819c2b3830a9_add_user_model.py new file mode 100644 index 0000000000..d8aa6247c1 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/819c2b3830a9_add_user_model.py @@ -0,0 +1,136 @@ +"""add user model + +Revision ID: 819c2b3830a9 +Revises: 8e660aa31fe3 +Create Date: 2019-05-02 00:21:43.704843 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from ai.backend.manager.models.base import ( + convention, EnumValueType, ForeignKeyIDColumn, GUID, IDColumn, +) +from ai.backend.manager.models.user import PasswordColumn +from ai.backend.manager.models import UserRole +# from ai.backend.manager.models import keypairs, users, UserRole + + +# revision identifiers, used by Alembic. +revision = '819c2b3830a9' +down_revision = '8e660aa31fe3' +branch_labels = None +depends_on = None + + +userrole_choices = list(map(lambda v: v.value, UserRole)) +userrole = postgresql.ENUM(*userrole_choices, name='userrole') + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + # partial table to be preserved and referred + keypairs = sa.Table( + 'keypairs', metadata, + sa.Column('user_id', sa.String(length=256), index=True), + sa.Column('access_key', sa.String(length=20), primary_key=True), + sa.Column('secret_key', sa.String(length=40)), + sa.Column('is_active', sa.Boolean, index=True), + sa.Column('is_admin', sa.Boolean, index=True), + ForeignKeyIDColumn('user', 'users.uuid', nullable=False), + ) + # partial table to insert the migrated data + users = sa.Table( + 'users', metadata, + IDColumn('uuid'), + sa.Column('username', sa.String(length=64), unique=True), + sa.Column('email', sa.String(length=64), index=True, + nullable=False, unique=True), + sa.Column('password', PasswordColumn()), + sa.Column('need_password_change', sa.Boolean), + sa.Column('is_active', sa.Boolean, default=True), + sa.Column('role', EnumValueType(UserRole), default=UserRole.USER), + ) + + userrole.create(op.get_bind()) + op.create_table( + 'users', + sa.Column('uuid', GUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('username', sa.String(length=64), nullable=True), + sa.Column('email', sa.String(length=64), nullable=False), + sa.Column('password', PasswordColumn(), nullable=True), + sa.Column('need_password_change', sa.Boolean(), nullable=True), + sa.Column('first_name', sa.String(length=32), nullable=True), + sa.Column('last_name', sa.String(length=32), nullable=True), + sa.Column('description', sa.String(length=500), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), nullable=True), + sa.Column('role', postgresql.ENUM(*userrole_choices, name='userrole', create_type=False), + nullable=True), + sa.PrimaryKeyConstraint('uuid', name=op.f('pk_users')), + sa.UniqueConstraint('username', name=op.f('uq_users_username')) + ) + op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) + op.add_column('keypairs', sa.Column('user', GUID(), nullable=True)) + op.create_foreign_key(op.f('fk_keypairs_user_users'), 'keypairs', 'users', ['user'], ['uuid']) + + # ### Create users based on keypair.user_id & associate keypairs.user to user record ### + # Get all keypairs + connection = op.get_bind() + query = ( + sa.select([keypairs.c.user_id, + keypairs.c.access_key, + keypairs.c.secret_key, + keypairs.c.is_admin]) + .select_from(keypairs)) + results = connection.execute(query).fetchall() + for keypair in results: + email = keypair['user_id'] + access_key = keypair['access_key'] + is_admin = keypair['is_admin'] + if email in [None, '']: + continue + # Try to get a user whose email matches with current keypair's email + query = sa.select([users.c.uuid, users.c.role]).select_from(users).where(users.c.email == email) + user = connection.execute(query).first() + if user: + # Update user's role if current keypair is admin keypair + user_uuid = user['uuid'] + role = UserRole.ADMIN if is_admin else UserRole.USER + if role == UserRole.ADMIN and user['role'] != UserRole.ADMIN: + query = (sa.update(users) + .values(role=UserRole.ADMIN) + .where(users.c.email == email)) + connection.execute(query) + else: + # Create new user (set username with email) + role = UserRole.ADMIN if is_admin else UserRole.USER + temp_password = keypair['secret_key'][:8] + query = (sa.insert(users) + .returning(users.c.uuid) + .values(username=email, email=email, + password=temp_password, + need_password_change=True, + is_active=True, role=role)) + user = connection.execute(query).first() + user_uuid = user[0] + # Update current keypair's `user` field with associated user's uuid. + query = (sa.update(keypairs) + .values(user=user_uuid) + .where(keypairs.c.access_key == access_key)) + connection.execute(query) + + # Make keypairs.user column NOT NULL. + op.alter_column('keypairs', column_name='user', nullable=False) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f('fk_keypairs_user_users'), 'keypairs', type_='foreignkey') + op.drop_column('keypairs', 'user') + op.drop_index(op.f('ix_users_email'), table_name='users') + op.drop_table('users') + # ### end Alembic commands ### + + userrole.drop(op.get_bind()) diff --git a/src/ai/backend/manager/models/alembic/versions/81c264528f20_add_max_session_lifetime.py b/src/ai/backend/manager/models/alembic/versions/81c264528f20_add_max_session_lifetime.py new file mode 100644 index 0000000000..0f15bd5294 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/81c264528f20_add_max_session_lifetime.py @@ -0,0 +1,24 @@ +"""add-max-session-lifetime + +Revision ID: 81c264528f20 +Revises: d727b5da20e6 +Create Date: 2022-04-21 09:22:01.405710 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '81c264528f20' +down_revision = 'd727b5da20e6' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('keypair_resource_policies', sa.Column('max_session_lifetime', sa.Integer(), server_default=sa.text('0'), nullable=False)) + + +def downgrade(): + op.drop_column('keypair_resource_policies', 'max_session_lifetime') diff --git a/src/ai/backend/manager/models/alembic/versions/854bd902b1bc_change_kernel_identification.py b/src/ai/backend/manager/models/alembic/versions/854bd902b1bc_change_kernel_identification.py new file mode 100644 index 0000000000..7d39f14435 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/854bd902b1bc_change_kernel_identification.py @@ -0,0 +1,67 @@ +"""change-kernel-identification + +Revision ID: 854bd902b1bc +Revises: 0f3bc98edaa0 +Create Date: 2017-08-21 17:08:20.581565 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import GUID +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '854bd902b1bc' +down_revision = '0f3bc98edaa0' +branch_labels = None +depends_on = None + + +def upgrade(): + op.drop_constraint('fk_vfolder_attachment_vfolder_vfolders', 'vfolder_attachment', type_='foreignkey') + op.drop_constraint('fk_vfolder_attachment_kernel_kernels', 'vfolder_attachment', type_='foreignkey') + op.drop_constraint('pk_kernels', 'kernels', type_='primary') + op.add_column('kernels', + sa.Column('id', GUID(), + server_default=sa.text('uuid_generate_v4()'), + nullable=False)) + op.add_column('kernels', sa.Column('role', sa.String(length=16), nullable=False, default='master')) + op.create_primary_key('pk_kernels', 'kernels', ['id']) + op.alter_column( + 'kernels', 'sess_id', + existing_type=postgresql.UUID(), + type_=sa.String(length=64), + nullable=True, + existing_server_default=sa.text('uuid_generate_v4()')) + op.create_index(op.f('ix_kernels_sess_id'), 'kernels', ['sess_id'], unique=False) + op.create_index(op.f('ix_kernels_sess_id_role'), 'kernels', ['sess_id', 'role'], unique=False) + op.create_foreign_key('fk_vfolder_attachment_vfolder_vfolders', + 'vfolder_attachment', 'vfolders', + ['vfolder'], ['id'], onupdate='CASCADE', ondelete='CASCADE') + op.create_foreign_key('fk_vfolder_attachment_kernel_kernels', + 'vfolder_attachment', 'kernels', + ['kernel'], ['id'], onupdate='CASCADE', ondelete='CASCADE') + + +def downgrade(): + op.drop_constraint('fk_vfolder_attachment_vfolder_vfolders', 'vfolder_attachment', type_='foreignkey') + op.drop_constraint('fk_vfolder_attachment_kernel_kernels', 'vfolder_attachment', type_='foreignkey') + op.drop_constraint('pk_kernels', 'kernels', type_='primary') + op.drop_index(op.f('ix_kernels_sess_id'), table_name='kernels') + op.drop_index(op.f('ix_kernels_sess_id_role'), table_name='kernels') + op.alter_column( + 'kernels', 'sess_id', + existing_type=sa.String(length=64), + type_=postgresql.UUID(), + nullable=False, + existing_server_default=sa.text('uuid_generate_v4()'), + postgresql_using='sess_id::uuid') + op.create_primary_key('pk_kernels', 'kernels', ['sess_id']) + op.drop_column('kernels', 'id') + op.drop_column('kernels', 'role') + op.create_foreign_key('fk_vfolder_attachment_vfolder_vfolders', + 'vfolder_attachment', 'vfolders', + ['vfolder'], ['id']) + op.create_foreign_key('fk_vfolder_attachment_kernel_kernels', + 'vfolder_attachment', 'kernels', + ['kernel'], ['sess_id']) diff --git a/src/ai/backend/manager/models/alembic/versions/8679d0a7e22b_add_scheduled_to_kernelstatus.py b/src/ai/backend/manager/models/alembic/versions/8679d0a7e22b_add_scheduled_to_kernelstatus.py new file mode 100644 index 0000000000..5c03acf082 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/8679d0a7e22b_add_scheduled_to_kernelstatus.py @@ -0,0 +1,96 @@ +"""add-scheduled-to-kernelstatus + +Revision ID: 8679d0a7e22b +Revises: 518ecf41f567 +Create Date: 2021-04-01 14:24:27.885209 + +""" +import textwrap + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '8679d0a7e22b' +down_revision = '518ecf41f567' +branch_labels = None +depends_on = None + +kernelstatus_new_values = [ + 'PENDING', + 'SCHEDULED', # added + 'PREPARING', + 'BUILDING', + 'PULLING', + 'RUNNING', + 'RESTARTING', + 'RESIZING', + 'SUSPENDED', + 'TERMINATING', + 'TERMINATED', + 'ERROR', + 'CANCELLED', +] +kernelstatus_new = postgresql.ENUM(*kernelstatus_new_values, name='kernelstatus') + +kernelstatus_old_values = [ + 'PENDING', + 'PREPARING', + 'BUILDING', + 'PULLING', + 'RUNNING', + 'RESTARTING', + 'RESIZING', + 'SUSPENDED', + 'TERMINATING', + 'TERMINATED', + 'ERROR', + 'CANCELLED', +] +kernelstatus_old = postgresql.ENUM(*kernelstatus_old_values, name='kernelstatus') + + +def upgrade(): + conn = op.get_bind() + conn.execute('DROP INDEX IF EXISTS ix_kernels_unique_sess_token;') + conn.execute('ALTER TYPE kernelstatus RENAME TO kernelstatus_old;') + kernelstatus_new.create(conn) + conn.execute(textwrap.dedent('''\ + ALTER TABLE kernels + ALTER COLUMN "status" DROP DEFAULT, + ALTER COLUMN "status" TYPE kernelstatus USING "status"::text::kernelstatus, + ALTER COLUMN "status" SET DEFAULT 'PENDING'::kernelstatus; + DROP TYPE kernelstatus_old; + ''')) + # This also fixes the unique constraint columns: + # (access_key, session_id) -> (access_key, session_name) + op.create_index( + 'ix_kernels_unique_sess_token', 'kernels', ['access_key', 'session_name'], + unique=True, postgresql_where=sa.text( + "status NOT IN ('TERMINATED', 'CANCELLED') and cluster_role = 'main'" + )) + + +def downgrade(): + op.drop_index('ix_kernels_unique_sess_token', table_name='kernels') + conn = op.get_bind() + conn.execute('ALTER TYPE kernelstatus RENAME TO kernelstatus_new;') + kernelstatus_old.create(conn) + conn.execute(textwrap.dedent('''\ + ALTER TABLE kernels + ALTER COLUMN "status" DROP DEFAULT, + ALTER COLUMN "status" TYPE kernelstatus USING ( + CASE "status"::text + WHEN 'SCHEDULED' THEN 'PREPARING' + ELSE "status"::text + END + )::kernelstatus, + ALTER COLUMN "status" SET DEFAULT 'PENDING'::kernelstatus; + DROP TYPE kernelstatus_new; + ''')) + op.create_index( + 'ix_kernels_unique_sess_token', 'kernels', ['access_key', 'session_id'], + unique=True, postgresql_where=sa.text( + "status != 'TERMINATED' and cluster_role = 'main'" + )) diff --git a/src/ai/backend/manager/models/alembic/versions/8e660aa31fe3_add_resource_presets.py b/src/ai/backend/manager/models/alembic/versions/8e660aa31fe3_add_resource_presets.py new file mode 100644 index 0000000000..0ad7eee27d --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/8e660aa31fe3_add_resource_presets.py @@ -0,0 +1,98 @@ +"""add-resource-presets + +Revision ID: 8e660aa31fe3 +Revises: 01456c812164 +Create Date: 2019-03-30 01:45:07.525096 + +""" +from alembic import op +from decimal import Decimal +import sqlalchemy as sa +from ai.backend.manager.models.base import ResourceSlotColumn +from ai.backend.manager.models import keypair_resource_policies +from ai.backend.common.types import BinarySize, ResourceSlot + + +# revision identifiers, used by Alembic. +revision = '8e660aa31fe3' +down_revision = '01456c812164' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'resource_presets', + sa.Column('name', sa.String(length=256), nullable=False), + sa.Column('resource_slots', + ResourceSlotColumn(), + nullable=False), + sa.PrimaryKeyConstraint('name', name=op.f('pk_resource_presets')) + ) + # Add initial fixtures for resource presets + query = ''' + INSERT INTO resource_presets + VALUES ( + 'small', + '{"cpu":"1","mem":"2147483648"}'::jsonb + ); + INSERT INTO resource_presets + VALUES ( + 'small-gpu', + '{"cpu":"1","mem":"2147483648","cuda.device":"1","cuda.shares":"0.5"}'::jsonb + ); + INSERT INTO resource_presets + VALUES ( + 'medium', + '{"cpu":"2","mem":"4294967296"}'::jsonb + ); + INSERT INTO resource_presets + VALUES ( + 'medium-gpu', + '{"cpu":"2","mem":"4294967296","cuda.device":"1","cuda.shares":"1.0"}'::jsonb + ); + INSERT INTO resource_presets + VALUES ( + 'large', + '{"cpu":"4","mem":"8589934592"}'::jsonb + ); + INSERT INTO resource_presets + VALUES ( + 'large-gpu', + '{"cpu":"4","mem":"8589934592","cuda.device":"2","cuda.shares":"2.0"}'::jsonb + ); + ''' + connection = op.get_bind() + connection.execute(query) + + query = ''' + SELECT name, total_resource_slots + FROM keypair_resource_policies + ''' + connection = op.get_bind() + result = connection.execute(query) + updates = [] + for row in result: + converted = ResourceSlot(row['total_resource_slots']) + if 'mem' in converted: + converted['mem'] = Decimal(BinarySize.from_str(converted['mem'])) + updates.append(( + row['name'], + converted, + )) + for name, slots in updates: + query = ( + sa.update(keypair_resource_policies) + .values(total_resource_slots=slots) + .where(keypair_resource_policies.c.name == name) + ) + connection.execute(query) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('resource_presets') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/911023380bc9_add_architecture_column_on_agents.py b/src/ai/backend/manager/models/alembic/versions/911023380bc9_add_architecture_column_on_agents.py new file mode 100644 index 0000000000..6bda687a4f --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/911023380bc9_add_architecture_column_on_agents.py @@ -0,0 +1,34 @@ +"""add architecture column on agents + +Revision ID: 911023380bc9 +Revises: 015d84d5a5ef +Create Date: 2022-02-16 00:54:23.261212 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '911023380bc9' +down_revision = '015d84d5a5ef' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + 'agents', + sa.Column('architecture', sa.String, default='x86_64')) + op.execute('UPDATE agents SET architecture=\'x86_64\'') + op.alter_column('agents', 'architecture', nullable=False) + op.add_column( + 'kernels', + sa.Column('architecture', sa.String, default='x86_64')) + op.execute('UPDATE kernels SET architecture=\'x86_64\'') + op.alter_column('kernels', 'architecture', nullable=False) + + +def downgrade(): + op.drop_column('kernels', 'architecture') + op.drop_column('agents', 'architecture') diff --git a/src/ai/backend/manager/models/alembic/versions/93e9d31d40bf_agent_add_region.py b/src/ai/backend/manager/models/alembic/versions/93e9d31d40bf_agent_add_region.py new file mode 100644 index 0000000000..c5ff170dcc --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/93e9d31d40bf_agent_add_region.py @@ -0,0 +1,41 @@ +"""agent_add_region + +Revision ID: 93e9d31d40bf +Revises: 80176413d8aa +Create Date: 2017-09-28 15:01:38.944738 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '93e9d31d40bf' +down_revision = '80176413d8aa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('agents', sa.Column('region', sa.String(length=64), + nullable=False, + server_default='amazon/ap-northeast-2')) + op.create_index(op.f('ix_agents_region'), 'agents', ['region'], unique=False) + op.alter_column( + 'keypairs', 'is_admin', + existing_type=sa.BOOLEAN(), + nullable=True, + existing_server_default=sa.text('false')) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + 'keypairs', 'is_admin', + existing_type=sa.BOOLEAN(), + nullable=False, + existing_server_default=sa.text('false')) + op.drop_index(op.f('ix_agents_region'), table_name='agents') + op.drop_column('agents', 'region') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/97f6c80c8aa5_merge.py b/src/ai/backend/manager/models/alembic/versions/97f6c80c8aa5_merge.py new file mode 100644 index 0000000000..791cddf60f --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/97f6c80c8aa5_merge.py @@ -0,0 +1,24 @@ +"""merge + +Revision ID: 97f6c80c8aa5 +Revises: e421c02cf9e4, 25e903510fa1 +Create Date: 2020-09-28 18:00:35.664882 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '97f6c80c8aa5' +down_revision = ('e421c02cf9e4', '25e903510fa1') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/9a91532c8534_add_scaling_group.py b/src/ai/backend/manager/models/alembic/versions/9a91532c8534_add_scaling_group.py new file mode 100644 index 0000000000..fe4902646c --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/9a91532c8534_add_scaling_group.py @@ -0,0 +1,131 @@ +"""add-scaling-group + +Revision ID: 9a91532c8534 +Revises: c401d78cc7b9 +Create Date: 2019-07-25 22:32:25.974046 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +import ai.backend.manager.models.base + +# revision identifiers, used by Alembic. +revision = '9a91532c8534' +down_revision = 'c401d78cc7b9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'scaling_groups', + sa.Column('name', sa.String(length=64), nullable=False), + sa.Column('description', sa.String(length=512), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('driver', sa.String(length=64), nullable=False), + sa.Column('driver_opts', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('scheduler', sa.String(length=64), nullable=False), + sa.Column('scheduler_opts', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint('name', name=op.f('pk_scaling_groups')) + ) + op.create_index(op.f('ix_scaling_groups_is_active'), 'scaling_groups', ['is_active'], unique=False) + op.create_table( + 'sgroups_for_domains', + sa.Column('scaling_group', sa.String(length=64), nullable=False), + sa.Column('domain', sa.String(length=64), nullable=False), + sa.ForeignKeyConstraint(['domain'], ['domains.name'], + name=op.f('fk_sgroups_for_domains_domain_domains'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['scaling_group'], ['scaling_groups.name'], + name=op.f('fk_sgroups_for_domains_scaling_group_scaling_groups'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.UniqueConstraint('scaling_group', 'domain', name='uq_sgroup_domain') + ) + op.create_index(op.f('ix_sgroups_for_domains_domain'), + 'sgroups_for_domains', ['domain'], unique=False) + op.create_index(op.f('ix_sgroups_for_domains_scaling_group'), + 'sgroups_for_domains', ['scaling_group'], unique=False) + op.create_table( + 'sgroups_for_groups', + sa.Column('scaling_group', sa.String(length=64), nullable=False), + sa.Column('group', ai.backend.manager.models.base.GUID(), nullable=False), + sa.ForeignKeyConstraint(['group'], ['groups.id'], + name=op.f('fk_sgroups_for_groups_group_groups'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['scaling_group'], ['scaling_groups.name'], + name=op.f('fk_sgroups_for_groups_scaling_group_scaling_groups'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.UniqueConstraint('scaling_group', 'group', name='uq_sgroup_ugroup') + ) + op.create_index(op.f('ix_sgroups_for_groups_group'), + 'sgroups_for_groups', ['group'], unique=False) + op.create_index(op.f('ix_sgroups_for_groups_scaling_group'), + 'sgroups_for_groups', ['scaling_group'], unique=False) + op.create_table( + 'sgroups_for_keypairs', + sa.Column('scaling_group', sa.String(length=64), nullable=False), + sa.Column('access_key', sa.String(length=20), nullable=False), + sa.ForeignKeyConstraint(['access_key'], ['keypairs.access_key'], + name=op.f('fk_sgroups_for_keypairs_access_key_keypairs'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['scaling_group'], ['scaling_groups.name'], + name=op.f('fk_sgroups_for_keypairs_scaling_group_scaling_groups'), + onupdate='CASCADE', ondelete='CASCADE'), + sa.UniqueConstraint('scaling_group', 'access_key', name='uq_sgroup_akey') + ) + op.create_index(op.f('ix_sgroups_for_keypairs_access_key'), 'sgroups_for_keypairs', ['access_key'], unique=False) + op.create_index(op.f('ix_sgroups_for_keypairs_scaling_group'), 'sgroups_for_keypairs', ['scaling_group'], unique=False) + + # create the default sgroup + query = ''' + INSERT INTO scaling_groups + VALUES ( + 'default', + 'The default agent scaling group', + 't', + now(), + 'static', + '{}'::jsonb, + 'fifo', + '{}'::jsonb + ); + INSERT INTO sgroups_for_domains + VALUES ('default', 'default'); + ''' + connection = op.get_bind() + connection.execute(query) + + op.add_column('agents', sa.Column('scaling_group', sa.String(length=64), server_default='default', nullable=False)) + op.create_index(op.f('ix_agents_scaling_group'), 'agents', ['scaling_group'], unique=False) + op.create_foreign_key(op.f('fk_agents_scaling_group_scaling_groups'), + 'agents', 'scaling_groups', ['scaling_group'], ['name']) + op.add_column('kernels', sa.Column('scaling_group', sa.String(length=64), server_default='default', nullable=False)) + op.create_index(op.f('ix_kernels_scaling_group'), 'kernels', ['scaling_group'], unique=False) + op.create_foreign_key(op.f('fk_kernels_scaling_group_scaling_groups'), 'kernels', 'scaling_groups', ['scaling_group'], ['name']) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f('fk_kernels_scaling_group_scaling_groups'), 'kernels', type_='foreignkey') + op.drop_index(op.f('ix_kernels_scaling_group'), table_name='kernels') + op.drop_column('kernels', 'scaling_group') + op.drop_constraint(op.f('fk_agents_scaling_group_scaling_groups'), 'agents', type_='foreignkey') + op.drop_index(op.f('ix_agents_scaling_group'), table_name='agents') + op.drop_column('agents', 'scaling_group') + op.drop_index(op.f('ix_sgroups_for_keypairs_scaling_group'), table_name='sgroups_for_keypairs') + op.drop_index(op.f('ix_sgroups_for_keypairs_access_key'), table_name='sgroups_for_keypairs') + op.drop_table('sgroups_for_keypairs') + op.drop_index(op.f('ix_sgroups_for_groups_scaling_group'), table_name='sgroups_for_groups') + op.drop_index(op.f('ix_sgroups_for_groups_group'), table_name='sgroups_for_groups') + op.drop_table('sgroups_for_groups') + op.drop_index(op.f('ix_sgroups_for_domains_scaling_group'), table_name='sgroups_for_domains') + op.drop_index(op.f('ix_sgroups_for_domains_domain'), table_name='sgroups_for_domains') + op.drop_table('sgroups_for_domains') + op.drop_index(op.f('ix_scaling_groups_is_active'), table_name='scaling_groups') + op.drop_table('scaling_groups') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/9bd986a75a2a_allow_kernels_scaling_group_nullable.py b/src/ai/backend/manager/models/alembic/versions/9bd986a75a2a_allow_kernels_scaling_group_nullable.py new file mode 100644 index 0000000000..344fbcc38f --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/9bd986a75a2a_allow_kernels_scaling_group_nullable.py @@ -0,0 +1,34 @@ +"""allow_kernels_scaling_group_nullable + +Revision ID: 9bd986a75a2a +Revises: 513164749de4 +Create Date: 2019-09-20 14:39:57.761791 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '9bd986a75a2a' +down_revision = '513164749de4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('kernels', 'scaling_group', + existing_type=sa.VARCHAR(length=64), + nullable=True, + existing_server_default=sa.text("'default'::character varying")) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('kernels', 'scaling_group', + existing_type=sa.VARCHAR(length=64), + nullable=False, + existing_server_default=sa.text("'default'::character varying")) + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/9c89b9011872_add_attached_devices_field_in_kernels.py b/src/ai/backend/manager/models/alembic/versions/9c89b9011872_add_attached_devices_field_in_kernels.py new file mode 100644 index 0000000000..a62ebe6c92 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/9c89b9011872_add_attached_devices_field_in_kernels.py @@ -0,0 +1,30 @@ +"""add_attached_devices_field_in_kernels + +Revision ID: 9c89b9011872 +Revises: 2a82340fa30e +Create Date: 2019-08-04 16:38:52.781990 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '9c89b9011872' +down_revision = '2a82340fa30e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('attached_devices', + postgresql.JSONB(astext_type=sa.Text()), + nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('kernels', 'attached_devices') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/9cd61b1ae70d_add_scheduable_field_to_agents.py b/src/ai/backend/manager/models/alembic/versions/9cd61b1ae70d_add_scheduable_field_to_agents.py new file mode 100644 index 0000000000..f68ba694a4 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/9cd61b1ae70d_add_scheduable_field_to_agents.py @@ -0,0 +1,34 @@ +"""add_scheduable_field_to_agents + +Revision ID: 9cd61b1ae70d +Revises: e35332f8d23d +Create Date: 2020-07-01 15:02:13.979828 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql.expression import true +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '9cd61b1ae70d' +down_revision = 'e35332f8d23d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('agents', sa.Column( + 'schedulable', sa.Boolean(), + server_default=true(), + default=True, + nullable=False, + )) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('agents', 'schedulable') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/a1fd4e7b7782_enumerate_vfolder_perms.py b/src/ai/backend/manager/models/alembic/versions/a1fd4e7b7782_enumerate_vfolder_perms.py new file mode 100644 index 0000000000..569541eb8e --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/a1fd4e7b7782_enumerate_vfolder_perms.py @@ -0,0 +1,49 @@ +"""enumerate_vfolder_perms + +Revision ID: a1fd4e7b7782 +Revises: f9971fbb34d9 +Create Date: 2018-09-05 16:51:49.973195 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models import VFolderPermission +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = 'a1fd4e7b7782' +down_revision = 'f9971fbb34d9' +branch_labels = None +depends_on = None + +# NOTE: VFolderPermission is EnumValueType +vfperm_choices = list(map(lambda v: v.value, VFolderPermission)) +vfolderpermission = postgresql.ENUM( + *vfperm_choices, + name='vfolderpermission', +) + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + vfolderpermission.create(op.get_bind()) + op.alter_column('vfolder_invitations', column_name='permission', + type_=sa.Enum(*vfperm_choices, name='vfolderpermission'), + postgresql_using='permission::vfolderpermission') + op.alter_column('vfolder_permissions', column_name='permission', + type_=sa.Enum(*vfperm_choices, name='vfolderpermission'), + postgresql_using='permission::vfolderpermission') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('vfolder_invitations', column_name='permission', + type_=sa.String(length=2), + postgresql_using='permission::text::vfolderpermission') + op.alter_column('vfolder_permissions', column_name='permission', + type_=sa.String(length=2), + postgresql_using='permission::text::vfolderpermission') + vfolderpermission.drop(op.get_bind()) + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/a7ca9f175d5f_merge.py b/src/ai/backend/manager/models/alembic/versions/a7ca9f175d5f_merge.py new file mode 100644 index 0000000000..c8d180d23c --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/a7ca9f175d5f_merge.py @@ -0,0 +1,24 @@ +"""merge + +Revision ID: a7ca9f175d5f +Revises: d59ff89e7514, 11146ba02235 +Create Date: 2022-03-28 15:25:22.965843 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a7ca9f175d5f' +down_revision = ('d59ff89e7514', '11146ba02235') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/babc74594aa6_add_partial_index_to_kernels.py b/src/ai/backend/manager/models/alembic/versions/babc74594aa6_add_partial_index_to_kernels.py new file mode 100644 index 0000000000..b0079e94ce --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/babc74594aa6_add_partial_index_to_kernels.py @@ -0,0 +1,31 @@ +"""add-partial-index-to-kernels + +Revision ID: babc74594aa6 +Revises: c3e74dcf1808 +Create Date: 2018-01-04 14:33:39.173062 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'babc74594aa6' +down_revision = 'c3e74dcf1808' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_index( + op.f('ix_kernels_unique_sess_token'), + 'kernels', ['access_key', 'sess_id'], + unique=True, + postgresql_where=sa.text( + "kernels.status != 'TERMINATED' and " + "kernels.role = 'master'"), + ) + + +def downgrade(): + op.drop_index(op.f('ix_kernels_unique_sess_token'), table_name='kernels') diff --git a/src/ai/backend/manager/models/alembic/versions/bae1a7326e8a_add_domain_model.py b/src/ai/backend/manager/models/alembic/versions/bae1a7326e8a_add_domain_model.py new file mode 100644 index 0000000000..4a77d529af --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/bae1a7326e8a_add_domain_model.py @@ -0,0 +1,79 @@ +"""add domain model + +Revision ID: bae1a7326e8a +Revises: 819c2b3830a9 +Create Date: 2019-05-08 08:29:29.588817 + +""" +import textwrap +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from ai.backend.manager.models.base import ( + convention, ResourceSlotColumn, +) + + +# revision identifiers, used by Alembic. +revision = 'bae1a7326e8a' +down_revision = '819c2b3830a9' +branch_labels = None +depends_on = None + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + + # partial table to insert "default" domain + domains = sa.Table( + 'domains', metadata, + sa.Column('name', sa.String(length=64), primary_key=True), + sa.Column('description', sa.String(length=512)), + sa.Column('is_active', sa.Boolean, default=True), + sa.Column('total_resource_slots', ResourceSlotColumn(), nullable=False), + ) + + op.create_table( + 'domains', + sa.Column('name', sa.String(length=64), nullable=False), + sa.Column('description', sa.String(length=512), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), nullable=True), + sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), nullable=True), + sa.Column('total_resource_slots', + postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint('name', name=op.f('pk_domains')) + ) + op.add_column('users', sa.Column('domain_name', sa.String(length=64), nullable=True)) + op.create_index(op.f('ix_users_domain_name'), 'users', ['domain_name'], unique=False) + op.create_foreign_key(op.f('fk_users_domain_name_domains'), + 'users', 'domains', ['domain_name'], ['name']) + + # Fill in users' domain_name column with default domain. + # Create default domain if not exist. + connection = op.get_bind() + query = sa.select([domains]).select_from(domains).where(domains.c.name == 'default') + results = connection.execute(query).first() + if results is None: + query = (sa.insert(domains) + .values(name='default', + description='Default domain', + is_active=True, + total_resource_slots='{}')) + query = textwrap.dedent('''\ + INSERT INTO domains (name, description, is_active, total_resource_slots) + VALUES ('default', 'Default domain', True, '{}'::jsonb);''') + connection.execute(query) + + # Fill in users' domain_name field. + query = "UPDATE users SET domain_name = 'default' WHERE email != 'admin@lablup.com';" + connection.execute(query) + + +def downgrade(): + op.drop_constraint(op.f('fk_users_domain_name_domains'), 'users', type_='foreignkey') + op.drop_index(op.f('ix_users_domain_name'), table_name='users') + op.drop_column('users', 'domain_name') + op.drop_table('domains') diff --git a/src/ai/backend/manager/models/alembic/versions/bf4bae8f942e_add_kernel_host.py b/src/ai/backend/manager/models/alembic/versions/bf4bae8f942e_add_kernel_host.py new file mode 100644 index 0000000000..86ecdfc146 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/bf4bae8f942e_add_kernel_host.py @@ -0,0 +1,24 @@ +"""add-kernel-host + +Revision ID: bf4bae8f942e +Revises: babc74594aa6 +Create Date: 2018-02-02 11:29:38.752576 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'bf4bae8f942e' +down_revision = 'babc74594aa6' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('kernel_host', sa.String(length=128), nullable=True)) + + +def downgrade(): + op.drop_column('kernels', 'kernel_host') diff --git a/src/ai/backend/manager/models/alembic/versions/c092dabf3ee5_add_batch_session.py b/src/ai/backend/manager/models/alembic/versions/c092dabf3ee5_add_batch_session.py new file mode 100644 index 0000000000..daaeb0206e --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c092dabf3ee5_add_batch_session.py @@ -0,0 +1,43 @@ +"""add-batch-session + +Revision ID: c092dabf3ee5 +Revises: c1409ad0e8da +Create Date: 2019-08-01 15:18:20.306290 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'c092dabf3ee5' +down_revision = '48ab2dfefba9' +branch_labels = None +depends_on = None + + +sessiontypes = postgresql.ENUM( + 'INTERACTIVE', 'BATCH', + name='sessiontypes') + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('startup_command', sa.Text(), nullable=True)) + op.drop_column('kernels', 'type') + op.add_column('kernels', + sa.Column('sess_type', sa.Enum('INTERACTIVE', 'BATCH', name='sessiontypes'), + nullable=False, server_default='INTERACTIVE')) + op.create_index(op.f('ix_kernels_sess_type'), 'kernels', ['sess_type'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_kernels_sess_type'), table_name='kernels') + op.drop_column('kernels', 'sess_type') + op.add_column('kernels', + sa.Column('type', sa.Enum('INTERACTIVE', 'BATCH', name='sessiontypes'), + nullable=False, server_default='INTERACTIVE')) + op.drop_column('kernels', 'startup_command') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/c1409ad0e8da_.py b/src/ai/backend/manager/models/alembic/versions/c1409ad0e8da_.py new file mode 100644 index 0000000000..47aa4bdeae --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c1409ad0e8da_.py @@ -0,0 +1,24 @@ +"""empty message + +Revision ID: c1409ad0e8da +Revises: 22e52d03fc61, 9a91532c8534 +Create Date: 2019-07-29 20:18:52.291350 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'c1409ad0e8da' +down_revision = ('22e52d03fc61', '9a91532c8534') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/c3e74dcf1808_add_environ_to_kernels.py b/src/ai/backend/manager/models/alembic/versions/c3e74dcf1808_add_environ_to_kernels.py new file mode 100644 index 0000000000..af46de3a5a --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c3e74dcf1808_add_environ_to_kernels.py @@ -0,0 +1,23 @@ +"""add_environ_to_kernels + +Revision ID: c3e74dcf1808 +Revises: d52bf5ec9ef3 +Create Date: 2017-11-15 11:31:54.083566 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'c3e74dcf1808' +down_revision = 'd52bf5ec9ef3' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('environ', sa.ARRAY(sa.String()), nullable=True)) + + +def downgrade(): + op.drop_column('kernels', 'environ') diff --git a/src/ai/backend/manager/models/alembic/versions/c401d78cc7b9_add_allowed_vfolder_hosts_to_domain_and_.py b/src/ai/backend/manager/models/alembic/versions/c401d78cc7b9_add_allowed_vfolder_hosts_to_domain_and_.py new file mode 100644 index 0000000000..2a46e1a6b2 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c401d78cc7b9_add_allowed_vfolder_hosts_to_domain_and_.py @@ -0,0 +1,42 @@ +"""add_allowed_vfolder_hosts_to_domain_and_group + +Revision ID: c401d78cc7b9 +Revises: 3cf19d906e71 +Create Date: 2019-06-26 11:34:55.426107 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'c401d78cc7b9' +down_revision = '3cf19d906e71' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('domains', sa.Column('allowed_vfolder_hosts', + postgresql.ARRAY(sa.String()), nullable=True)) + op.add_column('groups', sa.Column('allowed_vfolder_hosts', + postgresql.ARRAY(sa.String()), nullable=True)) + # ### end Alembic commands ### + + print('\nSet domain and group\'s allowed_vfolder_hosts with empty array.') + connection = op.get_bind() + query = ("UPDATE domains SET allowed_vfolder_hosts = '{}';") + connection.execute(query) + query = ("UPDATE groups SET allowed_vfolder_hosts = '{}';") + connection.execute(query) + + op.alter_column('domains', column_name='allowed_vfolder_hosts', nullable=False) + op.alter_column('groups', column_name='allowed_vfolder_hosts', nullable=False) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('groups', 'allowed_vfolder_hosts') + op.drop_column('domains', 'allowed_vfolder_hosts') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/c481d3dc6c7d_add_shared_memory_to_resource_presets.py b/src/ai/backend/manager/models/alembic/versions/c481d3dc6c7d_add_shared_memory_to_resource_presets.py new file mode 100644 index 0000000000..ee8ec1da3c --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c481d3dc6c7d_add_shared_memory_to_resource_presets.py @@ -0,0 +1,28 @@ +"""add_shared_memory_to_resource_presets + +Revision ID: c481d3dc6c7d +Revises: f5530eccf202 +Create Date: 2020-04-20 14:10:35.591063 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'c481d3dc6c7d' +down_revision = 'f5530eccf202' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('resource_presets', sa.Column('shared_memory', sa.BigInteger(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('resource_presets', 'shared_memory') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/c5e4e764f9e3_add_domain_group_user_fields_to_kernels.py b/src/ai/backend/manager/models/alembic/versions/c5e4e764f9e3_add_domain_group_user_fields_to_kernels.py new file mode 100644 index 0000000000..c00d3f5b44 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c5e4e764f9e3_add_domain_group_user_fields_to_kernels.py @@ -0,0 +1,125 @@ +"""add domain, group, user fields to kernels + +Revision ID: c5e4e764f9e3 +Revises: 6f1c1b83870a +Create Date: 2019-05-28 10:22:56.904061 + +""" +import textwrap +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import ( + convention, ForeignKeyIDColumn, GUID, IDColumn, +) + + +# revision identifiers, used by Alembic. +revision = 'c5e4e764f9e3' +down_revision = '6f1c1b83870a' +branch_labels = None +depends_on = None + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + # partial tables for data migration + groups = sa.Table( + 'groups', metadata, + IDColumn('id'), + sa.Column('name', sa.String(length=64), nullable=False), + sa.Column('description', sa.String(length=512)), + sa.Column('is_active', sa.Boolean, default=True), + sa.Column('domain_name', sa.String(length=64), + sa.ForeignKey('domains.name', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False, index=True), + ) + users = sa.Table( + 'users', metadata, + IDColumn('uuid'), + ) + kernels = sa.Table( + 'kernels', metadata, + IDColumn(), + sa.Column('access_key', sa.String(length=20), sa.ForeignKey('keypairs.access_key')), + ) + keypairs = sa.Table( + 'keypairs', metadata, + sa.Column('user_id', sa.String(length=256), index=True), + sa.Column('access_key', sa.String(length=20), primary_key=True), + ForeignKeyIDColumn('user', 'users.uuid', nullable=False), + ) + + op.add_column('kernels', sa.Column('domain_name', sa.String(length=64), nullable=True)) + op.add_column('kernels', sa.Column('group_id', GUID(), nullable=True)) + op.add_column('kernels', sa.Column('user_uuid', GUID(), nullable=True)) + op.create_foreign_key(op.f('fk_kernels_group_id_groups'), + 'kernels', 'groups', ['group_id'], ['id']) + op.create_foreign_key(op.f('fk_kernels_user_uuid_users'), + 'kernels', 'users', ['user_uuid'], ['uuid']) + op.create_foreign_key(op.f('fk_kernels_domain_name_domains'), + 'kernels', 'domains', ['domain_name'], ['name']) + + # Create default group in the default domain. + # Assumption: "default" domain must exist + connection = op.get_bind() + query = (sa.insert(groups) + .values(name='default', + description='Default group', + is_active=True, + domain_name='default')) + query = textwrap.dedent('''\ + INSERT INTO groups (name, description, is_active, domain_name) + VALUES ('default', 'Default group', True, 'default') + ON CONFLICT (name, domain_name) DO NOTHING + RETURNING id; + ''') + result = connection.execute(query).first() + gid = result.id if hasattr(result, 'id') else None + if gid is None: # group already exists + query = textwrap.dedent('''\ + SELECT id FROM groups where name = 'default' and domain_name = 'default'; + ''') + gid = connection.execute(query).first().id + + # Fill in kernels' domain_name, group_id, and user_uuid. + query = sa.select([kernels.c.id, kernels.c.access_key]).select_from(kernels) + all_kernels = connection.execute(query).fetchall() + for kernel in all_kernels: + # Get kernel's keypair (access_key). + query = (sa.select([keypairs.c.user]).select_from(keypairs) + .where(keypairs.c.access_key == kernel['access_key'])) + kp = connection.execute(query).first() + # Update kernel information. + query = '''\ + UPDATE kernels SET domain_name = 'default', group_id = '%s', user_uuid = '%s' + WHERE id = '%s'; + ''' % (gid, kp.user, kernel['id']) + connection.execute(query) + + # Associate every users with the default group. + # NOTE: this operation is not undoable unless you drop groups table. + query = sa.select([users.c.uuid]).select_from(users) + all_users = connection.execute(query).fetchall() + for user in all_users: + query = '''\ + INSERT INTO association_groups_users (group_id, user_id) + VALUES ('%s', '%s') + ON CONFLICT (group_id, user_id) DO NOTHING; + ''' % (gid, user.uuid,) + connection.execute(query) + + # Make kernel's new fields non-nullable. + op.alter_column('kernels', column_name='domain_name', nullable=False) + op.alter_column('kernels', column_name='group_id', nullable=False) + op.alter_column('kernels', column_name='user_uuid', nullable=False) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f('fk_kernels_domain_name_domains'), 'kernels', type_='foreignkey') + op.drop_constraint(op.f('fk_kernels_user_uuid_users'), 'kernels', type_='foreignkey') + op.drop_constraint(op.f('fk_kernels_group_id_groups'), 'kernels', type_='foreignkey') + op.drop_column('kernels', 'user_uuid') + op.drop_column('kernels', 'group_id') + op.drop_column('kernels', 'domain_name') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/ce209920f654_create_task_template_table.py b/src/ai/backend/manager/models/alembic/versions/ce209920f654_create_task_template_table.py new file mode 100644 index 0000000000..4d5321a265 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/ce209920f654_create_task_template_table.py @@ -0,0 +1,49 @@ +"""Create task_template table + +Revision ID: ce209920f654 +Revises: 5e88398bc340 +Create Date: 2019-12-16 13:39:13.210996 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pgsql +from ai.backend.manager.models.base import GUID, IDColumn, ForeignKeyIDColumn + + +# revision identifiers, used by Alembic. +revision = 'ce209920f654' +down_revision = '65c4a109bbc7' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'session_templates', + IDColumn('id'), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True), + sa.Column('is_active', sa.Boolean, default=True), + sa.Column('type', + sa.Enum('TASK', 'CLUSTER', name='templatetypes'), + nullable=False, + server_default='TASK' + ), + sa.Column('domain_name', sa.String(length=64), sa.ForeignKey('domains.name'), nullable=False), + sa.Column('group_id', GUID, sa.ForeignKey('groups.id'), nullable=True), + sa.Column('user_uuid', GUID, sa.ForeignKey('users.uuid'), nullable=False), + + sa.Column('name', sa.String(length=128), nullable=True), + sa.Column('template', sa.String(length=16 * 1024), nullable=False) + ) + op.add_column( + 'kernels', + sa.Column('bootstrap_script', sa.String(length=4 * 1024), nullable=True) + ) + + +def downgrade(): + op.drop_table('session_templates') + op.execute('DROP TYPE templatetypes') + op.drop_column('kernels', 'bootstrap_script') diff --git a/src/ai/backend/manager/models/alembic/versions/d2aafa234374_create_error_logs_table.py b/src/ai/backend/manager/models/alembic/versions/d2aafa234374_create_error_logs_table.py new file mode 100644 index 0000000000..54ebdb8339 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d2aafa234374_create_error_logs_table.py @@ -0,0 +1,44 @@ +"""Create error_logs table + +Revision ID: d2aafa234374 +Revises: 3bb80d1887d6 +Create Date: 2020-02-12 13:55:12.450743 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from ai.backend.manager.models.base import IDColumn, GUID + +# revision identifiers, used by Alembic. +revision = 'd2aafa234374' +down_revision = '3bb80d1887d6' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'error_logs', + IDColumn(), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True), + sa.Column('severity', sa.Enum('critical', 'error', 'warning', 'info', 'debug', name='errorlog_severity'), + index=True), + sa.Column('source', sa.String), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=True, index=True), + sa.Column('is_read', sa.Boolean, default=False, index=True), + sa.Column('is_cleared', sa.Boolean, default=False, index=True), + sa.Column('message', sa.Text), + sa.Column('context_lang', sa.String), + sa.Column('context_env', postgresql.JSONB()), + sa.Column('request_url', sa.String, nullable=True), + sa.Column('request_status', sa.Integer, nullable=True), + sa.Column('traceback', sa.Text, nullable=True), + ) + + +def downgrade(): + op.drop_table('error_logs') + op.execute('DROP TYPE errorlog_severity') diff --git a/src/ai/backend/manager/models/alembic/versions/d452bacd085c_add_mount_map_column_to_kernel.py b/src/ai/backend/manager/models/alembic/versions/d452bacd085c_add_mount_map_column_to_kernel.py new file mode 100644 index 0000000000..caac9cc902 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d452bacd085c_add_mount_map_column_to_kernel.py @@ -0,0 +1,26 @@ +"""Add mount_map column to kernel + +Revision ID: d452bacd085c +Revises: 4b7b650bc30e +Create Date: 2019-11-19 14:43:12.728678 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pgsql + + +# revision identifiers, used by Alembic. +revision = 'd452bacd085c' +down_revision = '4b7b650bc30e' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('mount_map', pgsql.JSONB(), nullable=True, default={})) + + +def downgrade(): + op.drop_column('kernels', 'mount_map') + diff --git a/src/ai/backend/manager/models/alembic/versions/d463fc5d6109_add_clone_allowed_to_vfolders.py b/src/ai/backend/manager/models/alembic/versions/d463fc5d6109_add_clone_allowed_to_vfolders.py new file mode 100644 index 0000000000..9025bb2f2b --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d463fc5d6109_add_clone_allowed_to_vfolders.py @@ -0,0 +1,30 @@ +"""add_clone_allowed_to_vfolders + +Revision ID: d463fc5d6109 +Revises: 0d553d59f369 +Create Date: 2020-09-01 16:57:55.339619 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd463fc5d6109' +down_revision = '0d553d59f369' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('vfolders', sa.Column('clone_allowed', sa.Boolean(), nullable=True)) + op.execute("UPDATE vfolders SET clone_allowed = false") + op.alter_column('vfolders', 'clone_allowed', nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('vfolders', 'clone_allowed') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/d52bf5ec9ef3_convert_cpu_gpu_slots_to_float.py b/src/ai/backend/manager/models/alembic/versions/d52bf5ec9ef3_convert_cpu_gpu_slots_to_float.py new file mode 100644 index 0000000000..483f05f7b4 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d52bf5ec9ef3_convert_cpu_gpu_slots_to_float.py @@ -0,0 +1,75 @@ +"""convert_cpu_gpu_slots_to_float + +Revision ID: d52bf5ec9ef3 +Revises: 4545f5c948b3 +Create Date: 2017-11-09 14:30:20.737908 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'd52bf5ec9ef3' +down_revision = '4545f5c948b3' +branch_labels = None +depends_on = None + + +def upgrade(): + op.alter_column('agents', 'mem_slots', + existing_type=sa.Integer(), + type_=sa.BigInteger()) + op.alter_column('agents', 'cpu_slots', + existing_type=sa.Integer(), + type_=sa.Float()) + op.alter_column('agents', 'gpu_slots', + existing_type=sa.Integer(), + type_=sa.Float()) + op.alter_column('agents', 'used_mem_slots', + existing_type=sa.Integer(), + type_=sa.BigInteger()) + op.alter_column('agents', 'used_cpu_slots', + existing_type=sa.Integer(), + type_=sa.Float()) + op.alter_column('agents', 'used_gpu_slots', + existing_type=sa.Integer(), + type_=sa.Float()) + op.alter_column('kernels', 'mem_slot', + existing_type=sa.Integer(), + type_=sa.BigInteger()) + op.alter_column('kernels', 'cpu_slot', + existing_type=sa.Integer(), + type_=sa.Float()) + op.alter_column('kernels', 'gpu_slot', + existing_type=sa.Integer(), + type_=sa.Float()) + + +def downgrade(): + op.alter_column('agents', 'mem_slots', + existing_type=sa.BigInteger(), + type_=sa.Integer()) + op.alter_column('agents', 'cpu_slots', + existing_type=sa.Float(), + type_=sa.Integer()) + op.alter_column('agents', 'gpu_slots', + existing_type=sa.Float(), + type_=sa.Integer()) + op.alter_column('agents', 'used_mem_slots', + existing_type=sa.BigInteger(), + type_=sa.Integer()) + op.alter_column('agents', 'used_cpu_slots', + existing_type=sa.Float(), + type_=sa.Integer()) + op.alter_column('agents', 'used_gpu_slots', + existing_type=sa.Float(), + type_=sa.Integer()) + op.alter_column('kernels', 'mem_slot', + existing_type=sa.BigInteger(), + type_=sa.Integer()) + op.alter_column('kernels', 'cpu_slot', + existing_type=sa.Float(), + type_=sa.Integer()) + op.alter_column('kernels', 'gpu_slot', + existing_type=sa.Float(), + type_=sa.Integer()) diff --git a/src/ai/backend/manager/models/alembic/versions/d582942886ad_add_tag_to_kernels.py b/src/ai/backend/manager/models/alembic/versions/d582942886ad_add_tag_to_kernels.py new file mode 100644 index 0000000000..21d48e3995 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d582942886ad_add_tag_to_kernels.py @@ -0,0 +1,28 @@ +"""add tag to kernels + +Revision ID: d582942886ad +Revises: a1fd4e7b7782 +Create Date: 2018-10-25 10:51:39.448309 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd582942886ad' +down_revision = 'a1fd4e7b7782' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('tag', sa.String(length=64), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('kernels', 'tag') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/d59ff89e7514_remove_keypair_concurrency_used.py b/src/ai/backend/manager/models/alembic/versions/d59ff89e7514_remove_keypair_concurrency_used.py new file mode 100644 index 0000000000..5f11da6786 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d59ff89e7514_remove_keypair_concurrency_used.py @@ -0,0 +1,26 @@ +"""remove_keypair_concurrency_used + +Revision ID: d59ff89e7514 +Revises: 0f7a4b643940 +Create Date: 2022-03-21 16:43:29.899251 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd59ff89e7514' +down_revision = '0f7a4b643940' +branch_labels = None +depends_on = None + + +def upgrade(): + op.drop_column('keypairs', 'concurrency_used') + + +def downgrade(): + op.add_column('keypairs', sa.Column( + 'concurrency_used', sa.Integer, nullable=True, default=0, server_default=0, + )) diff --git a/src/ai/backend/manager/models/alembic/versions/d5cc54fd36b5_update_for_multicontainer_sessions.py b/src/ai/backend/manager/models/alembic/versions/d5cc54fd36b5_update_for_multicontainer_sessions.py new file mode 100644 index 0000000000..7cf831bcb6 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d5cc54fd36b5_update_for_multicontainer_sessions.py @@ -0,0 +1,104 @@ +"""Update for multi-container sessions. + +Revision ID: d5cc54fd36b5 +Revises: 0d553d59f369 +Create Date: 2020-01-06 13:56:50.885635 + +""" +from alembic import op +import sqlalchemy as sa + +from ai.backend.manager.models.base import GUID + +# revision identifiers, used by Alembic. +revision = 'd5cc54fd36b5' +down_revision = '0d553d59f369' +branch_labels = None +depends_on = None + + +def upgrade(): + # In this mgiration, we finally clear up the column namings: + # sess_id -> session_name + # => client-provided alias + # (new) -> session_id + # => for single-container sessions, it may be derived from the kernel id. + # sess_type -> session_type + # + op.drop_index('ix_kernels_sess_id', table_name='kernels') + op.drop_index('ix_kernels_sess_type', table_name='kernels') + + conn = op.get_bind() + op.add_column( + 'kernels', + sa.Column('idx', sa.Integer, nullable=True, default=None)) + op.add_column( + 'kernels', + sa.Column('cluster_mode', sa.String(16), nullable=False, + default='single-node', server_default='single-node')) + + # Set idx to 1 (previous sessions are all composed of one kernel) + query = "UPDATE kernels SET idx = 1;" + conn.execute(query) + + # Convert "master" to "main" + # NOTE: "main" is defined from ai.backend.manager.defs.DEFAULT_ROLE + op.alter_column('kernels', 'role', server_default='main') + query = "UPDATE kernels SET role = 'main' WHERE role = 'master'" + conn.execute(query) + + # First a session_id column as nullable and fill it up before setting it non-nullable. + op.add_column('kernels', sa.Column('session_id', GUID, nullable=True)) + query = "UPDATE kernels SET session_id = kernels.id WHERE role = 'main'" + conn.execute(query) + # If we upgrade from a database downgraded in the past with sub-kernel records, + # we loose the information of kernel_id -> session_id mapping. + # Try to restore it by getting the session ID of a main-kernel record which is created + # at a similar time range. This will raise an error if there are two or more such records, + # and it is based on an assumption that development setups with manual tests would not make such + # overlaps. + query = """ + UPDATE kernels t SET session_id = ( + SELECT session_id + FROM kernels s + WHERE + s.role = 'main' + AND ( + s.created_at BETWEEN + t.created_at - (interval '0.5s') + AND t.created_at + (interval '3s') + ) + ) + WHERE t.role <> 'main' + """ + conn.execute(query) + op.alter_column('kernels', 'session_id', nullable=False) + + op.alter_column('kernels', 'sess_id', new_column_name='session_name') + op.alter_column('kernels', 'sess_type', new_column_name='session_type') + + op.create_index(op.f('ix_kernels_session_id'), 'kernels', ['session_id'], unique=False) + op.create_index(op.f('ix_kernels_session_name'), 'kernels', ['session_name'], unique=False) + op.create_index(op.f('ix_kernels_session_type'), 'kernels', ['session_type'], unique=False) + + +def downgrade(): + op.drop_index(op.f('ix_kernels_session_type'), table_name='kernels') + op.drop_index(op.f('ix_kernels_session_name'), table_name='kernels') + op.drop_index(op.f('ix_kernels_session_id'), table_name='kernels') + + op.alter_column('kernels', 'session_type', new_column_name='sess_type') + op.alter_column('kernels', 'session_name', new_column_name='sess_id') + op.drop_column('kernels', 'session_id') + + # Convert "main" to "master" for backward compatibility + op.alter_column('kernels', 'role', server_default='master') + conn = op.get_bind() + query = "UPDATE kernels SET role = 'master' WHERE role = 'main'" + conn.execute(query) + + op.drop_column('kernels', 'cluster_mode') + op.drop_column('kernels', 'idx') + + op.create_index('ix_kernels_sess_type', 'kernels', ['sess_type'], unique=False) + op.create_index('ix_kernels_sess_id', 'kernels', ['sess_id'], unique=False) diff --git a/src/ai/backend/manager/models/alembic/versions/d643752544de_.py b/src/ai/backend/manager/models/alembic/versions/d643752544de_.py new file mode 100644 index 0000000000..313ac08ac8 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d643752544de_.py @@ -0,0 +1,24 @@ +"""Merge 51dd and d2aa + +Revision ID: d643752544de +Revises: 51dddd79aa21, d2aafa234374 +Create Date: 2020-03-09 12:04:27.013567 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd643752544de' +down_revision = ('51dddd79aa21', 'd2aafa234374') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/d727b5da20e6_add_callback_url_to_kernels.py b/src/ai/backend/manager/models/alembic/versions/d727b5da20e6_add_callback_url_to_kernels.py new file mode 100644 index 0000000000..e187ba573d --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/d727b5da20e6_add_callback_url_to_kernels.py @@ -0,0 +1,26 @@ +"""add-callback_url-to-kernels + +Revision ID: d727b5da20e6 +Revises: a7ca9f175d5f +Create Date: 2022-03-31 07:22:28.426046 + +""" +from alembic import op +import sqlalchemy as sa + +from ai.backend.manager.models.base import URLColumn + + +# revision identifiers, used by Alembic. +revision = 'd727b5da20e6' +down_revision = 'a7ca9f175d5f' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('kernels', sa.Column('callback_url', URLColumn(), nullable=True)) + + +def downgrade(): + op.drop_column('kernels', 'callback_url') diff --git a/src/ai/backend/manager/models/alembic/versions/da24ff520049_add_starts_at_field_into_kernels.py b/src/ai/backend/manager/models/alembic/versions/da24ff520049_add_starts_at_field_into_kernels.py new file mode 100644 index 0000000000..198ae3cd83 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/da24ff520049_add_starts_at_field_into_kernels.py @@ -0,0 +1,28 @@ +"""add_startsat_field_into_kernels + +Revision ID: da24ff520049 +Revises: 529113b08c2c +Create Date: 2020-06-18 20:47:22.152831 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'da24ff520049' +down_revision = '529113b08c2c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('starts_at', sa.DateTime(timezone=True), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('kernels', 'starts_at') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/dbc1e053b880_add_keypair_resource_policy.py b/src/ai/backend/manager/models/alembic/versions/dbc1e053b880_add_keypair_resource_policy.py new file mode 100644 index 0000000000..c07e63b8eb --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/dbc1e053b880_add_keypair_resource_policy.py @@ -0,0 +1,97 @@ +"""add-keypair-resource-policy + +Revision ID: dbc1e053b880 +Revises: 2b0931e4a059 +Create Date: 2019-02-07 15:30:54.861821 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.common.types import DefaultForUnspecified +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'dbc1e053b880' +down_revision = '2b0931e4a059' +branch_labels = None +depends_on = None + + +default_for_unspecified_choices = list(map(lambda v: v.name, DefaultForUnspecified)) +default_for_unspecified = postgresql.ENUM( + *default_for_unspecified_choices, + name='default_for_unspecified', +) + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'keypair_resource_policies', + sa.Column('name', sa.String(length=256), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), + nullable=True), + sa.Column('default_for_unspecified', + type_=sa.Enum(*default_for_unspecified_choices, + name='default_for_unspecified'), + nullable=False), + sa.Column('total_resource_slots', + postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('max_concurrent_sessions', sa.Integer(), nullable=False), + sa.Column('max_containers_per_session', sa.Integer(), nullable=False), + sa.Column('max_vfolder_count', sa.Integer(), nullable=False), + sa.Column('max_vfolder_size', sa.BigInteger(), nullable=False), + sa.Column('allowed_vfolder_hosts', + postgresql.ARRAY(sa.String()), nullable=False), + sa.PrimaryKeyConstraint('name', name=op.f('pk_keypair_resource_policies')) + ) + # Create a default resource policy + query = ''' + INSERT INTO keypair_resource_policies + VALUES ( + 'default', + now(), + 'UNLIMITED', + '{}'::jsonb, + 30, + 1, + 10, + 0, + array['local'] + ); + UPDATE keypairs + SET resource_policy = 'default'; + ''' + connection = op.get_bind() + connection.execute(query) + print('\n!!! NOTICE !!!\n') + print('Created a default resource policy and linked all keypairs to it.') + print('Please inspect and adjust it!\n') + op.alter_column('keypairs', 'resource_policy', + existing_type=sa.VARCHAR(), + type_=sa.String(length=256), + nullable=False) + op.create_foreign_key( + op.f('fk_keypairs_resource_policy_keypair_resource_policies'), + 'keypairs', 'keypair_resource_policies', + ['resource_policy'], ['name']) + op.drop_column('keypairs', 'concurrency_limit') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(op.f('fk_keypairs_resource_policy_keypair_resource_policies'), + 'keypairs', type_='foreignkey') + op.add_column( + 'keypairs', + sa.Column('concurrency_limit', sa.INTEGER(), + autoincrement=False, nullable=True)) + op.alter_column('keypairs', 'resource_policy', + existing_type=sa.String(length=256), + type_=sa.VARCHAR(), + nullable=True) + op.drop_table('keypair_resource_policies') + default_for_unspecified.drop(op.get_bind()) + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/dc9b66466e43_remove_clusterized.py b/src/ai/backend/manager/models/alembic/versions/dc9b66466e43_remove_clusterized.py new file mode 100644 index 0000000000..f12b5cdcd1 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/dc9b66466e43_remove_clusterized.py @@ -0,0 +1,24 @@ +"""remove-clusterized + +Revision ID: dc9b66466e43 +Revises: 06184d82a211 +Create Date: 2020-12-25 04:45:20.245137 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'dc9b66466e43' +down_revision = '06184d82a211' +branch_labels = None +depends_on = None + + +def upgrade(): + op.drop_column('agents', 'clusterized') + + +def downgrade(): + op.add_column('agents', sa.Column('clusterized', sa.BOOLEAN(), autoincrement=False, nullable=True)) diff --git a/src/ai/backend/manager/models/alembic/versions/e18ed5fcfedf_add_superadmin_role_for_user.py b/src/ai/backend/manager/models/alembic/versions/e18ed5fcfedf_add_superadmin_role_for_user.py new file mode 100644 index 0000000000..1dde76fb2b --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/e18ed5fcfedf_add_superadmin_role_for_user.py @@ -0,0 +1,78 @@ +"""add superadmin role for user + +Revision ID: e18ed5fcfedf +Revises: c5e4e764f9e3 +Create Date: 2019-05-29 23:17:17.762968 + +""" +import textwrap + +from alembic import op +from ai.backend.manager.models import UserRole + + +# revision identifiers, used by Alembic. +revision = 'e18ed5fcfedf' +down_revision = 'c5e4e764f9e3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + # Add superadmin to user role choices. + userrole_choices = list(map(lambda v: v.value, UserRole)) + assert 'superadmin' in userrole_choices, 'superadmin in UserRole is required!' + + conn = op.get_bind() + conn.execute('ALTER TYPE userrole RENAME TO userrole__;') + conn.execute('CREATE TYPE userrole as enum (%s)' % ("'" + "','".join(userrole_choices) + "'")) + conn.execute(textwrap.dedent('''\ + ALTER TABLE users + ALTER COLUMN role TYPE userrole USING role::text::userrole; + ''')) + conn.execute('DROP TYPE userrole__;') + + # Set admin@lablup.com's role as superadmin. + # Also, set admin@lablup.com's domain to default. + # + # We have judged superadmin as an admin user not associated with any domain. + # This results in broken code execution for superadmin since doamain_name should not be null. + # So, this policy is changed to simply adopt superadmin role, and superadmin can also have + # domain and groups as well. + query = "SELECT uuid FROM users where email = 'admin@lablup.com';" + result = conn.execute(query).first() + uuid = result.uuid if hasattr(result, 'uuid') else None + if uuid is not None: # update only when admin@lablup.com user exist + query = textwrap.dedent('''\ + UPDATE users SET domain_name = 'default', role = 'superadmin' + WHERE email = 'admin@lablup.com'; + ''') + conn.execute(query) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + userrole_choices = list(map(lambda v: v.value, UserRole)) + if 'superadmin' in userrole_choices: + userrole_choices.remove('superadmin') + conn = op.get_bind() + + # First, change all superadmin role to admin. + query = textwrap.dedent("UPDATE users SET role = 'admin' WHERE role = 'superadmin';") + conn.execute(query) + + # Remove superadmin from user role choices. + conn.execute('ALTER TYPE userrole RENAME TO userrole___;') + conn.execute('CREATE TYPE userrole as enum (%s)' % ("'" + "','".join(userrole_choices) + "'")) + conn.execute(textwrap.dedent('''\ + ALTER TABLE users + ALTER COLUMN role TYPE userrole USING role::text::userrole; + ''')) + conn.execute('DROP TYPE userrole___;') diff --git a/src/ai/backend/manager/models/alembic/versions/e35332f8d23d_add_modified_at_to_users_and_kernels.py b/src/ai/backend/manager/models/alembic/versions/e35332f8d23d_add_modified_at_to_users_and_kernels.py new file mode 100644 index 0000000000..18988eeb62 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/e35332f8d23d_add_modified_at_to_users_and_kernels.py @@ -0,0 +1,78 @@ +"""add_modified_at_to_users_and_kernels + +Revision ID: e35332f8d23d +Revises: da24ff520049 +Create Date: 2020-07-01 14:02:11.022032 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql.expression import bindparam + +from ai.backend.manager.models.base import convention, IDColumn + +# revision identifiers, used by Alembic. +revision = 'e35332f8d23d' +down_revision = 'da24ff520049' +branch_labels = None +depends_on = None + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + # partial table to be preserved and referred + users = sa.Table( + 'users', metadata, + IDColumn('uuid'), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now()), + sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), onupdate=sa.func.current_timestamp()), + ) + keypairs = sa.Table( + 'keypairs', metadata, + sa.Column('access_key', sa.String(length=20), primary_key=True), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now()), + sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), onupdate=sa.func.current_timestamp()), + ) + + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('keypairs', sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), nullable=True)) + op.add_column('users', sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), nullable=True)) + # ### end Alembic commands ### + + conn = op.get_bind() + + # Set user's modified_at with the value of created_at. + query = sa.select([users.c.uuid, users.c.created_at]).select_from(users) + updates = [] + for row in conn.execute(query).fetchall(): + updates.append({'b_uuid': row['uuid'], 'modified_at': row['created_at']}) + if updates: + query = (sa.update(users) + .values(modified_at=bindparam('modified_at')) + .where(users.c.uuid == bindparam('b_uuid'))) + conn.execute(query, updates) + + # Set keypairs's modified_at with the value of created_at. + query = sa.select([keypairs.c.access_key, keypairs.c.created_at]).select_from(keypairs) + updates = [] + for row in conn.execute(query).fetchall(): + updates.append({'b_access_key': row['access_key'], 'modified_at': row['created_at']}) + if updates: + query = (sa.update(keypairs) + .values(modified_at=bindparam('modified_at')) + .where(keypairs.c.access_key == bindparam('b_access_key'))) + conn.execute(query, updates) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'modified_at') + op.drop_column('keypairs', 'modified_at') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/e421c02cf9e4_rename_kernel_dependencies_to_session_.py b/src/ai/backend/manager/models/alembic/versions/e421c02cf9e4_rename_kernel_dependencies_to_session_.py new file mode 100644 index 0000000000..3404eeb7e6 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/e421c02cf9e4_rename_kernel_dependencies_to_session_.py @@ -0,0 +1,51 @@ +"""rename_kernel_dependencies_to_session_dependencies + +Revision ID: e421c02cf9e4 +Revises: 548cc8aa49c8 +Create Date: 2020-09-14 10:45:40.218548 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'e421c02cf9e4' +down_revision = '548cc8aa49c8' +branch_labels = None +depends_on = None + + +def upgrade(): + op.drop_constraint('fk_kernel_dependencies_depends_on_kernels', 'kernel_dependencies') + op.drop_constraint('fk_kernel_dependencies_kernel_id_kernels', 'kernel_dependencies') + op.rename_table('kernel_dependencies', 'session_dependencies') + op.alter_column('session_dependencies', 'kernel_id', new_column_name='session_id') + op.execute('ALTER INDEX pk_kernel_dependencies ' + 'RENAME TO pk_session_dependencies') + op.execute('ALTER INDEX ix_kernel_dependencies_depends_on ' + 'RENAME TO ix_session_dependencies_depends_on') + op.execute('ALTER INDEX ix_kernel_dependencies_kernel_id ' + 'RENAME TO ix_session_dependencies_session_id') + # NOTE: we keep the fkey target as "kernels.id" instead of "kernels.session_id" + # because fkey target must be a unique index and in Backend.AI `kernels.session_id` + # is same to the main kernel's `kernels.id`. + op.create_foreign_key(None, 'session_dependencies', 'kernels', ['session_id'], ['id'], + onupdate='CASCADE', ondelete='CASCADE') + op.create_foreign_key(None, 'session_dependencies', 'kernels', ['depends_on'], ['id'], + onupdate='CASCADE', ondelete='CASCADE') + + +def downgrade(): + op.drop_constraint('fk_session_dependencies_depends_on_kernels', 'session_dependencies') + op.drop_constraint('fk_session_dependencies_session_id_kernels', 'session_dependencies') + op.rename_table('session_dependencies', 'kernel_dependencies') + op.alter_column('kernel_dependencies', 'session_id', new_column_name='kernel_id') + op.execute('ALTER INDEX pk_session_dependencies ' + 'RENAME TO pk_kernel_dependencies') + op.execute('ALTER INDEX ix_session_dependencies_depends_on ' + 'RENAME TO ix_kernel_dependencies_depends_on') + op.execute('ALTER INDEX ix_session_dependencies_session_id ' + 'RENAME TO ix_kernel_dependencies_kernel_id') + op.create_foreign_key(None, 'kernel_dependencies', 'kernels', ['kernel_id'], ['id'], + onupdate='CASCADE', ondelete='CASCADE') + op.create_foreign_key(None, 'kernel_dependencies', 'kernels', ['depends_on'], ['id'], + onupdate='CASCADE', ondelete='CASCADE') diff --git a/src/ai/backend/manager/models/alembic/versions/e7371ca5797a_rename_mem_stats.py b/src/ai/backend/manager/models/alembic/versions/e7371ca5797a_rename_mem_stats.py new file mode 100644 index 0000000000..aed8e29964 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/e7371ca5797a_rename_mem_stats.py @@ -0,0 +1,24 @@ +"""rename_mem_stats + +Revision ID: e7371ca5797a +Revises: 93e9d31d40bf +Create Date: 2017-10-10 13:01:37.169568 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'e7371ca5797a' +down_revision = '93e9d31d40bf' +branch_labels = None +depends_on = None + + +def upgrade(): + op.alter_column('kernels', column_name='max_mem_bytes', new_column_name='mem_max_bytes') + op.alter_column('kernels', column_name='cur_mem_bytes', new_column_name='mem_cur_bytes') + + +def downgrade(): + op.alter_column('kernels', column_name='mem_max_bytes', new_column_name='max_mem_bytes') + op.alter_column('kernels', column_name='mem_cur_bytes', new_column_name='cur_mem_bytes') diff --git a/src/ai/backend/manager/models/alembic/versions/ed666f476f39_add_bootstrap_script_to_keypairs.py b/src/ai/backend/manager/models/alembic/versions/ed666f476f39_add_bootstrap_script_to_keypairs.py new file mode 100644 index 0000000000..3bdc70b7dd --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/ed666f476f39_add_bootstrap_script_to_keypairs.py @@ -0,0 +1,31 @@ +"""add_bootstrap_script_to_keypairs + +Revision ID: ed666f476f39 +Revises: d643752544de +Create Date: 2020-03-15 17:40:46.754121 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'ed666f476f39' +down_revision = 'd643752544de' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('keypairs', sa.Column('bootstrap_script', + sa.String(length=64 * 1024), + nullable=False, + server_default='')) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('keypairs', 'bootstrap_script') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/eec98e65902a_merge_with_vfolder_clone.py b/src/ai/backend/manager/models/alembic/versions/eec98e65902a_merge_with_vfolder_clone.py new file mode 100644 index 0000000000..fca173b6ad --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/eec98e65902a_merge_with_vfolder_clone.py @@ -0,0 +1,24 @@ +"""merge-with-vfolder-clone + +Revision ID: eec98e65902a +Revises: d463fc5d6109, 97f6c80c8aa5 +Create Date: 2020-10-03 18:11:06.270486 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'eec98e65902a' +down_revision = ('d463fc5d6109', '97f6c80c8aa5') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/alembic/versions/f0f4ee907155_dynamic_resource_slots.py b/src/ai/backend/manager/models/alembic/versions/f0f4ee907155_dynamic_resource_slots.py new file mode 100644 index 0000000000..9fdecd0c12 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/f0f4ee907155_dynamic_resource_slots.py @@ -0,0 +1,210 @@ +"""dynamic-resource-slots + +Revision ID: f0f4ee907155 +Revises: ff4bfca66bf8 +Create Date: 2019-01-27 17:05:13.997279 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f0f4ee907155' +down_revision = 'ff4bfca66bf8' +branch_labels = None +depends_on = None + + +def upgrade(): + + # ### commands auto generated by Alembic - please adjust! ### + connection = op.get_bind() + op.alter_column('kernels', 'service_ports', + existing_type=sa.JSON(), + type_=postgresql.JSONB(), + postgresql_using='CAST(service_ports AS jsonb)') + op.add_column('agents', sa.Column('available_slots', postgresql.JSONB(), + nullable=False, + server_default=sa.text("'{}'::jsonb"))) + op.add_column('agents', sa.Column('occupied_slots', postgresql.JSONB(), + nullable=False, + server_default=sa.text("'{}'::jsonb"))) + query = ''' + UPDATE agents SET available_slots = json_strip_nulls(json_build_object( + 'cpu', cpu_slots, + 'mem', mem_slots::text || 'g' + )); + UPDATE agents SET available_slots = available_slots || json_build_object( + 'cuda.device', gpu_slots + )::jsonb + WHERE gpu_slots > 0; + UPDATE agents SET available_slots = available_slots || json_build_object( + 'tpu.device', tpu_slots + )::jsonb + WHERE tpu_slots > 0; + + UPDATE agents SET occupied_slots = json_strip_nulls(json_build_object( + 'cpu', used_cpu_slots, + 'mem', used_mem_slots::text || 'g' + )); + UPDATE agents SET occupied_slots = occupied_slots || json_build_object( + 'cuda.device', used_gpu_slots + )::jsonb + WHERE used_gpu_slots > 0; + UPDATE agents SET occupied_slots = occupied_slots || json_build_object( + 'tpu.device', used_tpu_slots + )::jsonb + WHERE used_tpu_slots > 0; + ''' + connection.execute(query) + op.drop_column('agents', 'cpu_slots') + op.drop_column('agents', 'mem_slots') + op.drop_column('agents', 'gpu_slots') + op.drop_column('agents', 'tpu_slots') + op.drop_column('agents', 'used_cpu_slots') + op.drop_column('agents', 'used_mem_slots') + op.drop_column('agents', 'used_gpu_slots') + op.drop_column('agents', 'used_tpu_slots') + + op.add_column('kernels', sa.Column('occupied_slots', postgresql.JSONB(), + nullable=False, + server_default=sa.text("'{}'::jsonb"))) + op.add_column('kernels', sa.Column('occupied_shares', postgresql.JSONB(), + nullable=False, + server_default=sa.text("'{}'::jsonb"))) + query = ''' + UPDATE kernels SET occupied_slots = json_build_object( + 'cpu', cpu_slot, + 'mem', mem_slot, + 'cuda.device', gpu_slot, + 'tpu.device', tpu_slot + ); + UPDATE kernels SET occupied_shares = json_build_object( + 'cpu', cpu_set, + 'mem', mem_slot, + 'cuda.device', '{}'::json, + 'tpu.device', '{}'::json + ); + ''' + connection.execute(query) + op.drop_column('kernels', 'cpu_slot') + op.drop_column('kernels', 'mem_slot') + op.drop_column('kernels', 'gpu_slot') + op.drop_column('kernels', 'tpu_slot') + op.drop_column('kernels', 'cpu_set') + op.drop_column('kernels', 'gpu_set') + op.drop_column('kernels', 'tpu_set') + # ### end Alembic commands ### + + +def downgrade(): + op.alter_column('kernels', 'service_ports', + existing_type=postgresql.JSONB(), + type_=sa.JSON(), + postgresql_using='CAST(service_ports AS json)') + connection = op.get_bind() + + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('kernels', sa.Column('cpu_set', postgresql.ARRAY(sa.INTEGER()), + autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('gpu_set', postgresql.ARRAY(sa.INTEGER()), + autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('tpu_set', postgresql.ARRAY(sa.INTEGER()), + autoincrement=False, nullable=True)) + op.add_column('kernels', sa.Column('cpu_slot', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('kernels', sa.Column('mem_slot', sa.BIGINT(), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('kernels', sa.Column('gpu_slot', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('kernels', sa.Column('tpu_slot', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + query = ''' + UPDATE kernels + SET + cpu_set = ( + SELECT coalesce(array_agg(v::text::int), '{}') + FROM json_array_elements((occupied_shares->>'cpu')::json) v), + gpu_set = ( + SELECT coalesce(array_agg(v::text::int), '{}') + FROM json_array_elements((occupied_shares->>'cuda.device')::json) v), + tpu_set = ( + SELECT coalesce(array_agg(v::text::int), '{}') + FROM json_array_elements((occupied_shares->>'tpu.device')::json) v) + ; + ''' + connection.execute(query) + query = ''' + UPDATE kernels + SET + cpu_slot = coalesce((occupied_slots->>'cpu')::text::float, 0), + mem_slot = coalesce((occupied_slots->>'mem')::text::bigint, 0), + gpu_slot = coalesce((occupied_slots->>'cuda.device')::text::float, 0), + tpu_slot = coalesce((occupied_slots->>'tpu.device')::text::float, 0) + ; + ''' + connection.execute(query) + op.drop_column('kernels', 'occupied_shares') + op.drop_column('kernels', 'occupied_slots') + + op.add_column('agents', sa.Column('gpu_slots', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('agents', sa.Column('used_cpu_slots', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('agents', sa.Column('tpu_slots', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('agents', sa.Column('used_mem_slots', sa.BIGINT(), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('agents', sa.Column('cpu_slots', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('agents', sa.Column('used_gpu_slots', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('agents', sa.Column('mem_slots', sa.BIGINT(), + autoincrement=False, nullable=False, + server_default='0')) + op.add_column('agents', sa.Column('used_tpu_slots', + postgresql.DOUBLE_PRECISION(precision=53), + autoincrement=False, nullable=False, + server_default='0')) + query = ''' + UPDATE agents + SET + cpu_slots = coalesce((available_slots->>'cpu')::text::float, 0), + mem_slots = coalesce((available_slots->>'mem')::text::bigint, 0), + gpu_slots = coalesce((available_slots->>'cuda.device')::text::float, 0), + tpu_slots = coalesce((available_slots->>'tpu.device')::text::float, 0) + ; + ''' + connection.execute(query) + query = ''' + UPDATE agents + SET + used_cpu_slots = coalesce((occupied_slots->>'cpu')::text::float, 0), + used_mem_slots = coalesce((occupied_slots->>'mem')::text::bigint, 0), + used_gpu_slots = coalesce((occupied_slots->>'cuda.device')::text::float, 0), + used_tpu_slots = coalesce((occupied_slots->>'tpu.device')::text::float, 0) + ; + ''' + connection.execute(query) + op.drop_column('agents', 'occupied_slots') + op.drop_column('agents', 'available_slots') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/f5530eccf202_add_kernels_uuid_prefix_index.py b/src/ai/backend/manager/models/alembic/versions/f5530eccf202_add_kernels_uuid_prefix_index.py new file mode 100644 index 0000000000..b40d7f9444 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/f5530eccf202_add_kernels_uuid_prefix_index.py @@ -0,0 +1,31 @@ +"""add-kernels-uuid-prefix-index + +Revision ID: f5530eccf202 +Revises: ed666f476f39 +Create Date: 2020-03-25 17:29:50.696450 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f5530eccf202' +down_revision = 'ed666f476f39' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_index( + op.f('ix_kernels_uuid_prefix'), + 'kernels', + [sa.text('CAST("id" AS VARCHAR) COLLATE "C"')], + ) + + +def downgrade(): + op.drop_index( + op.f('ix_kernels_uuid_prefix'), + 'kernels', + ) diff --git a/src/ai/backend/manager/models/alembic/versions/f8a71c3bffa2_stringify_userid.py b/src/ai/backend/manager/models/alembic/versions/f8a71c3bffa2_stringify_userid.py new file mode 100644 index 0000000000..f90ed5f45a --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/f8a71c3bffa2_stringify_userid.py @@ -0,0 +1,79 @@ +"""stringify_userid + +Revision ID: f8a71c3bffa2 +Revises: bf4bae8f942e +Create Date: 2018-06-17 13:52:13.346856 + +""" +from alembic import op +import sqlalchemy as sa +from ai.backend.manager.models.base import convention +import os + +# revision identifiers, used by Alembic. +revision = 'f8a71c3bffa2' +down_revision = 'bf4bae8f942e' +branch_labels = None +depends_on = None + + +def upgrade(): + metadata = sa.MetaData(naming_convention=convention) + keypairs = sa.Table( + 'keypairs', metadata, + sa.Column('user_id', sa.String(length=256), index=True), + ) + + print('Choose keypairs.user_id column migrate option:') + print(' [a] Convert all numeric user IDs to strings directly') + print(' [b] Convert numeric user IDs to strings using a mapping table\n' + ' (user_id_map.txt must be present in the current working directory\n' + ' which contains a space-sep.list of numeric and string ID pairs.)') + print('NOTE: If you choose [b], you will not be able to downgrade!') + + choice = os.environ.get('STRINGIFY_USERID_CHOICE') + if choice is None: + while True: + choice = input('Your choice? [a/b] ') + if choice in ('a', 'b'): + break + print('Invalid choice.') + continue + + op.alter_column('keypairs', 'user_id', + existing_type=sa.Integer(), + type_=sa.String(length=256)) + + # NOTE: We do the data migration after converting column type. + + if choice == 'b': + # query all unique user ids + q = sa.select([keypairs.c.user_id]).group_by(keypairs.c.user_id) + rows = op.get_bind().execute(q) + user_ids = set(int(row.user_id) for row in rows) + print(f'There are {len(user_ids)} unique user IDs.') + + user_id_map = {} + with open('user_id_map.txt', 'r') as f: + for line in f: + num_id, str_id = line.split(maxsplit=1) + assert len(str_id) <= 256, \ + f'Too long target user ID! ({num_id} -> {str_id!r})' + user_id_map[int(num_id)] = str_id + + map_diff = user_ids - set(user_id_map.keys()) + assert len(map_diff) == 0, \ + f'There are unmapped user IDs!\n{map_diff}' + + for num_id, str_id in user_id_map.items(): + op.execute( + keypairs.update() + .values({'user_id': str_id}) + .where(keypairs.c.user_id == str(num_id)) + ) + + +def downgrade(): + op.alter_column('keypairs', 'user_id', + existing_type=sa.Integer(), + type_=sa.String(length=256)) diff --git a/src/ai/backend/manager/models/alembic/versions/f9971fbb34d9_add_state_column_to_vfolder_invitations.py b/src/ai/backend/manager/models/alembic/versions/f9971fbb34d9_add_state_column_to_vfolder_invitations.py new file mode 100644 index 0000000000..45e6f1103d --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/f9971fbb34d9_add_state_column_to_vfolder_invitations.py @@ -0,0 +1,30 @@ +"""add state column to vfolder_invitations + +Revision ID: f9971fbb34d9 +Revises: 185852ff9872 +Create Date: 2018-07-12 23:30:14.942845 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f9971fbb34d9' +down_revision = '185852ff9872' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('vfolder_invitations', sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True)) + op.add_column('vfolder_invitations', sa.Column('state', sa.String(length=10), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('vfolder_invitations', 'state') + op.drop_column('vfolder_invitations', 'created_at') + # ### end Alembic commands ### diff --git a/src/ai/backend/manager/models/alembic/versions/ff4bfca66bf8_.py b/src/ai/backend/manager/models/alembic/versions/ff4bfca66bf8_.py new file mode 100644 index 0000000000..d977f75971 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/ff4bfca66bf8_.py @@ -0,0 +1,22 @@ +"""empty message + +Revision ID: ff4bfca66bf8 +Revises: 0e558d06e0e3, 352fa4f88f61 +Create Date: 2018-12-24 22:42:54.188099 + +""" + + +# revision identifiers, used by Alembic. +revision = 'ff4bfca66bf8' +down_revision = ('0e558d06e0e3', '352fa4f88f61') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/ai/backend/manager/models/base.py b/src/ai/backend/manager/models/base.py new file mode 100644 index 0000000000..d7c32368b8 --- /dev/null +++ b/src/ai/backend/manager/models/base.py @@ -0,0 +1,845 @@ +from __future__ import annotations + +import asyncio +import collections +import enum +import functools +import logging +import trafaret as t +from typing import ( + Any, + Awaitable, + Callable, + ClassVar, + Dict, + Generic, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Protocol, + Sequence, + TYPE_CHECKING, + Type, + TypeVar, + Union, + cast, +) +import sys +import uuid + +from aiodataloader import DataLoader +from aiotools import apartial +import graphene +from graphene.types import Scalar +from graphql.language import ast +from graphene.types.scalars import MIN_INT, MAX_INT +import sqlalchemy as sa +from sqlalchemy.engine.result import Result +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import ( + AsyncConnection as SAConnection, + AsyncEngine as SAEngine, +) +from sqlalchemy.orm import ( + registry, +) +from sqlalchemy.types import ( + SchemaType, + TypeDecorator, + CHAR, +) +from sqlalchemy.dialects.postgresql import UUID, ENUM, JSONB +import yarl + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + BinarySize, + KernelId, + ResourceSlot, + SessionId, + JSONSerializableMixin, +) + +from ai.backend.manager.models.utils import execute_with_retry + +from .. import models +from ..api.exceptions import ( + GenericForbidden, InvalidAPIParameters, +) +if TYPE_CHECKING: + from graphql.execution.executors.asyncio import AsyncioExecutor + + from .gql import GraphQueryContext + from .user import UserRole + +SAFE_MIN_INT = -9007199254740991 +SAFE_MAX_INT = 9007199254740991 + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +# The common shared metadata instance +convention = { + "ix": 'ix_%(column_0_label)s', + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +} +metadata = sa.MetaData(naming_convention=convention) +mapper_registry = registry(metadata=metadata) +Base: Any = mapper_registry.generate_base() # TODO: remove Any after #422 is merged + +pgsql_connect_opts = { + 'server_settings': { + 'jit': 'off', + # 'deadlock_timeout': '10000', # FIXME: AWS RDS forbids settings this via connection arguments + 'lock_timeout': '60000', # 60 secs + 'idle_in_transaction_session_timeout': '60000', # 60 secs + }, +} + + +# helper functions +def zero_if_none(val): + return 0 if val is None else val + + +class EnumType(TypeDecorator, SchemaType): + """ + A stripped-down version of Spoqa's sqlalchemy-enum34. + It also handles postgres-specific enum type creation. + + The actual postgres enum choices are taken from the Python enum names. + """ + + impl = ENUM + cache_ok = True + + def __init__(self, enum_cls, **opts): + assert issubclass(enum_cls, enum.Enum) + if 'name' not in opts: + opts['name'] = enum_cls.__name__.lower() + self._opts = opts + enums = (m.name for m in enum_cls) + super().__init__(*enums, **opts) + self._enum_cls = enum_cls + + def process_bind_param(self, value, dialect): + return value.name if value else None + + def process_result_value(self, value: str, dialect): + return self._enum_cls[value] if value else None + + def copy(self): + return EnumType(self._enum_cls, **self._opts) + + @property + def python_type(self): + return self._enum_class + + +class EnumValueType(TypeDecorator, SchemaType): + """ + A stripped-down version of Spoqa's sqlalchemy-enum34. + It also handles postgres-specific enum type creation. + + The actual postgres enum choices are taken from the Python enum values. + """ + + impl = ENUM + cache_ok = True + + def __init__(self, enum_cls, **opts): + assert issubclass(enum_cls, enum.Enum) + if 'name' not in opts: + opts['name'] = enum_cls.__name__.lower() + self._opts = opts + enums = (m.value for m in enum_cls) + super().__init__(*enums, **opts) + self._enum_cls = enum_cls + + def process_bind_param(self, value, dialect): + return value.value if value else None + + def process_result_value(self, value: str, dialect): + return self._enum_cls(value) if value else None + + def copy(self): + return EnumValueType(self._enum_cls, **self._opts) + + @property + def python_type(self): + return self._enum_class + + +class ResourceSlotColumn(TypeDecorator): + """ + A column type wrapper for ResourceSlot from JSONB. + """ + + impl = JSONB + cache_ok = True + + def process_bind_param(self, value: Union[Mapping, ResourceSlot], dialect): + if isinstance(value, Mapping) and not isinstance(value, ResourceSlot): + return value + return value.to_json() if value is not None else None + + def process_result_value(self, raw_value: Dict[str, str], dialect): + # legacy handling + interim_value: Dict[str, Any] = raw_value + mem = raw_value.get('mem') + if isinstance(mem, str) and not mem.isdigit(): + interim_value['mem'] = BinarySize.from_str(mem) + return ResourceSlot.from_json(interim_value) if raw_value is not None else None + + def copy(self): + return ResourceSlotColumn() + + +class StructuredJSONColumn(TypeDecorator): + """ + A column type to convert JSON values back and forth using a Trafaret. + """ + + impl = JSONB + cache_ok = True + + def __init__(self, schema: t.Trafaret) -> None: + super().__init__() + self._schema = schema + + def load_dialect_impl(self, dialect): + if dialect.name == 'sqlite': + return dialect.type_descriptor(sa.JSON) + else: + return super().load_dialect_impl(dialect) + + def process_bind_param(self, value, dialect): + if value is None: + return self._schema.check({}) + try: + self._schema.check(value) + except t.DataError as e: + raise ValueError( + "The given value does not conform with the structured json column format.", + e.as_dict(), + ) + return value + + def process_result_value(self, raw_value, dialect): + if raw_value is None: + return self._schema.check({}) + return self._schema.check(raw_value) + + def copy(self): + return StructuredJSONColumn(self._schema) + + +class StructuredJSONObjectColumn(TypeDecorator): + """ + A column type to convert JSON values back and forth using JSONSerializableMixin. + """ + + impl = JSONB + cache_ok = True + + def __init__(self, schema: Type[JSONSerializableMixin]) -> None: + super().__init__() + self._schema = schema + + def process_bind_param(self, value, dialect): + return self._schema.to_json(value) + + def process_result_value(self, raw_value, dialect): + return self._schema.from_json(raw_value) + + def copy(self): + return StructuredJSONObjectColumn(self._schema) + + +class StructuredJSONObjectListColumn(TypeDecorator): + """ + A column type to convert JSON values back and forth using JSONSerializableMixin, + but store and load a list of the objects. + """ + + impl = JSONB + cache_ok = True + + def __init__(self, schema: Type[JSONSerializableMixin]) -> None: + super().__init__() + self._schema = schema + + def process_bind_param(self, value, dialect): + return [self._schema.to_json(item) for item in value] + + def process_result_value(self, raw_value, dialect): + if raw_value is None: + return [] + return [self._schema.from_json(item) for item in raw_value] + + def copy(self): + return StructuredJSONObjectListColumn(self._schema) + + +class URLColumn(TypeDecorator): + """ + A column type for URL strings + """ + + impl = sa.types.UnicodeText + cache_ok = True + + def process_bind_param(self, value, dialect): + if isinstance(value, yarl.URL): + return str(value) + return value + + def process_result_value(self, value, dialect): + if value is None: + return None + if value is not None: + return yarl.URL(value) + + +class CurrencyTypes(enum.Enum): + KRW = 'KRW' + USD = 'USD' + + +UUID_SubType = TypeVar('UUID_SubType', bound=uuid.UUID) + + +class GUID(TypeDecorator, Generic[UUID_SubType]): + """ + Platform-independent GUID type. + Uses PostgreSQL's UUID type, otherwise uses CHAR(16) storing as raw bytes. + """ + impl = CHAR + uuid_subtype_func: ClassVar[Callable[[Any], Any]] = lambda v: v + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(16)) + + def process_bind_param(self, value: Union[UUID_SubType, uuid.UUID], dialect): + # NOTE: SessionId, KernelId are *not* actual types defined as classes, + # but a "virtual" type that is an identity function at runtime. + # The type checker treats them as distinct derivatives of uuid.UUID. + # Therefore, we just do isinstance on uuid.UUID only below. + if value is None: + return value + elif dialect.name == 'postgresql': + if isinstance(value, uuid.UUID): + return str(value) + else: + return str(uuid.UUID(value)) + else: + if isinstance(value, uuid.UUID): + return value.bytes + else: + return uuid.UUID(value).bytes + + def process_result_value(self, value: Any, dialect) -> Optional[UUID_SubType]: + if value is None: + return value + else: + cls = type(self) + if isinstance(value, bytes): + return cast(UUID_SubType, cls.uuid_subtype_func(uuid.UUID(bytes=value))) + else: + return cast(UUID_SubType, cls.uuid_subtype_func(uuid.UUID(value))) + + +class SessionIDColumnType(GUID[SessionId]): + uuid_subtype_func = SessionId + cache_ok = True + + +class KernelIDColumnType(GUID[KernelId]): + uuid_subtype_func = KernelId + cache_ok = True + + +def IDColumn(name='id'): + return sa.Column(name, GUID, primary_key=True, + server_default=sa.text("uuid_generate_v4()")) + + +def SessionIDColumn(name='id'): + return sa.Column(name, SessionIDColumnType, primary_key=True, + server_default=sa.text("uuid_generate_v4()")) + + +def KernelIDColumn(name='id'): + return sa.Column(name, KernelIDColumnType, primary_key=True, + server_default=sa.text("uuid_generate_v4()")) + + +def ForeignKeyIDColumn(name, fk_field, nullable=True): + return sa.Column(name, GUID, sa.ForeignKey(fk_field), nullable=nullable) + + +class DataLoaderManager: + """ + For every different combination of filtering conditions, we need to make a + new DataLoader instance because it "batches" the database queries. + This manager get-or-creates dataloaders with fixed conditions (represetned + as arguments) like a cache. + + NOTE: Just like DataLoaders, it is recommended to instantiate this manager + for every incoming API request. + """ + + cache: Dict[int, DataLoader] + + def __init__(self) -> None: + self.cache = {} + self.mod = sys.modules['ai.backend.manager.models'] + + @staticmethod + def _get_key(otname: str, args, kwargs) -> int: + """ + Calculate the hash of the all arguments and keyword arguments. + """ + key = (otname, ) + args + for item in kwargs.items(): + key += item + return hash(key) + + def get_loader(self, context: GraphQueryContext, objtype_name: str, *args, **kwargs) -> DataLoader: + k = self._get_key(objtype_name, args, kwargs) + loader = self.cache.get(k) + if loader is None: + objtype_name, has_variant, variant_name = objtype_name.partition('.') + objtype = getattr(self.mod, objtype_name) + if has_variant: + batch_load_fn = getattr(objtype, 'batch_load_' + variant_name) + else: + batch_load_fn = objtype.batch_load + loader = DataLoader( + apartial(batch_load_fn, context, *args, **kwargs), + max_batch_size=128, + ) + self.cache[k] = loader + return loader + + +class ResourceLimit(graphene.ObjectType): + key = graphene.String() + min = graphene.String() + max = graphene.String() + + +class KVPair(graphene.ObjectType): + key = graphene.String() + value = graphene.String() + + +class ResourceLimitInput(graphene.InputObjectType): + key = graphene.String() + min = graphene.String() + max = graphene.String() + + +class KVPairInput(graphene.InputObjectType): + key = graphene.String() + value = graphene.String() + + +class BigInt(Scalar): + """ + BigInt is an extension of the regular graphene.Int scalar type + to support integers outside the range of a signed 32-bit integer. + """ + + @staticmethod + def coerce_bigint(value): + num = int(value) + if not (SAFE_MIN_INT <= num <= SAFE_MAX_INT): + raise ValueError( + 'Cannot serialize integer out of the safe range.') + if not (MIN_INT <= num <= MAX_INT): + # treat as float + return float(int(num)) + return num + + serialize = coerce_bigint + parse_value = coerce_bigint + + @staticmethod + def parse_literal(node): + if isinstance(node, ast.IntValue): + num = int(node.value) + if not (SAFE_MIN_INT <= num <= SAFE_MAX_INT): + raise ValueError( + 'Cannot parse integer out of the safe range.') + if not (MIN_INT <= num <= MAX_INT): + # treat as float + return float(int(num)) + return num + + +class Item(graphene.Interface): + id = graphene.ID() + + +class PaginatedList(graphene.Interface): + items = graphene.List(Item, required=True) + total_count = graphene.Int(required=True) + + +# ref: https://github.com/python/mypy/issues/1212 +_GenericSQLBasedGQLObject = TypeVar('_GenericSQLBasedGQLObject', bound='_SQLBasedGQLObject') +_Key = TypeVar('_Key') + + +class _SQLBasedGQLObject(Protocol): + @classmethod + def from_row( + cls: Type[_GenericSQLBasedGQLObject], + ctx: GraphQueryContext, + row: Row, + ) -> _GenericSQLBasedGQLObject: + ... + + +async def batch_result( + graph_ctx: GraphQueryContext, + db_conn: SAConnection, + query: sa.sql.Select, + obj_type: Type[_GenericSQLBasedGQLObject], + key_list: Iterable[_Key], + key_getter: Callable[[Row], _Key], +) -> Sequence[Optional[_GenericSQLBasedGQLObject]]: + """ + A batched query adaptor for (key -> item) resolving patterns. + """ + objs_per_key: Dict[_Key, Optional[_GenericSQLBasedGQLObject]] + objs_per_key = collections.OrderedDict() + for key in key_list: + objs_per_key[key] = None + async for row in (await db_conn.stream(query)): + objs_per_key[key_getter(row)] = obj_type.from_row(graph_ctx, row) + return [*objs_per_key.values()] + + +async def batch_multiresult( + graph_ctx: GraphQueryContext, + db_conn: SAConnection, + query: sa.sql.Select, + obj_type: Type[_GenericSQLBasedGQLObject], + key_list: Iterable[_Key], + key_getter: Callable[[Row], _Key], +) -> Sequence[Sequence[_GenericSQLBasedGQLObject]]: + """ + A batched query adaptor for (key -> [item]) resolving patterns. + """ + objs_per_key: Dict[_Key, List[_GenericSQLBasedGQLObject]] + objs_per_key = collections.OrderedDict() + for key in key_list: + objs_per_key[key] = list() + async for row in (await db_conn.stream(query)): + objs_per_key[key_getter(row)].append( + obj_type.from_row(graph_ctx, row), + ) + return [*objs_per_key.values()] + + +def privileged_query(required_role: UserRole): + + def wrap(func): + + @functools.wraps(func) + async def wrapped(executor: AsyncioExecutor, info: graphene.ResolveInfo, *args, **kwargs) -> Any: + from .user import UserRole + ctx: GraphQueryContext = info.context + if ctx.user['role'] != UserRole.SUPERADMIN: + raise GenericForbidden('superadmin privilege required') + return await func(executor, info, *args, **kwargs) + + return wrapped + + return wrap + + +def scoped_query( + *, + autofill_user: bool = False, + user_key: str = 'access_key', +): + """ + Prepends checks for domain/group/user access rights depending + on the client's user and keypair information. + + :param autofill_user: When the *user_key* is not specified, + automatically fills out the user data with the current + user who is makeing the API request. + :param user_key: The key used for storing user identification value + in the keyword arguments. + """ + + def wrap(resolve_func): + + @functools.wraps(resolve_func) + async def wrapped(executor: AsyncioExecutor, info: graphene.ResolveInfo, *args, **kwargs) -> Any: + from .user import UserRole + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + if user_key == 'access_key': + client_user_id = ctx.access_key + elif user_key == 'email': + client_user_id = ctx.user['email'] + else: + client_user_id = ctx.user['uuid'] + client_domain = ctx.user['domain_name'] + domain_name = kwargs.get('domain_name', None) + group_id = kwargs.get('group_id', None) + user_id = kwargs.get(user_key, None) + if client_role == UserRole.SUPERADMIN: + if autofill_user: + if user_id is None: + user_id = client_user_id + elif client_role == UserRole.ADMIN: + if domain_name is not None and domain_name != client_domain: + raise GenericForbidden + domain_name = client_domain + if group_id is not None: + # TODO: check if the group is a member of the domain + pass + if autofill_user: + if user_id is None: + user_id = client_user_id + elif client_role == UserRole.USER: + if domain_name is not None and domain_name != client_domain: + raise GenericForbidden + domain_name = client_domain + if group_id is not None: + # TODO: check if the group is a member of the domain + # TODO: check if the client is a member of the group + pass + if user_id is not None and user_id != client_user_id: + raise GenericForbidden + user_id = client_user_id + else: + raise InvalidAPIParameters('Unknown client role') + kwargs['domain_name'] = domain_name + if group_id is not None: + kwargs['group_id'] = group_id + kwargs[user_key] = user_id + return await resolve_func(executor, info, *args, **kwargs) + + return wrapped + + return wrap + + +def privileged_mutation(required_role, target_func=None): + + def wrap(func): + + @functools.wraps(func) + async def wrapped(cls, root, info: graphene.ResolveInfo, *args, **kwargs) -> Any: + from .user import UserRole + from .group import groups # , association_groups_users + ctx: GraphQueryContext = info.context + permitted = False + if required_role == UserRole.SUPERADMIN: + if ctx.user['role'] == required_role: + permitted = True + elif required_role == UserRole.ADMIN: + if ctx.user['role'] == UserRole.SUPERADMIN: + permitted = True + elif ctx.user['role'] == UserRole.USER: + permitted = False + else: + if target_func is None: + return cls(False, 'misconfigured privileged mutation: no target_func', None) + target_domain, target_group = target_func(*args, **kwargs) + if target_domain is None and target_group is None: + return cls(False, 'misconfigured privileged mutation: ' + 'both target_domain and target_group missing', None) + permit_chains = [] + if target_domain is not None: + if ctx.user['domain_name'] == target_domain: + permit_chains.append(True) + if target_group is not None: + async with ctx.db.begin() as conn: + # check if the group is part of the requester's domain. + query = ( + groups.select() + .where( + (groups.c.id == target_group) & + (groups.c.domain_name == ctx.user['domain_name']), + ) + ) + result = await conn.execute(query) + if result.rowcount > 0: + permit_chains.append(True) + # TODO: check the group permission if implemented + # query = ( + # association_groups_users.select() + # .where(association_groups_users.c.group_id == target_group) + # ) + # result = await conn.execute(query) + # if result.rowcount > 0: + # permit_chains.append(True) + permitted = all(permit_chains) if permit_chains else False + elif required_role == UserRole.USER: + permitted = True + # assuming that mutation result objects has 2 or 3 fields: + # success(bool), message(str) - usually for delete mutations + # success(bool), message(str), item(object) + if permitted: + return await func(cls, root, info, *args, **kwargs) + return cls(False, f"no permission to execute {info.path[0]}") + + return wrapped + + return wrap + + +ResultType = TypeVar('ResultType', bound=graphene.ObjectType) +ItemType = TypeVar('ItemType', bound=graphene.ObjectType) + + +async def simple_db_mutate( + result_cls: Type[ResultType], + graph_ctx: GraphQueryContext, + mutation_query: sa.sql.Update | sa.sql.Insert | Callable[[], sa.sql.Update | sa.sql.Insert], + *, + pre_func: Callable[[SAConnection], Awaitable[None]] | None = None, + post_func: Callable[[SAConnection, Result], Awaitable[None]] | None = None, +) -> ResultType: + """ + Performs a database mutation based on the given + :class:`sqlalchemy.sql.Update` or :class:`sqlalchemy.sql.Insert` query, + and return the wrapped result as the GraphQL object type given as **result_cls**. + **result_cls** should have two initialization arguments: success (bool) + and message (str). + + See details about the arguments in :func:`simple_db_mutate_returning_item`. + """ + + async def _do_mutate() -> ResultType: + async with graph_ctx.db.begin() as conn: + if pre_func: + await pre_func(conn) + _query = mutation_query() if callable(mutation_query) else mutation_query + result = await conn.execute(_query) + if post_func: + await post_func(conn, result) + if result.rowcount > 0: + return result_cls(True, "success") + else: + return result_cls(False, f"no matching {result_cls.__name__.lower()}") + + try: + return await execute_with_retry(_do_mutate) + except sa.exc.IntegrityError as e: + return result_cls(False, f"integrity error: {e}") + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return result_cls(False, f"unexpected error: {e}") + + +async def simple_db_mutate_returning_item( + result_cls: Type[ResultType], + graph_ctx: GraphQueryContext, + mutation_query: sa.sql.Update | sa.sql.Insert | Callable[[], sa.sql.Update | sa.sql.Insert], + *, + item_cls: Type[ItemType], + pre_func: Callable[[SAConnection], Awaitable[None]] | None = None, + post_func: Callable[[SAConnection, Result], Awaitable[Row]] | None = None, +) -> ResultType: + """ + Performs a database mutation based on the given + :class:`sqlalchemy.sql.Update` or :class:`sqlalchemy.sql.Insert` query, + and return the wrapped result as the GraphQL object type given as **result_cls** + and the inserted/updated row wrapped as its 3rd argument in **item_cls**. + + If mutation_query uses external variable updated by pre_func, you should wrap the query + with lambda so that its parameters are re-evaluated when the transaction is retried. + + :param result_cls: The GraphQL Object Type used to wrap the result. + It should have two initialization arguments: success (bool), + message (str), and the item (ItemType). + :param graph_ctx: The common context that provides the reference to the database engine + and other stuffs required to resolve the GraphQL query. + :param mutation_query: A SQLAlchemy query object. + :param item_cls: The GraphQL Object Type used to wrap the returned row from the mutation query. + :param pre_func: An extra function that is executed before the mutation query, where the caller + may perform additional database queries. + :param post_func: An extra function that is executed after the mutation query, where the caller + may perform additional database queries. Note that it **MUST return the returned row + from the given mutation result**, because the result object could be fetched only one + time due to its cursor-like nature. + """ + + async def _do_mutate() -> ResultType: + async with graph_ctx.db.begin() as conn: + if pre_func: + await pre_func(conn) + _query = mutation_query() if callable(mutation_query) else mutation_query + _query = _query.returning(_query.table) + result = await conn.execute(_query) + if post_func: + row = await post_func(conn, result) + else: + row = result.first() + if result.rowcount > 0: + return result_cls(True, "success", item_cls.from_row(graph_ctx, row)) + else: + return result_cls(False, f"no matching {result_cls.__name__.lower()}", None) + + try: + return await execute_with_retry(_do_mutate) + except sa.exc.IntegrityError as e: + return result_cls(False, f"integrity error: {e}", None) + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return result_cls(False, f"unexpected error: {e}", None) + + +def set_if_set( + src: object, target: MutableMapping[str, Any], name: str, *, + clean_func=None, target_key: Optional[str] = None, +) -> None: + v = getattr(src, name) + # NOTE: unset optional fields are passed as null. + if v is not None: + if callable(clean_func): + target[target_key or name] = clean_func(v) + else: + target[target_key or name] = v + + +async def populate_fixture( + engine: SAEngine, + fixture_data: Mapping[str, Sequence[Dict[str, Any]]], + *, + ignore_unique_violation: bool = False, +) -> None: + for table_name, rows in fixture_data.items(): + table: sa.Table = getattr(models, table_name) + assert isinstance(table, sa.Table) + async with engine.begin() as conn: + for col in table.columns: + if isinstance(col.type, EnumType): + for row in rows: + row[col.name] = col.type._enum_cls[row[col.name]] + elif isinstance(col.type, EnumValueType): + for row in rows: + row[col.name] = col.type._enum_cls(row[col.name]) + elif isinstance(col.type, (StructuredJSONObjectColumn, StructuredJSONObjectListColumn)): + for row in rows: + row[col.name] = col.type._schema.from_json(row[col.name]) + await conn.execute(sa.dialects.postgresql.insert(table, rows).on_conflict_do_nothing()) diff --git a/src/ai/backend/manager/models/domain.py b/src/ai/backend/manager/models/domain.py new file mode 100644 index 0000000000..dc1a34ad10 --- /dev/null +++ b/src/ai/backend/manager/models/domain.py @@ -0,0 +1,381 @@ +from __future__ import annotations + +import logging +import re +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + TYPE_CHECKING, + TypedDict, +) + +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +from sqlalchemy.engine.row import Row +from sqlalchemy.dialects import postgresql as pgsql + +from ai.backend.common import msgpack +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ResourceSlot +from .base import ( + metadata, ResourceSlotColumn, + simple_db_mutate, + simple_db_mutate_returning_item, + set_if_set, + batch_result, +) +from .scaling_group import ScalingGroup +from .user import UserRole +from ..defs import RESERVED_DOTFILES + +if TYPE_CHECKING: + from .gql import GraphQueryContext + +log = BraceStyleAdapter(logging.getLogger(__file__)) + + +__all__: Sequence[str] = ( + 'domains', + 'Domain', 'DomainInput', 'ModifyDomainInput', + 'CreateDomain', 'ModifyDomain', 'DeleteDomain', + 'DomainDotfile', 'MAXIMUM_DOTFILE_SIZE', + 'query_domain_dotfiles', + 'verify_dotfile_name', +) + +MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB +_rx_slug = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?$') + +domains = sa.Table( + 'domains', metadata, + sa.Column('name', sa.String(length=64), primary_key=True), + sa.Column('description', sa.String(length=512)), + sa.Column('is_active', sa.Boolean, default=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), onupdate=sa.func.current_timestamp()), + # TODO: separate resource-related fields with new domain resource policy table when needed. + sa.Column('total_resource_slots', ResourceSlotColumn(), default='{}'), + sa.Column('allowed_vfolder_hosts', pgsql.ARRAY(sa.String), nullable=False, default='{}'), + sa.Column('allowed_docker_registries', pgsql.ARRAY(sa.String), nullable=False, default='{}'), + #: Field for synchronization with external services. + sa.Column('integration_id', sa.String(length=512)), + # dotfiles column, \x90 means empty list in msgpack + sa.Column('dotfiles', sa.LargeBinary(length=MAXIMUM_DOTFILE_SIZE), nullable=False, default=b'\x90'), +) + + +class Domain(graphene.ObjectType): + name = graphene.String() + description = graphene.String() + is_active = graphene.Boolean() + created_at = GQLDateTime() + modified_at = GQLDateTime() + total_resource_slots = graphene.JSONString() + allowed_vfolder_hosts = graphene.List(lambda: graphene.String) + allowed_docker_registries = graphene.List(lambda: graphene.String) + integration_id = graphene.String() + + # Dynamic fields. + scaling_groups = graphene.List(lambda: graphene.String) + + async def resolve_scaling_groups(self, info: graphene.ResolveInfo) -> Sequence[str]: + sgroups = await ScalingGroup.load_by_domain(info.context, self.name) + return [sg.name for sg in sgroups] + + @classmethod + def from_row(cls, ctx: GraphQueryContext, row: Row) -> Optional[Domain]: + if row is None: + return None + return cls( + name=row['name'], + description=row['description'], + is_active=row['is_active'], + created_at=row['created_at'], + modified_at=row['modified_at'], + total_resource_slots=row['total_resource_slots'].to_json(), + allowed_vfolder_hosts=row['allowed_vfolder_hosts'], + allowed_docker_registries=row['allowed_docker_registries'], + integration_id=row['integration_id'], + ) + + @classmethod + async def load_all( + cls, + ctx: GraphQueryContext, + *, + is_active: bool = None, + ) -> Sequence[Domain]: + async with ctx.db.begin_readonly() as conn: + query = sa.select([domains]).select_from(domains) + if is_active is not None: + query = query.where(domains.c.is_active == is_active) + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(ctx, row)) is not None + ] + + @classmethod + async def batch_load_by_name( + cls, + ctx: GraphQueryContext, + names: Sequence[str], + *, + is_active: bool = None, + ) -> Sequence[Optional[Domain]]: + async with ctx.db.begin_readonly() as conn: + query = ( + sa.select([domains]) + .select_from(domains) + .where(domains.c.name.in_(names)) + ) + if is_active is not None: + query = query.where(domains.c.is_active == is_active) + return await batch_result( + ctx, conn, query, cls, + names, lambda row: row['name'], + ) + + +class DomainInput(graphene.InputObjectType): + description = graphene.String(required=False) + is_active = graphene.Boolean(required=False, default=True) + total_resource_slots = graphene.JSONString(required=False) + allowed_vfolder_hosts = graphene.List(lambda: graphene.String, required=False) + allowed_docker_registries = graphene.List(lambda: graphene.String, required=False) + integration_id = graphene.String(required=False) + + +class ModifyDomainInput(graphene.InputObjectType): + name = graphene.String(required=False) + description = graphene.String(required=False) + is_active = graphene.Boolean(required=False) + total_resource_slots = graphene.JSONString(required=False) + allowed_vfolder_hosts = graphene.List(lambda: graphene.String, required=False) + allowed_docker_registries = graphene.List(lambda: graphene.String, required=False) + integration_id = graphene.String(required=False) + + +class CreateDomain(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + props = DomainInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + domain = graphene.Field(lambda: Domain, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: DomainInput, + ) -> CreateDomain: + if _rx_slug.search(name) is None: + return cls(False, 'invalid name format. slug format required.', None) + ctx: GraphQueryContext = info.context + data = { + 'name': name, + 'description': props.description, + 'is_active': props.is_active, + 'total_resource_slots': ResourceSlot.from_user_input( + props.total_resource_slots, None), + 'allowed_vfolder_hosts': props.allowed_vfolder_hosts, + 'allowed_docker_registries': props.allowed_docker_registries, + 'integration_id': props.integration_id, + } + insert_query = ( + sa.insert(domains).values(data) + ) + return await simple_db_mutate_returning_item(cls, ctx, insert_query, item_cls=Domain) + + +class ModifyDomain(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + props = ModifyDomainInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + domain = graphene.Field(lambda: Domain, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: ModifyDomainInput, + ) -> ModifyDomain: + ctx: GraphQueryContext = info.context + data: Dict[str, Any] = {} + set_if_set(props, data, 'name') # data['name'] is new domain name + set_if_set(props, data, 'description') + set_if_set(props, data, 'is_active') + set_if_set(props, data, 'total_resource_slots', + clean_func=lambda v: ResourceSlot.from_user_input(v, None)) + set_if_set(props, data, 'allowed_vfolder_hosts') + set_if_set(props, data, 'allowed_docker_registries') + set_if_set(props, data, 'integration_id') + if 'name' in data and _rx_slug.search(data['name']) is None: + raise ValueError('invalid name format. slug format required.') + update_query = ( + sa.update(domains).values(data).where(domains.c.name == name) + ) + return await simple_db_mutate_returning_item(cls, ctx, update_query, item_cls=Domain) + + +class DeleteDomain(graphene.Mutation): + """ + Instead of deleting the domain, just mark it as inactive. + """ + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate(cls, root, info: graphene.ResolveInfo, name: str) -> DeleteDomain: + ctx: GraphQueryContext = info.context + update_query = ( + sa.update(domains) + .values(is_active=False) + .where(domains.c.name == name) + ) + return await simple_db_mutate(cls, ctx, update_query) + + +class PurgeDomain(graphene.Mutation): + """ + Completely delete domain from DB. + + Domain-bound kernels will also be all deleted. + To purge domain, there should be no users and groups in the target domain. + """ + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate(cls, root, info: graphene.ResolveInfo, name: str) -> PurgeDomain: + from . import users, groups + ctx: GraphQueryContext = info.context + + async def _pre_func(conn: SAConnection) -> None: + if await cls.domain_has_active_kernels(conn, name): + raise RuntimeError("Domain has some active kernels. Terminate them first.") + query = ( + sa.select([sa.func.count()]) + .where(users.c.domain_name == name) + ) + user_count = await conn.scalar(query) + if user_count > 0: + raise RuntimeError("There are users bound to the domain. Remove users first.") + query = ( + sa.select([sa.func.count()]) + .where(groups.c.domain_name == name) + ) + group_count = await conn.scalar(query) + if group_count > 0: + raise RuntimeError("There are groups bound to the domain. Remove groups first.") + + await cls.delete_kernels(conn, name) + + delete_query = (sa.delete(domains).where(domains.c.name == name)) + return await simple_db_mutate(cls, ctx, delete_query, pre_func=_pre_func) + + @classmethod + async def delete_kernels( + cls, + conn: SAConnection, + domain_name: str, + ) -> int: + """ + Delete all kernels run from the target domain. + + :param conn: DB connection + :param domain_name: domain's name to delete kernels + + :return: number of deleted rows + """ + from . import kernels + delete_query = ( + sa.delete(kernels) + .where(kernels.c.domain_name == domain_name) + ) + result = await conn.execute(delete_query) + if result.rowcount > 0: + log.info("deleted {0} domain\"s kernels ({1})", result.rowcount, domain_name) + return result.rowcount + + @classmethod + async def domain_has_active_kernels( + cls, + conn: SAConnection, + domain_name: str, + ) -> bool: + """ + Check if the domain does not have active kernels. + + :param conn: DB connection + :param domain_name: domain's name + + :return: True if the domain has some active kernels. + """ + from . import kernels, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES + query = ( + sa.select([sa.func.count()]) + .select_from(kernels) + .where((kernels.c.domain_name == domain_name) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))) + ) + active_kernel_count = await conn.scalar(query) + return (active_kernel_count > 0) + + +class DomainDotfile(TypedDict): + data: str + path: str + perm: str + + +async def query_domain_dotfiles( + conn: SAConnection, + name: str, +) -> tuple[List[DomainDotfile], int]: + query = ( + sa.select([domains.c.dotfiles]) + .select_from(domains) + .where(domains.c.name == name) + ) + packed_dotfile = await conn.scalar(query) + if packed_dotfile is None: + return [], MAXIMUM_DOTFILE_SIZE + rows = msgpack.unpackb(packed_dotfile) + return rows, MAXIMUM_DOTFILE_SIZE - len(packed_dotfile) + + +def verify_dotfile_name(dotfile: str) -> bool: + if dotfile in RESERVED_DOTFILES: + return False + return True diff --git a/src/ai/backend/manager/models/dotfile.py b/src/ai/backend/manager/models/dotfile.py new file mode 100644 index 0000000000..911de78231 --- /dev/null +++ b/src/ai/backend/manager/models/dotfile.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from pathlib import PurePosixPath +from typing import Any, Mapping, Sequence, TYPE_CHECKING + +import sqlalchemy as sa +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import ( + AsyncConnection as SAConnection, + ) + +from ai.backend.common import msgpack +from ai.backend.common.types import VFolderMount + +from ..api.exceptions import BackendError +from ..types import UserScope +from .keypair import keypairs +from .domain import query_domain_dotfiles +from .group import query_group_dotfiles + +__all__ = ( + 'prepare_dotfiles', +) + + +async def prepare_dotfiles( + conn: SAConnection, + user_scope: UserScope, + access_key: str, + vfolder_mounts: Sequence[VFolderMount], +) -> Mapping[str, Any]: + # Feed SSH keypair and dotfiles if exists. + internal_data = {} + query = ( + sa.select([ + keypairs.c.ssh_public_key, + keypairs.c.ssh_private_key, + keypairs.c.dotfiles, + ]) + .select_from(keypairs) + .where(keypairs.c.access_key == access_key) + ) + result = await conn.execute(query) + row = result.first() + dotfiles = msgpack.unpackb(row['dotfiles']) + internal_data.update({'dotfiles': dotfiles}) + if row['ssh_public_key'] and row['ssh_private_key']: + internal_data['ssh_keypair'] = { + 'public_key': row['ssh_public_key'], + 'private_key': row['ssh_private_key'], + } + # use dotfiles in the priority of keypair > group > domain + dotfile_paths = set(map(lambda x: x['path'], dotfiles)) + # add keypair dotfiles + internal_data.update({'dotfiles': list(dotfiles)}) + # add group dotfiles + dotfiles, _ = await query_group_dotfiles(conn, user_scope.group_id) + for dotfile in dotfiles: + if dotfile['path'] not in dotfile_paths: + internal_data['dotfiles'].append(dotfile) + dotfile_paths.add(dotfile['path']) + # add domain dotfiles + dotfiles, _ = await query_domain_dotfiles(conn, user_scope.domain_name) + for dotfile in dotfiles: + if dotfile['path'] not in dotfile_paths: + internal_data['dotfiles'].append(dotfile) + dotfile_paths.add(dotfile['path']) + # reverse the dotfiles list so that higher priority can overwrite + # in case the actual path is the same + internal_data['dotfiles'].reverse() + + # check if there is no name conflict of dotfile and vfolder + vfolder_kernel_paths = {m.kernel_path for m in vfolder_mounts} + for dotfile in internal_data.get('dotfiles', []): + dotfile_path = PurePosixPath(dotfile['path']) + if not dotfile_path.is_absolute(): + dotfile_path = PurePosixPath('/home/work', dotfile['path']) + if dotfile_path in vfolder_kernel_paths: + raise BackendError( + f"There is a kernel-side path from vfolders that conflicts with " + f"a dotfile '{dotfile['path']}'.", + ) + + return internal_data diff --git a/src/ai/backend/manager/models/error_logs.py b/src/ai/backend/manager/models/error_logs.py new file mode 100644 index 0000000000..8e8bb9e3bb --- /dev/null +++ b/src/ai/backend/manager/models/error_logs.py @@ -0,0 +1,26 @@ +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from .base import metadata, IDColumn, GUID +__all__ = [ + 'error_logs', +] + +error_logs = sa.Table( + 'error_logs', metadata, + IDColumn(), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True), + sa.Column('severity', sa.Enum('critical', 'error', 'warning', name='errorlog_severity'), + index=True), + sa.Column('source', sa.String), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=True, index=True), + sa.Column('is_read', sa.Boolean, default=False, index=True), + sa.Column('is_cleared', sa.Boolean, default=False, index=True), + sa.Column('message', sa.Text), + sa.Column('context_lang', sa.String), + sa.Column('context_env', postgresql.JSONB()), + sa.Column('request_url', sa.String, nullable=True), + sa.Column('request_status', sa.Integer, nullable=True), + sa.Column('traceback', sa.Text, nullable=True), +) diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py new file mode 100644 index 0000000000..7dc8ee7543 --- /dev/null +++ b/src/ai/backend/manager/models/gql.py @@ -0,0 +1,1294 @@ +from __future__ import annotations + +from typing import Any, Optional, Mapping, Sequence, TYPE_CHECKING +import uuid + +import attr +import graphene + +from ai.backend.manager.defs import DEFAULT_IMAGE_ARCH + +if TYPE_CHECKING: + from graphql.execution.executors.asyncio import AsyncioExecutor + + from ai.backend.common.bgtask import BackgroundTaskManager + from ai.backend.common.etcd import AsyncEtcd + from ai.backend.common.types import ( + AccessKey, + AgentId, + RedisConnectionInfo, + SlotName, + SlotTypes, + SessionId, + ) + + from ..api.manager import ManagerStatus + from ..config import LocalConfig, SharedConfig + from ..registry import AgentRegistry + from ..models.utils import ExtendedAsyncSAEngine + from .storage import StorageSessionManager + +from .base import DataLoaderManager, privileged_query, scoped_query +from .agent import ( + Agent, + AgentList, + ModifyAgent, +) +from .domain import ( + Domain, + CreateDomain, + ModifyDomain, + DeleteDomain, + PurgeDomain, +) +from .group import ( + Group, + CreateGroup, + ModifyGroup, + DeleteGroup, + PurgeGroup, +) +from .image import ( + ClearImages, + Image, + ModifyImage, + RescanImages, + PreloadImage, + UnloadImage, + ForgetImage, + AliasImage, + DealiasImage, +) +from .kernel import ( + ComputeSession, + ComputeSessionList, + ComputeContainer, + ComputeContainerList, + LegacyComputeSession, + LegacyComputeSessionList, +) +from .keypair import ( + KeyPair, + KeyPairList, + CreateKeyPair, + ModifyKeyPair, + DeleteKeyPair, +) +from .resource_policy import ( + KeyPairResourcePolicy, + CreateKeyPairResourcePolicy, + ModifyKeyPairResourcePolicy, + DeleteKeyPairResourcePolicy, +) +from .resource_preset import ( + ResourcePreset, + CreateResourcePreset, + ModifyResourcePreset, + DeleteResourcePreset, +) +from .scaling_group import ( + ScalingGroup, + CreateScalingGroup, + ModifyScalingGroup, + DeleteScalingGroup, + AssociateScalingGroupWithDomain, + DisassociateScalingGroupWithDomain, + DisassociateAllScalingGroupsWithDomain, + AssociateScalingGroupWithUserGroup, + DisassociateScalingGroupWithUserGroup, + DisassociateAllScalingGroupsWithGroup, + AssociateScalingGroupWithKeyPair, + DisassociateScalingGroupWithKeyPair, +) +from .storage import ( + StorageVolume, + StorageVolumeList, +) +from .user import ( + User, + UserList, + CreateUser, + ModifyUser, + DeleteUser, + PurgeUser, + UserRole, + UserStatus, +) +from .vfolder import ( + VirtualFolder, + VirtualFolderList, +) +from ..api.exceptions import ( + ObjectNotFound, + ImageNotFound, + InsufficientPrivilege, + InvalidAPIParameters, + TooManyKernelsFound, +) + + +@attr.s(auto_attribs=True, slots=True) +class GraphQueryContext: + schema: graphene.Schema + dataloader_manager: DataLoaderManager + local_config: LocalConfig + shared_config: SharedConfig + etcd: AsyncEtcd + user: Mapping[str, Any] # TODO: express using typed dict + access_key: str + db: ExtendedAsyncSAEngine + redis_stat: RedisConnectionInfo + redis_image: RedisConnectionInfo + manager_status: ManagerStatus + known_slot_types: Mapping[SlotName, SlotTypes] + background_task_manager: BackgroundTaskManager + storage_manager: StorageSessionManager + registry: AgentRegistry + + +class Mutations(graphene.ObjectType): + """ + All available GraphQL mutations. + """ + + # super-admin only + modify_agent = ModifyAgent.Field() + + # super-admin only + create_domain = CreateDomain.Field() + modify_domain = ModifyDomain.Field() + delete_domain = DeleteDomain.Field() + purge_domain = PurgeDomain.Field() + + # admin only + create_group = CreateGroup.Field() + modify_group = ModifyGroup.Field() + delete_group = DeleteGroup.Field() + purge_group = PurgeGroup.Field() + + # super-admin only + create_user = CreateUser.Field() + modify_user = ModifyUser.Field() + delete_user = DeleteUser.Field() + purge_user = PurgeUser.Field() + + # admin only + create_keypair = CreateKeyPair.Field() + modify_keypair = ModifyKeyPair.Field() + delete_keypair = DeleteKeyPair.Field() + + # admin only + rescan_images = RescanImages.Field() + preload_image = PreloadImage.Field() + unload_image = UnloadImage.Field() + modify_image = ModifyImage.Field() + forget_image = ForgetImage.Field() + alias_image = AliasImage.Field() + dealias_image = DealiasImage.Field() + clear_images = ClearImages.Field() + + # super-admin only + create_keypair_resource_policy = CreateKeyPairResourcePolicy.Field() + modify_keypair_resource_policy = ModifyKeyPairResourcePolicy.Field() + delete_keypair_resource_policy = DeleteKeyPairResourcePolicy.Field() + + # super-admin only + create_resource_preset = CreateResourcePreset.Field() + modify_resource_preset = ModifyResourcePreset.Field() + delete_resource_preset = DeleteResourcePreset.Field() + + # super-admin only + create_scaling_group = CreateScalingGroup.Field() + modify_scaling_group = ModifyScalingGroup.Field() + delete_scaling_group = DeleteScalingGroup.Field() + associate_scaling_group_with_domain = AssociateScalingGroupWithDomain.Field() + associate_scaling_group_with_user_group = AssociateScalingGroupWithUserGroup.Field() + associate_scaling_group_with_keypair = AssociateScalingGroupWithKeyPair.Field() + disassociate_scaling_group_with_domain = DisassociateScalingGroupWithDomain.Field() + disassociate_scaling_group_with_user_group = DisassociateScalingGroupWithUserGroup.Field() + disassociate_scaling_group_with_keypair = DisassociateScalingGroupWithKeyPair.Field() + disassociate_all_scaling_groups_with_domain = DisassociateAllScalingGroupsWithDomain.Field() + disassociate_all_scaling_groups_with_group = DisassociateAllScalingGroupsWithGroup.Field() + + +class Queries(graphene.ObjectType): + """ + All available GraphQL queries. + """ + + # super-admin only + agent = graphene.Field( + Agent, + agent_id=graphene.String(required=True), + ) + + # super-admin only + agent_list = graphene.Field( + AgentList, + limit=graphene.Int(required=True), + offset=graphene.Int(required=True), + filter=graphene.String(), + order=graphene.String(), + # filters + scaling_group=graphene.String(), + status=graphene.String(), + ) + + # super-admin only + agents = graphene.List( # legacy non-paginated list + Agent, + scaling_group=graphene.String(), + status=graphene.String(), + ) + + domain = graphene.Field( + Domain, + name=graphene.String(), + ) + + # super-admin only + domains = graphene.List( + Domain, + is_active=graphene.Boolean(), + ) + + group = graphene.Field( + Group, + id=graphene.UUID(required=True), + domain_name=graphene.String(), + ) + + # Within a single domain, this will always return nothing or a single item, + # but if queried across all domains by superadmins, it may return multiple results + # because the group name is unique only inside each domain. + groups_by_name = graphene.List( + Group, + name=graphene.String(required=True), + domain_name=graphene.String(), + ) + + groups = graphene.List( + Group, + domain_name=graphene.String(), + is_active=graphene.Boolean(), + ) + + image = graphene.Field( + Image, + reference=graphene.String(required=True), + architecture=graphene.String(default_value=DEFAULT_IMAGE_ARCH), + ) + + images = graphene.List( + Image, + is_installed=graphene.Boolean(), + is_operation=graphene.Boolean(), + ) + + user = graphene.Field( + User, + domain_name=graphene.String(), + email=graphene.String(), + ) + + user_from_uuid = graphene.Field( + User, + domain_name=graphene.String(), + user_id=graphene.ID(), + ) + + users = graphene.List( # legacy non-paginated list + User, + domain_name=graphene.String(), + group_id=graphene.UUID(), + is_active=graphene.Boolean(), + status=graphene.String(), + ) + + user_list = graphene.Field( + UserList, + limit=graphene.Int(required=True), + offset=graphene.Int(required=True), + filter=graphene.String(), + order=graphene.String(), + # intrinsic filters + domain_name=graphene.String(), + group_id=graphene.UUID(), + is_active=graphene.Boolean(), + status=graphene.String(), + ) + + keypair = graphene.Field( + KeyPair, + domain_name=graphene.String(), + access_key=graphene.String(), + ) + + keypairs = graphene.List( # legacy non-paginated list + KeyPair, + domain_name=graphene.String(), + email=graphene.String(), + is_active=graphene.Boolean(), + ) + + keypair_list = graphene.Field( + KeyPairList, + limit=graphene.Int(required=True), + offset=graphene.Int(required=True), + filter=graphene.String(), + order=graphene.String(), + # intrinsic filters + domain_name=graphene.String(), + email=graphene.String(), + is_active=graphene.Boolean(), + ) + + # NOTE: maybe add keypairs_from_user_id? + + keypair_resource_policy = graphene.Field( + KeyPairResourcePolicy, + name=graphene.String(), + ) + + keypair_resource_policies = graphene.List( + KeyPairResourcePolicy) + + resource_preset = graphene.Field( + ResourcePreset, + name=graphene.String(), + ) + + resource_presets = graphene.List( + ResourcePreset, + ) + + # super-admin only + scaling_group = graphene.Field( + ScalingGroup, + name=graphene.String(), + ) + + # super-admin only + scaling_groups = graphene.List( + ScalingGroup, + name=graphene.String(), + is_active=graphene.Boolean(), + ) + + # super-admin only + scaling_groups_for_domain = graphene.List( + ScalingGroup, + domain=graphene.String(required=True), + is_active=graphene.Boolean(), + ) + + # super-admin only + scaling_groups_for_user_group = graphene.List( + ScalingGroup, + user_group=graphene.String(required=True), + is_active=graphene.Boolean(), + ) + + # super-admin only + scaling_groups_for_keypair = graphene.List( + ScalingGroup, + access_key=graphene.String(required=True), + is_active=graphene.Boolean(), + ) + + # super-admin only + storage_volume = graphene.Field( + StorageVolume, + id=graphene.String(), + ) + + # super-admin only + storage_volume_list = graphene.Field( + StorageVolumeList, + limit=graphene.Int(required=True), + offset=graphene.Int(required=True), + filter=graphene.String(), + order=graphene.String(), + ) + + vfolder_list = graphene.Field( # legacy non-paginated list + VirtualFolderList, + limit=graphene.Int(required=True), + offset=graphene.Int(required=True), + filter=graphene.String(), + order=graphene.String(), + # intrinsic filters + domain_name=graphene.String(), + group_id=graphene.UUID(), + access_key=graphene.String(), # must be empty for user requests + ) + + vfolders = graphene.List( # legacy non-paginated list + VirtualFolder, + domain_name=graphene.String(), + group_id=graphene.String(), + access_key=graphene.String(), # must be empty for user requests + ) + + compute_session = graphene.Field( + ComputeSession, + id=graphene.UUID(required=True), + ) + + compute_container = graphene.Field( + ComputeContainer, + id=graphene.UUID(required=True), + ) + + compute_session_list = graphene.Field( + ComputeSessionList, + limit=graphene.Int(required=True), + offset=graphene.Int(required=True), + filter=graphene.String(), + order=graphene.String(), + # intrinsic filters + domain_name=graphene.String(), + group_id=graphene.String(), + access_key=graphene.String(), + status=graphene.String(), + ) + + compute_container_list = graphene.Field( + ComputeContainerList, + limit=graphene.Int(required=True), + offset=graphene.Int(required=True), + filter=graphene.String(), + order=graphene.String(), + # intrinsic filters + session_id=graphene.ID(required=True), + role=graphene.String(), + ) + + legacy_compute_session_list = graphene.Field( + LegacyComputeSessionList, + limit=graphene.Int(required=True), + offset=graphene.Int(required=True), + # legacy ordering + order_key=graphene.String(), + order_asc=graphene.Boolean(), + # intrinsic filters + domain_name=graphene.String(), + group_id=graphene.String(), + access_key=graphene.String(), + status=graphene.String(), + ) + + legacy_compute_session = graphene.Field( + LegacyComputeSession, + sess_id=graphene.String(required=True), + domain_name=graphene.String(), + access_key=graphene.String(), + ) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_agent( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + agent_id: AgentId, + ) -> Agent: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader( + ctx, + 'Agent', + raw_status=None, + ) + return await loader.load(agent_id) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_agents( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + scaling_group: str = None, + status: str = None, + ) -> Sequence[Agent]: + return await Agent.load_all( + info.context, + scaling_group=scaling_group, + raw_status=status, + ) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_agent_list( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + limit: int, + offset: int, + *, + filter: str = None, + order: str = None, + scaling_group: str = None, + status: str = None, + ) -> AgentList: + total_count = await Agent.load_count( + info.context, + scaling_group=scaling_group, + raw_status=status, + filter=filter, + ) + agent_list = await Agent.load_slice( + info.context, limit, offset, + scaling_group=scaling_group, + raw_status=status, + filter=filter, + order=order, + ) + return AgentList(agent_list, total_count) + + @staticmethod + async def resolve_domain( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, *, + name: str = None, + ) -> Domain: + ctx: GraphQueryContext = info.context + name = ctx.user['domain_name'] if name is None else name + if ctx.user['role'] != UserRole.SUPERADMIN: + if name != ctx.user['domain_name']: + # prevent querying other domains if not superadmin + raise ObjectNotFound(object_name='domain') + loader = ctx.dataloader_manager.get_loader(ctx, 'Domain.by_name') + return await loader.load(name) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_domains( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + is_active: bool = None, + ) -> Sequence[Domain]: + return await Domain.load_all(info.context, is_active=is_active) + + @staticmethod + async def resolve_group( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + id: uuid.UUID, + *, + domain_name: str = None, + ) -> Group: + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + client_domain = ctx.user['domain_name'] + client_user_id = ctx.user['uuid'] + if client_role == UserRole.SUPERADMIN: + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_id', domain_name=domain_name, + ) + group = await loader.load(id) + elif client_role == UserRole.ADMIN: + if domain_name is not None and domain_name != client_domain: + raise InsufficientPrivilege + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_id', domain_name=client_domain, + ) + group = await loader.load(id) + elif client_role == UserRole.USER: + if domain_name is not None and domain_name != client_domain: + raise InsufficientPrivilege + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_id', domain_name=client_domain, + ) + group = await loader.load(id) + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_user', + ) + client_groups = await loader.load(client_user_id) + if group.id not in (g.id for g in client_groups): + raise InsufficientPrivilege + else: + raise InvalidAPIParameters('Unknown client role') + return group + + @staticmethod + async def resolve_groups_by_name( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + name: str, + *, + domain_name: str = None, + ) -> Sequence[Group]: + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + client_domain = ctx.user['domain_name'] + client_user_id = ctx.user['uuid'] + if client_role == UserRole.SUPERADMIN: + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_name', domain_name=domain_name, + ) + groups = await loader.load(name) + elif client_role == UserRole.ADMIN: + if domain_name is not None and domain_name != client_domain: + raise InsufficientPrivilege + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_name', domain_name=client_domain, + ) + groups = await loader.load(name) + elif client_role == UserRole.USER: + if domain_name is not None and domain_name != client_domain: + raise InsufficientPrivilege + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_name', domain_name=client_domain, + ) + groups = await loader.load(name) + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_user', + ) + client_groups = await loader.load(client_user_id) + client_group_ids = set(g.id for g in client_groups) + groups = filter(lambda g: g.id in client_group_ids, groups) + else: + raise InvalidAPIParameters('Unknown client role') + return groups + + @staticmethod + async def resolve_groups( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + domain_name: str = None, + is_active: bool = None, + ) -> Sequence[Group]: + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + client_domain = ctx.user['domain_name'] + client_user_id = ctx.user['uuid'] + if client_role == UserRole.SUPERADMIN: + pass + elif client_role == UserRole.ADMIN: + if domain_name is not None and domain_name != client_domain: + raise InsufficientPrivilege + domain_name = client_domain + elif client_role == UserRole.USER: + loader = ctx.dataloader_manager.get_loader( + ctx, 'Group.by_user', + ) + client_groups = await loader.load(client_user_id) + return client_groups + else: + raise InvalidAPIParameters('Unknown client role') + return await Group.load_all( + info.context, + domain_name=domain_name, + is_active=is_active) + + @staticmethod + async def resolve_image( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + reference: str, + architecture: str, + ) -> Image: + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + client_domain = ctx.user['domain_name'] + item = await Image.load_item(info.context, reference, architecture) + if client_role == UserRole.SUPERADMIN: + pass + elif client_role in (UserRole.ADMIN, UserRole.USER): + items = await Image.filter_allowed(info.context, [item], client_domain) + if not items: + raise ImageNotFound + item = items[0] + else: + raise InvalidAPIParameters('Unknown client role') + return item + + @staticmethod + async def resolve_images( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + is_installed=None, + is_operation=False, + ) -> Sequence[Image]: + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + client_domain = ctx.user['domain_name'] + items = await Image.load_all(ctx, is_installed=is_installed, is_operation=is_operation) + if client_role == UserRole.SUPERADMIN: + pass + elif client_role in (UserRole.ADMIN, UserRole.USER): + items = await Image.filter_allowed( + info.context, + items, + client_domain, + is_installed=is_installed, + is_operation=is_operation, + ) + else: + raise InvalidAPIParameters('Unknown client role') + return items + + @staticmethod + @scoped_query(autofill_user=True, user_key='email') + async def resolve_user( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + domain_name: str = None, + email: str = None, + ) -> User: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader( + ctx, 'User.by_email', domain_name=domain_name, + ) + return await loader.load(email) + + @staticmethod + @scoped_query(autofill_user=True, user_key='user_id') + async def resolve_user_from_uuid( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + domain_name: str = None, + user_id: uuid.UUID | str | None = None, + ) -> User: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader( + ctx, 'User.by_uuid', domain_name=domain_name, + ) + # user_id is retrieved as string since it's a GraphQL's generic ID field + user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id + return await loader.load(user_uuid) + + @staticmethod + async def resolve_users( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + is_active: bool = None, + status: UserStatus = None, + ) -> Sequence[User]: + from .user import UserRole + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + client_domain = ctx.user['domain_name'] + if client_role == UserRole.SUPERADMIN: + pass + elif client_role == UserRole.ADMIN: + if domain_name is not None and domain_name != client_domain: + raise InsufficientPrivilege + domain_name = client_domain + elif client_role == UserRole.USER: + # Users cannot query other users. + raise InsufficientPrivilege() + else: + raise InvalidAPIParameters('Unknown client role') + return await User.load_all( + info.context, + domain_name=domain_name, + group_id=group_id, + is_active=is_active, + status=status, + limit=100) + + @staticmethod + async def resolve_user_list( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + limit: int, + offset: int, + *, + filter: str = None, + order: str = None, + domain_name: str = None, + group_id: uuid.UUID = None, + is_active: bool = None, + status: UserStatus = None, + ) -> UserList: + from .user import UserRole + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + client_domain = ctx.user['domain_name'] + if client_role == UserRole.SUPERADMIN: + pass + elif client_role == UserRole.ADMIN: + if domain_name is not None and domain_name != client_domain: + raise InsufficientPrivilege + domain_name = client_domain + elif client_role == UserRole.USER: + # Users cannot query other users. + raise InsufficientPrivilege() + else: + raise InvalidAPIParameters('Unknown client role') + total_count = await User.load_count( + info.context, + domain_name=domain_name, + group_id=group_id, + is_active=is_active, + status=status, + filter=filter, + ) + user_list = await User.load_slice( + info.context, + limit, + offset, + domain_name=domain_name, + group_id=group_id, + is_active=is_active, + status=status, + filter=filter, + order=order, + ) + return UserList(user_list, total_count) + + @staticmethod + @scoped_query(autofill_user=True, user_key='access_key') + async def resolve_keypair( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + domain_name: str = None, + access_key: AccessKey = None, + ) -> KeyPair: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader( + ctx, + 'KeyPair.by_ak', + domain_name=domain_name, + ) + return await loader.load(access_key) + + @staticmethod + @scoped_query(autofill_user=False, user_key='email') + async def resolve_keypairs( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + *, + domain_name: str = None, + email: str = None, + is_active: bool = None, + ) -> Sequence[KeyPair]: + ctx: GraphQueryContext = info.context + if email is None: + return await KeyPair.load_all( + info.context, + domain_name=domain_name, + is_active=is_active, + limit=100, + ) + else: + loader = ctx.dataloader_manager.get_loader( + ctx, + 'KeyPair.by_email', + domain_name=domain_name, + is_active=is_active, + ) + return await loader.load(email) + + @staticmethod + @scoped_query(autofill_user=False, user_key='email') + async def resolve_keypair_list( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + limit: int, + offset: int, + *, + filter: str = None, + order: str = None, + domain_name: str = None, + email: str = None, + is_active: bool = None, + ) -> KeyPairList: + total_count = await KeyPair.load_count( + info.context, + domain_name=domain_name, + email=email, + is_active=is_active, + filter=filter, + ) + keypair_list = await KeyPair.load_slice( + info.context, + limit, + offset, + domain_name=domain_name, + email=email, + is_active=is_active, + filter=filter, + order=order, + ) + return KeyPairList(keypair_list, total_count) + + @staticmethod + async def resolve_keypair_resource_policy( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + name: str = None, + ) -> KeyPairResourcePolicy: + ctx: GraphQueryContext = info.context + client_access_key = ctx.access_key + if name is None: + loader = ctx.dataloader_manager.get_loader( + ctx, 'KeyPairResourcePolicy.by_ak', + ) + return await loader.load(client_access_key) + else: + loader = ctx.dataloader_manager.get_loader( + ctx, 'KeyPairResourcePolicy.by_name', + ) + return await loader.load(name) + + @staticmethod + async def resolve_keypair_resource_policies( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + ) -> Sequence[KeyPairResourcePolicy]: + ctx: GraphQueryContext = info.context + client_role = ctx.user['role'] + client_access_key = ctx.access_key + if client_role == UserRole.SUPERADMIN: + return await KeyPairResourcePolicy.load_all(info.context) + elif client_role == UserRole.ADMIN: + # TODO: filter resource policies by domains? + return await KeyPairResourcePolicy.load_all(info.context) + elif client_role == UserRole.USER: + return await KeyPairResourcePolicy.load_all_user( + info.context, client_access_key, + ) + else: + raise InvalidAPIParameters('Unknown client role') + + @staticmethod + async def resolve_resource_preset( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + name: str, + ) -> ResourcePreset: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader(ctx, 'ResourcePreset.by_name') + return await loader.load(name) + + @staticmethod + async def resolve_resource_presets( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + ) -> Sequence[ResourcePreset]: + return await ResourcePreset.load_all(info.context) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_scaling_group( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + name: str, + ) -> ScalingGroup: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader( + ctx, 'ScalingGroup.by_name', + ) + return await loader.load(name) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_scaling_groups( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + is_active: bool = None, + ) -> Sequence[ScalingGroup]: + return await ScalingGroup.load_all(info.context, is_active=is_active) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_scaling_groups_for_domain( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + domain: str, + is_active: bool = None, + ) -> Sequence[ScalingGroup]: + return await ScalingGroup.load_by_domain( + info.context, domain, is_active=is_active, + ) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_scaling_groups_for_group( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + user_group, + is_active: bool = None, + ) -> Sequence[ScalingGroup]: + return await ScalingGroup.load_by_group( + info.context, user_group, is_active=is_active, + ) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_scaling_groups_for_keypair( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + access_key: AccessKey, + is_active: bool = None, + ) -> Sequence[ScalingGroup]: + return await ScalingGroup.load_by_keypair( + info.context, access_key, is_active=is_active, + ) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_storage_volume( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + id: str, + ) -> StorageVolume: + return await StorageVolume.load_by_id(info.context, id) + + @staticmethod + @privileged_query(UserRole.SUPERADMIN) + async def resolve_storage_volume_list( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + limit: int, + offset: int, + *, + filter: str = None, + order: str = None, + ) -> StorageVolumeList: + total_count = await StorageVolume.load_count( + info.context, + filter=filter, + ) + items = await StorageVolume.load_slice( + info.context, + limit, + offset, + filter=filter, + order=order, + ) + return StorageVolumeList(items, total_count) + + @staticmethod + @scoped_query(autofill_user=False, user_key='user_id') + async def resolve_vfolder_list( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + limit: int, + offset: int, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + user_id: uuid.UUID = None, + filter: str = None, + order: str = None, + ) -> VirtualFolderList: + # TODO: adopt the generic queryfilter language + total_count = await VirtualFolder.load_count( + info.context, + domain_name=domain_name, # scope + group_id=group_id, # scope + user_id=user_id, # scope + filter=filter, + ) + items = await VirtualFolder.load_slice( + info.context, + limit, + offset, + domain_name=domain_name, # scope + group_id=group_id, # scope + user_id=user_id, # scope + filter=filter, + order=order, + ) + return VirtualFolderList(items, total_count) + + @staticmethod + @scoped_query(autofill_user=False, user_key='access_key') + async def resolve_compute_container_list( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + limit: int, + offset: int, + *, + filter: str = None, + order: str = None, + session_id: SessionId, + role: UserRole = None, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: AccessKey = None, + ) -> ComputeContainerList: + # TODO: adopt the generic queryfilter language + total_count = await ComputeContainer.load_count( + info.context, + session_id, # filter (mandatory) + cluster_role=role, # filter + domain_name=domain_name, # scope + group_id=group_id, # scope + access_key=access_key, # scope + filter=filter, + ) + items = await ComputeContainer.load_slice( + info.context, + limit, offset, # slice + session_id, # filter (mandatory) + cluster_role=role, # filter + domain_name=domain_name, # scope + group_id=group_id, # scope + access_key=access_key, # scope + filter=filter, + order=order, + ) + return ComputeContainerList(items, total_count) + + @staticmethod + @scoped_query(autofill_user=False, user_key='access_key') + async def resolve_compute_container( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + container_id: str, + ) -> ComputeContainer: + # We need to check the group membership of the designated kernel, + # but practically a user cannot guess the IDs of kernels launched + # by other users and in other groups. + # Let's just protect the domain/user boundary here. + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, 'ComputeContainer.detail') + return await loader.load(container_id) + + @staticmethod + @scoped_query(autofill_user=False, user_key='access_key') + async def resolve_compute_session_list( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + limit: int, + offset: int, + *, + filter: str = None, + order: str = None, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: AccessKey = None, + status: str = None, + ) -> ComputeSessionList: + total_count = await ComputeSession.load_count( + info.context, + status=status, # filter + domain_name=domain_name, # scope + group_id=group_id, # scope + access_key=access_key, # scope + filter=filter, + ) + items = await ComputeSession.load_slice( + info.context, + limit, offset, # slice + status=status, # filter + domain_name=domain_name, # scope + group_id=group_id, # scope + access_key=access_key, # scope + filter=filter, + order=order, + ) + return ComputeSessionList(items, total_count) + + @staticmethod + @scoped_query(autofill_user=False, user_key='access_key') + async def resolve_compute_session( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + id: SessionId, + *, + domain_name: str = None, + access_key: AccessKey = None, + ) -> ComputeSession: + # We need to check the group membership of the designated kernel, + # but practically a user cannot guess the IDs of kernels launched + # by other users and in other groups. + # Let's just protect the domain/user boundary here. + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader( + graph_ctx, + 'ComputeSession.detail', + domain_name=domain_name, + access_key=access_key, + ) + return await loader.load(id) + + @staticmethod + @scoped_query(autofill_user=False, user_key='access_key') + async def resolve_legacy_compute_session_list( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + limit: int, + offset: int, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: AccessKey = None, + status: str = None, + order_key: str = None, + order_asc: bool = True, + ) -> LegacyComputeSessionList: + total_count = await LegacyComputeSession.load_count( + info.context, + domain_name=domain_name, + group_id=group_id, + access_key=access_key, + status=status, + ) + items = await LegacyComputeSession.load_slice( + info.context, + limit, + offset, + domain_name=domain_name, + group_id=group_id, + access_key=access_key, + status=status, + order_key=order_key, + order_asc=order_asc, + ) + return LegacyComputeSessionList(items, total_count) + + @staticmethod + @scoped_query(autofill_user=False, user_key='access_key') + async def resolve_legacy_compute_session( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + sess_id: str, + *, + domain_name: str = None, + access_key: AccessKey = None, + status: str = None, + ) -> Optional[LegacyComputeSession]: + # We need to check the group membership of the designated kernel, + # but practically a user cannot guess the IDs of kernels launched + # by other users and in other groups. + # Let's just protect the domain/user boundary here. + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader( + graph_ctx, + 'LegacyComputeSession.detail', + domain_name=domain_name, + access_key=access_key, + status=status, + ) + matches = await loader.load(sess_id) + if len(matches) == 0: + return None + elif len(matches) == 1: + return matches[0] + else: + raise TooManyKernelsFound + + +class GQLMutationPrivilegeCheckMiddleware: + + def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: + graph_ctx: GraphQueryContext = info.context + if info.operation.operation == 'mutation' and len(info.path) == 1: + mutation_cls = getattr(Mutations, info.path[0]).type + # default is allow nobody. + allowed_roles = getattr(mutation_cls, 'allowed_roles', []) + if graph_ctx.user['role'] not in allowed_roles: + return mutation_cls(False, f"no permission to execute {info.path[0]}") + return next(root, info, **args) diff --git a/src/ai/backend/manager/models/group.py b/src/ai/backend/manager/models/group.py new file mode 100644 index 0000000000..93fc4348c3 --- /dev/null +++ b/src/ai/backend/manager/models/group.py @@ -0,0 +1,665 @@ +from __future__ import annotations + +import asyncio +import logging +import re +from typing import ( + Any, + Dict, + Optional, + Sequence, TYPE_CHECKING, + TypedDict, + Union, +) +import uuid + +import aiohttp +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pgsql +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection + +from ai.backend.common import msgpack +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ResourceSlot + +from ..api.exceptions import VFolderOperationFailed +from ..defs import RESERVED_DOTFILES +from .base import ( + metadata, GUID, IDColumn, ResourceSlotColumn, + privileged_mutation, + set_if_set, + simple_db_mutate, + simple_db_mutate_returning_item, + batch_result, + batch_multiresult, +) +from .storage import StorageSessionManager +from .user import ModifyUserInput, UserRole +from .utils import execute_with_retry + +if TYPE_CHECKING: + from .gql import GraphQueryContext + from .scaling_group import ScalingGroup + +log = BraceStyleAdapter(logging.getLogger(__file__)) + + +__all__: Sequence[str] = ( + 'groups', 'association_groups_users', + 'resolve_group_name_or_id', + 'Group', 'GroupInput', 'ModifyGroupInput', + 'CreateGroup', 'ModifyGroup', 'DeleteGroup', + 'GroupDotfile', 'MAXIMUM_DOTFILE_SIZE', + 'query_group_dotfiles', + 'query_group_domain', + 'verify_dotfile_name', +) + +MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB +_rx_slug = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?$') + +association_groups_users = sa.Table( + 'association_groups_users', metadata, + sa.Column('user_id', GUID, + sa.ForeignKey('users.uuid', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False), + sa.Column('group_id', GUID, + sa.ForeignKey('groups.id', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False), + sa.UniqueConstraint('user_id', 'group_id', name='uq_association_user_id_group_id'), +) + + +groups = sa.Table( + 'groups', metadata, + IDColumn('id'), + sa.Column('name', sa.String(length=64), nullable=False), + sa.Column('description', sa.String(length=512)), + sa.Column('is_active', sa.Boolean, default=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), onupdate=sa.func.current_timestamp()), + #: Field for synchronization with external services. + sa.Column('integration_id', sa.String(length=512)), + + sa.Column('domain_name', sa.String(length=64), + sa.ForeignKey('domains.name', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False, index=True), + # TODO: separate resource-related fields with new domain resource policy table when needed. + sa.Column('total_resource_slots', ResourceSlotColumn(), default='{}'), + sa.Column('allowed_vfolder_hosts', pgsql.ARRAY(sa.String), nullable=False, default='{}'), + sa.UniqueConstraint('name', 'domain_name', name='uq_groups_name_domain_name'), + # dotfiles column, \x90 means empty list in msgpack + sa.Column('dotfiles', sa.LargeBinary(length=MAXIMUM_DOTFILE_SIZE), nullable=False, default=b'\x90'), +) + + +async def resolve_group_name_or_id( + db_conn: SAConnection, + domain_name: str, + value: Union[str, uuid.UUID], +) -> Optional[uuid.UUID]: + if isinstance(value, str): + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where( + (groups.c.name == value) & + (groups.c.domain_name == domain_name), + ) + ) + return await db_conn.scalar(query) + elif isinstance(value, uuid.UUID): + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where( + (groups.c.id == value) & + (groups.c.domain_name == domain_name), + ) + ) + return await db_conn.scalar(query) + else: + raise TypeError('unexpected type for group_name_or_id') + + +class Group(graphene.ObjectType): + id = graphene.UUID() + name = graphene.String() + description = graphene.String() + is_active = graphene.Boolean() + created_at = GQLDateTime() + modified_at = GQLDateTime() + domain_name = graphene.String() + total_resource_slots = graphene.JSONString() + allowed_vfolder_hosts = graphene.List(lambda: graphene.String) + integration_id = graphene.String() + + scaling_groups = graphene.List(lambda: graphene.String) + + @classmethod + def from_row(cls, graph_ctx: GraphQueryContext, row: Row) -> Optional[Group]: + if row is None: + return None + return cls( + id=row['id'], + name=row['name'], + description=row['description'], + is_active=row['is_active'], + created_at=row['created_at'], + modified_at=row['modified_at'], + domain_name=row['domain_name'], + total_resource_slots=row['total_resource_slots'].to_json(), + allowed_vfolder_hosts=row['allowed_vfolder_hosts'], + integration_id=row['integration_id'], + ) + + async def resolve_scaling_groups(self, info: graphene.ResolveInfo) -> Sequence[ScalingGroup]: + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader( + graph_ctx, "ScalingGroup.by_group", + ) + sgroups = await loader.load(self.id) + return [sg.name for sg in sgroups] + + @classmethod + async def load_all( + cls, + graph_ctx: GraphQueryContext, + *, + domain_name: str = None, + is_active: bool = None, + ) -> Sequence[Group]: + query = ( + sa.select([groups]) + .select_from(groups) + ) + if domain_name is not None: + query = query.where(groups.c.domain_name == domain_name) + if is_active is not None: + query = query.where(groups.c.is_active == is_active) + async with graph_ctx.db.begin_readonly() as conn: + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(graph_ctx, row)) is not None + ] + + @classmethod + async def batch_load_by_id( + cls, + graph_ctx: GraphQueryContext, + group_ids: Sequence[uuid.UUID], + *, + domain_name: str = None, + ) -> Sequence[Group | None]: + query = ( + sa.select([groups]) + .select_from(groups) + .where(groups.c.id.in_(group_ids)) + ) + if domain_name is not None: + query = query.where(groups.c.domain_name == domain_name) + async with graph_ctx.db.begin_readonly() as conn: + return await batch_result( + graph_ctx, conn, query, cls, + group_ids, lambda row: row['id'], + ) + + @classmethod + async def batch_load_by_name( + cls, + graph_ctx: GraphQueryContext, + group_names: Sequence[str], + *, + domain_name: str = None, + ) -> Sequence[Sequence[Group | None]]: + query = ( + sa.select([groups]) + .select_from(groups) + .where(groups.c.name.in_(group_names)) + ) + if domain_name is not None: + query = query.where(groups.c.domain_name == domain_name) + async with graph_ctx.db.begin_readonly() as conn: + return await batch_multiresult( + graph_ctx, conn, query, cls, + group_names, lambda row: row['name'], + ) + + @classmethod + async def batch_load_by_user( + cls, + graph_ctx: GraphQueryContext, + user_ids: Sequence[uuid.UUID], + ) -> Sequence[Sequence[Group | None]]: + j = sa.join( + groups, association_groups_users, + groups.c.id == association_groups_users.c.group_id, + ) + query = ( + sa.select([groups, association_groups_users.c.user_id]) + .select_from(j) + .where(association_groups_users.c.user_id.in_(user_ids)) + ) + async with graph_ctx.db.begin_readonly() as conn: + return await batch_multiresult( + graph_ctx, conn, query, cls, + user_ids, lambda row: row['user_id'], + ) + + @classmethod + async def get_groups_for_user( + cls, + graph_ctx: GraphQueryContext, + user_id: uuid.UUID, + ) -> Sequence[Group]: + j = sa.join( + groups, association_groups_users, + groups.c.id == association_groups_users.c.group_id, + ) + query = ( + sa.select([groups]) + .select_from(j) + .where(association_groups_users.c.user_id == user_id) + ) + async with graph_ctx.db.begin_readonly() as conn: + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(graph_ctx, row)) is not None + ] + + +class GroupInput(graphene.InputObjectType): + description = graphene.String(required=False) + is_active = graphene.Boolean(required=False, default=True) + domain_name = graphene.String(required=True) + total_resource_slots = graphene.JSONString(required=False) + allowed_vfolder_hosts = graphene.List(lambda: graphene.String, required=False) + integration_id = graphene.String(required=False) + + +class ModifyGroupInput(graphene.InputObjectType): + name = graphene.String(required=False) + description = graphene.String(required=False) + is_active = graphene.Boolean(required=False) + domain_name = graphene.String(required=False) + total_resource_slots = graphene.JSONString(required=False) + user_update_mode = graphene.String(required=False) + user_uuids = graphene.List(lambda: graphene.String, required=False) + allowed_vfolder_hosts = graphene.List(lambda: graphene.String, required=False) + integration_id = graphene.String(required=False) + + +class CreateGroup(graphene.Mutation): + + allowed_roles = (UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + name = graphene.String(required=True) + props = GroupInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + group = graphene.Field(lambda: Group, required=False) + + @classmethod + @privileged_mutation( + UserRole.ADMIN, + lambda name, props, **kwargs: (props.domain_name, None), + ) + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: GroupInput, + ) -> CreateGroup: + if _rx_slug.search(name) is None: + raise ValueError('invalid name format. slug format required.') + graph_ctx: GraphQueryContext = info.context + data = { + 'name': name, + 'description': props.description, + 'is_active': props.is_active, + 'domain_name': props.domain_name, + 'total_resource_slots': ResourceSlot.from_user_input( + props.total_resource_slots, None), + 'allowed_vfolder_hosts': props.allowed_vfolder_hosts, + 'integration_id': props.integration_id, + } + insert_query = ( + sa.insert(groups).values(data) + ) + return await simple_db_mutate_returning_item(cls, graph_ctx, insert_query, item_cls=Group) + + +class ModifyGroup(graphene.Mutation): + + allowed_roles = (UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + gid = graphene.UUID(required=True) + props = ModifyGroupInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + group = graphene.Field(lambda: Group, required=False) + + @classmethod + @privileged_mutation( + UserRole.ADMIN, + lambda gid, **kwargs: (None, gid), + ) + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + gid: uuid.UUID, + props: ModifyUserInput, + ) -> ModifyGroup: + graph_ctx: GraphQueryContext = info.context + data: Dict[str, Any] = {} + set_if_set(props, data, 'name') + set_if_set(props, data, 'description') + set_if_set(props, data, 'is_active') + set_if_set(props, data, 'domain_name') + set_if_set(props, data, 'total_resource_slots', + clean_func=lambda v: ResourceSlot.from_user_input(v, None)) + set_if_set(props, data, 'allowed_vfolder_hosts') + set_if_set(props, data, 'integration_id') + + if 'name' in data and _rx_slug.search(data['name']) is None: + raise ValueError('invalid name format. slug format required.') + if props.user_update_mode not in (None, 'add', 'remove'): + raise ValueError('invalid user_update_mode') + if not props.user_uuids: + props.user_update_mode = None + if not data and props.user_update_mode is None: + return cls(ok=False, msg='nothing to update', group=None) + + async def _do_mutate() -> ModifyGroup: + async with graph_ctx.db.begin() as conn: + # TODO: refactor user addition/removal in groups as separate mutations + # (to apply since 21.09) + if props.user_update_mode == 'add': + values = [{'user_id': uuid, 'group_id': gid} for uuid in props.user_uuids] + await conn.execute( + sa.insert(association_groups_users).values(values), + ) + elif props.user_update_mode == 'remove': + await conn.execute( + sa.delete(association_groups_users) + .where( + (association_groups_users.c.user_id.in_(props.user_uuids)) & + (association_groups_users.c.group_id == gid), + ), + ) + if data: + result = await conn.execute( + sa.update(groups) + .values(data) + .where(groups.c.id == gid) + .returning(groups), + ) + if result.rowcount > 0: + o = Group.from_row(graph_ctx, result.first()) + return cls(ok=True, msg='success', group=o) + return cls(ok=False, msg='no such group', group=None) + else: # updated association_groups_users table + return cls(ok=True, msg='success', group=None) + + try: + return await execute_with_retry(_do_mutate) + except sa.exc.IntegrityError as e: + return cls(ok=False, msg=f'integrity error: {e}', group=None) + except (asyncio.CancelledError, asyncio.TimeoutError): + raise + except Exception as e: + return cls(ok=False, msg=f'unexpected error: {e}', group=None) + + +class DeleteGroup(graphene.Mutation): + """ + Instead of deleting the group, just mark it as inactive. + """ + allowed_roles = (UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + gid = graphene.UUID(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + @privileged_mutation( + UserRole.ADMIN, + lambda gid, **kwargs: (None, gid), + ) + async def mutate(cls, root, info: graphene.ResolveInfo, gid: uuid.UUID) -> DeleteGroup: + ctx: GraphQueryContext = info.context + update_query = ( + sa.update(groups).values( + is_active=False, + integration_id=None, + ).where(groups.c.id == gid) + ) + return await simple_db_mutate(cls, ctx, update_query) + + +class PurgeGroup(graphene.Mutation): + """ + Completely deletes a group from DB. + + Group's vfolders and their data will also be lost + as well as the kernels run from the group. + There is no migration of the ownership for group folders. + """ + allowed_roles = (UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + gid = graphene.UUID(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + @privileged_mutation( + UserRole.ADMIN, + lambda gid, **kwargs: (None, gid), + ) + async def mutate(cls, root, info: graphene.ResolveInfo, gid: uuid.UUID) -> PurgeGroup: + graph_ctx: GraphQueryContext = info.context + + async def _pre_func(conn: SAConnection) -> None: + if await cls.group_vfolder_mounted_to_active_kernels(conn, gid): + raise RuntimeError( + "Some of virtual folders that belong to this group " + "are currently mounted to active sessions. " + "Terminate them first to proceed removal.", + ) + if await cls.group_has_active_kernels(conn, gid): + raise RuntimeError( + "Group has some active session. " + "Terminate them first to proceed removal.", + ) + await cls.delete_vfolders(conn, gid, graph_ctx.storage_manager) + await cls.delete_kernels(conn, gid) + + delete_query = sa.delete(groups).where(groups.c.id == gid) + return await simple_db_mutate(cls, graph_ctx, delete_query, pre_func=_pre_func) + + @classmethod + async def delete_vfolders( + cls, + db_conn: SAConnection, + group_id: uuid.UUID, + storage_manager: StorageSessionManager, + ) -> int: + """ + Delete group's all virtual folders as well as their physical data. + + :param conn: DB connection + :param group_id: group's UUID to delete virtual folders + + :return: number of deleted rows + """ + from . import vfolders + query = ( + sa.select([vfolders.c.id, vfolders.c.host]) + .select_from(vfolders) + .where(vfolders.c.group == group_id) + ) + result = await db_conn.execute(query) + target_vfs = result.fetchall() + delete_query = (sa.delete(vfolders).where(vfolders.c.group == group_id)) + result = await db_conn.execute(delete_query) + for row in target_vfs: + try: + async with storage_manager.request( + row['host'], 'POST', 'folder/delete', + json={ + 'volume': storage_manager.split_host(row['host'])[1], + 'vfid': str(row['id']), + }, + raise_for_status=True, + ): + pass + except aiohttp.ClientResponseError: + log.error('error on deleting vfolder filesystem directory: {0}', row['id']) + raise VFolderOperationFailed + if result.rowcount > 0: + log.info('deleted {0} group\'s virtual folders ({1})', result.rowcount, group_id) + return result.rowcount + + @classmethod + async def delete_kernels( + cls, + db_conn: SAConnection, + group_id: uuid.UUID, + ) -> int: + """ + Delete all kernels run from the target groups. + + :param conn: DB connection + :param group_id: group's UUID to delete kernels + + :return: number of deleted rows + """ + from . import kernels + query = ( + sa.delete(kernels) + .where(kernels.c.group_id == group_id) + ) + result = await db_conn.execute(query) + if result.rowcount > 0: + log.info('deleted {0} group\'s kernels ({1})', result.rowcount, group_id) + return result.rowcount + + @classmethod + async def group_vfolder_mounted_to_active_kernels( + cls, + db_conn: SAConnection, + group_id: uuid.UUID, + ) -> bool: + """ + Check if no active kernel is using the group's virtual folders. + + :param conn: DB connection + :param group_id: group's ID + + :return: True if a virtual folder is mounted to active kernels. + """ + from . import kernels, vfolders, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES + query = ( + sa.select([vfolders.c.id]) + .select_from(vfolders) + .where(vfolders.c.group == group_id) + ) + result = await db_conn.execute(query) + rows = result.fetchall() + group_vfolder_ids = [row['id'] for row in rows] + query = ( + sa.select([kernels.c.mounts]) + .select_from(kernels) + .where( + (kernels.c.group_id == group_id) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ) + ) + async for row in (await db_conn.stream(query)): + for _mount in row['mounts']: + try: + vfolder_id = uuid.UUID(_mount[2]) + if vfolder_id in group_vfolder_ids: + return True + except Exception: + pass + return False + + @classmethod + async def group_has_active_kernels( + cls, + db_conn: SAConnection, + group_id: uuid.UUID, + ) -> bool: + """ + Check if the group does not have active kernels. + + :param conn: DB connection + :param group_id: group's UUID + + :return: True if the group has some active kernels. + """ + from . import kernels, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES + query = ( + sa.select([sa.func.count()]) + .select_from(kernels) + .where((kernels.c.group_id == group_id) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))) + ) + active_kernel_count = await db_conn.scalar(query) + return True if active_kernel_count > 0 else False + + +class GroupDotfile(TypedDict): + data: str + path: str + perm: str + + +async def query_group_dotfiles( + db_conn: SAConnection, + group_id: Union[GUID, uuid.UUID], +) -> tuple[list[GroupDotfile], int]: + query = ( + sa.select([groups.c.dotfiles]) + .select_from(groups) + .where(groups.c.id == group_id) + ) + packed_dotfile = await db_conn.scalar(query) + if packed_dotfile is None: + return [], MAXIMUM_DOTFILE_SIZE + rows = msgpack.unpackb(packed_dotfile) + return rows, MAXIMUM_DOTFILE_SIZE - len(packed_dotfile) + + +async def query_group_domain( + db_conn: SAConnection, + group_id: Union[GUID, uuid.UUID], +) -> str: + query = ( + sa.select([groups.c.domain_name]) + .select_from(groups) + .where(groups.c.id == group_id) + ) + domain = await db_conn.scalar(query) + return domain + + +def verify_dotfile_name(dotfile: str) -> bool: + if dotfile in RESERVED_DOTFILES: + return False + return True diff --git a/src/ai/backend/manager/models/image.py b/src/ai/backend/manager/models/image.py new file mode 100644 index 0000000000..2c6494bf1b --- /dev/null +++ b/src/ai/backend/manager/models/image.py @@ -0,0 +1,910 @@ +from __future__ import annotations + +from decimal import Decimal +import enum +import functools +import logging +from pathlib import Path +from typing import ( + Any, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + Union, +) +import aiotools + +import graphene +from graphql.execution.executors.asyncio import AsyncioExecutor +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import ( + relationship, + selectinload, +) +import trafaret as t +import yaml + +from ai.backend.common import redis +from ai.backend.common.docker import ImageRef +from ai.backend.common.etcd import AsyncEtcd +from ai.backend.common.exception import UnknownImageReference +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + BinarySize, + ImageAlias, + ResourceSlot, +) + +from ai.backend.manager.container_registry import get_container_registry +from ai.backend.manager.api.exceptions import ImageNotFound +from ai.backend.manager.defs import DEFAULT_IMAGE_ARCH + +from .base import ( + BigInt, ForeignKeyIDColumn, IDColumn, + KVPair, ResourceLimit, KVPairInput, ResourceLimitInput, + Base, StructuredJSONColumn, set_if_set, +) +from .user import UserRole +from .utils import ExtendedAsyncSAEngine + +if TYPE_CHECKING: + from ai.backend.common.bgtask import ProgressReporter + from ai.backend.manager.config import SharedConfig + + from .gql import GraphQueryContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +__all__ = ( + 'rescan_images', + 'update_aliases_from_file', + 'ImageType', + 'ImageAliasRow', + 'ImageRow', + 'Image', + 'PreloadImage', + 'RescanImages', + 'ForgetImage', + 'ModifyImage', + 'AliasImage', + 'DealiasImage', + 'ClearImages', +) + + +async def rescan_images( + etcd: AsyncEtcd, + db: ExtendedAsyncSAEngine, + registry: str = None, + *, + reporter: ProgressReporter = None, +) -> None: + # cannot import ai.backend.manager.config at start due to circular import + from ai.backend.manager.config import container_registry_iv + + registry_config_iv = t.Mapping(t.String, container_registry_iv) + latest_registry_config = registry_config_iv.check( + await etcd.get_prefix('config/docker/registry'), + ) + # TODO: delete images from registries removed from the previous config? + if registry is None: + # scan all configured registries + registries = latest_registry_config + else: + try: + registries = {registry: latest_registry_config[registry]} + except KeyError: + raise RuntimeError("It is an unknown registry.", registry) + async with aiotools.TaskGroup() as tg: + for registry_name, registry_info in registries.items(): + log.info('Scanning kernel images from the registry "{0}"', registry_name) + scanner_cls = get_container_registry(registry_info) + scanner = scanner_cls(db, registry_name, registry_info) + tg.create_task(scanner.rescan_single_registry(reporter)) + # TODO: delete images removed from registry? + + +async def update_aliases_from_file(session: AsyncSession, file: Path) -> List[ImageAliasRow]: + log.info('Updating image aliases from "{0}"', file) + ret: List[ImageAliasRow] = [] + try: + data = yaml.safe_load(open(file, 'r', encoding='utf-8')) + except IOError: + log.error('Cannot open "{0}".', file) + return [] + for item in data['aliases']: + alias = item[0] + target = item[1] + if len(item) >= 2: + architecture = item[2] + else: + log.warn( + 'architecture not set for {} => {}, assuming as {}', + target, alias, DEFAULT_IMAGE_ARCH) + architecture = DEFAULT_IMAGE_ARCH + try: + image_row = await ImageRow.from_image_ref( + session, ImageRef(target, ['*'], architecture), + ) + image_alias = ImageAliasRow( + alias=alias, + image=image_row, + ) + # let user call session.begin() + session.add(image_alias) + ret.append(image_alias) + print(f'{alias} -> {image_row.image_ref}') + except UnknownImageReference: + print(f'{alias} -> target image not found') + log.info('Done.') + return ret + + +class ImageType(enum.Enum): + COMPUTE = 'compute' + SYSTEM = 'system' + SERVICE = 'service' + + +class ImageRow(Base): + __tablename__ = 'images' + id = IDColumn('id') + name = sa.Column('name', sa.String, nullable=False, index=True) + image = sa.Column('image', sa.String, nullable=False, index=True) + created_at = sa.Column( + 'created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True, + ) + tag = sa.Column('tag', sa.TEXT) + registry = sa.Column('registry', sa.String, nullable=False, index=True) + architecture = sa.Column('architecture', sa.String, nullable=False, index=True, default='x86_64') + config_digest = sa.Column('config_digest', sa.CHAR(length=72), nullable=False) + size_bytes = sa.Column('size_bytes', sa.BigInteger, nullable=False) + type = sa.Column('type', sa.Enum(ImageType), nullable=False) + accelerators = sa.Column('accelerators', sa.String) + labels = sa.Column('labels', sa.JSON, nullable=False) + resources = sa.Column('resources', StructuredJSONColumn( + t.Mapping( + t.String, + t.Dict({ + t.Key('min'): t.String, + t.Key('max', default=None): t.Null | t.String, + }), + ), + ), nullable=False) + aliases: relationship + + def __init__( + self, + name, + architecture, + registry=None, + image=None, + tag=None, + config_digest=None, + size_bytes=None, + type=None, + accelerators=None, + labels=None, + resources=None, + ) -> None: + self.name = name + self.registry = registry + self.image = image + self.tag = tag + self.architecture = architecture + self.config_digest = config_digest + self.size_bytes = size_bytes + self.type = type + self.accelerators = accelerators + self.labels = labels + self.resources = resources + + @property + def image_ref(self): + return ImageRef(self.name, [self.registry], self.architecture) + + @classmethod + async def from_alias( + cls, + session: AsyncSession, + alias: str, + load_aliases=False, + ) -> ImageRow: + query = ( + sa.select(ImageRow).select_from(ImageRow) + .join(ImageAliasRow, ImageRow.aliases.and_(ImageAliasRow.alias == alias)) + ) + if load_aliases: + query = query.options(selectinload(ImageRow.aliases)) + result = await session.scalar(query) + if result is not None: + return result + else: + raise UnknownImageReference + + @classmethod + async def from_image_ref( + cls, + session: AsyncSession, + ref: ImageRef, + *, + strict_arch: bool = False, + load_aliases: bool = False, + ) -> ImageRow: + """ + Loads a image row that corresponds to the given ImageRef object. + + When *strict_arch* is False and the image table has only one row + with respect to requested canonical, this function will + return that row regardless of the image architecture. + """ + query = sa.select(ImageRow).where(ImageRow.name == ref.canonical) + if load_aliases: + query = query.options(selectinload(ImageRow.aliases)) + + result = await session.execute(query) + candidates: List[ImageRow] = result.scalars().all() + + if len(candidates) == 0: + raise UnknownImageReference(ref) + if len(candidates) == 1 and not strict_arch: + return candidates[0] + for row in candidates: + if row.architecture == ref.architecture: + return row + raise UnknownImageReference(ref) + + @classmethod + async def resolve( + cls, + session: AsyncSession, + reference_candidates: List[Union[ImageAlias, ImageRef]], + *, + strict_arch: bool = False, + load_aliases: bool = True, + ) -> ImageRow: + """ + Resolves a matching row in the image table from image references and/or aliases. + If candidate element is `ImageRef`, this method will try to resolve image with matching + `ImageRef` description. Otherwise, if element is `str`, this will try to follow the alias. + If multiple elements are supplied, this method will return the first matched `ImageRow` + among those elements. + Passing the canonical image reference as string directly to resolve image data + is no longer possible. You need to declare ImageRef object explicitly if you're using string + as an canonical image references. For example: + .. code-block:: + await ImageRow.resolve( + conn, + [ + ImageRef( + image, + registry, + architecture, + ), + image_alias, + ], + ) + + When *strict_arch* is False and the image table has only one row + with respect to requested canonical, this function will + return that row regardless of the image architecture. + + When *load_aliases* is True, it tries to resolve the alias chain. + Otherwise it finds only the direct image references. + """ + searched_refs = [] + for reference in reference_candidates: + resolver_func: Any = None + if isinstance(reference, str): + resolver_func = cls.from_alias + searched_refs.append(f"alias:{reference!r}") + elif isinstance(reference, ImageRef): + resolver_func = functools.partial(cls.from_image_ref, strict_arch=strict_arch) + searched_refs.append(f"ref:{reference.canonical!r}") + try: + if (row := await resolver_func(session, reference, load_aliases=load_aliases)): + return row + except UnknownImageReference: + continue + raise ImageNotFound("Unkown image references: " + ", ".join(searched_refs)) + + @classmethod + async def list(cls, session: AsyncSession, load_aliases=False) -> List[ImageRow]: + query = sa.select(ImageRow) + if load_aliases: + query = query.options(selectinload(ImageRow.aliases)) + result = await session.execute(query) + return result.scalars().all() + + def __str__(self) -> str: + return self.image_ref.canonical + f' ({self.image_ref.architecture})' + + def __repr__(self) -> str: + return self.__str__() + + async def get_slot_ranges( + self, + shared_config: SharedConfig, + ) -> Tuple[ResourceSlot, ResourceSlot]: + slot_units = await shared_config.get_resource_slots() + min_slot = ResourceSlot() + max_slot = ResourceSlot() + + for slot_key, resource in self.resources.items(): + slot_unit = slot_units.get(slot_key) + if slot_unit is None: + # ignore unknown slots + continue + min_value = resource.get('min') + if min_value is None: + min_value = Decimal(0) + max_value = resource.get('max') + if max_value is None: + max_value = Decimal('Infinity') + if slot_unit == 'bytes': + if not isinstance(min_value, Decimal): + min_value = BinarySize.from_str(min_value) + if not isinstance(max_value, Decimal): + max_value = BinarySize.from_str(max_value) + else: + if not isinstance(min_value, Decimal): + min_value = Decimal(min_value) + if not isinstance(max_value, Decimal): + max_value = Decimal(max_value) + min_slot[slot_key] = min_value + max_slot[slot_key] = max_value + + # fill missing + for slot_key in slot_units.keys(): + if slot_key not in min_slot: + min_slot[slot_key] = Decimal(0) + if slot_key not in max_slot: + max_slot[slot_key] = Decimal('Infinity') + + return min_slot, max_slot + + def _parse_row(self): + res_limits = [] + for slot_key, slot_range in self.resources.items(): + min_value = slot_range.get('min') + if min_value is None: + min_value = Decimal(0) + max_value = slot_range.get('max') + if max_value is None: + max_value = Decimal('Infinity') + res_limits.append({ + 'key': slot_key, + 'min': min_value, + 'max': max_value, + }) + + accels = self.accelerators + if accels is None: + accels = [] + else: + accels = accels.split(',') + + return { + 'canonical_ref': self.name, + 'name': self.image, + 'humanized_name': self.image, # TODO: implement + 'tag': self.tag, + 'architecture': self.architecture, + 'registry': self.registry, + 'digest': self.config_digest, + 'labels': self.labels, + 'size_bytes': self.size_bytes, + 'resource_limits': res_limits, + 'supported_accelerators': accels, + } + + async def inspect(self) -> Mapping[str, Any]: + parsed_image_info = self._parse_row() + parsed_image_info['reverse_aliases'] = [x.alias for x in self.aliases] + return parsed_image_info + + def set_resource_limit( + self, slot_type: str, + value_range: Tuple[Optional[Decimal], Optional[Decimal]], + ): + resources = self.resources + if resources.get(slot_type) is None: + resources[slot_type] = {} + if value_range[0] is not None: + resources[slot_type]['min'] = str(value_range[0]) + if value_range[1] is not None: + resources[slot_type]['max'] = str(value_range[1]) + + self.resources = resources + + +class ImageAliasRow(Base): + __tablename__ = 'image_aliases' + id = IDColumn('id') + alias = sa.Column('alias', sa.String, unique=True, index=True) + image_id = ForeignKeyIDColumn('image', 'images.id', nullable=False) + image: relationship + + @classmethod + async def create( + cls, + session: AsyncSession, + alias: str, + target: ImageRow, + ) -> ImageAliasRow: + existing_alias: Optional[ImageRow] = await session.scalar( + sa.select(ImageAliasRow) + .where(ImageAliasRow.alias == alias) + .options(selectinload(ImageAliasRow.image)), + ) + if existing_alias is not None: + raise ValueError( + f'alias already created with ({existing_alias.image})', + ) + new_alias = ImageAliasRow( + alias=alias, + image_id=target.id, + ) + session.add_all([new_alias]) + return new_alias + + +ImageRow.aliases = relationship('ImageAliasRow', back_populates='image') +ImageAliasRow.image = relationship('ImageRow', back_populates='aliases') + + +class Image(graphene.ObjectType): + id = graphene.UUID() + name = graphene.String() + humanized_name = graphene.String() + tag = graphene.String() + registry = graphene.String() + architecture = graphene.String() + digest = graphene.String() + labels = graphene.List(KVPair) + aliases = graphene.List(graphene.String) + size_bytes = BigInt() + resource_limits = graphene.List(ResourceLimit) + supported_accelerators = graphene.List(graphene.String) + installed = graphene.Boolean() + installed_agents = graphene.List(graphene.String) + # legacy field + hash = graphene.String() + + @classmethod + async def from_row( + cls, + ctx: GraphQueryContext, + row: ImageRow, + ) -> Image: + # TODO: add architecture + installed = ( + await redis.execute(ctx.redis_image, lambda r: r.scard(row.name)) + ) > 0 + _installed_agents = await redis.execute( + ctx.redis_image, + lambda r: r.smembers(row.name), + ) + installed_agents: List[str] = [] + if installed_agents is not None: + for agent_id in _installed_agents: + if isinstance(agent_id, bytes): + installed_agents.append(agent_id.decode()) + else: + installed_agents.append(agent_id) + is_superadmin = (ctx.user['role'] == UserRole.SUPERADMIN) + hide_agents = False if is_superadmin else ctx.local_config['manager']['hide-agents'] + return cls( + id=row.id, + name=row.image, + humanized_name=row.image, + tag=row.tag, + registry=row.registry, + architecture=row.architecture, + digest=row.config_digest, + labels=[ + KVPair(key=k, value=v) + for k, v in row.labels.items()], + aliases=[alias_row.alias for alias_row in row.aliases], + size_bytes=row.size_bytes, + resource_limits=[ + ResourceLimit( + key=k, + min=v.get('min', Decimal(0)), + max=v.get('max', Decimal('Infinity')), + ) + for k, v in row.resources.items()], + supported_accelerators=(row.accelerators or '').split(','), + installed=installed, + installed_agents=installed_agents if not hide_agents else None, + # legacy + hash=row.config_digest, + ) + + @classmethod + async def batch_load_by_canonical( + cls, + graph_ctx: GraphQueryContext, + image_names: Sequence[str], + ) -> Sequence[Optional[Image]]: + query = ( + sa.select(ImageRow) + .where(ImageRow.name.in_(image_names)) + .options(selectinload(ImageRow.aliases)) + ) + async with graph_ctx.db.begin_readonly_session() as session: + result = await session.execute(query) + return [ + await Image.from_row(graph_ctx, row) + for row in result.scalars.all() + ] + + @classmethod + async def batch_load_by_image_ref( + cls, + graph_ctx: GraphQueryContext, + image_refs: Sequence[ImageRef], + ) -> Sequence[Optional[Image]]: + image_names = [x.canonical for x in image_refs] + return await cls.batch_load_by_canonical(graph_ctx, image_names) + + @classmethod + async def load_item( + cls, + ctx: GraphQueryContext, + reference: str, + architecture: str, + ) -> Image: + try: + async with ctx.db.begin_readonly_session() as session: + row = await ImageRow.resolve(session, [ + ImageRef(reference, ['*'], architecture), + ImageAlias(reference), + ]) + except UnknownImageReference: + raise ImageNotFound + return await cls.from_row(ctx, row) + + @classmethod + async def load_all( + cls, + ctx: GraphQueryContext, + *, + is_installed: bool = None, + is_operation: bool = None, + ) -> Sequence[Image]: + async with ctx.db.begin_readonly_session() as session: + rows = await ImageRow.list(session, load_aliases=True) + items: List[Image] = [] + # Convert to GQL objects + for r in rows: + item = await cls.from_row(ctx, r) + items.append(item) + if is_installed is not None: + items = [*filter(lambda item: item.installed == is_installed, items)] + if is_operation is not None: + def _filter_operation(item): + for label in item.labels: + if label.key == 'ai.backend.features' and 'operation' in label.value: + return not is_operation + return not is_operation + + items = [*filter(_filter_operation, items)] + return items + + @staticmethod + async def filter_allowed( + ctx: GraphQueryContext, + items: Sequence[Image], + domain_name: str, + *, + is_installed: bool = None, + is_operation: bool = None, + ) -> Sequence[Image]: + from .domain import domains + async with ctx.db.begin() as conn: + query = ( + sa.select([domains.c.allowed_docker_registries]) + .select_from(domains) + .where(domains.c.name == domain_name) + ) + result = await conn.execute(query) + allowed_docker_registries = result.scalar() + items = [ + item for item in items + if item.registry in allowed_docker_registries + ] + if is_installed is not None: + items = [*filter(lambda item: item.installed == is_installed, items)] + if is_operation is not None: + + def _filter_operation(item): + for label in item.labels: + if label.key == 'ai.backend.features' and 'operation' in label.value: + return not is_operation + return not is_operation + + items = [*filter(_filter_operation, items)] + return items + + +class PreloadImage(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + references = graphene.List(graphene.String, required=True) + target_agents = graphene.List(graphene.String, required=True) + + ok = graphene.Boolean() + msg = graphene.String() + task_id = graphene.String() + + @staticmethod + async def mutate( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + references: Sequence[str], + target_agents: Sequence[str], + ) -> PreloadImage: + return PreloadImage(ok=False, msg='Not implemented.', task_id=None) + + +class UnloadImage(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + references = graphene.List(graphene.String, required=True) + target_agents = graphene.List(graphene.String, required=True) + + ok = graphene.Boolean() + msg = graphene.String() + task_id = graphene.String() + + @staticmethod + async def mutate( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + references: Sequence[str], + target_agents: Sequence[str], + ) -> UnloadImage: + return UnloadImage(ok=False, msg='Not implemented.', task_id=None) + + +class RescanImages(graphene.Mutation): + + allowed_roles = (UserRole.ADMIN, UserRole.SUPERADMIN) + + class Arguments: + registry = graphene.String() + + ok = graphene.Boolean() + msg = graphene.String() + task_id = graphene.UUID() + + @staticmethod + async def mutate( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + registry: str = None, + ) -> RescanImages: + log.info('rescanning docker registry {0} by API request', + f'({registry})' if registry else '(all)') + ctx: GraphQueryContext = info.context + + async def _rescan_task(reporter: ProgressReporter) -> None: + await rescan_images(ctx.etcd, ctx.db, registry, reporter=reporter) + + task_id = await ctx.background_task_manager.start(_rescan_task) + return RescanImages(ok=True, msg='', task_id=task_id) + + +class ForgetImage(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + reference = graphene.String(required=True) + architecture = graphene.String(default_value=DEFAULT_IMAGE_ARCH) + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + reference: str, + architecture: str, + ) -> ForgetImage: + log.info('forget image {0} by API request', reference) + ctx: GraphQueryContext = info.context + async with ctx.db.begin_session() as session: + image_row = await ImageRow.resolve(session, [ + ImageRef(reference, ['*'], architecture), + ImageAlias(reference), + ]) + await session.delete(image_row) + return ForgetImage(ok=True, msg='') + + +class AliasImage(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + alias = graphene.String(required=True) + target = graphene.String(required=True) + architecture = graphene.String(default_value=DEFAULT_IMAGE_ARCH) + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + alias: str, + target: str, + architecture: str, + ) -> AliasImage: + image_ref = ImageRef(target, ['*'], architecture) + log.info('alias image {0} -> {1} by API request', alias, image_ref) + ctx: GraphQueryContext = info.context + try: + async with ctx.db.begin_session() as session: + try: + image_row = await ImageRow.from_image_ref(session, image_ref, load_aliases=True) + except UnknownImageReference: + raise ImageNotFound + else: + image_row.aliases.append(ImageAliasRow(alias=alias, image_id=image_row.id)) + except ValueError as e: + return AliasImage(ok=False, msg=str(e)) + return AliasImage(ok=True, msg='') + + +class DealiasImage(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + alias = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + alias: str, + ) -> DealiasImage: + log.info('dealias image {0} by API request', alias) + ctx: GraphQueryContext = info.context + try: + async with ctx.db.begin_session() as session: + existing_alias = await session.scalar( + sa.select(ImageAliasRow) + .where(ImageAliasRow.alias == alias), + ) + if existing_alias is None: + raise DealiasImage(ok=False, msg=str('No such alias')) + await session.delete(existing_alias) + except ValueError as e: + return DealiasImage(ok=False, msg=str(e)) + return DealiasImage(ok=True, msg='') + + +class ClearImages(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + registry = graphene.String() + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + registry: str, + ) -> ClearImages: + ctx: GraphQueryContext = info.context + try: + async with ctx.db.begin_session() as session: + result = await session.execute( + sa.select(ImageRow).where(ImageRow.registry == registry)) + image_ids = [x.id for x in result.scalars().all()] + + await session.execute( + sa.delete(ImageAliasRow).where(ImageAliasRow.image_id.in_(image_ids))) + await session.execute(sa.delete(ImageRow).where(ImageRow.registry == registry)) + except ValueError as e: + return ClearImages(ok=False, msg=str(e)) + return ClearImages(ok=True, msg='') + + +class ModifyImageInput(graphene.InputObjectType): + name = graphene.String(required=False) + registry = graphene.String(required=False) + image = graphene.String(required=False) + tag = graphene.String(required=False) + architecture = graphene.String(required=False) + size_bytes = graphene.Int(required=False) + type = graphene.String(required=False) + + digest = graphene.String(required=False) + labels = graphene.List(lambda: KVPairInput, required=False) + supported_accelerators = graphene.List(graphene.String, required=False) + resource_limits = graphene.List(lambda: ResourceLimitInput, required=False) + + +class ModifyImage(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + target = graphene.String(required=True, default_value=None) + architecture = graphene.String(required=False, default_value=DEFAULT_IMAGE_ARCH) + props = ModifyImageInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @staticmethod + async def mutate( + executor: AsyncioExecutor, + info: graphene.ResolveInfo, + target: str, + architecture: str, + props: ModifyImageInput, + ) -> AliasImage: + ctx: GraphQueryContext = info.context + data: MutableMapping[str, Any] = {} + set_if_set(props, data, 'name') + set_if_set(props, data, 'registry') + set_if_set(props, data, 'image') + set_if_set(props, data, 'tag') + set_if_set(props, data, 'architecture') + set_if_set(props, data, 'size_bytes') + set_if_set(props, data, 'type') + set_if_set(props, data, 'digest', target_key='config_digest') + set_if_set( + props, data, 'supported_accelerators', + clean_func=lambda v: ','.join(v), target_key='accelerators', + ) + set_if_set(props, data, 'labels', clean_func=lambda v: {pair.key: pair.value for pair in v}) + + if props.resource_limits is not None: + resources_data = {} + for limit_option in props.resource_limits: + limit_data = {} + if limit_option.min is not None and len(limit_option.min) > 0: + limit_data['min'] = limit_option.min + if limit_option.max is not None and len(limit_option.max) > 0: + limit_data['max'] = limit_option.max + resources_data[limit_option.key] = limit_data + data['resources'] = resources_data + + try: + async with ctx.db.begin_session() as session: + image_ref = ImageRef(target, ['*'], architecture) + try: + row = await ImageRow.from_image_ref(session, image_ref) + except UnknownImageReference: + return ModifyImage(ok=False, msg='Image not found') + for k, v in data.items(): + setattr(row, k, v) + except ValueError as e: + return ModifyImage(ok=False, msg=str(e)) + return ModifyImage(ok=True, msg='') diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py new file mode 100644 index 0000000000..c24d14109e --- /dev/null +++ b/src/ai/backend/manager/models/kernel.py @@ -0,0 +1,1542 @@ +from __future__ import annotations + +from collections import OrderedDict +from datetime import datetime +from decimal import Decimal +import enum +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Type, + TypedDict, + TypeVar, + TYPE_CHECKING, + Union, +) +from uuid import UUID +import uuid + +import aioredis +import aioredis.client +from dateutil.parser import parse as dtparse +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +import sqlalchemy as sa +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +from sqlalchemy.dialects import postgresql as pgsql + +from ai.backend.common import msgpack, redis +from ai.backend.common.types import ( + AccessKey, + BinarySize, + ClusterMode, + KernelId, + RedisConnectionInfo, + SessionId, + SessionTypes, + SessionResult, + SlotName, + ResourceSlot, + VFolderMount, +) + +from ..defs import DEFAULT_ROLE +from .base import ( + BigInt, + EnumType, + GUID, + Item, + KernelIDColumn, + PaginatedList, + ResourceSlotColumn, + SessionIDColumnType, + StructuredJSONObjectListColumn, + URLColumn, + batch_result, + batch_multiresult, + metadata, +) +from .group import groups +from .minilang.queryfilter import QueryFilterParser +from .minilang.ordering import QueryOrderParser +from .user import users +if TYPE_CHECKING: + from .gql import GraphQueryContext + +__all__ = ( + 'kernels', + 'session_dependencies', + 'KernelStatistics', + 'KernelStatus', + 'ComputeContainer', + 'ComputeSession', + 'ComputeContainerList', + 'ComputeSessionList', + 'LegacyComputeSession', + 'LegacyComputeSessionList', + 'AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES', + 'USER_RESOURCE_OCCUPYING_KERNEL_STATUSES', + 'RESOURCE_USAGE_KERNEL_STATUSES', + 'DEAD_KERNEL_STATUSES', + 'LIVE_STATUS', + 'recalc_concurrency_used', +) + + +class KernelStatus(enum.Enum): + # values are only meaningful inside the manager + PENDING = 0 + # --- + SCHEDULED = 5 + PREPARING = 10 + # --- + BUILDING = 20 + PULLING = 21 + # --- + RUNNING = 30 + RESTARTING = 31 + RESIZING = 32 + SUSPENDED = 33 + # --- + TERMINATING = 40 + TERMINATED = 41 + ERROR = 42 + CANCELLED = 43 + + +# statuses to consider when calculating current resource usage +AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES = tuple( + e for e in KernelStatus + if e not in ( + KernelStatus.TERMINATED, + KernelStatus.PENDING, + KernelStatus.CANCELLED, + ) +) + +USER_RESOURCE_OCCUPYING_KERNEL_STATUSES = tuple( + e for e in KernelStatus + if e not in ( + KernelStatus.TERMINATING, + KernelStatus.TERMINATED, + KernelStatus.PENDING, + KernelStatus.CANCELLED, + ) +) + +# statuses to consider when calculating historical resource usage +RESOURCE_USAGE_KERNEL_STATUSES = ( + KernelStatus.TERMINATED, + KernelStatus.RUNNING, +) + +DEAD_KERNEL_STATUSES = ( + KernelStatus.CANCELLED, + KernelStatus.TERMINATED, +) + +LIVE_STATUS = ( + KernelStatus.RUNNING, +) + + +def default_hostname(context) -> str: + params = context.get_current_parameters() + return f"{params['cluster_role']}{params['cluster_idx']}" + + +kernels = sa.Table( + 'kernels', metadata, + # The Backend.AI-side UUID for each kernel + # (mapped to a container in the docker backend and a pod in the k8s backend) + KernelIDColumn(), + # session_id == id when the kernel is the main container in a multi-container session or a + # single-container session. + # Otherwise, it refers the kernel ID of the main contaienr of the belonged multi-container session. + sa.Column('session_id', SessionIDColumnType, unique=False, index=True, nullable=False), + sa.Column('session_creation_id', sa.String(length=32), unique=False, index=False), + sa.Column('session_name', sa.String(length=64), unique=False, index=True), # previously sess_id + sa.Column('session_type', EnumType(SessionTypes), index=True, nullable=False, # previously sess_type + default=SessionTypes.INTERACTIVE, server_default=SessionTypes.INTERACTIVE.name), + sa.Column('cluster_mode', sa.String(length=16), nullable=False, + default=ClusterMode.SINGLE_NODE, server_default=ClusterMode.SINGLE_NODE.name), + sa.Column('cluster_size', sa.Integer, nullable=False, default=1), + sa.Column('cluster_role', sa.String(length=16), nullable=False, default=DEFAULT_ROLE, index=True), + sa.Column('cluster_idx', sa.Integer, nullable=False, default=0), + sa.Column('cluster_hostname', sa.String(length=64), nullable=False, default=default_hostname), + + # Resource ownership + sa.Column('scaling_group', sa.ForeignKey('scaling_groups.name'), index=True, nullable=True), + sa.Column('agent', sa.String(length=64), sa.ForeignKey('agents.id'), nullable=True), + sa.Column('agent_addr', sa.String(length=128), nullable=True), + sa.Column('domain_name', sa.String(length=64), sa.ForeignKey('domains.name'), nullable=False), + sa.Column('group_id', GUID, sa.ForeignKey('groups.id'), nullable=False), + sa.Column('user_uuid', GUID, sa.ForeignKey('users.uuid'), nullable=False), + sa.Column('access_key', sa.String(length=20), sa.ForeignKey('keypairs.access_key')), + sa.Column('image', sa.String(length=512)), + sa.Column('architecture', sa.String(length=32), default='x86_64'), + sa.Column('registry', sa.String(length=512)), + sa.Column('tag', sa.String(length=64), nullable=True), + + # Resource occupation + sa.Column('container_id', sa.String(length=64)), + sa.Column('occupied_slots', ResourceSlotColumn(), nullable=False), + sa.Column('occupied_shares', pgsql.JSONB(), nullable=False, default={}), # legacy + sa.Column('environ', sa.ARRAY(sa.String), nullable=True), + sa.Column('mounts', sa.ARRAY(sa.String), nullable=True), # list of list; legacy since 22.03 + sa.Column('mount_map', pgsql.JSONB(), nullable=True, default={}), # legacy since 22.03 + sa.Column('vfolder_mounts', StructuredJSONObjectListColumn(VFolderMount), nullable=True), + sa.Column('attached_devices', pgsql.JSONB(), nullable=True, default={}), + sa.Column('resource_opts', pgsql.JSONB(), nullable=True, default={}), + sa.Column('bootstrap_script', sa.String(length=16 * 1024), nullable=True), + + # Port mappings + # If kernel_host is NULL, it is assumed to be same to the agent host or IP. + sa.Column('kernel_host', sa.String(length=128), nullable=True), + sa.Column('repl_in_port', sa.Integer(), nullable=False), + sa.Column('repl_out_port', sa.Integer(), nullable=False), + sa.Column('stdin_port', sa.Integer(), nullable=False), # legacy for stream_pty + sa.Column('stdout_port', sa.Integer(), nullable=False), # legacy for stream_pty + sa.Column('service_ports', pgsql.JSONB(), nullable=True), + sa.Column('preopen_ports', sa.ARRAY(sa.Integer), nullable=True), + + # Lifecycle + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True), + sa.Column('terminated_at', sa.DateTime(timezone=True), + nullable=True, default=sa.null(), index=True), + sa.Column('starts_at', sa.DateTime(timezone=True), + nullable=True, default=sa.null()), + sa.Column('status', EnumType(KernelStatus), + default=KernelStatus.PENDING, + server_default=KernelStatus.PENDING.name, + nullable=False, index=True), + sa.Column('status_changed', sa.DateTime(timezone=True), nullable=True, index=True), + sa.Column('status_info', sa.Unicode(), nullable=True, default=sa.null()), + # status_info contains a kebab-cased string that expresses a summary of the last status change. + # Examples: "user-requested", "self-terminated", "predicate-checks-failed", "no-available-instances" + + sa.Column('status_data', pgsql.JSONB(), nullable=True, default=sa.null()), + # status_data contains a JSON object that contains detailed data for the last status change. + # During scheduling (as PENDING + ("no-available-instances" | "predicate-checks-failed")): + # { + # "scheduler": { + # // shceudler attempt information + # // NOTE: the whole field may be NULL before the first attempt! + # "retries": 5, + # // the number of scheudling attempts (used to avoid HoL blocking as well) + # "last_try": "2021-05-01T12:34:56.123456+09:00", + # // an ISO 8601 formatted timestamp of the last attempt + # "failed_predicates": [ + # { "name": "concurrency", "msg": "You cannot run more than 30 concurrent sessions." }, + # // see the manager.scheduler.predicates module for possible messages + # ... + # ], + # "passed_predicates": [ {"name": "reserved_time"}, ... ], // names only + # } + # } + # + # While running: the field is NULL. + # + # After termination: + # { + # "kernel": { + # // termination info for the individual kernel + # "exit_code": 123, + # // maybe null during termination + # }, + # "session": { + # // termination info for the session + # "status": "terminating" | "terminated" + # // "terminated" means all kernels that belong to the same session has terminated. + # // used to prevent duplication of SessionTerminatedEvent + # } + # } + sa.Column('callback_url', URLColumn, nullable=True, default=sa.null()), + + sa.Column('startup_command', sa.Text, nullable=True), + sa.Column('result', EnumType(SessionResult), + default=SessionResult.UNDEFINED, + server_default=SessionResult.UNDEFINED.name, + nullable=False, index=True), + sa.Column('internal_data', pgsql.JSONB(), nullable=True), + sa.Column('container_log', sa.LargeBinary(), nullable=True), + # Resource metrics measured upon termination + sa.Column('num_queries', sa.BigInteger(), default=0), + sa.Column('last_stat', pgsql.JSONB(), nullable=True, default=sa.null()), + + sa.Index('ix_kernels_sess_id_role', 'session_id', 'cluster_role', unique=False), + sa.Index('ix_kernels_status_role', 'status', 'cluster_role'), + sa.Index('ix_kernels_updated_order', + sa.func.greatest('created_at', 'terminated_at', 'status_changed'), + unique=False), + sa.Index('ix_kernels_unique_sess_token', 'access_key', 'session_name', + unique=True, + postgresql_where=sa.text( + "status NOT IN ('TERMINATED', 'CANCELLED') and " + "cluster_role = 'main'")), +) + +session_dependencies = sa.Table( + 'session_dependencies', metadata, + sa.Column('session_id', GUID, + sa.ForeignKey('kernels.id', onupdate='CASCADE', ondelete='CASCADE'), + index=True, nullable=False), + sa.Column('depends_on', GUID, + sa.ForeignKey('kernels.id', onupdate='CASCADE', ondelete='CASCADE'), + index=True, nullable=False), + sa.PrimaryKeyConstraint('session_id', 'depends_on'), +) + +DEFAULT_SESSION_ORDERING = [ + sa.desc(sa.func.greatest( + kernels.c.created_at, + kernels.c.terminated_at, + kernels.c.status_changed, + )), +] + + +class SessionInfo(TypedDict): + session_id: SessionId + session_name: str + status: KernelStatus + created_at: datetime + + +async def match_session_ids( + session_name_or_id: Union[str, UUID], + access_key: AccessKey, + *, + db_connection: SAConnection, + extra_cond=None, + for_update: bool = False, + max_matches: int = 10, +) -> Sequence[SessionInfo]: + """ + Match the prefix of session ID or session name among the sessions that belongs to the given + access key, and return the list of session IDs with matching prefixes. + """ + cond_id = ( + (sa.sql.expression.cast(kernels.c.id, sa.String).like(f'{session_name_or_id}%')) & + (kernels.c.access_key == access_key) + ) + if extra_cond is not None: + cond_id = cond_id & extra_cond + cond_equal_name = ( + (kernels.c.session_name == (f'{session_name_or_id}')) & + (kernels.c.access_key == access_key) + ) + cond_prefix_name = ( + (kernels.c.session_name.like(f'{session_name_or_id}%')) & + (kernels.c.access_key == access_key) + ) + if extra_cond is not None: + cond_equal_name = cond_equal_name & extra_cond + cond_prefix_name = cond_prefix_name & extra_cond + cond_session_id = ( + (sa.sql.expression.cast(kernels.c.session_id, sa.String).like(f'{session_name_or_id}%')) & + (kernels.c.access_key == access_key) + ) + if extra_cond is not None: + cond_session_id = cond_session_id & extra_cond + info_cols = [ + kernels.c.session_id, + kernels.c.session_name, + kernels.c.status, + kernels.c.created_at, + ] + match_sid_by_id = ( + sa.select(info_cols) + .select_from(kernels) + .where( + (kernels.c.session_id.in_( + sa.select( + [kernels.c.session_id], + ) + .select_from(kernels) + .where(cond_id) + .group_by(kernels.c.session_id) + .limit(max_matches).offset(0), + )) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + .order_by(sa.desc(kernels.c.created_at)) + ) + if for_update: + match_sid_by_id = match_sid_by_id.with_for_update() + match_sid_by_equal_name = ( + sa.select(info_cols) + .select_from(kernels) + .where( + (kernels.c.session_id.in_( + sa.select( + [kernels.c.session_id], + ) + .select_from(kernels) + .where(cond_equal_name) + .group_by(kernels.c.session_id) + .limit(max_matches).offset(0), + )) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + .order_by(sa.desc(kernels.c.created_at)) + ) + match_sid_by_prefix_name = ( + sa.select(info_cols) + .select_from(kernels) + .where( + (kernels.c.session_id.in_( + sa.select( + [kernels.c.session_id], + ) + .select_from(kernels) + .where(cond_prefix_name) + .group_by(kernels.c.session_id) + .limit(max_matches).offset(0), + )) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + .order_by(sa.desc(kernels.c.created_at)) + ) + if for_update: + match_sid_by_equal_name = match_sid_by_equal_name.with_for_update() + match_sid_by_prefix_name = match_sid_by_prefix_name.with_for_update() + match_sid_by_session_id = ( + sa.select(info_cols) + .select_from(kernels) + .where( + (kernels.c.session_id.in_( + sa.select( + [kernels.c.session_id], + ) + .select_from(kernels) + .where(cond_session_id) + .group_by(kernels.c.session_id) + .limit(max_matches).offset(0), + )) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + .order_by(sa.desc(kernels.c.created_at)) + ) + if for_update: + match_sid_by_session_id = match_sid_by_session_id.with_for_update() + for match_query in [ + match_sid_by_session_id, + match_sid_by_equal_name, + match_sid_by_prefix_name, + match_sid_by_id, + ]: + result = await db_connection.execute(match_query) + rows = result.fetchall() + if not rows: + continue + return [ + SessionInfo( + session_id=row['session_id'], + session_name=row['session_name'], + status=row['status'], + created_at=row['created_at'], + ) for row in rows + ] + return [] + + +async def get_main_kernels( + session_ids: Sequence[SessionId], + *, + db_connection: SAConnection, + for_update: bool = False, +) -> Sequence[Row]: + """ + Return a list of the main kernels for the given session IDs. + If a given session ID does not exist, its position will be ``None``. + """ + session_id_to_rows = OrderedDict( + (session_id, None) for session_id in session_ids + ) + query = ( + sa.select([kernels]) + .select_from(kernels) + .where( + (kernels.c.session_id.in_(session_ids)) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + ) + result = await db_connection.execute(query) + for row in result.fetchall(): + session_id_to_rows[row['session_id']] = row + return [*session_id_to_rows.values()] + + +async def get_all_kernels( + session_ids: Sequence[SessionId], + *, + db_connection: SAConnection, + for_update: bool = False, +) -> Sequence[Sequence[Row]]: + """ + Return a list of all belonging kernel lists per the given session IDs + in the order they are given. + If a given session ID does not exist, an empty list will be returned + at the position of that session ID. + """ + session_id_to_rowsets: Dict[SessionId, List[Row]] + session_id_to_rowsets = OrderedDict( + (session_id, []) for session_id in session_ids + ) + for session_id in session_ids: + query = ( + sa.select([sa.text('*')]) + .select_from(kernels) + .where( + (kernels.c.session_id == session_id), + ) + ) + result = await db_connection.execute(query) + if result.rowcount == 0: + continue + session_id_to_rowsets[session_id].extend( + row for row in result.fetchall() + ) + return [*session_id_to_rowsets.values()] + + +class KernelStatistics: + @classmethod + async def batch_load_by_kernel( + cls, + ctx: GraphQueryContext, + session_ids: Sequence[SessionId], + ) -> Sequence[Optional[Mapping[str, Any]]]: + + def _build_pipeline(redis: aioredis.Redis) -> aioredis.client.Pipeline: + pipe = redis.pipeline() + for sess_id in session_ids: + pipe.get(str(sess_id)) + return pipe + + stats = [] + results = await redis.execute(ctx.redis_stat, _build_pipeline) + for result in results: + if result is not None: + stats.append(msgpack.unpackb(result)) + else: + stats.append(None) + return stats + + +class ComputeContainer(graphene.ObjectType): + class Meta: + interfaces = (Item, ) + + # identity + idx = graphene.Int() # legacy + role = graphene.String() # legacy + hostname = graphene.String() # legacy + cluster_idx = graphene.Int() + cluster_role = graphene.String() + cluster_hostname = graphene.String() + session_id = graphene.UUID() # owner session + + # image + image = graphene.String() + architecture = graphene.String() + registry = graphene.String() + + # status + status = graphene.String() + status_changed = GQLDateTime() + status_info = graphene.String() + status_data = graphene.JSONString() + created_at = GQLDateTime() + terminated_at = GQLDateTime() + starts_at = GQLDateTime() + + # resources + agent = graphene.String() + container_id = graphene.String() + resource_opts = graphene.JSONString() + occupied_slots = graphene.JSONString() + live_stat = graphene.JSONString() + last_stat = graphene.JSONString() + + @classmethod + def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]: + assert row is not None + from .user import UserRole + is_superadmin = (ctx.user['role'] == UserRole.SUPERADMIN) + if is_superadmin: + hide_agents = False + else: + hide_agents = ctx.local_config['manager']['hide-agents'] + return { + # identity + 'id': row['id'], + 'idx': row['cluster_idx'], + 'role': row['cluster_role'], + 'hostname': row['cluster_hostname'], + 'cluster_idx': row['cluster_idx'], + 'cluster_role': row['cluster_role'], + 'cluster_hostname': row['cluster_hostname'], + 'session_id': row['session_id'], + + # image + 'image': row['image'], + 'architecture': row['architecture'], + 'registry': row['registry'], + + # status + 'status': row['status'].name, + 'status_changed': row['status_changed'], + 'status_info': row['status_info'], + 'status_data': row['status_data'], + 'created_at': row['created_at'], + 'terminated_at': row['terminated_at'], + 'starts_at': row['starts_at'], + 'occupied_slots': row['occupied_slots'].to_json(), + + # resources + 'agent': row['agent'] if not hide_agents else None, + 'container_id': row['container_id'] if not hide_agents else None, + 'resource_opts': row['resource_opts'], + + # statistics + # last_stat is resolved by Graphene (resolve_last_stat method) + } + + @classmethod + def from_row(cls, ctx: GraphQueryContext, row: Row) -> Optional[ComputeContainer]: + if row is None: + return None + props = cls.parse_row(ctx, row) + return cls(**props) + + # last_stat also fetches data from Redis, meaning that + # both live_stat and last_stat will reference same data from same source + # we can leave last_stat value for legacy support, as an alias to last_stat + async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Optional[Mapping[str, Any]]: + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, 'KernelStatistics.by_kernel') + return await loader.load(self.id) + + async def resolve_last_stat(self, info: graphene.ResolveInfo) -> Optional[Mapping[str, Any]]: + return await self.resolve_live_stat(info) + + _queryfilter_fieldspec = { + "image": ("image", None), + "architecture": ("architecture", None), + "agent": ("agent", None), + "cluster_idx": ("cluster_idx", None), + "cluster_role": ("cluster_role", None), + "cluster_hostname": ("cluster_hostname", None), + "status": ("status", lambda s: KernelStatus[s]), + "status_info": ("status_info", None), + "created_at": ("created_at", dtparse), + "status_changed": ("status_changed", dtparse), + "terminated_at": ("terminated_at", dtparse), + } + + _queryorder_colmap = { + "image": "image", + "architecture": "architecture", + "agent": "agent", + "cluster_idx": "cluster_idx", + "cluster_role": "cluster_role", + "cluster_hostname": "cluster_hostname", + "status": "status", + "status_info": "status_info", + "status_changed": "status_info", + "created_at": "created_at", + "terminated_at": "terminated_at", + } + + @classmethod + async def load_count( + cls, + ctx: GraphQueryContext, + session_id: SessionId, + *, + cluster_role: str = None, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: str = None, + filter: str = None, + ) -> int: + query = ( + sa.select([sa.func.count()]) + .select_from(kernels) + .where(kernels.c.session_id == session_id) + ) + if cluster_role is not None: + query = query.where(kernels.c.cluster_role == cluster_role) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + async with ctx.db.begin_readonly() as conn: + result = await conn.execute(query) + return result.scalar() + + @classmethod + async def load_slice( + cls, + ctx: GraphQueryContext, + limit: int, + offset: int, + session_id: SessionId, + *, + cluster_role: str = None, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: AccessKey = None, + filter: str = None, + order: str = None, + ) -> Sequence[Optional[ComputeContainer]]: + query = ( + sa.select([kernels]) + .select_from(kernels) + .where(kernels.c.session_id == session_id) + .limit(limit) + .offset(offset) + ) + if cluster_role is not None: + query = query.where(kernels.c.cluster_role == cluster_role) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + if order is not None: + qoparser = QueryOrderParser(cls._queryorder_colmap) + query = qoparser.append_ordering(query, order) + else: + query = query.order_by(*DEFAULT_SESSION_ORDERING) + async with ctx.db.begin_readonly() as conn: + return [cls.from_row(ctx, r) async for r in (await conn.stream(query))] + + @classmethod + async def batch_load_by_session( + cls, + ctx: GraphQueryContext, + session_ids: Sequence[SessionId], + ) -> Sequence[Sequence[ComputeContainer]]: + query = ( + sa.select([kernels]) + .select_from(kernels) + # TODO: use "owner session ID" when we implement multi-container session + .where(kernels.c.session_id.in_(session_ids)) + ) + async with ctx.db.begin_readonly() as conn: + return await batch_multiresult( + ctx, conn, query, cls, + session_ids, lambda row: row['session_id'], + ) + + @classmethod + async def batch_load_detail( + cls, + ctx: GraphQueryContext, + container_ids: Sequence[KernelId], + *, + domain_name: str = None, + access_key: AccessKey = None, + ) -> Sequence[Optional[ComputeContainer]]: + j = ( + kernels + .join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid) + ) + query = ( + sa.select([kernels]) + .select_from(j) + .where( + (kernels.c.id.in_(container_ids)), + )) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + container_ids, lambda row: row['id'], + ) + + +class ComputeSession(graphene.ObjectType): + class Meta: + interfaces = (Item, ) + + # identity + tag = graphene.String() + name = graphene.String() + type = graphene.String() + session_id = graphene.UUID() + + # image + image = graphene.String() # image for the main container + architecture = graphene.String() # image architecture for the main container + registry = graphene.String() # image registry for the main container + cluster_template = graphene.String() + cluster_mode = graphene.String() + cluster_size = graphene.Int() + + # ownership + domain_name = graphene.String() + group_name = graphene.String() + group_id = graphene.UUID() + user_email = graphene.String() + user_id = graphene.UUID() + access_key = graphene.String() + created_user_email = graphene.String() + created_user_id = graphene.UUID() + + # status + status = graphene.String() + status_changed = GQLDateTime() + status_info = graphene.String() + status_data = graphene.JSONString() + created_at = GQLDateTime() + terminated_at = GQLDateTime() + starts_at = GQLDateTime() + startup_command = graphene.String() + result = graphene.String() + + # resources + resource_opts = graphene.JSONString() + scaling_group = graphene.String() + service_ports = graphene.JSONString() + mounts = graphene.List(lambda: graphene.String) + occupied_slots = graphene.JSONString() + + # statistics + num_queries = BigInt() + + # owned containers (aka kernels) + containers = graphene.List(lambda: ComputeContainer) + + # relations + dependencies = graphene.List(lambda: ComputeSession) + + @classmethod + def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]: + assert row is not None + return { + # identity + 'id': row['id'], + 'tag': row['tag'], + 'name': row['session_name'], + 'type': row['session_type'].name, + 'session_id': row['session_id'], + + # image + 'image': row['image'], + 'architecture': row['architecture'], + 'registry': row['registry'], + 'cluster_template': None, # TODO: implement + 'cluster_mode': row['cluster_mode'], + 'cluster_size': row['cluster_size'], + + # ownership + 'domain_name': row['domain_name'], + 'group_name': row['group_name'], + 'group_id': row['group_id'], + 'user_email': row['email'], + 'user_id': row['user_uuid'], + 'access_key': row['access_key'], + 'created_user_email': None, # TODO: implement + 'created_user_id': None, # TODO: implement + + # status + 'status': row['status'].name, + 'status_changed': row['status_changed'], + 'status_info': row['status_info'], + 'status_data': row['status_data'], + 'created_at': row['created_at'], + 'terminated_at': row['terminated_at'], + 'starts_at': row['starts_at'], + 'startup_command': row['startup_command'], + 'result': row['result'].name, + + # resources + 'resource_opts': row['resource_opts'], + 'scaling_group': row['scaling_group'], + 'service_ports': row['service_ports'], + 'mounts': row['mounts'], + + # statistics + 'num_queries': row['num_queries'], + } + + @classmethod + def from_row(cls, ctx: GraphQueryContext, row: Row) -> ComputeSession | None: + if row is None: + return None + props = cls.parse_row(ctx, row) + return cls(**props) + + async def resolve_occupied_slots(self, info: graphene.ResolveInfo) -> Mapping[str, Any]: + """ + Calculate the sum of occupied resource slots of all sub-kernels, + and return the JSON-serializable object from the sum result. + """ + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, 'ComputeContainer.by_session') + containers = await loader.load(self.session_id) + zero = ResourceSlot() + return sum( + (ResourceSlot({ + SlotName(k): Decimal(v) for k, v in c.occupied_slots.items() + }) for c in containers), + start=zero, + ).to_json() + + async def resolve_containers( + self, + info: graphene.ResolveInfo, + ) -> Iterable[ComputeContainer]: + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, 'ComputeContainer.by_session') + return await loader.load(self.session_id) + + async def resolve_dependencies( + self, + info: graphene.ResolveInfo, + ) -> Iterable[ComputeSession]: + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, 'ComputeSession.by_dependency') + return await loader.load(self.id) + + _queryfilter_fieldspec = { + "type": ("kernels_session_type", lambda s: SessionTypes[s]), + "name": ("kernels_session_name", None), + "image": ("kernels_image", None), + "architecture": ("kernels_architecture", None), + "domain_name": ("kernels_domain_name", None), + "group_name": ("groups_group_name", None), + "user_email": ("users_email", None), + "access_key": ("kernels_access_key", None), + "scaling_group": ("kernels_scaling_group", None), + "cluster_mode": ("kernels_cluster_mode", lambda s: ClusterMode[s]), + "cluster_template": ("kernels_cluster_template", None), + "cluster_size": ("kernels_cluster_size", None), + "status": ("kernels_status", lambda s: KernelStatus[s]), + "status_info": ("kernels_status_info", None), + "status_changed": ("kernels_status_changed", dtparse), + "result": ("kernels_result", lambda s: SessionResult[s]), + "created_at": ("kernels_created_at", dtparse), + "terminated_at": ("kernels_terminated_at", dtparse), + "starts_at": ("kernels_starts_at", dtparse), + "startup_command": ("kernels_startup_command", None), + "agent": ("kernels_agent", None), + "agents": ("kernels_agent", None), + } + + _queryorder_colmap = { + "id": "kernels_id", + "type": "kernels_session_type", + "name": "kernels_session_name", + "image": "kernels_image", + "architecture": "kernels_architecture", + "domain_name": "kernels_domain_name", + "group_name": "kernels_group_name", + "user_email": "users_email", + "access_key": "kernels_access_key", + "scaling_group": "kernels_scaling_group", + "cluster_mode": "kernels_cluster_mode", + "cluster_template": "kernels_cluster_template", + "cluster_size": "kernels_cluster_size", + "status": "kernels_status", + "status_info": "kernels_status_info", + "status_changed": "kernels_status_info", + "result": "kernels_result", + "created_at": "kernels_created_at", + "terminated_at": "kernels_terminated_at", + "starts_at": "kernels_starts_at", + } + + @classmethod + async def load_count( + cls, + ctx: GraphQueryContext, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: str = None, + status: str = None, + filter: str = None, + ) -> int: + if isinstance(status, str): + status_list = [KernelStatus[s] for s in status.split(',')] + elif isinstance(status, KernelStatus): + status_list = [status] + j = ( + kernels + .join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid) + ) + query = ( + sa.select([sa.func.count()]) + .select_from(j) + .where(kernels.c.cluster_role == DEFAULT_ROLE) + ) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if status is not None: + query = query.where(kernels.c.status.in_(status_list)) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + async with ctx.db.begin_readonly() as conn: + result = await conn.execute(query) + return result.scalar() + + @classmethod + async def load_slice( + cls, + ctx: GraphQueryContext, + limit: int, + offset: int, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: str = None, + status: str = None, + filter: str = None, + order: str = None, + ) -> Sequence[ComputeSession | None]: + if isinstance(status, str): + status_list = [KernelStatus[s] for s in status.split(',')] + elif isinstance(status, KernelStatus): + status_list = [status] + j = ( + kernels + .join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid) + ) + query = ( + sa.select([ + kernels, + groups.c.name.label('group_name'), + users.c.email, + ]) + .select_from(j) + .where(kernels.c.cluster_role == DEFAULT_ROLE) + .limit(limit) + .offset(offset) + ) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if status is not None: + query = query.where(kernels.c.status.in_(status_list)) + if filter is not None: + parser = QueryFilterParser(cls._queryfilter_fieldspec) + query = parser.append_filter(query, filter) + if order is not None: + qoparser = QueryOrderParser(cls._queryorder_colmap) + query = qoparser.append_ordering(query, order) + else: + query = query.order_by(*DEFAULT_SESSION_ORDERING) + async with ctx.db.begin_readonly() as conn: + return [cls.from_row(ctx, r) async for r in (await conn.stream(query))] + + @classmethod + async def batch_load_by_dependency( + cls, + ctx: GraphQueryContext, + session_ids: Sequence[SessionId], + ) -> Sequence[Sequence[ComputeSession]]: + j = sa.join( + kernels, session_dependencies, + kernels.c.session_id == session_dependencies.c.depends_on, + ) + query = ( + sa.select([kernels]) + .select_from(j) + .where( + (kernels.c.cluster_role == DEFAULT_ROLE) & + (session_dependencies.c.session_id.in_(session_ids)), + ) + ) + async with ctx.db.begin_readonly() as conn: + return await batch_multiresult( + ctx, conn, query, cls, + session_ids, lambda row: row['id'], + ) + + @classmethod + async def batch_load_detail( + cls, + ctx: GraphQueryContext, + session_ids: Sequence[SessionId], + *, + domain_name: str = None, + access_key: str = None, + ) -> Sequence[ComputeSession | None]: + j = ( + kernels + .join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid) + ) + query = ( + sa.select([ + kernels, + groups.c.name.label('group_name'), + users.c.email, + ]) + .select_from(j) + .where( + (kernels.c.cluster_role == DEFAULT_ROLE) & + (kernels.c.id.in_(session_ids)), + )) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + session_ids, lambda row: row['id'], + ) + + +class ComputeContainerList(graphene.ObjectType): + class Meta: + interfaces = (PaginatedList, ) + + items = graphene.List(ComputeContainer, required=True) + + +class ComputeSessionList(graphene.ObjectType): + class Meta: + interfaces = (PaginatedList, ) + + items = graphene.List(ComputeSession, required=True) + + +# --------- pre-v5 legacy ----------- + +MetricValueType = TypeVar('MetricValueType', int, float) + + +class LegacyComputeSession(graphene.ObjectType): + """ + Represents a main session. + """ + class Meta: + interfaces = (Item, ) + + tag = graphene.String() # Only for ComputeSession + sess_id = graphene.String() # legacy + sess_type = graphene.String() # legacy + session_name = graphene.String() + session_type = graphene.String() + role = graphene.String() + image = graphene.String() + architecture = graphene.String() + registry = graphene.String() + domain_name = graphene.String() + group_name = graphene.String() + group_id = graphene.UUID() + scaling_group = graphene.String() + user_uuid = graphene.UUID() + access_key = graphene.String() + + status = graphene.String() + status_changed = GQLDateTime() + status_info = graphene.String() + created_at = GQLDateTime() + terminated_at = GQLDateTime() + startup_command = graphene.String() + result = graphene.String() + + # hidable fields by configuration + agent = graphene.String() + container_id = graphene.String() + + service_ports = graphene.JSONString() + + occupied_slots = graphene.JSONString() + occupied_shares = graphene.JSONString() + mounts = graphene.List(lambda: graphene.List(lambda: graphene.String)) + resource_opts = graphene.JSONString() + + num_queries = BigInt() + live_stat = graphene.JSONString() + last_stat = graphene.JSONString() + + user_email = graphene.String() + + # Legacy fields + lang = graphene.String() + mem_slot = graphene.Int() + cpu_slot = graphene.Float() + gpu_slot = graphene.Float() + tpu_slot = graphene.Float() + cpu_used = BigInt() + cpu_using = graphene.Float() + mem_max_bytes = BigInt() + mem_cur_bytes = BigInt() + net_rx_bytes = BigInt() + net_tx_bytes = BigInt() + io_read_bytes = BigInt() + io_write_bytes = BigInt() + io_max_scratch_size = BigInt() + io_cur_scratch_size = BigInt() + + # last_stat also fetches data from Redis, meaning that + # both live_stat and last_stat will reference same data from same source + # we can leave last_stat value for legacy support, as an alias to last_stat + async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Optional[Mapping[str, Any]]: + graph_ctx: GraphQueryContext = info.context + loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, 'KernelStatistics.by_kernel') + return await loader.load(self.id) + + async def resolve_last_stat(self, info: graphene.ResolveInfo) -> Optional[Mapping[str, Any]]: + return await self.resolve_live_stat(info) + + async def _resolve_legacy_metric( + self, + info: graphene.ResolveInfo, + metric_key: str, + metric_field: str, + convert_type: Type[MetricValueType], + ) -> Optional[MetricValueType]: + if not hasattr(self, 'status'): + return None + graph_ctx: GraphQueryContext = info.context + if KernelStatus[self.status] not in LIVE_STATUS: + if self.last_stat is None: + return convert_type(0) + metric = self.last_stat.get(metric_key) + if metric is None: + return convert_type(0) + value = metric.get(metric_field) + if value is None: + return convert_type(0) + return convert_type(value) + else: + loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, 'KernelStatistics.by_kernel') + kstat = await loader.load(self.id) + if kstat is None: + return convert_type(0) + metric = kstat.get(metric_key) + if metric is None: + return convert_type(0) + value = metric.get(metric_field) + if value is None: + return convert_type(0) + return convert_type(value) + + async def resolve_cpu_used(self, info: graphene.ResolveInfo) -> Optional[float]: + return await self._resolve_legacy_metric(info, 'cpu_used', 'current', float) + + async def resolve_cpu_using(self, info: graphene.ResolveInfo) -> Optional[float]: + return await self._resolve_legacy_metric(info, 'cpu_util', 'pct', float) + + async def resolve_mem_max_bytes(self, info: graphene.ResolveInfo) -> Optional[int]: + return await self._resolve_legacy_metric(info, 'mem', 'stats.max', int) + + async def resolve_mem_cur_bytes(self, info: graphene.ResolveInfo) -> Optional[int]: + return await self._resolve_legacy_metric(info, 'mem', 'current', int) + + async def resolve_net_rx_bytes(self, info: graphene.ResolveInfo) -> Optional[int]: + return await self._resolve_legacy_metric(info, 'net_rx', 'stats.rate', int) + + async def resolve_net_tx_bytes(self, info: graphene.ResolveInfo) -> Optional[int]: + return await self._resolve_legacy_metric(info, 'net_tx', 'stats.rate', int) + + async def resolve_io_read_bytes(self, info: graphene.ResolveInfo) -> Optional[int]: + return await self._resolve_legacy_metric(info, 'io_read', 'current', int) + + async def resolve_io_write_bytes(self, info: graphene.ResolveInfo) -> Optional[int]: + return await self._resolve_legacy_metric(info, 'io_write', 'current', int) + + async def resolve_io_max_scratch_size(self, info: graphene.ResolveInfo) -> Optional[int]: + return await self._resolve_legacy_metric(info, 'io_scratch_size', 'stats.max', int) + + async def resolve_io_cur_scratch_size(self, info: graphene.ResolveInfo) -> Optional[int]: + return await self._resolve_legacy_metric(info, 'io_scratch_size', 'current', int) + + @classmethod + def parse_row(cls, ctx: GraphQueryContext, row: Row) -> Mapping[str, Any]: + assert row is not None + from .user import UserRole + mega = 2 ** 20 + is_superadmin = (ctx.user['role'] == UserRole.SUPERADMIN) + if is_superadmin: + hide_agents = False + else: + hide_agents = ctx.local_config['manager']['hide-agents'] + return { + 'id': row['id'], + 'sess_id': row['session_name'], # legacy, will be deprecated + 'sess_type': row['session_type'].name, # legacy, will be deprecated + 'session_name': row['session_name'], + 'session_type': row['session_type'].name, + 'role': row['cluster_role'], + 'tag': row['tag'], + 'image': row['image'], + 'architecture': row['architecture'], + 'registry': row['registry'], + 'domain_name': row['domain_name'], + 'group_name': row['name'], # group.name (group is omitted since use_labels=True is not used) + 'group_id': row['group_id'], + 'scaling_group': row['scaling_group'], + 'user_uuid': row['user_uuid'], + 'access_key': row['access_key'], + 'status': row['status'].name, + 'status_changed': row['status_changed'], + 'status_info': row['status_info'], + 'status_data': row['status_data'], + 'created_at': row['created_at'], + 'terminated_at': row['terminated_at'], + 'startup_command': row['startup_command'], + 'result': row['result'].name, + 'service_ports': row['service_ports'], + 'occupied_slots': row['occupied_slots'].to_json(), + 'vfolder_mounts': row['vfolder_mounts'], + 'resource_opts': row['resource_opts'], + 'num_queries': row['num_queries'], + # optionally hidden + 'agent': row['agent'] if not hide_agents else None, + 'container_id': row['container_id'] if not hide_agents else None, + # live_stat is resolved by Graphene + # last_stat is resolved by Graphene + 'user_email': row['email'], + # Legacy fields + # NOTE: currently graphene always uses resolve methods! + 'cpu_used': 0, + 'mem_max_bytes': 0, + 'mem_cur_bytes': 0, + 'net_rx_bytes': 0, + 'net_tx_bytes': 0, + 'io_read_bytes': 0, + 'io_write_bytes': 0, + 'io_max_scratch_size': 0, + 'io_cur_scratch_size': 0, + 'lang': row['image'], + 'occupied_shares': row['occupied_shares'], + 'mem_slot': BinarySize.from_str( + row['occupied_slots'].get('mem', 0)) // mega, + 'cpu_slot': float(row['occupied_slots'].get('cpu', 0)), + 'gpu_slot': float(row['occupied_slots'].get('cuda.device', 0)), + 'tpu_slot': float(row['occupied_slots'].get('tpu.device', 0)), + } + + @classmethod + def from_row(cls, context: GraphQueryContext, row: Row) -> Optional[LegacyComputeSession]: + if row is None: + return None + props = cls.parse_row(context, row) + return cls(**props) + + @classmethod + async def load_count( + cls, + ctx: GraphQueryContext, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: AccessKey = None, + status: str = None, + ) -> int: + if isinstance(status, str): + status_list = [KernelStatus[s] for s in status.split(',')] + elif isinstance(status, KernelStatus): + status_list = [status] + query = ( + sa.select([sa.func.count()]) + .select_from(kernels) + .where(kernels.c.cluster_role == DEFAULT_ROLE) + ) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if status is not None: + query = query.where(kernels.c.status.in_(status_list)) + async with ctx.db.begin_readonly() as conn: + result = await conn.execute(query) + return result.scalar() + + @classmethod + async def load_slice( + cls, + ctx: GraphQueryContext, + limit: int, + offset: int, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + access_key: AccessKey = None, + status: str = None, + order_key: str = None, + order_asc: bool = True, + ) -> Sequence[LegacyComputeSession]: + if isinstance(status, str): + status_list = [KernelStatus[s] for s in status.split(',')] + elif isinstance(status, KernelStatus): + status_list = [status] + if order_key is None: + _ordering = DEFAULT_SESSION_ORDERING + else: + _order_func = sa.asc if order_asc else sa.desc + _ordering = [_order_func(getattr(kernels.c, order_key))] + j = (kernels.join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid)) + query = ( + sa.select([kernels, groups.c.name, users.c.email]) + .select_from(j) + .where(kernels.c.cluster_role == DEFAULT_ROLE) + .order_by(*_ordering) + .limit(limit) + .offset(offset) + ) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if status is not None: + query = query.where(kernels.c.status.in_(status_list)) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for r in (await conn.stream(query)) + if (obj := cls.from_row(ctx, r)) is not None + ] + + @classmethod + async def batch_load( + cls, + ctx: GraphQueryContext, + access_keys: AccessKey, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + status: str = None, + ) -> Sequence[Optional[LegacyComputeSession]]: + j = (kernels.join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid)) + query = ( + sa.select([kernels, groups.c.name, users.c.email]) + .select_from(j) + .where( + (kernels.c.access_key.in_(access_keys)) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + .order_by( + sa.desc(sa.func.greatest( + kernels.c.created_at, + kernels.c.terminated_at, + kernels.c.status_changed, + )), + ) + .limit(100)) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if group_id is not None: + query = query.where(kernels.c.group_id == group_id) + if status is not None: + query = query.where(kernels.c.status == status) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + access_keys, lambda row: row['access_key'], + ) + + @classmethod + async def batch_load_detail( + cls, + ctx: GraphQueryContext, + sess_ids: Sequence[SessionId], + *, + domain_name: str = None, + access_key: AccessKey = None, + status: str = None, + ) -> Sequence[Sequence[LegacyComputeSession]]: + status_list = [] + if isinstance(status, str): + status_list = [KernelStatus[s] for s in status.split(',')] + elif isinstance(status, KernelStatus): + status_list = [status] + elif status is None: + status_list = [KernelStatus['RUNNING']] + j = (kernels.join(groups, groups.c.id == kernels.c.group_id) + .join(users, users.c.uuid == kernels.c.user_uuid)) + query = (sa.select([kernels, groups.c.name, users.c.email]) + .select_from(j) + .where((kernels.c.cluster_role == DEFAULT_ROLE) & + (kernels.c.session_id.in_(sess_ids)))) + if domain_name is not None: + query = query.where(kernels.c.domain_name == domain_name) + if access_key is not None: + query = query.where(kernels.c.access_key == access_key) + if status_list: + query = query.where(kernels.c.status.in_(status_list)) + async with ctx.db.begin_readonly() as conn: + return await batch_multiresult( + ctx, conn, query, cls, + sess_ids, lambda row: row['session_name'], + ) + + +class LegacyComputeSessionList(graphene.ObjectType): + class Meta: + interfaces = (PaginatedList, ) + + items = graphene.List(LegacyComputeSession, required=True) + + +async def recalc_concurrency_used( + db_conn: SAConnection, + redis_stat: RedisConnectionInfo, + access_key: AccessKey, +) -> None: + + concurrency_used: int + async with db_conn.begin_nested(): + query = ( + sa.select([sa.func.count()]) + .select_from(kernels) + .where( + (kernels.c.access_key == access_key) & + (kernels.c.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ) + ) + result = await db_conn.execute(query) + concurrency_used = result.first()[0] + + await redis.execute( + redis_stat, + lambda r: r.set( + f'keypair.concurrency_used.{access_key}', concurrency_used, + ), + ) diff --git a/src/ai/backend/manager/models/keypair.py b/src/ai/backend/manager/models/keypair.py new file mode 100644 index 0000000000..804c9bbc4f --- /dev/null +++ b/src/ai/backend/manager/models/keypair.py @@ -0,0 +1,616 @@ +from __future__ import annotations + +import base64 +import secrets +from typing import ( + Any, + Dict, + Optional, + Sequence, + List, TYPE_CHECKING, + Tuple, + TypedDict, +) +import uuid + +from cryptography.hazmat.primitives import serialization as crypto_serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend as crypto_default_backend +from dateutil.parser import parse as dtparse +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +from sqlalchemy.engine.row import Row +from sqlalchemy.sql.expression import false + +from ai.backend.common import msgpack, redis +from ai.backend.common.types import ( + AccessKey, + SecretKey, +) + +if TYPE_CHECKING: + from .gql import GraphQueryContext + from .vfolder import VirtualFolder + +from .base import ( + ForeignKeyIDColumn, + Item, + PaginatedList, + metadata, + batch_result, + batch_multiresult, + set_if_set, + simple_db_mutate, + simple_db_mutate_returning_item, +) +from .minilang.queryfilter import QueryFilterParser +from .minilang.ordering import QueryOrderParser +from .user import ModifyUserInput, UserRole +from ..defs import RESERVED_DOTFILES + +__all__: Sequence[str] = ( + 'keypairs', + 'KeyPair', 'KeyPairList', + 'UserInfo', + 'KeyPairInput', + 'CreateKeyPair', 'ModifyKeyPair', 'DeleteKeyPair', + 'Dotfile', 'MAXIMUM_DOTFILE_SIZE', + 'query_owned_dotfiles', + 'query_bootstrap_script', + 'verify_dotfile_name', +) + + +MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB + +keypairs = sa.Table( + 'keypairs', metadata, + sa.Column('user_id', sa.String(length=256), index=True), + sa.Column('access_key', sa.String(length=20), primary_key=True), + sa.Column('secret_key', sa.String(length=40)), + sa.Column('is_active', sa.Boolean, index=True), + sa.Column('is_admin', sa.Boolean, index=True, + default=False, server_default=false()), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now()), + sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), onupdate=sa.func.current_timestamp()), + sa.Column('last_used', sa.DateTime(timezone=True), nullable=True), + sa.Column('rate_limit', sa.Integer), + sa.Column('num_queries', sa.Integer, server_default='0'), + + # SSH Keypairs. + sa.Column('ssh_public_key', sa.String(length=750), nullable=True), + sa.Column('ssh_private_key', sa.String(length=2000), nullable=True), + + ForeignKeyIDColumn('user', 'users.uuid', nullable=False), + sa.Column('resource_policy', sa.String(length=256), + sa.ForeignKey('keypair_resource_policies.name'), + nullable=False), + # dotfiles column, \x90 means empty list in msgpack + sa.Column('dotfiles', sa.LargeBinary(length=MAXIMUM_DOTFILE_SIZE), nullable=False, default=b'\x90'), + sa.Column('bootstrap_script', sa.String(length=MAXIMUM_DOTFILE_SIZE), nullable=False, default=''), +) + + +class UserInfo(graphene.ObjectType): + email = graphene.String() + full_name = graphene.String() + + @classmethod + def from_row( + cls, + ctx: GraphQueryContext, + row: Row, + ) -> Optional[UserInfo]: + if row is None: + return None + return cls(email=row['email'], full_name=row['full_name']) + + @classmethod + async def batch_load_by_uuid( + cls, + ctx: GraphQueryContext, + user_uuids: Sequence[uuid.UUID], + ) -> Sequence[Optional[UserInfo]]: + async with ctx.db.begin_readonly() as conn: + from .user import users + query = ( + sa.select([users.c.uuid, users.c.email, users.c.full_name]) + .select_from(users) + .where(users.c.uuid.in_(user_uuids)) + ) + return await batch_result( + ctx, conn, query, cls, + user_uuids, lambda row: row['uuid'], + ) + + +class KeyPair(graphene.ObjectType): + + class Meta: + interfaces = (Item, ) + + user_id = graphene.String() + full_name = graphene.String() + access_key = graphene.String() + secret_key = graphene.String() + is_active = graphene.Boolean() + is_admin = graphene.Boolean() + resource_policy = graphene.String() + created_at = GQLDateTime() + last_used = GQLDateTime() + rate_limit = graphene.Int() + num_queries = graphene.Int() + user = graphene.UUID() + + ssh_public_key = graphene.String() + + vfolders = graphene.List('ai.backend.manager.models.VirtualFolder') + compute_sessions = graphene.List( + 'ai.backend.manager.models.ComputeSession', + status=graphene.String(), + ) + concurrency_used = graphene.Int() + + user_info = graphene.Field(lambda: UserInfo) + + # Deprecated + concurrency_limit = graphene.Int( + deprecation_reason='Moved to KeyPairResourcePolicy object as ' + 'the max_concurrent_sessions field.') + + async def resolve_user_info( + self, + info: graphene.ResolveInfo, + ) -> UserInfo: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader(ctx, 'UserInfo.by_uuid') + return await loader.load(self.user) + + @classmethod + def from_row( + cls, + ctx: GraphQueryContext, + row: Row, + ) -> KeyPair: + return cls( + id=row['access_key'], + user_id=row['user_id'], + full_name=row['full_name'] if 'full_name' in row.keys() else None, + access_key=row['access_key'], + secret_key=row['secret_key'], + is_active=row['is_active'], + is_admin=row['is_admin'], + resource_policy=row['resource_policy'], + created_at=row['created_at'], + last_used=row['last_used'], + rate_limit=row['rate_limit'], + user=row['user'], + ssh_public_key=row['ssh_public_key'], + concurrency_limit=0, # deprecated + ) + + async def resolve_num_queries(self, info: graphene.ResolveInfo) -> int: + ctx: GraphQueryContext = info.context + n = await redis.execute(ctx.redis_stat, lambda r: r.get(f"kp:{self.access_key}:num_queries")) + if n is not None: + return n + return 0 + + async def resolve_vfolders(self, info: graphene.ResolveInfo) -> Sequence[VirtualFolder]: + ctx: GraphQueryContext = info.context + loader = ctx.dataloader_manager.get_loader(ctx, 'VirtualFolder') + return await loader.load(self.access_key) + + async def resolve_compute_sessions(self, info: graphene.ResolveInfo, raw_status: str = None): + ctx: GraphQueryContext = info.context + from . import KernelStatus # noqa: avoid circular imports + if raw_status is not None: + status = KernelStatus[raw_status] + loader = ctx.dataloader_manager.get_loader(ctx, 'ComputeSession', status=status) + return await loader.load(self.access_key) + + async def resolve_concurrency_used(self, info: graphene.ResolveInfo) -> int: + ctx: GraphQueryContext = info.context + kp_key = 'keypair.concurrency_used' + concurrency_used = await redis.execute( + ctx.redis_stat, + lambda r: r.get(f'{kp_key}.{self.access_key}'), + ) + if concurrency_used is not None: + return int(concurrency_used) + return 0 + + @classmethod + async def load_all( + cls, + graph_ctx: GraphQueryContext, + *, + domain_name: str = None, + is_active: bool = None, + limit: int = None, + ) -> Sequence[KeyPair]: + from .user import users + j = sa.join( + keypairs, users, + keypairs.c.user == users.c.uuid, + ) + query = ( + sa.select([keypairs]) + .select_from(j) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if is_active is not None: + query = query.where(keypairs.c.is_active == is_active) + if limit is not None: + query = query.limit(limit) + async with graph_ctx.db.begin_readonly() as conn: + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(graph_ctx, row)) is not None + ] + + _queryfilter_fieldspec = { + "access_key": ("keypairs_access_key", None), + "user_id": ("users_uuid", None), + "email": ("users_email", None), + "full_name": ("users_full_name", None), + "is_active": ("keypairs_is_active", None), + "is_admin": ("keypairs_is_admin", None), + "resource_policy": ("keypairs_resource_policy", None), + "created_at": ("keypairs_created_at", dtparse), + "last_used": ("keypairs_last_used", dtparse), + "rate_limit": ("keypairs_rate_limit", None), + "num_queries": ("keypairs_num_queries", None), + "ssh_public_key": ("keypairs_ssh_public_key", None), + } + + _queryorder_colmap = { + "access_key": "keypairs_access_key", + "email": "users_email", + "full_name": "users_full_name", + "is_active": "keypairs_is_active", + "is_admin": "keypairs_is_admin", + "resource_policy": "keypairs_resource_policy", + "created_at": "keypairs_created_at", + "last_used": "keypairs_last_used", + "rate_limit": "keypairs_rate_limit", + "num_queries": "keypairs_num_queries", + } + + @classmethod + async def load_count( + cls, + graph_ctx: GraphQueryContext, + *, + domain_name: str = None, + email: str = None, + is_active: bool = None, + filter: str = None, + ) -> int: + from .user import users + j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) + query = ( + sa.select([sa.func.count()]) + .select_from(j) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if email is not None: + query = query.where(keypairs.c.user_id == email) + if is_active is not None: + query = query.where(keypairs.c.is_active == is_active) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + async with graph_ctx.db.begin_readonly() as conn: + result = await conn.execute(query) + return result.scalar() + + @classmethod + async def load_slice( + cls, + graph_ctx: GraphQueryContext, + limit: int, + offset: int, + *, + domain_name: str = None, + email: str = None, + is_active: bool = None, + filter: str = None, + order: str = None, + ) -> Sequence[KeyPair]: + from .user import users + j = sa.join(keypairs, users, keypairs.c.user == users.c.uuid) + query = ( + sa.select([keypairs, users.c.email, users.c.full_name]) + .select_from(j) + .limit(limit) + .offset(offset) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if email is not None: + query = query.where(keypairs.c.user_id == email) + if is_active is not None: + query = query.where(keypairs.c.is_active == is_active) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + if order is not None: + qoparser = QueryOrderParser(cls._queryorder_colmap) + query = qoparser.append_ordering(query, order) + else: + query = query.order_by(keypairs.c.created_at.desc()) + async with graph_ctx.db.begin_readonly() as conn: + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(graph_ctx, row)) is not None + ] + + @classmethod + async def batch_load_by_email( + cls, + graph_ctx: GraphQueryContext, + user_ids: Sequence[uuid.UUID], + *, + domain_name: str = None, + is_active: bool = None, + ) -> Sequence[Sequence[Optional[KeyPair]]]: + from .user import users + j = sa.join( + keypairs, users, + keypairs.c.user == users.c.uuid, + ) + query = ( + sa.select([keypairs]) + .select_from(j) + .where(keypairs.c.user_id.in_(user_ids)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if is_active is not None: + query = query.where(keypairs.c.is_active == is_active) + async with graph_ctx.db.begin_readonly() as conn: + return await batch_multiresult( + graph_ctx, conn, query, cls, + user_ids, lambda row: row['user_id'], + ) + + @classmethod + async def batch_load_by_ak( + cls, + graph_ctx: GraphQueryContext, + access_keys: Sequence[AccessKey], + *, + domain_name: str = None, + ) -> Sequence[Optional[KeyPair]]: + from .user import users + j = sa.join( + keypairs, users, + keypairs.c.user == users.c.uuid, + ) + query = ( + sa.select([keypairs]) + .select_from(j) + .where(keypairs.c.access_key.in_(access_keys)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + async with graph_ctx.db.begin_readonly() as conn: + return await batch_result( + graph_ctx, conn, query, cls, + access_keys, lambda row: row['access_key'], + ) + + +class KeyPairList(graphene.ObjectType): + class Meta: + interfaces = (PaginatedList, ) + + items = graphene.List(KeyPair, required=True) + + +class KeyPairInput(graphene.InputObjectType): + is_active = graphene.Boolean(required=False, default=True) + is_admin = graphene.Boolean(required=False, default=False) + resource_policy = graphene.String(required=True) + concurrency_limit = graphene.Int(required=False) # deprecated and ignored + rate_limit = graphene.Int(required=True) + + # When creating, you MUST set all fields. + # When modifying, set the field to "None" to skip setting the value. + + +class ModifyKeyPairInput(graphene.InputObjectType): + is_active = graphene.Boolean(required=False) + is_admin = graphene.Boolean(required=False) + resource_policy = graphene.String(required=False) + concurrency_limit = graphene.Int(required=False) # deprecated and ignored + rate_limit = graphene.Int(required=False) + + +class CreateKeyPair(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + user_id = graphene.String(required=True) + props = KeyPairInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + keypair = graphene.Field(lambda: KeyPair, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + user_id: str, + props: KeyPairInput, + ) -> CreateKeyPair: + from .user import users # noqa + graph_ctx: GraphQueryContext = info.context + data = cls.prepare_new_keypair(user_id, props) + insert_query = ( + sa.insert(keypairs) + .values( + **data, + user=sa.select([users.c.uuid]).where(users.c.email == user_id).as_scalar(), + ) + ) + return await simple_db_mutate_returning_item(cls, graph_ctx, insert_query, item_cls=KeyPair) + + @classmethod + def prepare_new_keypair(cls, user_email: str, props: KeyPairInput) -> Dict[str, Any]: + ak, sk = generate_keypair() + pubkey, privkey = generate_ssh_keypair() + data = { + 'user_id': user_email, + 'access_key': ak, + 'secret_key': sk, + 'is_active': props.is_active, + 'is_admin': props.is_admin, + 'resource_policy': props.resource_policy, + 'rate_limit': props.rate_limit, + 'num_queries': 0, + 'ssh_public_key': pubkey, + 'ssh_private_key': privkey, + } + return data + + +class ModifyKeyPair(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + access_key = graphene.String(required=True) + props = ModifyKeyPairInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + access_key: AccessKey, + props: ModifyUserInput, + ) -> ModifyKeyPair: + ctx: GraphQueryContext = info.context + data: Dict[str, Any] = {} + set_if_set(props, data, 'is_active') + set_if_set(props, data, 'is_admin') + set_if_set(props, data, 'resource_policy') + set_if_set(props, data, 'rate_limit') + update_query = ( + sa.update(keypairs) + .values(data) + .where(keypairs.c.access_key == access_key) + ) + return await simple_db_mutate(cls, ctx, update_query) + + +class DeleteKeyPair(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + access_key = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + access_key: AccessKey, + ) -> DeleteKeyPair: + ctx: GraphQueryContext = info.context + delete_query = ( + sa.delete(keypairs) + .where(keypairs.c.access_key == access_key) + ) + await redis.execute( + ctx.redis_stat, + lambda r: r.delete(f'keypair.concurrency_used.{access_key}'), + ) + return await simple_db_mutate(cls, ctx, delete_query) + + +class Dotfile(TypedDict): + data: str + path: str + perm: str + + +def generate_keypair() -> Tuple[AccessKey, SecretKey]: + ''' + AWS-like access key and secret key generation. + ''' + ak = 'AKIA' + base64.b32encode(secrets.token_bytes(10)).decode('ascii') + sk = secrets.token_urlsafe(30) + return AccessKey(ak), SecretKey(sk) + + +def generate_ssh_keypair() -> Tuple[str, str]: + ''' + Generate RSA keypair for SSH/SFTP connection. + ''' + key = rsa.generate_private_key( + backend=crypto_default_backend(), + public_exponent=65537, + key_size=2048, + ) + private_key = key.private_bytes( + crypto_serialization.Encoding.PEM, + crypto_serialization.PrivateFormat.TraditionalOpenSSL, + crypto_serialization.NoEncryption(), + ).decode("utf-8") + public_key = key.public_key().public_bytes( + crypto_serialization.Encoding.OpenSSH, + crypto_serialization.PublicFormat.OpenSSH, + ).decode("utf-8") + return (public_key, private_key) + + +async def query_owned_dotfiles( + conn: SAConnection, + access_key: AccessKey, +) -> Tuple[List[Dotfile], int]: + query = ( + sa.select([keypairs.c.dotfiles]) + .select_from(keypairs) + .where(keypairs.c.access_key == access_key) + ) + packed_dotfile = (await conn.execute(query)).scalar() + rows = msgpack.unpackb(packed_dotfile) + return rows, MAXIMUM_DOTFILE_SIZE - len(packed_dotfile) + + +async def query_bootstrap_script( + conn: SAConnection, + access_key: AccessKey, +) -> Tuple[str, int]: + query = ( + sa.select([keypairs.c.bootstrap_script]) + .select_from(keypairs) + .where(keypairs.c.access_key == access_key) + ) + script = (await conn.execute(query)).scalar() + return script, MAXIMUM_DOTFILE_SIZE - len(script) + + +def verify_dotfile_name(dotfile: str) -> bool: + if dotfile in RESERVED_DOTFILES: + return False + return True diff --git a/src/ai/backend/manager/models/minilang/__init__.py b/src/ai/backend/manager/models/minilang/__init__.py new file mode 100644 index 0000000000..8c7101af01 --- /dev/null +++ b/src/ai/backend/manager/models/minilang/__init__.py @@ -0,0 +1,8 @@ +from typing import ( + Any, + Callable, + Optional, + Tuple, +) + +FieldSpecItem = Tuple[str, Optional[Callable[[str], Any]]] diff --git a/src/ai/backend/manager/models/minilang/ordering.py b/src/ai/backend/manager/models/minilang/ordering.py new file mode 100644 index 0000000000..a1c7cb5eb2 --- /dev/null +++ b/src/ai/backend/manager/models/minilang/ordering.py @@ -0,0 +1,82 @@ +from typing import ( + Mapping, +) + +from lark import Lark, LarkError, Transformer +import sqlalchemy as sa + +__all__ = ( + 'QueryOrderParser', +) + +_grammar = r""" + ?start: expr + expr : [col ("," col)*] + col : ORDER? CNAME + ORDER : "+" | "-" + %import common.CNAME + %import common.WS + %ignore WS +""" +_parser = Lark( + _grammar, + parser='lalr', + maybe_placeholders=False, +) + + +class QueryOrderTransformer(Transformer): + + def __init__(self, sa_table: sa.Table, column_map: Mapping[str, str] = None) -> None: + super().__init__() + self._sa_table = sa_table + self._column_map = column_map + + def _get_col(self, col_name: str) -> sa.Column: + try: + if self._column_map: + col = self._sa_table.c[self._column_map[col_name]] + else: + col = self._sa_table.c[col_name] + return col + except KeyError: + raise ValueError("Unknown/unsupported field name", col_name) + + def col(self, *args): + children = args[0] + if len(children) == 2: + op = children[0].value + col = self._get_col(children[1].value) + else: + op = "+" # assume ascending if not marked + col = self._get_col(children[0].value) + if op == "+": + return col.asc() + elif op == "-": + return col.desc() + + expr = tuple + + +class QueryOrderParser(): + + def __init__(self, column_map: Mapping[str, str] = None) -> None: + self._column_map = column_map + self._parser = _parser + + def append_ordering( + self, + sa_query: sa.sql.Select, + order_expr: str, + ) -> sa.sql.Select: + """ + Parse the given filter expression and build the where clause based on the first target table from + the given SQLAlchemy query object. + """ + table = sa_query.froms[0] + try: + ast = self._parser.parse(order_expr) + orders = QueryOrderTransformer(table, self._column_map).transform(ast) + except LarkError as e: + raise ValueError(f"Query ordering parsing error: {e}") + return sa_query.order_by(*orders) diff --git a/src/ai/backend/manager/models/minilang/queryfilter.py b/src/ai/backend/manager/models/minilang/queryfilter.py new file mode 100644 index 0000000000..7822393681 --- /dev/null +++ b/src/ai/backend/manager/models/minilang/queryfilter.py @@ -0,0 +1,196 @@ +from typing import ( + Any, + Mapping, + Union, +) + +from lark import Lark, LarkError, Transformer, Tree +import sqlalchemy as sa + +from . import FieldSpecItem + +__all__ = ( + 'FilterableSQLQuery', + 'QueryFilterParser', +) + + +FilterableSQLQuery = Union[sa.sql.Select, sa.sql.Update, sa.sql.Delete] + +_grammar = r""" + ?start: expr + value: string + | number + | array + | ATOM -> atom + ATOM : "null" | "true" | "false" + COMBINE_OP : "&" | "|" + UNARY_OP : "!" + BINARY_OP : "==" | "!=" + | ">" | ">=" + | "<" | "<=" + | "contains" | "in" + | "isnot" | "is" + | "like" | "ilike" + expr: UNARY_OP expr -> unary_expr + | CNAME BINARY_OP value -> binary_expr + | expr COMBINE_OP expr -> combine_expr + | "(" expr ")" -> paren_expr + array : "[" [value ("," value)*] "]" + string : ESCAPED_STRING + number : SIGNED_NUMBER + %import common.CNAME + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + %import common.WS + %ignore WS +""" +_parser = Lark( + _grammar, + parser='lalr', + maybe_placeholders=False, +) + + +class QueryFilterTransformer(Transformer): + + def __init__(self, sa_table: sa.Table, fieldspec: Mapping[str, FieldSpecItem] = None) -> None: + super().__init__() + self._sa_table = sa_table + self._fieldspec = fieldspec + + def string(self, s): + (s,) = s + # SQL-side escaping is handled by SQLAlchemy + return s[1:-1].replace("\\\"", '"') + + def number(self, n): + (n,) = n + if '.' in n: + return float(n) + return int(n) + + array = list + + def atom(self, a): + (a,) = a + if a.value == "null": + return sa.null() + elif a.value == "true": + return sa.true() + elif a.value == "false": + return sa.false() + + def _get_col(self, col_name: str) -> sa.Column: + try: + if self._fieldspec: + col = self._sa_table.c[self._fieldspec[col_name][0]] + else: + col = self._sa_table.c[col_name] + return col + except KeyError: + raise ValueError("Unknown/unsupported field name", col_name) + + def _transform_val_leaf(self, col_name: str, value: Any) -> Any: + if self._fieldspec: + try: + func = self._fieldspec[col_name][1] + except KeyError: + raise ValueError("Unknown/unsupported field name", col_name) + return func(value) if func is not None else value + else: + return value + + def _transform_val(self, col_name: str, value: Any) -> Any: + if isinstance(value, Tree): + val = self._transform_val(col_name, value.children[0]) + elif isinstance(value, list): + val = [self._transform_val(col_name, v) for v in value] + else: + val = self._transform_val_leaf(col_name, value) + return val + + def binary_expr(self, *args): + children = args[0] + col = self._get_col(children[0].value) + op = children[1].value + val = self._transform_val(children[0].value, children[2]) + if op == "==": + return (col == val) + elif op == "!=": + return (col != val) + elif op == ">": + return (col > val) + elif op == ">=": + return (col >= val) + elif op == "<": + return (col < val) + elif op == "<=": + return (col <= val) + elif op == "contains": + return (col.contains(val)) + elif op == "in": + return (col.in_(val)) + elif op == "isnot": + return (col.isnot(val)) + elif op == "is": + return (col.is_(val)) + elif op == "like": + return (col.like(val)) + elif op == "ilike": + return (col.ilike(val)) + return args + + def unary_expr(self, *args): + children = args[0] + op = children[0].value + expr = children[1] + if op in ("not", "!"): + return (sa.not_(expr)) + return args + + def combine_expr(self, *args): + children = args[0] + op = children[1].value + expr1 = children[0] + expr2 = children[2] + if op == "&": + return (sa.and_(expr1, expr2)) + elif op == "|": + return (sa.or_(expr1, expr2)) + return args + + def paren_expr(self, *args): + children = args[0] + return children[0] + + +class QueryFilterParser(): + + def __init__(self, fieldspec: Mapping[str, FieldSpecItem] = None) -> None: + self._fieldspec = fieldspec + self._parser = _parser + + def append_filter( + self, + sa_query: FilterableSQLQuery, + filter_expr: str, + ) -> FilterableSQLQuery: + """ + Parse the given filter expression and build the where clause based on the first target table from + the given SQLAlchemy query object. + """ + if isinstance(sa_query, sa.sql.Select): + table = sa_query.froms[0] + elif isinstance(sa_query, sa.sql.Delete): + table = sa_query.table + elif isinstance(sa_query, sa.sql.Update): + table = sa_query.table + else: + raise ValueError('Unsupported SQLAlchemy query object type') + try: + ast = self._parser.parse(filter_expr) + where_clause = QueryFilterTransformer(table, self._fieldspec).transform(ast) + except LarkError as e: + raise ValueError(f"Query filter parsing error: {e}") + return sa_query.where(where_clause) diff --git a/src/ai/backend/manager/models/resource_policy.py b/src/ai/backend/manager/models/resource_policy.py new file mode 100644 index 0000000000..1603a15b74 --- /dev/null +++ b/src/ai/backend/manager/models/resource_policy.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import logging +from typing import ( + Any, + Dict, + Sequence, + TYPE_CHECKING, +) + +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +import sqlalchemy as sa +from sqlalchemy.engine.row import Row +from sqlalchemy.dialects import postgresql as pgsql + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import DefaultForUnspecified, ResourceSlot +from .base import ( + metadata, BigInt, EnumType, ResourceSlotColumn, + simple_db_mutate, + simple_db_mutate_returning_item, + set_if_set, + batch_result, +) +from .keypair import keypairs +from .user import UserRole + +if TYPE_CHECKING: + from .gql import GraphQueryContext + +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.models')) + +__all__: Sequence[str] = ( + 'keypair_resource_policies', + 'KeyPairResourcePolicy', + 'DefaultForUnspecified', + 'CreateKeyPairResourcePolicy', + 'ModifyKeyPairResourcePolicy', + 'DeleteKeyPairResourcePolicy', +) + + +keypair_resource_policies = sa.Table( + 'keypair_resource_policies', metadata, + sa.Column('name', sa.String(length=256), primary_key=True), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now()), + sa.Column('default_for_unspecified', + EnumType(DefaultForUnspecified), + default=DefaultForUnspecified.LIMITED, + nullable=False), + sa.Column('total_resource_slots', ResourceSlotColumn(), nullable=False), + sa.Column('max_session_lifetime', sa.Integer(), nullable=False, server_default=sa.text('0')), + sa.Column('max_concurrent_sessions', sa.Integer(), nullable=False), + sa.Column('max_containers_per_session', sa.Integer(), nullable=False), + sa.Column('max_vfolder_count', sa.Integer(), nullable=False), + sa.Column('max_vfolder_size', sa.BigInteger(), nullable=False), + sa.Column('idle_timeout', sa.BigInteger(), nullable=False), + sa.Column('allowed_vfolder_hosts', pgsql.ARRAY(sa.String), nullable=False), + # TODO: implement with a many-to-many association table + # sa.Column('allowed_scaling_groups', sa.Array(sa.String), nullable=False), +) + + +class KeyPairResourcePolicy(graphene.ObjectType): + name = graphene.String() + created_at = GQLDateTime() + default_for_unspecified = graphene.String() + total_resource_slots = graphene.JSONString() + max_session_lifetime = graphene.Int() + max_concurrent_sessions = graphene.Int() + max_containers_per_session = graphene.Int() + idle_timeout = BigInt() + max_vfolder_count = graphene.Int() + max_vfolder_size = BigInt() + allowed_vfolder_hosts = graphene.List(lambda: graphene.String) + + @classmethod + def from_row( + cls, + ctx: GraphQueryContext, + row: Row | None, + ) -> KeyPairResourcePolicy | None: + if row is None: + return None + return cls( + name=row['name'], + created_at=row['created_at'], + default_for_unspecified=row['default_for_unspecified'].name, + total_resource_slots=row['total_resource_slots'].to_json(), + max_session_lifetime=row['max_session_lifetime'], + max_concurrent_sessions=row['max_concurrent_sessions'], + max_containers_per_session=row['max_containers_per_session'], + idle_timeout=row['idle_timeout'], + max_vfolder_count=row['max_vfolder_count'], + max_vfolder_size=row['max_vfolder_size'], + allowed_vfolder_hosts=row['allowed_vfolder_hosts'], + ) + + @classmethod + async def load_all(cls, ctx: GraphQueryContext) -> Sequence[KeyPairResourcePolicy]: + query = ( + sa.select([keypair_resource_policies]) + .select_from(keypair_resource_policies) + ) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for r in (await conn.stream(query)) + if (obj := cls.from_row(ctx, r)) is not None + ] + + @classmethod + async def load_all_user( + cls, + ctx: GraphQueryContext, + access_key: str, + ) -> Sequence[KeyPairResourcePolicy]: + j = sa.join( + keypairs, keypair_resource_policies, + keypairs.c.resource_policy == keypair_resource_policies.c.name, + ) + query = ( + sa.select([keypair_resource_policies]) + .select_from(j) + .where( + keypairs.c.user_id == ( + sa.select([keypairs.c.user_id]) + .select_from(keypairs) + .where(keypairs.c.access_key == access_key) + .as_scalar() + ), + ) + ) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for r in (await conn.stream(query)) + if (obj := cls.from_row(ctx, r)) is not None + ] + + @classmethod + async def batch_load_by_name( + cls, + ctx: GraphQueryContext, + names: Sequence[str], + ) -> Sequence[KeyPairResourcePolicy | None]: + query = ( + sa.select([keypair_resource_policies]) + .select_from(keypair_resource_policies) + .where(keypair_resource_policies.c.name.in_(names)) + .order_by(keypair_resource_policies.c.name) + ) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + names, lambda row: row['name'], + ) + + @classmethod + async def batch_load_by_name_user( + cls, + ctx: GraphQueryContext, + names: Sequence[str], + ) -> Sequence[KeyPairResourcePolicy | None]: + access_key = ctx.access_key + j = sa.join( + keypairs, keypair_resource_policies, + keypairs.c.resource_policy == keypair_resource_policies.c.name, + ) + query = ( + sa.select([keypair_resource_policies]) + .select_from(j) + .where( + (keypair_resource_policies.c.name.in_(names)) & + (keypairs.c.access_key == access_key), + ) + .order_by(keypair_resource_policies.c.name) + ) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + names, lambda row: row['name'], + ) + + @classmethod + async def batch_load_by_ak( + cls, + ctx: GraphQueryContext, + access_keys: Sequence[str], + ) -> Sequence[KeyPairResourcePolicy]: + j = sa.join( + keypairs, keypair_resource_policies, + keypairs.c.resource_policy == keypair_resource_policies.c.name, + ) + query = ( + sa.select([keypair_resource_policies]) + .select_from(j) + .where((keypairs.c.access_key.in_(access_keys))) + .order_by(keypair_resource_policies.c.name) + ) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for r in (await conn.stream(query)) + if (obj := cls.from_row(ctx, r)) is not None + ] + + +class CreateKeyPairResourcePolicyInput(graphene.InputObjectType): + default_for_unspecified = graphene.String(required=True) + total_resource_slots = graphene.JSONString(required=True) + max_session_lifetime = graphene.Int(required=True, default_value=0) + max_concurrent_sessions = graphene.Int(required=True) + max_containers_per_session = graphene.Int(required=True) + idle_timeout = BigInt(required=True) + max_vfolder_count = graphene.Int(required=True) + max_vfolder_size = BigInt(required=True) + allowed_vfolder_hosts = graphene.List(lambda: graphene.String) + + +class ModifyKeyPairResourcePolicyInput(graphene.InputObjectType): + default_for_unspecified = graphene.String(required=False) + total_resource_slots = graphene.JSONString(required=False) + max_session_lifetime = graphene.Int(required=False) + max_concurrent_sessions = graphene.Int(required=False) + max_containers_per_session = graphene.Int(required=False) + idle_timeout = BigInt(required=False) + max_vfolder_count = graphene.Int(required=False) + max_vfolder_size = BigInt(required=False) + allowed_vfolder_hosts = graphene.List(lambda: graphene.String, required=False) + + +class CreateKeyPairResourcePolicy(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + props = CreateKeyPairResourcePolicyInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + resource_policy = graphene.Field(lambda: KeyPairResourcePolicy, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: CreateKeyPairResourcePolicyInput, + ) -> CreateKeyPairResourcePolicy: + data = { + 'name': name, + 'default_for_unspecified': + DefaultForUnspecified[props.default_for_unspecified], + 'total_resource_slots': ResourceSlot.from_user_input( + props.total_resource_slots, None), + 'max_session_lifetime': props.max_session_lifetime, + 'max_concurrent_sessions': props.max_concurrent_sessions, + 'max_containers_per_session': props.max_containers_per_session, + 'idle_timeout': props.idle_timeout, + 'max_vfolder_count': props.max_vfolder_count, + 'max_vfolder_size': props.max_vfolder_size, + 'allowed_vfolder_hosts': props.allowed_vfolder_hosts, + } + insert_query = ( + sa.insert(keypair_resource_policies).values(data) + ) + return await simple_db_mutate_returning_item( + cls, info.context, insert_query, item_cls=KeyPairResourcePolicy, + ) + + +class ModifyKeyPairResourcePolicy(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + props = ModifyKeyPairResourcePolicyInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: ModifyKeyPairResourcePolicyInput, + ) -> ModifyKeyPairResourcePolicy: + data: Dict[str, Any] = {} + set_if_set(props, data, 'default_for_unspecified', + clean_func=lambda v: DefaultForUnspecified[v]) + set_if_set(props, data, 'total_resource_slots', + clean_func=lambda v: ResourceSlot.from_user_input(v, None)) + set_if_set(props, data, 'max_session_lifetime') + set_if_set(props, data, 'max_concurrent_sessions') + set_if_set(props, data, 'max_containers_per_session') + set_if_set(props, data, 'idle_timeout') + set_if_set(props, data, 'max_vfolder_count') + set_if_set(props, data, 'max_vfolder_size') + set_if_set(props, data, 'allowed_vfolder_hosts') + update_query = ( + sa.update(keypair_resource_policies) + .values(data) + .where(keypair_resource_policies.c.name == name) + ) + return await simple_db_mutate(cls, info.context, update_query) + + +class DeleteKeyPairResourcePolicy(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + ) -> DeleteKeyPairResourcePolicy: + delete_query = ( + sa.delete(keypair_resource_policies) + .where(keypair_resource_policies.c.name == name) + ) + return await simple_db_mutate(cls, info.context, delete_query) diff --git a/src/ai/backend/manager/models/resource_preset.py b/src/ai/backend/manager/models/resource_preset.py new file mode 100644 index 0000000000..3f41d22387 --- /dev/null +++ b/src/ai/backend/manager/models/resource_preset.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import logging +from typing import ( + Any, + Dict, + Sequence, + TYPE_CHECKING, +) + +import graphene +import sqlalchemy as sa +from sqlalchemy.engine.row import Row + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ResourceSlot +from .base import ( + metadata, BigInt, BinarySize, ResourceSlotColumn, + simple_db_mutate, + simple_db_mutate_returning_item, + set_if_set, + batch_result, +) +from .user import UserRole + +if TYPE_CHECKING: + from .gql import GraphQueryContext + +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.models')) + +__all__: Sequence[str] = ( + 'resource_presets', + 'ResourcePreset', + 'CreateResourcePreset', + 'ModifyResourcePreset', + 'DeleteResourcePreset', +) + + +resource_presets = sa.Table( + 'resource_presets', metadata, + sa.Column('name', sa.String(length=256), primary_key=True), + sa.Column('resource_slots', ResourceSlotColumn(), nullable=False), + sa.Column('shared_memory', sa.BigInteger(), nullable=True), +) + + +class ResourcePreset(graphene.ObjectType): + name = graphene.String() + resource_slots = graphene.JSONString() + shared_memory = BigInt() + + @classmethod + def from_row( + cls, + ctx: GraphQueryContext, + row: Row | None, + ) -> ResourcePreset | None: + if row is None: + return None + shared_memory = str(row['shared_memory']) if row['shared_memory'] else None + return cls( + name=row['name'], + resource_slots=row['resource_slots'].to_json(), + shared_memory=shared_memory, + ) + + @classmethod + async def load_all(cls, ctx: GraphQueryContext) -> Sequence[ResourcePreset]: + query = ( + sa.select([resource_presets]) + .select_from(resource_presets) + ) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for r in (await conn.stream(query)) + if (obj := cls.from_row(ctx, r)) is not None + ] + + @classmethod + async def batch_load_by_name( + cls, + ctx: GraphQueryContext, + names: Sequence[str], + ) -> Sequence[ResourcePreset | None]: + query = ( + sa.select([resource_presets]) + .select_from(resource_presets) + .where(resource_presets.c.name.in_(names)) + .order_by(resource_presets.c.name) + ) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + names, lambda row: row['name'], + ) + + +class CreateResourcePresetInput(graphene.InputObjectType): + resource_slots = graphene.JSONString(required=True) + shared_memory = graphene.String(required=False) + + +class ModifyResourcePresetInput(graphene.InputObjectType): + resource_slots = graphene.JSONString(required=False) + shared_memory = graphene.String(required=False) + + +class CreateResourcePreset(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + props = CreateResourcePresetInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + resource_preset = graphene.Field(lambda: ResourcePreset, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: CreateResourcePresetInput, + ) -> CreateResourcePreset: + data = { + 'name': name, + 'resource_slots': ResourceSlot.from_user_input( + props.resource_slots, None), + 'shared_memory': BinarySize.from_str(props.shared_memory) if props.shared_memory else None, + } + insert_query = sa.insert(resource_presets).values(data) + return await simple_db_mutate_returning_item( + cls, info.context, insert_query, + item_cls=ResourcePreset, + ) + + +class ModifyResourcePreset(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + props = ModifyResourcePresetInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: ModifyResourcePresetInput, + ) -> ModifyResourcePreset: + data: Dict[str, Any] = {} + set_if_set(props, data, 'resource_slots', + clean_func=lambda v: ResourceSlot.from_user_input(v, None)) + set_if_set(props, data, 'shared_memory', + clean_func=lambda v: BinarySize.from_str(v) if v else None) + update_query = ( + sa.update(resource_presets) + .values(data) + .where(resource_presets.c.name == name) + ) + return await simple_db_mutate(cls, info.context, update_query) + + +class DeleteResourcePreset(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + ) -> DeleteResourcePreset: + delete_query = ( + sa.delete(resource_presets) + .where(resource_presets.c.name == name) + ) + return await simple_db_mutate(cls, info.context, delete_query) diff --git a/src/ai/backend/manager/models/scaling_group.py b/src/ai/backend/manager/models/scaling_group.py new file mode 100644 index 0000000000..fd7a2e9fec --- /dev/null +++ b/src/ai/backend/manager/models/scaling_group.py @@ -0,0 +1,707 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import ( + Any, + Dict, + Mapping, + Sequence, + Set, + TYPE_CHECKING, + Union, +) +import uuid + +import attr +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pgsql +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.types import SessionTypes, JSONSerializableMixin + +from .base import ( + metadata, + simple_db_mutate, + simple_db_mutate_returning_item, + set_if_set, + batch_result, + batch_multiresult, + StructuredJSONObjectColumn, +) +from .group import resolve_group_name_or_id +from .user import UserRole + +if TYPE_CHECKING: + from .gql import GraphQueryContext + +__all__: Sequence[str] = ( + # table defs + 'scaling_groups', + 'sgroups_for_domains', + 'sgroups_for_groups', + 'sgroups_for_keypairs', + # functions + 'query_allowed_sgroups', + 'ScalingGroup', + 'CreateScalingGroup', + 'ModifyScalingGroup', + 'DeleteScalingGroup', + 'AssociateScalingGroupWithDomain', + 'AssociateScalingGroupWithUserGroup', + 'AssociateScalingGroupWithKeyPair', + 'DisassociateScalingGroupWithDomain', + 'DisassociateScalingGroupWithUserGroup', + 'DisassociateScalingGroupWithKeyPair', +) + + +@attr.define(slots=True) +class ScalingGroupOpts(JSONSerializableMixin): + allowed_session_types: list[SessionTypes] = attr.Factory( + lambda: [SessionTypes.INTERACTIVE, SessionTypes.BATCH], + ) + pending_timeout: timedelta = timedelta(seconds=0) + config: Mapping[str, Any] = attr.Factory(dict) + + def to_json(self) -> dict[str, Any]: + return { + "allowed_session_types": [ + item.value for item in self.allowed_session_types + ], + "pending_timeout": self.pending_timeout.total_seconds(), + "config": self.config, + } + + @classmethod + def from_json(cls, obj: Mapping[str, Any]) -> ScalingGroupOpts: + return cls(**cls.as_trafaret().check(obj)) + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + return t.Dict({ + t.Key('allowed_session_types', default=['interactive', 'batch']): + t.List(tx.Enum(SessionTypes), min_length=1), + t.Key('pending_timeout', default=0): + tx.TimeDuration(allow_negative=False), + # Each scheduler impl refers an additional "config" key. + t.Key("config", default={}): t.Mapping(t.String, t.Any), + }).allow_extra('*') + + +scaling_groups = sa.Table( + 'scaling_groups', metadata, + sa.Column('name', sa.String(length=64), primary_key=True), + sa.Column('description', sa.String(length=512)), + sa.Column('is_active', sa.Boolean, index=True, default=True), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now()), + sa.Column('wsproxy_addr', sa.String(length=1024), nullable=True), + sa.Column('driver', sa.String(length=64), nullable=False), + sa.Column('driver_opts', pgsql.JSONB(), nullable=False, default={}), + sa.Column('scheduler', sa.String(length=64), nullable=False), + sa.Column( + 'scheduler_opts', StructuredJSONObjectColumn(ScalingGroupOpts), + nullable=False, default={}, + ), +) + + +# When scheduling, we take the union of allowed scaling groups for +# each domain, group, and keypair. + + +sgroups_for_domains = sa.Table( + 'sgroups_for_domains', metadata, + sa.Column('scaling_group', + sa.ForeignKey('scaling_groups.name', + onupdate='CASCADE', + ondelete='CASCADE'), + index=True, nullable=False), + sa.Column('domain', + sa.ForeignKey('domains.name', + onupdate='CASCADE', + ondelete='CASCADE'), + index=True, nullable=False), + sa.UniqueConstraint('scaling_group', 'domain', name='uq_sgroup_domain'), +) + + +sgroups_for_groups = sa.Table( + 'sgroups_for_groups', metadata, + sa.Column('scaling_group', + sa.ForeignKey('scaling_groups.name', + onupdate='CASCADE', + ondelete='CASCADE'), + index=True, nullable=False), + sa.Column('group', + sa.ForeignKey('groups.id', + onupdate='CASCADE', + ondelete='CASCADE'), + index=True, nullable=False), + sa.UniqueConstraint('scaling_group', 'group', name='uq_sgroup_ugroup'), +) + + +sgroups_for_keypairs = sa.Table( + 'sgroups_for_keypairs', metadata, + sa.Column('scaling_group', + sa.ForeignKey('scaling_groups.name', + onupdate='CASCADE', + ondelete='CASCADE'), + index=True, nullable=False), + sa.Column('access_key', + sa.ForeignKey('keypairs.access_key', + onupdate='CASCADE', + ondelete='CASCADE'), + index=True, nullable=False), + sa.UniqueConstraint('scaling_group', 'access_key', name='uq_sgroup_akey'), +) + + +async def query_allowed_sgroups( + db_conn: SAConnection, + domain_name: str, + group: Union[uuid.UUID, str], + access_key: str, +) -> Sequence[Row]: + query = ( + sa.select([sgroups_for_domains]) + .where(sgroups_for_domains.c.domain == domain_name) + ) + result = await db_conn.execute(query) + from_domain = {row['scaling_group'] for row in result} + + group_id = await resolve_group_name_or_id(db_conn, domain_name, group) + from_group: Set[str] + if group_id is None: + from_group = set() # empty + else: + query = ( + sa.select([sgroups_for_groups]) + .where( + (sgroups_for_groups.c.group == group_id), + ) + ) + result = await db_conn.execute(query) + from_group = {row['scaling_group'] for row in result} + + query = (sa.select([sgroups_for_keypairs]) + .where(sgroups_for_keypairs.c.access_key == access_key)) + result = await db_conn.execute(query) + from_keypair = {row['scaling_group'] for row in result} + + sgroups = from_domain | from_group | from_keypair + query = ( + sa.select([scaling_groups]) + .where( + (scaling_groups.c.name.in_(sgroups)) & + (scaling_groups.c.is_active), + ) + .order_by(scaling_groups.c.name) + ) + result = await db_conn.execute(query) + return [row for row in result] + + +class ScalingGroup(graphene.ObjectType): + name = graphene.String() + description = graphene.String() + is_active = graphene.Boolean() + created_at = GQLDateTime() + wsproxy_addr = graphene.String() + driver = graphene.String() + driver_opts = graphene.JSONString() + scheduler = graphene.String() + scheduler_opts = graphene.JSONString() + + @classmethod + def from_row( + cls, + ctx: GraphQueryContext, + row: Row | None, + ) -> ScalingGroup | None: + if row is None: + return None + return cls( + name=row['name'], + description=row['description'], + is_active=row['is_active'], + created_at=row['created_at'], + wsproxy_addr=row['wsproxy_addr'], + driver=row['driver'], + driver_opts=row['driver_opts'], + scheduler=row['scheduler'], + scheduler_opts=row['scheduler_opts'].to_json(), + ) + + @classmethod + async def load_all( + cls, + ctx: GraphQueryContext, + *, + is_active: bool = None, + ) -> Sequence[ScalingGroup]: + query = sa.select([scaling_groups]).select_from(scaling_groups) + if is_active is not None: + query = query.where(scaling_groups.c.is_active == is_active) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(ctx, row)) is not None + ] + + @classmethod + async def load_by_domain( + cls, + ctx: GraphQueryContext, + domain: str, + *, + is_active: bool = None, + ) -> Sequence[ScalingGroup]: + j = sa.join( + scaling_groups, sgroups_for_domains, + scaling_groups.c.name == sgroups_for_domains.c.scaling_group) + query = ( + sa.select([scaling_groups]) + .select_from(j) + .where(sgroups_for_domains.c.domain == domain) + ) + if is_active is not None: + query = query.where(scaling_groups.c.is_active == is_active) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(ctx, row)) is not None + ] + + @classmethod + async def load_by_group( + cls, + ctx: GraphQueryContext, + group: uuid.UUID, + *, + is_active: bool = None, + ) -> Sequence[ScalingGroup]: + j = sa.join( + scaling_groups, sgroups_for_groups, + scaling_groups.c.name == sgroups_for_groups.c.scaling_group, + ) + query = ( + sa.select([scaling_groups]) + .select_from(j) + .where(sgroups_for_groups.c.group == group) + ) + if is_active is not None: + query = query.where(scaling_groups.c.is_active == is_active) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(ctx, row)) is not None + ] + + @classmethod + async def load_by_keypair( + cls, + ctx: GraphQueryContext, + access_key: str, + *, + is_active: bool = None, + ) -> Sequence[ScalingGroup]: + j = sa.join( + scaling_groups, sgroups_for_keypairs, + scaling_groups.c.name == sgroups_for_keypairs.c.scaling_group, + ) + query = ( + sa.select([scaling_groups]) + .select_from(j) + .where(sgroups_for_keypairs.c.access_key == access_key) + ) + if is_active is not None: + query = query.where(scaling_groups.c.is_active == is_active) + async with ctx.db.begin_readonly() as conn: + return [ + obj async for row in (await conn.stream(query)) + if (obj := cls.from_row(ctx, row)) is not None + ] + + @classmethod + async def batch_load_by_group( + cls, + ctx: GraphQueryContext, + group_ids: Sequence[uuid.UUID], + ) -> Sequence[Sequence[ScalingGroup | None]]: + j = sa.join( + scaling_groups, sgroups_for_groups, + scaling_groups.c.name == sgroups_for_groups.c.scaling_group, + ) + query = ( + sa.select([scaling_groups, sgroups_for_groups.c.group]) + .select_from(j) + .where(sgroups_for_groups.c.group.in_(group_ids)) + ) + async with ctx.db.begin_readonly() as conn: + return await batch_multiresult( + ctx, conn, query, cls, + group_ids, lambda row: row['group'], + ) + + @classmethod + async def batch_load_by_name( + cls, + ctx: GraphQueryContext, + names: Sequence[str], + ) -> Sequence[ScalingGroup | None]: + query = ( + sa.select([scaling_groups]) + .select_from(scaling_groups) + .where(scaling_groups.c.name.in_(names)) + ) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + names, lambda row: row['name'], + ) + + +class CreateScalingGroupInput(graphene.InputObjectType): + description = graphene.String(required=False, default='') + is_active = graphene.Boolean(required=False, default=True) + wsproxy_addr = graphene.String(required=False) + driver = graphene.String(required=True) + driver_opts = graphene.JSONString(required=False, default={}) + scheduler = graphene.String(required=True) + scheduler_opts = graphene.JSONString(required=False, default={}) + + +class ModifyScalingGroupInput(graphene.InputObjectType): + description = graphene.String(required=False) + is_active = graphene.Boolean(required=False) + wsproxy_addr = graphene.String(required=False) + driver = graphene.String(required=False) + driver_opts = graphene.JSONString(required=False) + scheduler = graphene.String(required=False) + scheduler_opts = graphene.JSONString(required=False) + + +class CreateScalingGroup(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + props = CreateScalingGroupInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + scaling_group = graphene.Field(lambda: ScalingGroup, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: CreateScalingGroupInput, + ) -> CreateScalingGroup: + data = { + 'name': name, + 'description': props.description, + 'is_active': bool(props.is_active), + 'wsproxy_addr': props.wsproxy_addr, + 'driver': props.driver, + 'driver_opts': props.driver_opts, + 'scheduler': props.scheduler, + 'scheduler_opts': ScalingGroupOpts.from_json(props.scheduler_opts), + } + insert_query = ( + sa.insert(scaling_groups).values(data) + ) + return await simple_db_mutate_returning_item( + cls, info.context, insert_query, item_cls=ScalingGroup, + ) + + +class ModifyScalingGroup(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + props = ModifyScalingGroupInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + props: ModifyScalingGroupInput, + ) -> ModifyScalingGroup: + data: Dict[str, Any] = {} + set_if_set(props, data, 'description') + set_if_set(props, data, 'is_active') + set_if_set(props, data, 'driver') + set_if_set(props, data, 'wsproxy_addr') + set_if_set(props, data, 'driver_opts') + set_if_set(props, data, 'scheduler') + set_if_set(props, data, 'scheduler_opts', clean_func=lambda v: ScalingGroupOpts.from_json(v)) + update_query = ( + sa.update(scaling_groups) + .values(data) + .where(scaling_groups.c.name == name) + ) + return await simple_db_mutate(cls, info.context, update_query) + + +class DeleteScalingGroup(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + name = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + name: str, + ) -> DeleteScalingGroup: + delete_query = ( + sa.delete(scaling_groups) + .where(scaling_groups.c.name == name) + ) + return await simple_db_mutate(cls, info.context, delete_query) + + +class AssociateScalingGroupWithDomain(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + scaling_group = graphene.String(required=True) + domain = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scaling_group: str, + domain: str, + ) -> AssociateScalingGroupWithDomain: + insert_query = ( + sa.insert(sgroups_for_domains) + .values({ + 'scaling_group': scaling_group, + 'domain': domain, + }) + ) + return await simple_db_mutate(cls, info.context, insert_query) + + +class DisassociateScalingGroupWithDomain(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + scaling_group = graphene.String(required=True) + domain = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scaling_group: str, + domain: str, + ) -> DisassociateScalingGroupWithDomain: + delete_query = ( + sa.delete(sgroups_for_domains) + .where( + (sgroups_for_domains.c.scaling_group == scaling_group) & + (sgroups_for_domains.c.domain == domain), + ) + ) + return await simple_db_mutate(cls, info.context, delete_query) + + +class DisassociateAllScalingGroupsWithDomain(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + domain = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + domain: str, + ) -> DisassociateAllScalingGroupsWithDomain: + delete_query = ( + sa.delete(sgroups_for_domains) + .where(sgroups_for_domains.c.domain == domain) + ) + return await simple_db_mutate(cls, info.context, delete_query) + + +class AssociateScalingGroupWithUserGroup(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + scaling_group = graphene.String(required=True) + user_group = graphene.UUID(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scaling_group: str, + user_group: uuid.UUID, + ) -> AssociateScalingGroupWithUserGroup: + insert_query = ( + sa.insert(sgroups_for_groups) + .values({ + 'scaling_group': scaling_group, + 'group': user_group, + }) + ) + return await simple_db_mutate(cls, info.context, insert_query) + + +class DisassociateScalingGroupWithUserGroup(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + scaling_group = graphene.String(required=True) + user_group = graphene.UUID(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scaling_group: str, + user_group: uuid.UUID, + ) -> DisassociateScalingGroupWithUserGroup: + delete_query = ( + sa.delete(sgroups_for_groups) + .where( + (sgroups_for_groups.c.scaling_group == scaling_group) & + (sgroups_for_groups.c.group == user_group), + ) + ) + return await simple_db_mutate(cls, info.context, delete_query) + + +class DisassociateAllScalingGroupsWithGroup(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + user_group = graphene.UUID(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + user_group: uuid.UUID, + ) -> DisassociateAllScalingGroupsWithGroup: + delete_query = ( + sa.delete(sgroups_for_groups) + .where(sgroups_for_groups.c.group == user_group) + ) + return await simple_db_mutate(cls, info.context, delete_query) + + +class AssociateScalingGroupWithKeyPair(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + scaling_group = graphene.String(required=True) + access_key = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scaling_group: str, + access_key: str, + ) -> AssociateScalingGroupWithKeyPair: + insert_query = ( + sa.insert(sgroups_for_keypairs) + .values({ + 'scaling_group': scaling_group, + 'access_key': access_key, + }) + ) + return await simple_db_mutate(cls, info.context, insert_query) + + +class DisassociateScalingGroupWithKeyPair(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + scaling_group = graphene.String(required=True) + access_key = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + scaling_group: str, + access_key: str, + ) -> DisassociateScalingGroupWithKeyPair: + delete_query = ( + sa.delete(sgroups_for_keypairs) + .where( + (sgroups_for_keypairs.c.scaling_group == scaling_group) & + (sgroups_for_keypairs.c.access_key == access_key), + ) + ) + return await simple_db_mutate(cls, info.context, delete_query) diff --git a/src/ai/backend/manager/models/session_template.py b/src/ai/backend/manager/models/session_template.py new file mode 100644 index 0000000000..5f24a3b012 --- /dev/null +++ b/src/ai/backend/manager/models/session_template.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import enum +from typing import ( + Any, + Iterable, + List, + Mapping, + Sequence, +) +import uuid + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as pgsql +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.types import SessionTypes + +from ..defs import DEFAULT_ROLE +from ..exceptions import InvalidArgument +from .base import metadata, GUID, IDColumn, EnumType +from .user import UserRole +from .vfolder import verify_vfolder_name + +__all__: Sequence[str] = ( + 'TemplateType', 'session_templates', 'query_accessible_session_templates', +) + + +class TemplateType(str, enum.Enum): + TASK = 'task' + CLUSTER = 'cluster' + + +session_templates = sa.Table( + 'session_templates', metadata, + IDColumn('id'), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), index=True), + sa.Column('is_active', sa.Boolean, default=True), + + sa.Column('domain_name', sa.String(length=64), sa.ForeignKey('domains.name'), nullable=False), + sa.Column('group_id', GUID, sa.ForeignKey('groups.id'), nullable=True), + sa.Column('user_uuid', GUID, sa.ForeignKey('users.uuid'), index=True, nullable=False), + sa.Column('type', EnumType(TemplateType), nullable=False, server_default='TASK', index=True), + + sa.Column('name', sa.String(length=128), nullable=True), + sa.Column('template', pgsql.JSONB(), nullable=False), +) + + +task_template_v1 = t.Dict({ + tx.AliasedKey(['api_version', 'apiVersion']): t.String, + t.Key('kind'): t.Enum('taskTemplate', 'task_template'), + t.Key('metadata'): t.Dict({ + t.Key('name'): t.String, + t.Key('tag', default=None): t.Null | t.String, + }), + t.Key('spec'): t.Dict({ + tx.AliasedKey(['type', 'session_type', 'sessionType'], + default='interactive') >> 'session_type': tx.Enum(SessionTypes), + t.Key('kernel'): t.Dict({ + t.Key('image'): t.String, + t.Key('architecture', default='x86_64'): t.Null | t.String, + t.Key('environ', default={}): t.Null | t.Mapping(t.String, t.String), + t.Key('run', default=None): t.Null | t.Dict({ + t.Key('bootstrap', default=None): t.Null | t.String, + tx.AliasedKey(['startup', 'startup_command', 'startupCommand'], + default=None) >> 'startup_command': t.Null | t.String, + }), + t.Key('git', default=None): t.Null | t.Dict({ + t.Key('repository'): t.String, + t.Key('commit', default=None): t.Null | t.String, + t.Key('branch', default=None): t.Null | t.String, + t.Key('credential', default=None): t.Null | t.Dict({ + t.Key('username'): t.String, + t.Key('password'): t.String, + }), + tx.AliasedKey(['destination_dir', 'destinationDir'], + default=None) >> 'dest_dir': t.Null | t.String, + }), + }), + t.Key('scaling_group', default=None): t.Null | t.String, + t.Key('mounts', default={}): t.Null | t.Mapping(t.String, t.Any), + t.Key('resources', default=None): t.Null | t.Mapping(t.String, t.Any), + tx.AliasedKey(['agent_list', 'agentList'], + default=None) >> 'agent_list': t.Null | t.List(t.String), + }), +}).allow_extra('*') + + +def check_task_template(raw_data: Mapping[str, Any]) -> Mapping[str, Any]: + data = task_template_v1.check(raw_data) + if mounts := data['spec'].get('mounts'): + for p in mounts.values(): + if p is None: + continue + if p.startswith("/home/work/"): + p = p.replace("/home/work/", "") + if not verify_vfolder_name(p): + raise InvalidArgument(f'Path {p} is reserved for internal operations.') + return data + + +cluster_template_v1 = t.Dict({ + tx.AliasedKey(['api_version', 'apiVersion']): t.String, + t.Key('kind'): t.Enum('clusterTemplate', 'cluster_template'), + t.Key('mode'): t.Enum('single-node', 'multi-node'), + t.Key('metadata'): t.Dict({ + t.Key('name'): t.String, + }), + t.Key('spec'): t.Dict({ + t.Key('environ', default={}): t.Null | t.Mapping(t.String, t.String), + t.Key('mounts', default={}): t.Null | t.Mapping(t.String, t.Any), + t.Key('nodes'): t.List(t.Dict({ + t.Key('role'): t.String, + tx.AliasedKey(['session_template', 'sessionTemplate']): tx.UUID, + t.Key('replicas', default=1): t.Int, + })), + }), +}).allow_extra('*') + + +def check_cluster_template(raw_data: Mapping[str, Any]) -> Mapping[str, Any]: + data = cluster_template_v1.check(raw_data) + defined_roles: List[str] = [] + for node in data['spec']['nodes']: + node['session_template'] = str(node['session_template']) + if node['role'] in defined_roles: + raise InvalidArgument("Each role can only be defined once") + if node['role'] == DEFAULT_ROLE and node['replicas'] != 1: + raise InvalidArgument( + f"One and only one {DEFAULT_ROLE} node must be created per cluster", + ) + defined_roles.append(node['role']) + if DEFAULT_ROLE not in defined_roles: + raise InvalidArgument( + f"One and only one {DEFAULT_ROLE} node must be created per cluster", + ) + return data + + +async def query_accessible_session_templates( + conn: SAConnection, + user_uuid: uuid.UUID, + template_type: TemplateType, + *, + user_role: UserRole = None, + domain_name: str = None, + allowed_types: Iterable[str] = ['user'], + extra_conds=None, +) -> List[Mapping[str, Any]]: + from ai.backend.manager.models import groups, users, association_groups_users as agus + entries: List[Mapping[str, Any]] = [] + if 'user' in allowed_types: + # Query user templates + j = (session_templates.join(users, session_templates.c.user_uuid == users.c.uuid)) + query = ( + sa.select([ + session_templates.c.name, + session_templates.c.id, + session_templates.c.created_at, + session_templates.c.user_uuid, + session_templates.c.group_id, + users.c.email, + ]) + .select_from(j) + .where( + (session_templates.c.user_uuid == user_uuid) & + session_templates.c.is_active & + (session_templates.c.type == template_type), + ) + ) + if extra_conds is not None: + query = query.where(extra_conds) + result = await conn.execute(query) + for row in result: + entries.append({ + 'name': row.name, + 'id': row.id, + 'created_at': row.created_at, + 'is_owner': True, + 'user': str(row.user_uuid) if row.user_uuid else None, + 'group': str(row.group_id) if row.group_id else None, + 'user_email': row.email, + 'group_name': None, + }) + if 'group' in allowed_types: + # Query group session_templates + if user_role == UserRole.ADMIN or user_role == 'admin': + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == domain_name) + ) + result = await conn.execute(query) + grps = result.fetchall() + group_ids = [g.id for g in grps] + else: + j = sa.join(agus, users, agus.c.user_id == users.c.uuid) + query = ( + sa.select([agus.c.group_id]) + .select_from(j) + .where(agus.c.user_id == user_uuid) + ) + result = await conn.execute(query) + grps = result.fetchall() + group_ids = [g.group_id for g in grps] + j = (session_templates.join(groups, session_templates.c.group_id == groups.c.id)) + query = ( + sa.select([ + session_templates.c.name, + session_templates.c.id, + session_templates.c.created_at, + session_templates.c.user_uuid, + session_templates.c.group_id, + groups.c.name, + ], use_labels=True) + .select_from(j) + .where( + session_templates.c.group_id.in_(group_ids) & + session_templates.c.is_active & + (session_templates.c.type == template_type), + ) + ) + if extra_conds is not None: + query = query.where(extra_conds) + if 'user' in allowed_types: + query = query.where(session_templates.c.user_uuid != user_uuid) + result = await conn.execute(query) + is_owner = (user_role == UserRole.ADMIN or user_role == 'admin') + for row in result: + entries.append({ + 'name': row.session_templates_name, + 'id': row.session_templates_id, + 'created_at': row.session_templates_created_at, + 'is_owner': is_owner, + 'user': (str(row.session_templates_user_uuid) if row.session_templates_user_uuid + else None), + 'group': str(row.session_templates_group_id) if row.session_templates_group_id else None, + 'user_email': None, + 'group_name': row.groups_name, + }) + return entries diff --git a/src/ai/backend/manager/models/storage.py b/src/ai/backend/manager/models/storage.py new file mode 100644 index 0000000000..f1625d8009 --- /dev/null +++ b/src/ai/backend/manager/models/storage.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager as actxmgr +from contextvars import ContextVar +import itertools +import logging +from pathlib import PurePosixPath +from typing import ( + Any, + AsyncIterator, + Final, + Iterable, + List, + Mapping, + Sequence, + Tuple, + TypedDict, + TYPE_CHECKING, +) +from uuid import UUID + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import HardwareMetadata + +import aiohttp +import attr +import graphene +import yarl + +from .base import ( + Item, PaginatedList, +) +from ..api.exceptions import VFolderOperationFailed +from ..exceptions import InvalidArgument +if TYPE_CHECKING: + from .gql import GraphQueryContext + +__all__ = ( + 'StorageProxyInfo', + 'VolumeInfo', + 'StorageSessionManager', + 'StorageVolume', +) + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class StorageProxyInfo: + session: aiohttp.ClientSession + secret: str + client_api_url: yarl.URL + manager_api_url: yarl.URL + + +AUTH_TOKEN_HDR: Final = 'X-BackendAI-Storage-Auth-Token' + +_ctx_volumes_cache: ContextVar[List[Tuple[str, VolumeInfo]]] = ContextVar('_ctx_volumes') + + +class VolumeInfo(TypedDict): + name: str + backend: str + path: str + fsprefix: str + capabilities: List[str] + + +class StorageSessionManager: + + _proxies: Mapping[str, StorageProxyInfo] + + def __init__(self, storage_config: Mapping[str, Any]) -> None: + self.config = storage_config + self._proxies = {} + for proxy_name, proxy_config in self.config['proxies'].items(): + connector = aiohttp.TCPConnector(ssl=proxy_config['ssl_verify']) + session = aiohttp.ClientSession(connector=connector) + self._proxies[proxy_name] = StorageProxyInfo( + session=session, + secret=proxy_config['secret'], + client_api_url=yarl.URL(proxy_config['client_api']), + manager_api_url=yarl.URL(proxy_config['manager_api']), + ) + + async def aclose(self) -> None: + close_aws = [] + for proxy_info in self._proxies.values(): + close_aws.append(proxy_info.session.close()) + await asyncio.gather(*close_aws, return_exceptions=True) + + @staticmethod + def split_host(vfolder_host: str) -> Tuple[str, str]: + proxy_name, _, volume_name = vfolder_host.partition(':') + return proxy_name, volume_name + + async def get_all_volumes(self) -> Iterable[Tuple[str, VolumeInfo]]: + try: + # per-asyncio-task cache + return _ctx_volumes_cache.get() + except LookupError: + pass + fetch_aws = [] + + async def _fetch( + proxy_name: str, + proxy_info: StorageProxyInfo, + ) -> Iterable[Tuple[str, VolumeInfo]]: + async with proxy_info.session.request( + 'GET', proxy_info.manager_api_url / 'volumes', + raise_for_status=True, + headers={AUTH_TOKEN_HDR: proxy_info.secret}, + ) as resp: + reply = await resp.json() + return ((proxy_name, volume_data) for volume_data in reply['volumes']) + + for proxy_name, proxy_info in self._proxies.items(): + fetch_aws.append(_fetch(proxy_name, proxy_info)) + results = [*itertools.chain(*await asyncio.gather(*fetch_aws))] + _ctx_volumes_cache.set(results) + return results + + async def get_mount_path( + self, + vfolder_host: str, + vfolder_id: UUID, + subpath: PurePosixPath = PurePosixPath("."), + ) -> str: + async with self.request( + vfolder_host, 'GET', 'folder/mount', + json={ + 'volume': self.split_host(vfolder_host)[1], + 'vfid': str(vfolder_id), + 'subpath': str(subpath), + }, + ) as (_, resp): + reply = await resp.json() + return reply['path'] + + @actxmgr + async def request( + self, + vfolder_host_or_proxy_name: str, + method: str, + request_relpath: str, + /, + *args, + **kwargs, + ) -> AsyncIterator[Tuple[yarl.URL, aiohttp.ClientResponse]]: + proxy_name, _ = self.split_host(vfolder_host_or_proxy_name) + try: + proxy_info = self._proxies[proxy_name] + except KeyError: + raise InvalidArgument('There is no such storage proxy', proxy_name) + headers = kwargs.pop('headers', {}) + headers[AUTH_TOKEN_HDR] = proxy_info.secret + async with proxy_info.session.request( + method, proxy_info.manager_api_url / request_relpath, + *args, + headers=headers, + **kwargs, + ) as client_resp: + if client_resp.status // 100 != 2: + try: + error_data = await client_resp.json() + raise VFolderOperationFailed( + extra_msg=error_data.pop("msg"), + extra_data=error_data, + ) + except aiohttp.ClientResponseError: + # when the response body is not JSON, just raise with status info. + raise VFolderOperationFailed( + extra_msg=f"Storage proxy responded with " + f"{client_resp.status} {client_resp.reason}", + ) + yield proxy_info.client_api_url, client_resp + + +class StorageVolume(graphene.ObjectType): + + class Meta: + interfaces = (Item, ) + + # id: {proxy_name}:{name} + backend = graphene.String() + fsprefix = graphene.String() + path = graphene.String() + capabilities = graphene.List(graphene.String) + hardware_metadata = graphene.JSONString() + performance_metric = graphene.JSONString() + usage = graphene.JSONString() + + async def resolve_hardware_metadata(self, info: graphene.ResolveInfo) -> HardwareMetadata: + ctx: GraphQueryContext = info.context + return await ctx.registry.gather_storage_hwinfo(self.id) + + async def resolve_performance_metric(self, info: graphene.ResolveInfo) -> Mapping[str, Any]: + ctx: GraphQueryContext = info.context + proxy_name, volume_name = ctx.storage_manager.split_host(self.id) + try: + proxy_info = ctx.storage_manager._proxies[proxy_name] + except KeyError: + raise ValueError(f"no such storage proxy: {proxy_name!r}") + try: + async with proxy_info.session.request( + 'GET', proxy_info.manager_api_url / 'volume/performance-metric', + json={'volume': volume_name}, + raise_for_status=True, + headers={AUTH_TOKEN_HDR: proxy_info.secret}, + ) as resp: + reply = await resp.json() + return reply['metric'] + except aiohttp.ClientResponseError: + return {} + + async def resolve_usage(self, info: graphene.ResolveInfo) -> Mapping[str, Any]: + ctx: GraphQueryContext = info.context + proxy_name, volume_name = ctx.storage_manager.split_host(self.id) + try: + proxy_info = ctx.storage_manager._proxies[proxy_name] + except KeyError: + raise ValueError(f"no such storage proxy: {proxy_name!r}") + try: + async with proxy_info.session.request( + 'GET', proxy_info.manager_api_url / 'folder/fs-usage', + json={'volume': volume_name}, + raise_for_status=True, + headers={AUTH_TOKEN_HDR: proxy_info.secret}, + ) as resp: + reply = await resp.json() + return reply + except aiohttp.ClientResponseError: + return {} + + @classmethod + def from_info(cls, proxy_name: str, volume_info: VolumeInfo) -> StorageVolume: + return cls( + id=f"{proxy_name}:{volume_info['name']}", + backend=volume_info['backend'], + path=volume_info['path'], + fsprefix=volume_info['fsprefix'], + capabilities=volume_info['capabilities'], + ) + + @classmethod + async def load_count( + cls, + ctx: GraphQueryContext, + filter: str = None, + ) -> int: + volumes = [*await ctx.storage_manager.get_all_volumes()] + return len(volumes) + + @classmethod + async def load_slice( + cls, + ctx: GraphQueryContext, + limit: int, + offset: int, + filter: str = None, + order: str = None, + ) -> Sequence[StorageVolume]: + # For consistency we add filter/order params here, but it's actually noop. + if filter is not None or order is not None: + log.warning("Paginated list of storage volumes igonores custom filtering and/or ordering") + volumes = [*await ctx.storage_manager.get_all_volumes()] + return [ + cls.from_info(proxy_name, volume_info) + for proxy_name, volume_info in volumes[offset:offset + limit] + ] + + @classmethod + async def load_by_id( + cls, + ctx: GraphQueryContext, + id: str, + ) -> StorageVolume: + proxy_name, volume_name = ctx.storage_manager.split_host(id) + try: + proxy_info = ctx.storage_manager._proxies[proxy_name] + except KeyError: + raise ValueError(f"no such storage proxy: {proxy_name!r}") + async with proxy_info.session.request( + 'GET', proxy_info.manager_api_url / 'volumes', + raise_for_status=True, + headers={AUTH_TOKEN_HDR: proxy_info.secret}, + ) as resp: + reply = await resp.json() + for volume_data in reply['volumes']: + if volume_data['name'] == volume_name: + return cls.from_info(proxy_name, volume_data) + else: + raise ValueError( + f"no such volume in the storage proxy {proxy_name!r}: {volume_name!r}", + ) + + +class StorageVolumeList(graphene.ObjectType): + class Meta: + interfaces = (PaginatedList, ) + + items = graphene.List(StorageVolume, required=True) diff --git a/src/ai/backend/manager/models/user.py b/src/ai/backend/manager/models/user.py new file mode 100644 index 0000000000..2a0ebd493c --- /dev/null +++ b/src/ai/backend/manager/models/user.py @@ -0,0 +1,1151 @@ +from __future__ import annotations + +import enum +import logging +from typing import ( + Any, + Dict, + Iterable, + Optional, + Sequence, + TYPE_CHECKING, +) +from uuid import UUID, uuid4 + +import aiohttp +from dateutil.parser import parse as dtparse +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +from passlib.hash import bcrypt +import sqlalchemy as sa +from sqlalchemy.engine.result import Result +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import ( + AsyncConnection as SAConnection, + AsyncEngine as SAEngine, +) +from sqlalchemy.sql.expression import bindparam +from sqlalchemy.types import TypeDecorator, VARCHAR + +from ai.backend.common import redis +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import RedisConnectionInfo + +from ..api.exceptions import VFolderOperationFailed +from .base import ( + EnumValueType, + IDColumn, + Item, + PaginatedList, + metadata, + set_if_set, + batch_result, + batch_multiresult, + simple_db_mutate, + simple_db_mutate_returning_item, +) +from .minilang.queryfilter import QueryFilterParser +from .minilang.ordering import QueryOrderParser +from .storage import StorageSessionManager + +if TYPE_CHECKING: + from .gql import GraphQueryContext + +log = BraceStyleAdapter(logging.getLogger(__file__)) + + +__all__: Sequence[str] = ( + 'users', + 'User', 'UserList', + 'UserGroup', 'UserRole', + 'UserInput', 'ModifyUserInput', + 'CreateUser', 'ModifyUser', 'DeleteUser', + 'UserStatus', 'ACTIVE_USER_STATUSES', 'INACTIVE_USER_STATUSES', +) + + +class PasswordColumn(TypeDecorator): + impl = VARCHAR + + def process_bind_param(self, value, dialect): + return _hash_password(value) + + +class UserRole(str, enum.Enum): + """ + User's role. + """ + SUPERADMIN = 'superadmin' + ADMIN = 'admin' + USER = 'user' + MONITOR = 'monitor' + + +class UserStatus(str, enum.Enum): + """ + User account status. + """ + ACTIVE = 'active' + INACTIVE = 'inactive' + DELETED = 'deleted' + BEFORE_VERIFICATION = 'before-verification' + + +ACTIVE_USER_STATUSES = ( + UserStatus.ACTIVE, +) + +INACTIVE_USER_STATUSES = ( + UserStatus.INACTIVE, + UserStatus.DELETED, + UserStatus.BEFORE_VERIFICATION, +) + + +users = sa.Table( + 'users', metadata, + IDColumn('uuid'), + sa.Column('username', sa.String(length=64), unique=True), + sa.Column('email', sa.String(length=64), index=True, + nullable=False, unique=True), + sa.Column('password', PasswordColumn()), + sa.Column('need_password_change', sa.Boolean), + sa.Column('full_name', sa.String(length=64)), + sa.Column('description', sa.String(length=500)), + sa.Column('status', EnumValueType(UserStatus), default=UserStatus.ACTIVE, nullable=False), + sa.Column('status_info', sa.Unicode(), nullable=True, default=sa.null()), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now()), + sa.Column('modified_at', sa.DateTime(timezone=True), + server_default=sa.func.now(), onupdate=sa.func.current_timestamp()), + #: Field for synchronization with external services. + sa.Column('integration_id', sa.String(length=512)), + + sa.Column('domain_name', sa.String(length=64), + sa.ForeignKey('domains.name'), index=True), + sa.Column('role', EnumValueType(UserRole), default=UserRole.USER), +) + + +class UserGroup(graphene.ObjectType): + id = graphene.UUID() + name = graphene.String() + + @classmethod + def from_row(cls, ctx: GraphQueryContext, row: Row) -> Optional[UserGroup]: + if row is None: + return None + return cls( + id=row['id'], + name=row['name'], + ) + + @classmethod + async def batch_load_by_user_id(cls, ctx: GraphQueryContext, user_ids: Sequence[UUID]): + async with ctx.db.begin() as conn: + from .group import groups, association_groups_users as agus + j = agus.join(groups, agus.c.group_id == groups.c.id) + query = ( + sa.select([agus.c.user_id, groups.c.name, groups.c.id]) + .select_from(j) + .where(agus.c.user_id.in_(user_ids)) + ) + return await batch_multiresult( + ctx, conn, query, cls, + user_ids, lambda row: row['user_id'], + ) + + +class User(graphene.ObjectType): + + class Meta: + interfaces = (Item, ) + + uuid = graphene.UUID() # legacy + username = graphene.String() + email = graphene.String() + need_password_change = graphene.Boolean() + full_name = graphene.String() + description = graphene.String() + is_active = graphene.Boolean() + status = graphene.String() + status_info = graphene.String() + created_at = GQLDateTime() + modified_at = GQLDateTime() + domain_name = graphene.String() + role = graphene.String() + + groups = graphene.List(lambda: UserGroup) + + async def resolve_groups( + self, + info: graphene.ResolveInfo, + ) -> Iterable[UserGroup]: + ctx: GraphQueryContext = info.context + manager = ctx.dataloader_manager + loader = manager.get_loader(ctx, 'UserGroup.by_user_id') + return await loader.load(self.id) + + @classmethod + def from_row( + cls, + ctx: GraphQueryContext, + row: Row, + ) -> User: + return cls( + id=row['uuid'], + uuid=row['uuid'], + username=row['username'], + email=row['email'], + need_password_change=row['need_password_change'], + full_name=row['full_name'], + description=row['description'], + is_active=True if row['status'] == UserStatus.ACTIVE else False, # legacy + status=row['status'], + status_info=row['status_info'], + created_at=row['created_at'], + modified_at=row['modified_at'], + domain_name=row['domain_name'], + role=row['role'], + ) + + @classmethod + async def load_all( + cls, + ctx: GraphQueryContext, + *, + domain_name: str = None, + group_id: UUID = None, + is_active: bool = None, + status: str = None, + limit: int = None, + ) -> Sequence[User]: + """ + Load user's information. Group names associated with the user are also returned. + """ + if group_id is not None: + from .group import association_groups_users as agus + j = (users.join(agus, agus.c.user_id == users.c.uuid)) + query = ( + sa.select([users]) + .select_from(j) + .where(agus.c.group_id == group_id) + ) + else: + query = ( + sa.select([users]) + .select_from(users) + ) + if ctx.user['role'] != UserRole.SUPERADMIN: + query = query.where(users.c.domain_name == ctx.user['domain_name']) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + if limit is not None: + query = query.limit(limit) + async with ctx.db.begin_readonly() as conn: + return [cls.from_row(ctx, row) async for row in (await conn.stream(query))] + + _queryfilter_fieldspec = { + "uuid": ("uuid", None), + "username": ("username", None), + "email": ("email", None), + "need_password_change": ("need_password_change", None), + "full_name": ("full_name", None), + "description": ("description", None), + "is_active": ("is_active", None), + "status": ("status", lambda s: UserStatus[s]), + "status_info": ("status_info", None), + "created_at": ("created_at", dtparse), + "modified_at": ("modified_at", dtparse), + "domain_name": ("domain_name", None), + "role": ("role", lambda s: UserRole[s]), + } + + _queryorder_colmap = { + "uuid": "uuid", + "username": "username", + "email": "email", + "need_password_change": "need_password_change", + "full_name": "full_name", + "is_active": "is_active", + "status": "status", + "status_info": "status_info", + "created_at": "created_at", + "modified_at": "modified_at", + "domain_name": "domain_name", + "role": "role", + } + + @classmethod + async def load_count( + cls, + ctx: GraphQueryContext, + *, + domain_name: str = None, + group_id: UUID = None, + is_active: bool = None, + status: str = None, + filter: str = None, + ) -> int: + if group_id is not None: + from .group import association_groups_users as agus + j = (users.join(agus, agus.c.user_id == users.c.uuid)) + query = ( + sa.select([sa.func.count()]) + .select_from(j) + .where(agus.c.group_id == group_id) + ) + else: + query = ( + sa.select([sa.func.count()]) + .select_from(users) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + if filter is not None: + if group_id is not None: + qfparser = QueryFilterParser({ + k: ('users_' + v[0], v[1]) + for k, v in cls._queryfilter_fieldspec.items() + }) + else: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + async with ctx.db.begin_readonly() as conn: + result = await conn.execute(query) + return result.scalar() + + @classmethod + async def load_slice( + cls, + ctx: GraphQueryContext, + limit: int, + offset: int, + *, + domain_name: str = None, + group_id: UUID = None, + is_active: bool = None, + status: str = None, + filter: str = None, + order: str = None, + ) -> Sequence[User]: + if group_id is not None: + from .group import association_groups_users as agus + j = (users.join(agus, agus.c.user_id == users.c.uuid)) + query = ( + sa.select([users]) + .select_from(j) + .where(agus.c.group_id == group_id) + .limit(limit) + .offset(offset) + ) + else: + query = ( + sa.select([users]) + .select_from(users) + .limit(limit) + .offset(offset) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + if filter is not None: + if group_id is not None: + qfparser = QueryFilterParser({ + k: ('users_' + v[0], v[1]) + for k, v in cls._queryfilter_fieldspec.items() + }) + else: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + if order is not None: + if group_id is not None: + qoparser = QueryOrderParser({ + k: 'users_' + v + for k, v in cls._queryorder_colmap.items() + }) + else: + qoparser = QueryOrderParser(cls._queryorder_colmap) + query = qoparser.append_ordering(query, order) + else: + query = query.order_by( + users.c.created_at.desc(), + ) + async with ctx.db.begin_readonly() as conn: + return [ + cls.from_row(ctx, row) async for row in (await conn.stream(query)) + ] + + @classmethod + async def batch_load_by_email( + cls, + ctx: GraphQueryContext, + emails: Sequence[str] = None, + *, + domain_name: str = None, + is_active: bool = None, + status: str = None, + ) -> Sequence[Optional[User]]: + if not emails: + return [] + query = ( + sa.select([users]) + .select_from(users) + .where(users.c.email.in_(emails)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + emails, lambda row: row['email'], + ) + + @classmethod + async def batch_load_by_uuid( + cls, + ctx: GraphQueryContext, + user_ids: Sequence[UUID] = None, + *, + domain_name: str = None, + is_active: bool = None, + status: str = None, + ) -> Sequence[Optional[User]]: + if not user_ids: + return [] + query = ( + sa.select([users]) + .select_from(users) + .where(users.c.uuid.in_(user_ids)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if status is not None: + query = query.where(users.c.status == UserStatus(status)) + elif is_active is not None: # consider is_active field only if status is empty + _statuses = ACTIVE_USER_STATUSES if is_active else INACTIVE_USER_STATUSES + query = query.where(users.c.status.in_(_statuses)) + async with ctx.db.begin_readonly() as conn: + return await batch_result( + ctx, conn, query, cls, + user_ids, lambda row: row['uuid'], + ) + + +class UserList(graphene.ObjectType): + class Meta: + interfaces = (PaginatedList, ) + + items = graphene.List(User, required=True) + + +class UserInput(graphene.InputObjectType): + username = graphene.String(required=True) + password = graphene.String(required=True) + need_password_change = graphene.Boolean(required=True) + full_name = graphene.String(required=False, default='') + description = graphene.String(required=False, default='') + is_active = graphene.Boolean(required=False, default=True) + status = graphene.String(required=False, default=UserStatus.ACTIVE) + domain_name = graphene.String(required=True, default='default') + role = graphene.String(required=False, default=UserRole.USER) + group_ids = graphene.List(lambda: graphene.String, required=False) + + # When creating, you MUST set all fields. + # When modifying, set the field to "None" to skip setting the value. + + +class ModifyUserInput(graphene.InputObjectType): + username = graphene.String(required=False) + password = graphene.String(required=False) + need_password_change = graphene.Boolean(required=False) + full_name = graphene.String(required=False) + description = graphene.String(required=False) + is_active = graphene.Boolean(required=False) + status = graphene.String(required=False) + domain_name = graphene.String(required=False) + role = graphene.String(required=False) + group_ids = graphene.List(lambda: graphene.String, required=False) + + +class PurgeUserInput(graphene.InputObjectType): + purge_shared_vfolders = graphene.Boolean(required=False, default=False) + + +class CreateUser(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + email = graphene.String(required=True) + props = UserInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + user = graphene.Field(lambda: User, required=False) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + email: str, + props: UserInput, + ) -> CreateUser: + graph_ctx: GraphQueryContext = info.context + username = props.username if props.username else email + if props.status is None and props.is_active is not None: + _status = UserStatus.ACTIVE if props.is_active else UserStatus.INACTIVE + else: + _status = UserStatus(props.status) + user_data = { + 'username': username, + 'email': email, + 'password': props.password, + 'need_password_change': props.need_password_change, + 'full_name': props.full_name, + 'description': props.description, + 'status': _status, + 'status_info': 'admin-requested', # user mutation is only for admin + 'domain_name': props.domain_name, + 'role': UserRole(props.role), + } + user_insert_query = ( + sa.insert(users) + .values(user_data) + ) + + async def _post_func(conn: SAConnection, result: Result) -> Row: + if result.rowcount == 0: + return + created_user = result.first() + + # Create a default keypair for the user. + from .keypair import CreateKeyPair, keypairs + kp_data = CreateKeyPair.prepare_new_keypair( + email, + graph_ctx.schema.get_type('KeyPairInput').create_container({ + 'is_active': (_status == UserStatus.ACTIVE), + 'is_admin': (user_data['role'] in [UserRole.SUPERADMIN, UserRole.ADMIN]), + 'resource_policy': 'default', + 'rate_limit': 10000, + }), + ) + kp_insert_query = ( + sa.insert(keypairs) + .values( + **kp_data, + user=created_user.uuid, + ) + ) + await conn.execute(kp_insert_query) + + # Add user to groups if group_ids parameter is provided. + from .group import association_groups_users, groups + if props.group_ids: + query = ( + sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == props.domain_name) + .where(groups.c.id.in_(props.group_ids)) + ) + grps = (await conn.execute(query)).all() + if grps: + group_data = [ + {'user_id': created_user.uuid, 'group_id': grp.id} + for grp in grps + ] + group_insert_query = ( + sa.insert(association_groups_users) + .values(group_data) + ) + await conn.execute(group_insert_query) + + return created_user + + return await simple_db_mutate_returning_item( + cls, graph_ctx, user_insert_query, item_cls=User, post_func=_post_func, + ) + + +class ModifyUser(graphene.Mutation): + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + email = graphene.String(required=True) + props = ModifyUserInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + user = graphene.Field(lambda: User) + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + email: str, + props: ModifyUserInput, + ) -> ModifyUser: + graph_ctx: GraphQueryContext = info.context + data: Dict[str, Any] = {} + set_if_set(props, data, 'username') + set_if_set(props, data, 'password') + set_if_set(props, data, 'need_password_change') + set_if_set(props, data, 'full_name') + set_if_set(props, data, 'description') + set_if_set(props, data, 'status', clean_func=UserStatus) + set_if_set(props, data, 'domain_name') + set_if_set(props, data, 'role', clean_func=UserRole) + if not data and not props.group_ids: + return cls(ok=False, msg='nothing to update', user=None) + if data.get('status') is None and props.is_active is not None: + data['status'] = UserStatus.ACTIVE if props.is_active else UserStatus.INACTIVE + + user_update_data: Dict[str, Any] + prev_domain_name: str + prev_role: UserRole + + async def _pre_func(conn: SAConnection) -> None: + nonlocal user_update_data, prev_domain_name, prev_role + result = await conn.execute( + sa.select([ + users.c.domain_name, + users.c.role, + users.c.status, + ]) + .select_from(users) + .where(users.c.email == email), + ) + row = result.first() + prev_domain_name = row.domain_name + prev_role = row.role + user_update_data = data.copy() + if 'status' in data and row.status != data['status']: + user_update_data['status_info'] = 'admin-requested' # user mutation is only for admin + + update_query = lambda: ( # uses lambda because user_update_data is modified in _pre_func() + sa.update(users) + .values(user_update_data) + .where(users.c.email == email) + ) + + async def _post_func(conn: SAConnection, result: Result) -> Row: + nonlocal prev_domain_name, prev_role + updated_user = result.first() + + if 'role' in data and data['role'] != prev_role: + from ai.backend.manager.models import keypairs + result = await conn.execute( + sa.select([ + keypairs.c.user, + keypairs.c.is_active, + keypairs.c.is_admin, + ]) + .select_from(keypairs) + .where(keypairs.c.user == updated_user.uuid) + .order_by(sa.desc(keypairs.c.is_admin)) + .order_by(sa.desc(keypairs.c.is_active)), + ) + if data['role'] in [UserRole.SUPERADMIN, UserRole.ADMIN]: + # User's becomes admin. Set the keypair as active admin. + kp = result.first() + kp_data = dict() + if not kp.is_admin: + kp_data['is_admin'] = True + if not kp.is_active: + kp_data['is_active'] = True + if len(kp_data.keys()) > 0: + await conn.execute( + sa.update(keypairs) + .values(kp_data) + .where(keypairs.c.user == updated_user.uuid), + ) + else: + # User becomes non-admin. Make the keypair non-admin as well. + # If there are multiple admin keypairs, inactivate them. + rows = result.fetchall() + cnt = 0 + kp_updates = [] + for row in rows: + kp_data = { + 'user': row.user, + 'is_admin': keypairs.c.is_admin, + 'is_active': keypairs.c.is_active, + } + changed = False + if cnt == 0: + kp_data['is_admin'] = False + changed = True + elif row.is_admin and row.is_active: + kp_data['is_active'] = False + changed = True + if changed: + kp_updates.append(kp_data) + cnt += 1 + await conn.execute( + sa.update(keypairs) + .values({ + 'is_admin': bindparam('is_admin'), + 'is_active': bindparam('is_active'), + }) + .where(keypairs.c.user == bindparam('user')), + kp_updates, + ) + + # If domain is changed and no group is associated, clear previous domain's group. + if prev_domain_name != updated_user.domain_name and not props.group_ids: + from .group import association_groups_users, groups + await conn.execute( + sa.delete(association_groups_users) + .where(association_groups_users.c.user_id == updated_user.uuid), + ) + + # Update user's group if group_ids parameter is provided. + if props.group_ids and updated_user is not None: + from .group import association_groups_users, groups # noqa + # Clear previous groups associated with the user. + await conn.execute( + sa.delete(association_groups_users) + .where(association_groups_users.c.user_id == updated_user.uuid), + ) + # Add user to new groups. + result = await conn.execute( + sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == updated_user.domain_name) + .where(groups.c.id.in_(props.group_ids)), + ) + grps = result.fetchall() + if grps: + values = [ + {'user_id': updated_user.uuid, 'group_id': grp.id} + for grp in grps + ] + await conn.execute( + sa.insert(association_groups_users).values(values), + ) + + return updated_user + + return await simple_db_mutate_returning_item( + cls, graph_ctx, update_query, item_cls=User, + pre_func=_pre_func, post_func=_post_func, + ) + + +class DeleteUser(graphene.Mutation): + """ + Instead of really deleting user, just mark the account as deleted status. + + All related keypairs will also be inactivated. + """ + + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + email = graphene.String(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + email: str, + ) -> DeleteUser: + graph_ctx: GraphQueryContext = info.context + + async def _pre_func(conn: SAConnection) -> None: + # Make all user keypairs inactive. + from ai.backend.manager.models import keypairs + await conn.execute( + sa.update(keypairs) + .values(is_active=False) + .where(keypairs.c.user_id == email), + ) + + update_query = ( + sa.update(users) + .values(status=UserStatus.DELETED, + status_info='admin-requested') + .where(users.c.email == email) + ) + return await simple_db_mutate(cls, graph_ctx, update_query, pre_func=_pre_func) + + +class PurgeUser(graphene.Mutation): + """ + Delete user as well as all user-related DB informations such as keypairs, kernels, etc. + + If target user has virtual folders, they can be purged together or migrated to the superadmin. + + vFolder treatment policy: + User-type: + - vfolder is not shared: delete + - vfolder is shared: + + if purge_shared_vfolder is True: delete + + else: change vfolder's owner to requested admin + + This action cannot be undone. + """ + allowed_roles = (UserRole.SUPERADMIN,) + + class Arguments: + email = graphene.String(required=True) + props = PurgeUserInput(required=True) + + ok = graphene.Boolean() + msg = graphene.String() + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + email: str, + props: PurgeUserInput, + ) -> PurgeUser: + graph_ctx: GraphQueryContext = info.context + + async def _pre_func(conn: SAConnection) -> None: + user_uuid = await conn.scalar( + sa.select([users.c.uuid]) + .select_from(users) + .where(users.c.email == email), + ) + log.info("Purging all records of the user {0}...", email) + + if await cls.user_vfolder_mounted_to_active_kernels(conn, user_uuid): + raise RuntimeError( + "Some of user's virtual folders are mounted to active kernels. " + "Terminate those kernels first.", + ) + if await cls.user_has_active_kernels(conn, user_uuid): + raise RuntimeError("User has some active kernels. Terminate them first.") + + if not props.purge_shared_vfolders: + await cls.migrate_shared_vfolders( + conn, + deleted_user_uuid=user_uuid, + target_user_uuid=graph_ctx.user['uuid'], + target_user_email=graph_ctx.user['email'], + ) + await cls.delete_vfolders(conn, user_uuid, graph_ctx.storage_manager) + await cls.delete_kernels(conn, user_uuid) + await cls.delete_keypairs(conn, graph_ctx.redis_stat, user_uuid) + + delete_query = ( + sa.delete(users) + .where(users.c.email == email) + ) + return await simple_db_mutate(cls, graph_ctx, delete_query, pre_func=_pre_func) + + @classmethod + async def migrate_shared_vfolders( + cls, + conn: SAConnection, + deleted_user_uuid: UUID, + target_user_uuid: UUID, + target_user_email: str, + ) -> int: + """ + Migrate shared virtual folders' ownership to a target user. + + If migrating virtual folder's name collides with target user's already + existing folder, append random string to the migrating one. + + :param conn: DB connection + :param deleted_user_uuid: user's UUID who will be deleted + :param target_user_uuid: user's UUID who will get the ownership of virtual folders + + :return: number of deleted rows + """ + from . import vfolders, vfolder_invitations, vfolder_permissions + # Gather target user's virtual folders' names. + query = ( + sa.select([vfolders.c.name]) + .select_from(vfolders) + .where(vfolders.c.user == target_user_uuid) + ) + existing_vfolder_names = [row.name async for row in (await conn.stream(query))] + + # Migrate shared virtual folders. + # If virtual folder's name collides with target user's folder, + # append random string to the name of the migrating folder. + j = vfolder_permissions.join( + vfolders, + vfolder_permissions.c.vfolder == vfolders.c.id, + ) + query = ( + sa.select([vfolders.c.id, vfolders.c.name]) + .select_from(j) + .where(vfolders.c.user == deleted_user_uuid) + ) + migrate_updates = [] + async for row in (await conn.stream(query)): + name = row.name + if name in existing_vfolder_names: + name += f'-{uuid4().hex[:10]}' + migrate_updates.append({'vid': row.id, 'vname': name}) + if migrate_updates: + # Remove invitations and vfolder_permissions from target user. + # Target user will be the new owner, and it does not make sense to have + # invitation and shared permission for its own folder. + migrate_vfolder_ids = [item['vid'] for item in migrate_updates] + delete_query = ( + sa.delete(vfolder_invitations) + .where((vfolder_invitations.c.invitee == target_user_email) & + (vfolder_invitations.c.vfolder.in_(migrate_vfolder_ids))) + ) + await conn.execute(delete_query) + delete_query = ( + sa.delete(vfolder_permissions) + .where((vfolder_permissions.c.user == target_user_uuid) & + (vfolder_permissions.c.vfolder.in_(migrate_vfolder_ids))) + ) + await conn.execute(delete_query) + + rowcount = 0 + for item in migrate_updates: + update_query = ( + sa.update(vfolders) + .values( + user=target_user_uuid, + name=item['vname'], + ) + .where(vfolders.c.id == item['vid']) + ) + result = await conn.execute(update_query) + rowcount += result.rowcount + if rowcount > 0: + log.info( + "{0} shared folders are detected and migrated to user {1}", + rowcount, target_user_uuid, + ) + return rowcount + else: + return 0 + + @classmethod + async def delete_vfolders( + cls, + conn: SAConnection, + user_uuid: UUID, + storage_manager: StorageSessionManager, + ) -> int: + """ + Delete user's all virtual folders as well as their physical data. + + :param conn: DB connection + :param user_uuid: user's UUID to delete virtual folders + + :return: number of deleted rows + """ + from . import vfolders, vfolder_permissions + await conn.execute( + vfolder_permissions.delete() + .where(vfolder_permissions.c.user == user_uuid), + ) + result = await conn.execute( + sa.select([vfolders.c.id, vfolders.c.host]) + .select_from(vfolders) + .where(vfolders.c.user == user_uuid), + ) + target_vfs = result.fetchall() + result = await conn.execute( + sa.delete(vfolders).where(vfolders.c.user == user_uuid), + ) + for row in target_vfs: + try: + async with storage_manager.request( + row['host'], 'POST', 'folder/delete', + json={ + 'volume': storage_manager.split_host(row['host'])[1], + 'vfid': str(row['id']), + }, + raise_for_status=True, + ): + pass + except aiohttp.ClientResponseError: + log.error('error on deleting vfolder filesystem directory: {0}', row['id']) + raise VFolderOperationFailed + if result.rowcount > 0: + log.info('deleted {0} user\'s virtual folders ({1})', result.rowcount, user_uuid) + return result.rowcount + + @classmethod + async def user_vfolder_mounted_to_active_kernels( + cls, + conn: SAConnection, + user_uuid: UUID, + ) -> bool: + """ + Check if no active kernel is using the user's virtual folders. + + :param conn: DB connection + :param user_uuid: user's UUID + + :return: True if a virtual folder is mounted to active kernels. + """ + from . import kernels, vfolders, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES + result = await conn.execute( + sa.select([vfolders.c.id]) + .select_from(vfolders) + .where(vfolders.c.user == user_uuid), + ) + rows = result.fetchall() + user_vfolder_ids = [row.id for row in rows] + query = ( + sa.select([kernels.c.mounts]) + .select_from(kernels) + .where(kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) + ) + async for row in (await conn.stream(query)): + for _mount in row['mounts']: + try: + vfolder_id = UUID(_mount[2]) + if vfolder_id in user_vfolder_ids: + return True + except Exception: + pass + return False + + @classmethod + async def user_has_active_kernels( + cls, + conn: SAConnection, + user_uuid: UUID, + ) -> bool: + """ + Check if the user does not have active kernels. + + :param conn: DB connection + :param user_uuid: user's UUID + + :return: True if the user has some active kernels. + """ + from . import kernels, AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES + active_kernel_count = await conn.scalar( + sa.select([sa.func.count()]) + .select_from(kernels) + .where( + (kernels.c.user_uuid == user_uuid) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ), + ) + return (active_kernel_count > 0) + + @classmethod + async def delete_kernels( + cls, + conn: SAConnection, + user_uuid: UUID, + ) -> int: + """ + Delete user's all kernels. + + :param conn: DB connection + :param user_uuid: user's UUID to delete kernels + :return: number of deleted rows + """ + from . import kernels + result = await conn.execute( + sa.delete(kernels) + .where(kernels.c.user_uuid == user_uuid), + ) + if result.rowcount > 0: + log.info('deleted {0} user\'s kernels ({1})', result.rowcount, user_uuid) + return result.rowcount + + @classmethod + async def delete_keypairs( + cls, + conn: SAConnection, + redis_conn: RedisConnectionInfo, + user_uuid: UUID, + ) -> int: + """ + Delete user's all keypairs. + + :param conn: DB connection + :param redis_conn: redis connection info + :param user_uuid: user's UUID to delete keypairs + :return: number of deleted rows + """ + from . import keypairs + ak_rows = await conn.execute( + sa.select([keypairs.c.access_key]) + .where(keypairs.c.user == user_uuid), + ) + access_key = ak_rows.first().access_key + await redis.execute( + redis_conn, + lambda r: r.delete(f'keypair.concurrency_used.{access_key}'), + ) + result = await conn.execute( + sa.delete(keypairs) + .where(keypairs.c.user == user_uuid), + ) + if result.rowcount > 0: + log.info('deleted {0} user\'s keypairs ({1})', result.rowcount, user_uuid) + return result.rowcount + + +def _hash_password(password): + return bcrypt.using(rounds=12).hash(password) + + +def _verify_password(guess, hashed): + return bcrypt.verify(guess, hashed) + + +async def check_credential( + db: SAEngine, + domain: str, + email: str, + password: str, +) -> Any: + async with db.begin_readonly() as conn: + result = await conn.execute( + sa.select([users]) + .select_from(users) + .where( + (users.c.email == email) & + (users.c.domain_name == domain), + ), + ) + row = result.first() + if row is None: + return None + if row['password'] is None: + # user password is not set. + return None + try: + if _verify_password(password, row['password']): + return row + except ValueError: + return None + return None diff --git a/src/ai/backend/manager/models/utils.py b/src/ai/backend/manager/models/utils.py new file mode 100644 index 0000000000..8642fa3a47 --- /dev/null +++ b/src/ai/backend/manager/models/utils.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager as actxmgr +import functools +import json +import logging +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Mapping, + Tuple, + TypeVar, +) +from urllib.parse import quote_plus as urlquote + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql as psql +from sqlalchemy.engine import create_engine as _create_engine +from sqlalchemy.exc import DBAPIError +from sqlalchemy.ext.asyncio import ( + AsyncConnection as SAConnection, + AsyncEngine as SAEngine, + AsyncSession as SASession, +) +from tenacity import ( + AsyncRetrying, + RetryError, + TryAgain, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from ai.backend.common.json import ExtendedJSONEncoder +from ai.backend.common.logging import BraceStyleAdapter + +if TYPE_CHECKING: + from ..config import LocalConfig +from ..defs import LockID +from ..types import Sentinel + +log = BraceStyleAdapter(logging.getLogger(__name__)) +column_constraints = ['nullable', 'index', 'unique', 'primary_key'] + +# TODO: Implement begin(), begin_readonly() for AsyncSession also + + +class ExtendedAsyncSAEngine(SAEngine): + """ + A subclass to add a few more convenience methods to the SQLAlchemy's async engine. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._readonly_txn_count = 0 + self._generic_txn_count = 0 + self._txn_concurrency_threshold = kwargs.pop("txn_concurrency_threshold", 8) + + @actxmgr + async def begin(self) -> AsyncIterator[SAConnection]: + async with super().begin() as conn: + self._generic_txn_count += 1 + if self._generic_txn_count >= self._txn_concurrency_threshold: + log.warning( + "The number of concurrent generic transaction ({}) exceeded the threshold {}.", + self._generic_txn_count, self._txn_concurrency_threshold, + stack_info=False, + ) + try: + yield conn + finally: + self._generic_txn_count -= 1 + + @actxmgr + async def begin_readonly(self, deferrable: bool = False) -> AsyncIterator[SAConnection]: + async with self.connect() as conn: + self._readonly_txn_count += 1 + if self._readonly_txn_count >= self._txn_concurrency_threshold: + log.warning( + "The number of concurrent read-only transaction ({}) exceeded the threshold {}.", + self._readonly_txn_count, self._txn_concurrency_threshold, + stack_info=False, + ) + conn_with_exec_opts = await conn.execution_options( + postgresql_readonly=True, + postgresql_deferrable=deferrable, + ) + async with conn_with_exec_opts.begin(): + try: + yield conn_with_exec_opts + finally: + self._readonly_txn_count -= 1 + + @actxmgr + async def begin_session(self) -> AsyncIterator[SASession]: + async with self.begin() as conn: + session = SASession(bind=conn) + try: + yield session + await session.commit() + except Exception as e: + await session.rollback() + raise e + + @actxmgr + async def begin_readonly_session(self, deferrable: bool = False) -> AsyncIterator[SASession]: + async with self.begin_readonly(deferrable=deferrable) as conn: + yield SASession(bind=conn) + + @actxmgr + async def advisory_lock(self, lock_id: LockID) -> AsyncIterator[None]: + lock_acquired = False + # Here we use the session-level advisory lock, + # which follows the lifetime of underlying DB connection. + # As such, we should keep using one single connection for both lock and unlock ops. + async with self.connect() as lock_conn: + try: + # It is usually a BAD practice to directly interpolate strings into SQL statements, + # but in this case: + # - The lock ID is only given from trusted codes. + # - asyncpg does not support parameter interpolation with raw SQL statements. + await lock_conn.exec_driver_sql( + f"SELECT pg_advisory_lock({lock_id:d});", + ) + except sa.exc.DBAPIError as e: + if getattr(e.orig, 'pgcode', None) == '55P03': # lock not available error + # This may happen upon shutdown after some time. + raise asyncio.CancelledError() + raise + except asyncio.CancelledError: + raise + else: + lock_acquired = True + yield + finally: + if lock_acquired and not lock_conn.closed: + await lock_conn.exec_driver_sql( + f"SELECT pg_advisory_unlock({lock_id:d})", + ) + + +def create_async_engine(*args, **kwargs) -> ExtendedAsyncSAEngine: + kwargs["future"] = True + sync_engine = _create_engine(*args, **kwargs) + return ExtendedAsyncSAEngine(sync_engine) + + +@actxmgr +async def connect_database( + local_config: LocalConfig | Mapping[str, Any], +) -> AsyncIterator[ExtendedAsyncSAEngine]: + from .base import pgsql_connect_opts + username = local_config['db']['user'] + password = local_config['db']['password'] + address = local_config['db']['addr'] + dbname = local_config['db']['name'] + url = f"postgresql+asyncpg://{urlquote(username)}:{urlquote(password)}@{address}/{urlquote(dbname)}" + + version_check_db = create_async_engine(url) + async with version_check_db.begin() as conn: + result = await conn.execute(sa.text("show server_version")) + major, minor, *_ = map(int, result.scalar().split(".")) + if (major, minor) < (11, 0): + pgsql_connect_opts['server_settings'].pop("jit") + await version_check_db.dispose() + + db = create_async_engine( + url, + connect_args=pgsql_connect_opts, + pool_size=8, + max_overflow=64, + json_serializer=functools.partial(json.dumps, cls=ExtendedJSONEncoder), + isolation_level="SERIALIZABLE", + future=True, + ) + yield db + await db.dispose() + + +@actxmgr +async def reenter_txn( + pool: ExtendedAsyncSAEngine, + conn: SAConnection, + execution_opts: Mapping[str, Any] | None = None, +) -> AsyncIterator[SAConnection]: + if conn is None: + async with pool.connect() as conn: + if execution_opts: + await conn.execution_options(**execution_opts) + async with conn.begin(): + yield conn + else: + async with conn.begin_nested(): + yield conn + + +TQueryResult = TypeVar('TQueryResult') + + +async def execute_with_retry(txn_func: Callable[[], Awaitable[TQueryResult]]) -> TQueryResult: + max_attempts = 20 + result: TQueryResult | Sentinel = Sentinel.token + try: + async for attempt in AsyncRetrying( + wait=wait_exponential(multiplier=0.02, min=0.02, max=5.0), + stop=stop_after_attempt(max_attempts), + retry=retry_if_exception_type(TryAgain), + ): + with attempt: + try: + result = await txn_func() + except DBAPIError as e: + if getattr(e.orig, 'pgcode', None) == '40001': + raise TryAgain + raise + except RetryError: + raise RuntimeError(f"DB serialization failed after {max_attempts} retries") + assert result is not Sentinel.token + return result + + +def sql_json_merge( + col, + key: Tuple[str, ...], + obj: Mapping[str, Any], + *, + _depth: int = 0, +): + """ + Generate an SQLAlchemy column update expression that merges the given object with + the existing object at a specific (nested) key of the given JSONB column, + with automatic creation of empty objects in parents and the target level. + + Note that the existing value must be also an object, not a primitive value. + """ + expr = sa.func.coalesce( + col if _depth == 0 else col[key[:_depth]], + sa.text("'{}'::jsonb"), + ).concat( + sa.func.jsonb_build_object( + key[_depth], + ( + sa.func.coalesce(col[key], sa.text("'{}'::jsonb")) + .concat(sa.func.cast(obj, psql.JSONB)) + if _depth == len(key) - 1 + else sql_json_merge(col, key, obj=obj, _depth=_depth + 1) + ), + ), + ) + return expr + + +def sql_json_increment( + col, + key: Tuple[str, ...], + *, + parent_updates: Mapping[str, Any] = None, + _depth: int = 0, +): + """ + Generate an SQLAlchemy column update expression that increments the value at a specific + (nested) key of the given JSONB column, + with automatic creation of empty objects in parents and population of the + optional parent_updates object to the target key's parent. + + Note that the existing value of the parent key must be also an object, not a primitive value. + """ + expr = sa.func.coalesce( + col if _depth == 0 else col[key[:_depth]], + sa.text("'{}'::jsonb"), + ).concat( + sa.func.jsonb_build_object( + key[_depth], + ( + sa.func.coalesce(col[key].as_integer(), 0) + 1 + if _depth == len(key) - 1 + else sql_json_increment(col, key, parent_updates=parent_updates, _depth=_depth + 1) + ), + ), + ) + if _depth == len(key) - 1 and parent_updates is not None: + expr = expr.concat(sa.func.cast(parent_updates, psql.JSONB)) + return expr + + +def _populate_column(column: sa.Column): + column_attrs = dict(column.__dict__) + name = column_attrs.pop('name') + return sa.Column(name, column.type, **{k: column_attrs[k] for k in column_constraints}) + + +def regenerate_table(table: sa.Table, new_metadata: sa.MetaData) -> sa.Table: + ''' + This function can be used to regenerate table which belongs to SQLAlchemy ORM Class, + which can be helpful when you're tring to build fresh new table for use on diffrent context + than main manager logic (e.g. test code). + Check out tests/test_image.py for more details. + ''' + return sa.Table( + table.name, new_metadata, + *[_populate_column(c) for c in table.columns], + ) diff --git a/src/ai/backend/manager/models/vfolder.py b/src/ai/backend/manager/models/vfolder.py new file mode 100644 index 0000000000..12d59da3a3 --- /dev/null +++ b/src/ai/backend/manager/models/vfolder.py @@ -0,0 +1,823 @@ +from __future__ import annotations + +import enum +import os.path +import uuid +from pathlib import PurePosixPath +from typing import ( + Any, + List, + Mapping, + Optional, + Sequence, + Set, + TYPE_CHECKING, +) + +from dateutil.parser import parse as dtparse +import graphene +from graphene.types.datetime import DateTime as GQLDateTime +import sqlalchemy as sa +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +import trafaret as t + +from ai.backend.common.types import VFolderMount + +from ..api.exceptions import InvalidAPIParameters, VFolderNotFound, VFolderOperationFailed +from ..defs import RESERVED_VFOLDER_PATTERNS, RESERVED_VFOLDERS +from ..types import UserScope +from .base import ( + metadata, EnumValueType, GUID, IDColumn, + Item, PaginatedList, BigInt, + batch_multiresult, +) +from .minilang.queryfilter import QueryFilterParser +from .minilang.ordering import QueryOrderParser +from .user import UserRole +if TYPE_CHECKING: + from .gql import GraphQueryContext + from .storage import StorageSessionManager + +__all__: Sequence[str] = ( + 'vfolders', + 'vfolder_invitations', + 'vfolder_permissions', + 'VirtualFolder', + 'VFolderUsageMode', + 'VFolderOwnershipType', + 'VFolderInvitationState', + 'VFolderPermission', + 'VFolderPermissionValidator', + 'query_accessible_vfolders', + 'get_allowed_vfolder_hosts_by_group', + 'get_allowed_vfolder_hosts_by_user', + 'verify_vfolder_name', + 'prepare_vfolder_mounts', +) + + +class VFolderUsageMode(str, enum.Enum): + ''' + Usage mode of virtual folder. + + GENERAL: normal virtual folder + MODEL: virtual folder which provides shared models + DATA: virtual folder which provides shared data + ''' + GENERAL = 'general' + MODEL = 'model' + DATA = 'data' + + +class VFolderOwnershipType(str, enum.Enum): + ''' + Ownership type of virtual folder. + ''' + USER = 'user' + GROUP = 'group' + + +class VFolderPermission(str, enum.Enum): + ''' + Permissions for a virtual folder given to a specific access key. + RW_DELETE includes READ_WRITE and READ_WRITE includes READ_ONLY. + ''' + READ_ONLY = 'ro' + READ_WRITE = 'rw' + RW_DELETE = 'wd' + OWNER_PERM = 'wd' # resolved as RW_DELETE + + +class VFolderPermissionValidator(t.Trafaret): + def check_and_return(self, value: Any) -> VFolderPermission: + if value not in ['ro', 'rw', 'wd']: + self._failure('one of "ro", "rw", or "wd" required', value=value) + return VFolderPermission(value) + + +class VFolderInvitationState(str, enum.Enum): + ''' + Virtual Folder invitation state. + ''' + PENDING = 'pending' + CANCELED = 'canceled' # canceled by inviter + ACCEPTED = 'accepted' + REJECTED = 'rejected' # rejected by invitee + + +vfolders = sa.Table( + 'vfolders', metadata, + IDColumn('id'), + # host will be '' if vFolder is unmanaged + sa.Column('host', sa.String(length=128), nullable=False), + sa.Column('name', sa.String(length=64), nullable=False, index=True), + sa.Column('usage_mode', EnumValueType(VFolderUsageMode), + default=VFolderUsageMode.GENERAL, nullable=False), + sa.Column('permission', EnumValueType(VFolderPermission), + default=VFolderPermission.READ_WRITE), + sa.Column('max_files', sa.Integer(), default=1000), + sa.Column('max_size', sa.Integer(), default=None), # in MBytes + sa.Column('num_files', sa.Integer(), default=0), + sa.Column('cur_size', sa.Integer(), default=0), # in KBytes + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.func.now()), + sa.Column('last_used', sa.DateTime(timezone=True), nullable=True), + # creator is always set to the user who created vfolder (regardless user/project types) + sa.Column('creator', sa.String(length=128), nullable=True), + # unmanaged vfolder represents the host-side absolute path instead of storage-based path. + sa.Column('unmanaged_path', sa.String(length=512), nullable=True), + sa.Column('ownership_type', EnumValueType(VFolderOwnershipType), + default=VFolderOwnershipType.USER, nullable=False), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=True), # owner if user vfolder + sa.Column('group', GUID, sa.ForeignKey('groups.id'), nullable=True), # owner if project vfolder + sa.Column('cloneable', sa.Boolean, default=False, nullable=False), + + sa.CheckConstraint( + '(ownership_type = \'user\' AND "user" IS NOT NULL) OR ' + '(ownership_type = \'group\' AND "group" IS NOT NULL)', + name='ownership_type_match_with_user_or_group', + ), + sa.CheckConstraint( + '("user" IS NULL AND "group" IS NOT NULL) OR ("user" IS NOT NULL AND "group" IS NULL)', + name='either_one_of_user_or_group', + ), +) + + +vfolder_attachment = sa.Table( + 'vfolder_attachment', metadata, + sa.Column('vfolder', GUID, + sa.ForeignKey('vfolders.id', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False), + sa.Column('kernel', GUID, + sa.ForeignKey('kernels.id', onupdate='CASCADE', ondelete='CASCADE'), + nullable=False), + sa.PrimaryKeyConstraint('vfolder', 'kernel'), +) + + +vfolder_invitations = sa.Table( + 'vfolder_invitations', metadata, + IDColumn('id'), + sa.Column('permission', EnumValueType(VFolderPermission), + default=VFolderPermission.READ_WRITE), + sa.Column('inviter', sa.String(length=256)), # email + sa.Column('invitee', sa.String(length=256), nullable=False), # email + sa.Column('state', EnumValueType(VFolderInvitationState), + default=VFolderInvitationState.PENDING), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('modified_at', sa.DateTime(timezone=True), nullable=True, + onupdate=sa.func.current_timestamp()), + sa.Column('vfolder', GUID, + sa.ForeignKey('vfolders.id', + onupdate='CASCADE', + ondelete='CASCADE'), + nullable=False), +) + + +vfolder_permissions = sa.Table( + 'vfolder_permissions', metadata, + sa.Column('permission', EnumValueType(VFolderPermission), + default=VFolderPermission.READ_WRITE), + sa.Column('vfolder', GUID, + sa.ForeignKey('vfolders.id', + onupdate='CASCADE', + ondelete='CASCADE'), + nullable=False), + sa.Column('user', GUID, sa.ForeignKey('users.uuid'), nullable=False), +) + + +def verify_vfolder_name(folder: str) -> bool: + if folder in RESERVED_VFOLDERS: + return False + for pattern in RESERVED_VFOLDER_PATTERNS: + if pattern.match(folder): + return False + return True + + +async def query_accessible_vfolders( + conn: SAConnection, + user_uuid: uuid.UUID, + *, + user_role=None, + domain_name=None, + allowed_vfolder_types=None, + extra_vf_conds=None, + extra_vfperm_conds=None, + extra_vf_user_conds=None, + extra_vf_group_conds=None, +) -> Sequence[Mapping[str, Any]]: + from ai.backend.manager.models import groups, users, association_groups_users as agus + if allowed_vfolder_types is None: + allowed_vfolder_types = ['user'] # legacy default + + vfolders_selectors = [ + vfolders.c.name, + vfolders.c.id, + vfolders.c.host, + vfolders.c.usage_mode, + vfolders.c.created_at, + vfolders.c.last_used, + vfolders.c.max_files, + vfolders.c.max_size, + vfolders.c.ownership_type, + vfolders.c.user, + vfolders.c.group, + vfolders.c.creator, + vfolders.c.unmanaged_path, + vfolders.c.cloneable, + # vfolders.c.permission, + # users.c.email, + ] + + async def _append_entries(_query, _is_owner=True): + if extra_vf_conds is not None: + _query = _query.where(extra_vf_conds) + if extra_vf_user_conds is not None: + _query = _query.where(extra_vf_user_conds) + result = await conn.execute(_query) + for row in result: + row_keys = row.keys() + _perm = row.vfolder_permissions_permission \ + if 'vfolder_permissions_permission' in row_keys \ + else row.vfolders_permission + entries.append({ + 'name': row.vfolders_name, + 'id': row.vfolders_id, + 'host': row.vfolders_host, + 'usage_mode': row.vfolders_usage_mode, + 'created_at': row.vfolders_created_at, + 'last_used': row.vfolders_last_used, + 'max_size': row.vfolders_max_size, + 'max_files': row.vfolders_max_files, + 'ownership_type': row.vfolders_ownership_type, + 'user': str(row.vfolders_user) if row.vfolders_user else None, + 'group': str(row.vfolders_group) if row.vfolders_group else None, + 'creator': row.vfolders_creator, + 'user_email': row.users_email if 'users_email' in row_keys else None, + 'group_name': row.groups_name if 'groups_name' in row_keys else None, + 'is_owner': _is_owner, + 'permission': _perm, + 'unmanaged_path': row.vfolders_unmanaged_path, + 'cloneable': row.vfolders_cloneable, + }) + + entries: List[dict] = [] + # User vfolders. + if 'user' in allowed_vfolder_types: + # Scan my owned vfolders. + j = (vfolders.join(users, vfolders.c.user == users.c.uuid)) + query = ( + sa.select(vfolders_selectors + [vfolders.c.permission, users.c.email], use_labels=True) + .select_from(j) + .where(vfolders.c.user == user_uuid) + ) + await _append_entries(query) + + # Scan vfolders shared with me. + j = ( + vfolders.join( + vfolder_permissions, + vfolders.c.id == vfolder_permissions.c.vfolder, + isouter=True, + ) + .join( + users, + vfolders.c.user == users.c.uuid, + isouter=True, + ) + ) + query = ( + sa.select( + vfolders_selectors + [vfolder_permissions.c.permission, users.c.email], + use_labels=True, + ) + .select_from(j) + .where( + (vfolder_permissions.c.user == user_uuid) & + (vfolders.c.ownership_type == VFolderOwnershipType.USER), + ) + ) + await _append_entries(query, _is_owner=False) + + if 'group' in allowed_vfolder_types: + # Scan group vfolders. + if user_role == UserRole.ADMIN or user_role == 'admin': + query = (sa.select([groups.c.id]) + .select_from(groups) + .where(groups.c.domain_name == domain_name)) + result = await conn.execute(query) + grps = result.fetchall() + group_ids = [g.id for g in grps] + else: + j = sa.join(agus, users, agus.c.user_id == users.c.uuid) + query = (sa.select([agus.c.group_id]) + .select_from(j) + .where(agus.c.user_id == user_uuid)) + result = await conn.execute(query) + grps = result.fetchall() + group_ids = [g.group_id for g in grps] + j = (vfolders.join(groups, vfolders.c.group == groups.c.id)) + query = ( + sa.select(vfolders_selectors + [vfolders.c.permission, groups.c.name], use_labels=True) + .select_from(j) + ) + if user_role != UserRole.SUPERADMIN: + query = query.where(vfolders.c.group.in_(group_ids)) + is_owner = ((user_role == UserRole.ADMIN or user_role == 'admin') or + (user_role == UserRole.SUPERADMIN or user_role == 'superadmin')) + await _append_entries(query, is_owner) + + # Override permissions, if exists, for group vfolders. + j = sa.join( + vfolders, vfolder_permissions, vfolders.c.id == vfolder_permissions.c.vfolder, + ) + query = ( + sa.select(vfolder_permissions.c.permission, vfolder_permissions.c.vfolder) + .select_from(j) + .where( + (vfolders.c.group.in_(group_ids)) & + (vfolder_permissions.c.user == user_uuid), + ) + ) + if extra_vf_conds is not None: + query = query.where(extra_vf_conds) + if extra_vf_user_conds is not None: + query = query.where(extra_vf_user_conds) + result = await conn.execute(query) + overriding_permissions: dict = {row.vfolder: row.permission for row in result} + for entry in entries: + if entry['id'] in overriding_permissions and \ + entry['ownership_type'] == VFolderOwnershipType.GROUP: + entry['permission'] = overriding_permissions[entry['id']] + + return entries + + +async def get_allowed_vfolder_hosts_by_group( + conn: SAConnection, + resource_policy, + domain_name: str, + group_id: uuid.UUID = None, + domain_admin: bool = False, +) -> Set[str]: + ''' + Union `allowed_vfolder_hosts` from domain, group, and keypair_resource_policy. + + If `group_id` is not None, `allowed_vfolder_hosts` from the group is also merged. + If the requester is a domain admin, gather all `allowed_vfolder_hosts` of the domain groups. + ''' + from . import domains, groups + # Domain's allowed_vfolder_hosts. + allowed_hosts = set() + query = ( + sa.select([domains.c.allowed_vfolder_hosts]) + .where( + (domains.c.name == domain_name) & + (domains.c.is_active), + ) + ) + allowed_hosts.update(await conn.scalar(query)) + # Group's allowed_vfolder_hosts. + if group_id is not None: + query = ( + sa.select([groups.c.allowed_vfolder_hosts]) + .where( + (groups.c.domain_name == domain_name) & + (groups.c.id == group_id) & + (groups.c.is_active), + ) + ) + allowed_hosts.update(await conn.scalar(query)) + elif domain_admin: + query = ( + sa.select([groups.c.allowed_vfolder_hosts]) + .where( + (groups.c.domain_name == domain_name) & + (groups.c.is_active), + ) + ) + result = await conn.execute(query) + for row in result: + allowed_hosts.update(row.allowed_vfolder_hosts) + # Keypair Resource Policy's allowed_vfolder_hosts + allowed_hosts.update(resource_policy['allowed_vfolder_hosts']) + return allowed_hosts + + +async def get_allowed_vfolder_hosts_by_user( + conn: SAConnection, + resource_policy, + domain_name: str, + user_uuid: uuid.UUID, + group_id: uuid.UUID = None, +) -> Set[str]: + ''' + Union `allowed_vfolder_hosts` from domain, groups, and keypair_resource_policy. + + All available `allowed_vfolder_hosts` of groups which requester associated will be merged. + ''' + from . import association_groups_users, domains, groups + # Domain's allowed_vfolder_hosts. + allowed_hosts = set() + query = ( + sa.select([domains.c.allowed_vfolder_hosts]) + .where( + (domains.c.name == domain_name) & + (domains.c.is_active), + ) + ) + allowed_hosts.update(await conn.scalar(query)) + # User's Groups' allowed_vfolder_hosts. + if group_id is not None: + j = groups.join( + association_groups_users, + ( + (groups.c.id == association_groups_users.c.group_id) & + (groups.c.id == group_id) & + (association_groups_users.c.user_id == user_uuid) + ), + ) + else: + j = groups.join( + association_groups_users, + ( + (groups.c.id == association_groups_users.c.group_id) & + (association_groups_users.c.user_id == user_uuid) + ), + ) + query = ( + sa.select([groups.c.allowed_vfolder_hosts]) + .select_from(j) + .where( + (domains.c.name == domain_name) & + (groups.c.is_active), + ) + ) + result = await conn.execute(query) + rows = result.fetchall() + for row in rows: + allowed_hosts.update(row['allowed_vfolder_hosts']) + # Keypair Resource Policy's allowed_vfolder_hosts + allowed_hosts.update(resource_policy['allowed_vfolder_hosts']) + return allowed_hosts + + +async def prepare_vfolder_mounts( + conn: SAConnection, + storage_manager: StorageSessionManager, + allowed_vfolder_types: Sequence[str], + user_scope: UserScope, + requested_mounts: Sequence[str], + requested_mount_map: Mapping[str, str], +) -> Sequence[VFolderMount]: + """ + Determine the actual mount information from the requested vfolder lists, + vfolder configurations, and the given user scope. + """ + + requested_vfolder_names: dict[str, str] = {} + requested_vfolder_subpaths: dict[str, str] = {} + requested_vfolder_dstpaths: dict[str, str] = {} + matched_vfolder_mounts: list[VFolderMount] = [] + + # Split the vfolder name and subpaths + for key in requested_mounts: + name, _, subpath = key.partition("/") + if not PurePosixPath(os.path.normpath(key)).is_relative_to(name): + raise InvalidAPIParameters( + f"The subpath '{subpath}' should designate " + f"a subdirectory of the vfolder '{name}'.", + ) + requested_vfolder_names[key] = name + requested_vfolder_subpaths[key] = os.path.normpath(subpath) + for key, value in requested_mount_map.items(): + requested_vfolder_dstpaths[key] = value + + # Check if there are overlapping mount sources + for p1 in requested_mounts: + for p2 in requested_mounts: + if p1 == p2: + continue + if PurePosixPath(p1).is_relative_to(PurePosixPath(p2)): + raise InvalidAPIParameters( + f"VFolder source path '{p1}' overlaps with '{p2}'", + ) + + # Query the accessible vfolders that satisfy either: + # - the name matches with the requested vfolder name, or + # - the name starts with a dot (dot-prefixed vfolder) for automatic mounting. + if requested_vfolder_names: + extra_vf_conds = ( + vfolders.c.name.in_(requested_vfolder_names.values()) | + vfolders.c.name.startswith('.') + ) + else: + extra_vf_conds = vfolders.c.name.startswith('.') + accessible_vfolders = await query_accessible_vfolders( + conn, user_scope.user_uuid, + user_role=user_scope.user_role, + domain_name=user_scope.domain_name, + allowed_vfolder_types=allowed_vfolder_types, + extra_vf_conds=extra_vf_conds, + ) + + # Fast-path for empty requested mounts + if not accessible_vfolders: + if requested_vfolder_names: + raise VFolderNotFound("There is no accessible vfolders at all.") + else: + return [] + accessible_vfolders_map = { + vfolder['name']: vfolder for vfolder in accessible_vfolders + } + + # add automount folder list into requested_vfolder_names + # and requested_vfolder_subpath + for _vfolder in accessible_vfolders: + if _vfolder['name'].startswith('.'): + requested_vfolder_names.setdefault(_vfolder['name'], _vfolder['name']) + requested_vfolder_subpaths.setdefault(_vfolder['name'], '.') + + # for vfolder in accessible_vfolders: + for key, vfolder_name in requested_vfolder_names.items(): + if not (vfolder := accessible_vfolders_map.get(vfolder_name)): + raise VFolderNotFound(f"VFolder {vfolder_name} is not found or accessible.") + if vfolder['group'] is not None and vfolder['group'] != str(user_scope.group_id): + # User's accessible group vfolders should not be mounted + # if not belong to the execution kernel. + continue + try: + mount_base_path = PurePosixPath( + await storage_manager.get_mount_path( + vfolder['host'], + vfolder['id'], + PurePosixPath(requested_vfolder_subpaths[key]), + ), + ) + except VFolderOperationFailed as e: + raise InvalidAPIParameters(e.extra_msg, e.extra_data) from None + if vfolder['name'] == '.local' and vfolder['group'] is not None: + # Auto-create per-user subdirectory inside the group-owned ".local" vfolder. + async with storage_manager.request( + vfolder['host'], 'POST', 'folder/file/mkdir', + params={ + 'volume': storage_manager.split_host(vfolder['host'])[1], + 'vfid': vfolder['id'], + 'relpath': str(user_scope.user_uuid.hex), + 'exist_ok': True, + }, + ): + pass + # Mount the per-user subdirectory as the ".local" vfolder. + matched_vfolder_mounts.append(VFolderMount( + name=vfolder['name'], + vfid=vfolder['id'], + vfsubpath=PurePosixPath(user_scope.user_uuid.hex), + host_path=mount_base_path / user_scope.user_uuid.hex, + kernel_path=PurePosixPath("/home/work/.local"), + mount_perm=vfolder['permission'], + )) + else: + # Normal vfolders + kernel_path_raw = requested_vfolder_dstpaths.get(key) + if kernel_path_raw is None: + kernel_path = PurePosixPath(f"/home/work/{vfolder['name']}") + else: + kernel_path = PurePosixPath(kernel_path_raw) + if not kernel_path.is_absolute(): + kernel_path = PurePosixPath("/home/work", kernel_path_raw) + matched_vfolder_mounts.append(VFolderMount( + name=vfolder['name'], + vfid=vfolder['id'], + vfsubpath=PurePosixPath(requested_vfolder_subpaths[key]), + host_path=mount_base_path / requested_vfolder_subpaths[key], + kernel_path=kernel_path, + mount_perm=vfolder['permission'], + )) + + # Check if there are overlapping mount targets + for vf1 in matched_vfolder_mounts: + for vf2 in matched_vfolder_mounts: + if vf1.name == vf2.name: + continue + if vf1.kernel_path.is_relative_to(vf2.kernel_path): + raise InvalidAPIParameters( + f"VFolder mount path {vf1.kernel_path} overlaps with {vf2.kernel_path}", + ) + + return matched_vfolder_mounts + + +class VirtualFolder(graphene.ObjectType): + class Meta: + interfaces = (Item, ) + + host = graphene.String() + name = graphene.String() + user = graphene.UUID() # User.id (current owner, null in project vfolders) + user_email = graphene.String() # User.email (current owner, null in project vfolders) + group = graphene.UUID() # Group.id (current owner, null in user vfolders) + group_name = graphene.String() # Group.name (current owenr, null in user vfolders) + creator = graphene.String() # User.email (always set) + unmanaged_path = graphene.String() + usage_mode = graphene.String() + permission = graphene.String() + ownership_type = graphene.String() + max_files = graphene.Int() + max_size = BigInt() # in MiB + created_at = GQLDateTime() + last_used = GQLDateTime() + + num_files = graphene.Int() + cur_size = BigInt() + # num_attached = graphene.Int() + cloneable = graphene.Boolean() + + @classmethod + def from_row(cls, ctx: GraphQueryContext, row: Row) -> Optional[VirtualFolder]: + if row is None: + return None + return cls( + id=row['id'], + host=row['host'], + name=row['name'], + user=row['user'], + user_email=row['users_email'], + group=row['group'], + group_name=row['groups_name'], + creator=row['creator'], + unmanaged_path=row['unmanaged_path'], + usage_mode=row['usage_mode'], + permission=row['permission'], + ownership_type=row['ownership_type'], + max_files=row['max_files'], + max_size=row['max_size'], # in MiB + created_at=row['created_at'], + last_used=row['last_used'], + # num_attached=row['num_attached'], + cloneable=row['cloneable'], + ) + + async def resolve_num_files(self, info: graphene.ResolveInfo) -> int: + # TODO: measure on-the-fly + return 0 + + async def resolve_cur_size(self, info: graphene.ResolveInfo) -> int: + # TODO: measure on-the-fly + return 0 + + _queryfilter_fieldspec = { + "id": ("vfolders_id", uuid.UUID), + "host": ("vfolders_host", None), + "name": ("vfolders_name", None), + "group": ("vfolders_group", uuid.UUID), + "group_name": ("groups_name", None), + "user": ("vfolders_user", uuid.UUID), + "user_email": ("users_email", None), + "creator": ("vfolders_creator", None), + "unmanaged_path": ("vfolders_unmanaged_path", None), + "usage_mode": ("vfolders_usage_mode", lambda s: VFolderUsageMode[s]), + "permission": ("vfolders_permission", lambda s: VFolderPermission[s]), + "ownership_type": ("vfolders_ownership_type", lambda s: VFolderOwnershipType[s]), + "max_files": ("vfolders_max_files", None), + "max_size": ("vfolders_max_size", None), + "created_at": ("vfolders_created_at", dtparse), + "last_used": ("vfolders_last_used", dtparse), + "cloneable": ("vfolders_cloneable", None), + } + + _queryorder_colmap = { + "id": "vfolders_id", + "host": "vfolders_host", + "name": "vfolders_name", + "group": "vfolders_group", + "group_name": "groups_name", + "user": "vfolders_user", + "user_email": "users_email", + "usage_mode": "vfolders_usage_mode", + "permission": "vfolders_permission", + "ownership_type": "vfolders_ownership_type", + "max_files": "vfolders_max_files", + "max_size": "vfolders_max_size", + "created_at": "vfolders_created_at", + "last_used": "vfolders_last_used", + "cloneable": "vfolders_cloneable", + } + + @classmethod + async def load_count( + cls, + graph_ctx: GraphQueryContext, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + user_id: uuid.UUID = None, + filter: str = None, + ) -> int: + from .user import users + j = ( + vfolders + .join(users, vfolders.c.user == users.c.uuid, isouter=True) + ) + query = ( + sa.select([sa.func.count()]) + .select_from(j) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if group_id is not None: + query = query.where(vfolders.c.group == group_id) + if user_id is not None: + query = query.where(vfolders.c.user == user_id) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + async with graph_ctx.db.begin_readonly() as conn: + result = await conn.execute(query) + return result.scalar() + + @classmethod + async def load_slice( + cls, + graph_ctx: GraphQueryContext, + limit: int, + offset: int, + *, + domain_name: str = None, + group_id: uuid.UUID = None, + user_id: uuid.UUID = None, + filter: str = None, + order: str = None, + ) -> Sequence[VirtualFolder]: + from .user import users + from .group import groups + j = ( + vfolders + .join(users, vfolders.c.user == users.c.uuid, isouter=True) + .join(groups, vfolders.c.group == groups.c.id, isouter=True) + ) + query = ( + sa.select([vfolders, users.c.email, groups.c.name.label('groups_name')]) + .select_from(j) + .limit(limit) + .offset(offset) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if group_id is not None: + query = query.where(vfolders.c.group == group_id) + if user_id is not None: + query = query.where(vfolders.c.user == user_id) + if filter is not None: + qfparser = QueryFilterParser(cls._queryfilter_fieldspec) + query = qfparser.append_filter(query, filter) + if order is not None: + qoparser = QueryOrderParser(cls._queryorder_colmap) + query = qoparser.append_ordering(query, order) + else: + query = query.order_by(vfolders.c.created_at.desc()) + async with graph_ctx.db.begin_readonly() as conn: + return [ + obj async for r in (await conn.stream(query)) + if (obj := cls.from_row(graph_ctx, r)) is not None + ] + + @classmethod + async def batch_load_by_user( + cls, + graph_ctx: GraphQueryContext, + user_uuids: Sequence[uuid.UUID], + *, + domain_name: str = None, + group_id: uuid.UUID = None, + ) -> Sequence[Sequence[VirtualFolder]]: + from .user import users + # TODO: num_attached count group-by + j = sa.join(vfolders, users, vfolders.c.user == users.c.uuid) + query = ( + sa.select([vfolders]) + .select_from(j) + .where(vfolders.c.user.in_(user_uuids)) + .order_by(sa.desc(vfolders.c.created_at)) + ) + if domain_name is not None: + query = query.where(users.c.domain_name == domain_name) + if group_id is not None: + query = query.where(vfolders.c.group == group_id) + async with graph_ctx.db.begin_readonly() as conn: + return await batch_multiresult( + graph_ctx, conn, query, cls, + user_uuids, lambda row: row['user'], + ) + + +class VirtualFolderList(graphene.ObjectType): + class Meta: + interfaces = (PaginatedList, ) + + items = graphene.List(VirtualFolder, required=True) diff --git a/src/ai/backend/manager/pglock.py b/src/ai/backend/manager/pglock.py new file mode 100644 index 0000000000..7905a2629e --- /dev/null +++ b/src/ai/backend/manager/pglock.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import Any, AsyncContextManager + +from ai.backend.common.lock import AbstractDistributedLock + +from .models.utils import ExtendedAsyncSAEngine +from .defs import LockID + + +class PgAdvisoryLock(AbstractDistributedLock): + + _lock_ctx: AsyncContextManager | None + + def __init__(self, db: ExtendedAsyncSAEngine, lock_id: LockID) -> None: + self.db = db + self.lock_id = lock_id + self._lock_ctx = None + + async def __aenter__(self) -> Any: + self._lock_ctx = self.db.advisory_lock(self.lock_id) + await self._lock_ctx.__aenter__() + + async def __aexit__(self, *exc_info) -> bool | None: + assert self._lock_ctx is not None + try: + return await self._lock_ctx.__aexit__(*exc_info) + finally: + self._lock_ctx = None diff --git a/src/ai/backend/manager/plugin/__init__.py b/src/ai/backend/manager/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/manager/plugin/error_monitor.py b/src/ai/backend/manager/plugin/error_monitor.py new file mode 100644 index 0000000000..6a402e9af2 --- /dev/null +++ b/src/ai/backend/manager/plugin/error_monitor.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import logging +import sys +import traceback +from typing import Any, Mapping, TYPE_CHECKING + +from ai.backend.common.events import AgentErrorEvent +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + AgentId, + LogSeverity, +) +from ai.backend.common.plugin.monitor import AbstractErrorReporterPlugin + +from ..models import error_logs + +if TYPE_CHECKING: + from ai.backend.manager.api.context import RootContext + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class ErrorMonitor(AbstractErrorReporterPlugin): + + async def init(self, context: Any = None) -> None: + if context is None: + log.warning( + "manager.plugin.error_monitor is initialized without the root context. " + "The plugin is disabled.", + ) + self.enabled = False + return + else: + self.enabled = True + root_ctx: RootContext = context['_root.context'] # type: ignore + self.event_dispatcher = root_ctx.event_dispatcher + self._evh = self.event_dispatcher.consume(AgentErrorEvent, None, self.handle_agent_error) + self.db = root_ctx.db + + async def cleanup(self) -> None: + if self.enabled: + self.event_dispatcher.unconsume(self._evh) + + async def update_plugin_config(self, plugin_config: Mapping[str, Any]) -> None: + pass + + async def capture_message(self, message: str) -> None: + pass + + async def capture_exception( + self, + exc_instance: Exception = None, + context: Mapping[str, Any] = None, + ) -> None: + if not self.enabled: + return + if exc_instance: + tb = exc_instance.__traceback__ + else: + _, sys_exc_instance, tb = sys.exc_info() + if ( + isinstance(sys_exc_instance, BaseException) + and not isinstance(sys_exc_instance, Exception) + ): + # bypass BaseException as they are used for controlling the process/coroutine lifecycles + # instead of indicating actual errors + return + exc_instance = sys_exc_instance + exc_type: Any = type(exc_instance) + + if context is None or 'severity' not in context: + severity = LogSeverity.ERROR + else: + severity = context['severity'] + if context is None or 'user' not in context: + user = None + else: + user = context['user'] + message = ''.join(traceback.format_exception_only(exc_type, exc_instance)).strip() + + async with self.db.begin() as conn: + query = error_logs.insert().values({ + 'severity': severity, + 'source': 'manager', + 'user': user, + 'message': message, + 'context_lang': 'python', + 'context_env': context, + 'traceback': ''.join(traceback.format_tb(tb)).strip(), + }) + await conn.execute(query) + log.debug( + "collected an error log [{}] \"{}\" from manager", + severity.name, message, + ) + + async def handle_agent_error( + self, + context: None, + source: AgentId, + event: AgentErrorEvent, + ) -> None: + if not self.enabled: + return + async with self.db.begin() as conn: + query = error_logs.insert().values({ + 'severity': event.severity, + 'source': source, + 'user': event.user, + 'message': event.message, + 'context_lang': 'python', + 'context_env': event.context_env, + 'traceback': event.traceback, + }) + await conn.execute(query) + log.debug( + "collected an error log [{}] \"{}\" from agent:{}", + event.severity.name, event.message, source, + ) diff --git a/src/ai/backend/manager/plugin/exceptions.py b/src/ai/backend/manager/plugin/exceptions.py new file mode 100644 index 0000000000..6e39afbf1d --- /dev/null +++ b/src/ai/backend/manager/plugin/exceptions.py @@ -0,0 +1,15 @@ +""" +This module defines a series of Backend.AI's plugin-specific errors. +""" +from aiohttp import web +from ai.backend.manager.api.exceptions import BackendError + + +class PluginError(web.HTTPBadRequest, BackendError): + error_type = 'https://api.backend.ai/probs/plugin-error' + error_title = 'Plugin generated error' + + +class PluginConfigurationError(PluginError): + error_type = 'https://api.backend.ai/probs/plugin-config-error' + error_title = 'Plugin configuration error' diff --git a/src/ai/backend/manager/plugin/webapp.py b/src/ai/backend/manager/plugin/webapp.py new file mode 100644 index 0000000000..003c51eb3f --- /dev/null +++ b/src/ai/backend/manager/plugin/webapp.py @@ -0,0 +1,31 @@ +from abc import ABCMeta, abstractmethod +from typing import ( + Tuple, + Sequence, +) + +from aiohttp import web + +from ai.backend.common.plugin import AbstractPlugin, BasePluginContext +from ai.backend.manager.api.types import CORSOptions, WebMiddleware + + +class WebappPlugin(AbstractPlugin, metaclass=ABCMeta): + """ + Webapp plugins should create a valid aiohttp.web.Application instance. The returned app + instance will be a subapp of the root app defined by the manager, and additional user-properties + will be set as defined in ``ai.backend.gateway.server.PUBLIC_INTERFACES``. + + The init/cleanup methods of the plugin are ignored and the manager uses the standard aiohttp's + application lifecycle handlers attached to the returned app instance. + """ + + @abstractmethod + async def create_app( + self, cors_options: CORSOptions, + ) -> Tuple[web.Application, Sequence[WebMiddleware]]: + pass + + +class WebappPluginContext(BasePluginContext[WebappPlugin]): + plugin_group = 'backendai_webapp_v20' diff --git a/src/ai/backend/manager/py.typed b/src/ai/backend/manager/py.typed new file mode 100644 index 0000000000..5abed26af8 --- /dev/null +++ b/src/ai/backend/manager/py.typed @@ -0,0 +1 @@ +marker diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py new file mode 100644 index 0000000000..3132410f96 --- /dev/null +++ b/src/ai/backend/manager/registry.py @@ -0,0 +1,2827 @@ +from __future__ import annotations + +import asyncio +from contextvars import ContextVar +from contextlib import asynccontextmanager as actxmgr +from collections import defaultdict +import copy +from datetime import datetime +from decimal import Decimal +import itertools +import logging +import secrets +import time +import re +from typing import ( + Any, + AsyncIterator, + Callable, + Container, + Dict, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, + cast, +) +import uuid +import weakref + +import aiodocker +import aioredis +import aiotools +from async_timeout import timeout as _timeout +from callosum.rpc import Peer, RPCUserError +from callosum.lower.zeromq import ZeroMQAddress, ZeroMQRPCTransport +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend +from dateutil.tz import tzutc +import snappy +import sqlalchemy as sa +from sqlalchemy.exc import DBAPIError +from sqlalchemy.sql.expression import true +from yarl import URL +import zmq + +from ai.backend.common import msgpack, redis +from ai.backend.common.docker import get_registry_info, get_known_registries, ImageRef +from ai.backend.common.events import ( + AgentStartedEvent, + KernelCancelledEvent, + KernelTerminatedEvent, + KernelTerminatingEvent, + SessionCancelledEvent, + SessionEnqueuedEvent, + SessionStartedEvent, + SessionTerminatedEvent, +) +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.plugin.hook import ( + HookPluginContext, + ALL_COMPLETED, + PASSED, +) +from ai.backend.common.service_ports import parse_service_ports +from ai.backend.common.types import ( + AccessKey, + AgentId, + BinarySize, + ClusterInfo, + ClusterMode, + ClusterSSHKeyPair, + DeviceId, + HardwareMetadata, + KernelEnqueueingConfig, + KernelId, + RedisConnectionInfo, + ResourceSlot, + SessionId, + SessionResult, + SessionTypes, + SlotName, + SlotTypes, + check_typed_dict, +) +from ai.backend.common.utils import nmget + +from .api.exceptions import ( + BackendError, InvalidAPIParameters, + RejectedByHook, + InstanceNotFound, + SessionNotFound, TooManySessionsMatched, + KernelCreationFailed, KernelDestructionFailed, + KernelExecutionFailed, KernelRestartFailed, + ScalingGroupNotFound, + AgentError, + GenericForbidden, + QuotaExceeded, +) +from .config import SharedConfig +from .exceptions import MultiAgentError +from .defs import DEFAULT_ROLE, INTRINSIC_SLOTS +from .types import SessionGetter, UserScope +from .models import ( + agents, kernels, + keypair_resource_policies, + session_dependencies, + AgentStatus, KernelStatus, + ImageRow, + query_allowed_sgroups, + prepare_dotfiles, + prepare_vfolder_mounts, + recalc_agent_resource_occupancy, + recalc_concurrency_used, + AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, + USER_RESOURCE_OCCUPYING_KERNEL_STATUSES, + DEAD_KERNEL_STATUSES, +) +from .models.kernel import match_session_ids, get_all_kernels, get_main_kernels +from .models.utils import ( + ExtendedAsyncSAEngine, + execute_with_retry, + reenter_txn, sql_json_merge, +) + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import ( + AsyncConnection as SAConnection, + ) + from sqlalchemy.engine.row import Row + + from ai.backend.common.events import EventDispatcher, EventProducer + + from .models.storage import StorageSessionManager + from .scheduler.types import ( + AgentAllocationContext, + KernelAgentBinding, + SchedulingContext, + PendingSession, + ) + +__all__ = ['AgentRegistry', 'InstanceNotFound'] + +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.registry')) + +_read_only_txn_opts = { + 'postgresql_readonly': True, +} + + +class PeerInvoker(Peer): + + class _CallStub: + + _cached_funcs: Dict[str, Callable] + order_key: ContextVar[Optional[str]] + + def __init__(self, peer: Peer): + self._cached_funcs = {} + self.peer = peer + self.order_key = ContextVar('order_key', default=None) + + def __getattr__(self, name: str): + if f := self._cached_funcs.get(name, None): + return f + else: + async def _wrapped(*args, **kwargs): + request_body = { + 'args': args, + 'kwargs': kwargs, + } + self.peer.last_used = time.monotonic() + ret = await self.peer.invoke(name, request_body, + order_key=self.order_key.get()) + self.peer.last_used = time.monotonic() + return ret + self._cached_funcs[name] = _wrapped + return _wrapped + + call: _CallStub + last_used: float + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.call = self._CallStub(self) + self.last_used = time.monotonic() + + +@actxmgr +async def RPCContext( + agent_id: AgentId, + addr, + *, + invoke_timeout: float = None, + order_key: str = None, + keepalive_timeout: int = 60, +) -> AsyncIterator[PeerInvoker]: + keepalive_retry_count = 3 + keepalive_interval = keepalive_timeout // keepalive_retry_count + if keepalive_interval < 2: + keepalive_interval = 2 + peer = PeerInvoker( + connect=ZeroMQAddress(addr), + transport=ZeroMQRPCTransport, + transport_opts={ + 'zsock_opts': { + zmq.TCP_KEEPALIVE: 1, + zmq.TCP_KEEPALIVE_IDLE: keepalive_timeout, + zmq.TCP_KEEPALIVE_INTVL: keepalive_interval, + zmq.TCP_KEEPALIVE_CNT: keepalive_retry_count, + }, + }, + serializer=msgpack.packb, + deserializer=msgpack.unpackb, + ) + try: + with _timeout(invoke_timeout): + async with peer: + okey_token = peer.call.order_key.set('') + try: + yield peer + finally: + peer.call.order_key.reset(okey_token) + except RPCUserError as orig_exc: + raise AgentError(agent_id, orig_exc.name, orig_exc.repr, orig_exc.args) + except Exception: + raise + + +class AgentRegistry: + """ + Provide a high-level API to create, destroy, and query the computation + kernels. + + The registry is also responsible to implement our resource management + policy, such as the limitation of maximum number of kernels per instance. + """ + + kernel_creation_tracker: Dict[KernelId, asyncio.Future] + _post_kernel_creation_tasks: weakref.WeakValueDictionary[KernelId, asyncio.Task] + _post_kernel_creation_infos: dict[KernelId, asyncio.Future] + _kernel_actual_allocated_resources: dict[KernelId, ResourceSlot] + + def __init__( + self, + shared_config: SharedConfig, + db: ExtendedAsyncSAEngine, + redis_stat: RedisConnectionInfo, + redis_live: RedisConnectionInfo, + redis_image: RedisConnectionInfo, + event_dispatcher: EventDispatcher, + event_producer: EventProducer, + storage_manager: StorageSessionManager, + hook_plugin_ctx: HookPluginContext, + ) -> None: + self.shared_config = shared_config + self.docker = aiodocker.Docker() + self.db = db + self.redis_stat = redis_stat + self.redis_live = redis_live + self.redis_image = redis_image + self.event_dispatcher = event_dispatcher + self.event_producer = event_producer + self.storage_manager = storage_manager + self.hook_plugin_ctx = hook_plugin_ctx + self.kernel_creation_tracker = {} + self._post_kernel_creation_tasks = weakref.WeakValueDictionary() + self._post_kernel_creation_infos = {} + self._kernel_actual_allocated_resources = {} + self.rpc_keepalive_timeout = \ + int(shared_config.get("config/network/rpc/keepalive-timeout", "60")) + + async def init(self) -> None: + self.heartbeat_lock = asyncio.Lock() + + async def shutdown(self) -> None: + pass + + async def get_instance(self, inst_id: AgentId, field=None): + async with self.db.begin_readonly() as conn: + cols = [agents.c.id] + if field is not None: + cols.append(field) + query = (sa.select(cols) + .select_from(agents) + .where(agents.c.id == inst_id)) + result = await conn.execute(query) + row = result.first() + if not row: + raise InstanceNotFound(inst_id) + return row + + async def enumerate_instances(self, check_shadow=True): + + async with self.db.begin_readonly() as conn: + query = (sa.select('*').select_from(agents)) + if check_shadow: + query = query.where(agents.c.status == AgentStatus.ALIVE) + async for row in (await conn.stream(query)): + yield row + + async def update_instance(self, inst_id, updated_fields): + + async def _update() -> None: + async with self.db.begin() as conn: + query = ( + sa.update(agents) + .values(**updated_fields) + .where(agents.c.id == inst_id) + ) + await conn.execute(query) + + await execute_with_retry(_update) + + async def gather_agent_hwinfo(self, instance_id: AgentId) -> Mapping[str, HardwareMetadata]: + agent = await self.get_instance(instance_id, agents.c.addr) + async with RPCContext( + agent['id'], agent['addr'], + invoke_timeout=None, + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + result = await rpc.call.gather_hwinfo() + return { + k: check_typed_dict(v, HardwareMetadata) # type: ignore # (python/mypy#9827) + for k, v in result.items() + } + + async def gather_storage_hwinfo(self, vfolder_host: str) -> HardwareMetadata: + proxy_name, volume_name = self.storage_manager.split_host(vfolder_host) + async with self.storage_manager.request( + proxy_name, 'GET', 'volume/hwinfo', + json={'volume': volume_name}, + raise_for_status=True, + ) as (_, storage_resp): + return check_typed_dict( + await storage_resp.json(), HardwareMetadata, # type: ignore # (python/mypy#9827) + ) + + @actxmgr + async def handle_kernel_exception( + self, + op: str, + session_id: SessionId, + access_key: AccessKey, + error_callback=None, + cancellation_callback=None, + set_error: bool = False, + ) -> AsyncIterator[None]: + op_exc = { + 'create_session': KernelCreationFailed, + 'restart_session': KernelRestartFailed, + 'destroy_session': KernelDestructionFailed, + 'execute': KernelExecutionFailed, + 'shutdown_service': KernelExecutionFailed, + 'upload_file': KernelExecutionFailed, + 'download_file': KernelExecutionFailed, + 'list_files': KernelExecutionFailed, + 'get_logs_from_agent': KernelExecutionFailed, + 'refresh_session': KernelExecutionFailed, + } + exc_class = op_exc[op] + # NOTE: Error logging is done outside of this actxmanager. + try: + yield + except asyncio.TimeoutError: + if set_error: + await self.set_session_status( + session_id, + access_key, + KernelStatus.ERROR, + status_info=f'operation-timeout ({op})', + ) + if error_callback: + await error_callback() + raise exc_class('TIMEOUT') from None + except asyncio.CancelledError: + if cancellation_callback: + await cancellation_callback() + raise + except AgentError as e: + if set_error: + await self.set_session_status( + session_id, + access_key, + KernelStatus.ERROR, + status_info=f'agent-error ({e!r})', + status_data={ + "error": { + "src": "agent", + "agent_id": e.agent_id, + "name": e.exc_name, + "repr": e.exc_repr, + }, + }, + ) + if error_callback: + await error_callback() + raise exc_class('FAILURE', e) from None + except BackendError: + # silently re-raise to make them handled by gateway http handlers + raise + except Exception as e: + if set_error: + await self.set_session_status( + session_id, + access_key, + KernelStatus.ERROR, + status_info=f'other-error ({e!r})', + status_data={ + "error": { + "src": "other", + "name": e.__class__.__name__, + "repr": repr(e), + }, + }, + ) + if error_callback: + await error_callback() + raise + + async def get_kernel( + self, + kern_id: uuid.UUID, + field=None, + allow_stale: bool = False, + db_connection=None, + ): + """ + Retrieve the kernel information from the given kernel ID. + This ID is unique for all individual agent-spawned containers. + + If ``field`` is given, it extracts only the raw value of the given + field, without wrapping it as Kernel object. + If ``allow_stale`` is true, it skips checking validity of the kernel + owner instance. + """ + cols = [kernels.c.id, kernels.c.session_id, + kernels.c.agent_addr, kernels.c.kernel_host, kernels.c.access_key] + if field == '*': + cols = [sa.text('*')] + elif isinstance(field, (tuple, list)): + cols.extend(field) + elif isinstance(field, (sa.Column, sa.sql.elements.ColumnClause)): + cols.append(field) + elif isinstance(field, str): + cols.append(sa.column(field)) + async with reenter_txn(self.db, db_connection, _read_only_txn_opts) as conn: + if allow_stale: + query = ( + sa.select(cols) + .select_from(kernels) + .where(kernels.c.id == kern_id) + .limit(1).offset(0)) + else: + query = ( + sa.select(cols) + .select_from(kernels.join(agents)) + .where( + (kernels.c.id == kern_id) & + ~(kernels.c.status.in_(DEAD_KERNEL_STATUSES)) & + (agents.c.status == AgentStatus.ALIVE) & + (agents.c.id == kernels.c.agent), + ) + .limit(1).offset(0)) + result = await conn.execute(query) + row = result.first() + if row is None: + raise SessionNotFound + return row + + async def get_kernels( + self, + session_name_or_id: Union[str, uuid.UUID], + access_key: str, *, + field=None, + allow_stale: bool = False, + for_update: bool = False, + db_connection: SAConnection = None, + cluster_role: str = None, + ) -> Sequence[sa.engine.Row]: + """ + Retrieve the kernel information by kernel's ID, kernel's session UUID + (session_id), or kernel's name (session_id) paired with access_key. + If the session is composed of multiple containers, this will return + every container information, unless field and role is specified by the caller. + + :param session_name_or_id: kernel's id, session_id (session name), or session_id. + :param access_key: Access key used to create kernels. + :param field: If given, it extracts only the raw value of the given field, without + wrapping it as Kernel object. + :param allow_stale: If True, filter "inactive" kernels as well as "active" ones. + If False, filter "active" kernels only. + :param for_update: Apply for_update during select query. + :param db_connection: Database connection for reuse. + :param cluster_role: Filter kernels by role. "main", "sub", or None (all). + """ + cols = [ + kernels.c.id, + kernels.c.session_id, + kernels.c.session_name, + kernels.c.session_type, + kernels.c.status, + kernels.c.cluster_mode, + kernels.c.cluster_role, + kernels.c.cluster_idx, + kernels.c.access_key, + kernels.c.agent_addr, + kernels.c.kernel_host, + kernels.c.image, + kernels.c.registry, + kernels.c.service_ports, + ] + if field == '*': + cols = [sa.text('*')] + elif isinstance(field, (tuple, list)): + cols.extend(field) + elif isinstance(field, (sa.Column, sa.sql.elements.ColumnClause)): + cols.append(field) + elif isinstance(field, str): + cols.append(sa.column(field)) + + cond_id = ( + (sa.sql.expression.cast(kernels.c.id, sa.String).like(f'{session_name_or_id}%')) & + (kernels.c.access_key == access_key) + ) + cond_name = ( + (kernels.c.session_name.like(f'{session_name_or_id}%')) & + (kernels.c.access_key == access_key) + ) + cond_session_id = ( + (sa.sql.expression.cast(kernels.c.session_id, sa.String).like(f'{session_name_or_id}%')) & + (kernels.c.access_key == access_key) + ) + if cluster_role is not None: + cond_id = cond_id & (kernels.c.cluster_role == cluster_role) + cond_name = cond_name & (kernels.c.cluster_role == cluster_role) + cond_session_id = cond_session_id & (kernels.c.cluster_role == cluster_role) + if allow_stale: + cond_status = true() # any status + else: + cond_status = ~(kernels.c.status.in_(DEAD_KERNEL_STATUSES)) + query_by_id = ( + sa.select(cols) + .select_from(kernels) + .where(cond_id & cond_status) + .order_by(sa.desc(kernels.c.created_at)) + .limit(10).offset(0) + ) + if for_update: + query_by_id = query_by_id.with_for_update() + query_by_name = ( + sa.select(cols) + .select_from(kernels) + .where(cond_name & cond_status) + .order_by(sa.desc(kernels.c.created_at)) + ) + if for_update: + query_by_name = query_by_name.with_for_update() + query_by_session_id = ( + sa.select(cols) + .select_from(kernels) + .where(cond_session_id & cond_status) + .order_by(sa.desc(kernels.c.created_at)) + .limit(10).offset(0) + ) + if for_update: + query_by_session_id = query_by_session_id.with_for_update() + if allow_stale: + query_by_name = query_by_name.limit(10).offset(0) + else: + # for backward-compatibility + query_by_name = query_by_name.limit(1).offset(0) + + async with reenter_txn(self.db, db_connection) as conn: + for query in [ + query_by_id, + query_by_session_id, + query_by_name, + ]: + result = await conn.execute(query) + if result.rowcount == 0: + continue + return result.fetchall() + raise SessionNotFound + + async def get_session_by_session_id( + self, + session_id: SessionId, + *, + db_connection: SAConnection, + for_update: bool = False, + ) -> sa.engine.Row: + query = ( + sa.select( + [sa.text('*')], + ) + .select_from(kernels) + .where( + (kernels.c.session_id == session_id) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + ) + if for_update: + query = query.with_for_update() + result = await db_connection.execute(query) + row = result.first() + if row is None: + raise SessionNotFound + return row + + async def get_session_by_kernel_id( + self, + kernel_id: KernelId, + *, + db_connection: SAConnection, + for_update: bool = False, + ) -> sa.engine.Row: + query = ( + sa.select( + [sa.text('*')], + ) + .select_from(kernels) + .where( + (kernels.c.session_id == ( + sa.select([kernels.c.session_id]) + .select_from(kernels) + .where(kernels.c.id == kernel_id) + )) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + ) + if for_update: + query = query.with_for_update() + result = await db_connection.execute(query) + row = result.first() + if row is None: + raise SessionNotFound + return row + + async def get_session( + self, + session_name_or_id: Union[str, uuid.UUID], + access_key: Union[str, AccessKey], + *, + allow_stale: bool = False, + for_update: bool = False, + db_connection: SAConnection = None, + ) -> sa.engine.Row: + """ + Retrieve the session information by kernel's ID, kernel's session UUID + (session_id), or kernel's name (session_id) paired with access_key. + If the session is composed of multiple containers, this will return + the information of the main kernel. + + :param session_name_or_id: kernel's id, session_id (session name), or session_id. + :param access_key: Access key used to create kernels. + :param field: If given, it extracts only the raw value of the given field, without + wrapping it as Kernel object. + :param allow_stale: If True, filter "inactive" kernels as well as "active" ones. + If False, filter "active" kernels only. + :param for_update: Apply for_update during select query. + :param db_connection: Database connection for reuse. + :param cluster_role: Filter kernels by role. "main", "sub", or None (all). + """ + async with reenter_txn(self.db, db_connection, _read_only_txn_opts) as conn: + if allow_stale: + extra_cond = None + else: + extra_cond = (~kernels.c.status.in_(DEAD_KERNEL_STATUSES)) + session_infos = await match_session_ids( + session_name_or_id, + AccessKey(access_key), + for_update=for_update, + extra_cond=extra_cond, + db_connection=conn, + ) + if not session_infos: + raise SessionNotFound() + if len(session_infos) > 1: + raise TooManySessionsMatched(extra_data={'matches': session_infos}) + kernel_list = await get_main_kernels( + [SessionId(session_infos[0]['session_id'])], + db_connection=conn, + ) + return kernel_list[0] + + async def get_session_kernels( + self, + session_id: str, + access_key: str, *, + field=None, + allow_stale: bool = False, + for_update: bool = False, + db_connection: SAConnection = None, + cluster_role: str = None, + ) -> Sequence[sa.engine.Row]: + """ + Retrieve the information of all kernels of a session by session UUID. + If the session is bundled with multiple containers, + this will return every information of them. + + :param session_id: Session's UUID. + :param access_key: Access key used to create the session. + :param field: If given, it extracts only the raw value of the given field, without + wrapping it as Kernel object. + :param allow_stale: If True, filter "inactive" kernels as well as "active" ones. + If False, filter "active" kernels only. + :param for_update: Apply for_update during select query. + :param db_connection: Database connection for reuse. + :param cluster_role: Filter kernels by role. "main", "sub", or None (all). + """ + return await self.get_kernels( + session_id, access_key, + field=field, for_update=for_update, + db_connection=db_connection, + cluster_role=cluster_role, + ) + + async def get_sessions( + self, + session_names: Container[str], + field=None, + allow_stale=False, + db_connection=None, + ): + """ + Batched version of :meth:`get_session() `. + The order of the returend array is same to the order of ``sess_ids``. + For non-existent or missing kernel IDs, it fills None in their + positions without raising SessionNotFound exception. + """ + + cols = [kernels.c.id, kernels.c.session_id, + kernels.c.agent_addr, kernels.c.kernel_host, kernels.c.access_key, + kernels.c.service_ports] + if isinstance(field, (tuple, list)): + cols.extend(field) + elif isinstance(field, (sa.Column, sa.sql.elements.ColumnClause)): + cols.append(field) + elif isinstance(field, str): + cols.append(sa.column(field)) + async with reenter_txn(self.db, db_connection, _read_only_txn_opts) as conn: + if allow_stale: + query = (sa.select(cols) + .select_from(kernels) + .where((kernels.c.session_id.in_(session_names)) & + (kernels.c.cluster_role == DEFAULT_ROLE))) + else: + query = (sa.select(cols) + .select_from(kernels.join(agents)) + .where((kernels.c.session_id.in_(session_names)) & + (kernels.c.cluster_role == DEFAULT_ROLE) & + (agents.c.status == AgentStatus.ALIVE) & + (agents.c.id == kernels.c.agent))) + result = await conn.execute(query) + rows = result.fetchall() + return rows + + async def enqueue_session( + self, + session_creation_id: str, + session_name: str, + access_key: AccessKey, + kernel_enqueue_configs: List[KernelEnqueueingConfig], + scaling_group: Optional[str], + session_type: SessionTypes, + resource_policy: dict, + *, + user_scope: UserScope, + cluster_mode: ClusterMode = ClusterMode.SINGLE_NODE, + cluster_size: int = 1, + session_tag: str = None, + internal_data: dict = None, + starts_at: datetime = None, + agent_list: Sequence[str] = None, + dependency_sessions: Sequence[SessionId] = None, + callback_url: URL = None, + ) -> SessionId: + + session_id = SessionId(uuid.uuid4()) + + # Check keypair resource limit + if cluster_size > int(resource_policy['max_containers_per_session']): + raise QuotaExceeded( + f"You cannot create session with more than " + f"{resource_policy['max_containers_per_session']} containers.", + ) + + async with self.db.begin_readonly() as conn: + # Check scaling group availability if scaling_group parameter is given. + # If scaling_group is not provided, it will be selected as the first one among + # the list of allowed scaling groups. + sgroups = await query_allowed_sgroups( + conn, user_scope.domain_name, user_scope.group_id, access_key, + ) + if not sgroups: + raise ScalingGroupNotFound("You have no scaling groups allowed to use.") + if scaling_group is None: + scaling_group = sgroups[0]['name'] + log.warning( + f"enqueue_session(s:{session_name}, ak:{access_key}): " + f"The client did not specify the scaling group for session; " + f"falling back to {scaling_group}", + ) + else: + for sgroup in sgroups: + if scaling_group == sgroup['name']: + break + else: + raise ScalingGroupNotFound(f"The scaling group {scaling_group} does not exist.") + assert scaling_group is not None + + # Translate mounts/mount_map into vfolder mounts + requested_mounts = kernel_enqueue_configs[0]['creation_config'].get('mounts') or [] + requested_mount_map = kernel_enqueue_configs[0]['creation_config'].get('mount_map') or {} + allowed_vfolder_types = await self.shared_config.get_vfolder_types() + vfolder_mounts = await prepare_vfolder_mounts( + conn, + self.storage_manager, + allowed_vfolder_types, + user_scope, + requested_mounts, + requested_mount_map, + ) + + # Prepare internal data for common dotfiles. + dotfile_data = await prepare_dotfiles( + conn, + user_scope, + access_key, + vfolder_mounts, + ) + + is_multicontainer = cluster_size > 1 + if is_multicontainer: + if len(kernel_enqueue_configs) == 1: + log.debug( + 'enqueue_session(): replicating kernel_enqueue_config with cluster_size={}', + cluster_size, + ) + # the first kernel_config is repliacted to sub-containers + assert kernel_enqueue_configs[0]['cluster_role'] == DEFAULT_ROLE + kernel_enqueue_configs[0]['cluster_idx'] = 1 + for i in range(cluster_size - 1): + sub_kernel_config = cast(KernelEnqueueingConfig, {**kernel_enqueue_configs[0]}) + sub_kernel_config['cluster_role'] = 'sub' + sub_kernel_config['cluster_idx'] = i + 1 + sub_kernel_config['cluster_hostname'] = sub_kernel_config['cluster_role'] + \ + str(sub_kernel_config['cluster_idx']) + kernel_enqueue_configs.append(sub_kernel_config) + elif len(kernel_enqueue_configs) > 1: + # each container should have its own kernel_config + log.debug( + 'enqueue_session(): using given kernel_enqueue_configs with cluster_size={}', + cluster_size, + ) + if len(kernel_enqueue_configs) != cluster_size: + raise InvalidAPIParameters( + "The number of kernel configs differs from the cluster size") + else: + raise InvalidAPIParameters("Missing kernel configurations") + + # Prepare internal data. + internal_data = {} if internal_data is None else internal_data + internal_data.update(dotfile_data) + + hook_result = await self.hook_plugin_ctx.dispatch( + 'PRE_ENQUEUE_SESSION', + (session_id, session_name, access_key), + return_when=ALL_COMPLETED, + ) + if hook_result.status != PASSED: + raise RejectedByHook.from_hook_result(hook_result) + + kernel_bulk_insert_query = kernels.insert().values({ + 'agent': sa.bindparam('mapped_agent'), + 'id': sa.bindparam('kernel_id'), + 'status': KernelStatus.PENDING, + 'session_creation_id': session_creation_id, + 'session_id': session_id, + 'session_name': session_name, + 'session_type': session_type, + 'cluster_mode': cluster_mode.value, + 'cluster_size': cluster_size, + 'cluster_role': sa.bindparam('cluster_role'), + 'cluster_idx': sa.bindparam('cluster_idx'), + 'cluster_hostname': sa.bindparam('cluster_hostname'), + 'scaling_group': scaling_group, + 'domain_name': user_scope.domain_name, + 'group_id': user_scope.group_id, + 'user_uuid': user_scope.user_uuid, + 'access_key': access_key, + 'image': sa.bindparam('image'), + 'registry': sa.bindparam('registry'), + 'tag': session_tag, + 'starts_at': starts_at, + 'internal_data': internal_data, + 'callback_url': callback_url, + 'startup_command': sa.bindparam('startup_command'), + 'occupied_slots': sa.bindparam('occupied_slots'), + 'occupied_shares': {}, + 'resource_opts': sa.bindparam('resource_opts'), + 'environ': sa.bindparam('environ'), + 'mounts': [ # TODO: keep for legacy? + mount.name for mount in vfolder_mounts + ], + 'vfolder_mounts': vfolder_mounts, + 'bootstrap_script': sa.bindparam('bootstrap_script'), + 'repl_in_port': 0, + 'repl_out_port': 0, + 'stdin_port': 0, + 'stdout_port': 0, + 'preopen_ports': sa.bindparam('preopen_ports'), + }) + kernel_data = [] + + for idx, kernel in enumerate(kernel_enqueue_configs): + kernel_id: KernelId + if kernel['cluster_role'] == DEFAULT_ROLE: + kernel_id = cast(KernelId, session_id) + else: + kernel_id = KernelId(uuid.uuid4()) + creation_config = kernel['creation_config'] + image_ref = kernel['image_ref'] + resource_opts = creation_config.get('resource_opts') or {} + + creation_config['mounts'] = [vfmount.to_json() for vfmount in vfolder_mounts] + # TODO: merge into a single call + async with self.db.begin_readonly_session() as session: + log.debug('enqueue_session(): image ref => {} ({})', image_ref, image_ref.architecture) + image_row = await ImageRow.resolve(session, [image_ref]) + image_min_slots, image_max_slots = await image_row.get_slot_ranges(self.shared_config) + known_slot_types = await self.shared_config.get_resource_slots() + + labels = image_row.labels + # Parse service ports to check for port errors + parse_service_ports(labels.get('ai.backend.service-ports', ''), BackendError) + + # Shared memory. + # We need to subtract the amount of shared memory from the memory limit of + # a container, since tmpfs including /dev/shm uses host-side kernel memory + # and cgroup's memory limit does not apply. + shmem = resource_opts.get('shmem', None) + if shmem is None: + shmem = labels.get('ai.backend.resource.preferred.shmem', '64m') + shmem = BinarySize.from_str(shmem) + resource_opts['shmem'] = shmem + image_min_slots = copy.deepcopy(image_min_slots) + image_min_slots['mem'] += shmem + + # Sanitize user input: does it have resource config? + if 'resources' in creation_config: + # Sanitize user input: does it have "known" resource slots only? + for slot_key, slot_value in creation_config['resources'].items(): + if slot_key not in known_slot_types: + raise InvalidAPIParameters( + f'Unknown requested resource slot: {slot_key}') + try: + requested_slots = ResourceSlot.from_user_input( + creation_config['resources'], known_slot_types) + except ValueError: + log.exception('request_slots & image_slots calculation error') + # happens when requested_slots have more keys + # than the image-defined slots + # (e.g., image does not support accelerators + # requested by the client) + raise InvalidAPIParameters( + 'Your resource request has resource type(s) ' + 'not supported by the image.') + + # If intrinsic resources are not specified, + # fill them with image minimums. + for k, v in requested_slots.items(): + if (v is None or v == 0) and k in INTRINSIC_SLOTS: + requested_slots[k] = image_min_slots[k] + else: + # Handle the legacy clients (prior to v19.03) + # We support CPU/memory conversion, but to use accelerators users + # must update their clients because the slots names are not provided + # by the accelerator plugins. + cpu = creation_config.get('instanceCores') + if cpu is None: # the key is there but may be null. + cpu = image_min_slots['cpu'] + mem = creation_config.get('instanceMemory') + if mem is None: # the key is there but may be null. + mem = image_min_slots['mem'] + else: + # In legacy clients, memory is normalized to GiB. + mem = str(mem) + 'g' + requested_slots = ResourceSlot.from_user_input({ + 'cpu': cpu, + 'mem': mem, + }, known_slot_types) + gpu = creation_config.get('instanceGPUs') + if gpu is not None: + raise InvalidAPIParameters('Client upgrade required ' + 'to use GPUs (v19.03+).') + tpu = creation_config.get('instanceTPUs') + if tpu is not None: + raise InvalidAPIParameters('Client upgrade required ' + 'to use TPUs (v19.03+).') + + # Check the image resource slots. + log_fmt = "s:{} k:{} r:{}-{}" + log_args = (session_id, kernel_id, kernel['cluster_role'], kernel['cluster_idx']) + log.debug(log_fmt + ' -> requested_slots: {}', *log_args, requested_slots) + log.debug(log_fmt + ' -> resource_opts: {}', *log_args, resource_opts) + log.debug(log_fmt + ' -> image_min_slots: {}', *log_args, image_min_slots) + log.debug(log_fmt + ' -> image_max_slots: {}', *log_args, image_max_slots) + + # Check if: requested >= image-minimum + if image_min_slots > requested_slots: + raise InvalidAPIParameters( + 'Your resource request is smaller than ' + 'the minimum required by the image. ({})'.format(' '.join( + f'{k}={v}' for k, v in + image_min_slots.to_humanized(known_slot_types).items() + ))) + + # Check if: requested <= image-maximum + if not (requested_slots <= image_max_slots): + raise InvalidAPIParameters( + 'Your resource request is larger than ' + 'the maximum allowed by the image. ({})' + .format(' '.join( + f'{k}={v}' for k, v in + image_max_slots.to_humanized(known_slot_types).items() + ))) + + # Check if: shmem < memory + if shmem >= requested_slots['mem']: + raise InvalidAPIParameters( + 'Shared memory should be less than the main memory. (s:{}, m:{})' + .format(str(shmem), str(BinarySize(requested_slots['mem']))), + ) + + environ = kernel_enqueue_configs[0]['creation_config'].get('environ') or {} + + # Create kernel object in PENDING state. + mapped_agent = None + if not agent_list: + pass + else: + mapped_agent = agent_list[idx] + + kernel_data.append({ + 'mapped_agent': mapped_agent, + 'kernel_id': kernel_id, + 'cluster_role': kernel['cluster_role'], + 'cluster_idx': kernel['cluster_idx'], + 'cluster_hostname': f"{kernel['cluster_role']}{kernel['cluster_idx']}", + 'image': image_ref.canonical, + 'architecture': image_ref.architecture, + 'registry': image_ref.registry, + 'startup_command': kernel.get('startup_command'), + 'occupied_slots': requested_slots, + 'resource_opts': resource_opts, + 'environ': [f'{k}={v}' for k, v in environ.items()], + 'bootstrap_script': kernel.get('bootstrap_script'), + 'preopen_ports': creation_config.get('preopen_ports', []), + }) + + try: + async def _enqueue() -> None: + async with self.db.begin() as conn: + await conn.execute(kernel_bulk_insert_query, kernel_data) + if dependency_sessions: + matched_dependency_session_ids = [] + for dependency_id in dependency_sessions: + match_info = await match_session_ids( + dependency_id, + access_key, + db_connection=conn, + ) + if match_info: + matched_dependency_session_ids.append(match_info[0]['session_id']) + else: + raise InvalidAPIParameters( + "Unknown session ID or name in the dependency list", + extra_data={"session_ref": dependency_id}, + ) + dependency_bulk_insert_query = session_dependencies.insert().values( + { + 'session_id': session_id, + 'depends_on': sa.bindparam('dependency_id'), + }, + ) + await conn.execute(dependency_bulk_insert_query, [ + {'dependency_id': dependency_id} + for dependency_id in matched_dependency_session_ids + ]) + + await execute_with_retry(_enqueue) + except DBAPIError as e: + if getattr(e.orig, "pgcode", None) == '23503': + match = re.search(r'Key \(agent\)=\((?P[^)]+)\)', repr(e.orig)) + if match: + raise InvalidAPIParameters(f"No such agent: {match.group('agent')}") + else: + raise InvalidAPIParameters("No such agent") + raise + + await self.hook_plugin_ctx.notify( + 'POST_ENQUEUE_SESSION', + (session_id, session_name, access_key), + ) + await self.event_producer.produce_event( + SessionEnqueuedEvent(session_id, session_creation_id), + ) + return session_id + + async def start_session( + self, + sched_ctx: SchedulingContext, + scheduled_session: PendingSession, + ) -> None: + from .scheduler.types import KernelAgentBinding, AgentAllocationContext + kernel_agent_bindings: Sequence[KernelAgentBinding] = [ + KernelAgentBinding( + kernel=k, + agent_alloc_ctx=AgentAllocationContext( + agent_id=k.agent_id, + agent_addr=k.agent_addr, + scaling_group=scheduled_session.scaling_group, + ), + ) + for k in scheduled_session.kernels + ] + session_creation_id = scheduled_session.session_creation_id + + hook_result = await self.hook_plugin_ctx.dispatch( + 'PRE_START_SESSION', + (scheduled_session.session_id, scheduled_session.session_name, scheduled_session.access_key), + return_when=ALL_COMPLETED, + ) + if hook_result.status != PASSED: + raise RejectedByHook.from_hook_result(hook_result) + + # Get resource policy for the session + # TODO: memoize with TTL + async with self.db.begin_readonly() as conn: + query = ( + sa.select([keypair_resource_policies]) + .select_from(keypair_resource_policies) + .where(keypair_resource_policies.c.name == scheduled_session.resource_policy) + ) + result = await conn.execute(query) + resource_policy = result.first() + auto_pull = await self.shared_config.get_raw('config/docker/image/auto_pull') + + # Aggregate image registry information + keyfunc = lambda item: item.kernel.image_ref + image_infos = {} + async with self.db.begin_readonly_session() as session: + for image_ref, _ in itertools.groupby( + sorted(kernel_agent_bindings, key=keyfunc), key=keyfunc, + ): + log.debug('start_session(): image ref => {} ({})', image_ref, image_ref.architecture) + image_infos[image_ref] = await ImageRow.resolve(session, [image_ref]) + registry_url, registry_creds = \ + await get_registry_info(self.shared_config.etcd, image_ref.registry) + image_info = { + 'image_infos': image_infos, + 'registry_url': registry_url, + 'registry_creds': registry_creds, + 'resource_policy': resource_policy, + 'auto_pull': auto_pull, + } + + network_name: Optional[str] = None + if scheduled_session.cluster_mode == ClusterMode.SINGLE_NODE: + if scheduled_session.cluster_size > 1: + network_name = f'bai-singlenode-{scheduled_session.session_id}' + assert kernel_agent_bindings[0].agent_alloc_ctx.agent_id is not None + assert scheduled_session.session_id is not None + try: + async with RPCContext( + kernel_agent_bindings[0].agent_alloc_ctx.agent_id, + kernel_agent_bindings[0].agent_alloc_ctx.agent_addr, + invoke_timeout=None, + order_key=str(scheduled_session.session_id), + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + await rpc.call.create_local_network(network_name) + except Exception: + log.exception(f"Failed to create an agent-local network {network_name}") + raise + else: + network_name = None + elif scheduled_session.cluster_mode == ClusterMode.MULTI_NODE: + # Create overlay network for multi-node sessions + network_name = f'bai-multinode-{scheduled_session.session_id}' + mtu = await self.shared_config.get_raw('config/network/overlay/mtu') + try: + # Overlay networks can only be created at the Swarm manager. + create_options = { + 'Name': network_name, + 'Driver': 'overlay', + 'Attachable': True, + 'Labels': { + 'ai.backend.cluster-network': '1', + }, + 'Options': {}, + } + if mtu: + create_options['Options'] = {'com.docker.network.driver.mtu': mtu} + await self.docker.networks.create(create_options) + except Exception: + log.exception(f"Failed to create an overlay network {network_name}") + raise + keyfunc = lambda item: item.kernel.cluster_role + replicas = { + cluster_role: len([*group_iterator]) + for cluster_role, group_iterator in itertools.groupby( + sorted(kernel_agent_bindings, key=keyfunc), + key=keyfunc, + ) + } + cluster_info = ClusterInfo( + mode=scheduled_session.cluster_mode, + size=scheduled_session.cluster_size, + replicas=replicas, + network_name=network_name, + ssh_keypair=( + await self.create_cluster_ssh_keypair() + if scheduled_session.cluster_size > 1 else None + ), + ) + scheduled_session.environ.update({ + 'BACKENDAI_SESSION_ID': str(scheduled_session.session_id), + 'BACKENDAI_SESSION_NAME': str(scheduled_session.session_name), + 'BACKENDAI_CLUSTER_SIZE': str(scheduled_session.cluster_size), + 'BACKENDAI_CLUSTER_REPLICAS': + ",".join(f"{k}:{v}" for k, v in replicas.items()), + 'BACKENDAI_CLUSTER_HOSTS': + ",".join(binding.kernel.cluster_hostname for binding in kernel_agent_bindings), + 'BACKENDAI_ACCESS_KEY': scheduled_session.access_key, + }) + + # Aggregate by agents to minimize RPC calls + per_agent_tasks = [] + keyfunc = lambda item: item.agent_alloc_ctx.agent_id + for agent_id, group_iterator in itertools.groupby( + sorted(kernel_agent_bindings, key=keyfunc), key=keyfunc, + ): + items = [*group_iterator] + # Within a group, agent_alloc_ctx are same. + agent_alloc_ctx = items[0].agent_alloc_ctx + per_agent_tasks.append( + ( + agent_alloc_ctx, + self._create_kernels_in_one_agent( + agent_alloc_ctx, + scheduled_session, + items, + image_info, + cluster_info, + ), + ), + ) + if per_agent_tasks: + agent_errors = [] + results = await asyncio.gather( + *[item[1] for item in per_agent_tasks], + return_exceptions=True, + ) + for agent_alloc_tx, result in zip((item[0] for item in per_agent_tasks), results): + if isinstance(result, aiotools.TaskGroupError): + agent_errors.extend(result.__errors__) + elif isinstance(result, Exception): + # mark to be destroyed afterwards + agent_errors.append(result) + if agent_errors: + raise MultiAgentError( + "agent(s) raise errors during kernel creation", + errors=agent_errors, + ) + await self.settle_agent_alloc(kernel_agent_bindings) + # If all is well, let's say the session is ready. + await self.event_producer.produce_event( + SessionStartedEvent(scheduled_session.session_id, session_creation_id), + ) + await self.hook_plugin_ctx.notify( + 'POST_START_SESSION', + (scheduled_session.session_id, scheduled_session.session_name, scheduled_session.access_key), + ) + + def convert_resource_spec_to_resource_slot( + self, + allocations: Mapping[str, Mapping[SlotName, Mapping[DeviceId, str]]], + ) -> ResourceSlot: + """ + Convert per-device resource spec allocations (agent-side format) + back into a resource slot (manager-side format). + """ + slots = ResourceSlot() + for alloc_map in allocations.values(): + for slot_name, allocation_by_device in alloc_map.items(): + total_allocs: List[Decimal] = [] + for allocation in allocation_by_device.values(): + if BinarySize.suffix_map.get(allocation[-1].lower()) is not None: + total_allocs.append(Decimal(BinarySize.from_str(allocation))) + else: + total_allocs.append(Decimal(allocation)) + slots[slot_name] = str(sum(total_allocs)) + return slots + + async def _post_create_kernel( + self, + agent_alloc_ctx: AgentAllocationContext, + kernel_id: KernelId, + ) -> None: + # Wait until the kernel_started event. + try: + created_info, _ = await asyncio.gather( + self._post_kernel_creation_infos[kernel_id], + self.kernel_creation_tracker[kernel_id], + ) + except asyncio.CancelledError: + log.warning("post_create_kernel(k:{}) cancelled", kernel_id) + return + except Exception: + log.exception("post_create_kernel(k:{}) unexpected error", kernel_id) + return + else: + + async def _finialize_running() -> None: + # Record kernel access information + try: + async with self.db.begin() as conn: + agent_host = URL(agent_alloc_ctx.agent_addr).host + kernel_host = created_info.get('kernel_host', agent_host) + service_ports = created_info.get('service_ports', []) + # NOTE: created_info contains resource_spec + values = { + 'scaling_group': agent_alloc_ctx.scaling_group, + 'status': KernelStatus.RUNNING, + 'container_id': created_info['container_id'], + 'occupied_shares': {}, + 'attached_devices': created_info.get('attached_devices', {}), + 'kernel_host': kernel_host, + 'repl_in_port': created_info['repl_in_port'], + 'repl_out_port': created_info['repl_out_port'], + 'stdin_port': created_info['stdin_port'], + 'stdout_port': created_info['stdout_port'], + 'service_ports': service_ports, + } + actual_allocs = self.convert_resource_spec_to_resource_slot( + created_info['resource_spec']['allocations']) + values['occupied_slots'] = actual_allocs + self._kernel_actual_allocated_resources[kernel_id] = actual_allocs + update_query = ( + kernels.update() + .values(values) + .where(kernels.c.id == created_info['id']) + ) + await conn.execute(update_query) + except Exception: + log.exception('error while executing _finalize_running') + raise + await execute_with_retry(_finialize_running) + finally: + try: + await asyncio.sleep(1) + finally: + del self._post_kernel_creation_infos[kernel_id] + del self.kernel_creation_tracker[kernel_id] + + async def _create_kernels_in_one_agent( + self, + agent_alloc_ctx: AgentAllocationContext, + scheduled_session: PendingSession, + items: Sequence[KernelAgentBinding], + image_info, + cluster_info, + ) -> None: + loop = asyncio.get_running_loop() + registry_url = image_info['registry_url'] + registry_creds = image_info['registry_creds'] + image_infos = image_info['image_infos'] + resource_policy = image_info['resource_policy'] + auto_pull = image_info['auto_pull'] + assert agent_alloc_ctx.agent_id is not None + assert scheduled_session.session_id is not None + async with RPCContext( + agent_alloc_ctx.agent_id, + agent_alloc_ctx.agent_addr, + invoke_timeout=None, + order_key=str(scheduled_session.session_id), + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + kernel_creation_id = secrets.token_urlsafe(16) + # Prepare kernel_started event handling + for binding in items: + self.kernel_creation_tracker[ + binding.kernel.kernel_id + ] = loop.create_future() + # Spawn post-processing tasks + post_tasks = [] + for binding in items: + self._post_kernel_creation_infos[binding.kernel.kernel_id] = loop.create_future() + post_task = asyncio.create_task(self._post_create_kernel( + agent_alloc_ctx, + binding.kernel.kernel_id, + )) + self._post_kernel_creation_tasks[binding.kernel.kernel_id] = post_task + post_tasks.append(post_task) + try: + # Issue a batched RPC call to create kernels on this agent + created_infos = await rpc.call.create_kernels( + kernel_creation_id, + str(scheduled_session.session_id), + [str(binding.kernel.kernel_id) for binding in items], + [ + { + 'image': { + 'registry': { + 'name': binding.kernel.image_ref.registry, + 'url': str(registry_url), + **registry_creds, # type: ignore + }, + 'digest': image_infos[binding.kernel.image_ref].config_digest, + 'repo_digest': None, + 'canonical': binding.kernel.image_ref.canonical, + 'architecture': binding.kernel.image_ref.architecture, + 'labels': image_infos[binding.kernel.image_ref].labels, + }, + 'session_type': scheduled_session.session_type.value, + 'cluster_role': binding.kernel.cluster_role, + 'cluster_idx': binding.kernel.cluster_idx, + 'cluster_hostname': binding.kernel.cluster_hostname, + 'idle_timeout': resource_policy['idle_timeout'], + 'mounts': [item.to_json() for item in scheduled_session.vfolder_mounts], + 'environ': { + # inherit per-session environment variables + **scheduled_session.environ, + # set per-kernel environment variables + 'BACKENDAI_KERNEL_ID': str(binding.kernel.kernel_id), + 'BACKENDAI_KERNEL_IMAGE': str(binding.kernel.image_ref), + 'BACKENDAI_CLUSTER_ROLE': binding.kernel.cluster_role, + 'BACKENDAI_CLUSTER_IDX': str(binding.kernel.cluster_idx), + 'BACKENDAI_CLUSTER_HOST': str(binding.kernel.cluster_hostname), + }, + 'resource_slots': binding.kernel.requested_slots.to_json(), + 'resource_opts': binding.kernel.resource_opts, + 'bootstrap_script': binding.kernel.bootstrap_script, + 'startup_command': binding.kernel.startup_command, + 'internal_data': scheduled_session.internal_data, + 'auto_pull': auto_pull, + 'preopen_ports': scheduled_session.preopen_ports, + } + for binding in items + ], + cluster_info, + ) + log.debug( + 'start_session(s:{}, ak:{}, k:{}) -> created on ag:{}', + scheduled_session.session_name, + scheduled_session.access_key, + [binding.kernel.kernel_id for binding in items], + agent_alloc_ctx.agent_id, + ) + # Pass the return value of RPC calls to post-processing tasks + for created_info in created_infos: + kernel_id = KernelId(uuid.UUID(created_info['id'])) + self._post_kernel_creation_infos[kernel_id].set_result(created_info) + await asyncio.gather(*post_tasks, return_exceptions=True) + except Exception as e: + # The agent has already cancelled or issued the destruction lifecycle event + # for this batch of kernels. + for binding in items: + kernel_id = binding.kernel.kernel_id + self.kernel_creation_tracker[kernel_id].cancel() + self._post_kernel_creation_infos[kernel_id].set_exception(e) + await asyncio.gather(*post_tasks, return_exceptions=True) + raise + + async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair: + key = rsa.generate_private_key( + backend=default_backend(), + public_exponent=65537, + key_size=2048, + ) + public_key = key.public_key().public_bytes( + serialization.Encoding.OpenSSH, + serialization.PublicFormat.OpenSSH, + ) + public_key += b' work@cluster.backend.ai.local' + pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + return { + 'private_key': pem.decode('utf-8'), + 'public_key': public_key.decode('utf-8'), + } + + async def get_keypair_occupancy(self, access_key, *, conn=None): + known_slot_types = \ + await self.shared_config.get_resource_slots() + + async def _query() -> ResourceSlot: + async with reenter_txn(self.db, conn) as _conn: + query = ( + sa.select([kernels.c.occupied_slots]) + .where( + (kernels.c.access_key == access_key) & + (kernels.c.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ) + ) + zero = ResourceSlot() + key_occupied = sum([ + row['occupied_slots'] + async for row in (await _conn.stream(query))], zero) + # drop no-longer used slot types + drops = [k for k in key_occupied.keys() if k not in known_slot_types] + for k in drops: + del key_occupied[k] + return key_occupied + + return await execute_with_retry(_query) + + async def get_domain_occupancy(self, domain_name, *, conn=None): + # TODO: store domain occupied_slots in Redis? + known_slot_types = await self.shared_config.get_resource_slots() + + async def _query() -> ResourceSlot: + async with reenter_txn(self.db, conn) as _conn: + query = ( + sa.select([kernels.c.occupied_slots]) + .where( + (kernels.c.domain_name == domain_name) & + (kernels.c.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ) + ) + zero = ResourceSlot() + key_occupied = sum( + [ + row['occupied_slots'] + async for row in (await _conn.stream(query)) + ], + zero, + ) + # drop no-longer used slot types + drops = [k for k in key_occupied.keys() if k not in known_slot_types] + for k in drops: + del key_occupied[k] + return key_occupied + + return await execute_with_retry(_query) + + async def get_group_occupancy(self, group_id, *, conn=None): + # TODO: store domain occupied_slots in Redis? + known_slot_types = await self.shared_config.get_resource_slots() + + async def _query() -> ResourceSlot: + async with reenter_txn(self.db, conn) as _conn: + query = ( + sa.select([kernels.c.occupied_slots]) + .where( + (kernels.c.group_id == group_id) & + (kernels.c.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)), + ) + ) + zero = ResourceSlot() + key_occupied = sum( + [ + row['occupied_slots'] + async for row in (await _conn.stream(query)) + ], + zero, + ) + # drop no-longer used slot types + drops = [k for k in key_occupied.keys() if k not in known_slot_types] + for k in drops: + del key_occupied[k] + return key_occupied + + return await execute_with_retry(_query) + + async def update_scaling_group(self, id, scaling_group) -> None: + agent = await self.get_instance(id, agents.c.addr) + async with RPCContext( + agent['id'], + agent['addr'], + invoke_timeout=None, + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + await rpc.call.update_scaling_group(scaling_group) + + async def settle_agent_alloc( + self, kernel_agent_bindings: Sequence[KernelAgentBinding], + ) -> None: + """ + Tries to settle down agent row's occupied_slots with real value. This must be called + after kernel creation is completed, to prevent fraction of resource dropped by agent scheduler + during kernel creation still being reported as used. + """ + + keyfunc = lambda item: item.agent_alloc_ctx.agent_id + for agent_id, group_iterator in itertools.groupby( + sorted(kernel_agent_bindings, key=keyfunc), key=keyfunc, + ): + actual_allocated_slots = ResourceSlot() + requested_slots = ResourceSlot() + + for kernel_agent_binding in group_iterator: + # this value must be set while running _post_create_kernel + actual_allocated_slot = self._kernel_actual_allocated_resources.get( + kernel_agent_binding.kernel.kernel_id) + requested_slots += kernel_agent_binding.kernel.requested_slots + if actual_allocated_slot is not None: + actual_allocated_slots += ResourceSlot.from_json(actual_allocated_slot) + del self._kernel_actual_allocated_resources[kernel_agent_binding.kernel.kernel_id] + else: # something's wrong; just fall back to requested slot value + actual_allocated_slots += kernel_agent_binding.kernel.requested_slots + + # perform DB update only if requested slots and actual allocated value differs + if actual_allocated_slots != requested_slots: + log.debug('calibrating resource slot usage for agent {}', agent_id) + async with self.db.begin() as conn: + select_query = ( + sa.select([agents.c.occupied_slots]) + .select_from(agents).where(agents.c.id == agent_id) + ) + result = await conn.execute(select_query) + occupied_slots: ResourceSlot = result.scalar() + diff = actual_allocated_slots - requested_slots + update_query = ( + sa.update(agents).values({ + 'occupied_slots': ResourceSlot.from_json(occupied_slots) + diff, + }).where(agents.c.id == agent_id) + ) + await conn.execute(update_query) + + async def recalc_resource_usage(self, do_fullscan: bool = False) -> None: + concurrency_used_per_key: MutableMapping[str, int] = defaultdict(lambda: 0) + occupied_slots_per_agent: MutableMapping[str, ResourceSlot] = \ + defaultdict(lambda: ResourceSlot({'cpu': 0, 'mem': 0})) + + async def _recalc() -> None: + async with self.db.begin() as conn: + # Query running containers and calculate concurrency_used per AK and + # occupied_slots per agent. + query = ( + sa.select([kernels.c.access_key, kernels.c.agent, kernels.c.occupied_slots]) + .where(kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) + .order_by(sa.asc(kernels.c.access_key)) + ) + async for row in (await conn.stream(query)): + occupied_slots_per_agent[row.agent] += ResourceSlot(row.occupied_slots) + query = ( + sa.select([kernels.c.access_key, kernels.c.agent, kernels.c.occupied_slots]) + .where(kernels.c.status.in_(USER_RESOURCE_OCCUPYING_KERNEL_STATUSES)) + .order_by(sa.asc(kernels.c.access_key)) + ) + async for row in (await conn.stream(query)): + concurrency_used_per_key[row.access_key] += 1 + + if len(occupied_slots_per_agent) > 0: + # Update occupied_slots for agents with running containers. + for aid, slots in occupied_slots_per_agent.items(): + query = ( + sa.update(agents) + .values(occupied_slots=slots) + .where(agents.c.id == aid) + ) + await conn.execute(query) + # Update all other agents to have empty occupied_slots. + query = ( + sa.update(agents) + .values(occupied_slots=ResourceSlot({})) + .where(agents.c.status == AgentStatus.ALIVE) + .where(sa.not_(agents.c.id.in_(occupied_slots_per_agent.keys()))) + ) + await conn.execute(query) + else: + query = ( + sa.update(agents) + .values(occupied_slots=ResourceSlot({})) + .where(agents.c.status == AgentStatus.ALIVE) + ) + await conn.execute(query) + + await execute_with_retry(_recalc) + + # Update keypair resource usage for keypairs with running containers. + kp_key = 'keypair.concurrency_used' + + async def _update(r: aioredis.Redis): + updates: Mapping[str, int] = \ + {f'{kp_key}.{k}': concurrency_used_per_key[k] for k in concurrency_used_per_key} + if updates: + await r.mset(updates) + + async def _update_by_fullscan(r: aioredis.Redis): + updates: Dict[str, int] = {} + keys = await r.keys(f'{kp_key}.*') + for ak in keys: + usage = concurrency_used_per_key.get(ak, 0) + updates[f'{kp_key}.{ak}'] = usage + if updates: + await r.mset(updates) + + if do_fullscan: + await redis.execute( + self.redis_stat, + _update_by_fullscan, + ) + else: + await redis.execute( + self.redis_stat, + _update, + ) + + async def destroy_session_lowlevel( + self, + session_id: SessionId, + kernels: Sequence[Row], # should have (id, agent, agent_addr, container_id) columns + ) -> None: + """ + Destroy the kernels that belongs the to given session unconditionally + and without generation of any relevant events nor invocation of plugin hooks. + """ + keyfunc = lambda item: item['agent'] if item['agent'] is not None else '' + for agent_id, group_iterator in itertools.groupby( + sorted(kernels, key=keyfunc), key=keyfunc, + ): + rpc_coros = [] + destroyed_kernels = [] + grouped_kernels = [*group_iterator] + for kernel in grouped_kernels: + if kernel['container_id'] is not None and kernel['agent_addr'] is not None: + destroyed_kernels.append(kernel) + if not destroyed_kernels: + return + async with RPCContext( + destroyed_kernels[0]['agent'], + destroyed_kernels[0]['agent_addr'], + invoke_timeout=None, + order_key=str(session_id), + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + for kernel in destroyed_kernels: + # internally it enqueues a "destroy" lifecycle event. + rpc_coros.append( + rpc.call.destroy_kernel( + str(kernel['id']), + "failed-to-start", + suppress_events=True, + ), + ) + await asyncio.gather(*rpc_coros) + + async def destroy_session( + self, + session_getter: SessionGetter, + *, + forced: bool = False, + reason: str = 'user-requested', + ) -> Mapping[str, Any]: + """ + Destroy session kernels. Do not destroy + PREPARING/TERMINATING/ERROR and PULLING sessions. + + :param forced: If True, destroy PREPARING/TERMINATING/ERROR session. + However, PULLING session still cannot be destroyed. + :param reason: Reason to destroy a session if client wants to specify it manually. + """ + async with self.db.begin_readonly() as conn: + session = await session_getter(db_connection=conn) + if forced: + reason = 'force-terminated' + hook_result = await self.hook_plugin_ctx.dispatch( + 'PRE_DESTROY_SESSION', + (session['session_id'], session['session_name'], session['access_key']), + return_when=ALL_COMPLETED, + ) + if hook_result.status != PASSED: + raise RejectedByHook.from_hook_result(hook_result) + + async with self.handle_kernel_exception( + 'destroy_session', session['id'], session['access_key'], set_error=True, + ): + + async def _fetch() -> Sequence[Row]: + async with self.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.id, + kernels.c.session_id, + kernels.c.session_creation_id, + kernels.c.status, + kernels.c.access_key, + kernels.c.cluster_role, + kernels.c.agent, + kernels.c.agent_addr, + kernels.c.container_id, + ]) + .select_from(kernels) + .where(kernels.c.session_id == session['id']) + ) + result = await conn.execute(query) + kernel_list = result.fetchall() + return kernel_list + + kernel_list = await execute_with_retry(_fetch) + main_stat = {} + per_agent_tasks = [] + now = datetime.now(tzutc()) + + keyfunc = lambda item: item['agent'] if item['agent'] is not None else '' + for agent_id, group_iterator in itertools.groupby( + sorted(kernel_list, key=keyfunc), key=keyfunc, + ): + destroyed_kernels = [] + grouped_kernels = [*group_iterator] + for kernel in grouped_kernels: + if kernel['status'] == KernelStatus.PENDING: + + async def _update() -> None: + async with self.db.begin() as conn: + await conn.execute( + sa.update(kernels) + .values({ + 'status': KernelStatus.CANCELLED, + 'status_info': reason, + 'status_changed': now, + 'terminated_at': now, + }) + .where(kernels.c.id == kernel['id']), + ) + + await execute_with_retry(_update) + await self.event_producer.produce_event( + KernelCancelledEvent(kernel['id'], '', reason), + ) + if kernel['cluster_role'] == DEFAULT_ROLE: + main_stat = {'status': 'cancelled'} + await self.event_producer.produce_event( + SessionCancelledEvent( + kernel['session_id'], + kernel['session_creation_id'], + reason, + ), + ) + elif kernel['status'] == KernelStatus.PULLING: + raise GenericForbidden('Cannot destroy kernels in pulling status') + elif kernel['status'] in ( + KernelStatus.SCHEDULED, + KernelStatus.PREPARING, + KernelStatus.TERMINATING, + KernelStatus.ERROR, + ): + if not forced: + raise GenericForbidden( + 'Cannot destroy kernels in scheduled/preparing/terminating/error status', + ) + log.warning('force-terminating kernel (k:{}, status:{})', + kernel['id'], kernel['status']) + if kernel['container_id'] is not None: + destroyed_kernels.append(kernel) + + async def _update() -> None: + kern_stat = await redis.execute( + self.redis_stat, + lambda r: r.get(str(kernel['id'])), + ) + async with self.db.begin() as conn: + values = { + 'status': KernelStatus.TERMINATED, + 'status_info': reason, + 'status_changed': now, + 'terminated_at': now, + } + if kern_stat: + values['last_stat'] = msgpack.unpackb(kern_stat) + await conn.execute( + sa.update(kernels) + .values(values) + .where(kernels.c.id == kernel['id']), + ) + + if kernel['cluster_role'] == DEFAULT_ROLE: + # The main session is terminated; + # decrement the user's concurrency counter + await redis.execute( + self.redis_stat, + lambda r: r.incrby( + f"keypair.concurrency_used.{kernel['access_key']}", + -1, + ), + ) + + await execute_with_retry(_update) + await self.event_producer.produce_event( + KernelTerminatedEvent(kernel['id'], reason), + ) + else: + + async def _update() -> None: + async with self.db.begin() as conn: + await conn.execute( + sa.update(kernels) + .values({ + 'status': KernelStatus.TERMINATING, + 'status_info': reason, + 'status_changed': now, + 'status_data': { + "kernel": {"exit_code": None}, + "session": {"status": "terminating"}, + }, + }) + .where(kernels.c.id == kernel['id']), + ) + + if kernel['cluster_role'] == DEFAULT_ROLE: + # The main session is terminated; + # decrement the user's concurrency counter + await redis.execute( + self.redis_stat, + lambda r: r.incrby( + f"keypair.concurrency_used.{kernel['access_key']}", + -1, + ), + ) + + await execute_with_retry(_update) + await self.event_producer.produce_event( + KernelTerminatingEvent(kernel['id'], reason), + ) + + if kernel['agent_addr'] is None: + await self.mark_kernel_terminated(kernel['id'], 'missing-agent-allocation') + if kernel['cluster_role'] == DEFAULT_ROLE: + main_stat = {'status': 'terminated'} + else: + destroyed_kernels.append(kernel) + + async def _destroy_kernels_in_agent(session, destroyed_kernels) -> None: + nonlocal main_stat + async with RPCContext( + destroyed_kernels[0]['agent'], + destroyed_kernels[0]['agent_addr'], + invoke_timeout=None, + order_key=session['session_id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + rpc_coros = [] + for kernel in destroyed_kernels: + # internally it enqueues a "destroy" lifecycle event. + if kernel['status'] != KernelStatus.SCHEDULED: + rpc_coros.append( + rpc.call.destroy_kernel(str(kernel['id']), reason), + ) + try: + await asyncio.gather(*rpc_coros) + except Exception: + log.exception( + "destroy_kernels_in_agent(a:{}, s:{}): unexpected error", + destroyed_kernels[0]['agent'], + session['session_id'], + ) + for kernel in destroyed_kernels: + last_stat: Optional[Dict[str, Any]] + last_stat = None + try: + raw_last_stat = await redis.execute( + self.redis_stat, + lambda r: r.get(str(kernel['id']))) + if raw_last_stat is not None: + last_stat = msgpack.unpackb(raw_last_stat) + last_stat['version'] = 2 + except asyncio.TimeoutError: + pass + if kernel['cluster_role'] == DEFAULT_ROLE: + main_stat = { + **(last_stat if last_stat is not None else {}), + 'status': 'terminated', + } + + if destroyed_kernels: + per_agent_tasks.append(_destroy_kernels_in_agent(session, destroyed_kernels)) + + if per_agent_tasks: + await asyncio.gather(*per_agent_tasks, return_exceptions=True) + await self.hook_plugin_ctx.notify( + 'POST_DESTROY_SESSION', + (session['session_id'], session['session_name'], session['access_key']), + ) + if forced: + await self.recalc_resource_usage() + return main_stat + + async def clean_session( + self, + session_id: SessionId, + ) -> None: + + async def _fetch() -> Row: + async with self.db.begin_readonly() as conn: + query = ( + sa.select([ + kernels.c.session_id, + kernels.c.cluster_mode, + kernels.c.cluster_size, + kernels.c.agent, + kernels.c.agent_addr, + ]) + .select_from(kernels) + .where( + (kernels.c.session_id == session_id) & + (kernels.c.cluster_role == DEFAULT_ROLE), + ) + ) + result = await conn.execute(query) + return result.first() + + session = await execute_with_retry(_fetch) + if session is None: + return + if session['cluster_mode'] == ClusterMode.SINGLE_NODE and session['cluster_size'] > 1: + network_name = f'bai-singlenode-{session["session_id"]}' + try: + async with RPCContext( + session['agent'], # the main-container's agent + session['agent_addr'], + invoke_timeout=None, + order_key=session['session_id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + await rpc.call.destroy_local_network(network_name) + except Exception: + log.exception(f"Failed to destroy the agent-local network {network_name}") + elif session['cluster_mode'] == ClusterMode.MULTI_NODE: + network_name = f'bai-multinode-{session["session_id"]}' + try: + try: + # await rpc.call.destroy_overlay_network(network_name) + await asyncio.sleep(2.0) + network = await self.docker.networks.get(network_name) + await network.delete() + except aiodocker.DockerError as e: + if e.status == 404: + # It may have been auto-destructed when the last container was detached. + pass + else: + raise + except Exception: + log.exception(f"Failed to destroy the overlay network {network_name}") + else: + pass + + async def restart_session( + self, + session_creation_id: str, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + ) -> None: + log.warning('restart_session({})', session_name_or_id) + async with self.db.begin_readonly() as conn: + session_infos = await match_session_ids( + session_name_or_id, + access_key, + db_connection=conn, + ) + if len(session_infos) > 1: + raise TooManySessionsMatched(extra_data={'matches': session_infos}) + elif len(session_infos) == 0: + raise SessionNotFound() + session_id = session_infos[0]['session_id'] + kernel_list = [row for row in await get_all_kernels( + [session_id], + db_connection=conn, + )][0] + + async def _restart_kernel(kernel) -> None: + loop = asyncio.get_running_loop() + try: + kernel_creation_id = secrets.token_urlsafe(16) + start_future = loop.create_future() + self.kernel_creation_tracker[ + kernel['id'] + ] = start_future + try: + async with self.db.begin() as conn: + query = ( + kernels.update() + .values({ + 'status': KernelStatus.RESTARTING, + }) + .where(kernels.c.id == kernel['id']) + ) + await conn.execute(query) + async with RPCContext( + kernel['agent'], # the main-container's agent + kernel['agent_addr'], + invoke_timeout=None, + order_key=None, + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + updated_config: Dict[str, Any] = { + # TODO: support resacling of sub-containers + } + kernel_info = await rpc.call.restart_kernel( + kernel_creation_id, + str(kernel['session_id']), + str(kernel['id']), + updated_config, + ) + await start_future + async with self.db.begin() as conn: + query = ( + kernels.update() + .values({ + 'status': KernelStatus.RUNNING, + 'container_id': kernel_info['container_id'], + 'repl_in_port': kernel_info['repl_in_port'], + 'repl_out_port': kernel_info['repl_out_port'], + 'stdin_port': kernel_info['stdin_port'], + 'stdout_port': kernel_info['stdout_port'], + 'service_ports': kernel_info.get('service_ports', []), + }) + .where(kernels.c.id == kernel['id']) + ) + await conn.execute(query) + finally: + del self.kernel_creation_tracker[ + kernel['id'] + ] + except Exception: + log.exception('unexpected-error in _restart_kerenl()') + + restart_coros = [] + for kernel in kernel_list: + restart_coros.append(_restart_kernel(kernel)) + async with self.handle_kernel_exception( + 'restart_session', session_id, access_key, set_error=True, + ): + await asyncio.gather(*restart_coros) + + # NOTE: If the restarted session is a batch-type one, then the startup command + # will be executed again after restart. + await self.event_producer.produce_event( + SessionStartedEvent(session_id, session_creation_id), + ) + + async def execute( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + api_version: Tuple[int, str], + run_id: str, + mode: str, + code: str, + opts: Mapping[str, Any], + *, + flush_timeout: float = None, + ) -> Mapping[str, Any]: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('execute', kernel['id'], access_key): + # The agent aggregates at most 2 seconds of outputs + # if the kernel runs for a long time. + major_api_version = api_version[0] + if major_api_version == 4: # manager-agent protocol is same. + major_api_version = 3 + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=30, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + return await rpc.call.execute( + str(kernel['id']), + major_api_version, + run_id, mode, code, opts, + flush_timeout, + ) + + async def interrupt_session( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + ) -> Mapping[str, Any]: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('execute', kernel['id'], access_key): + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=30, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + return await rpc.call.interrupt_kernel(str(kernel['id'])) + + async def get_completions( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + text: str, + opts: Mapping[str, Any], + ) -> Mapping[str, Any]: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('execute', kernel['id'], access_key): + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=10, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + return await rpc.call.get_completions(str(kernel['id']), text, opts) + + async def start_service( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + service: str, + opts: Mapping[str, Any], + ) -> Mapping[str, Any]: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('execute', kernel['id'], access_key): + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=None, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + return await rpc.call.start_service(str(kernel['id']), service, opts) + + async def shutdown_service( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + service: str, + ) -> None: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('shutdown_service', kernel['id'], access_key): + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=None, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + return await rpc.call.shutdown_service(str(kernel['id']), service) + + async def upload_file( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + filename: str, + payload: bytes, + ) -> Mapping[str, Any]: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('upload_file', kernel['id'], access_key): + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=None, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + return await rpc.call.upload_file(str(kernel['id']), filename, payload) + + async def download_file( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + filepath: str, + ) -> bytes: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('download_file', kernel['id'], + access_key): + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=None, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + return await rpc.call.download_file(str(kernel['id']), filepath) + + async def list_files( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + path: str, + ) -> Mapping[str, Any]: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('list_files', kernel['id'], access_key): + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=30, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + return await rpc.call.list_files(str(kernel['id']), path) + + async def get_logs_from_agent( + self, + session_name_or_id: Union[str, SessionId], + access_key: AccessKey, + ) -> str: + kernel = await self.get_session(session_name_or_id, access_key) + async with self.handle_kernel_exception('get_logs_from_agent', kernel['id'], access_key): + async with RPCContext( + kernel['agent'], + kernel['agent_addr'], + invoke_timeout=30, + order_key=kernel['id'], + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + reply = await rpc.call.get_logs(str(kernel['id'])) + return reply['logs'] + + async def increment_session_usage( + self, + session_name: str, + access_key: AccessKey, + conn: SAConnection = None, + ) -> None: + pass + # async with reenter_txn(self.db, conn) as conn: + # query = ( + # sa.update(kernels) + # .values(num_queries=kernels.c.num_queries + 1) + # .where( + # (kernels.c.session_name == session_name) & + # (kernels.c.access_key == access_key) & + # (kernels.c.cluster_role == DEFAULT_ROLE) + # ) + # ) + # await execute_with_retry(conn, query) + + async def kill_all_sessions_in_agent(self, agent_id, agent_addr): + async with RPCContext( + agent_id, + agent_addr, + invoke_timeout=None, + keepalive_timeout=self.rpc_keepalive_timeout, + ) as rpc: + coro = rpc.call.clean_all_kernels('manager-freeze-force-kill') + return await coro + + async def kill_all_sessions(self, conn=None): + async with reenter_txn(self.db, conn, {'postgresql_readonly': True}) as conn: + query = (sa.select([agents.c.id, agents.c.addr]) + .where(agents.c.status == AgentStatus.ALIVE)) + result = await conn.execute(query) + rows = result.fetchall() + tasks = [] + for row in rows: + tasks.append( + self.kill_all_sessions_in_agent(row['id'], row['addr']), + ) + await asyncio.gather(*tasks, return_exceptions=True) + + async def handle_heartbeat(self, agent_id, agent_info): + now = datetime.now(tzutc()) + slot_key_and_units = { + SlotName(k): SlotTypes(v[0]) for k, v in + agent_info['resource_slots'].items()} + available_slots = ResourceSlot({ + SlotName(k): Decimal(v[1]) for k, v in + agent_info['resource_slots'].items()}) + current_addr = agent_info['addr'] + sgroup = agent_info.get('scaling_group', 'default') + async with self.heartbeat_lock: + + instance_rejoin = False + + # Update "last seen" timestamp for liveness tracking + await redis.execute( + self.redis_live, + lambda r: r.hset('agent.last_seen', agent_id, now.timestamp()), + ) + + # Check and update status of the agent record in DB + async def _update() -> None: + nonlocal instance_rejoin + async with self.db.begin() as conn: + fetch_query = ( + sa.select([ + agents.c.status, + agents.c.addr, + agents.c.scaling_group, + agents.c.available_slots, + agents.c.version, + agents.c.compute_plugins, + agents.c.architecture, + ]) + .select_from(agents) + .where(agents.c.id == agent_id) + .with_for_update() + ) + result = await conn.execute(fetch_query) + row = result.first() + + if row is None or row['status'] is None: + # new agent detected! + log.info('agent {0} joined!', agent_id) + await self.shared_config.update_resource_slots(slot_key_and_units) + insert_query = sa.insert(agents).values({ + 'id': agent_id, + 'status': AgentStatus.ALIVE, + 'region': agent_info['region'], + 'scaling_group': sgroup, + 'available_slots': available_slots, + 'occupied_slots': {}, + 'addr': agent_info['addr'], + 'first_contact': now, + 'lost_at': sa.null(), + 'version': agent_info['version'], + 'compute_plugins': agent_info['compute_plugins'], + 'architecture': agent_info.get('architecture', 'x86_64'), + }) + result = await conn.execute(insert_query) + assert result.rowcount == 1 + elif row['status'] == AgentStatus.ALIVE: + updates = {} + if row['available_slots'] != available_slots: + updates['available_slots'] = available_slots + if row['scaling_group'] != sgroup: + updates['scaling_group'] = sgroup + if row['addr'] != current_addr: + updates['addr'] = current_addr + if row['version'] != agent_info['version']: + updates['version'] = agent_info['version'] + if row['compute_plugins'] != agent_info['compute_plugins']: + updates['compute_plugins'] = agent_info['compute_plugins'] + if row['architecture'] != agent_info['architecture']: + updates['architecture'] = agent_info['architecture'] + # occupied_slots are updated when kernels starts/terminates + if updates: + await self.shared_config.update_resource_slots(slot_key_and_units) + update_query = ( + sa.update(agents) + .values(updates) + .where(agents.c.id == agent_id) + ) + await conn.execute(update_query) + elif row['status'] in (AgentStatus.LOST, AgentStatus.TERMINATED): + await self.shared_config.update_resource_slots(slot_key_and_units) + instance_rejoin = True + update_query = ( + sa.update(agents) + .values({ + 'status': AgentStatus.ALIVE, + 'region': agent_info['region'], + 'scaling_group': sgroup, + 'addr': agent_info['addr'], + 'lost_at': sa.null(), + 'available_slots': available_slots, + 'version': agent_info['version'], + 'compute_plugins': agent_info['compute_plugins'], + 'architecture': agent_info['architecture'], + }) + .where(agents.c.id == agent_id) + ) + await conn.execute(update_query) + else: + log.error('should not reach here! {0}', type(row['status'])) + + try: + await execute_with_retry(_update) + except sa.exc.IntegrityError: + log.error("Scaling group named [{}] does not exist.", sgroup) + return + + if instance_rejoin: + await self.event_producer.produce_event( + AgentStartedEvent('revived'), + source=agent_id, + ) + + # Update the mapping of kernel images to agents. + known_registries = await get_known_registries(self.shared_config.etcd) + loaded_images = msgpack.unpackb(snappy.decompress(agent_info['images'])) + + def _pipe_builder(r: aioredis.Redis): + pipe = r.pipeline() + for image in loaded_images: + image_ref = ImageRef(image[0], known_registries, agent_info['architecture']) + pipe.sadd(image_ref.canonical, agent_id) + return pipe + await redis.execute(self.redis_image, _pipe_builder) + + await self.hook_plugin_ctx.notify( + 'POST_AGENT_HEARTBEAT', + (agent_id, sgroup, available_slots), + ) + + async def mark_agent_terminated(self, agent_id: AgentId, status: AgentStatus) -> None: + await redis.execute(self.redis_live, lambda r: r.hdel('agent.last_seen', agent_id)) + + async def _pipe_builder(r: aioredis.Redis): + pipe = r.pipeline() + async for imgname in r.scan_iter(): + pipe.srem(imgname, agent_id) + return pipe + + async def _update() -> None: + async with self.db.begin() as conn: + fetch_query = ( + sa.select([ + agents.c.status, + agents.c.addr, + ]) + .select_from(agents) + .where(agents.c.id == agent_id) + .with_for_update() + ) + result = await conn.execute(fetch_query) + row = result.first() + prev_status = row['status'] + if prev_status in (None, AgentStatus.LOST, AgentStatus.TERMINATED): + return + + if status == AgentStatus.LOST: + log.warning('agent {0} heartbeat timeout detected.', agent_id) + elif status == AgentStatus.TERMINATED: + log.info('agent {0} has terminated.', agent_id) + now = datetime.now(tzutc()) + update_query = ( + sa.update(agents) + .values({ + 'status': status, + 'status_changed': now, + 'lost_at': now, + }) + .where(agents.c.id == agent_id) + ) + await conn.execute(update_query) + + await redis.execute(self.redis_image, _pipe_builder) + await execute_with_retry(_update) + + async def set_session_status( + self, + session_id: SessionId, + access_key: AccessKey, + status: KernelStatus, + reason: str = '', + **extra_fields, + ) -> None: + now = datetime.now(tzutc()) + data = { + 'status': status, + 'status_info': reason, + 'status_changed': now, + } + if status in (KernelStatus.CANCELLED, KernelStatus.TERMINATED): + data['terminated_at'] = now + data.update(extra_fields) + + async def _update() -> None: + async with self.db.begin() as conn: + query = ( + sa.update(kernels) + .values(data) + .where( + (kernels.c.session_id == session_id) & + (kernels.c.access_key == access_key) & + ~(kernels.c.status.in_(DEAD_KERNEL_STATUSES)), + ) + ) + await conn.execute(query) + + await execute_with_retry(_update) + + async def set_kernel_status( + self, kernel_id: KernelId, + status: KernelStatus, + reason: str = '', + ) -> None: + assert status != KernelStatus.TERMINATED, \ + 'TERMINATED status update must be handled in ' \ + 'mark_kernel_terminated()' + now = datetime.now(tzutc()) + data = { + 'status': status, + 'status_info': reason, + 'status_changed': now, + } + if status in (KernelStatus.CANCELLED, KernelStatus.TERMINATED): + data['terminated_at'] = now + + async def _update() -> None: + async with self.db.begin() as conn: + query = ( + sa.update(kernels) + .values(data) + .where(kernels.c.id == kernel_id) + ) + await conn.execute(query) + + await execute_with_retry(_update) + + async def set_session_result( + self, + session_id: SessionId, + success: bool, + exit_code: int, + ) -> None: + # TODO: store exit code? + data = { + 'result': SessionResult.SUCCESS if success else SessionResult.FAILURE, + } + + async def _update() -> None: + async with self.db.begin() as conn: + query = ( + sa.update(kernels) + .values(data) + .where(kernels.c.id == session_id) + ) + await conn.execute(query) + + await execute_with_retry(_update) + + async def sync_kernel_stats( + self, kernel_ids: Sequence[KernelId], + ) -> None: + per_kernel_updates = {} + log.debug('sync_kernel_stats(k:{!r})', kernel_ids) + for kernel_id in kernel_ids: + raw_kernel_id = str(kernel_id) + kern_stat = await redis.execute( + self.redis_stat, + lambda r: r.get(raw_kernel_id), + ) + if kern_stat is None: + log.warning('sync_kernel_stats(k:{}): no statistics updates', kernel_id) + continue + else: + per_kernel_updates[kernel_id] = msgpack.unpackb(kern_stat) + + async def _update(): + async with self.db.begin() as conn: + update_query = ( + sa.update(kernels) + .where(kernels.c.id == sa.bindparam('kernel_id')) + .values({kernels.c.last_stat: sa.bindparam('last_stat')}) + ) + params = [] + for kernel_id, updates in per_kernel_updates.items(): + params.append({ + 'kernel_id': kernel_id, + 'last_stat': updates, + }) + await conn.execute(update_query, params) + + if per_kernel_updates: + await execute_with_retry(_update) + + async def mark_kernel_terminated( + self, + kernel_id: KernelId, + reason: str, + exit_code: int = None, + ) -> None: + """ + Mark the kernel (individual worker) terminated and release + the resource slots occupied by it. + """ + post_task = self._post_kernel_creation_tasks.get(kernel_id, None) + if post_task is not None and not post_task.done(): + post_task.cancel() + try: + await post_task + except asyncio.CancelledError: + pass + + kern_stat = await redis.execute( + self.redis_stat, + lambda r: r.get(str(kernel_id)), + ) + + async def _update_kernel_status() -> Row | None: + async with self.db.begin() as conn: + # Check the current status. + select_query = ( + sa.select([ + kernels.c.access_key, + kernels.c.agent, + kernels.c.status, + kernels.c.occupied_slots, + kernels.c.session_id, + ]) + .select_from(kernels) + .where(kernels.c.id == kernel_id) + .with_for_update() + ) + result = await conn.execute(select_query) + kernel = result.first() + if ( + kernel is None + or kernel['status'] in ( + KernelStatus.CANCELLED, + KernelStatus.TERMINATED, + KernelStatus.RESTARTING, + ) + ): + # Skip if non-existent, already terminated, or restarting. + return None + + # Change the status to TERMINATED. + # (we don't delete the row for later logging and billing) + now = datetime.now(tzutc()) + values = { + 'status': KernelStatus.TERMINATED, + 'status_info': reason, + 'status_changed': now, + 'status_data': sql_json_merge( + kernels.c.status_data, + ("kernel",), + {"exit_code": exit_code}, + ), + 'terminated_at': now, + } + if kern_stat: + values['last_stat'] = msgpack.unpackb(kern_stat) + update_query = ( + sa.update(kernels) + .values(values) + .where(kernels.c.id == kernel_id) + ) + await conn.execute(update_query) + return kernel + + kernel = await execute_with_retry(_update_kernel_status) + if kernel is None: + return + + async def _recalc() -> None: + assert kernel is not None + async with self.db.begin() as conn: + await recalc_concurrency_used(conn, self.redis_stat, kernel['access_key']) + await recalc_agent_resource_occupancy(conn, kernel['agent']) + + await execute_with_retry(_recalc) + + # Perform statistics sync in a separate transaction block, since + # it may take a while to fetch stats from Redis. + + await self.sync_kernel_stats([kernel_id]) + + async def check_session_terminated( + self, + kernel_id: KernelId, + reason: str, + ) -> None: + + async def _check_and_mark() -> Tuple[bool, SessionId | None]: + async with self.db.begin() as conn: + session_id_query = ( + sa.select([ + kernels.c.session_id, + ]) + .select_from(kernels) + .where(kernels.c.id == kernel_id) + ) + kernels_query = ( + sa.select([ + kernels.c.session_id, + kernels.c.status_data, + kernels.c.status, + ]) + .select_from(kernels) + .where( + (kernels.c.session_id == session_id_query.scalar_subquery()), + ) + .with_for_update() + ) + result = await conn.execute(kernels_query) + rows = result.fetchall() + if not rows: + return False, None + session_id = rows[0]['session_id'] + if nmget(rows[0]['status_data'], "session.status") == "terminated": + # if already marked "session-terminated", skip the rest process + return False, session_id + all_terminated = all(map( + lambda row: row['status'] in (KernelStatus.TERMINATED, KernelStatus.CANCELLED), + rows, + )) + if all_terminated: + await conn.execute( + sa.update(kernels) + .values( + status_data=sql_json_merge( + kernels.c.status_data, + ("session",), + { + "status": "terminated", + }, + ), + ) + .where( + (kernels.c.session_id == session_id), + ), + ) + return all_terminated, session_id + + do_fire_event, session_id = await execute_with_retry(_check_and_mark) + if session_id is None: + return + if do_fire_event: + await self.event_producer.produce_event( + SessionTerminatedEvent(session_id, reason), + ) + + async def mark_session_terminated( + self, + session_id: SessionId, + reason: str, + ) -> None: + await self.clean_session(session_id) diff --git a/src/ai/backend/manager/scheduler/__init__.py b/src/ai/backend/manager/scheduler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py new file mode 100644 index 0000000000..ea3b7b71c5 --- /dev/null +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -0,0 +1,1019 @@ +from __future__ import annotations + +import asyncio +import logging +from contextvars import ContextVar +from datetime import datetime, timedelta +from typing import ( + Any, + Awaitable, + Final, + List, + Sequence, + Tuple, + Union, + TYPE_CHECKING, + Optional, +) + +import aiotools +from dateutil.tz import tzutc +import sqlalchemy as sa +from sqlalchemy.engine.row import Row +from sqlalchemy.exc import DBAPIError +from sqlalchemy.ext.asyncio import ( + AsyncConnection as SAConnection, +) +from sqlalchemy.sql.expression import true + +from ai.backend.common.distributed import GlobalTimer +from ai.backend.common.events import ( + AgentStartedEvent, + CoalescingOptions, + DoScheduleEvent, + DoPrepareEvent, + SessionCancelledEvent, + SessionEnqueuedEvent, + SessionPreparingEvent, + SessionScheduledEvent, + SessionTerminatedEvent, + EventDispatcher, + EventProducer, +) +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + aobject, + AgentId, + ClusterMode, + ResourceSlot, +) +from ai.backend.plugin.entrypoint import scan_entrypoints + +from ai.backend.manager.types import DistributedLockFactory + +from ..api.exceptions import GenericBadRequest, InstanceNotAvailable +from ..defs import ( + LockID, +) +from ..exceptions import convert_to_status_data +from ..models import ( + agents, kernels, scaling_groups, + recalc_agent_resource_occupancy, + recalc_concurrency_used, + AgentStatus, KernelStatus, + AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, +) +from ..models.scaling_group import ScalingGroupOpts +from ..models.utils import ( + ExtendedAsyncSAEngine as SAEngine, + execute_with_retry, + sql_json_increment, + sql_json_merge, +) +from .types import ( + PredicateResult, + PendingSession, + ExistingSession, + SchedulingContext, + AgentContext, + AgentAllocationContext, + AbstractScheduler, + KernelAgentBinding, +) +from .predicates import ( + check_reserved_batch_session, + check_concurrency, + check_dependencies, + check_keypair_resource_limit, + check_group_resource_limit, + check_domain_resource_limit, + check_scaling_group, +) + +if TYPE_CHECKING: + from ..config import LocalConfig, SharedConfig + from ..registry import AgentRegistry + +__all__ = ( + 'load_scheduler', + 'SchedulerDispatcher', +) + +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.scheduler')) + +_log_fmt: ContextVar[str] = ContextVar('_log_fmt') +_log_args: ContextVar[Tuple[Any, ...]] = ContextVar('_log_args') + +_key_schedule_prep_tasks: Final = "scheduler.preptasks" + + +def load_scheduler( + name: str, + sgroup_opts: ScalingGroupOpts, + scheduler_config: dict[str, Any], +) -> AbstractScheduler: + entry_prefix = 'backendai_scheduler_v10' + for entrypoint in scan_entrypoints(entry_prefix): + if entrypoint.name == name: + log.debug('loading scheduler plugin "{}" from {}', name, entrypoint.module) + scheduler_cls = entrypoint.load() + return scheduler_cls(sgroup_opts, scheduler_config) + raise ImportError('Cannot load the scheduler plugin', name) + + +StartTaskArgs = Tuple[ + Tuple[Any, ...], + SchedulingContext, + Tuple[PendingSession, List[KernelAgentBinding]], + List[Tuple[str, Union[Exception, PredicateResult]]], +] + + +class SchedulerDispatcher(aobject): + + config: LocalConfig + shared_config: SharedConfig + registry: AgentRegistry + db: SAEngine + + event_dispatcher: EventDispatcher + event_producer: EventProducer + schedule_timer: GlobalTimer + prepare_timer: GlobalTimer + + def __init__( + self, + local_config: LocalConfig, + shared_config: SharedConfig, + event_dispatcher: EventDispatcher, + event_producer: EventProducer, + lock_factory: DistributedLockFactory, + registry: AgentRegistry, + ) -> None: + self.local_config = local_config + self.shared_config = shared_config + self.event_dispatcher = event_dispatcher + self.event_producer = event_producer + self.registry = registry + self.lock_factory = lock_factory + self.db = registry.db + + async def __ainit__(self) -> None: + coalescing_opts: CoalescingOptions = { + 'max_wait': 0.5, + 'max_batch_size': 32, + } + # coalescing_opts = None + evd = self.registry.event_dispatcher + evd.consume(SessionEnqueuedEvent, None, self.schedule, coalescing_opts, name="dispatcher.enq") + evd.consume(SessionTerminatedEvent, None, self.schedule, coalescing_opts, name="dispatcher.term") + evd.consume(AgentStartedEvent, None, self.schedule) + evd.consume(DoScheduleEvent, None, self.schedule, coalescing_opts) + evd.consume(DoPrepareEvent, None, self.prepare) + self.schedule_timer = GlobalTimer( + self.lock_factory(LockID.LOCKID_SCHEDULE_TIMER, 10.0), + self.event_producer, + lambda: DoScheduleEvent(), + interval=10.0, + ) + self.prepare_timer = GlobalTimer( + self.lock_factory(LockID.LOCKID_PREPARE_TIMER, 10.0), + self.event_producer, + lambda: DoPrepareEvent(), + interval=10.0, + initial_delay=5.0, + ) + await self.schedule_timer.join() + await self.prepare_timer.join() + log.info('Session scheduler started') + + async def close(self) -> None: + async with aiotools.TaskGroup() as tg: + tg.create_task(self.prepare_timer.leave()) + tg.create_task(self.schedule_timer.leave()) + log.info('Session scheduler stopped') + + async def schedule( + self, + context: None, + source: AgentId, + event: SessionEnqueuedEvent | SessionTerminatedEvent | AgentStartedEvent | DoScheduleEvent, + ) -> None: + """ + Trigger the scheduler to scan pending sessions and mark them scheduled if they fulfill + the scheduling requirements. + + HoL blocking issue due to indefinitely preparing sessions will be mitigated because + they will be treated as already "scheduled" sessions and the scheduler will continue to + work on other pending sessions. + + Session status transition: PENDING -> SCHEDULED + """ + log.debug('schedule(): triggered') + known_slot_types = await self.shared_config.get_resource_slots() + sched_ctx = SchedulingContext( + registry=self.registry, + known_slot_types=known_slot_types, + ) + + try: + # The schedule() method should be executed with a global lock + # as its individual steps are composed of many short-lived transactions. + async with self.lock_factory(LockID.LOCKID_SCHEDULE, 60): + async with self.db.begin_readonly() as conn: + query = ( + sa.select([agents.c.scaling_group]) + .select_from(agents) + .where(agents.c.status == AgentStatus.ALIVE) + .group_by(agents.c.scaling_group) + ) + result = await conn.execute(query) + schedulable_scaling_groups = [ + row.scaling_group for row in result.fetchall() + ] + for sgroup_name in schedulable_scaling_groups: + try: + await self._schedule_in_sgroup( + sched_ctx, sgroup_name, + ) + except InstanceNotAvailable: + # Proceed to the next scaling group and come back later. + log.debug('schedule({}): instance not available', sgroup_name) + except Exception as e: + log.exception('schedule({}): scheduling error!\n{}', sgroup_name, repr(e)) + except DBAPIError as e: + if getattr(e.orig, 'pgcode', None) == '55P03': + log.info("schedule(): cancelled due to advisory lock timeout; " + "maybe another schedule() call is still running") + raise asyncio.CancelledError() + raise + + async def _load_scheduler( + self, + db_conn: SAConnection, + sgroup_name: str, + ) -> AbstractScheduler: + query = ( + sa.select([scaling_groups.c.scheduler, scaling_groups.c.scheduler_opts]) + .select_from(scaling_groups) + .where(scaling_groups.c.name == sgroup_name) + ) + result = await db_conn.execute(query) + row = result.first() + scheduler_name = row['scheduler'] + sgroup_opts: ScalingGroupOpts = row['scheduler_opts'] + global_scheduler_opts = {} + if self.shared_config['plugins']['scheduler']: + global_scheduler_opts = self.shared_config['plugins']['scheduler'].get(scheduler_name, {}) + scheduler_specific_config = {**global_scheduler_opts, **sgroup_opts.config} + return load_scheduler(scheduler_name, sgroup_opts, scheduler_specific_config) + + async def _schedule_in_sgroup( + self, + sched_ctx: SchedulingContext, + sgroup_name: str, + ) -> None: + async with self.db.begin_readonly() as kernel_db_conn: + scheduler = await self._load_scheduler(kernel_db_conn, sgroup_name) + pending_session_rows, cancelled_session_rows = \ + await _list_pending_sessions(kernel_db_conn, scheduler, sgroup_name) + pending_sessions = PendingSession.from_rows(pending_session_rows) + existing_sessions = await _list_existing_sessions(kernel_db_conn, sgroup_name) + + if cancelled_session_rows: + now = datetime.now(tzutc()) + + async def _apply_cancellation(): + async with self.db.begin() as db_conn: + query = kernels.update().values({ + 'status': KernelStatus.CANCELLED, + 'status_changed': now, + 'status_info': "pending-timeout", + 'terminated_at': now, + }).where(kernels.c.session_id.in_([ + item['session_id'] for item in cancelled_session_rows + ])) + await db_conn.execute(query) + + await execute_with_retry(_apply_cancellation) + for item in cancelled_session_rows: + await self.event_producer.produce_event( + SessionCancelledEvent( + item['session_id'], + item['session_creation_id'], + reason="pending timeout", + ), + ) + + log.debug( + "running scheduler (sgroup:{}, pending:{}, existing:{}, cancelled:{})", + sgroup_name, len(pending_sessions), len(existing_sessions), len(cancelled_session_rows), + ) + zero = ResourceSlot() + num_scheduled = 0 + while len(pending_sessions) > 0: + + async with self.db.begin_readonly() as conn: + candidate_agents = await _list_agents_by_sgroup(conn, sgroup_name) + total_capacity = sum((ag.available_slots for ag in candidate_agents), zero) + + picked_session_id = scheduler.pick_session( + total_capacity, + pending_sessions, + existing_sessions, + ) + if picked_session_id is None: + # no session is picked. + # continue to next sgroup. + return + for picked_idx, sess_ctx in enumerate(pending_sessions): + if sess_ctx.session_id == picked_session_id: + break + else: + # no matching entry for picked session? + raise RuntimeError('should not reach here') + sess_ctx = pending_sessions.pop(picked_idx) + requested_architectures = set([ + x.image_ref.architecture for x in sess_ctx.kernels + ]) + candidate_agents = list( + filter( + lambda x: x.architecture in requested_architectures, + candidate_agents, + ), + ) + + log_fmt = 'schedule(s:{}, type:{}, name:{}, ak:{}, cluster_mode:{}): ' + log_args = ( + sess_ctx.session_id, + sess_ctx.session_type, + sess_ctx.session_name, + sess_ctx.access_key, + sess_ctx.cluster_mode, + ) + _log_fmt.set(log_fmt) + _log_args.set(log_args) + log.debug(log_fmt + 'try-scheduling', *log_args) + + async def _check_predicates() -> List[Tuple[str, Union[Exception, PredicateResult]]]: + check_results: List[Tuple[str, Union[Exception, PredicateResult]]] = [] + async with self.db.begin() as kernel_db_conn: + predicates: Sequence[Tuple[str, Awaitable[PredicateResult]]] = [ + ( + 'reserved_time', + check_reserved_batch_session(kernel_db_conn, sched_ctx, sess_ctx), + ), + ('concurrency', check_concurrency(kernel_db_conn, sched_ctx, sess_ctx)), + ('dependencies', check_dependencies(kernel_db_conn, sched_ctx, sess_ctx)), + ( + 'keypair_resource_limit', + check_keypair_resource_limit(kernel_db_conn, sched_ctx, sess_ctx), + ), + ( + 'user_group_resource_limit', + check_group_resource_limit(kernel_db_conn, sched_ctx, sess_ctx), + ), + ( + 'domain_resource_limit', + check_domain_resource_limit(kernel_db_conn, sched_ctx, sess_ctx), + ), + ( + 'scaling_group_resource_limit', + check_scaling_group(kernel_db_conn, sched_ctx, sess_ctx), + ), + ] + for predicate_name, check_coro in predicates: + try: + check_results.append((predicate_name, await check_coro)) + except DBAPIError: + raise + except Exception as e: + log.exception(log_fmt + 'predicate-error', *log_args) + check_results.append((predicate_name, e)) + return check_results + + check_results = await execute_with_retry(_check_predicates) + has_failure = False + has_permanent_failure = False + failed_predicates = [] + passed_predicates = [] + for predicate_name, result in check_results: + if isinstance(result, Exception): + has_failure = True + failed_predicates.append({ + 'name': predicate_name, + 'msg': repr(result), + }) + continue + if result.passed: + passed_predicates.append({ + 'name': predicate_name, + }) + else: + failed_predicates.append({ + 'name': predicate_name, + 'msg': result.message or "", + }) + has_failure = True + if result.permanent: + has_permanent_failure = True # noqa + if has_failure: + log.debug(log_fmt + 'predicate-checks-failed (temporary)', *log_args) + # TODO: handle has_permanent_failure as cancellation + # - An early implementation of it has caused DB query blocking due to + # the inclusion of the kernels.status field. :( + # Let's fix it. + + async def _update() -> None: + async with self.db.begin() as conn: + await _rollback_predicate_mutations( + conn, sched_ctx, sess_ctx, + ) + query = kernels.update().values({ + 'status_info': "predicate-checks-failed", + 'status_data': sql_json_increment( + kernels.c.status_data, + ('scheduler', 'retries'), + parent_updates={ + 'last_try': datetime.now(tzutc()).isoformat(), + 'failed_predicates': failed_predicates, + 'passed_predicates': passed_predicates, + }, + ), + }).where(kernels.c.id == sess_ctx.session_id) + await conn.execute(query) + + await execute_with_retry(_update) + # Predicate failures are *NOT* permanent errors. + # We need to retry the scheduling afterwards. + continue + else: + async def _update() -> None: + async with self.db.begin() as conn: + query = kernels.update().values({ + 'status_data': sql_json_merge( + kernels.c.status_data, + ('scheduler',), + { + 'last_try': datetime.now(tzutc()).isoformat(), + 'failed_predicates': failed_predicates, + 'passed_predicates': passed_predicates, + }, + ), + }).where(kernels.c.id == sess_ctx.session_id) + await conn.execute(query) + + await execute_with_retry(_update) + + if sess_ctx.cluster_mode == ClusterMode.SINGLE_NODE: + # Single node session can't have multiple containers with different arch + if len(requested_architectures) > 1: + raise GenericBadRequest( + 'Cannot assign multiple kernels with different architecture' + 'on single node session', + ) + requested_architecture = requested_architectures.pop() + candidate_agents = list( + filter( + lambda x: x.architecture == requested_architecture, + candidate_agents, + ), + ) + await self._schedule_single_node_session( + sched_ctx, + scheduler, + sgroup_name, + candidate_agents, + sess_ctx, + check_results, + ) + elif sess_ctx.cluster_mode == ClusterMode.MULTI_NODE: + await self._schedule_multi_node_session( + sched_ctx, + scheduler, + sgroup_name, + candidate_agents, + sess_ctx, + check_results, + ) + else: + raise RuntimeError( + f"should not reach here; unknown cluster_mode: {sess_ctx.cluster_mode}", + ) + num_scheduled += 1 + if num_scheduled > 0: + await self.event_producer.produce_event(DoPrepareEvent()) + + async def _schedule_single_node_session( + self, + sched_ctx: SchedulingContext, + scheduler: AbstractScheduler, + sgroup_name: str, + candidate_agents: Sequence[AgentContext], + sess_ctx: PendingSession, + check_results: List[Tuple[str, Union[Exception, PredicateResult]]], + ) -> None: + # Assign agent resource per session. + log_fmt = _log_fmt.get("") + log_args = _log_args.get(tuple()) + try: + # If sess_ctx.agent_id is already set for manual assignment by superadmin, + # skip assign_agent_for_session(). + agent_id = None + if sess_ctx.agent_id is not None: + agent_id = sess_ctx.agent_id + else: + agent_id = scheduler.assign_agent_for_session(candidate_agents, sess_ctx) + async with self.db.begin() as agent_db_conn: + query = ( + sa.select([agents.c.available_slots]) + .select_from(agents) + .where(agents.c.id == agent_id) + ) + available_agent_slots = (await agent_db_conn.execute(query)).scalar() + # if pass the available test + if available_agent_slots is None: + raise InstanceNotAvailable("There is no such agent.") + for key in available_agent_slots: + if available_agent_slots[key] >= sess_ctx.requested_slots[key]: + continue + else: + raise InstanceNotAvailable( + "The resource slot does not have the enough remaining capacity.", + ) + agent_alloc_ctx = await _reserve_agent( + sched_ctx, agent_db_conn, sgroup_name, agent_id, sess_ctx.requested_slots, + ) + except InstanceNotAvailable: + log.debug(log_fmt + 'no-available-instances', *log_args) + + async def _update() -> None: + async with self.db.begin() as kernel_db_conn: + await _rollback_predicate_mutations( + kernel_db_conn, sched_ctx, sess_ctx, + ) + query = kernels.update().values({ + 'status_info': "no-available-instances", + 'status_data': sql_json_increment( + kernels.c.status_data, + ('scheduler', 'retries'), + parent_updates={ + 'last_try': datetime.now(tzutc()).isoformat(), + }, + ), + }).where(kernels.c.id == sess_ctx.session_id) + await kernel_db_conn.execute(query) + + await execute_with_retry(_update) + raise + except Exception as e: + log.exception( + log_fmt + 'unexpected-error, during agent allocation', + *log_args, + ) + exc_data = convert_to_status_data(e) + + async def _update() -> None: + async with self.db.begin() as kernel_db_conn: + await _rollback_predicate_mutations( + kernel_db_conn, sched_ctx, sess_ctx, + ) + query = kernels.update().values({ + 'status_info': "scheduler-error", + 'status_data': exc_data, + }).where(kernels.c.id == sess_ctx.session_id) + await kernel_db_conn.execute(query) + + await execute_with_retry(_update) + raise + + async def _finalize_scheduled() -> None: + async with self.db.begin() as kernel_db_conn: + query = kernels.update().values({ + 'agent': agent_alloc_ctx.agent_id, + 'agent_addr': agent_alloc_ctx.agent_addr, + 'scaling_group': sgroup_name, + 'status': KernelStatus.SCHEDULED, + 'status_info': 'scheduled', + 'status_data': {}, + 'status_changed': datetime.now(tzutc()), + }).where(kernels.c.session_id == sess_ctx.session_id) + await kernel_db_conn.execute(query) + + await execute_with_retry(_finalize_scheduled) + await self.registry.event_producer.produce_event( + SessionScheduledEvent(sess_ctx.session_id, sess_ctx.session_creation_id), + ) + + async def _schedule_multi_node_session( + self, + sched_ctx: SchedulingContext, + scheduler: AbstractScheduler, + sgroup_name: str, + candidate_agents: Sequence[AgentContext], + sess_ctx: PendingSession, + check_results: List[Tuple[str, Union[Exception, PredicateResult]]], + ) -> None: + # Assign agent resource per kernel in the session. + log_fmt = _log_fmt.get() + log_args = _log_args.get() + agent_query_extra_conds = None + kernel_agent_bindings: List[KernelAgentBinding] = [] + async with self.db.begin() as agent_db_conn: + # This outer transaction is rolled back when any exception occurs inside, + # including scheduling failures of a kernel. + # It ensures that occupied_slots are recovered when there are partial + # scheduling failures. + for kernel in sess_ctx.kernels: + agent_alloc_ctx: AgentAllocationContext | None = None + try: + agent_id: AgentId | None + if kernel.agent_id is not None: + agent_id = kernel.agent_id + else: + # limit agent candidates with requested image architecture + candidate_agents = list( + filter( + lambda x: x.architecture == kernel.image_ref.architecture, + candidate_agents, + ), + ) + agent_id = scheduler.assign_agent_for_kernel(candidate_agents, kernel) + assert agent_id is not None + + query = ( + sa.select([agents.c.available_slots]) + .select_from(agents) + .where(agents.c.id == agent_id) + ) + available_agent_slots = (await agent_db_conn.execute(query)).scalar() + if available_agent_slots is None: + raise InstanceNotAvailable("There is no such agent.") + available_test_pass = False + for key in available_agent_slots: + if available_agent_slots[key] >= kernel.requested_slots[key]: + available_test_pass = True + continue + else: + raise InstanceNotAvailable( + "The resource slot does not have the enough remaining capacity.", + ) + if available_test_pass: + + async def _reserve() -> None: + nonlocal agent_alloc_ctx, candidate_agents + async with agent_db_conn.begin_nested(): + agent_alloc_ctx = await _reserve_agent( + sched_ctx, agent_db_conn, + sgroup_name, agent_id, kernel.requested_slots, + extra_conds=agent_query_extra_conds, + ) + candidate_agents = await _list_agents_by_sgroup( + agent_db_conn, sgroup_name, + ) + + await execute_with_retry(_reserve) + except InstanceNotAvailable: + log.debug(log_fmt + 'no-available-instances', *log_args) + + async def _update() -> None: + async with self.db.begin() as kernel_db_conn: + await _rollback_predicate_mutations( + kernel_db_conn, sched_ctx, sess_ctx, + ) + query = kernels.update().values({ + 'status_info': "no-available-instances", + 'status_data': sql_json_increment( + kernels.c.status_data, + ('scheduler', 'retries'), + parent_updates={ + 'last_try': datetime.now(tzutc()).isoformat(), + }, + ), + }).where(kernels.c.id == kernel.kernel_id) + await kernel_db_conn.execute(query) + + await execute_with_retry(_update) + raise + except Exception as e: + log.exception( + log_fmt + 'unexpected-error, during agent allocation', + *log_args, + ) + exc_data = convert_to_status_data(e) + + async def _update() -> None: + async with self.db.begin() as kernel_db_conn: + await _rollback_predicate_mutations( + kernel_db_conn, sched_ctx, sess_ctx, + ) + query = kernels.update().values({ + 'status_info': "scheduler-error", + 'status_data': exc_data, + }).where(kernels.c.id == kernel.kernel_id) + await kernel_db_conn.execute(query) + + await execute_with_retry(_update) + raise + else: + assert agent_alloc_ctx is not None + kernel_agent_bindings.append(KernelAgentBinding(kernel, agent_alloc_ctx)) + + assert len(kernel_agent_bindings) == len(sess_ctx.kernels) + # Proceed to PREPARING only when all kernels are successfully scheduled. + + async def _finalize_scheduled() -> None: + async with self.db.begin() as kernel_db_conn: + for binding in kernel_agent_bindings: + query = kernels.update().values({ + 'agent': binding.agent_alloc_ctx.agent_id, + 'agent_addr': binding.agent_alloc_ctx.agent_addr, + 'scaling_group': sgroup_name, + 'status': KernelStatus.SCHEDULED, + 'status_info': 'scheduled', + 'status_data': {}, + 'status_changed': datetime.now(tzutc()), + }).where(kernels.c.id == binding.kernel.kernel_id) + await kernel_db_conn.execute(query) + + await execute_with_retry(_finalize_scheduled) + await self.registry.event_producer.produce_event( + SessionScheduledEvent(sess_ctx.session_id, sess_ctx.session_creation_id), + ) + + async def prepare( + self, + context: None, + source: AgentId, + event: DoPrepareEvent, + ) -> None: + """ + Scan the scheduled sessions and perform the agent RPC calls to begin preparation of them. + Each RPC calls are done in separate asyncio tasks. + + Session status transition: SCHEDULED -> PREPARING + """ + known_slot_types = await self.shared_config.get_resource_slots() + sched_ctx = SchedulingContext( + self.registry, + known_slot_types, + ) + try: + async with self.lock_factory(LockID.LOCKID_PREPARE, 600): + now = datetime.now(tzutc()) + + async def _mark_session_preparing() -> Sequence[PendingSession]: + async with self.db.begin() as conn: + update_query = ( + sa.update(kernels) + .values({ + 'status': KernelStatus.PREPARING, + 'status_changed': now, + 'status_info': "", + 'status_data': {}, + }) + .where( + (kernels.c.status == KernelStatus.SCHEDULED), + ) + .returning(kernels.c.id) + ) + rows = (await conn.execute(update_query)).fetchall() + if len(rows) == 0: + return [] + target_kernel_ids = [r['id'] for r in rows] + select_query = ( + PendingSession.base_query() + .where( + kernels.c.id.in_(target_kernel_ids), + ) + ) + rows = (await conn.execute(select_query)).fetchall() + return PendingSession.from_rows(rows) + + scheduled_sessions = await execute_with_retry(_mark_session_preparing) + log.debug("prepare(): preparing {} session(s)", len(scheduled_sessions)) + async with aiotools.TaskGroup() as tg: + for scheduled_session in scheduled_sessions: + await self.registry.event_producer.produce_event( + SessionPreparingEvent( + scheduled_session.session_id, + scheduled_session.session_creation_id, + ), + ) + tg.create_task(self.start_session( + sched_ctx, + scheduled_session, + )) + + except DBAPIError as e: + if getattr(e.orig, 'pgcode', None) == '55P03': + log.info("prepare(): cancelled due to advisory lock timeout; " + "maybe another prepare() call is still running") + raise asyncio.CancelledError() + raise + + async def start_session( + self, + sched_ctx: SchedulingContext, + session: PendingSession, + ) -> None: + log_fmt = "prepare(s:{0.session_id}, type:{0.session_type}, name:{0.session_name}, " \ + "ak:{0.access_key}, cluster_mode:{0.cluster_mode}): " + log_args = (session, ) + log.debug(log_fmt + 'try-starting', *log_args) + try: + assert len(session.kernels) > 0 + await self.registry.start_session(sched_ctx, session) + except Exception as e: + status_data = convert_to_status_data(e, self.local_config['debug']['enabled']) + log.warning(log_fmt + 'failed-starting: {1!r}', *log_args, status_data) + # TODO: instead of instantly cancelling upon exception, we could mark it as + # SCHEDULED and retry within some limit using status_data. + + async def _mark_session_cancelled() -> None: + async with self.db.begin() as db_conn: + affected_agents = set(k.agent_id for k in session.kernels) + for agent_id in affected_agents: + await recalc_agent_resource_occupancy(db_conn, agent_id) + await _rollback_predicate_mutations(db_conn, sched_ctx, session) + now = datetime.now(tzutc()) + update_query = sa.update(kernels).values({ + 'status': KernelStatus.CANCELLED, + 'status_changed': now, + 'status_info': "failed-to-start", + 'status_data': status_data, + 'terminated_at': now, + }).where(kernels.c.session_id == session.session_id) + await db_conn.execute(update_query) + + log.debug(log_fmt + 'cleanup-start-failure: begin', *log_args) + try: + await execute_with_retry(_mark_session_cancelled) + await self.registry.event_producer.produce_event( + SessionCancelledEvent( + session.session_id, + session.session_creation_id, + "failed-to-start", + ), + ) + async with self.db.begin_readonly() as db_conn: + query = ( + sa.select([kernels.c.id, kernels.c.container_id]) + .where(kernels.c.session_id == session.session_id) + ) + rows = (await db_conn.execute(query)).fetchall() + cid_map = {row['id']: row['container_id'] for row in rows} + destroyed_kernels = [ + { + "agent": k.agent_id, + "agent_addr": k.agent_addr, + "id": k.kernel_id, + "container_id": cid_map[k.kernel_id], + } + for k in session.kernels + ] + await self.registry.destroy_session_lowlevel( + session.session_id, destroyed_kernels, + ) + await self.registry.recalc_resource_usage() + except Exception as destroy_err: + log.error(log_fmt + 'cleanup-start-failure: error', *log_args, exc_info=destroy_err) + finally: + log.debug(log_fmt + 'cleanup-start-failure: done', *log_args) + else: + log.info(log_fmt + 'started', *log_args) + + +async def _list_pending_sessions( + db_conn: SAConnection, + scheduler: AbstractScheduler, + sgroup_name: str, +) -> tuple[list[Row], list[Row]]: + """ + Return two lists of pending sessions and to-be-cancelled sessions due to pending timeout. + """ + pending_timeout: timedelta = scheduler.sgroup_opts.pending_timeout + query = ( + PendingSession.base_query() + .where( + (kernels.c.status == KernelStatus.PENDING) & + ( + (kernels.c.scaling_group == sgroup_name) + ), + ) + ) + rows = (await db_conn.execute(query)).fetchall() + candidate_rows = [] + cancelled_rows = [] + now = datetime.now(tzutc()) + for row in rows: + elapsed_pending_time = now - row['created_at'] + if pending_timeout.total_seconds() > 0 and elapsed_pending_time >= pending_timeout: + cancelled_rows.append(row) + else: + candidate_rows.append(row) + return candidate_rows, cancelled_rows + + +async def _list_existing_sessions( + db_conn: SAConnection, + sgroup_name: str, +) -> List[ExistingSession]: + query = ( + ExistingSession.base_query() + .where( + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & + (kernels.c.scaling_group == sgroup_name), + ) + ) + rows = (await db_conn.execute(query)).fetchall() + return ExistingSession.from_rows(rows) + + +async def _list_agents_by_sgroup( + db_conn: SAConnection, + sgroup_name: str, +) -> Sequence[AgentContext]: + query = ( + sa.select([ + agents.c.id, + agents.c.architecture, + agents.c.addr, + agents.c.scaling_group, + agents.c.available_slots, + agents.c.occupied_slots, + ]) + .select_from(agents) + .where( + (agents.c.status == AgentStatus.ALIVE) & + (agents.c.scaling_group == sgroup_name) & + (agents.c.schedulable == true()), + ) + ) + items = [] + for row in (await db_conn.execute(query)): + item = AgentContext( + row['id'], + row['addr'], + row['architecture'], + row['scaling_group'], + row['available_slots'], + row['occupied_slots'], + ) + items.append(item) + return items + + +async def _reserve_agent( + sched_ctx: SchedulingContext, + db_conn: SAConnection, + scaling_group: str, + agent_id: Optional[AgentId], + requested_slots: ResourceSlot, + extra_conds: Any = None, +) -> AgentAllocationContext: + query = ( + sa.select([agents.c.occupied_slots]) + .select_from(agents) + .where(agents.c.id == agent_id) + .with_for_update() + ) + if extra_conds is not None: + query = query.where(extra_conds) + current_occupied_slots = (await db_conn.execute(query)).scalar() + if current_occupied_slots is None: + raise RuntimeError(f"No agent matching condition: {extra_conds}") + update_query = ( + sa.update(agents) + .values({ + 'occupied_slots': current_occupied_slots + requested_slots, + }) + .where(agents.c.id == agent_id) + ) + await db_conn.execute(update_query) + # Get the agent address for later RPC calls + query = (sa.select([agents.c.addr]) + .where(agents.c.id == agent_id)) + agent_addr = await db_conn.scalar(query) + assert agent_addr is not None + return AgentAllocationContext(agent_id, agent_addr, scaling_group) + + +async def _rollback_predicate_mutations( + db_conn: SAConnection, + sched_ctx: SchedulingContext, + session: PendingSession, +) -> None: + """ + Rollback any changes performed by predicates. + + NOTE: We don't use the DB-level transaction rollback because we need to + store the "ERROR" status to corresponding rows in the kernels table. + """ + + # Instead of decrementing concurrency_used, we recalculate the access_key's usage, + # because asynchronous container launch failures and agent failures + # (especially with multi-node multi-container cluster sessions) + # may accumulate up multiple subtractions, resulting in + # negative concurrency_occupied values. + await recalc_concurrency_used(db_conn, sched_ctx.registry.redis_stat, session.access_key) diff --git a/src/ai/backend/manager/scheduler/drf.py b/src/ai/backend/manager/scheduler/drf.py new file mode 100644 index 0000000000..e21e306d5e --- /dev/null +++ b/src/ai/backend/manager/scheduler/drf.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from collections import defaultdict +from decimal import Decimal +import logging +from typing import ( + Any, Optional, + Dict, + Sequence, + Mapping, + Set, +) + +import trafaret as t + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + AccessKey, AgentId, + ResourceSlot, + SessionId, +) + +from ..models.scaling_group import ScalingGroupOpts +from .types import ( + AbstractScheduler, + AgentContext, + PendingSession, + ExistingSession, + KernelInfo, +) + +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.scheduler')) + + +class DRFScheduler(AbstractScheduler): + + config_iv = t.Dict({}).allow_extra('*') + per_user_dominant_share: Dict[AccessKey, Decimal] + total_capacity: ResourceSlot + + def __init__(self, sgroup_opts: ScalingGroupOpts, config: Mapping[str, Any]) -> None: + super().__init__(sgroup_opts, config) + self.per_user_dominant_share = defaultdict(lambda: Decimal(0)) + + def pick_session( + self, + total_capacity: ResourceSlot, + pending_sessions: Sequence[PendingSession], + existing_sessions: Sequence[ExistingSession], + ) -> Optional[SessionId]: + self.total_capacity = total_capacity + + # Calculate the initial dominant shares of all users. + for existing_sess in existing_sessions: + dominant_share = Decimal(0) + self.total_capacity.sync_keys(existing_sess.occupying_slots) + for slot, value in existing_sess.occupying_slots.items(): + slot_cap = Decimal(self.total_capacity[slot]) + if slot_cap == 0: + continue + slot_share = Decimal(value) / slot_cap + if dominant_share < slot_share: + dominant_share = slot_share + if self.per_user_dominant_share[existing_sess.access_key] < dominant_share: + self.per_user_dominant_share[existing_sess.access_key] = dominant_share + log.debug('per-user dominant share: {}', dict(self.per_user_dominant_share)) + + # Find who has the least dominant share among the pending session. + users_with_pending_session: Set[AccessKey] = { + pending_sess.access_key for pending_sess in pending_sessions + } + if not users_with_pending_session: + return None + least_dominant_share_user, dshare = min( + ((akey, self.per_user_dominant_share[akey]) + for akey in users_with_pending_session), + key=lambda item: item[1]) + log.debug('least dominant share user: {} ({})', least_dominant_share_user, dshare) + + # Pick the first pending session of the user + # who has the lowest dominant share. + for pending_sess in pending_sessions: + if pending_sess.access_key == least_dominant_share_user: + return SessionId(pending_sess.session_id) + + return None + + def _assign_agent( + self, + agents: Sequence[AgentContext], + access_key: AccessKey, + requested_slots: ResourceSlot, + ) -> Optional[AgentId]: + # If some predicate checks for a picked session fail, + # this method is NOT called at all for the picked session. + # In such case, we just skip updating self.per_user_dominant_share state + # and the scheduler dispatcher continues to pick another session within the same scaling group. + + possible_agents = [] + for agent in agents: + remaining_slots = agent.available_slots - agent.occupied_slots + if remaining_slots >= requested_slots: + possible_agents.append(agent) + + if possible_agents: + # We have one or more agents that can host the picked session. + + # Update the dominant share. + # This is required to use to the latest dominant share information + # when iterating over multiple pending sessions in a single scaling group. + dominant_share_from_request = Decimal(0) + for slot, value in requested_slots.items(): + self.total_capacity.sync_keys(requested_slots) + slot_cap = Decimal(self.total_capacity[slot]) + if slot_cap == 0: + continue + slot_share = Decimal(value) / slot_cap + if dominant_share_from_request < slot_share: + dominant_share_from_request = slot_share + if self.per_user_dominant_share[access_key] < dominant_share_from_request: + self.per_user_dominant_share[access_key] = dominant_share_from_request + + # Choose the agent. + chosen_agent = \ + max(possible_agents, key=lambda a: a.available_slots) + return chosen_agent.agent_id + + return None + + def assign_agent_for_session( + self, + agents: Sequence[AgentContext], + pending_session: PendingSession, + ) -> Optional[AgentId]: + return self._assign_agent( + agents, pending_session.access_key, pending_session.requested_slots, + ) + + def assign_agent_for_kernel( + self, + agents: Sequence[AgentContext], + pending_kernel: KernelInfo, + ) -> Optional[AgentId]: + return self._assign_agent( + agents, pending_kernel.access_key, pending_kernel.requested_slots, + ) diff --git a/src/ai/backend/manager/scheduler/fifo.py b/src/ai/backend/manager/scheduler/fifo.py new file mode 100644 index 0000000000..ffa4d79333 --- /dev/null +++ b/src/ai/backend/manager/scheduler/fifo.py @@ -0,0 +1,166 @@ +from __future__ import annotations +from decimal import Decimal +from typing import ( + List, + Optional, + Sequence, + Tuple, +) + +import trafaret as t + +from ai.backend.common.types import ( + AgentId, + ResourceSlot, + SessionId, +) +from .types import ( + AbstractScheduler, + AgentContext, + PendingSession, + ExistingSession, + KernelInfo, +) + + +def key_by_requested_slots( + agent: AgentContext, + requested_slots: ResourceSlot, +) -> Tuple[int, ResourceSlot]: + unused_slot_keys = set() + for k, v in requested_slots.items(): + if v == Decimal(0): + unused_slot_keys.add(k) + num_extras = 0 + for k, v in agent.available_slots.items(): + if k in unused_slot_keys and v > Decimal(0): + num_extras += 1 + # Put back agents with more extra slot types + # (e.g., accelerators) + # Also put front agents with exactly required slot types + return (-num_extras, agent.available_slots) + + +class FIFOSlotScheduler(AbstractScheduler): + + config_iv = t.Dict({ + t.Key('num_retries_to_skip', default=0): t.ToInt(gte=0), + }).allow_extra('*') + + def pick_session( + self, + total_capacity: ResourceSlot, + pending_sessions: Sequence[PendingSession], + existing_sessions: Sequence[ExistingSession], + ) -> Optional[SessionId]: + local_pending_sessions = list(pending_sessions) + skipped_sessions: List[PendingSession] = [] + max_retries = self.config['num_retries_to_skip'] + while local_pending_sessions: + # Just pick the first pending session, but skip it + # if it has more than 3 failures. + s = local_pending_sessions.pop(0) + if max_retries == 0: # it's strict FIFO + return s.session_id + if s.status_data is not None: + sched_data = s.status_data.get('scheduler', {}) + if sched_data.get('retries', 0) >= max_retries: + skipped_sessions.append(s) + continue + return s.session_id + # But if all sessions are skipped, then choose the first one. + if skipped_sessions: + return skipped_sessions[0].session_id + return None + + def _assign_agent( + self, + agents: Sequence[AgentContext], + requested_slots: ResourceSlot, + ) -> Optional[AgentId]: + possible_agents = [] + for agent in agents: + remaining_slots = agent.available_slots - agent.occupied_slots + if remaining_slots >= requested_slots: + possible_agents.append(agent) + if possible_agents: + chosen_agent = max( + possible_agents, + key=lambda a: key_by_requested_slots( + a, + requested_slots, + ), + ) + return chosen_agent.agent_id + return None + + def assign_agent_for_session( + self, + agents: Sequence[AgentContext], + pending_session: PendingSession, + ) -> Optional[AgentId]: + return self._assign_agent( + agents, pending_session.requested_slots, + ) + + def assign_agent_for_kernel( + self, + agents: Sequence[AgentContext], + pending_kernel: KernelInfo, + ) -> Optional[AgentId]: + return self._assign_agent( + agents, pending_kernel.requested_slots, + ) + + +class LIFOSlotScheduler(AbstractScheduler): + + config_iv = t.Dict({}).allow_extra('*') + + def pick_session( + self, + total_capacity: ResourceSlot, + pending_sessions: Sequence[PendingSession], + existing_sessions: Sequence[ExistingSession], + ) -> Optional[SessionId]: + # Just pick the last pending session. + return SessionId(pending_sessions[-1].session_id) + + def _assign_agent( + self, + agents: Sequence[AgentContext], + requested_slots: ResourceSlot, + ) -> Optional[AgentId]: + possible_agents = [] + for agent in agents: + remaining_slots = agent.available_slots - agent.occupied_slots + if remaining_slots >= requested_slots: + possible_agents.append(agent) + if possible_agents: + chosen_agent = max( + possible_agents, + key=lambda a: key_by_requested_slots( + a, + requested_slots, + ), + ) + return chosen_agent.agent_id + return None + + def assign_agent_for_session( + self, + agents: Sequence[AgentContext], + pending_session: PendingSession, + ) -> Optional[AgentId]: + return self._assign_agent( + agents, pending_session.requested_slots, + ) + + def assign_agent_for_kernel( + self, + agents: Sequence[AgentContext], + pending_kernel: KernelInfo, + ) -> Optional[AgentId]: + return self._assign_agent( + agents, pending_kernel.requested_slots, + ) diff --git a/src/ai/backend/manager/scheduler/mof.py b/src/ai/backend/manager/scheduler/mof.py new file mode 100644 index 0000000000..4b1f6af73b --- /dev/null +++ b/src/ai/backend/manager/scheduler/mof.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import ( + Optional, + Sequence, +) + +import trafaret as t + +from ai.backend.common.types import ( + AccessKey, + AgentId, + SessionId, + ResourceSlot, +) + +from .types import ( + AbstractScheduler, + PendingSession, + ExistingSession, + AgentContext, + KernelInfo, +) + + +class MOFScheduler(AbstractScheduler): + """Minimum Occupied slot First Scheduler""" + + config_iv = t.Dict({}).allow_extra('*') + + def pick_session( + self, + total_capacity: ResourceSlot, + pending_sessions: Sequence[PendingSession], + existing_sessions: Sequence[ExistingSession], + ) -> Optional[SessionId]: + # Just pick the first pending session. + return SessionId(pending_sessions[0].session_id) + + def _assign_agent( + self, + agents: Sequence[AgentContext], + access_key: AccessKey, + requested_slots: ResourceSlot, + ) -> Optional[AgentId]: + # return min occupied slot agent or None + return next((one_agent.agent_id for one_agent in (sorted( + (agent for agent in agents if ( + (agent.available_slots - agent.occupied_slots) + >= requested_slots + )), + key=lambda a: a.occupied_slots) + )), None) + + def assign_agent_for_session( + self, + agents: Sequence[AgentContext], + pending_session: PendingSession, + ) -> Optional[AgentId]: + return self._assign_agent( + agents, pending_session.access_key, pending_session.requested_slots, + ) + + def assign_agent_for_kernel( + self, + agents: Sequence[AgentContext], + pending_kernel: KernelInfo, + ) -> Optional[AgentId]: + return self._assign_agent( + agents, pending_kernel.access_key, pending_kernel.requested_slots, + ) diff --git a/src/ai/backend/manager/scheduler/predicates.py b/src/ai/backend/manager/scheduler/predicates.py new file mode 100644 index 0000000000..91d22751ac --- /dev/null +++ b/src/ai/backend/manager/scheduler/predicates.py @@ -0,0 +1,299 @@ +from datetime import datetime +import logging +from typing import ( + List, +) + +from dateutil.tz import tzutc +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection + +from ai.backend.common import redis +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import ( + ResourceSlot, + SessionResult, + SessionTypes, +) + +from ..models import ( + domains, groups, kernels, + keypair_resource_policies, + session_dependencies, + query_allowed_sgroups, + DefaultForUnspecified, +) +from ..models.utils import execute_with_retry, reenter_txn +from .types import ( + SchedulingContext, + PendingSession, + PredicateResult, +) + +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.scheduler')) + +_check_keypair_concurrency_script = ''' +local key = KEYS[1] +local limit = tonumber(ARGV[1]) +local result = {} +redis.call('SETNX', key, 0) +local count = tonumber(redis.call('GET', key)) +if limit > 0 and count >= limit then + result[1] = 0 + result[2] = count + return result +end +redis.call('INCR', key) +result[1] = 1 +result[2] = count + 1 +return result +''' + + +async def check_reserved_batch_session( + db_conn: SAConnection, + sched_ctx: SchedulingContext, + sess_ctx: PendingSession, +) -> PredicateResult: + """ + Check if a batch-type session should not be started for a certain amount of time. + """ + if sess_ctx.session_type == SessionTypes.BATCH: + query = ( + sa.select([kernels.c.starts_at]) + .select_from(kernels) + .where(kernels.c.id == sess_ctx.session_id) + ) + starts_at = await db_conn.scalar(query) + if starts_at is not None and datetime.now(tzutc()) < starts_at: + return PredicateResult( + False, + 'Before start time', + ) + return PredicateResult(True) + + +async def check_concurrency( + db_conn: SAConnection, + sched_ctx: SchedulingContext, + sess_ctx: PendingSession, +) -> PredicateResult: + + async def _get_max_concurrent_sessions() -> int: + select_query = ( + sa.select([keypair_resource_policies]) + .select_from(keypair_resource_policies) + .where(keypair_resource_policies.c.name == sess_ctx.resource_policy) + ) + result = await db_conn.execute(select_query) + return result.first()['max_concurrent_sessions'] + + max_concurrent_sessions = await execute_with_retry(_get_max_concurrent_sessions) + ok, concurrency_used = await redis.execute_script( + sched_ctx.registry.redis_stat, + 'check_keypair_concurrency_used', + _check_keypair_concurrency_script, + [f"keypair.concurrency_used.{sess_ctx.access_key}"], + [max_concurrent_sessions], + ) + if ok == 0: + return PredicateResult( + False, + "You cannot run more than " + f"{max_concurrent_sessions} concurrent sessions", + ) + log.debug( + 'number of concurrent sessions of ak:{0} = {1} / {2}', + sess_ctx.access_key, + concurrency_used, + max_concurrent_sessions, + ) + return PredicateResult(True) + + +async def check_dependencies( + db_conn: SAConnection, + sched_ctx: SchedulingContext, + sess_ctx: PendingSession, +) -> PredicateResult: + j = sa.join( + session_dependencies, + kernels, + session_dependencies.c.depends_on == kernels.c.session_id, + ) + query = ( + sa.select([ + kernels.c.session_id, + kernels.c.session_name, + kernels.c.result, + ]) + .select_from(j) + .where(session_dependencies.c.session_id == sess_ctx.session_id) + ) + result = await db_conn.execute(query) + rows = result.fetchall() + pending_dependencies = [] + for row in rows: + if row['result'] != SessionResult.SUCCESS: + pending_dependencies.append(row) + all_success = (not pending_dependencies) + if all_success: + return PredicateResult(True) + return PredicateResult( + False, + "Waiting dependency sessions to finish as success. ({})".format( + ", ".join(f"{row['session_name']} ({row['session_id']})" for row in pending_dependencies), + ), + ) + + +async def check_keypair_resource_limit( + db_conn: SAConnection, + sched_ctx: SchedulingContext, + sess_ctx: PendingSession, +) -> PredicateResult: + query = ( + sa.select([keypair_resource_policies]) + .select_from(keypair_resource_policies) + .where(keypair_resource_policies.c.name == sess_ctx.resource_policy) + ) + result = await db_conn.execute(query) + resource_policy = result.first() + total_keypair_allowed = ResourceSlot.from_policy(resource_policy, + sched_ctx.known_slot_types) + key_occupied = await sched_ctx.registry.get_keypair_occupancy( + sess_ctx.access_key, conn=db_conn) + log.debug('keypair:{} current-occupancy: {}', sess_ctx.access_key, key_occupied) + log.debug('keypair:{} total-allowed: {}', sess_ctx.access_key, total_keypair_allowed) + if not (key_occupied + sess_ctx.requested_slots <= total_keypair_allowed): + return PredicateResult( + False, + "Your keypair resource quota is exceeded. ({})" + .format(' '.join( + f'{k}={v}' for k, v in + total_keypair_allowed.to_humanized(sched_ctx.known_slot_types).items() + )), + ) + return PredicateResult(True) + + +async def check_group_resource_limit( + db_conn: SAConnection, + sched_ctx: SchedulingContext, + sess_ctx: PendingSession, +) -> PredicateResult: + query = (sa.select([groups.c.total_resource_slots]) + .where(groups.c.id == sess_ctx.group_id)) + group_resource_slots = await db_conn.scalar(query) + group_resource_policy = {'total_resource_slots': group_resource_slots, + 'default_for_unspecified': DefaultForUnspecified.UNLIMITED} + total_group_allowed = ResourceSlot.from_policy(group_resource_policy, + sched_ctx.known_slot_types) + group_occupied = await sched_ctx.registry.get_group_occupancy( + sess_ctx.group_id, conn=db_conn) + log.debug('group:{} current-occupancy: {}', sess_ctx.group_id, group_occupied) + log.debug('group:{} total-allowed: {}', sess_ctx.group_id, total_group_allowed) + if not (group_occupied + sess_ctx.requested_slots <= total_group_allowed): + return PredicateResult( + False, + "Your group resource quota is exceeded. ({})" + .format(' '.join( + f'{k}={v}' for k, v in + total_group_allowed.to_humanized(sched_ctx.known_slot_types).items() + )), + ) + return PredicateResult(True) + + +async def check_domain_resource_limit( + db_conn: SAConnection, + sched_ctx: SchedulingContext, + sess_ctx: PendingSession, +) -> PredicateResult: + query = (sa.select([domains.c.total_resource_slots]) + .where(domains.c.name == sess_ctx.domain_name)) + domain_resource_slots = await db_conn.scalar(query) + domain_resource_policy = { + 'total_resource_slots': domain_resource_slots, + 'default_for_unspecified': DefaultForUnspecified.UNLIMITED, + } + total_domain_allowed = ResourceSlot.from_policy(domain_resource_policy, + sched_ctx.known_slot_types) + domain_occupied = await sched_ctx.registry.get_domain_occupancy( + sess_ctx.domain_name, conn=db_conn) + log.debug('domain:{} current-occupancy: {}', sess_ctx.domain_name, domain_occupied) + log.debug('domain:{} total-allowed: {}', sess_ctx.domain_name, total_domain_allowed) + if not (domain_occupied + sess_ctx.requested_slots <= total_domain_allowed): + return PredicateResult( + False, + 'Your domain resource quota is exceeded. ({})' + .format(' '.join( + f'{k}={v}' for k, v in + total_domain_allowed.to_humanized(sched_ctx.known_slot_types).items() + )), + ) + return PredicateResult(True) + + +async def check_scaling_group( + db_conn: SAConnection, + sched_ctx: SchedulingContext, + sess_ctx: PendingSession, +) -> PredicateResult: + + async def _query(): + async with reenter_txn(sched_ctx.registry.db, db_conn) as _conn: + return await query_allowed_sgroups( + _conn, + sess_ctx.domain_name, + sess_ctx.group_id, + sess_ctx.access_key, + ) + + sgroups = await execute_with_retry(_query) + if not sgroups: + return PredicateResult( + False, + "You do not have any scaling groups allowed to use.", + permanent=True, + ) + target_sgroup_names: List[str] = [] + preferred_sgroup_name = sess_ctx.scaling_group + if preferred_sgroup_name is not None: + # Consider only the preferred scaling group. + for sgroup in sgroups: + if preferred_sgroup_name == sgroup['name']: + break + else: + return PredicateResult( + False, + f"You do not have access to the scaling group '{preferred_sgroup_name}'.", + permanent=True, + ) + allowed_session_types = sgroup['scheduler_opts'].allowed_session_types + if sess_ctx.session_type.value.lower() not in allowed_session_types: + return PredicateResult( + False, + f"The scaling group '{preferred_sgroup_name}' does not accept " + f"the session type '{sess_ctx.session_type}'. ", + permanent=True, + ) + target_sgroup_names = [preferred_sgroup_name] + else: + # Consider all allowed scaling groups. + usable_sgroups = [] + for sgroup in sgroups: + allowed_session_types = sgroup['scheduler_opts'].allowed_session_types + if sess_ctx.session_type.value.lower() in allowed_session_types: + usable_sgroups.append(sgroup) + if not usable_sgroups: + return PredicateResult( + False, + f"No scaling groups accept the session type '{sess_ctx.session_type}'.", + permanent=True, + ) + target_sgroup_names = [sgroup['name'] for sgroup in usable_sgroups] + assert target_sgroup_names + log.debug("scaling groups considered for s:{} are {}", sess_ctx.session_id, target_sgroup_names) + sess_ctx.target_sgroup_names.extend(target_sgroup_names) + return PredicateResult(True) diff --git a/src/ai/backend/manager/scheduler/types.py b/src/ai/backend/manager/scheduler/types.py new file mode 100644 index 0000000000..cfd8c33dd0 --- /dev/null +++ b/src/ai/backend/manager/scheduler/types.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +import logging +from typing import ( + Any, + Dict, + List, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Protocol, + Sequence, + Set, +) +import uuid + +import attr +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +from sqlalchemy.sql import Select, ColumnElement +from sqlalchemy.engine.row import Row +from datetime import datetime +import trafaret as t + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.docker import ( + ImageRef, +) +from ai.backend.common.types import ( + AccessKey, + AgentId, + ClusterMode, + KernelId, + SessionId, + SessionTypes, + ResourceSlot, + SlotName, + SlotTypes, + VFolderMount, +) + +from ..defs import DEFAULT_ROLE +from ..models import ( + kernels, keypairs, +) +from ..models.scaling_group import ScalingGroupOpts +from ..registry import AgentRegistry + +log = BraceStyleAdapter(logging.getLogger('ai.backend.manager.scheduler')) + + +def merge_resource( + target: MutableMapping[str, Any], + update: MutableMapping[str, Any], +) -> None: + for k in update.keys(): + if k in target.keys(): + target[k] += update[k] + else: + target[k] = update[k] + + +@attr.s(auto_attribs=True, slots=True) +class AgentAllocationContext: + agent_id: Optional[AgentId] + agent_addr: str + scaling_group: str + + +@attr.s(auto_attribs=True, slots=True) +class AgentContext: + agent_id: AgentId + agent_addr: str + architecture: str + scaling_group: str + available_slots: ResourceSlot + occupied_slots: ResourceSlot + + +@attr.s(auto_attribs=True, slots=True) +class ScheduleDecision: + agent_id: AgentId + kernel_id: KernelId + + +@attr.s(auto_attribs=True, slots=True) +class SchedulingContext: + """ + Context for each scheduling decision. + """ + registry: AgentRegistry + known_slot_types: Mapping[SlotName, SlotTypes] + + +@attr.s(auto_attribs=True, slots=True) +class ExistingSession: + kernels: List[KernelInfo] + access_key: AccessKey + session_id: uuid.UUID + session_type: SessionTypes + session_name: str + cluster_mode: ClusterMode + cluster_size: int + domain_name: str + group_id: uuid.UUID + scaling_group: str + occupying_slots: ResourceSlot + + @classmethod + def db_cols(cls) -> Set[ColumnElement]: + return { + kernels.c.id, + kernels.c.status, + kernels.c.access_key, + kernels.c.session_id, + kernels.c.session_type, + kernels.c.session_name, + kernels.c.cluster_mode, + kernels.c.cluster_size, + kernels.c.cluster_role, + kernels.c.domain_name, + kernels.c.group_id, + kernels.c.scaling_group, + kernels.c.occupied_slots, + } + + @classmethod + def base_query(cls) -> Select: + return ( + sa.select( + list(cls.db_cols() | KernelInfo.db_cols()), + ) + .select_from(kernels) + .order_by(kernels.c.created_at) + ) + + @classmethod + def from_row(cls, row: Row) -> ExistingSession: + return ExistingSession( + kernels=[], + access_key=row['access_key'], + session_id=row['session_id'], + session_type=row['session_type'], + session_name=row['session_name'], + cluster_mode=row['cluster_mode'], + cluster_size=row['cluster_size'], + domain_name=row['domain_name'], + group_id=row['group_id'], + scaling_group=row['scaling_group'], + occupying_slots=ResourceSlot(), + ) + + @classmethod + def from_rows(cls, rows: Sequence[Row]) -> List[ExistingSession]: + items: Dict[str, ExistingSession] = {} + for row in rows: + if row['cluster_role'] == "main": + items[row['session_id']] = cls.from_row(row) + for row in rows: + session_id = row['session_id'] + if session_id not in items: + # In some cases, sub containers are still RUNNING + # even though main container is TERMINATED. + # To circumvent this edge case, we skip if main container + # is not registered in `items`. + continue + session = items[session_id] + session.kernels.append(KernelInfo.from_row(row)) + session.occupying_slots += row['occupied_slots'] # type: ignore + return list(items.values()) + + +@attr.s(auto_attribs=True, slots=True) +class PendingSession: + """ + Context for individual session-related information used during scheduling. + Resource parameters defined here should contain total amount of resources + for all kernels in one session. + """ + kernels: List[KernelInfo] + access_key: AccessKey + agent_id: AgentId + agent_addr: str + session_id: SessionId + session_creation_id: str + session_type: SessionTypes + session_name: str + cluster_mode: ClusterMode + cluster_size: int + domain_name: str + group_id: uuid.UUID + status_data: Mapping[str, Any] + scaling_group: str + resource_policy: str + resource_opts: Mapping[str, Any] + requested_slots: ResourceSlot + target_sgroup_names: MutableSequence[str] + environ: MutableMapping[str, str] + vfolder_mounts: Sequence[VFolderMount] + bootstrap_script: Optional[str] + startup_command: Optional[str] + internal_data: Optional[MutableMapping[str, Any]] + preopen_ports: List[int] + created_at: datetime + + @property + def main_kernel_id(self) -> KernelId: + for k in self.kernels: + if k.cluster_role == DEFAULT_ROLE: + return k.kernel_id + raise RuntimeError('Unable to get the main kernel ID') + + @classmethod + def db_cols(cls) -> Set[ColumnElement]: + return { + kernels.c.id, + kernels.c.access_key, + kernels.c.agent, + kernels.c.agent_addr, + kernels.c.session_creation_id, + kernels.c.session_id, + kernels.c.session_type, + kernels.c.session_name, + kernels.c.cluster_mode, + kernels.c.cluster_size, + kernels.c.domain_name, + kernels.c.group_id, + kernels.c.status_data, + kernels.c.scaling_group, + keypairs.c.resource_policy, + kernels.c.occupied_slots, + kernels.c.internal_data, + kernels.c.resource_opts, + kernels.c.environ, + kernels.c.vfolder_mounts, + kernels.c.bootstrap_script, + kernels.c.startup_command, + kernels.c.preopen_ports, + kernels.c.created_at, + } + + @classmethod + def base_query(cls) -> Select: + return ( + sa.select( + list(cls.db_cols() | KernelInfo.db_cols()), + ) + .select_from(sa.join( + kernels, keypairs, + keypairs.c.access_key == kernels.c.access_key, + )) + .order_by(kernels.c.created_at) + ) + + @classmethod + def from_row(cls, row: Row) -> PendingSession: + return cls( + kernels=[], + access_key=row['access_key'], + agent_id=row['agent'], + agent_addr=row['agent_addr'], + session_creation_id=row['session_creation_id'], + session_id=row['session_id'], + session_type=row['session_type'], + session_name=row['session_name'], + cluster_mode=row['cluster_mode'], + cluster_size=row['cluster_size'], + domain_name=row['domain_name'], + group_id=row['group_id'], + status_data=row['status_data'], + scaling_group=row['scaling_group'], + resource_policy=row['resource_policy'], + resource_opts={}, + requested_slots=ResourceSlot(), + internal_data=row['internal_data'], + target_sgroup_names=[], + environ={ + k: v for k, v + in map(lambda s: s.split('=', maxsplit=1), row['environ']) + }, + vfolder_mounts=row['vfolder_mounts'], + bootstrap_script=row['bootstrap_script'], + startup_command=row['startup_command'], + preopen_ports=row['preopen_ports'], + created_at=row['created_at'], + ) + + @classmethod + def from_rows(cls, rows: Sequence[Row]) -> List[PendingSession]: + items: Dict[SessionId, PendingSession] = {} + for row in rows: + if row['cluster_role'] == "main": + items[row['session_id']] = cls.from_row(row) + for row in rows: + session = items[row['session_id']] + session.kernels.append(KernelInfo.from_row(row)) + session.requested_slots += row['occupied_slots'] # type: ignore + merge_resource(session.resource_opts, row['resource_opts']) # type: ignore + return list(items.values()) + + +@attr.s(auto_attribs=True, slots=True) +class KernelInfo: + """ + Representing invididual kernel info. + Resource parameters defined here should contain single value of resource + for each kernel. + """ + kernel_id: KernelId + session_id: SessionId + access_key: AccessKey + agent_id: AgentId + agent_addr: str + cluster_role: str + cluster_idx: int + cluster_hostname: str + image_ref: ImageRef + resource_opts: Mapping[str, Any] + requested_slots: ResourceSlot + bootstrap_script: Optional[str] + startup_command: Optional[str] + created_at: datetime + + def __str__(self): + return f'{self.kernel_id}#{self.cluster_role}{self.cluster_idx}' + + @classmethod + def db_cols(cls) -> Set[ColumnElement]: + return { + kernels.c.id, + kernels.c.session_id, + kernels.c.access_key, + kernels.c.agent, # for scheduled kernels + kernels.c.agent_addr, # for scheduled kernels + kernels.c.cluster_role, + kernels.c.cluster_idx, + kernels.c.cluster_hostname, + kernels.c.image, + kernels.c.architecture, + kernels.c.registry, + kernels.c.resource_opts, + kernels.c.occupied_slots, + kernels.c.bootstrap_script, + kernels.c.startup_command, + kernels.c.created_at, + } + + @classmethod + def from_row(cls, row: Row) -> KernelInfo: + return cls( + kernel_id=row['id'], + session_id=row['session_id'], + access_key=row['access_key'], + agent_id=row['agent'], + agent_addr=row['agent_addr'], + cluster_role=row['cluster_role'], + cluster_idx=row['cluster_idx'], + cluster_hostname=row['cluster_hostname'], + image_ref=ImageRef(row['image'], [row['registry']], row['architecture']), + resource_opts=row['resource_opts'], + requested_slots=row['occupied_slots'], + bootstrap_script=row['bootstrap_script'], + startup_command=row['startup_command'], + created_at=row['created_at'], + ) + + +@attr.s(auto_attribs=True, slots=True) +class KernelAgentBinding: + kernel: KernelInfo + agent_alloc_ctx: AgentAllocationContext + + +@attr.s(auto_attribs=True, slots=True) +class PredicateResult: + passed: bool + message: Optional[str] = None + permanent: bool = False + + +class SchedulingPredicate(Protocol): + async def __call__( + self, + db_conn: SAConnection, + sched_ctx: SchedulingContext, + sess_ctx: PendingSession, + ) -> PredicateResult: + ... + + +class AbstractScheduler(metaclass=ABCMeta): + + """ + Interface for scheduling algorithms where the + ``schedule()`` method is a pure function. + """ + + sgroup_opts: ScalingGroupOpts # sgroup-specific config + config: Mapping[str, Any] # scheduler-specific config + config_iv: t.Dict + + def __init__(self, sgroup_opts: ScalingGroupOpts, config: Mapping[str, Any]) -> None: + self.sgroup_opts = sgroup_opts + self.config = self.config_iv.check(config) + + @abstractmethod + def pick_session( + self, + total_capacity: ResourceSlot, + pending_sessions: Sequence[PendingSession], + existing_sessions: Sequence[ExistingSession], + ) -> Optional[SessionId]: + """ + Pick a session to try schedule. + This is where the queueing semantics is implemented such as prioritization. + """ + return None + + @abstractmethod + def assign_agent_for_session( + self, + possible_agents: Sequence[AgentContext], + pending_session: PendingSession, + ) -> Optional[AgentId]: + """ + Assign an agent for the entire session, only considering the total requested + slots of the session. This is used for both single-container sessions and + single-node multi-container sessions. + + In single-node multi-container sessions, all sub-containers are spawned by + slicing the assigned agent's resource. + """ + return None + + @abstractmethod + def assign_agent_for_kernel( + self, + possible_agents: Sequence[AgentContext], + pending_kernel: KernelInfo, + ) -> Optional[AgentId]: + """ + Assign an agent for a kernel of the session. + This may be called multiple times for multi-node multi-container sessions. + """ + return None diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py new file mode 100644 index 0000000000..0a3d35c249 --- /dev/null +++ b/src/ai/backend/manager/server.py @@ -0,0 +1,759 @@ +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager as actxmgr, closing +from datetime import datetime +import functools +import importlib +import logging +import os +import pwd, grp +import ssl +import sys +import traceback +from typing import ( + Any, + AsyncIterator, + Final, + Iterable, + List, + Mapping, + MutableMapping, + Sequence, + cast, +) + +from aiohttp import web +import aiohttp_cors +import aiotools +import click +from pathlib import Path +from setproctitle import setproctitle +import aiomonitor + +from ai.backend.common import redis +from ai.backend.common.bgtask import BackgroundTaskManager +from ai.backend.common.cli import LazyGroup +from ai.backend.common.events import EventDispatcher, EventProducer +from ai.backend.common.utils import env_info +from ai.backend.common.logging import Logger, BraceStyleAdapter +from ai.backend.common.plugin.hook import HookPluginContext, ALL_COMPLETED, PASSED +from ai.backend.common.plugin.monitor import ( + ErrorPluginContext, + StatsPluginContext, + INCREMENT, +) + +from . import __version__ +from .api.context import RootContext +from .api.exceptions import ( + BackendError, + MethodNotAllowed, + URLNotFound, + GenericBadRequest, + InternalServerError, + InvalidAPIParameters, +) +from .api.manager import ManagerStatus +from .api.types import ( + AppCreator, + WebRequestHandler, WebMiddleware, + CleanupContext, +) +from .config import ( + LocalConfig, + SharedConfig, + load as load_config, + volume_config_iv, +) +from .defs import REDIS_STAT_DB, REDIS_LIVE_DB, REDIS_IMAGE_DB, REDIS_STREAM_DB +from .exceptions import InvalidArgument +from .idle import init_idle_checkers +from .models.storage import StorageSessionManager +from .models.utils import connect_database +from .plugin.webapp import WebappPluginContext +from .registry import AgentRegistry +from .scheduler.dispatcher import SchedulerDispatcher +from .types import DistributedLockFactory + +VALID_VERSIONS: Final = frozenset([ + # 'v1.20160915', # deprecated + # 'v2.20170315', # deprecated + # 'v3.20170615', # deprecated + + # authentication changed not to use request bodies + 'v4.20181215', + + # added & enabled streaming-execute API + 'v4.20190115', + + # changed resource/image formats + 'v4.20190315', + + # added user mgmt and ID/password authentication + # added domain/group/scaling-group + # added domain/group/scaling-group ref. fields to user/keypair/vfolder objects + 'v4.20190615', + + # added mount_map parameter when creating kernel + # changed GraphQL query structures for multi-container bundled sessions + 'v5.20191215', + + # rewrote vfolder upload/download APIs to migrate to external storage proxies + 'v6.20200815', + + # added standard-compliant /admin/gql endpoint + # deprecated /admin/graphql endpoint (still present for backward compatibility) + # added "groups_by_name" GQL query + # added "filter" and "order" arg to all paginated GQL queries with their own expression mini-langs + # removed "order_key" and "order_asc" arguments from all paginated GQL queries (never used!) + 'v6.20210815', + + # added session dependencies and state callback URLs configs when creating sessions + # added session event webhook option to session creation API + # added architecture option when making image aliases + 'v6.20220315', +]) +LATEST_REV_DATES: Final = { + 1: '20160915', + 2: '20170915', + 3: '20181215', + 4: '20190615', + 5: '20191215', + 6: '20220315', +} +LATEST_API_VERSION: Final = 'v6.20220315' + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +PUBLIC_INTERFACES: Final = [ + 'pidx', + 'background_task_manager', + 'local_config', + 'shared_config', + 'db', + 'registry', + 'redis_live', + 'redis_stat', + 'redis_image', + 'redis_stream', + 'event_dispatcher', + 'event_producer', + 'idle_checkers', + 'storage_manager', + 'stats_monitor', + 'error_monitor', + 'hook_plugin_ctx', +] + +public_interface_objs: MutableMapping[str, Any] = {} + + +async def hello(request: web.Request) -> web.Response: + """ + Returns the API version number. + """ + return web.json_response({ + 'version': LATEST_API_VERSION, + 'manager': __version__, + }) + + +async def on_prepare(request: web.Request, response: web.StreamResponse) -> None: + response.headers['Server'] = 'BackendAI' + + +@web.middleware +async def api_middleware(request: web.Request, + handler: WebRequestHandler) -> web.StreamResponse: + _handler = handler + method_override = request.headers.get('X-Method-Override', None) + if method_override: + request = request.clone(method=method_override) + new_match_info = await request.app.router.resolve(request) + if new_match_info is None: + raise InternalServerError('No matching method handler found') + _handler = new_match_info.handler + request._match_info = new_match_info # type: ignore # this is a hack + ex = request.match_info.http_exception + if ex is not None: + # handled by exception_middleware + raise ex + new_api_version = request.headers.get('X-BackendAI-Version') + legacy_api_version = request.headers.get('X-Sorna-Version') + api_version = new_api_version or legacy_api_version + try: + if api_version is None: + path_major_version = int(request.match_info.get('version', 5)) + revision_date = LATEST_REV_DATES[path_major_version] + request['api_version'] = (path_major_version, revision_date) + elif api_version in VALID_VERSIONS: + hdr_major_version, revision_date = api_version.split('.', maxsplit=1) + request['api_version'] = (int(hdr_major_version[1:]), revision_date) + else: + return GenericBadRequest('Unsupported API version.') + except (ValueError, KeyError): + return GenericBadRequest('Unsupported API version.') + resp = (await _handler(request)) + return resp + + +@web.middleware +async def exception_middleware(request: web.Request, + handler: WebRequestHandler) -> web.StreamResponse: + root_ctx: RootContext = request.app['_root.context'] + error_monitor = root_ctx.error_monitor + stats_monitor = root_ctx.stats_monitor + try: + await stats_monitor.report_metric(INCREMENT, 'ai.backend.manager.api.requests') + resp = (await handler(request)) + except InvalidArgument as ex: + if len(ex.args) > 1: + raise InvalidAPIParameters(f"{ex.args[0]}: {', '.join(map(str, ex.args[1:]))}") + elif len(ex.args) == 1: + raise InvalidAPIParameters(ex.args[0]) + else: + raise InvalidAPIParameters() + except BackendError as ex: + if ex.status_code == 500: + log.warning('Internal server error raised inside handlers') + await error_monitor.capture_exception() + await stats_monitor.report_metric(INCREMENT, 'ai.backend.manager.api.failures') + await stats_monitor.report_metric(INCREMENT, f'ai.backend.manager.api.status.{ex.status_code}') + raise + except web.HTTPException as ex: + await stats_monitor.report_metric(INCREMENT, 'ai.backend.manager.api.failures') + await stats_monitor.report_metric(INCREMENT, f'ai.backend.manager.api.status.{ex.status_code}') + if ex.status_code == 404: + raise URLNotFound(extra_data=request.path) + if ex.status_code == 405: + concrete_ex = cast(web.HTTPMethodNotAllowed, ex) + raise MethodNotAllowed(concrete_ex.method, concrete_ex.allowed_methods) + log.warning('Bad request: {0!r}', ex) + raise GenericBadRequest + except asyncio.CancelledError as e: + # The server is closing or the client has disconnected in the middle of + # request. Atomic requests are still executed to their ends. + log.debug('Request cancelled ({0} {1})', request.method, request.rel_url) + raise e + except Exception as e: + await error_monitor.capture_exception() + log.exception('Uncaught exception in HTTP request handlers {0!r}', e) + if root_ctx.local_config['debug']['enabled']: + raise InternalServerError(traceback.format_exc()) + else: + raise InternalServerError() + else: + await stats_monitor.report_metric(INCREMENT, f'ai.backend.manager.api.status.{resp.status}') + return resp + + +@actxmgr +async def shared_config_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + # populate public interfaces + root_ctx.shared_config = SharedConfig( + root_ctx.local_config['etcd']['addr'], + root_ctx.local_config['etcd']['user'], + root_ctx.local_config['etcd']['password'], + root_ctx.local_config['etcd']['namespace'], + ) + await root_ctx.shared_config.reload() + yield + await root_ctx.shared_config.close() + + +@actxmgr +async def webapp_plugin_ctx(root_app: web.Application) -> AsyncIterator[None]: + root_ctx: RootContext = root_app['_root.context'] + plugin_ctx = WebappPluginContext(root_ctx.shared_config.etcd, root_ctx.local_config) + await plugin_ctx.init() + root_ctx.webapp_plugin_ctx = plugin_ctx + for plugin_name, plugin_instance in plugin_ctx.plugins.items(): + if root_ctx.pidx == 0: + log.info('Loading webapp plugin: {0}', plugin_name) + subapp, global_middlewares = await plugin_instance.create_app(root_ctx.cors_options) + _init_subapp(plugin_name, root_app, subapp, global_middlewares) + yield + await plugin_ctx.cleanup() + + +@actxmgr +async def manager_status_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + if root_ctx.pidx == 0: + mgr_status = await root_ctx.shared_config.get_manager_status() + if mgr_status is None or mgr_status not in (ManagerStatus.RUNNING, ManagerStatus.FROZEN): + # legacy transition: we now have only RUNNING or FROZEN for HA setup. + await root_ctx.shared_config.update_manager_status(ManagerStatus.RUNNING) + mgr_status = ManagerStatus.RUNNING + log.info('Manager status: {}', mgr_status) + tz = root_ctx.shared_config['system']['timezone'] + log.info('Configured timezone: {}', tz.tzname(datetime.now())) + yield + + +@actxmgr +async def redis_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + + root_ctx.redis_live = redis.get_redis_object(root_ctx.shared_config.data['redis'], db=REDIS_LIVE_DB) + root_ctx.redis_stat = redis.get_redis_object(root_ctx.shared_config.data['redis'], db=REDIS_STAT_DB) + root_ctx.redis_image = redis.get_redis_object( + root_ctx.shared_config.data['redis'], db=REDIS_IMAGE_DB, + ) + root_ctx.redis_stream = redis.get_redis_object( + root_ctx.shared_config.data['redis'], db=REDIS_STREAM_DB, + ) + yield + await root_ctx.redis_stream.close() + await root_ctx.redis_image.close() + await root_ctx.redis_stat.close() + await root_ctx.redis_live.close() + + +@actxmgr +async def database_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + async with connect_database(root_ctx.local_config) as db: + root_ctx.db = db + yield + + +@actxmgr +async def distributed_lock_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + root_ctx.distributed_lock_factory = init_lock_factory(root_ctx) + yield + + +@actxmgr +async def event_dispatcher_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + root_ctx.event_producer = await EventProducer.new( + root_ctx.shared_config.data['redis'], + db=REDIS_STREAM_DB, + ) + root_ctx.event_dispatcher = await EventDispatcher.new( + root_ctx.shared_config.data['redis'], + db=REDIS_STREAM_DB, + log_events=root_ctx.local_config['debug']['log-events'], + node_id=root_ctx.local_config['manager']['id'], + ) + yield + await root_ctx.event_producer.close() + await asyncio.sleep(0.2) + await root_ctx.event_dispatcher.close() + + +@actxmgr +async def idle_checker_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + root_ctx.idle_checker_host = await init_idle_checkers( + root_ctx.db, + root_ctx.shared_config, + root_ctx.event_dispatcher, + root_ctx.event_producer, + root_ctx.distributed_lock_factory, + ) + await root_ctx.idle_checker_host.start() + yield + await root_ctx.idle_checker_host.shutdown() + + +@actxmgr +async def storage_manager_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + raw_vol_config = await root_ctx.shared_config.etcd.get_prefix('volumes') + config = volume_config_iv.check(raw_vol_config) + root_ctx.storage_manager = StorageSessionManager(config) + yield + await root_ctx.storage_manager.aclose() + + +@actxmgr +async def hook_plugin_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + ctx = HookPluginContext(root_ctx.shared_config.etcd, root_ctx.local_config) + root_ctx.hook_plugin_ctx = ctx + await ctx.init() + hook_result = await ctx.dispatch( + 'ACTIVATE_MANAGER', + (), + return_when=ALL_COMPLETED, + ) + if hook_result.status != PASSED: + raise RuntimeError('Could not activate the manager instance.') + yield + await ctx.cleanup() + + +@actxmgr +async def agent_registry_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + root_ctx.registry = AgentRegistry( + root_ctx.shared_config, + root_ctx.db, + root_ctx.redis_stat, + root_ctx.redis_live, + root_ctx.redis_image, + root_ctx.event_dispatcher, + root_ctx.event_producer, + root_ctx.storage_manager, + root_ctx.hook_plugin_ctx, + ) + await root_ctx.registry.init() + yield + await root_ctx.registry.shutdown() + + +@actxmgr +async def sched_dispatcher_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + sched_dispatcher = await SchedulerDispatcher.new( + root_ctx.local_config, root_ctx.shared_config, + root_ctx.event_dispatcher, root_ctx.event_producer, + root_ctx.distributed_lock_factory, + root_ctx.registry, + ) + yield + await sched_dispatcher.close() + + +@actxmgr +async def monitoring_ctx(root_ctx: RootContext) -> AsyncIterator[None]: + ectx = ErrorPluginContext(root_ctx.shared_config.etcd, root_ctx.local_config) + sctx = StatsPluginContext(root_ctx.shared_config.etcd, root_ctx.local_config) + await ectx.init(context={'_root.context': root_ctx}) + await sctx.init() + root_ctx.error_monitor = ectx + root_ctx.stats_monitor = sctx + yield + await sctx.cleanup() + await ectx.cleanup() + + +class background_task_ctx: + + def __init__(self, root_ctx: RootContext) -> None: + self.root_ctx = root_ctx + + async def __aenter__(self) -> None: + self.root_ctx.background_task_manager = BackgroundTaskManager(self.root_ctx.event_producer) + + async def __aexit__(self, *exc_info) -> None: + pass + + async def shutdown(self) -> None: + if hasattr(self.root_ctx, 'background_task_manager'): + await self.root_ctx.background_task_manager.shutdown() + + +def handle_loop_error( + root_ctx: RootContext, + loop: asyncio.AbstractEventLoop, + context: Mapping[str, Any], +) -> None: + exception = context.get('exception') + msg = context.get('message', '(empty message)') + if exception is not None: + if sys.exc_info()[0] is not None: + log.exception('Error inside event loop: {0}', msg) + if (error_monitor := getattr(root_ctx, 'error_monitor', None)) is not None: + loop.create_task(error_monitor.capture_exception()) + else: + exc_info = (type(exception), exception, exception.__traceback__) + log.error('Error inside event loop: {0}', msg, exc_info=exc_info) + if (error_monitor := getattr(root_ctx, 'error_monitor', None)) is not None: + loop.create_task(error_monitor.capture_exception(exc_instance=exception)) + + +def _init_subapp( + pkg_name: str, + root_app: web.Application, + subapp: web.Application, + global_middlewares: Iterable[WebMiddleware], +) -> None: + subapp.on_response_prepare.append(on_prepare) + + async def _set_root_ctx(subapp: web.Application): + # Allow subapp's access to the root app properties. + # These are the public APIs exposed to plugins as well. + subapp['_root.context'] = root_app['_root.context'] + + # We must copy the public interface prior to all user-defined startup signal handlers. + subapp.on_startup.insert(0, _set_root_ctx) + prefix = subapp.get('prefix', pkg_name.split('.')[-1].replace('_', '-')) + root_app.add_subapp('/' + prefix, subapp) + root_app.middlewares.extend(global_middlewares) + + +def init_subapp(pkg_name: str, root_app: web.Application, create_subapp: AppCreator) -> None: + root_ctx: RootContext = root_app['_root.context'] + subapp, global_middlewares = create_subapp(root_ctx.cors_options) + _init_subapp(pkg_name, root_app, subapp, global_middlewares) + + +def init_lock_factory(root_ctx: RootContext) -> DistributedLockFactory: + ipc_base_path = root_ctx.local_config['manager']['ipc-base-path'] + manager_id = root_ctx.local_config['manager']['id'] + lock_backend = root_ctx.local_config['manager']['distributed-lock'] + log.debug("using {} as the distributed lock backend", lock_backend) + match lock_backend: + case 'filelock': + from ai.backend.common.lock import FileLock + return lambda lock_id, lifetime_hint: FileLock( + ipc_base_path / f"{manager_id}.{lock_id}.lock", + timeout=0, + ) + case 'pg_advisory': + from .pglock import PgAdvisoryLock + return lambda lock_id, lifetime_hint: PgAdvisoryLock(root_ctx.db, lock_id) + case 'redlock': + raise NotImplementedError("Redlock on aioredis v2 is not supported yet.") + case 'etcd': + from ai.backend.common.lock import EtcdLock + return lambda lock_id, lifetime_hint: EtcdLock( + str(lock_id), + root_ctx.shared_config.etcd, + lifetime=min(lifetime_hint * 2, lifetime_hint + 30), + ) + case other: + raise ValueError(f"Invalid lock backend: {other}") + + +def build_root_app( + pidx: int, + local_config: LocalConfig, *, + cleanup_contexts: Sequence[CleanupContext] = None, + subapp_pkgs: Sequence[str] = None, + scheduler_opts: Mapping[str, Any] = None, +) -> web.Application: + public_interface_objs.clear() + app = web.Application(middlewares=[ + exception_middleware, + api_middleware, + ]) + root_ctx = RootContext() + global_exception_handler = functools.partial(handle_loop_error, root_ctx) + loop = asyncio.get_running_loop() + loop.set_exception_handler(global_exception_handler) + app['_root.context'] = root_ctx + root_ctx.local_config = local_config + root_ctx.pidx = pidx + root_ctx.cors_options = { + '*': aiohttp_cors.ResourceOptions( + allow_credentials=False, + expose_headers="*", allow_headers="*"), + } + default_scheduler_opts = { + 'limit': 2048, + 'close_timeout': 30, + 'exception_handler': global_exception_handler, + } + app['scheduler_opts'] = { + **default_scheduler_opts, + **(scheduler_opts if scheduler_opts is not None else {}), + } + app.on_response_prepare.append(on_prepare) + + if cleanup_contexts is None: + cleanup_contexts = [ + manager_status_ctx, + redis_ctx, + database_ctx, + distributed_lock_ctx, + event_dispatcher_ctx, + idle_checker_ctx, + storage_manager_ctx, + hook_plugin_ctx, + monitoring_ctx, + agent_registry_ctx, + sched_dispatcher_ctx, + background_task_ctx, + ] + + async def _cleanup_context_wrapper(cctx, app: web.Application) -> AsyncIterator[None]: + # aiohttp's cleanup contexts are just async generators, not async context managers. + cctx_instance = cctx(app['_root.context']) + app['_cctx_instances'].append(cctx_instance) + try: + async with cctx_instance: + yield + except Exception as e: + exc_info = (type(e), e, e.__traceback__) + log.error('Error initializing cleanup_contexts: {0}', cctx.__name__, exc_info=exc_info) + + async def _call_cleanup_context_shutdown_handlers(app: web.Application) -> None: + for cctx in app['_cctx_instances']: + if hasattr(cctx, 'shutdown'): + try: + await cctx.shutdown() + except Exception: + log.exception("error while shutting down a cleanup context") + + app['_cctx_instances'] = [] + app.on_shutdown.append(_call_cleanup_context_shutdown_handlers) + for cleanup_ctx in cleanup_contexts: + app.cleanup_ctx.append( + functools.partial(_cleanup_context_wrapper, cleanup_ctx), + ) + cors = aiohttp_cors.setup(app, defaults=root_ctx.cors_options) + # should be done in create_app() in other modules. + cors.add(app.router.add_route('GET', r'', hello)) + cors.add(app.router.add_route('GET', r'/', hello)) + if subapp_pkgs is None: + subapp_pkgs = [] + for pkg_name in subapp_pkgs: + if pidx == 0: + log.info('Loading module: {0}', pkg_name[1:]) + subapp_mod = importlib.import_module(pkg_name, 'ai.backend.manager.api') + init_subapp(pkg_name, app, getattr(subapp_mod, 'create_app')) + return app + + +@actxmgr +async def server_main( + loop: asyncio.AbstractEventLoop, + pidx: int, + _args: List[Any], +) -> AsyncIterator[None]: + subapp_pkgs = [ + '.etcd', '.events', + '.auth', '.ratelimit', + '.vfolder', '.admin', + '.session', + '.stream', + '.manager', + '.resource', + '.scaling_group', + '.cluster_template', + '.session_template', + '.image', + '.userconfig', + '.domainconfig', + '.groupconfig', + '.logs', + ] + root_app = build_root_app(pidx, _args[0], subapp_pkgs=subapp_pkgs) + root_ctx: RootContext = root_app['_root.context'] + + # Start aiomonitor. + # Port is set by config (default=50001). + m = aiomonitor.Monitor( + loop, + port=root_ctx.local_config['manager']['aiomonitor-port'] + pidx, + console_enabled=False, + ) + m.prompt = f"monitor (manager[{pidx}@{os.getpid()}]) >>> " + m.start() + + # Plugin webapps should be loaded before runner.setup(), + # which freezes on_startup event. + with closing(m): + async with ( + shared_config_ctx(root_ctx), + webapp_plugin_ctx(root_app), + ): + ssl_ctx = None + if root_ctx.local_config['manager']['ssl-enabled']: + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain( + str(root_ctx.local_config['manager']['ssl-cert']), + str(root_ctx.local_config['manager']['ssl-privkey']), + ) + + runner = web.AppRunner(root_app, keepalive_timeout=30.0) + await runner.setup() + service_addr = root_ctx.local_config['manager']['service-addr'] + site = web.TCPSite( + runner, + str(service_addr.host), + service_addr.port, + backlog=1024, + reuse_port=True, + ssl_context=ssl_ctx, + ) + await site.start() + + if os.geteuid() == 0: + uid = root_ctx.local_config['manager']['user'] + gid = root_ctx.local_config['manager']['group'] + os.setgroups([ + g.gr_gid for g in grp.getgrall() + if pwd.getpwuid(uid).pw_name in g.gr_mem + ]) + os.setgid(gid) + os.setuid(uid) + log.info('changed process uid and gid to {}:{}', uid, gid) + log.info('started handling API requests at {}', service_addr) + + try: + yield + finally: + log.info('shutting down...') + await runner.cleanup() + + +@actxmgr +async def server_main_logwrapper( + loop: asyncio.AbstractEventLoop, + pidx: int, + _args: List[Any], +) -> AsyncIterator[None]: + setproctitle(f"backend.ai: manager worker-{pidx}") + log_endpoint = _args[1] + logger = Logger(_args[0]['logging'], is_master=False, log_endpoint=log_endpoint) + try: + with logger: + async with server_main(loop, pidx, _args): + yield + except Exception: + traceback.print_exc() + + +@click.group(invoke_without_command=True) +@click.option('-f', '--config-path', '--config', type=Path, default=None, + help='The config file path. (default: ./manager.toml and /etc/backend.ai/manager.toml)') +@click.option('--debug', is_flag=True, + help='Enable the debug mode and override the global log level to DEBUG.') +@click.pass_context +def main(ctx: click.Context, config_path: Path, debug: bool) -> None: + """ + Start the manager service as a foreground process. + """ + + cfg = load_config(config_path, debug) + + if ctx.invoked_subcommand is None: + cfg['manager']['pid-file'].write_text(str(os.getpid())) + ipc_base_path = cfg['manager']['ipc-base-path'] + log_sockpath = ipc_base_path / f'manager-logger-{os.getpid()}.sock' + log_endpoint = f'ipc://{log_sockpath}' + try: + logger = Logger(cfg['logging'], is_master=True, log_endpoint=log_endpoint) + with logger: + ns = cfg['etcd']['namespace'] + setproctitle(f"backend.ai: manager {ns}") + log.info('Backend.AI Manager {0}', __version__) + log.info('runtime: {0}', env_info()) + log_config = logging.getLogger('ai.backend.manager.config') + log_config.debug('debug mode enabled.') + if cfg['manager']['event-loop'] == 'uvloop': + import uvloop + uvloop.install() + log.info('Using uvloop as the event loop backend') + try: + aiotools.start_server( + server_main_logwrapper, + num_workers=cfg['manager']['num-proc'], + args=(cfg, log_endpoint), + wait_timeout=5.0, + ) + finally: + log.info('terminated.') + finally: + if cfg['manager']['pid-file'].is_file(): + # check is_file() to prevent deleting /dev/null! + cfg['manager']['pid-file'].unlink() + else: + # Click is going to invoke a subcommand. + pass + + +@main.group(cls=LazyGroup, import_name='ai.backend.manager.api.auth:cli') +def auth() -> None: + pass + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/ai/backend/manager/types.py b/src/ai/backend/manager/types.py new file mode 100644 index 0000000000..7b5439115f --- /dev/null +++ b/src/ai/backend/manager/types.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import attr +import enum +import uuid +from typing import ( + Protocol, + TYPE_CHECKING, +) + +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +from sqlalchemy.engine.row import Row + +if TYPE_CHECKING: + from ai.backend.common.lock import AbstractDistributedLock + from .defs import LockID + + +class SessionGetter(Protocol): + + def __call__(self, *, db_connection: SAConnection) -> Row: + ... + + +# Sentinel is a special object that indicates a special status instead of a value +# where the user expects a value. +# According to the discussion in https://github.com/python/typing/issues/236, +# we define our Sentinel type as an enum with only one special value. +# This enables passing of type checks by "value is sentinel" (or "value is Sentinel.token") +# instead of more expensive "isinstance(value, Sentinel)" because we can assure type checkers +# to think there is no other possible instances of the Sentinel type. + +class Sentinel(enum.Enum): + token = 0 + + +@attr.define(slots=True) +class UserScope: + domain_name: str + group_id: uuid.UUID + user_uuid: uuid.UUID + user_role: str + + +class DistributedLockFactory(Protocol): + + def __call__(self, lock_id: LockID, lifetime_hint: float) -> AbstractDistributedLock: + ... diff --git a/src/ai/backend/meta/BUILD b/src/ai/backend/meta/BUILD new file mode 100644 index 0000000000..4a902ca870 --- /dev/null +++ b/src/ai/backend/meta/BUILD @@ -0,0 +1,3 @@ +python_sources( + name="lib", +) diff --git a/src/ai/backend/meta/__init__.py b/src/ai/backend/meta/__init__.py index d5c87375cb..e69de29bb2 100644 --- a/src/ai/backend/meta/__init__.py +++ b/src/ai/backend/meta/__init__.py @@ -1,5 +0,0 @@ -""" -This is a meta-package which contains nothing yet. -""" - -__version__ = '22.03.0' diff --git a/src/ai/backend/plugin/BUILD b/src/ai/backend/plugin/BUILD new file mode 100644 index 0000000000..93e3ddc21c --- /dev/null +++ b/src/ai/backend/plugin/BUILD @@ -0,0 +1,34 @@ +python_sources( + name="lib", + sources=["**/*.py"], + dependencies=[ + ":resources", + ], +) + +python_distribution( + name="dist", + dependencies=[ + ":lib", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-plugin", + description="Backend.AI Plugin Subsystem", + license="MIT", + ), + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + ], +) diff --git a/src/ai/backend/plugin/README.md b/src/ai/backend/plugin/README.md new file mode 100644 index 0000000000..cc57c434be --- /dev/null +++ b/src/ai/backend/plugin/README.md @@ -0,0 +1,7 @@ +Backend.AI Plugin Subsystem +=========================== + +Package Structure +----------------- + +* `ai.backend.plugin`: Abstract types for plugins and a common base plugin set diff --git a/src/ai/backend/plugin/VERSION b/src/ai/backend/plugin/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/plugin/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/plugin/__init__.py b/src/ai/backend/plugin/__init__.py new file mode 100644 index 0000000000..17b3552989 --- /dev/null +++ b/src/ai/backend/plugin/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +__version__ = (Path(__file__).parent / 'VERSION').read_text().strip() diff --git a/src/ai/backend/plugin/entrypoint.py b/src/ai/backend/plugin/entrypoint.py new file mode 100644 index 0000000000..ff141bfe3e --- /dev/null +++ b/src/ai/backend/plugin/entrypoint.py @@ -0,0 +1,99 @@ +import ast +import itertools +import logging +from importlib.metadata import EntryPoint, entry_points +from pathlib import Path +from typing import Container, Iterator, Optional + +log = logging.getLogger(__name__) + + +def scan_entrypoints( + group_name: str, + blocklist: Container[str] = None, +) -> Iterator[EntryPoint]: + if blocklist is None: + blocklist = set() + existing_names: dict[str, EntryPoint] = {} + for entrypoint in itertools.chain( + scan_entrypoint_from_buildscript(group_name), + scan_entrypoint_from_package_metadata(group_name), + ): + if entrypoint.name in blocklist: + continue + if existing_entrypoint := existing_names.get(entrypoint.name): + raise RuntimeError( + f"Detected a duplicate plugin entrypoint name {entrypoint.name!r} " + f"from {existing_entrypoint.value} and {entrypoint.value}", + ) + existing_names[entrypoint.name] = entrypoint + yield entrypoint + + +def scan_entrypoint_from_package_metadata(group_name: str) -> Iterator[EntryPoint]: + yield from entry_points().select(group=group_name) + + +def scan_entrypoint_from_buildscript(group_name: str) -> Iterator[EntryPoint]: + entrypoints = {} + # Scan self-exported entrypoints when executed via pex. + ai_backend_ns_path = Path(__file__).parent.parent + log.debug("scan_entrypoint_from_buildscript({!r}): Namespace path: {}", group_name, ai_backend_ns_path) + for buildscript_path in ai_backend_ns_path.glob("**/BUILD"): + log.debug("reading entry points [{}] from {}", group_name, buildscript_path) + for entrypoint in extract_entrypoints_from_buildscript(group_name, buildscript_path): + entrypoints[entrypoint.name] = entrypoint + # Override with the entrypoints found in the current source directories, + try: + build_root = find_build_root() + except ValueError: + pass + else: + src_path = build_root / 'src' + plugins_path = build_root / 'plugins' + log.debug("scan_entrypoint_from_buildscript({!r}): current src: {}", group_name, src_path) + for buildscript_path in src_path.glob("**/BUILD"): + if buildscript_path.is_relative_to(plugins_path): + # Prevent loading BUILD files in plugin checkouts if they use Pants on their own. + continue + log.debug("reading entry points [{}] from {}", group_name, buildscript_path) + for entrypoint in extract_entrypoints_from_buildscript(group_name, buildscript_path): + entrypoints[entrypoint.name] = entrypoint + yield from entrypoints.values() + + +def find_build_root(path: Optional[Path] = None) -> Path: + cwd = Path.cwd() if path is None else path + while True: + if (cwd / 'BUILD_ROOT').exists(): + return cwd + cwd = cwd.parent + if cwd.parent == cwd: + # reached the root directory + break + raise ValueError("Could not find the build root directory") + + +def extract_entrypoints_from_buildscript( + group_name: str, + buildscript_path: Path, +) -> Iterator[EntryPoint]: + tree = ast.parse(buildscript_path.read_bytes()) + for node in tree.body: + if ( + isinstance(node, ast.Expr) and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Name) and + node.value.func.id == "python_distribution" + ): + for kwarg in node.value.keywords: + if kwarg.arg == "entry_points": + raw_data = ast.literal_eval(kwarg.value) + for key, raw_entry_points in raw_data.items(): + if key != group_name: + continue + for name, ref in raw_entry_points.items(): + try: + yield EntryPoint(name=name, value=ref, group=group_name) + except ValueError: + pass diff --git a/src/ai/backend/plugin/py.typed b/src/ai/backend/plugin/py.typed new file mode 100644 index 0000000000..48cdce8528 --- /dev/null +++ b/src/ai/backend/plugin/py.typed @@ -0,0 +1 @@ +placeholder diff --git a/src/ai/backend/runner/.bash_profile b/src/ai/backend/runner/.bash_profile new file mode 100644 index 0000000000..d884a8f5ee --- /dev/null +++ b/src/ai/backend/runner/.bash_profile @@ -0,0 +1,6 @@ +# From https://unix.stackexchange.com/a/541352 +if [ -n "$BASH_VERSION" ]; then + if [ -f "$HOME/.bashrc" ]; then + . "$HOME/.bashrc" + fi +fi diff --git a/src/ai/backend/runner/.bashrc b/src/ai/backend/runner/.bashrc new file mode 100644 index 0000000000..ca736a42d7 --- /dev/null +++ b/src/ai/backend/runner/.bashrc @@ -0,0 +1,7 @@ +export PS1="\[\033[01;32m\]\u@${BACKENDAI_CLUSTER_HOST:-main}\[\033[01;33m\][${BACKENDAI_SESSION_NAME}]\[\033[00m\]:\[\033[01;34m\]\w\[\033[00m\]\$ " + +if [[ `uname` == "Linux" ]]; then + alias ls="ls --color" +fi +alias ll="ls -al" +alias l="ls -a" diff --git a/src/ai/backend/runner/.dockerignore b/src/ai/backend/runner/.dockerignore new file mode 100644 index 0000000000..42ed402165 --- /dev/null +++ b/src/ai/backend/runner/.dockerignore @@ -0,0 +1,3 @@ +*.tar.xz +*.so +*.bin diff --git a/src/ai/backend/runner/.tmux.conf b/src/ai/backend/runner/.tmux.conf new file mode 100644 index 0000000000..9112dfba41 --- /dev/null +++ b/src/ai/backend/runner/.tmux.conf @@ -0,0 +1,82 @@ +# tmux configuration by Joongi Kim (https://github.com/achimnol/dotfiles) + +new-session -n $HOST + +unbind C-c +unbind c +bind C-c new-window +bind c new-window + +unbind C-d +unbind d +bind C-d detach +bind d detach + +# Tip: to select windows by index, use ' +unbind C-h +unbind C-l +bind -r C-h select-window -t :- +bind -r C-l select-window -t :+ + +# Terminal colors +# If it does not work, alias tmux='tmux -2' in your shell configuration. +# override inner $TERM +set-option -g default-terminal "xterm-256color" +# override outer $TERM for true-color support (tmux 2.3+) +set-option -as terminal-overrides ",xterm-256color:Tc,gnome*:RGB" + +# Enable mouse features +set -g mouse on +set -g focus-events on + +# Let function keys to work +setw -g xterm-keys on + +# Remove delays when hitting Esc + arrow keys in Vim +set -s escape-time 0 + +# +" and +_ splits the window horizontally. +# +% and +| splits the window vertically. +bind | split-window -h +bind _ split-window -v + +set-option -g base-index 1 +set-option -g set-titles on +set-option -g visual-activity on +# set-window-option -g mode-keys vi +# set-window-option -g automatic-rename +# set-window-option -g monitor-activity on +# set-window-option -g aggressive-resize on +# set -g status off + +set-option -g history-limit 5000 + +# redisplay ^L l +unbind ^L +bind ^L refresh-client +unbind l +bind l refresh-client + +## vim style key bindings +unbind Tab +bind Tab select-pane -t :.+ + +# move around panes with j and k, a bit like vim +# as of tmux 1.1, there is no way to move based on pane position (ie, no way to +# move the pane to the right) +bind j select-pane -t :.+ # down-pane +bind k select-pane -t :.- # up-pane + +# resize panes like vim +# feel free to change the "1" to however many lines you want to resize by, only +# one at a time can be slow +bind < resize-pane -L 1 +bind > resize-pane -R 1 +bind - resize-pane -D 1 +bind + resize-pane -U 1 + +## key bindings compatible with my Emacs config +bind C-o select-pane -t :.- # up-pane + +## apply tmux configure local file if it exists +source -q ~/.tmux.conf.local \ No newline at end of file diff --git a/src/ai/backend/runner/.vimrc b/src/ai/backend/runner/.vimrc new file mode 100644 index 0000000000..6b90b953a3 --- /dev/null +++ b/src/ai/backend/runner/.vimrc @@ -0,0 +1,10 @@ +set nocompatible +set laststatus=2 " display status line always +set number " show line number +set encoding=utf-8 " use UTF-8 +set autoindent " indent when moving to the next line while coding +set smartindent " tab is converted to 4 spaces (Python convention) +set bs=2 " make backspace behave naturally +set ignorecase " make search case insensitive +set smartcase +set pastetoggle= " fix indent problem when pasting from ext. source diff --git a/src/ai/backend/runner/BUILD b/src/ai/backend/runner/BUILD new file mode 100644 index 0000000000..8fb489b71d --- /dev/null +++ b/src/ai/backend/runner/BUILD @@ -0,0 +1,77 @@ +python_sources( + name="lib", + dependencies=[ + ":resources", + ], +) + +pex_binary( + name="extract-dotfiles", + entry_point="extract_dotfiles.py", +) + +# This package does not have external requirements. +python_requirements() + +python_distribution( + name="dist", + dependencies=[ + ":lib", + ], + provides=python_artifact( + name="backend.ai-kernel-binary", + description="Backend.AI Kernel Runner Prebuilt Binaries Package", + license="LGPLv3", + ), + generate_setup=True, + tags=["wheel", "platform-specific"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ":platform-binaries", + ], + sources=[ + "*.sh", + "*.svg", + "*.md", + "*.css", + "*.ttf", + "*.dockerfile", + ".bash_profile", + ".bashrc", + ".tmux.conf", + ".vimrc", + "terminfo.alpine3.8/**/*", + ], +) + +platform_resources( + name="platform-binaries", + dependency_map={ + "linux_x86_64": ":linux-x86_64-binaries", + "linux_arm64": ":linux-arm64-binaries", + }, +) + +resources( + name="linux-x86_64-binaries", + sources=[ + "*.x86_64.bin", + "*.x86_64.so", + "*.x86_64.tar.xz", + ], +) + +resources( + name="linux-arm64-binaries", + sources=[ + "*.aarch64.bin", + "*.aarch64.so", + "*.aarch64.tar.xz", + ], +) diff --git a/src/ai/backend/runner/DO_NOT_STORE_PERSISTENT_FILES_HERE.md b/src/ai/backend/runner/DO_NOT_STORE_PERSISTENT_FILES_HERE.md new file mode 100644 index 0000000000..7e7c27f833 --- /dev/null +++ b/src/ai/backend/runner/DO_NOT_STORE_PERSISTENT_FILES_HERE.md @@ -0,0 +1,9 @@ +This directory (/home/work) will be deleted when the compute session is terminated. +To keep persistent files after termination: +- Create and mount data folder(s) when starting new sessions. +- Always use them to put any persistent files. + +이 디렉토리(/home/work)는 연산 세션이 종료될 때 함께 삭제됩니다. +계속 보존할 파일들을 잃어버리지 않으려면: +- 데이터 폴더를 생성하여 세션 실행 시 마운트합니다. +- 보존할 파일들을 항상 마운트된 폴더들 안에 저장하십시오. diff --git a/src/ai/backend/runner/README.md b/src/ai/backend/runner/README.md new file mode 100644 index 0000000000..c32693a176 --- /dev/null +++ b/src/ai/backend/runner/README.md @@ -0,0 +1,2 @@ +# Backend.AI Kernel Runner Binary Components + diff --git a/src/ai/backend/runner/VERSION b/src/ai/backend/runner/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/runner/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/runner/__init__.py b/src/ai/backend/runner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/runner/dropbear.glibc.aarch64.bin b/src/ai/backend/runner/dropbear.glibc.aarch64.bin new file mode 100755 index 0000000000..763b425fc6 --- /dev/null +++ b/src/ai/backend/runner/dropbear.glibc.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1555c7cbe29c2791d652035f423103020d1ed7df1d793c80775992730a18c649 +size 1396936 diff --git a/src/ai/backend/runner/dropbear.glibc.x86_64.bin b/src/ai/backend/runner/dropbear.glibc.x86_64.bin new file mode 100755 index 0000000000..45cf5283ae --- /dev/null +++ b/src/ai/backend/runner/dropbear.glibc.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:520a7b596c19131d3545121ebc9d89ccd43995019c8a42828041b8e2521507cc +size 1750056 diff --git a/src/ai/backend/runner/dropbear.musl.aarch64.bin b/src/ai/backend/runner/dropbear.musl.aarch64.bin new file mode 100755 index 0000000000..ad97138474 --- /dev/null +++ b/src/ai/backend/runner/dropbear.musl.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:667033ef151d7c16318a7a768cce374d5d7e6755f3a2a0d404ae06d21ded99da +size 1137816 diff --git a/src/ai/backend/runner/dropbear.musl.x86_64.bin b/src/ai/backend/runner/dropbear.musl.x86_64.bin new file mode 100755 index 0000000000..6057d3b3c5 --- /dev/null +++ b/src/ai/backend/runner/dropbear.musl.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f482388283903faa9c0bca5f2a6db85daffe0f1f916791f263c612026afddeb +size 1086024 diff --git a/src/ai/backend/runner/dropbearconvert.glibc.aarch64.bin b/src/ai/backend/runner/dropbearconvert.glibc.aarch64.bin new file mode 100755 index 0000000000..39f0fcf065 --- /dev/null +++ b/src/ai/backend/runner/dropbearconvert.glibc.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fef3ec5f750eb79a87d473121b532005e67bdb528327195aef5a0247ed6bb40 +size 905448 diff --git a/src/ai/backend/runner/dropbearconvert.glibc.x86_64.bin b/src/ai/backend/runner/dropbearconvert.glibc.x86_64.bin new file mode 100755 index 0000000000..ba6f88d4ce --- /dev/null +++ b/src/ai/backend/runner/dropbearconvert.glibc.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b1eb6ef98e25b4901ae4806c410f8cc233c7cecc3a489a26ad3d3bc8e994fc1 +size 1221640 diff --git a/src/ai/backend/runner/dropbearconvert.musl.aarch64.bin b/src/ai/backend/runner/dropbearconvert.musl.aarch64.bin new file mode 100755 index 0000000000..53c518d8cf --- /dev/null +++ b/src/ai/backend/runner/dropbearconvert.musl.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11c80d1ad4b542a6bae90d5781cde68e7d1c32b2afefc0d230b3582eaa220e9d +size 718856 diff --git a/src/ai/backend/runner/dropbearconvert.musl.x86_64.bin b/src/ai/backend/runner/dropbearconvert.musl.x86_64.bin new file mode 100755 index 0000000000..4a14b85fe1 --- /dev/null +++ b/src/ai/backend/runner/dropbearconvert.musl.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4b2c17f29fe8b4d9226912728cb39f03a56d04674db5f1f2d78702724e1c615 +size 651608 diff --git a/src/ai/backend/runner/dropbearkey.glibc.aarch64.bin b/src/ai/backend/runner/dropbearkey.glibc.aarch64.bin new file mode 100755 index 0000000000..4033ad3d88 --- /dev/null +++ b/src/ai/backend/runner/dropbearkey.glibc.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d3dd9e19e5f963f4bf8bc62dcc4620e802397bd03ac75b59f000904812d324c +size 896616 diff --git a/src/ai/backend/runner/dropbearkey.glibc.x86_64.bin b/src/ai/backend/runner/dropbearkey.glibc.x86_64.bin new file mode 100755 index 0000000000..1761a80862 --- /dev/null +++ b/src/ai/backend/runner/dropbearkey.glibc.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7275d8bb392cdcb28011dab9eb318b72cea01321fcc3437a00dffcb992a13e7 +size 1212792 diff --git a/src/ai/backend/runner/dropbearkey.musl.aarch64.bin b/src/ai/backend/runner/dropbearkey.musl.aarch64.bin new file mode 100755 index 0000000000..cea5a6c0a0 --- /dev/null +++ b/src/ai/backend/runner/dropbearkey.musl.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48b6a99ee4ea7e45d0e317a53ef9449e783862fa6a5a425720d5b8adb8307ea5 +size 698080 diff --git a/src/ai/backend/runner/dropbearkey.musl.x86_64.bin b/src/ai/backend/runner/dropbearkey.musl.x86_64.bin new file mode 100755 index 0000000000..4006e26fc5 --- /dev/null +++ b/src/ai/backend/runner/dropbearkey.musl.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:432c828c39e1c9ee2997717ac7b101ec52976936b7ce4ccb95c2c1c83d3ec8d6 +size 632992 diff --git a/src/ai/backend/runner/entrypoint.sh b/src/ai/backend/runner/entrypoint.sh new file mode 100755 index 0000000000..149c588a6f --- /dev/null +++ b/src/ai/backend/runner/entrypoint.sh @@ -0,0 +1,89 @@ +#! /bin/sh + +USER_ID=${LOCAL_USER_ID:-9001} +GROUP_ID=${LOCAL_GROUP_ID:-9001} + +echo "Kernel started at: $(date -Iseconds -u)" + +if [ $USER_ID -eq 0 ]; then + + echo "WARNING: Running the user codes as root is not recommended." + if [ -f /bin/ash ]; then # for alpine + export SHELL=/bin/ash + else + export SHELL=/bin/bash + echo "$LD_PRELOAD" | tr ':' '\n' > /etc/ld.so.preload + unset LD_PRELOAD + fi + export LD_LIBRARY_PATH="/opt/backend.ai/lib:$LD_LIBRARY_PATH" + export HOME="/home/work" + + # Invoke image-specific bootstrap hook. + if [ -x "/opt/container/bootstrap.sh" ]; then + echo 'Executing image bootstrap... ' + . /opt/container/bootstrap.sh + echo 'Image bootstrap executed.' + fi + + # Extract dotfiles + /opt/backend.ai/bin/python /opt/kernel/extract_dotfiles.py + + echo "Executing the main program..." + exec "$@" + +else + + echo "Setting up uid and gid: $USER_ID:$GROUP_ID" + USER_NAME=$(getent group $USER_ID | cut -d: -f1) + GROUP_NAME=$(getent group $GROUP_ID | cut -d: -f1) + if [ -f /bin/ash ]; then # for alpine (busybox) + if [ -z "$GROUP_NAME" ]; then + GROUP_NAME=work + addgroup -g $GROUP_ID $GROUP_NAME + fi + if [ -z "$USER_NAME" ]; then + USER_NAME=work + adduser -s /bin/ash -h "/home/$USER_NAME" -H -D -u $USER_ID -G $GROUP_NAME -g "User" $USER_NAME + fi + export SHELL=/bin/ash + else + echo "$LD_PRELOAD" | tr ':' '\n' > /etc/ld.so.preload + unset LD_PRELOAD + if [ -z "$GROUP_NAME" ]; then + GROUP_NAME=work + groupadd -g $GROUP_ID $GROUP_NAME + fi + if [ -z "$USER_NAME" ]; then + USER_NAME=work + useradd -s /bin/bash -d "/home/$USER_NAME" -M -r -u $USER_ID -g $GROUP_NAME -o -c "User" $USER_NAME + else + cp -R "/home/$USER_NAME/*" /home/work/ + cp -R "/home/$USER_NAME/.*" /home/work/ + usermod -s /bin/bash -d /home/work -l work -g $GROUP_NAME $USER_NAME + USER_NAME=work + chown -R $USER_NAME:$GROUP_NAME /home/work + fi + export SHELL=/bin/bash + fi + export LD_LIBRARY_PATH="/opt/backend.ai/lib:$LD_LIBRARY_PATH" + export HOME="/home/$USER_NAME" + + # Invoke image-specific bootstrap hook. + if [ -x "/opt/container/bootstrap.sh" ]; then + echo 'Executing image bootstrap... ' + export LOCAL_USER_ID=$USER_ID + export LOCAL_GROUP_ID=$GROUP_ID + . /opt/container/bootstrap.sh + echo 'Image bootstrap executed.' + fi + + # Correct the ownership of agent socket. + chown $USER_ID:$GROUP_ID /opt/kernel/agent.sock + + # Extract dotfiles + /opt/kernel/su-exec $USER_ID:$GROUP_ID /opt/backend.ai/bin/python /opt/kernel/extract_dotfiles.py + + echo "Executing the main program..." + exec /opt/kernel/su-exec $USER_ID:$GROUP_ID "$@" + +fi diff --git a/src/ai/backend/runner/extract_dotfiles.py b/src/ai/backend/runner/extract_dotfiles.py new file mode 100644 index 0000000000..65b3db5eae --- /dev/null +++ b/src/ai/backend/runner/extract_dotfiles.py @@ -0,0 +1,30 @@ +import json +from pathlib import Path +import sys + + +def extract_dotfiles(): + try: + with open('/home/config/dotfiles.json') as fr: + dotfiles = json.loads(fr.read()) + except FileNotFoundError: + return + work_dir = Path('/home/work') + for dotfile in dotfiles: + file_path = work_dir / dotfile['path'] + try: + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(dotfile['data']) + except IOError: + print(f"failed to write dotfile: {file_path}", file=sys.stderr) + try: + tmp = Path(file_path) + while tmp != work_dir: + tmp.chmod(int(dotfile['perm'], 8)) + tmp = tmp.parent + except IOError: + print(f"failed to chmod dotfile: {file_path}", file=sys.stderr) + + +if __name__ == '__main__': + extract_dotfiles() diff --git a/src/ai/backend/runner/jail.alpine3.8.bin b/src/ai/backend/runner/jail.alpine3.8.bin new file mode 100755 index 0000000000..f171f5ec54 --- /dev/null +++ b/src/ai/backend/runner/jail.alpine3.8.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3af7f73ad20e38e00c4568ceb50b21389b063d6d6f0e7a72b40895d36b5f4e73 +size 3689112 diff --git a/src/ai/backend/runner/jail.ubuntu16.04.bin b/src/ai/backend/runner/jail.ubuntu16.04.bin new file mode 100755 index 0000000000..692e29226f --- /dev/null +++ b/src/ai/backend/runner/jail.ubuntu16.04.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ad7d9eea555876b906bfb5eb8c83b6d5df820c409a0849acce5f9c21efbe207 +size 4590720 diff --git a/src/ai/backend/runner/jupyter-custom.css b/src/ai/backend/runner/jupyter-custom.css new file mode 100644 index 0000000000..7cf1ec487b --- /dev/null +++ b/src/ai/backend/runner/jupyter-custom.css @@ -0,0 +1,275 @@ +.navbar-brand { + width: 50px; +} + +#ipython_notebook img { + display: block; + background: url(logo.svg) no-repeat; + background-size: contain; + width: 50px !important; + height: 33px; + padding-left: 80px; + -moz-box-sizing: border-box; + box-sizing: border-box; +} + +#shutdown_widget { + display: none; +} + +.CodeMirror { + font-family: "Source Code Pro", "Menlo", "Consolas", monospace; +} + +.output * { + font-family: "Source Code Pro", "Menlo", "Consolas", monospace !important; +} + +.output button .fa { + font-family: "FontAwesome" !important; +} + +.btn { + -webkit-tap-highlight-color: rgba(0, 0, 0, 0); + -webkit-tap-highlight-color: transparent; + text-transform: uppercase; + outline-width: 0; + -moz-user-select: none; + -ms-user-select: none; + -webkit-user-select: none; + user-select: none; + cursor: pointer; + margin: 0 0.29em; + border-radius: 3px; + vertical-align: middle; +} + +.btn-xs { + height: 32px; + line-height: 28px !important; +} + +#select-all { + margin-top: 9px; +} + +/* File list */ +#tabs { + position: absolute; + top: 11px; + z-index: 110; +} + +#tabs:first-child { + margin-left: 70px; +} + +.nav-tabs {} + +.nav-tabs li { + height: 100%; + transform: translateZ(0); + -webkit-transform: translateZ(0); + transition: opacity 0.1s cubic-bezier(0.4, 0.0, 1, 1); +} + + +.list_container { + box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.14), + 0 1px 5px 0 rgba(0, 0, 0, 0.12), + 0 3px 1px -2px rgba(0, 0, 0, 0.2); +} + +.navbar-nav { + font-size: 14px; + margin-top: 4px; + margin-bottom: 4px; + height: 32px; +} + +.indicator_area { + margin-top: 4px; + margin-bottom: 4px; + height: 32px; +} + +.alternate_upload .btn-upload, +.list_toolbar .btn-group button { + height: 32px; +} + +.alternate_upload .btn-upload { + line-height: 28px !important; +} + +.dropdown-menu { + font-size: 14px; + box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.14), + 0 1px 5px 0 rgba(0, 0, 0, 0.12), + 0 3px 1px -2px rgba(0, 0, 0, 0.2); +} + +/* Notebook UI */ +div.cell.code_cell.selected { + box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.14), + 0 1px 5px 0 rgba(0, 0, 0, 0.12), + 0 3px 1px -2px rgba(0, 0, 0, 0.2); +} + +.checkpoint_status { + color: #5CAD7C; +} + +.autosave_status { + font-size: 10px !important; + color: #C0C0C0; +} + +.kernel_idle_icon, +.kernel_busy_icon { + color: #3E872D; + font-weight: 900; +} + +#kernel_indicator { + border: 0; +} + +.kernel_indicator_name { + height: 32px; + display: inline-block; + font-size: 18px; + line-height: 28px; + color: #5CAD7C; + font-weight: 500; + vertical-align: middle; +} + +/* Icon menu */ +.navbar-default { + background-color: transparent; + border-bottom: 0; + border-left: 0; + border-right: 0; +} + +#maintoolbar-container:first-child { + margin-left: 5px; +} + +.container.toolbar .fa { + font-size: 18px; +} + +.container.toolbar .btn { + border: none; +} + +.toolbar-btn-label { + display: none; +} + +/* Code container */ +.container { + border-radius: 5px; +} + +/* Running button */ +/* +.prompt_container .run_this_cell { + display:flex!important; + color: #009fb7; + margin: 0 auto auto auto; + width: 45px!important; + height: 25px!important; + background: transparent; + background-color: transparent; + border-radius: 3px; + text-align: center; + box-sizing: border-box; + -moz-user-select: none; + -ms-user-select: none; + -webkit-user-select: none; + user-select: none; + cursor: pointer; + padding: 0.7em 0.57em; + } + .running .prompt_container .run_this_cell i { + display: none!important; + } + .running .prompt_container .run_this_cell:after { + font-size:10px; + content: "Running"; + } + .prompt_container .run_this_cell i { + margin: auto; + } + div.code_cell:hover div.input .run_this_cell { + visibility: visible; + } + div.cell.code_cell.rendered.selected .run_this_cell:hover { + background-color: #e3e3e3; + background: #e3e3e3; + color: #5EAE7E !important; + } + div.cell.code_cell.rendered.unselected .run_this_cell:hover { + background-color: #e3e3e3; + background: #e3e3e3; + color: #5EAE7E !important; + } +*/ +/* misc. */ +#logout { + display: none; +} + +#ipython-main-app #tab_content #tab:last-child { + display: none !important; +} + +.clusters_tab_link, +#ipyclusters { + display: none !important; +} + +body.terminal-app #header { + display: none !important; +} + + +/* Editor */ +/* +.CodeMirror { background: #272822; color: #f8f8f2; } +.CodeMirror div.CodeMirror-selected { background: #49483E; } +.CodeMirror .CodeMirror-line::selection, .CodeMirror .CodeMirror-line > span::selection, .CodeMirror .CodeMirror-line > span > span::selection { background: rgba(73, 72, 62, .99); } +.CodeMirror .CodeMirror-line::-moz-selection, .CodeMirror .CodeMirror-line > span::-moz-selection, .CodeMirror .CodeMirror-line > span > span::-moz-selection { background: rgba(73, 72, 62, .99); } +.CodeMirror .CodeMirror-gutters { background: #272822; border-right: 0px; } +.CodeMirror .CodeMirror-guttermarker { color: white; } +.CodeMirror .CodeMirror-guttermarker-subtle { color: #d0d0d0; } +.CodeMirror .CodeMirror-linenumber { color: #d0d0d0; } +.CodeMirror .CodeMirror-cursor { border-left: 1px solid #f8f8f0; } + +.CodeMirror span.cm-comment { color: #75715e; } +.CodeMirror span.cm-atom { color: #ae81ff; } +.CodeMirror span.cm-number { color: #ae81ff; } + +.CodeMirror span.cm-property, .CodeMirror span.cm-attribute { color: #a6e22e; } +.CodeMirror span.cm-keyword { color: #f92672; } +.CodeMirror span.cm-string { color: #e6db74; } + +.CodeMirror span.cm-variable { color: #f8f8f2; } +.CodeMirror span.cm-variable-2 { color: #9effff; } +.CodeMirror span.cm-variable-3 { color: #66d9ef; } +.CodeMirror span.cm-def { color: #fd971f; } +.CodeMirror span.cm-bracket { color: #f8f8f2; } +.CodeMirror span.cm-tag { color: #f92672; } +.CodeMirror span.cm-header { color: #ae81ff; } +.CodeMirror span.cm-link { color: #ae81ff; } +.CodeMirror span.cm-error { background: #f92672; color: #f8f8f0; } + +.CodeMirror .CodeMirror-activeline-background { background: #373831; } +.CodeMirror .CodeMirror-matchingbracket { + text-decoration: underline; + color: white !important; +} +*/ diff --git a/src/ai/backend/runner/krunner-extractor.dockerfile b/src/ai/backend/runner/krunner-extractor.dockerfile new file mode 100644 index 0000000000..85ce04692b --- /dev/null +++ b/src/ai/backend/runner/krunner-extractor.dockerfile @@ -0,0 +1,5 @@ +FROM alpine:3.8 + +RUN apk add --no-cache xz + +# vim: ft=dockerfile diff --git a/src/ai/backend/runner/krunner-extractor.img.aarch64.tar.xz b/src/ai/backend/runner/krunner-extractor.img.aarch64.tar.xz new file mode 100644 index 0000000000..079214be06 --- /dev/null +++ b/src/ai/backend/runner/krunner-extractor.img.aarch64.tar.xz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b56c8abdfd6d07ceb57d0a6bdf4c404a60edf15bee1658e44ce8cb467261c869 +size 1674528 diff --git a/src/ai/backend/runner/krunner-extractor.img.x86_64.tar.xz b/src/ai/backend/runner/krunner-extractor.img.x86_64.tar.xz new file mode 100644 index 0000000000..57d5e3c3d3 --- /dev/null +++ b/src/ai/backend/runner/krunner-extractor.img.x86_64.tar.xz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd989ba19320da4ed97c7f3a938ed526af4b4ef1173dfd3ff853c0a8ad724ea9 +size 1853504 diff --git a/src/ai/backend/runner/krunner-extractor.sh b/src/ai/backend/runner/krunner-extractor.sh new file mode 100755 index 0000000000..7115c04416 --- /dev/null +++ b/src/ai/backend/runner/krunner-extractor.sh @@ -0,0 +1,4 @@ +#! /bin/sh +rm -rf /root/volume/* +tar xJf /root/archive.tar.xz -C /root/volume/ +echo "$KRUNNER_VERSION" > /root/volume/VERSION diff --git a/src/ai/backend/runner/libbaihook.alpine3.8.aarch64.so b/src/ai/backend/runner/libbaihook.alpine3.8.aarch64.so new file mode 100755 index 0000000000..2001582b95 --- /dev/null +++ b/src/ai/backend/runner/libbaihook.alpine3.8.aarch64.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0508907e5cb0ffde95832d8b1c537ee06df8f63916cde5a45c85548e5ac13a1e +size 5511528 diff --git a/src/ai/backend/runner/libbaihook.alpine3.8.x86_64.so b/src/ai/backend/runner/libbaihook.alpine3.8.x86_64.so new file mode 100755 index 0000000000..e349d0bae5 --- /dev/null +++ b/src/ai/backend/runner/libbaihook.alpine3.8.x86_64.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81e7d8989d64c605ce6921054d2577083f76c4d401cb8032819768a118a40aef +size 5294944 diff --git a/src/ai/backend/runner/libbaihook.centos7.6.aarch64.so b/src/ai/backend/runner/libbaihook.centos7.6.aarch64.so new file mode 100755 index 0000000000..95cbdbc412 --- /dev/null +++ b/src/ai/backend/runner/libbaihook.centos7.6.aarch64.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19618d9d0a24f8fd7f0c951d592f3c2feae05a3a0563f2fec98b1e740a9c32e6 +size 521584 diff --git a/src/ai/backend/runner/libbaihook.centos7.6.x86_64.so b/src/ai/backend/runner/libbaihook.centos7.6.x86_64.so new file mode 100755 index 0000000000..61983b08be --- /dev/null +++ b/src/ai/backend/runner/libbaihook.centos7.6.x86_64.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce314c5f7f670935c912b5c2f5447a4efef5a427b2a6332b75a7b2ca2145cac2 +size 481984 diff --git a/src/ai/backend/runner/libbaihook.ubuntu18.04.x86_64.so b/src/ai/backend/runner/libbaihook.ubuntu18.04.x86_64.so new file mode 100755 index 0000000000..8cc9ce3836 --- /dev/null +++ b/src/ai/backend/runner/libbaihook.ubuntu18.04.x86_64.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:45d7605d992ed7ada814143fea081f56e8a6fcc11eababf35fdb8f1e45df9b9f +size 833240 diff --git a/src/ai/backend/runner/libbaihook.ubuntu20.04.aarch64.so b/src/ai/backend/runner/libbaihook.ubuntu20.04.aarch64.so new file mode 100755 index 0000000000..0464d1ab7b --- /dev/null +++ b/src/ai/backend/runner/libbaihook.ubuntu20.04.aarch64.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3765209d2863d1a93cbbc4a7bc3bc0bc106dbb9069f6704d33511fb30c88ddbe +size 497632 diff --git a/src/ai/backend/runner/libbaihook.ubuntu20.04.x86_64.so b/src/ai/backend/runner/libbaihook.ubuntu20.04.x86_64.so new file mode 100755 index 0000000000..4d461f0b04 --- /dev/null +++ b/src/ai/backend/runner/libbaihook.ubuntu20.04.x86_64.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c8d4c355d537ef4ed601ba3f58c474981303d7087c153063e936ebe5ba48361 +size 478480 diff --git a/src/ai/backend/runner/logo.svg b/src/ai/backend/runner/logo.svg new file mode 100644 index 0000000000..77b67ee4a1 --- /dev/null +++ b/src/ai/backend/runner/logo.svg @@ -0,0 +1 @@ + diff --git a/src/ai/backend/runner/requirements.txt b/src/ai/backend/runner/requirements.txt new file mode 100644 index 0000000000..1bf0cd96e4 --- /dev/null +++ b/src/ai/backend/runner/requirements.txt @@ -0,0 +1 @@ +# see src/ai/backend/krunner/*/requirements.txt in "backend.ai-krunner-*" repositories diff --git a/src/ai/backend/runner/roboto-italic.ttf b/src/ai/backend/runner/roboto-italic.ttf new file mode 100644 index 0000000000..663297055b --- /dev/null +++ b/src/ai/backend/runner/roboto-italic.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dbdd703b09a17ae4c16d0d9265eb582349416cb0ba3c77b383f29de6e2b27c5 +size 120832 diff --git a/src/ai/backend/runner/roboto.ttf b/src/ai/backend/runner/roboto.ttf new file mode 100644 index 0000000000..90e78edc7f --- /dev/null +++ b/src/ai/backend/runner/roboto.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7ab2d73cf7d538face08bcdde95b928ce609a970237c8811ca3c76059c8bb2f +size 114624 diff --git a/src/ai/backend/runner/scp.alpine3.8.aarch64.bin b/src/ai/backend/runner/scp.alpine3.8.aarch64.bin new file mode 100755 index 0000000000..530dcb0ef9 --- /dev/null +++ b/src/ai/backend/runner/scp.alpine3.8.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5404df3a51f73a8c814ace7a0eb34d1eef2f11e30e1aa2ddc8e78de5345944a6 +size 2732960 diff --git a/src/ai/backend/runner/scp.alpine3.8.x86_64.bin b/src/ai/backend/runner/scp.alpine3.8.x86_64.bin new file mode 100755 index 0000000000..fde3f56cc3 --- /dev/null +++ b/src/ai/backend/runner/scp.alpine3.8.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e164c144fedb8c6ee36d1722bbe552743a4c2d80fcf04b3e4b6ad032e1207702 +size 3233552 diff --git a/src/ai/backend/runner/scp.centos7.6.aarch64.bin b/src/ai/backend/runner/scp.centos7.6.aarch64.bin new file mode 100755 index 0000000000..5597988602 --- /dev/null +++ b/src/ai/backend/runner/scp.centos7.6.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e53076c5d6972416ba7ce41427dc904b779c79639c57420a65e61d3292e8ae44 +size 3110032 diff --git a/src/ai/backend/runner/scp.centos7.6.x86_64.bin b/src/ai/backend/runner/scp.centos7.6.x86_64.bin new file mode 100755 index 0000000000..bb1bac7df1 --- /dev/null +++ b/src/ai/backend/runner/scp.centos7.6.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d8c1592f4e2ceb3c86d0bf3649663cc282efc3813c9233f058379eaaf0f0151 +size 3277080 diff --git a/src/ai/backend/runner/scp.ubuntu16.04.aarch64.bin b/src/ai/backend/runner/scp.ubuntu16.04.aarch64.bin new file mode 100755 index 0000000000..5df63728cb --- /dev/null +++ b/src/ai/backend/runner/scp.ubuntu16.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d937fef0e78a1e33c4c5623807b7163179e320213b7e36a57d3b9c4ee6df4b54 +size 2850312 diff --git a/src/ai/backend/runner/scp.ubuntu16.04.x86_64.bin b/src/ai/backend/runner/scp.ubuntu16.04.x86_64.bin new file mode 100755 index 0000000000..22b4cb5e9b --- /dev/null +++ b/src/ai/backend/runner/scp.ubuntu16.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:acf41af20594c73557e8e031feb124536f88b1278866969a17a46a9906dd5c68 +size 3346408 diff --git a/src/ai/backend/runner/scp.ubuntu18.04.aarch64.bin b/src/ai/backend/runner/scp.ubuntu18.04.aarch64.bin new file mode 100755 index 0000000000..f71995152f --- /dev/null +++ b/src/ai/backend/runner/scp.ubuntu18.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c11082424a17bd7481d9550ae8e024816e3afab27ea7eeefd65b706913ba83a +size 2880440 diff --git a/src/ai/backend/runner/scp.ubuntu18.04.x86_64.bin b/src/ai/backend/runner/scp.ubuntu18.04.x86_64.bin new file mode 100755 index 0000000000..de9517eeb0 --- /dev/null +++ b/src/ai/backend/runner/scp.ubuntu18.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ef5a5c8b05be87cb09b4307f77f295d160f8fadcfe28f893ec5638beec394ad +size 3367576 diff --git a/src/ai/backend/runner/scp.ubuntu20.04.aarch64.bin b/src/ai/backend/runner/scp.ubuntu20.04.aarch64.bin new file mode 100755 index 0000000000..41baee85d2 --- /dev/null +++ b/src/ai/backend/runner/scp.ubuntu20.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc7999c7de0ef0285d5701be12c13ddd2ec6208309f2ef97cece427d2a59594b +size 3137936 diff --git a/src/ai/backend/runner/scp.ubuntu20.04.x86_64.bin b/src/ai/backend/runner/scp.ubuntu20.04.x86_64.bin new file mode 100755 index 0000000000..781961dd97 --- /dev/null +++ b/src/ai/backend/runner/scp.ubuntu20.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dc02a34ca477a811673f6b269f9d5f7f16a3cc10b7efaf5726f1ae577381e52 +size 3389872 diff --git a/src/ai/backend/runner/sftp-server.alpine3.8.aarch64.bin b/src/ai/backend/runner/sftp-server.alpine3.8.aarch64.bin new file mode 100755 index 0000000000..430b10e73f --- /dev/null +++ b/src/ai/backend/runner/sftp-server.alpine3.8.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:947df921c0030e4a9aee617217f4e287c8f51afb11476a7731912af9fdb650f7 +size 2788048 diff --git a/src/ai/backend/runner/sftp-server.alpine3.8.x86_64.bin b/src/ai/backend/runner/sftp-server.alpine3.8.x86_64.bin new file mode 100755 index 0000000000..939c7c6458 --- /dev/null +++ b/src/ai/backend/runner/sftp-server.alpine3.8.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3b95eef9f04a3a2d3abfb0e472d63aaaf14efb02167c5c8cc84d2145229c77e +size 3284144 diff --git a/src/ai/backend/runner/sftp-server.centos7.6.aarch64.bin b/src/ai/backend/runner/sftp-server.centos7.6.aarch64.bin new file mode 100755 index 0000000000..a6c5570c81 --- /dev/null +++ b/src/ai/backend/runner/sftp-server.centos7.6.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e14dbf5c18e424212f389fc677d107bb5d7893d13131d2f1c8dc970138428768 +size 3218208 diff --git a/src/ai/backend/runner/sftp-server.centos7.6.x86_64.bin b/src/ai/backend/runner/sftp-server.centos7.6.x86_64.bin new file mode 100755 index 0000000000..6120e66e6e --- /dev/null +++ b/src/ai/backend/runner/sftp-server.centos7.6.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc67889781ebafa289162906da3723e269558e53a9cb701afd0719a7d8611c8c +size 3328888 diff --git a/src/ai/backend/runner/sftp-server.ubuntu16.04.aarch64.bin b/src/ai/backend/runner/sftp-server.ubuntu16.04.aarch64.bin new file mode 100755 index 0000000000..a6d3755c97 --- /dev/null +++ b/src/ai/backend/runner/sftp-server.ubuntu16.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a215859c842e974b36798831cf37f2e6c8c0dad128ebd9408718ce109d05d0c +size 2914472 diff --git a/src/ai/backend/runner/sftp-server.ubuntu16.04.x86_64.bin b/src/ai/backend/runner/sftp-server.ubuntu16.04.x86_64.bin new file mode 100755 index 0000000000..5f38ac6b9b --- /dev/null +++ b/src/ai/backend/runner/sftp-server.ubuntu16.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d115de4b4865fa00c143607d664ecb55bd15e8fd37a0d4e2093188cb53a7e401 +size 3396384 diff --git a/src/ai/backend/runner/sftp-server.ubuntu18.04.aarch64.bin b/src/ai/backend/runner/sftp-server.ubuntu18.04.aarch64.bin new file mode 100755 index 0000000000..646eed5b62 --- /dev/null +++ b/src/ai/backend/runner/sftp-server.ubuntu18.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b264472c34a30c05ff96cd3a2130a7a33cc8dd4a23632b9f5ede9ac57504e7b +size 2944072 diff --git a/src/ai/backend/runner/sftp-server.ubuntu18.04.x86_64.bin b/src/ai/backend/runner/sftp-server.ubuntu18.04.x86_64.bin new file mode 100755 index 0000000000..5d3a36cf7a --- /dev/null +++ b/src/ai/backend/runner/sftp-server.ubuntu18.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:30dbff514702d1aeacc6841789f7f1fb1b31388287e88fe162037696c83185d8 +size 3416792 diff --git a/src/ai/backend/runner/sftp-server.ubuntu20.04.aarch64.bin b/src/ai/backend/runner/sftp-server.ubuntu20.04.aarch64.bin new file mode 100755 index 0000000000..29c03c5c61 --- /dev/null +++ b/src/ai/backend/runner/sftp-server.ubuntu20.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aff00b1781bcb9215a43f0b9083721b63ef539fef57363799ffe3025e6aaad89 +size 3206952 diff --git a/src/ai/backend/runner/sftp-server.ubuntu20.04.x86_64.bin b/src/ai/backend/runner/sftp-server.ubuntu20.04.x86_64.bin new file mode 100755 index 0000000000..11111b39e4 --- /dev/null +++ b/src/ai/backend/runner/sftp-server.ubuntu20.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d566588d9c7d62e213ef230ce66e591d0c7a852dce203f032ff803007de2ef3 +size 3456536 diff --git a/src/ai/backend/runner/su-exec.alpine3.8.aarch64.bin b/src/ai/backend/runner/su-exec.alpine3.8.aarch64.bin new file mode 100755 index 0000000000..fd864e3ae4 --- /dev/null +++ b/src/ai/backend/runner/su-exec.alpine3.8.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d90aa643242e20f39c06b123a1e933e0b301a96ce02a52ad175e2336bc9a684 +size 17560 diff --git a/src/ai/backend/runner/su-exec.alpine3.8.x86_64.bin b/src/ai/backend/runner/su-exec.alpine3.8.x86_64.bin new file mode 100755 index 0000000000..bfce35ebdb --- /dev/null +++ b/src/ai/backend/runner/su-exec.alpine3.8.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bda79354f8b2dbdf9826aec939f575b6004140fc6ed3426f095414fb5dafa725 +size 17088 diff --git a/src/ai/backend/runner/su-exec.centos7.6.aarch64.bin b/src/ai/backend/runner/su-exec.centos7.6.aarch64.bin new file mode 100755 index 0000000000..4f4302325f --- /dev/null +++ b/src/ai/backend/runner/su-exec.centos7.6.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e8fae5aa2fee973fdad8e44ed093cd6356d97b25561c6d8cdf11ecb24f902c9 +size 73992 diff --git a/src/ai/backend/runner/su-exec.centos7.6.x86_64.bin b/src/ai/backend/runner/su-exec.centos7.6.x86_64.bin new file mode 100755 index 0000000000..fc52f19845 --- /dev/null +++ b/src/ai/backend/runner/su-exec.centos7.6.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0548cb317865741f8ae8ae9dd3de976b54e2c910db9c818cb03cb52254d01a38 +size 15600 diff --git a/src/ai/backend/runner/su-exec.ubuntu16.04.aarch64.bin b/src/ai/backend/runner/su-exec.ubuntu16.04.aarch64.bin new file mode 100755 index 0000000000..37b0ace0c8 --- /dev/null +++ b/src/ai/backend/runner/su-exec.ubuntu16.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0bc0b7f624ad155068b780375fce35cfaef6688d504c5488e0bc6de612b21db6 +size 16960 diff --git a/src/ai/backend/runner/su-exec.ubuntu16.04.x86_64.bin b/src/ai/backend/runner/su-exec.ubuntu16.04.x86_64.bin new file mode 100755 index 0000000000..fab201ecce --- /dev/null +++ b/src/ai/backend/runner/su-exec.ubuntu16.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4a3f844430fd9e32c8fa11783ddab8b185fc0d810581a0d714e999687bc06dd +size 15808 diff --git a/src/ai/backend/runner/su-exec.ubuntu18.04.aarch64.bin b/src/ai/backend/runner/su-exec.ubuntu18.04.aarch64.bin new file mode 100755 index 0000000000..df8b270610 --- /dev/null +++ b/src/ai/backend/runner/su-exec.ubuntu18.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fae9967e6a730a72cf057cf0b074b81ead325976f91fe13455c00829766cb9a +size 17760 diff --git a/src/ai/backend/runner/su-exec.ubuntu18.04.x86_64.bin b/src/ai/backend/runner/su-exec.ubuntu18.04.x86_64.bin new file mode 100755 index 0000000000..1e351f4765 --- /dev/null +++ b/src/ai/backend/runner/su-exec.ubuntu18.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e525ab77eeef5463379befe654014192602dcdeda1527c0f2033ba4cf3e62d8 +size 16784 diff --git a/src/ai/backend/runner/su-exec.ubuntu20.04.aarch64.bin b/src/ai/backend/runner/su-exec.ubuntu20.04.aarch64.bin new file mode 100755 index 0000000000..711d19cb08 --- /dev/null +++ b/src/ai/backend/runner/su-exec.ubuntu20.04.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50d750bba958688df2ce68a8f9a8831a33e1192c2b9c5610a76dd9bcf76f43df +size 17992 diff --git a/src/ai/backend/runner/su-exec.ubuntu20.04.x86_64.bin b/src/ai/backend/runner/su-exec.ubuntu20.04.x86_64.bin new file mode 100755 index 0000000000..30d088073c --- /dev/null +++ b/src/ai/backend/runner/su-exec.ubuntu20.04.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee966945d4da3180f3791989fb6e9b79c20afec7fabda8cef673c96e03abf733 +size 21328 diff --git a/src/ai/backend/runner/terminfo.alpine3.8/s/screen b/src/ai/backend/runner/terminfo.alpine3.8/s/screen new file mode 100644 index 0000000000000000000000000000000000000000..3687e9ed31adf5b78e3a9174ddf8bdc09ea68283 GIT binary patch literal 1660 zcmchYJxo(k6vzL!5=4Wb#P4TJ9bQyE-&bjCqZOoDlu{IghAvdTTG|Jt9|(d(6CKsT z#OUDSVsvmb8l#JoiGzc2ad2>QaBy&P@OjRCMMGfH#23#0{LZ=O-t*plua~`|5lz^N zjJW41juytpr9f4G))t5Vn zKGa}0oOP}m9LG-ws6sW>?7@DhIE)sm4WJum5XT^K0%N3UWHF1fTxS7`xPt4%H*s5L z?&3Zk5Iw{rJR$!KFYy|0Ld+MAf2JG>Bco zarTjUg(@1!^VKR0aiqd-X0>+GSVZ)SAu%E{q97Kf}< z3?<%4#?1_!qqAh^$oeUx`zVvt{Za!`gRF)oSt3cii*oH8aebVgBHlqkSkEcPq)#hH zp+pGf)^RThu@#^@a5KkOfV<0@o)zF8GDFL|Ks`p~lWJOyyy2tm2=Ji$hQ^@ydAqrb zriRsgyIKnAYKT!MG+i_ZEf-Bfzl#>3&qW`hp@wXp7gS4i>Ug~xf)d3hG&MChx3sjj zYTC$1Dm5`NIhjspGE-C8Y%XWp)6+9E`TXo`p-?QAO6Bt0+{KIY^9u`?E-fxDEnU96 z%sVy=lN-T~IQ>tX0KPNL7vV#12`8iCv0~}cu!Uj6AYtev{^3di3+HtCWpMrnB ze^6yL-E%E{Bn{X>=eL3bI7t5~+F-)Zw?W3Mj-n%hQyLNjI3GYnL&QYChJk)0V#r2% XyL`w-<1x~df$RmH%tUt7z`TuLdM_rP literal 0 HcmV?d00001 diff --git a/src/ai/backend/runner/terminfo.alpine3.8/s/screen-256color b/src/ai/backend/runner/terminfo.alpine3.8/s/screen-256color new file mode 100644 index 0000000000000000000000000000000000000000..cba35030f8c8c4494db40534d103bee81ebb44fa GIT binary patch literal 1995 zcmcgtJ#1T56#lLg6^cqLDt`5#GDufY5=ZYnKmXN$X_^q1#&&2D8pTUe+lgJCs7?@bhckXjQRx_Yf;>qVb-}%lt z_q=zX?RzRPf)SiRMGX(N3e8e!E#yX`h4tEcb7v|$ADlH>aI4ZT2bl<(ObbxvBaXzN zu))hu&QQFU_9(6QAFuwp>=?!oc0Yzw&^V0{SBT?0E+B_j$rn&0mQlq9HqGkW=;Afp zAb$(*80B5Oj}OUi;bVM4{~5l*H~7xvdcE)a{s;OW`}`;RpZmPm`vt$^H~fx2@F)J_ z!yZw`)JgS->h&I{1(jCM(#tibTy>_;ePs~GnVVEubxAF$ifXE^x=!WIJ34L}<(9gw zK2@KqFV)xTfk3bKAnl;*Vd|F#a`}NsdM@9g9JJEiYpa3!|1_y}$T^)u_TA0iyxyOA zBk*m&$}i#`+z%_SE4tz6Vp2C;U2IYauM*N${$;ugH1YkSseh?opT>QZ(zKKtsqq)W zcCT!rGWj`J86(w*Gqenj6C-tkBam`zFCmq#M4qsiu&@d1{6nzvbpcZ-t9Syg$-5q_ zc(bQ%O4I)EJy^MMnhbLtiC&>yq|H)sB2@BD%wXJLLeRAhHT`UG+*t-G zr)IrieA#*))+DLfI+ASFTj6wOYNtwzj_BXl!gWo2^#6y}5by>eg1Lv%P)o zTDQBiv%9+|JGKbxnQ{M>coMr|gN8fubqXhvUP4~4*)4PL-T1D4PPfBmJ9J50*NGY} zVFKnux~NOxcrudotCF&W|LFa8?|^p49_aT5`iJ=U{v@8p7v|*f50Bx0{%;TWO9O^@ zQXIuGesxdcA)Yo*;3>`+g@XvrAcYB@8(GX@5ldJ>4NY{ghwFF)Z{sH3!v{PiKEiF; zMw;Ay0cRI5Wn*Cu^Knetn6)uCi-ioT6X(JLo9A zw7YC~8&OQ8%7-LqG-wFXlmJ0P;0MO2DTxqJq7f4{n*Koffr-i=?H?VV_q=yTH^x{u zd!FY#@AIDboH<|j?9A<>8)!StsDUkw@p`ScxOcT#t1s*^Q9Qmhzf@m8Tp2D5S*5Tt z)0{3?xp1m6vp7{iP{BquNGT%x+hQxQZH=AnKShHuKY(%_O6xNJvYcq)&{nj*o_1o) zH_~2YDA5t>L)N=t-$RqYSz4r3TDRFhLl4qJ^a$*y=-bxwEPaPwqF3kw{g{3V&qew@ zy+?nsTC0^)GNfz3*Jbucz#q-*R_o)cOYKp<856)x({3n}6yP>YG+MufD5ZQ5V#k>TUHq^{)D}x|Fq0xh=UXbFJ3)+zw!?wL909 zGr1_2%{(>XRE$S@6gxlg6`6Lb*%U6L%O2x()a2~y{J#=b^SSgPCu@{uAkP=>KFC*^acHf z{<;2@zNp{Tf6{-||Ik~V4>})qb~-mWdz{coo!gxw&Qa&MGv>@Xb!W}_tn-lbnDbSq z)Bj}|PdewF=baxqul}FyE-U%}2>MCZ=g*w&Tk_T8bOl{a`KmkPF8l68>Mk31q7Fwe z1?b6F%@Ob&D9l;Z;eVOZpogx6s|TfQqInz<7`~ae)oD}?pxh?T)I{tVMCqseCp>Vp z0Pq>gz899J(Wyzj)N8_H*d20y7K;{3 zjK;IJH30($;fuM(@dG>&erk}G2~#p4(;^+0{ptKZNSKDej@%mjATB;yN-Gy^$=zxW(3o?DVDE?wqZWH>Aw3$M<*wzreW62#$?l{YB_n@Do!o`eB z!pLp?$SoXX918nsAh`p{9Z2pV%WY>5xrK`vmxPg9W*14aIS~%CxHOd9q2vy2ZoYA_ z>2M)Ua6MiGl-+9T`6dXXI4PD+@Mh%_l+v)o$KM_cY&_pRQ`~TGahrkN?rzPx{Mhh_U{S{!UtLQ`^Nkq*@9K!dtH>JF&CK|>x5)o7?e$FLRqX{tW9v_Ml0#<_zB0)~(wVu%?M zK!pp76&5QjR#>b+?5-Sy0)zyF280L{wMcV41O12rA_j;UAYy=s0V1ZCd++7md%5>s z?!A|L?>lr13J?+y8W187YGs5~AXFe!AXFgK!J~IU1wsWv1wsWvJ=)7{flz@^flz@^ ohX=VJV0aMv@F3*jL8!xn5JwHojrl0h#|_mNR#&OPu)0A11P6C+4*&oF literal 0 HcmV?d00001 diff --git a/src/ai/backend/runner/terminfo.alpine3.8/x/xterm b/src/ai/backend/runner/terminfo.alpine3.8/x/xterm new file mode 100644 index 0000000000000000000000000000000000000000..84e576e2b4dd8b9822f41b64f99c495e2726dc6a GIT binary patch literal 3617 zcmb_edu*Fm6+g${7uuphC#nIPwUJ8wXsIak_q4sQtyj`)aj9d|G;62{YErvN8aqjC zw@uch8WIyK6N4c(sc1uEXkrY+w9!c&;;Bg+Z&dcr_77;{Aq_O4ipRubGQV@rXIFzs z3~5pB?|1I`oqO)}eSP;j&d_^lBR#+u2bMeaR&&+hpqO(DQ;orT^JHVHv)CHkHZgc; zZlSh#dT?x|-KjUiA4#EBi9RnVx1zM(6ljRQj1f8S1VLwO@)1%I>l6Du;4O*W?QY}uv&RwN#JhPf@8b&J#vkP)(Ca93yufWf&1YfX$>;c< ze!JKHr1SicQNG3(_&fXz|AL?A7x@qTXZ}~xVj#64_4d@J)YjDYl(gNMDx~(MN~v<{ zKpE+*sci$^nE0c< z-tF$PhODAhu?|@uv+7pUI%%z0cUqse?zJAU9=5)2ebf4u^#kii)+Or&>qYA&>$3HC zi|zIHM*E%i^|oX0utR&7J#3fkaeLCP+4FYGUa>!JpR@0=&)Z+JziK~jKV^U4e%AiE z{YHMlm6+%3U)jI0|7QQA-_9B+4qVa2`pkdS#rOHw@rnQc`TIhW1`Qs*aWx&1veAqXVIw_XW~o;9#7nXw8)8-?L#1z z8S@0vd{9cdDn63~U3-2nN6>)VGfh3ytPsv_+3997H||C_#+$BgUV`Pv0otMLhG)O| z75-UDmyx_IKT&9cy8_)GBV@Y-!o(Y#v^?j<6zH_zUNQ^ zfpCsQJt{$-bH-hBb~v?ZN*{&xUsIHWGuaR|N@ryYl*lPKOL?b*g>~v!5+_bqGvbS! zPV73dm?2D0O%GVLEbTYUi%?+iyQ}W_5LK?W|1yVU$8v=CY8*s+Vs; zk$LoESteNxc!bQ#^_^z?a&3yK$tsBh4LY?kLG^wym3u zm2(N_m66-{kz3hMI8b(DUvvAK+t=KFlH1H4ax3Q&&MPCg&MuN9bD|t1adDuz1I-qiileDSKIppO1sQTo}G=O4sc6T}~%(>ZazeIjBWR zEig4VsYRydYprH>W;J+IviLF}&)`loG`W<;7noBE3;VW77G6k2x{KNt!|HYkw|KX^ z$-I%aQwH`;=H+sjrf8NH=@hNu#d8;Zj?UAA^cDIBJxSlDrxEjGx@uj%)&U#7p% zKPkl<_-cL^-@p#uG(KK5dw7J)c+X67jpw<=D|`oklJDkw`F{Q~Kgy5uMgCs@Me_{* zlrP}{XW>n@fv&>4@EvpwMwZ6wY%^`4t$3kjD54lIw_&`}$}~n3G)dDmLk()-jdm7e z{SOAdc>PJ2W0<%&ZVrBEqS$_hoACDJUBW{EUQq*+1@2{k0t zkWfRyen)a`q9dhHDN@QR$rlkJzEYqRDn&|JAjg385*n8Sep)F8au5m-5)c{?A`t56gs4EMK&U{dK&X`iw?PF$1wsWv1wuVgl(s;qK&U{d pK&aytDF_&^Kp(F_9`Zq~n#h(BG literal 0 HcmV?d00001 diff --git a/src/ai/backend/runner/terminfo.alpine3.8/x/xterm+256color b/src/ai/backend/runner/terminfo.alpine3.8/x/xterm+256color new file mode 100644 index 0000000000000000000000000000000000000000..ccccb4a48e18561cceff38e6fa69f17dc4a406dc GIT binary patch literal 1090 zcmb1RQfH81;AhBWe8Nysl3J9jZDeYeoS&0lR0HNJ0J*wgzCv1RVo7OHDgz$C$nc*G zz`zJHl03a+SU0L=1cd+t13NfP{`>!r&c04%Si&G3YiMA?4x~)1RSOJNQ>=^9ldMz= zj8v}=Si*`c;6#XVB`T<7yL|Fg; literal 0 HcmV?d00001 diff --git a/src/ai/backend/runner/terminfo.alpine3.8/x/xterm-256color b/src/ai/backend/runner/terminfo.alpine3.8/x/xterm-256color new file mode 100644 index 0000000000000000000000000000000000000000..225a17bd6f4897515e28e28d4a3d1a5282dbbb8d GIT binary patch literal 3713 zcmb_eZH!!18Gg^4yR7R65J;oVXx+wMOR8<}J@;c~dY2AmyRcKbJ43fy*AAPs`_cXW z=*-Y=+1*N_hF}x?iVz|M0x=j!Boc^V5KKslQ3*j!gkQu2HE0YmX!^sF@p;dCXUfJJ z`@_wi=XuZjyyrb<&euI#GA@0duAzrj=Bky(WNl%miQ@64`K8*Lt#+TDX-s#+*KM73 zB4z)XU|3{>|Jya#6*xbF@+p+&75=3;(ZZpt(0Vg%!Ps`tE@UXs0Xm4RUxNK*Isu%e zMOvXXoBeM38r@6x!G4r}Ku?07rf2B|dWl}8U(s(+e}~?s_voKCy4mbhGNjGm&!qNe z!P`>1*}Or0Q5hAhJ!-GouMVh+I->4S$DmK5%&0|GSEtok*k4!os&BX3ZTzVZs>iJI zlzLYEOr2M+t2fmj)L+#5>SEfWBh!_+I&*zyd&bLf+nY0enM!6TGn^U8JO|J7*4xVZ z({}A8Snu1~h0F(SZ!1n%>rTDVb0xKp**4d;TW;3fdb_Rd0!FF7r>!@ey}G2!`iLIY z6M9b9^{W1=KBvE{zpo$HPw8j$3;O5!RsFjDy?$H&S-+<*=%(&;u68y#pLRaye8CBv z#OZZzbB3I%Gw#efC!H1NE6zR6H=PHZN1P{|r=6cTFFWU**PP!uSMm!k#k}RbXrg%>MHs%Bx7BG>dQtX?Gcgf+ zDkuY#y}<)V3jm*?=zC#d8l9TNOS}d=hTS2z$g`DeL&%mL^9a)1%q3f8pUr`;P0-2_ z&ckhNQ)8PIqv>ruzUS@Ui*SrLTiv=0D@a1LL)i<@LHjF$vy?3(d6~y+e6HP&atLJz z9zTMo>gO!xEfyG!2Q2rq#m0UZG5a~Z%#-0K23eUfEoGS+>9^=lXK#drpAp!BTZ1RS z{#YJ|H3nJ+=dqrfS_Y5IdJeV>9)R_5-_|(XGPv&h$TAl{07Vx@&cHZEb!1EHA=Qlwyu&w3x%%mv2Gg*$fh1EM60MnC$ws zopuqW7PXx`pN2O{)r4tdULiEbHetd}VUo>~Hp``)Pq`p$yp$1dGo}MFVH%f2k~@;z zk>rlj-04HZ!K6MH<8+SJ${dj?Dtt-CvL7Cg^ecXzvHml@g3 zyL9!xDB>b=2dcuBT>_Opk;Jd5&b+^ZD1h!*$(z^WE`U))_$IzV&u&6RTdki)< z4=>Z)_cpW-H@3&#Um%S~REf?}UE@7wUps^^LbG{2USiGW4)TC|tp6w-r)gTEQ?!oP z(>eMMJxGtz6Z9i`j$Wi!5c5lVgWjaK=})l#M*pCHQATyCYt<*!7Ukl76ySBVPxY%Z z-bSNpLd~g~I-~AZ_o(~S_tZn`arLD7v3g#;+96#6xF{I1$^Q8NhlNv#X>nI zlO94Dl1ZV#yLqmRPbx5)w&BBq5Q6MD33F@Y)<11wsWv1wsWv9jaYMC*l@+QptSr#K0hz?*l>h($ literal 0 HcmV?d00001 diff --git a/src/ai/backend/runner/tmux.glibc.aarch64.bin b/src/ai/backend/runner/tmux.glibc.aarch64.bin new file mode 100755 index 0000000000..de57d2b278 --- /dev/null +++ b/src/ai/backend/runner/tmux.glibc.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4c9f89c8b6ba146e521347c24e430ba49a40b8a6886c40ce160e14f882d153a +size 2499160 diff --git a/src/ai/backend/runner/tmux.glibc.x86_64.bin b/src/ai/backend/runner/tmux.glibc.x86_64.bin new file mode 100755 index 0000000000..9e093c9441 --- /dev/null +++ b/src/ai/backend/runner/tmux.glibc.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4bddedd23d0955afdd977feaa8c2e1a23e6bca6d68d78f73dec9179595ca62b +size 2922768 diff --git a/src/ai/backend/runner/tmux.musl.aarch64.bin b/src/ai/backend/runner/tmux.musl.aarch64.bin new file mode 100755 index 0000000000..c239998399 --- /dev/null +++ b/src/ai/backend/runner/tmux.musl.aarch64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a2457e9115de2b8428ac46b32a03be715a1744e18136d7f748da3eedfd74e00 +size 2456544 diff --git a/src/ai/backend/runner/tmux.musl.x86_64.bin b/src/ai/backend/runner/tmux.musl.x86_64.bin new file mode 100755 index 0000000000..d8f71e8467 --- /dev/null +++ b/src/ai/backend/runner/tmux.musl.x86_64.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b9165cfe09e5f8c4362b7b52d5d267bfc56b52e8f0a60adc8b07c477a21d722 +size 2416808 diff --git a/src/ai/backend/storage/BUILD b/src/ai/backend/storage/BUILD new file mode 100644 index 0000000000..c567443a67 --- /dev/null +++ b/src/ai/backend/storage/BUILD @@ -0,0 +1,44 @@ +python_sources( + name="service", + sources=["**/*.py"], + dependencies=[ + "src/ai/backend/common:lib", + ":resources", + ], +) + +pex_binary( + name="server", + dependencies=[ + ":service", + ], + entry_point="server.py", +) + +python_distribution( + name="dist", + dependencies=[ + ":service", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-storage-proxy", + description="Backend.AI Storage Proxy", + license="LGPLv3", + ), + entry_points={}, + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + ], +) diff --git a/src/ai/backend/storage/README.md b/src/ai/backend/storage/README.md new file mode 100644 index 0000000000..6df908efa6 --- /dev/null +++ b/src/ai/backend/storage/README.md @@ -0,0 +1,162 @@ +# Backend.AI Storage Proxy +Backend.AI Storage Proxy is an RPC daemon to manage vfolders used in Backend.AI agent, with quota and +storage-specific optimization support. + + +## Package Structure +* `ai.backend.storage` + - `server`: The agent daemon which communicates between Backend.AI Manager + - `api.client`: The client-facing API to handle tus.io server-side protocol for uploads and ranged HTTP + queries for downloads. + - `api.manager`: The manager-facing (internal) API to provide abstraction of volumes and separation of + the hardware resources for volume and file operations. + - `vfs` + - The minimal fallback backend which only uses the standard Linux filesystem interfaces + - `xfs` + - XFS-optimized backend with a small daemon to manage XFS project IDs for quota limits + - `agent`: Implementation of `AbstractVolumeAgent` with XFS support + - `purestorage` + - PureStorage's FlashBlade-optimized backend with RapidFile Toolkit (formerly PureTools) + - `netapp` + - NetApp QTree integration backend based on the NetApp ONTAP REST API + - `cephfs` (TODO) + - CephFS-optimized backend with quota limit support + + +## Installation + +### Prequisites +* Python 3.8 or higher with [pyenv](https://github.com/pyenv/pyenv) +and [pyenv-virtualenv](https://github.com/pyenv/pyenv-virtualenv) (optional but recommneded) + +### Installation Process + +First, prepare the source clone of this agent: +```console +# git clone https://github.com/lablup/backend.ai-storage-agent +``` + +From now on, let's assume all shell commands are executed inside the virtualenv. + +Now install dependencies: +```console +# pip install -U -r requirements/dist.txt # for deployment +# pip install -U -r requirements/dev.txt # for development +``` + +Then, copy halfstack.toml to root of the project folder and edit to match your machine: +```console +# cp config/sample.toml storage-proxy.toml +``` + +When done, start storage server: +```console +# python -m ai.backend.storage.server +``` + +It will start Storage Proxy daemon bound at `127.0.0.1:6021` (client API) and +`127.0.0.1:6022` (manager API). + +NOTE: Depending on the backend, the server may require to be run as root. + +### Production Deployment + +To get performance boosts by using OS-provided `sendfile()` syscall +for file transfers, SSL termination should be handled by reverse-proxies +such as nginx and the storage proxy daemon itself should be run without SSL. + + +## Filesystem Backends + +### VFS + +#### Prerequisites + +* User account permission to access for the given directory + - Make sure a directory such as `/vfroot/vfs` a directory or you want to mount exists + + +### XFS + +#### Prerequisites + +* Local device mounted under `/vfroot` +* Native support for XFS filesystem + - Mounting XFS volume with an option `-o pquota` to enable project quota + - To turn on quotas on the root filesystem, the quota mount flags must be + set with the `rootflags=` boot parameter. Usually, this is not recommended. +* Access to root privilege + - Execution of `xfs_quota`, which performs quota-related commands, requires + the `root` privilege. + - Thus, you need to start the Storage-Proxy service by a `root` user or a + user with passwordless sudo access. + - If the root user starts the Storage-Proxy, the owner of every file created + is also root. In some situations, this would not be the desired setting. + In that case, it might be better to start the service with a regular user + with passwordless sudo privilege. + +#### Creating virtual XFS device for testing + +Create a virtual block device mounted to `lo` (loopback) if you are the only one +to use the storage for testing: + +1. Create file with your desired size +```console +# dd if=/dev/zero of=xfs_test.img bs=1G count=100 +``` +2. Make file as XFS partition +```console +# mkfs.xfs xfs_test.img +``` +3. Mount it to loopback +```console +# export LODEVICE=$(losetup -f) +# losetup $LODEVICE xfs_test.img +``` +4. Create mount point and mount loopback device, with pquota option +```console +# mkdir -p /vfroot/xfs +# mount -o loop -o pquota $LODEVICE /vfroot/xfs +``` + +#### Note on operation + +XFS keeps quota mapping information on two files: `/etc/projects` and +`/etc/projid`. If they are deleted or damaged in any way, per-directory quota +information will also be lost. So, it is crucial not to delete them +accidentally. If possible, it is a good idea to backup them to a different disk +or NFS. + + +### PureStorage FlashBlade + +#### Prerequisites + +* NFSv3 export mounted under `/vfroot` +* Purity API access + + +### CephFS + +#### Prerequisites + +* FUSE export mounted unde `/vfroot` + + +### NetApp ONTAP + +#### Prerequisites + +* NFSv3 export mounted under `/vfroot` +* NetApp ONTAP API access +* native NetApp XCP or Dockerized NetApp XCP container + - To install NetApp XCP, please refer [NetApp XCP install guide](https://xcp.netapp.com/) +* Create Qtree in Volume explicitly using NetApp ONTAP Sysmgr GUI + + +#### Note on operation +The volume host of Backend.AI Storage proxy corresponds to Qtree of NetApp ONTAP, not NetApp ONTAP Volume. +Please DO NOT remove Backend.AI mapped qtree in NetApp ONTAP Sysmgr GUI. If not, you cannot access to NetApp ONTAP Volume through Backend.AI. + +> NOTE: +Qtree name in configuration file(`storage-proxy.toml`) must have the same name created in NetApp ONTAP Sysmgr. diff --git a/src/ai/backend/storage/VERSION b/src/ai/backend/storage/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/storage/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/storage/__init__.py b/src/ai/backend/storage/__init__.py new file mode 100644 index 0000000000..17b3552989 --- /dev/null +++ b/src/ai/backend/storage/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +__version__ = (Path(__file__).parent / 'VERSION').read_text().strip() diff --git a/src/ai/backend/storage/abc.py b/src/ai/backend/storage/abc.py new file mode 100644 index 0000000000..d9cb52a458 --- /dev/null +++ b/src/ai/backend/storage/abc.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from pathlib import Path, PurePath, PurePosixPath +from typing import Any, AsyncIterator, Final, FrozenSet, Mapping, Sequence +from uuid import UUID + +from ai.backend.common.types import BinarySize, HardwareMetadata + +from .exception import InvalidSubpathError, VFolderNotFoundError +from .types import ( + DirEntry, + FSPerfMetric, + FSUsage, + VFolderCreationOptions, + VFolderUsage, +) + +# Available capabilities of a volume implementation +CAP_VFOLDER: Final = "vfolder" +CAP_VFHOST_QUOTA: Final = "vfhost-quota" +CAP_METRIC: Final = "metric" +CAP_QUOTA: Final = "quota" +CAP_FAST_SCAN: Final = "fast-scan" + + +class AbstractVolume(metaclass=ABCMeta): + def __init__( + self, + local_config: Mapping[str, Any], + mount_path: Path, + *, + fsprefix: PurePath = None, + options: Mapping[str, Any] = None, + ) -> None: + self.local_config = local_config + self.mount_path = mount_path + self.fsprefix = fsprefix or PurePath(".") + self.config = options or {} + + async def init(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def mangle_vfpath(self, vfid: UUID) -> Path: + prefix1 = vfid.hex[0:2] + prefix2 = vfid.hex[2:4] + rest = vfid.hex[4:] + return Path(self.mount_path, prefix1, prefix2, rest) + + def sanitize_vfpath( + self, + vfid: UUID, + relpath: PurePosixPath = PurePosixPath("."), + ) -> Path: + vfpath = self.mangle_vfpath(vfid).resolve() + if not (vfpath.exists() and vfpath.is_dir()): + raise VFolderNotFoundError(vfid) + target_path = (vfpath / relpath).resolve() + if not target_path.is_relative_to(vfpath): + raise InvalidSubpathError(vfid, relpath) + return target_path + + def strip_vfpath(self, vfid: UUID, target_path: Path) -> PurePosixPath: + vfpath = self.mangle_vfpath(vfid).resolve() + return PurePosixPath(target_path.relative_to(vfpath)) + + # ------ volume operations ------- + + @abstractmethod + async def get_capabilities(self) -> FrozenSet[str]: + pass + + @abstractmethod + async def get_hwinfo(self) -> HardwareMetadata: + pass + + @abstractmethod + async def create_vfolder( + self, + vfid: UUID, + options: VFolderCreationOptions = None, + *, + exist_ok: bool = False, + ) -> None: + pass + + @abstractmethod + async def delete_vfolder(self, vfid: UUID) -> None: + pass + + @abstractmethod + async def clone_vfolder( + self, + src_vfid: UUID, + dst_volume: AbstractVolume, + dst_vfid: UUID, + options: VFolderCreationOptions = None, + ) -> None: + """ + Create a new vfolder on the destination volume with + ``exist_ok=True`` option and copy all contents of the source + vfolder into it, preserving file permissions and timestamps. + """ + pass + + @abstractmethod + async def copy_tree( + self, + src_vfpath: Path, + dst_vfpath: Path, + ) -> None: + """ + The actual backend-specific implementation of copying + files from a directory to another in an efficient way. + The source and destination are in the same filesystem namespace + but they may be on different physical media. + """ + pass + + @abstractmethod + async def get_vfolder_mount(self, vfid: UUID, subpath: str) -> Path: + pass + + @abstractmethod + async def put_metadata(self, vfid: UUID, payload: bytes) -> None: + pass + + @abstractmethod + async def get_metadata(self, vfid: UUID) -> bytes: + pass + + @abstractmethod + async def get_performance_metric(self) -> FSPerfMetric: + pass + + @abstractmethod + async def get_quota(self, vfid: UUID) -> BinarySize: + pass + + @abstractmethod + async def set_quota(self, vfid: UUID, size_bytes: BinarySize) -> None: + pass + + @abstractmethod + async def get_fs_usage(self) -> FSUsage: + pass + + @abstractmethod + async def get_usage( + self, + vfid: UUID, + relpath: PurePosixPath = PurePosixPath("."), + ) -> VFolderUsage: + pass + + @abstractmethod + async def get_used_bytes(self, vfid: UUID) -> BinarySize: + pass + + # ------ vfolder operations ------- + + @abstractmethod + def scandir(self, vfid: UUID, relpath: PurePosixPath) -> AsyncIterator[DirEntry]: + pass + + @abstractmethod + async def mkdir( + self, + vfid: UUID, + relpath: PurePosixPath, + *, + parents: bool = False, + exist_ok: bool = False, + ) -> None: + pass + + @abstractmethod + async def rmdir( + self, + vfid: UUID, + relpath: PurePosixPath, + *, + recursive: bool = False, + ) -> None: + pass + + @abstractmethod + async def move_file( + self, + vfid: UUID, + src: PurePosixPath, + dst: PurePosixPath, + ) -> None: + pass + + @abstractmethod + async def move_tree( + self, + vfid: UUID, + src: PurePosixPath, + dst: PurePosixPath, + ) -> None: + pass + + @abstractmethod + async def copy_file( + self, + vfid: UUID, + src: PurePosixPath, + dst: PurePosixPath, + ) -> None: + pass + + @abstractmethod + async def prepare_upload(self, vfid: UUID) -> str: + """ + Prepare an upload session by creating a dedicated temporary directory. + Returns a unique session identifier. + """ + pass + + @abstractmethod + async def add_file( + self, + vfid: UUID, + relpath: PurePosixPath, + payload: AsyncIterator[bytes], + ) -> None: + pass + + @abstractmethod + def read_file( + self, + vfid: UUID, + relpath: PurePosixPath, + *, + chunk_size: int = 0, + ) -> AsyncIterator[bytes]: + pass + + @abstractmethod + async def delete_files( + self, + vfid: UUID, + relpaths: Sequence[PurePosixPath], + recursive: bool = False, + ) -> None: + pass diff --git a/src/ai/backend/storage/api/__init__.py b/src/ai/backend/storage/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/storage/api/client.py b/src/ai/backend/storage/api/client.py new file mode 100644 index 0000000000..2e0f5a8ceb --- /dev/null +++ b/src/ai/backend/storage/api/client.py @@ -0,0 +1,348 @@ +""" +Client-facing API +""" + +import asyncio +import json +import logging +import os +import urllib.parse +from pathlib import Path +from typing import Any, Final, Mapping, MutableMapping, cast + +import aiohttp_cors +import janus +import trafaret as t +import zipstream +from aiohttp import hdrs, web + +from ai.backend.common import validators as tx +from ai.backend.common.files import AsyncFileWriter +from ai.backend.common.logging import BraceStyleAdapter + +from ..abc import AbstractVolume +from ..context import Context +from ..exception import InvalidAPIParameters +from ..types import SENTINEL +from ..utils import CheckParamSource, check_params + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +DEFAULT_CHUNK_SIZE: Final = 256 * 1024 # 256 KiB +DEFAULT_INFLIGHT_CHUNKS: Final = 8 + + +download_token_data_iv = t.Dict( + { + t.Key("op"): t.Atom("download"), + t.Key("volume"): t.String, + t.Key("vfid"): tx.UUID, + t.Key("relpath"): t.String, + t.Key("archive", default=False): t.Bool, + t.Key("unmanaged_path", default=None): t.Null | t.String, + }, +).allow_extra( + "*", +) # allow JWT-intrinsic keys + +upload_token_data_iv = t.Dict( + { + t.Key("op"): t.Atom("upload"), + t.Key("volume"): t.String, + t.Key("vfid"): tx.UUID, + t.Key("relpath"): t.String, + t.Key("session"): t.String, + t.Key("size"): t.Int, + }, +).allow_extra( + "*", +) # allow JWT-intrinsic keys + + +async def download(request: web.Request) -> web.StreamResponse: + ctx: Context = request.app["ctx"] + secret = ctx.local_config["storage-proxy"]["secret"] + async with check_params( + request, + t.Dict( + { + t.Key("token"): tx.JsonWebToken( + secret=secret, + inner_iv=download_token_data_iv, + ), + t.Key("archive", default=False): t.ToBool, + t.Key("no_cache", default=False): t.ToBool, + }, + ), + read_from=CheckParamSource.QUERY, + ) as params: + async with ctx.get_volume(params["token"]["volume"]) as volume: + token_data = params["token"] + if token_data["unmanaged_path"] is not None: + vfpath = Path(token_data["unmanaged_path"]) + else: + vfpath = volume.mangle_vfpath(token_data["vfid"]) + try: + file_path = (vfpath / token_data["relpath"]).resolve() + file_path.relative_to(vfpath) + if not file_path.exists(): + raise FileNotFoundError + except (ValueError, FileNotFoundError): + raise web.HTTPNotFound( + body=json.dumps( + { + "title": "File not found", + "type": "https://api.backend.ai/probs/storage/file-not-found", + }, + ), + content_type="application/problem+json", + ) + if not file_path.is_file(): + if params["archive"]: + # Download directory as an archive when archive param is set. + return await download_directory_as_archive(request, file_path) + else: + raise InvalidAPIParameters("The file is not a regular file.") + if request.method == "HEAD": + return web.Response( + status=200, + headers={ + hdrs.ACCEPT_RANGES: "bytes", + hdrs.CONTENT_LENGTH: str(file_path.stat().st_size), + }, + ) + ascii_filename = ( + file_path.name.encode("ascii", errors="ignore") + .decode("ascii") + .replace('"', r"\"") + ) + encoded_filename = urllib.parse.quote(file_path.name, encoding="utf-8") + headers = { + hdrs.CONTENT_TYPE: "application/octet-stream", + hdrs.CONTENT_DISPOSITION: " ".join( + [ + "attachment;" f'filename="{ascii_filename}";', # RFC-2616 sec2.2 + f"filename*=UTF-8''{encoded_filename}", # RFC-5987 + ], + ), + } + if params["no_cache"]: + headers[hdrs.CACHE_CONTROL] = "no-store" + return web.FileResponse(file_path, headers=cast(Mapping[str, str], headers)) + + +async def download_directory_as_archive( + request: web.Request, + file_path: Path, + zip_filename: str = None, +) -> web.StreamResponse: + """ + Serve a directory as a zip archive on the fly. + """ + + def _iter2aiter(iter): + """Iterable to async iterable""" + + def _consume(loop, iter, q): + for item in iter: + q.put(item) + q.put(SENTINEL) + + async def _aiter(): + loop = asyncio.get_running_loop() + q = janus.Queue(maxsize=DEFAULT_INFLIGHT_CHUNKS) + try: + fut = loop.run_in_executor(None, lambda: _consume(loop, iter, q.sync_q)) + while True: + item = await q.async_q.get() + if item is SENTINEL: + break + yield item + q.async_q.task_done() + await fut + finally: + q.close() + await q.wait_closed() + + return _aiter() + + if zip_filename is None: + zip_filename = file_path.name + ".zip" + zf = zipstream.ZipFile(compression=zipstream.ZIP_DEFLATED) + async for root, dirs, files in _iter2aiter(os.walk(file_path)): + for file in files: + zf.write(Path(root) / file, Path(root).relative_to(file_path) / file) + if len(dirs) == 0 and len(files) == 0: + # Include an empty directory in the archive as well. + zf.write(root, Path(root).relative_to(file_path)) + ascii_filename = ( + zip_filename.encode("ascii", errors="ignore") + .decode("ascii") + .replace('"', r"\"") + ) + encoded_filename = urllib.parse.quote(zip_filename, encoding="utf-8") + response = web.StreamResponse( + headers={ + hdrs.CONTENT_TYPE: "application/zip", + hdrs.CONTENT_DISPOSITION: " ".join( + [ + "attachment;" f'filename="{ascii_filename}";', # RFC-2616 sec2.2 + f"filename*=UTF-8''{encoded_filename}", # RFC-5987 + ], + ), + }, + ) + await response.prepare(request) + async for chunk in _iter2aiter(zf): + await response.write(chunk) + return response + + +async def tus_check_session(request: web.Request) -> web.Response: + """ + Check the availability of an upload session. + """ + ctx: Context = request.app["ctx"] + secret = ctx.local_config["storage-proxy"]["secret"] + async with check_params( + request, + t.Dict( + { + t.Key("token"): tx.JsonWebToken( + secret=secret, + inner_iv=upload_token_data_iv, + ), + }, + ), + read_from=CheckParamSource.QUERY, + ) as params: + token_data = params["token"] + async with ctx.get_volume(token_data["volume"]) as volume: + headers = await prepare_tus_session_headers(request, token_data, volume) + return web.Response(headers=headers) + + +async def tus_upload_part(request: web.Request) -> web.Response: + """ + Perform the chunk upload. + """ + ctx: Context = request.app["ctx"] + secret = ctx.local_config["storage-proxy"]["secret"] + async with check_params( + request, + t.Dict( + { + t.Key("token"): tx.JsonWebToken( + secret=secret, + inner_iv=upload_token_data_iv, + ), + }, + ), + read_from=CheckParamSource.QUERY, + ) as params: + token_data = params["token"] + async with ctx.get_volume(token_data["volume"]) as volume: + headers = await prepare_tus_session_headers(request, token_data, volume) + vfpath = volume.mangle_vfpath(token_data["vfid"]) + upload_temp_path = vfpath / ".upload" / token_data["session"] + + async with AsyncFileWriter( + target_filename=upload_temp_path, + access_mode="ab", + max_chunks=DEFAULT_INFLIGHT_CHUNKS, + ) as writer: + while not request.content.at_eof(): + chunk = await request.content.read(DEFAULT_CHUNK_SIZE) + await writer.write(chunk) + + current_size = Path(upload_temp_path).stat().st_size + if current_size >= int(token_data["size"]): + target_path = vfpath / token_data["relpath"] + upload_temp_path.rename(target_path) + try: + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: upload_temp_path.parent.rmdir(), + ) + except OSError: + pass + headers["Upload-Offset"] = str(current_size) + return web.Response(status=204, headers=headers) + + +async def tus_options(request: web.Request) -> web.Response: + """ + Let clients discover the supported features of our tus.io server-side implementation. + """ + ctx: Context = request.app["ctx"] + headers = {} + headers["Access-Control-Allow-Origin"] = "*" + headers[ + "Access-Control-Allow-Headers" + ] = "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type" + headers[ + "Access-Control-Expose-Headers" + ] = "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type" + headers["Access-Control-Allow-Methods"] = "*" + headers["Tus-Resumable"] = "1.0.0" + headers["Tus-Version"] = "1.0.0" + headers["Tus-Max-Size"] = str( + int(ctx.local_config["storage-proxy"]["max-upload-size"]), + ) + headers["X-Content-Type-Options"] = "nosniff" + return web.Response(headers=headers) + + +async def prepare_tus_session_headers( + request: web.Request, + token_data: Mapping[str, Any], + volume: AbstractVolume, +) -> MutableMapping[str, str]: + vfpath = volume.mangle_vfpath(token_data["vfid"]) + upload_temp_path = vfpath / ".upload" / token_data["session"] + if not Path(upload_temp_path).exists(): + raise web.HTTPNotFound( + body=json.dumps( + { + "title": "No such upload session", + "type": "https://api.backend.ai/probs/storage/no-such-upload-session", + }, + ), + content_type="application/problem+json", + ) + headers = {} + headers["Access-Control-Allow-Origin"] = "*" + headers[ + "Access-Control-Allow-Headers" + ] = "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type" + headers[ + "Access-Control-Expose-Headers" + ] = "Tus-Resumable, Upload-Length, Upload-Metadata, Upload-Offset, Content-Type" + headers["Access-Control-Allow-Methods"] = "*" + headers["Cache-Control"] = "no-store" + headers["Tus-Resumable"] = "1.0.0" + headers["Upload-Offset"] = str(Path(upload_temp_path).stat().st_size) + headers["Upload-Length"] = str(token_data["size"]) + return headers + + +async def init_client_app(ctx: Context) -> web.Application: + app = web.Application() + app["ctx"] = ctx + cors_options = { + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + allow_methods="*", + expose_headers="*", + allow_headers="*", + ), + } + cors = aiohttp_cors.setup(app, defaults=cors_options) + r = cors.add(app.router.add_resource("/download")) + r.add_route("GET", download) + r = app.router.add_resource("/upload") # tus handlers handle CORS by themselves + r.add_route("OPTIONS", tus_options) + r.add_route("HEAD", tus_check_session) + r.add_route("PATCH", tus_upload_part) + return app diff --git a/src/ai/backend/storage/api/manager.py b/src/ai/backend/storage/api/manager.py new file mode 100644 index 0000000000..4ca7c9ac38 --- /dev/null +++ b/src/ai/backend/storage/api/manager.py @@ -0,0 +1,658 @@ +""" +Manager-facing API +""" + +import json +import logging +from contextlib import contextmanager as ctxmgr +from datetime import datetime +from pathlib import Path +from typing import Awaitable, Callable, Iterator, List +from uuid import UUID + +import attr +import jwt +import trafaret as t +from aiohttp import hdrs, web + +from ai.backend.common import validators as tx +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.storage.exception import ExecutionError + +from ..abc import AbstractVolume +from ..context import Context +from ..exception import InvalidSubpathError, VFolderNotFoundError +from ..types import VFolderCreationOptions +from ..utils import check_params, log_manager_api_entry + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +@web.middleware +async def token_auth_middleware( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]], +) -> web.StreamResponse: + token = request.headers.get("X-BackendAI-Storage-Auth-Token", None) + if not token: + raise web.HTTPForbidden() + ctx: Context = request.app["ctx"] + if token != ctx.local_config["api"]["manager"]["secret"]: + raise web.HTTPForbidden() + return await handler(request) + + +async def get_status(request: web.Request) -> web.Response: + async with check_params(request, None) as params: + await log_manager_api_entry(log, "get_status", params) + return web.json_response( + { + "status": "ok", + }, + ) + + +@ctxmgr +def handle_fs_errors( + volume: AbstractVolume, + vfid: UUID, +) -> Iterator[None]: + try: + yield + except OSError as e: + related_paths = [] + msg = str(e) if e.strerror is None else e.strerror + if e.filename: + related_paths.append(str(volume.strip_vfpath(vfid, Path(e.filename)))) + if e.filename2: + related_paths.append(str(volume.strip_vfpath(vfid, Path(e.filename2)))) + raise web.HTTPBadRequest( + body=json.dumps( + { + "msg": msg, + "errno": e.errno, + "paths": related_paths, + }, + ), + content_type="application/json", + ) + + +async def get_volumes(request: web.Request) -> web.Response: + async def _get_caps(ctx: Context, volume_name: str) -> List[str]: + async with ctx.get_volume(volume_name) as volume: + return [*await volume.get_capabilities()] + + async with check_params(request, None) as params: + await log_manager_api_entry(log, "get_volumes", params) + ctx: Context = request.app["ctx"] + volumes = ctx.list_volumes() + return web.json_response( + { + "volumes": [ + { + "name": name, + "backend": info.backend, + "path": str(info.path), + "fsprefix": str(info.fsprefix), + "capabilities": await _get_caps(ctx, name), + } + for name, info in volumes.items() + ], + }, + ) + + +async def get_hwinfo(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + }, + ), + ) as params: + await log_manager_api_entry(log, "get_hwinfo", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + data = await volume.get_hwinfo() + return web.json_response(data) + + +async def create_vfolder(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("options", default=None): t.Null + | VFolderCreationOptions.as_trafaret(), + }, + ), + ) as params: + await log_manager_api_entry(log, "create_vfolder", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + obj_opts = VFolderCreationOptions.as_object(params["options"]) + await volume.create_vfolder(params["vfid"], obj_opts) + return web.Response(status=204) + + +async def delete_vfolder(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + }, + ), + ) as params: + await log_manager_api_entry(log, "delete_vfolder", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + await volume.delete_vfolder(params["vfid"]) + return web.Response(status=204) + + +async def clone_vfolder(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("src_volume"): t.String(), + t.Key("src_vfid"): tx.UUID(), + t.Key("dst_volume"): t.String(), + t.Key("dst_vfid"): tx.UUID(), + t.Key("options", default=None): t.Null + | VFolderCreationOptions.as_trafaret(), + }, + ), + ) as params: + await log_manager_api_entry(log, "clone_vfolder", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["src_volume"]) as src_volume: + async with ctx.get_volume(params["dst_volume"]) as dst_volume: + await src_volume.clone_vfolder( + params["src_vfid"], + dst_volume, + params["dst_vfid"], + ) + return web.Response(status=204) + + +async def get_vfolder_mount(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("subpath", default="."): t.String(), + }, + ), + ) as params: + await log_manager_api_entry(log, "get_container_mount", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + try: + mount_path = await volume.get_vfolder_mount( + params["vfid"], + params["subpath"], + ) + except VFolderNotFoundError: + raise web.HTTPBadRequest( + body=json.dumps( + { + "msg": "VFolder not found", + "vfid": str(params["vfid"]), + }, + ), + content_type="application/json", + ) + except InvalidSubpathError as e: + raise web.HTTPBadRequest( + body=json.dumps( + { + "msg": "Invalid vfolder subpath", + "vfid": str(params["vfid"]), + "subpath": str(e.args[1]), + }, + ), + content_type="application/json", + ) + return web.json_response( + { + "path": str(mount_path), + }, + ) + + +async def get_performance_metric(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + }, + ), + ) as params: + await log_manager_api_entry(log, "get_performance_metric", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + metric = await volume.get_performance_metric() + return web.json_response( + { + "metric": attr.asdict(metric), + }, + ) + + +async def fetch_file(request: web.Request) -> web.StreamResponse: + """ + Direct file streaming API for internal use, such as retrieving + task logs from a user vfolder ".logs". + """ + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("relpath"): tx.PurePath(relative_only=True), + }, + ), + ) as params: + await log_manager_api_entry(log, "fetch_file", params) + ctx: Context = request.app["ctx"] + response = web.StreamResponse(status=200) + response.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" + try: + prepared = False + async with ctx.get_volume(params["volume"]) as volume: + with handle_fs_errors(volume, params["vfid"]): + async for chunk in volume.read_file( + params["vfid"], + params["relpath"], + ): + if not chunk: + return response + if not prepared: + await response.prepare(request) + prepared = True + await response.write(chunk) + except FileNotFoundError: + response = web.Response(status=404, reason="Log data not found") + finally: + if prepared: + await response.write_eof() + return response + + +async def get_metadata(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + }, + ), + ) as params: + await log_manager_api_entry(log, "get_metadata", params) + return web.json_response( + { + "status": "ok", + }, + ) + + +async def set_metadata(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("payload"): t.Bytes(), + }, + ), + ) as params: + await log_manager_api_entry(log, "set_metadata", params) + return web.json_response( + { + "status": "ok", + }, + ) + + +async def get_vfolder_fs_usage(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + }, + ), + ) as params: + await log_manager_api_entry(log, "get_vfolder_fs_usage", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + fs_usage = await volume.get_fs_usage() + return web.json_response( + { + "capacity_bytes": fs_usage.capacity_bytes, + "used_bytes": fs_usage.used_bytes, + }, + ) + + +async def get_vfolder_usage(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + }, + ), + ) as params: + try: + await log_manager_api_entry(log, "get_vfolder_usage", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + usage = await volume.get_usage(params["vfid"]) + return web.json_response( + { + "file_count": usage.file_count, + "used_bytes": usage.used_bytes, + }, + ) + except ExecutionError: + return web.Response( + status=500, + reason="Storage server is busy. Please try again", + ) + + +async def get_quota(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid", default=None): t.Null | t.String, + }, + ), + ) as params: + await log_manager_api_entry(log, "get_quota", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + quota = await volume.get_quota(params["vfid"]) + return web.json_response(quota) + + +async def set_quota(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid", default=None): t.Null | t.String, + t.Key("size_bytes"): tx.BinarySize, + }, + ), + ) as params: + await log_manager_api_entry(log, "update_quota", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + await volume.set_quota(params["vfid"], params["size_bytes"]) + return web.Response(status=204) + + +async def mkdir(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("relpath"): tx.PurePath(relative_only=True), + t.Key("parents", default=True): t.ToBool, + t.Key("exist_ok", default=False): t.ToBool, + }, + ), + ) as params: + await log_manager_api_entry(log, "mkdir", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + with handle_fs_errors(volume, params["vfid"]): + await volume.mkdir( + params["vfid"], + params["relpath"], + parents=params["parents"], + exist_ok=params["exist_ok"], + ) + return web.Response(status=204) + + +async def list_files(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("relpath"): tx.PurePath(relative_only=True), + }, + ), + ) as params: + await log_manager_api_entry(log, "list_files", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + with handle_fs_errors(volume, params["vfid"]): + items = [ + { + "name": item.name, + "type": item.type.name, + "stat": { + "mode": item.stat.mode, + "size": item.stat.size, + "created": item.stat.created.isoformat(), + "modified": item.stat.modified.isoformat(), + }, + "symlink_target": item.symlink_target, + } + async for item in volume.scandir( + params["vfid"], + params["relpath"], + ) + ] + return web.json_response( + { + "items": items, + }, + ) + + +async def rename_file(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("relpath"): tx.PurePath(relative_only=True), + t.Key("new_name"): t.String(), + t.Key("is_dir"): t.ToBool(), # ignored since 22.03 + }, + ), + ) as params: + await log_manager_api_entry(log, "rename_file", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + with handle_fs_errors(volume, params["vfid"]): + await volume.move_file( + params["vfid"], + params["relpath"], + params["relpath"].with_name(params["new_name"]), + ) + return web.Response(status=204) + + +async def move_file(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("src_relpath"): tx.PurePath(relative_only=True), + t.Key("dst_relpath"): tx.PurePath(relative_only=True), + }, + ), + ) as params: + await log_manager_api_entry(log, "move_file", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + with handle_fs_errors(volume, params["vfid"]): + await volume.move_file( + params["vfid"], + params["src_relpath"], + params["dst_relpath"], + ) + return web.Response(status=204) + + +async def create_download_session(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("relpath"): tx.PurePath(relative_only=True), + t.Key("archive", default=False): t.ToBool, + t.Key("unmanaged_path", default=None): t.Null | t.String, + }, + ), + ) as params: + await log_manager_api_entry(log, "create_download_session", params) + ctx: Context = request.app["ctx"] + token_data = { + "op": "download", + "volume": params["volume"], + "vfid": str(params["vfid"]), + "relpath": str(params["relpath"]), + "exp": datetime.utcnow() + + ctx.local_config["storage-proxy"]["session-expire"], + } + token = jwt.encode( + token_data, + ctx.local_config["storage-proxy"]["secret"], + algorithm="HS256", + ) + return web.json_response( + { + "token": token, + }, + ) + + +async def create_upload_session(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("relpath"): tx.PurePath(relative_only=True), + t.Key("size"): t.ToInt, + }, + ), + ) as params: + await log_manager_api_entry(log, "create_upload_session", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + session_id = await volume.prepare_upload(params["vfid"]) + token_data = { + "op": "upload", + "volume": params["volume"], + "vfid": str(params["vfid"]), + "relpath": str(params["relpath"]), + "size": params["size"], + "session": session_id, + "exp": datetime.utcnow() + + ctx.local_config["storage-proxy"]["session-expire"], + } + token = jwt.encode( + token_data, + ctx.local_config["storage-proxy"]["secret"], + algorithm="HS256", + ) + return web.json_response( + { + "token": token, + }, + ) + + +async def delete_files(request: web.Request) -> web.Response: + async with check_params( + request, + t.Dict( + { + t.Key("volume"): t.String(), + t.Key("vfid"): tx.UUID(), + t.Key("relpaths"): t.List(tx.PurePath(relative_only=True)), + t.Key("recursive", default=False): t.ToBool, + }, + ), + ) as params: + await log_manager_api_entry(log, "delete_files", params) + ctx: Context = request.app["ctx"] + async with ctx.get_volume(params["volume"]) as volume: + with handle_fs_errors(volume, params["vfid"]): + await volume.delete_files( + params["vfid"], + params["relpaths"], + params["recursive"], + ) + return web.json_response( + { + "status": "ok", + }, + ) + + +async def init_manager_app(ctx: Context) -> web.Application: + app = web.Application( + middlewares=[ + token_auth_middleware, + ], + ) + app["ctx"] = ctx + app.router.add_route("GET", "/", get_status) + app.router.add_route("GET", "/volumes", get_volumes) + app.router.add_route("GET", "/volume/hwinfo", get_hwinfo) + app.router.add_route("POST", "/folder/create", create_vfolder) + app.router.add_route("POST", "/folder/delete", delete_vfolder) + app.router.add_route("POST", "/folder/clone", clone_vfolder) + app.router.add_route("GET", "/folder/mount", get_vfolder_mount) + app.router.add_route("GET", "/volume/performance-metric", get_performance_metric) + app.router.add_route("GET", "/folder/metadata", get_metadata) + app.router.add_route("POST", "/folder/metadata", set_metadata) + app.router.add_route("GET", "/volume/quota", get_quota) + app.router.add_route("PATCH", "/volume/quota", set_quota) + app.router.add_route("GET", "/folder/usage", get_vfolder_usage) + app.router.add_route("GET", "/folder/fs-usage", get_vfolder_fs_usage) + app.router.add_route("POST", "/folder/file/mkdir", mkdir) + app.router.add_route("POST", "/folder/file/list", list_files) + app.router.add_route("POST", "/folder/file/rename", rename_file) + app.router.add_route("POST", "/folder/file/move", move_file) + app.router.add_route("POST", "/folder/file/fetch", fetch_file) + app.router.add_route("POST", "/folder/file/download", create_download_session) + app.router.add_route("POST", "/folder/file/upload", create_upload_session) + app.router.add_route("POST", "/folder/file/delete", delete_files) + return app diff --git a/src/ai/backend/storage/config.py b/src/ai/backend/storage/config.py new file mode 100644 index 0000000000..1d432fcfd8 --- /dev/null +++ b/src/ai/backend/storage/config.py @@ -0,0 +1,84 @@ +import os +from pathlib import Path + +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.config import etcd_config_iv +from ai.backend.common.logging import logging_config_iv + +from .types import VolumeInfo + +_max_cpu_count = os.cpu_count() +_file_perm = (Path(__file__).parent / "server.py").stat() + + +local_config_iv = ( + t.Dict( + { + t.Key("storage-proxy"): t.Dict( + { + t.Key("node-id"): t.String, + t.Key("num-proc", default=_max_cpu_count): t.Int[1:_max_cpu_count], + t.Key("pid-file", default=os.devnull): tx.Path( + type="file", + allow_nonexisting=True, + allow_devnull=True, + ), + t.Key("event-loop", default="asyncio"): t.Enum("asyncio", "uvloop"), + t.Key("scandir-limit", default=1000): t.Int[0:], + t.Key("max-upload-size", default="100g"): tx.BinarySize, + t.Key("secret"): t.String, # used to generate JWT tokens + t.Key("session-expire"): tx.TimeDuration, + t.Key("user", default=None): tx.UserID( + default_uid=_file_perm.st_uid, + ), + t.Key("group", default=None): tx.GroupID( + default_gid=_file_perm.st_gid, + ), + }, + ), + t.Key("logging"): logging_config_iv, + t.Key("api"): t.Dict( + { + t.Key("client"): t.Dict( + { + t.Key("service-addr"): tx.HostPortPair( + allow_blank_host=True, + ), + t.Key("ssl-enabled"): t.ToBool, + t.Key("ssl-cert", default=None): t.Null + | tx.Path(type="file"), + t.Key("ssl-privkey", default=None): t.Null + | tx.Path(type="file"), + }, + ), + t.Key("manager"): t.Dict( + { + t.Key("service-addr"): tx.HostPortPair( + allow_blank_host=True, + ), + t.Key("ssl-enabled"): t.ToBool, + t.Key("ssl-cert", default=None): t.Null + | tx.Path(type="file"), + t.Key("ssl-privkey", default=None): t.Null + | tx.Path(type="file"), + t.Key("secret"): t.String, # used to authenticate managers + }, + ), + }, + ), + t.Key("volume"): t.Mapping( + t.String, + VolumeInfo.as_trafaret(), # volume name -> details + ), + t.Key("debug"): t.Dict( + { + t.Key("enabled", default=False): t.ToBool, + }, + ).allow_extra("*"), + }, + ) + .merge(etcd_config_iv) + .allow_extra("*") +) diff --git a/src/ai/backend/storage/context.py b/src/ai/backend/storage/context.py new file mode 100644 index 0000000000..1f228af681 --- /dev/null +++ b/src/ai/backend/storage/context.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager as actxmgr +from pathlib import Path, PurePosixPath +from typing import Any, AsyncIterator, Mapping, Type + +from ai.backend.common.etcd import AsyncEtcd + +from .abc import AbstractVolume +from .exception import InvalidVolumeError +from .netapp import NetAppVolume +from .purestorage import FlashBladeVolume +from .types import VolumeInfo +from .vfs import BaseVolume +from .xfs import XfsVolume + +BACKENDS: Mapping[str, Type[AbstractVolume]] = { + "purestorage": FlashBladeVolume, + "vfs": BaseVolume, + "xfs": XfsVolume, + "netapp": NetAppVolume, +} + + +class Context: + + __slots__ = ("pid", "etcd", "local_config") + + pid: int + etcd: AsyncEtcd + local_config: Mapping[str, Any] + + def __init__( + self, + pid: int, + local_config: Mapping[str, Any], + etcd: AsyncEtcd, + ) -> None: + self.pid = pid + self.etcd = etcd + self.local_config = local_config + + def list_volumes(self) -> Mapping[str, VolumeInfo]: + return { + name: VolumeInfo(**info) + for name, info in self.local_config["volume"].items() + } + + @actxmgr + async def get_volume(self, name: str) -> AsyncIterator[AbstractVolume]: + try: + volume_config = self.local_config["volume"][name] + except KeyError: + raise InvalidVolumeError(name) + volume_cls: Type[AbstractVolume] = BACKENDS[volume_config["backend"]] + volume_obj = volume_cls( + local_config=self.local_config, + mount_path=Path(volume_config["path"]), + fsprefix=PurePosixPath(volume_config["fsprefix"]), + options=volume_config["options"] or {}, + ) + await volume_obj.init() + try: + yield volume_obj + finally: + await volume_obj.shutdown() diff --git a/src/ai/backend/storage/exception.py b/src/ai/backend/storage/exception.py new file mode 100644 index 0000000000..0e04f1d5a3 --- /dev/null +++ b/src/ai/backend/storage/exception.py @@ -0,0 +1,50 @@ +import json +from typing import Any + +from aiohttp import web + + +class StorageProxyError(Exception): + pass + + +class ExecutionError(StorageProxyError): + pass + + +class VFolderCreationError(StorageProxyError): + pass + + +class VFolderNotFoundError(StorageProxyError): + pass + + +class InvalidSubpathError(StorageProxyError): + pass + + +class InvalidVolumeError(StorageProxyError): + pass + + +class InvalidAPIParameters(web.HTTPBadRequest): + def __init__( + self, + type_suffix: str = "invalid-api-params", + title: str = "Invalid API parameters", + msg: str = None, + data: Any = None, + ) -> None: + payload = { + "type": f"https://api.backend.ai/probs/storage/{type_suffix}", + "title": title, + } + if msg is not None: + payload["title"] = f"{title} ({msg})" + if data is not None: + payload["data"] = data + super().__init__( + text=json.dumps(payload), + content_type="application/problem+json", + ) diff --git a/src/ai/backend/storage/filelock.py b/src/ai/backend/storage/filelock.py new file mode 100644 index 0000000000..9f8d976a48 --- /dev/null +++ b/src/ai/backend/storage/filelock.py @@ -0,0 +1,50 @@ +import asyncio +import fcntl +import logging +import time +from pathlib import Path + +from ai.backend.common.logging import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class FileLock: + default_timeout: int = 3 # not allow infinite timeout for safety + locked: bool = False + + def __init__(self, path: Path, *, mode: str = "rb", timeout: int = None): + self._path = path + self._mode = mode + self._timeout = timeout if timeout is not None else self.default_timeout + + async def __aenter__(self): + def _lock(): + start_time = time.perf_counter() + self._fp = open(self._path, self._mode) + while True: + try: + fcntl.flock(self._fp, fcntl.LOCK_EX | fcntl.LOCK_NB) + self.locked = True + log.debug("file lock acquired: {}", self._path) + return self._fp + except BlockingIOError: + # Failed to get file lock. Waiting until timeout ... + if time.perf_counter() - start_time > self._timeout: + raise TimeoutError(f"failed to lock file: {self._path}") + time.sleep(0.1) + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _lock) + + async def __aexit__(self, *args): + def _unlock(): + if self.locked: + fcntl.flock(self._fp, fcntl.LOCK_UN) + self.locked = False + log.debug("file lock released: {}", self._path) + self._fp.close() + self.f_fp = None + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _unlock) diff --git a/src/ai/backend/storage/netapp/__init__.py b/src/ai/backend/storage/netapp/__init__.py new file mode 100644 index 0000000000..a87833d78b --- /dev/null +++ b/src/ai/backend/storage/netapp/__init__.py @@ -0,0 +1,394 @@ +from __future__ import annotations + +import asyncio +import glob +import json +import os +import time +from pathlib import Path, PurePosixPath +from typing import FrozenSet +from uuid import UUID + +import aiofiles + +from ai.backend.common.types import BinarySize, HardwareMetadata + +from ..abc import CAP_METRIC, CAP_VFHOST_QUOTA, CAP_VFOLDER, AbstractVolume +from ..exception import ExecutionError, StorageProxyError, VFolderCreationError +from ..types import FSPerfMetric, FSUsage, VFolderCreationOptions, VFolderUsage +from ..vfs import BaseVolume +from .netappclient import NetAppClient +from .quotamanager import QuotaManager + + +class NetAppVolume(BaseVolume): + + endpoint: str + netapp_admin: str + netapp_password: str + netapp_svm: str + netapp_volume_name: str + netapp_volume_uuid: str + netapp_qtree_name: str + netapp_qtree_id: str + + async def init(self) -> None: + + self.endpoint = self.config["netapp_endpoint"] + self.netapp_admin = self.config["netapp_admin"] + self.netapp_password = str(self.config["netapp_password"]) + self.netapp_svm = self.config["netapp_svm"] + self.netapp_volume_name = self.config["netapp_volume_name"] + self.netapp_xcp_hostname = self.config["netapp_xcp_hostname"] + self.netapp_xcp_catalog_path = self.config["netapp_xcp_catalog_path"] + self.netapp_xcp_container_name = self.config["netapp_xcp_container_name"] + + self.netapp_client = NetAppClient( + str(self.endpoint), + self.netapp_admin, + self.netapp_password, + str(self.netapp_svm), + self.netapp_volume_name, + ) + + self.quota_manager = QuotaManager( + endpoint=str(self.endpoint), + user=self.netapp_admin, + password=self.netapp_password, + svm=str(self.netapp_svm), + volume_name=self.netapp_volume_name, + ) + + # assign qtree info after netapp_client and quotamanager are initiated + self.netapp_volume_uuid = await self.netapp_client.get_volume_uuid_by_name() + default_qtree = await self.get_default_qtree_by_volume_id( + self.netapp_volume_uuid, + ) + self.netapp_qtree_name = default_qtree.get( + "name", + self.config["netapp_qtree_name"], + ) + self.netapp_qtree_id = await self.get_qtree_id_by_name(self.netapp_qtree_name) + + # adjust mount path (volume + qtree) + self.mount_path = (self.mount_path / Path(self.netapp_qtree_name)).resolve() + + async def get_capabilities(self) -> FrozenSet[str]: + return frozenset([CAP_VFOLDER, CAP_VFHOST_QUOTA, CAP_METRIC]) + + async def get_hwinfo(self) -> HardwareMetadata: + raw_metadata = await self.netapp_client.get_metadata() + qtree_info = await self.get_default_qtree_by_volume_id(self.netapp_volume_uuid) + self.netapp_qtree_name = qtree_info["name"] + quota = await self.quota_manager.get_quota_by_qtree_name(self.netapp_qtree_name) + # add quota in hwinfo + metadata = {"quota": json.dumps(quota), **raw_metadata} + return {"status": "healthy", "status_info": None, "metadata": {**metadata}} + + async def get_fs_usage(self) -> FSUsage: + volume_usage = await self.netapp_client.get_usage() + qtree_info = await self.get_default_qtree_by_volume_id(self.netapp_volume_uuid) + self.netapp_qtree_name = qtree_info["name"] + quota = await self.quota_manager.get_quota_by_qtree_name(self.netapp_qtree_name) + space = quota.get("space") + if space and space.get("hard_limit"): + capacity_bytes = space["hard_limit"] + else: + capacity_bytes = volume_usage["capacity_bytes"] + return FSUsage( + capacity_bytes=capacity_bytes, + used_bytes=volume_usage["used_bytes"], + ) + + async def get_performance_metric(self) -> FSPerfMetric: + uuid = await self.get_volume_uuid_by_name() + volume_info = await self.get_volume_info(uuid) + metric = volume_info["metric"] + return FSPerfMetric( + iops_read=metric["iops"]["read"], + iops_write=metric["iops"]["write"], + io_bytes_read=metric["throughput"]["read"], + io_bytes_write=metric["throughput"]["write"], + io_usec_read=metric["latency"]["read"], + io_usec_write=metric["latency"]["write"], + ) + + async def delete_vfolder(self, vfid: UUID) -> None: + vfpath = self.mangle_vfpath(vfid) + + # extract target_dir from vfpath + target_dir = str(vfpath).split(self.netapp_qtree_name + "/", 1)[1].split("/")[0] + nfs_path = ( + f"{self.netapp_xcp_hostname}:/{self.netapp_volume_name}/" + + f"{self.netapp_qtree_name}/{target_dir}" + ) + + async def watch_delete_dir(root_dir): + delete_cmd = ["xcp", "delete", "-force", nfs_path] + if self.netapp_xcp_container_name is not None: + delete_cmd = [ + "docker", + "exec", + self.netapp_xcp_container_name, + ] + delete_cmd + # remove vfolder by xcp command + proc = await asyncio.create_subprocess_exec( + *delete_cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + # readline and send + while True: + line = await proc.stdout.readline() + if not line: + break + yield line.rstrip() + + async def read_progress(root_dir): + async for line in watch_delete_dir(root_dir): + # TODO: line for bgtask + pass + # remove intermediate prefix directories if they become empty + from aiofiles import os as aiofile_os + + await aiofile_os.rmdir(vfpath.parent.parent) + + await read_progress(nfs_path) + + async def clone_vfolder( + self, + src_vfid: UUID, + dst_volume: AbstractVolume, + dst_vfid: UUID, + options: VFolderCreationOptions = None, + ) -> None: + # check if there is enough space in destination + fs_usage = await dst_volume.get_fs_usage() + vfolder_usage = await self.get_usage(src_vfid) + if vfolder_usage.used_bytes > fs_usage.capacity_bytes - fs_usage.used_bytes: + raise VFolderCreationError("Not enough space available for clone") + + # create the target vfolder + await dst_volume.create_vfolder(dst_vfid, options=options, exist_ok=True) + + # arrange directory based on nfs + src_vfpath = str(self.mangle_vfpath(src_vfid)).split( + self.netapp_qtree_name + "/", + 1, + )[1] + dst_vfpath = str(dst_volume.mangle_vfpath(dst_vfid)).split( + self.netapp_qtree_name + "/", + 1, + )[1] + + nfs_src_path = ( + f"{self.netapp_xcp_hostname}:/{self.netapp_volume_name}/" + + f"{self.netapp_qtree_name}/{src_vfpath}" + ) + nfs_dst_path = ( + f"{self.netapp_xcp_hostname}:/{dst_volume.config['netapp_volume_name']}/" + + f"{dst_volume.config['netapp_qtree_name']}/{dst_vfpath}" + ) + + # perform clone using xcp copy (exception handling needed) + try: + + async def watch_copy_dir(src_path, dst_path): + copy_cmd = ["xcp", "copy", src_path, dst_path] + if self.netapp_xcp_container_name is not None: + copy_cmd = [ + "docker", + "exec", + self.netapp_xcp_container_name, + ] + copy_cmd + proc = await asyncio.create_subprocess_exec( + *copy_cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + stdout, stderr = await proc.communicate() + # readline and send + while True: + line = await proc.stdout.readline() + if not line: + break + if b"xcp: ERROR:" in line: + raise Exception + yield line.rstrip() + + async def read_progress(src_path, dst_path): + async for line in watch_copy_dir(src_path, dst_path): + # TODO: line for bgtask + pass + + await read_progress(nfs_src_path, nfs_dst_path) + + except Exception: + await dst_volume.delete_vfolder(dst_vfid) + raise RuntimeError("Copying files from source directories failed.") + + async def shutdown(self) -> None: + await self.netapp_client.aclose() + await self.quota_manager.aclose() + + # ------ volume operations ------ + async def get_list_volumes(self): + resp = await self.netapp_client.get_list_volumes() + + if "error" in resp: + raise ExecutionError("api error") + return resp + + async def get_volume_uuid_by_name(self): + resp = await self.netapp_client.get_volume_uuid_by_name() + + if "error" in resp: + raise ExecutionError("api error") + return resp + + async def get_volume_info(self, volume_uuid): + resp = await self.netapp_client.get_volume_info(volume_uuid) + + if "error" in resp: + raise ExecutionError("api error") + return resp + + # ------ qtree and quotas operations ------ + async def get_default_qtree_by_volume_id(self, volume_uuid): + volume_uuid = volume_uuid if volume_uuid else self.netapp_volume_uuid + resp = await self.netapp_client.get_default_qtree_by_volume_id(volume_uuid) + if "error" in resp: + raise ExecutionError("api error") + return resp + + async def get_qtree_id_by_name(self, qtree_name): + qtree_name = ( + qtree_name if qtree_name else await self.get_default_qtree_by_volume_id() + ) + resp = await self.netapp_client.get_qtree_id_by_name(qtree_name) + + if "error" in resp: + raise ExecutionError("api error") + return resp + + async def get_quota(self, vfid: UUID) -> BinarySize: + raise NotImplementedError + + async def set_quota(self, vfid: UUID, size_bytes: BinarySize) -> None: + raise NotImplementedError + + async def get_usage( + self, + vfid: UUID, + relpath: PurePosixPath = PurePosixPath("."), + ) -> VFolderUsage: + target_path = self.sanitize_vfpath(vfid, relpath) + total_size = 0 + total_count = 0 + raw_target_path = str(target_path).split(self.netapp_qtree_name + "/", 1)[1] + nfs_path = ( + f"{self.netapp_xcp_hostname}:/{self.netapp_volume_name}/" + + f"{self.netapp_qtree_name}/{raw_target_path}" + ) + start_time = time.monotonic() + available = True + + prev_files_count = 0 + curr_files_count = 0 + + # check the number of scan result files changed + # NOTE: if directory contains small amout of files, scan result doesn't get saved + files = list(glob.iglob(f"{self.netapp_xcp_catalog_path}/stats/*.json")) + prev_files_count = len(files) + + scan_cmd = ["xcp", "scan", "-q", nfs_path] + if self.netapp_xcp_container_name is not None: + scan_cmd = ["docker", "exec", self.netapp_xcp_container_name] + scan_cmd + # Measure the exact file sizes and bytes + proc = await asyncio.create_subprocess_exec( + *scan_cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + try: + stdout, stderr = await proc.communicate() + if b"xcp: ERROR:" in stdout: + # destination directory is busy for other operations + if b"xcp: ERROR: mnt3 MOUNT" in stdout: + raise StorageProxyError + available = False + available = False if (await proc.wait() != 0) else True + # get the latest saved file + # scan command saves json file when operation completed + files = sorted( + glob.iglob(f"{self.netapp_xcp_catalog_path}/stats/*.json"), + key=os.path.getctime, + reverse=True, + ) + curr_files_count = len(files) + + # scan result file has been created + if prev_files_count < curr_files_count and available: + file = files[0] + async with aiofiles.open(file, "r", encoding="utf8") as scan_result: + contents = await scan_result.read() + data = json.loads(contents) + # includes size element + count_keys = [ + "numberOfDirectories", + "numberOfHardlinkedFiles", + "numberOfHardlinks", + "numberOfRegularFiles", + "numberOfSpecialFiles", + "numberOfSymbolicLinks", + "numberOfUnreadableDirs", + "numberOfUnreadableFiles", + ] + size_keys = [ + "spaceSavedByHardlinks", + "spaceUsedDirectories", + "spaceUsedRegularFiles", + "spaceUsedSpecialFiles", + "spaceUsedSymbolicLinks", + ] + total_count = sum([data[item] for item in count_keys]) + total_size = sum([data[item] for item in size_keys]) + else: + # if there's no scan result file, or cannot execute xcp command, + # then use the same way in vfs + def _calc_usage(target_path: os.DirEntry | Path) -> None: + nonlocal total_size, total_count + _timeout = 3 + # FIXME: Remove "type: ignore" when python/mypy#11964 is resolved. + with os.scandir(target_path) as scanner: # type: ignore + for entry in scanner: + if entry.is_dir(): + _calc_usage(entry) + continue + if entry.is_file() or entry.is_symlink(): + stat = entry.stat(follow_symlinks=False) + total_size += stat.st_size + total_count += 1 + if total_count % 1000 == 0: + # Cancel if this I/O operation takes too much time. + if time.monotonic() - start_time > _timeout: + raise TimeoutError + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _calc_usage, target_path) + except StorageProxyError: + raise ExecutionError("Storage server is busy. Please try again") + except FileNotFoundError: + available = False + except IndexError: + available = False + except TimeoutError: + # -1 indicates "too many" + total_size = -1 + total_count = -1 + if not available: + raise ExecutionError( + "Cannot access the scan result file. Please check xcp is activated.", + ) + + return VFolderUsage(file_count=total_count, used_bytes=total_size) diff --git a/src/ai/backend/storage/netapp/netappclient.py b/src/ai/backend/storage/netapp/netappclient.py new file mode 100644 index 0000000000..82bb7d7856 --- /dev/null +++ b/src/ai/backend/storage/netapp/netappclient.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import json +from typing import Any, List, Mapping + +import aiohttp + + +class NetAppClient: + + endpoint: str + user: str + password: str + _session: aiohttp.ClientSession + svm: str + volume_name: str + + def __init__( + self, + endpoint: str, + user: str, + password: str, + svm: str, + volume_name: str, + ) -> None: + self.endpoint = endpoint + self.user = user + self.password = password + self.svm = svm + self.volume_name = volume_name + self._session = aiohttp.ClientSession() + + async def aclose(self) -> None: + await self._session.close() + + async def get_metadata(self) -> Mapping[str, Any]: + volume_uuid = await self.get_volume_uuid_by_name() + data = await self.get_volume_info(volume_uuid) + qos = await self.get_qos_by_volume_id(volume_uuid) + qos_policies = await self.get_qos_policies() + qtree_metadata = await self.get_default_qtree_by_volume_id(volume_uuid) + qtree = await self.get_qtree_info(qtree_metadata.get("id")) + + # mapping certain data for better explanation + volume_qtree_cluster = { + # ------ use volume info ------ + "id": data["uuid"], + "local_tier": data["aggregates"][0]["name"], + "create_time": data["create_time"], + "snapshot_policy": data["snapshot_policy"]["name"], + "snapmirroring": str(data["snapmirror"]["is_protected"]), + "state": data["state"], + "style": data["style"], + "svm_name": data["svm"]["name"], + "svm_id": data["svm"]["uuid"], + # ------ use qtree info ------ + "name": qtree["name"], + "path": qtree["path"], + "security_style": qtree["security_style"], + "export_policy": qtree["export_policy"]["name"], + "timestamp": qtree["statistics"].get("timestamp"), # last check time + } + # optional values to add in volume_qtree_cluster + if qos: + volume_qtree_cluster.update({"qos": json.dumps(qos["policy"])}) + if qos_policies: + volume_qtree_cluster.update({"qos_policies": json.dumps(qos_policies)}) + return volume_qtree_cluster + + async def get_usage(self) -> Mapping[str, Any]: + # volume specific usage check + uuid = await self.get_volume_uuid_by_name() + data = await self.get_volume_info(uuid) + return { + "capacity_bytes": data["space"]["available"], + "used_bytes": data["space"]["used"], + } + + async def get_list_volumes(self) -> Mapping[str, Any]: + async with self._session.get( + f"{self.endpoint}/api/storage/volumes", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + return data["records"] + + async def get_volume_name_by_uuid(self, volume_uuid) -> str: + async with self._session.get( + f"{self.endpoint}/api/storage/volumes?uuid={volume_uuid}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + name = data["records"][0]["name"] + return name + + async def get_volume_uuid_by_name(self) -> str: + async with self._session.get( + f"{self.endpoint}/api/storage/volumes?name={self.volume_name}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + uuid = data["records"][0]["uuid"] + return uuid + + async def get_volume_info(self, volume_uuid) -> Mapping[str, Any]: + async with self._session.get( + f"{self.endpoint}/api/storage/volumes/{volume_uuid}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + return data + + async def get_default_qtree_by_volume_id(self, volume_uuid) -> Mapping[str, Any]: + qtrees = await self.list_qtrees_by_volume_id(volume_uuid) + for qtree in qtrees: + # skip the default qtree made by NetApp ONTAP internally + # It will not be used in Backend.AI NetApp ONTAP Plugin + if not qtree["name"]: + continue + else: + return qtree + return {} + + async def get_qtree_name_by_id(self, qtree_id) -> str: + async with self._session.get( + f"{self.endpoint}/api/storage/qtrees?id={qtree_id}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + return data["name"] + + async def get_qtree_id_by_name(self, qtree_name) -> str: + async with self._session.get( + f"{self.endpoint}/api/storage/qtrees?name={qtree_name}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + id = str(data["records"][0]["id"]) if data["num_records"] > 0 else "" + return id + + async def list_qtrees_by_volume_id(self, volume_uuid) -> List[Mapping[str, Any]]: + if not volume_uuid: + volume_uuid = await self.get_volume_uuid_by_name() + async with self._session.get( + f"{self.endpoint}/api/storage/qtrees/{volume_uuid}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + return data["records"] + + async def get_qtree_info(self, qtree_id) -> Mapping[str, Any]: + uuid = await self.get_volume_uuid_by_name() + async with self._session.get( + f"{self.endpoint}/api/storage/qtrees/{uuid}/{qtree_id}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + return data + + async def get_qos_policies(self) -> List[Mapping[str, Any]]: + async with self._session.get( + f"{self.endpoint}/api/storage/qos/policies", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + qos_policies_metadata = data["records"] + qos_policies = [] + for qos in qos_policies_metadata: + policy = await self.get_qos_by_uuid(qos["uuid"]) + qos_policies.append(policy) + return qos_policies + + async def get_qos_by_uuid(self, qos_uuid) -> Mapping[str, Any]: + async with self._session.get( + f"{self.endpoint}/api/storage/qos/policies/{qos_uuid}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + fixed = data["fixed"] + qos_policy = { + "uuid": data["uuid"], + "name": data["name"], + "fixed": { + "max_throughput_iops": fixed.get("max_throughput_iops", 0), + "max_throughput_mbps": fixed.get("max_throughput_mbps", 0), + "min_throughput_iops": fixed.get("min_throughput_iops", 0), + "min_throughput_mbps": fixed.get("min_throughput_mbps", 0), + "capacity_shared": fixed["capacity_shared"], + }, + "svm": data["svm"], + } + return qos_policy + + async def get_qos_by_volume_id(self, volume_uuid) -> Mapping[str, Any]: + async with self._session.get( + f"{self.endpoint}/api/storage/volumes/{volume_uuid}?fields=qos", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + return data["qos"] diff --git a/src/ai/backend/storage/netapp/quotamanager.py b/src/ai/backend/storage/netapp/quotamanager.py new file mode 100644 index 0000000000..16a32f5a05 --- /dev/null +++ b/src/ai/backend/storage/netapp/quotamanager.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import Any, Mapping + +import aiohttp +from aiohttp.client_reqrep import ClientResponse + + +class QuotaManager: + + endpoint: str + user: str + password: str + _session: aiohttp.ClientSession + svm: str + volume_name: str + + def __init__( + self, + endpoint: str, + user: str, + password: str, + svm: str, + volume_name: str, + ) -> None: + self.endpoint = endpoint + self.user = user + self.password = password + self._session = aiohttp.ClientSession() + self.svm = svm + self.volume_name = volume_name + + async def aclose(self) -> None: + await self._session.close() + + async def list_quotarules(self): + async with self._session.get( + f"{self.endpoint}/api/storage/quota/rules", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=False, + ) as resp: + data = await resp.json() + await self._session.close() + + rules = [rule for rule in data["uuid"]] + self.rules = rules + return rules + + async def list_all_qtrees_with_quotas(self) -> Mapping[str, Any]: + rules = await self.list_quotarules() + qtrees = {} + + for rule in rules: + async with self._session.get( + f"{self.endpoint}/api/storage/quota/rules/{rule}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=False, + ) as resp: + data = await resp.json() + qtree_uuid = data["uuid"] + qtree_name = data["qtree"]["name"] + qtrees[qtree_uuid] = qtree_name + self.qtrees = qtrees + return qtrees + + async def get_quota_by_rule(self, rule_uuid) -> Mapping[str, Any]: + async with self._session.get( + f"{self.endpoint}/api/storage/quota/rules/{rule_uuid}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=False, + ) as resp: + data = await resp.json() + quota = {} + if data.get("space"): + quota["space"] = data["space"] + if data.get("files"): + quota["files"] = data["files"] + return quota + + async def get_quota_by_qtree_name(self, qtree_name) -> Mapping[str, Any]: + async with self._session.get( + f"{self.endpoint}/api/storage/quota/rules?volume={self.volume_name}&qtree={qtree_name}", + auth=aiohttp.BasicAuth(self.user, self.password), + ssl=False, + raise_for_status=False, + ) as resp: + data = await resp.json() + rule_uuid = data["records"][0]["uuid"] + quota = await self.get_quota_by_rule(rule_uuid) + return quota + + # For now, Only Read / Update operation for qtree is available + # in NetApp ONTAP Plugin of Backend.AI + async def create_quotarule_qtree( + self, + qtree_name: str, + spahali: int, + spasoli: int, + fihali: int, + fisoli: int, + ) -> Mapping[str, Any]: + dataobj = { + "svm": {"name": self.svm}, + "volume": {"name": self.volume_name}, + "type": "tree", + "space": {"hard_limit": spahali, "soft_limit": spasoli}, + "files": {"hard_limit": fihali, "soft_limit": fisoli}, + "qtree": {"name": qtree_name}, + } + + headers = {"content-type": "application/json", "accept": "application/hal+json"} + + async with self._session.post( + f"{self.endpoint}/api/storage/quota/rules", + auth=aiohttp.BasicAuth(self.user, self.password), + headers=headers, + json=dataobj, + ssl=False, + raise_for_status=True, + ) as resp: + + msg = await resp.json() + return msg + + async def update_quotarule_qtree( + self, + spahali: int, + spasoli: int, + fihali: int, + fisoli: int, + rule_uuid, + ) -> ClientResponse: + dataobj = { + "space": {"hard_limit": spahali, "soft_limit": spasoli}, + "files": {"hard_limit": fihali, "soft_limit": fisoli}, + } + + headers = {"content-type": "application/json", "accept": "application/hal+json"} + + async with self._session.patch( + f"{self.endpoint}/api/storage/quota/rules/{rule_uuid}", + auth=aiohttp.BasicAuth(self.user, self.password), + headers=headers, + json=dataobj, + ssl=False, + raise_for_status=True, + ) as resp: + return await resp.json() + + # For now, Only Read / Update operation for qtree is available + # in NetApp ONTAP Plugin of Backend.AI + async def delete_quotarule_qtree(self, rule_uuid) -> ClientResponse: + headers = {"content-type": "application/json", "accept": "application/hal+json"} + + async with self._session.delete( + f"{self.endpoint}/api/storage/quota/rules/{rule_uuid}", + auth=aiohttp.BasicAuth(self.user, self.password), + headers=headers, + ssl=False, + raise_for_status=True, + ) as resp: + return await resp.json() diff --git a/src/ai/backend/storage/purestorage/__init__.py b/src/ai/backend/storage/purestorage/__init__.py new file mode 100644 index 0000000000..8410447c46 --- /dev/null +++ b/src/ai/backend/storage/purestorage/__init__.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +import asyncio +import json +from pathlib import Path, PurePosixPath +from typing import AsyncIterator, FrozenSet, Sequence +from uuid import UUID + +from aiotools import aclosing + +from ai.backend.common.types import BinarySize, HardwareMetadata + +from ..abc import CAP_FAST_SCAN, CAP_METRIC, CAP_VFOLDER +from ..types import ( + DirEntry, + DirEntryType, + FSPerfMetric, + FSUsage, + Stat, + VFolderUsage, +) +from ..utils import fstime2datetime +from ..vfs import BaseVolume +from .purity import PurityClient + + +class FlashBladeVolume(BaseVolume): + async def init(self) -> None: + available = True + try: + proc = await asyncio.create_subprocess_exec( + b"pdu", + b"--version", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + except FileNotFoundError: + available = False + else: + try: + stdout, stderr = await proc.communicate() + if b"RapidFile Toolkit" not in stdout or proc.returncode != 0: + available = False + finally: + await proc.wait() + if not available: + raise RuntimeError( + "PureStorage RapidFile Toolkit is not installed. " + "You cannot use the PureStorage backend for the stroage proxy.", + ) + self.purity_client = PurityClient( + self.config["purity_endpoint"], + self.config["purity_api_token"], + api_version=self.config["purity_api_version"], + ) + + async def shutdown(self) -> None: + await self.purity_client.aclose() + + async def get_capabilities(self) -> FrozenSet[str]: + return frozenset( + [ + CAP_VFOLDER, + CAP_METRIC, + CAP_FAST_SCAN, + ], + ) + + async def get_hwinfo(self) -> HardwareMetadata: + async with self.purity_client as client: + metadata = await client.get_metadata() + return { + "status": "healthy", + "status_info": None, + "metadata": { + **metadata, + }, + } + + async def get_fs_usage(self) -> FSUsage: + async with self.purity_client as client: + usage = await client.get_usage(self.config["purity_fs_name"]) + return FSUsage( + capacity_bytes=usage["capacity_bytes"], + used_bytes=usage["used_bytes"], + ) + + async def copy_tree( + self, + src_vfpath: Path, + dst_vfpath: Path, + ) -> None: + proc = await asyncio.create_subprocess_exec( + b"pcp", + b"-r", + b"-p", + bytes(src_vfpath / "."), + bytes(dst_vfpath), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f'"pcp" command failed: {stderr.decode()}') + + async def get_quota(self, vfid: UUID) -> BinarySize: + raise NotImplementedError + + async def set_quota(self, vfid: UUID, size_bytes: BinarySize) -> None: + raise NotImplementedError + + async def get_performance_metric(self) -> FSPerfMetric: + async with self.purity_client as client: + async with aclosing( + client.get_nfs_metric(self.config["purity_fs_name"]), + ) as items: + async for item in items: + return FSPerfMetric( + iops_read=item["reads_per_sec"], + iops_write=item["writes_per_sec"], + io_bytes_read=item["read_bytes_per_sec"], + io_bytes_write=item["write_bytes_per_sec"], + io_usec_read=item["usec_per_read_op"], + io_usec_write=item["usec_per_write_op"], + ) + else: + raise RuntimeError( + "no metric found for the configured flashblade filesystem", + ) + + async def get_usage( + self, + vfid: UUID, + relpath: PurePosixPath = PurePosixPath("."), + ) -> VFolderUsage: + target_path = self.sanitize_vfpath(vfid, relpath) + total_size = 0 + total_count = 0 + raw_target_path = bytes(target_path) + # Measure the exact file sizes and bytes + proc = await asyncio.create_subprocess_exec( + b"pdu", + b"-0", + b"-b", + b"-a", + b"-s", + raw_target_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + assert proc.stdout is not None + try: + # TODO: check slowdowns when there are millions of files + while True: + try: + line = await proc.stdout.readuntil(b"\0") + line = line.rstrip(b"\0") + except asyncio.IncompleteReadError: + break + size, name = line.split(maxsplit=1) + if len(name) != len(raw_target_path) and name != raw_target_path: + total_size += int(size) + total_count += 1 + finally: + await proc.wait() + return VFolderUsage(file_count=total_count, used_bytes=total_size) + + async def get_used_bytes(self, vfid: UUID) -> BinarySize: + vfpath = self.mangle_vfpath(vfid) + proc = await asyncio.create_subprocess_exec( + b"pdu", + b"-hs", + bytes(vfpath), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"pdu command failed: {stderr.decode()}") + used_bytes, _ = stdout.decode().split() + return BinarySize.finite_from_str(used_bytes) + + # ------ vfolder internal operations ------- + + def scandir(self, vfid: UUID, relpath: PurePosixPath) -> AsyncIterator[DirEntry]: + target_path = self.sanitize_vfpath(vfid, relpath) + raw_target_path = bytes(target_path) + + async def _aiter() -> AsyncIterator[DirEntry]: + proc = await asyncio.create_subprocess_exec( + b"pls", + b"--json", + raw_target_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + assert proc.stdout is not None + try: + while True: + line = await proc.stdout.readline() + if not line: + break + line = line.rstrip(b"\n") + item = json.loads(line) + item_path = Path(item["path"]) + entry_type = DirEntryType.FILE + if item["filetype"] == 40000: + entry_type = DirEntryType.DIRECTORY + if item["filetype"] == 120000: + entry_type = DirEntryType.SYMLINK + yield DirEntry( + name=item_path.name, + path=item_path, + type=entry_type, + stat=Stat( + size=item["size"], + owner=str(item["uid"]), + # The integer represents the octal number in decimal + # (e.g., 644 which actually means 0o644) + mode=int(str(item["mode"]), 8), + modified=fstime2datetime(item["mtime"]), + created=fstime2datetime(item["ctime"]), + ), + symlink_target="", # TODO: should be tested on PureStorage + ) + finally: + await proc.wait() + + return _aiter() + + async def copy_file( + self, + vfid: UUID, + src: PurePosixPath, + dst: PurePosixPath, + ) -> None: + src_path = self.sanitize_vfpath(vfid, src) + dst_path = self.sanitize_vfpath(vfid, dst) + proc = await asyncio.create_subprocess_exec( + b"pcp", + b"-p", + bytes(src_path), + bytes(dst_path), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f'"pcp" command failed: {stderr.decode()}') + + async def delete_files( + self, + vfid: UUID, + relpaths: Sequence[PurePosixPath], + recursive: bool = False, + ) -> None: + target_paths = [bytes(self.sanitize_vfpath(vfid, p)) for p in relpaths] + proc = await asyncio.create_subprocess_exec( + b"prm", + b"-r", + *target_paths, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + await proc.communicate() + if proc.returncode != 0: + raise RuntimeError("'prm' command returned a non-zero exit code.") diff --git a/src/ai/backend/storage/purestorage/purity.py b/src/ai/backend/storage/purestorage/purity.py new file mode 100644 index 0000000000..2836f8127e --- /dev/null +++ b/src/ai/backend/storage/purestorage/purity.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from contextvars import ContextVar, Token +from typing import Any, AsyncGenerator, Mapping + +import aiohttp +from yarl import URL + + +class PurityClient: + + endpoint: URL + api_token: str + api_version: str + auth_token: ContextVar[str] + + _session: aiohttp.ClientSession + _auth_token_cvtoken: Token + + def __init__( + self, + endpoint: str, + api_token: str, + *, + api_version: str = "1.8", + ) -> None: + self.endpoint = URL(endpoint) + self.api_token = api_token + self.api_version = api_version + self.auth_token = ContextVar("auth_token") + self._session = aiohttp.ClientSession() + + async def aclose(self) -> None: + await self._session.close() + + async def __aenter__(self) -> PurityClient: + async with self._session.post( + self.endpoint / "api" / "login", + headers={"api-token": self.api_token}, + ssl=False, + raise_for_status=True, + ) as resp: + auth_token = resp.headers["x-auth-token"] + self._auth_token_cvtoken = self.auth_token.set(auth_token) + _ = await resp.json() + return self + + async def __aexit__(self, *exc_info) -> None: + self.auth_token.reset(self._auth_token_cvtoken) + + # For the concrete API reference, check out: + # https://purity-fb.readthedocs.io/en/latest/ + + async def get_metadata(self) -> Mapping[str, Any]: + if self.auth_token is None: + raise RuntimeError("The auth token for Purity API is not initialized.") + items = [] + pagination_token = "" + while True: + async with self._session.get( + (self.endpoint / "api" / self.api_version / "arrays"), + headers={"x-auth-token": self.auth_token.get()}, + params={ + "items_returned": 10, + "token": pagination_token, + }, + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + for item in data["items"]: + items.append(item) + pagination_token = data["pagination_info"]["continuation_token"] + if pagination_token is None: + break + if not items: + return {} + first = items[0] + return { + "id": first["id"], + "name": first["name"], + "os": first["os"], + "revision": first["revision"], + "version": first["version"], + "blade_count": str(len(items)), + "console_url": str(self.endpoint), + } + + async def get_nfs_metric( + self, + fs_name: str, + ) -> AsyncGenerator[Mapping[str, Any], None]: + if self.auth_token is None: + raise RuntimeError("The auth token for Purity API is not initialized.") + pagination_token = "" + while True: + async with self._session.get( + ( + self.endpoint + / "api" + / self.api_version + / "file-systems" + / "performance" + ), + headers={"x-auth-token": self.auth_token.get()}, + params={ + "names": fs_name, + "protocol": "NFS", + "items_returned": 10, + "token": pagination_token, + }, + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + for item in data["items"]: + yield item + pagination_token = data["pagination_info"]["continuation_token"] + if pagination_token is None: + break + + async def get_usage(self, fs_name: str) -> Mapping[str, Any]: + if self.auth_token is None: + raise RuntimeError("The auth token for Purity API is not initialized.") + items = [] + pagination_token = "" + while True: + async with self._session.get( + (self.endpoint / "api" / self.api_version / "file-systems"), + headers={"x-auth-token": self.auth_token.get()}, + params={ + "names": fs_name, + "items_returned": 10, + "token": pagination_token, + }, + ssl=False, + raise_for_status=True, + ) as resp: + data = await resp.json() + for item in data["items"]: + items.append(item) + pagination_token = data["pagination_info"]["continuation_token"] + if pagination_token is None: + break + if not items: + return {} + first = items[0] + return { + "capacity_bytes": data["total"]["provisioned"], + "used_bytes": first["space"]["total_physical"], + } diff --git a/src/ai/backend/storage/py.typed b/src/ai/backend/storage/py.typed new file mode 100644 index 0000000000..5abed26af8 --- /dev/null +++ b/src/ai/backend/storage/py.typed @@ -0,0 +1 @@ +marker diff --git a/src/ai/backend/storage/server.py b/src/ai/backend/storage/server.py new file mode 100644 index 0000000000..e91a9189e9 --- /dev/null +++ b/src/ai/backend/storage/server.py @@ -0,0 +1,222 @@ +import asyncio +import grp +import logging +import multiprocessing +import os +import pwd +import ssl +import sys +from pathlib import Path +from pprint import pformat, pprint +from typing import Any, AsyncIterator, Sequence + +import aiotools +import click +from aiohttp import web +from setproctitle import setproctitle + +from ai.backend.common import config +from ai.backend.common.etcd import AsyncEtcd, ConfigScopes +from ai.backend.common.logging import BraceStyleAdapter, Logger +from ai.backend.common.utils import env_info + +from . import __version__ as VERSION +from .api.client import init_client_app +from .api.manager import init_manager_app +from .config import local_config_iv +from .context import Context + +log = BraceStyleAdapter(logging.getLogger("ai.backend.storage.server")) + + +@aiotools.server +async def server_main_logwrapper(loop, pidx, _args): + setproctitle(f"backend.ai: storage-proxy worker-{pidx}") + try: + asyncio.get_child_watcher() + except (AttributeError, NotImplementedError): + pass + log_endpoint = _args[1] + logger = Logger(_args[0]["logging"], is_master=False, log_endpoint=log_endpoint) + with logger: + async with server_main(loop, pidx, _args): + yield + + +@aiotools.server +async def server_main( + loop: asyncio.AbstractEventLoop, + pidx: int, + _args: Sequence[Any], +) -> AsyncIterator[None]: + local_config = _args[0] + + etcd_credentials = None + if local_config["etcd"]["user"]: + etcd_credentials = { + "user": local_config["etcd"]["user"], + "password": local_config["etcd"]["password"], + } + scope_prefix_map = { + ConfigScopes.GLOBAL: "", + ConfigScopes.NODE: f"nodes/storage/{local_config['storage-proxy']['node-id']}", + } + etcd = AsyncEtcd( + local_config["etcd"]["addr"], + local_config["etcd"]["namespace"], + scope_prefix_map, + credentials=etcd_credentials, + ) + ctx = Context(pid=os.getpid(), local_config=local_config, etcd=etcd) + client_api_app = await init_client_app(ctx) + manager_api_app = await init_manager_app(ctx) + + client_ssl_ctx = None + manager_ssl_ctx = None + if local_config["api"]["client"]["ssl-enabled"]: + client_ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + client_ssl_ctx.load_cert_chain( + str(local_config["api"]["client"]["ssl-cert"]), + str(local_config["api"]["client"]["ssl-privkey"]), + ) + if local_config["api"]["manager"]["ssl-enabled"]: + manager_ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + manager_ssl_ctx.load_cert_chain( + str(local_config["api"]["manager"]["ssl-cert"]), + str(local_config["api"]["manager"]["ssl-privkey"]), + ) + client_api_runner = web.AppRunner(client_api_app) + manager_api_runner = web.AppRunner(manager_api_app) + await client_api_runner.setup() + await manager_api_runner.setup() + client_service_addr = local_config["api"]["client"]["service-addr"] + manager_service_addr = local_config["api"]["manager"]["service-addr"] + client_api_site = web.TCPSite( + client_api_runner, + str(client_service_addr.host), + client_service_addr.port, + backlog=1024, + reuse_port=True, + ssl_context=client_ssl_ctx, + ) + manager_api_site = web.TCPSite( + manager_api_runner, + str(manager_service_addr.host), + manager_service_addr.port, + backlog=1024, + reuse_port=True, + ssl_context=manager_ssl_ctx, + ) + await client_api_site.start() + await manager_api_site.start() + if os.geteuid() == 0: + uid = local_config["storage-proxy"]["user"] + gid = local_config["storage-proxy"]["group"] + os.setgroups( + [g.gr_gid for g in grp.getgrall() if pwd.getpwuid(uid).pw_name in g.gr_mem], + ) + os.setgid(gid) + os.setuid(uid) + log.info("Changed process uid:gid to {}:{}", uid, gid) + log.info("Started service.") + try: + yield + finally: + log.info("Shutting down...") + await manager_api_runner.cleanup() + await client_api_runner.cleanup() + + +@click.group(invoke_without_command=True) +@click.option( + "-f", + "--config-path", + "--config", + type=Path, + default=None, + help="The config file path. " + "(default: ./storage-proxy.toml and /etc/backend.ai/storage-proxy.toml)", +) +@click.option( + "--debug", + is_flag=True, + help="Enable the debug mode and override the global log level to DEBUG.", +) +@click.pass_context +def main(cli_ctx, config_path, debug): + # Determine where to read configuration. + raw_cfg, cfg_src_path = config.read_from_file(config_path, "storage-proxy") + + config.override_with_env(raw_cfg, ("etcd", "namespace"), "BACKEND_NAMESPACE") + config.override_with_env(raw_cfg, ("etcd", "addr"), "BACKEND_ETCD_ADDR") + config.override_with_env(raw_cfg, ("etcd", "user"), "BACKEND_ETCD_USER") + config.override_with_env(raw_cfg, ("etcd", "password"), "BACKEND_ETCD_PASSWORD") + if debug: + config.override_key(raw_cfg, ("debug", "enabled"), True) + + try: + local_config = config.check(raw_cfg, local_config_iv) + local_config["_src"] = cfg_src_path + except config.ConfigurationError as e: + print( + "ConfigurationError: Validation of agent configuration has failed:", + file=sys.stderr, + ) + print(pformat(e.invalid_data), file=sys.stderr) + raise click.Abort() + + if local_config["debug"]["enabled"]: + config.override_key(local_config, ("logging", "level"), "DEBUG") + config.override_key(local_config, ("logging", "pkg-ns", "ai.backend"), "DEBUG") + + # if os.getuid() != 0: + # print('Storage agent can only be run as root', file=sys.stderr) + # raise click.Abort() + + multiprocessing.set_start_method("spawn") + + if cli_ctx.invoked_subcommand is None: + local_config["storage-proxy"]["pid-file"].write_text(str(os.getpid())) + log_sockpath = Path( + f"/tmp/backend.ai/ipc/storage-proxy-logger-{os.getpid()}.sock", + ) + log_sockpath.parent.mkdir(parents=True, exist_ok=True) + log_endpoint = f"ipc://{log_sockpath}" + local_config["logging"]["endpoint"] = log_endpoint + try: + logger = Logger( + local_config["logging"], + is_master=True, + log_endpoint=log_endpoint, + ) + with logger: + setproctitle("backend.ai: storage-proxy") + log.info("Backend.AI Storage Proxy", VERSION) + log.info("Runtime: {0}", env_info()) + log.info("Node ID: {0}", local_config["storage-proxy"]["node-id"]) + log_config = logging.getLogger("ai.backend.agent.config") + if local_config["debug"]["enabled"]: + log_config.debug("debug mode enabled.") + if "debug" in local_config and local_config["debug"]["enabled"]: + print("== Storage proxy configuration ==") + pprint(local_config) + if local_config["storage-proxy"]["event-loop"] == "uvloop": + import uvloop + + uvloop.install() + log.info("Using uvloop as the event loop backend") + aiotools.start_server( + server_main_logwrapper, + num_workers=local_config["storage-proxy"]["num-proc"], + args=(local_config, log_endpoint), + ) + log.info("exit.") + finally: + if local_config["storage-proxy"]["pid-file"].is_file(): + # check is_file() to prevent deleting /dev/null! + local_config["storage-proxy"]["pid-file"].unlink() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/ai/backend/storage/types.py b/src/ai/backend/storage/types.py new file mode 100644 index 0000000000..cd153064e2 --- /dev/null +++ b/src/ai/backend/storage/types.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import enum +from datetime import datetime +from pathlib import Path, PurePath +from typing import Any, Final, Mapping, Optional + +import attr +import trafaret as t + +from ai.backend.common import validators as tx +from ai.backend.common.types import BinarySize + + +class Sentinel(enum.Enum): + token = 0 + + +SENTINEL: Final = Sentinel.token + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class FSPerfMetric: + # iops + iops_read: int + iops_write: int + # thruput + io_bytes_read: int + io_bytes_write: int + # latency + io_usec_read: float + io_usec_write: float + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class FSUsage: + capacity_bytes: BinarySize + used_bytes: BinarySize + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class VolumeInfo: + backend: str + path: Path + fsprefix: Optional[PurePath] + options: Optional[Mapping[str, Any]] + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + return t.Dict( + { + t.Key("backend"): t.String, + t.Key("path"): tx.Path(type="dir"), + t.Key("fsprefix", default="."): tx.PurePath(relative_only=True), + t.Key("options", default=None): t.Null | t.Mapping(t.String, t.Any), + }, + ) + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class VFolderCreationOptions: + quota: Optional[BinarySize] + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + return t.Dict({t.Key("quota", default=None): t.Null | tx.BinarySize}) + + @classmethod + def as_object(cls, dict_opts: Mapping | None) -> VFolderCreationOptions: + if dict_opts is None: + quota = None + else: + quota = dict_opts.get("quota") + return VFolderCreationOptions(quota=quota) + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class VFolderUsage: + file_count: int + used_bytes: int + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class Stat: + size: int + owner: str + mode: int + modified: datetime + created: datetime + + +class DirEntryType(enum.Enum): + FILE = 0 + DIRECTORY = 1 + SYMLINK = 2 + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class DirEntry: + name: str + path: Path + type: DirEntryType + stat: Stat + symlink_target: str diff --git a/src/ai/backend/storage/utils.py b/src/ai/backend/storage/utils.py new file mode 100644 index 0000000000..4d065d52f5 --- /dev/null +++ b/src/ai/backend/storage/utils.py @@ -0,0 +1,129 @@ +import enum +import json +import logging +from contextlib import asynccontextmanager as actxmgr +from datetime import datetime +from datetime import timezone as tz +from typing import Any, Optional, Union + +import trafaret as t +from aiohttp import web + +from ai.backend.common.logging import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class CheckParamSource(enum.Enum): + BODY = 0 + QUERY = 1 + + +def fstime2datetime(t: Union[float, int]) -> datetime: + return datetime.utcfromtimestamp(t).replace(tzinfo=tz.utc) + + +@actxmgr +async def check_params( + request: web.Request, + checker: Optional[t.Trafaret], + *, + read_from: CheckParamSource = CheckParamSource.BODY, + auth_required: bool = True, +) -> Any: + if checker is None: + if request.can_read_body: + raise web.HTTPBadRequest( + text=json.dumps( + { + "type": "https://api.backend.ai/probs/storage/malformed-request", + "title": "Malformed request (request body should be empty)", + }, + ), + content_type="application/problem+json", + ) + else: + if read_from == CheckParamSource.BODY: + raw_params = await request.json() + elif read_from == CheckParamSource.QUERY: + raw_params = request.query + else: + raise ValueError("Invalid source for check_params() helper") + try: + if checker is None: + yield None + else: + yield checker.check(raw_params) + except t.DataError as e: + log.debug("check_params IV error", exc_info=e) + raise web.HTTPBadRequest( + text=json.dumps( + { + "type": "https://api.backend.ai/probs/storage/invalid-api-params", + "title": "Invalid API parameters", + "data": e.as_dict(), + }, + ), + content_type="application/problem+json", + ) + except NotImplementedError: + raise web.HTTPBadRequest( + text=json.dumps( + { + "type": "https://api.backend.ai/probs/storage/unsupported-operation", + "title": "Unsupported operation by the storage backend", + }, + ), + content_type="application/problem+json", + ) + + +async def log_manager_api_entry( + log: Union[logging.Logger, BraceStyleAdapter], + name: str, + params: Any, +) -> None: + if params is not None: + if "src_vfid" in params and "dst_vfid" in params: + log.info( + "ManagerAPI::{}(v:{}, f:{} -> dst_v: {}, dst_f:{})", + name.upper(), + params["src_volume"], + params["src_vfid"], + params["dst_volume"], + params["dst_vfid"], + ) + elif "relpaths" in params: + log.info( + "ManagerAPI::{}(v:{}, f:{}, p*:{})", + name.upper(), + params["volume"], + params["vfid"], + str(params["relpaths"][0]) + "...", + ) + elif "relpath" in params: + log.info( + "ManagerAPI::{}(v:{}, f:{}, p:{})", + name.upper(), + params["volume"], + params["vfid"], + params["relpath"], + ) + elif "vfid" in params: + log.info( + "ManagerAPI::{}(v:{}, f:{})", + name.upper(), + params["volume"], + params["vfid"], + ) + elif "volume" in params: + log.info( + "ManagerAPI::{}(v:{})", + name.upper(), + params["volume"], + ) + return + log.info( + "ManagerAPI::{}()", + name.upper(), + ) diff --git a/src/ai/backend/storage/vfs/__init__.py b/src/ai/backend/storage/vfs/__init__.py new file mode 100644 index 0000000000..b836d63a9d --- /dev/null +++ b/src/ai/backend/storage/vfs/__init__.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import os +import secrets +import shutil +import time +import warnings +from pathlib import Path, PurePosixPath +from typing import AsyncIterator, FrozenSet, Sequence, Union +from uuid import UUID + +import janus + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import BinarySize, HardwareMetadata + +from ..abc import CAP_VFOLDER, AbstractVolume +from ..exception import ExecutionError, InvalidAPIParameters +from ..types import ( + SENTINEL, + DirEntry, + DirEntryType, + FSPerfMetric, + FSUsage, + Sentinel, + Stat, + VFolderCreationOptions, + VFolderUsage, +) +from ..utils import fstime2datetime + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +async def run(cmd: Sequence[Union[str, Path]]) -> str: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + out, err = await proc.communicate() + if err: + raise ExecutionError(err.decode()) + return out.decode() + + +class BaseVolume(AbstractVolume): + + # ------ volume operations ------- + + async def get_capabilities(self) -> FrozenSet[str]: + return frozenset([CAP_VFOLDER]) + + async def get_hwinfo(self) -> HardwareMetadata: + return { + "status": "healthy", + "status_info": None, + "metadata": {}, + } + + async def create_vfolder( + self, + vfid: UUID, + options: VFolderCreationOptions = None, + *, + exist_ok: bool = False, + ) -> None: + vfpath = self.mangle_vfpath(vfid) + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: vfpath.mkdir(0o755, parents=True, exist_ok=exist_ok), + ) + + async def delete_vfolder(self, vfid: UUID) -> None: + vfpath = self.mangle_vfpath(vfid) + loop = asyncio.get_running_loop() + + def _delete_vfolder(): + try: + shutil.rmtree(vfpath) + except FileNotFoundError: + pass + # remove intermediate prefix directories if they become empty + if not os.listdir(vfpath.parent): + vfpath.parent.rmdir() + if not os.listdir(vfpath.parent.parent): + vfpath.parent.parent.rmdir() + + await loop.run_in_executor(None, _delete_vfolder) + + async def clone_vfolder( + self, + src_vfid: UUID, + dst_volume: AbstractVolume, + dst_vfid: UUID, + options: VFolderCreationOptions = None, + ) -> None: + # check if there is enough space in the destination + fs_usage = await dst_volume.get_fs_usage() + vfolder_usage = await self.get_usage(src_vfid) + if vfolder_usage.used_bytes > fs_usage.capacity_bytes - fs_usage.used_bytes: + raise ExecutionError("Not enough space available for clone.") + + # create the target vfolder + src_vfpath = self.mangle_vfpath(src_vfid) + await dst_volume.create_vfolder(dst_vfid, options=options, exist_ok=True) + dst_vfpath = dst_volume.mangle_vfpath(dst_vfid) + + # perform the file-tree copy + try: + await self.copy_tree(src_vfpath, dst_vfpath) + except Exception: + await dst_volume.delete_vfolder(dst_vfid) + log.exception("clone_vfolder: error during copy_tree()") + raise ExecutionError("Copying files from source directories failed.") + + async def copy_tree( + self, + src_vfpath: Path, + dst_vfpath: Path, + ) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + functools.partial( + shutil.copytree, + src_vfpath, + dst_vfpath, + dirs_exist_ok=True, + ), + ) + + async def get_vfolder_mount(self, vfid: UUID, subpath: str) -> Path: + self.sanitize_vfpath(vfid, PurePosixPath(subpath)) + return self.mangle_vfpath(vfid).resolve() + + async def put_metadata(self, vfid: UUID, payload: bytes) -> None: + vfpath = self.mangle_vfpath(vfid) + metadata_path = vfpath / "metadata.json" + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, metadata_path.write_bytes, payload) + + async def get_metadata(self, vfid: UUID) -> bytes: + vfpath = self.mangle_vfpath(vfid) + metadata_path = vfpath / "metadata.json" + loop = asyncio.get_running_loop() + try: + stat = await loop.run_in_executor(None, metadata_path.stat) + if stat.st_size > 10 * (2**20): + raise RuntimeError("Too large metadata (more than 10 MiB)") + data = await loop.run_in_executor(None, metadata_path.read_bytes) + return data + except FileNotFoundError: + return b"" + # Other IO errors should be bubbled up. + + async def get_quota(self, vfid: UUID) -> BinarySize: + raise NotImplementedError + + async def set_quota(self, vfid: UUID, size_bytes: BinarySize) -> None: + raise NotImplementedError + + async def get_performance_metric(self) -> FSPerfMetric: + raise NotImplementedError + + async def get_fs_usage(self) -> FSUsage: + loop = asyncio.get_running_loop() + stat = await loop.run_in_executor(None, os.statvfs, self.mount_path) + return FSUsage( + capacity_bytes=BinarySize(stat.f_frsize * stat.f_blocks), + used_bytes=BinarySize(stat.f_frsize * (stat.f_blocks - stat.f_bavail)), + ) + + async def get_usage( + self, + vfid: UUID, + relpath: PurePosixPath = PurePosixPath("."), + ) -> VFolderUsage: + target_path = self.sanitize_vfpath(vfid, relpath) + total_size = 0 + total_count = 0 + start_time = time.monotonic() + + def _calc_usage(target_path: os.DirEntry | Path) -> None: + nonlocal total_size, total_count + _timeout = 3 + # FIXME: Remove "type: ignore" when python/mypy#11964 is resolved. + with os.scandir(target_path) as scanner: # type: ignore + for entry in scanner: + if entry.is_dir(): + _calc_usage(entry) + continue + if entry.is_file() or entry.is_symlink(): + stat = entry.stat(follow_symlinks=False) + total_size += stat.st_size + total_count += 1 + if total_count % 1000 == 0: + # Cancel if this I/O operation takes too much time. + if time.monotonic() - start_time > _timeout: + raise TimeoutError + + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor(None, _calc_usage, target_path) + except TimeoutError: + # -1 indicates "too many" + total_size = -1 + total_count = -1 + return VFolderUsage(file_count=total_count, used_bytes=total_size) + + async def get_used_bytes(self, vfid: UUID) -> BinarySize: + vfpath = self.mangle_vfpath(vfid) + info = await run(["du", "-hs", vfpath]) + used_bytes, _ = info.split() + return BinarySize.finite_from_str(used_bytes) + + # ------ vfolder internal operations ------- + + def scandir(self, vfid: UUID, relpath: PurePosixPath) -> AsyncIterator[DirEntry]: + target_path = self.sanitize_vfpath(vfid, relpath) + q: janus.Queue[Union[Sentinel, DirEntry]] = janus.Queue() + loop = asyncio.get_running_loop() + + def _scandir(q: janus._SyncQueueProxy[Union[Sentinel, DirEntry]]) -> None: + count = 0 + limit = self.local_config["storage-proxy"]["scandir-limit"] + try: + with os.scandir(target_path) as scanner: + for entry in scanner: + symlink_target = "" + entry_type = DirEntryType.FILE + if entry.is_dir(): + entry_type = DirEntryType.DIRECTORY + if entry.is_symlink(): + entry_type = DirEntryType.SYMLINK + symlink_target = str(Path(entry).resolve()) + entry_stat = entry.stat(follow_symlinks=False) + q.put( + DirEntry( + name=entry.name, + path=Path(entry.path), + type=entry_type, + stat=Stat( + size=entry_stat.st_size, + owner=str(entry_stat.st_uid), + mode=entry_stat.st_mode, + modified=fstime2datetime(entry_stat.st_mtime), + created=fstime2datetime(entry_stat.st_ctime), + ), + symlink_target=symlink_target, + ), + ) + count += 1 + if limit > 0 and count == limit: + break + finally: + q.put(SENTINEL) + + async def _scan_task(_scandir, q) -> None: + await loop.run_in_executor(None, _scandir, q.sync_q) + + async def _aiter() -> AsyncIterator[DirEntry]: + scan_task = asyncio.create_task(_scan_task(_scandir, q)) + await asyncio.sleep(0) + try: + while True: + item = await q.async_q.get() + if item is SENTINEL: + break + yield item + q.async_q.task_done() + finally: + await scan_task + q.close() + await q.wait_closed() + + return _aiter() + + async def mkdir( + self, + vfid: UUID, + relpath: PurePosixPath, + *, + parents: bool = False, + exist_ok: bool = False, + ) -> None: + target_path = self.sanitize_vfpath(vfid, relpath) + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: target_path.mkdir(0o755, parents=parents, exist_ok=exist_ok), + ) + + async def rmdir( + self, + vfid: UUID, + relpath: PurePosixPath, + *, + recursive: bool = False, + ) -> None: + target_path = self.sanitize_vfpath(vfid, relpath) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, target_path.rmdir) + + async def move_file( + self, + vfid: UUID, + src: PurePosixPath, + dst: PurePosixPath, + ) -> None: + src_path = self.sanitize_vfpath(vfid, src) + dst_path = self.sanitize_vfpath(vfid, dst) + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: shutil.move(str(src_path), str(dst_path)), + ) + + async def move_tree( + self, + vfid: UUID, + src: PurePosixPath, + dst: PurePosixPath, + ) -> None: + warnings.warn( + "Use move_file() instead. move_tree() will be deprecated", + DeprecationWarning, + stacklevel=2, + ) + src_path = self.sanitize_vfpath(vfid, src) + if not src_path.is_dir(): + raise InvalidAPIParameters( + msg=f"source path {str(src_path)} is not a directory", + ) + dst_path = self.sanitize_vfpath(vfid, dst) + src_path = self.sanitize_vfpath(vfid, src) + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: shutil.move(str(src_path), str(dst_path)), + ) + + async def copy_file( + self, + vfid: UUID, + src: PurePosixPath, + dst: PurePosixPath, + ) -> None: + src_path = self.sanitize_vfpath(vfid, src) + if not src_path.is_file(): + raise InvalidAPIParameters(msg=f"source path {str(src_path)} is not a file") + dst_path = self.sanitize_vfpath(vfid, dst) + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: dst_path.parent.mkdir(parents=True, exist_ok=True), + ) + await loop.run_in_executor( + None, + lambda: shutil.copyfile(str(src_path), str(dst_path)), + ) + + async def prepare_upload(self, vfid: UUID) -> str: + vfpath = self.mangle_vfpath(vfid) + session_id = secrets.token_hex(16) + + def _create_target(): + upload_base_path = vfpath / ".upload" + upload_base_path.mkdir(exist_ok=True) + upload_target_path = upload_base_path / session_id + upload_target_path.touch() + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _create_target) + return session_id + + async def add_file( + self, + vfid: UUID, + relpath: PurePosixPath, + payload: AsyncIterator[bytes], + ) -> None: + target_path = self.sanitize_vfpath(vfid, relpath) + q: janus.Queue[bytes] = janus.Queue() + + def _write(q: janus._SyncQueueProxy[bytes]) -> None: + with open(target_path, "wb") as f: + while True: + buf = q.get() + try: + if not buf: + return + f.write(buf) + finally: + q.task_done() + + loop = asyncio.get_running_loop() + write_task: asyncio.Task = asyncio.create_task( + loop.run_in_executor(None, _write, q.sync_q), # type: ignore + ) + try: + async for buf in payload: + await q.async_q.put(buf) + await q.async_q.put(b"") + await q.async_q.join() + finally: + await write_task + + def read_file( + self, + vfid: UUID, + relpath: PurePosixPath, + *, + chunk_size: int = 0, + ) -> AsyncIterator[bytes]: + target_path = self.sanitize_vfpath(vfid, relpath) + q: janus.Queue[Union[bytes, Exception]] = janus.Queue() + loop = asyncio.get_running_loop() + + def _read( + q: janus._SyncQueueProxy[Union[bytes, Exception]], + chunk_size: int, + ) -> None: + try: + with open(target_path, "rb") as f: + while True: + buf = f.read(chunk_size) + if not buf: + return + q.put(buf) + except Exception as e: + q.put(e) + finally: + q.put(b"") + + async def _aiter() -> AsyncIterator[bytes]: + nonlocal chunk_size + if chunk_size == 0: + # get the preferred io block size + _vfs_stat = await loop.run_in_executor( + None, + os.statvfs, + self.mount_path, + ) + chunk_size = _vfs_stat.f_bsize + read_fut = loop.run_in_executor(None, _read, q.sync_q, chunk_size) + await asyncio.sleep(0) + try: + while True: + buf = await q.async_q.get() + if isinstance(buf, Exception): + raise buf + yield buf + q.async_q.task_done() + if not buf: + return + finally: + await read_fut + + return _aiter() + + async def delete_files( + self, + vfid: UUID, + relpaths: Sequence[PurePosixPath], + recursive: bool = False, + ) -> None: + target_paths = [self.sanitize_vfpath(vfid, p) for p in relpaths] + + def _delete() -> None: + for p in target_paths: + if p.is_dir() and recursive: + shutil.rmtree(p) + else: + p.unlink() + + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _delete) diff --git a/src/ai/backend/storage/xfs/__init__.py b/src/ai/backend/storage/xfs/__init__.py new file mode 100644 index 0000000000..a058eef775 --- /dev/null +++ b/src/ai/backend/storage/xfs/__init__.py @@ -0,0 +1,255 @@ +import asyncio +import logging +import os +from pathlib import Path, PurePosixPath +from tempfile import NamedTemporaryFile +from typing import Dict, List +from uuid import UUID + +from ai.backend.common.logging import BraceStyleAdapter +from ai.backend.common.types import BinarySize + +from ..exception import ExecutionError, VFolderCreationError +from ..filelock import FileLock +from ..types import VFolderCreationOptions, VFolderUsage +from ..vfs import BaseVolume, run + +log = BraceStyleAdapter(logging.getLogger(__name__)) + +LOCK_FILE = Path("/tmp/backendai-xfs-file-lock") +Path(LOCK_FILE).touch() + + +class XfsProjectRegistry: + file_projects: Path = Path("/etc/projects") + file_projid: Path = Path("/etc/projid") + backend: BaseVolume + name_id_map: Dict[UUID, int] = dict() + project_id_pool: List[int] = list() + + async def init(self, backend: BaseVolume) -> None: + self.backend = backend + + async def read_project_info(self): + def _read_projid_file(): + return self.file_projid.read_text() + + # TODO: how to handle if /etc/proj* files are deleted by external reason? + # TODO: do we need to use /etc/proj* files to enlist the project information? + if self.file_projid.is_file(): + project_id_pool = [] + self.name_id_map = {} + loop = asyncio.get_running_loop() + raw_projid = await loop.run_in_executor(None, _read_projid_file) + for line in raw_projid.splitlines(): + proj_name, proj_id = line.split(":")[:2] + project_id_pool.append(int(proj_id)) + self.name_id_map[UUID(proj_name)] = int(proj_id) + self.project_id_pool = sorted(project_id_pool) + else: + await run(["sudo", "touch", self.file_projid]) + if not Path(self.file_projects).is_file(): + await run(["sudo", "touch", self.file_projects]) + + async def add_project_entry( + self, + *, + vfid: UUID, + quota: int, + project_id: int = None, + ) -> None: + vfpath = self.backend.mangle_vfpath(vfid) + if project_id is None: + project_id = self.get_project_id() + + temp_name_projects = "" + temp_name_projid = "" + + def _create_temp_files(): + nonlocal temp_name_projects, temp_name_projid + _tmp_projects = NamedTemporaryFile(delete=False) + _tmp_projid = NamedTemporaryFile(delete=False) + try: + _projects_content = Path(self.file_projects).read_text() + if _projects_content.strip() != "" and not _projects_content.endswith( + "\n", + ): + _projects_content += "\n" + _projects_content += f"{project_id}:{vfpath}\n" + _tmp_projects.write(_projects_content.encode("ascii")) + temp_name_projects = _tmp_projects.name + + _projid_content = Path(self.file_projid).read_text() + if _projid_content.strip() != "" and not _projid_content.endswith("\n"): + _projid_content += "\n" + _projid_content += f"{str(vfid)}:{project_id}\n" + _tmp_projid.write(_projid_content.encode("ascii")) + temp_name_projid = _tmp_projid.name + finally: + _tmp_projects.close() + _tmp_projid.close() + + def _delete_temp_files(): + try: + os.unlink(temp_name_projects) + except FileNotFoundError: + pass + try: + os.unlink(temp_name_projid) + except FileNotFoundError: + pass + + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor(None, _create_temp_files) + await run(["sudo", "cp", "-rp", temp_name_projects, self.file_projects]) + await run(["sudo", "cp", "-rp", temp_name_projid, self.file_projid]) + finally: + await loop.run_in_executor(None, _delete_temp_files) + + async def remove_project_entry(self, vfid: UUID) -> None: + await run(["sudo", "sed", "-i.bak", f"/{vfid.hex[4:]}/d", self.file_projects]) + await run(["sudo", "sed", "-i.bak", f"/{vfid}/d", self.file_projid]) + + def get_project_id(self) -> int: + """ + Get the next project_id, which is the smallest unused integer. + """ + project_id = -1 + for i in range(len(self.project_id_pool) - 1): + if self.project_id_pool[i] + 1 != self.project_id_pool[i + 1]: + project_id = self.project_id_pool[i] + 1 + break + if len(self.project_id_pool) == 0: + project_id = 1 + if project_id == -1: + project_id = self.project_id_pool[-1] + 1 + return project_id + + +class XfsVolume(BaseVolume): + """ + XFS volume backend. XFS natively supports per-directory quota through + the project qutoa. To enalbe project quota, the XFS volume should be + mounted with `-o pquota` option. + + This backend requires `root` or no password `sudo` permission to run + `xfs_quota` command and write to `/etc/projects` and `/etc/projid`. + """ + + registry: XfsProjectRegistry + + async def init(self, uid: int = None, gid: int = None) -> None: + self.uid = uid if uid is not None else os.getuid() + self.gid = gid if gid is not None else os.getgid() + self.registry = XfsProjectRegistry() + await self.registry.init(self) + + # ----- volume opeartions ----- + async def create_vfolder( + self, + vfid: UUID, + options: VFolderCreationOptions = None, + *, + exist_ok: bool = False, + ) -> None: + await super().create_vfolder(vfid, options, exist_ok=exist_ok) + + # NOTE: Do we need to register project ID for a directory without quota? + # Yes, to easily get the file size and used bytes of a directory. + if options is None or options.quota is None: # max quota i.e. the whole fs size + fs_usage = await self.get_fs_usage() + quota = fs_usage.capacity_bytes + else: + quota = options.quota + # quota = options.quota if options and options.quota else None + # if not quota: + # return + try: + async with FileLock(LOCK_FILE): + log.info("setting project quota (f:{}, q:{})", vfid, str(quota)) + await self.registry.read_project_info() + await self.registry.add_project_entry(vfid=vfid, quota=quota) + await self.set_quota(vfid, quota) + await self.registry.read_project_info() + except (asyncio.CancelledError, asyncio.TimeoutError) as e: + log.exception("vfolder creation timeout", exc_info=e) + await self.delete_vfolder(vfid) + raise + except Exception as e: + log.exception("vfolder creation error", exc_info=e) + await self.delete_vfolder(vfid) + raise VFolderCreationError("problem in setting vfolder quota") + + async def delete_vfolder(self, vfid: UUID) -> None: + async with FileLock(LOCK_FILE): + await self.registry.read_project_info() + if vfid in self.registry.name_id_map.keys(): + try: + log.info("removing project quota (f:{})", vfid) + await self.set_quota(vfid, BinarySize(0)) + except (asyncio.CancelledError, asyncio.TimeoutError) as e: + log.exception("vfolder deletion timeout", exc_info=e) + pass # Pass to delete the physical directlry anyway. + except Exception as e: + log.exception("vfolder deletion error", exc_info=e) + pass # Pass to delete the physical directlry anyway. + finally: + await self.registry.remove_project_entry(vfid) + await super().delete_vfolder(vfid) + await self.registry.read_project_info() + + async def get_quota(self, vfid: UUID) -> BinarySize: + full_report = await run( + ["sudo", "xfs_quota", "-x", "-c", "report -h", self.mount_path], + ) + for line in full_report.split("\n"): + if str(vfid) in line: + report = line + break + if len(report.split()) != 6: + raise ExecutionError("unexpected format for xfs_quota report") + proj_name, _, _, quota, _, _ = report.split() + if not str(vfid).startswith(proj_name): + raise ExecutionError("vfid and project name does not match") + return BinarySize.finite_from_str(quota) + + async def set_quota(self, vfid: UUID, size_bytes: BinarySize) -> None: + if vfid not in self.registry.name_id_map.keys(): + await run( + [ + "sudo", + "xfs_quota", + "-x", + "-c", + f"project -s {vfid}", + self.mount_path, + ], + ) + await run( + [ + "sudo", + "xfs_quota", + "-x", + "-c", + f"limit -p bsoft={int(size_bytes)} bhard={int(size_bytes)} {vfid}", + self.mount_path, + ], + ) + + async def get_usage(self, vfid: UUID, relpath: PurePosixPath = PurePosixPath(".")): + full_report = await run( + ["sudo", "xfs_quota", "-x", "-c", "report -pbih", self.mount_path], + ) + report = "" + for line in full_report.split("\n"): + if str(vfid) in line: + report = line + break + if len(report.split()) != 11: + raise ExecutionError("unexpected format for xfs_quota report") + proj_name, used_size, _, _, _, _, inode_used, _, _, _, _ = report.split() + used_bytes = int(BinarySize.finite_from_str(used_size)) + if not str(vfid).startswith(proj_name): + raise ExecutionError("vfid and project name does not match") + return VFolderUsage(file_count=int(inode_used), used_bytes=used_bytes) diff --git a/src/ai/backend/test/BUILD b/src/ai/backend/test/BUILD new file mode 100644 index 0000000000..5f231f4232 --- /dev/null +++ b/src/ai/backend/test/BUILD @@ -0,0 +1,49 @@ +python_sources( + name="lib", + sources=["**/*.py"], + dependencies=[ + ":resources", + "//:reqs#pytest", + "//:reqs#pytest-dependency", + ], +) + +pex_binary( + name="cli", + dependencies=[ + ":lib", + ], + entry_point="ai.backend.test.cli.__main__:main", +) + +python_distribution( + name="dist", + dependencies=[ + ":lib", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-test", + description="Backend.AI Integration Test Suite", + license="MIT", + ), + entry_points={ + "backendai_cli_v10": { + "test": "ai.backend.test.cli.__main__:main", + }, + }, + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + ], +) diff --git a/src/ai/backend/test/VERSION b/src/ai/backend/test/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/test/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/test/__init__.py b/src/ai/backend/test/__init__.py new file mode 100644 index 0000000000..17b3552989 --- /dev/null +++ b/src/ai/backend/test/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +__version__ = (Path(__file__).parent / 'VERSION').read_text().strip() diff --git a/src/ai/backend/test/cli/__init__.py b/src/ai/backend/test/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/test/cli/__main__.py b/src/ai/backend/test/cli/__main__.py new file mode 100644 index 0000000000..85c2011790 --- /dev/null +++ b/src/ai/backend/test/cli/__main__.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import click +import subprocess +import sys + +from .context import CLIContext +from .utils import CommaSeparatedChoice, CustomUsageArgsCommand + + +@click.group(invoke_without_command=True, context_settings={'help_option_names': ['-h', '--help']}) +@click.pass_context +def main(ctx: click.Context) -> None: + """ + The integration test suite + """ + ctx.obj = CLIContext() + + +@main.command(cls=CustomUsageArgsCommand, context_settings={ + 'ignore_unknown_options': True, + 'allow_extra_args': True, + 'allow_interspersed_args': True, +}, usage_args='[PKGS] [PYTEST_ARGS]') +@click.argument("pkgs", type=CommaSeparatedChoice([ + 'admin', 'user', +]), metavar='PKGS') +@click.pass_context +def run_cli( + ctx: click.Context, + pkgs: list[str], +) -> None: + """A shortcut command to run pytest against a specific set of CLI-based + integration tests + + It takes one or more test package names in a comma-separated list (PKGS) + and forwards all other extra arguments and options (PYTEST_ARGS) to + the underlying pytest command. + + \b + Available CLI-based integration test package names: + admin + user + """ + pytest_args = ctx.args + result = subprocess.run([ + sys.executable, '-m', 'pytest', + '--pyargs', + *(f'ai.backend.test.cli_integration.{pkg}' for pkg in pkgs), + *pytest_args, + ]) + ctx.exit(result.returncode) + + +if __name__ == '__main__': + main() diff --git a/src/ai/backend/test/cli/context.py b/src/ai/backend/test/cli/context.py new file mode 100644 index 0000000000..77c1bf4070 --- /dev/null +++ b/src/ai/backend/test/cli/context.py @@ -0,0 +1,6 @@ +import attr + + +@attr.s(auto_attribs=True, frozen=True) +class CLIContext: + pass diff --git a/src/ai/backend/test/cli/utils.py b/src/ai/backend/test/cli/utils.py new file mode 100644 index 0000000000..b77b303316 --- /dev/null +++ b/src/ai/backend/test/cli/utils.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Optional + +import click + + +class CommaSeparatedChoice(click.Choice): + + def convert( + self, + value: str, + param: Optional[click.Parameter], + ctx: Optional[click.Context], + ) -> Optional[list[str]]: + pieces = value.split(',') + return [super(click.Choice, self).convert(piece, param, ctx) for piece in pieces] + + +class CustomUsageArgsCommand(click.Command): + + def __init__(self, *args, **kwargs) -> None: + self._usage_args = kwargs.pop('usage_args') + super().__init__(*args, **kwargs) + + def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + if self._usage_args: + formatter.write_usage(ctx.command_path, self._usage_args) + else: + super().format_usage(ctx, formatter) diff --git a/src/ai/backend/test/cli_integration/__init__.py b/src/ai/backend/test/cli_integration/__init__.py new file mode 100644 index 0000000000..adca930eea --- /dev/null +++ b/src/ai/backend/test/cli_integration/__init__.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import click + +from ..cli.context import CLIContext + + +@click.group() +@click.pass_obj +def cli(cli_context: CLIContext) -> None: + """CLI-based integration tests""" + pass diff --git a/src/ai/backend/test/cli_integration/admin/__init__.py b/src/ai/backend/test/cli_integration/admin/__init__.py new file mode 100644 index 0000000000..643fc682fc --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/__init__.py @@ -0,0 +1,9 @@ +""" +An integration test for the Backend.AI admin APIs. + +It runs and checks the result of a series of CRUD commands for various entities including domain, group, +user, scaling group, resource policy, resource preset, etc. + +The location of the client executable and the credential scripts are configured as environment variables +read by the pytest fixtures in the ai.backend.test.cli_integration.conftest module. +""" diff --git a/src/ai/backend/test/cli_integration/admin/test_domain.py b/src/ai/backend/test/cli_integration/admin/test_domain.py new file mode 100644 index 0000000000..2a9e9dfb27 --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/test_domain.py @@ -0,0 +1,102 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_add_domain(run: ClientRunnerFunc): + print("[ Add domain ]") + + # Add domain + add_arguments = [ + 'admin', 'domain', 'add', + '-d', 'Test domain', + '-i', + '--total-resource-slots', '{}', + '--allowed-vfolder-hosts', 'local:volume1', + '--allowed-docker-registries', 'cr.backend.ai', + 'test', + ] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'Domain name test is created.' in p.before.decode(), 'Domain creation not successful' + + # Check if domain is added + with closing(run(['--output=json', 'admin', 'domain', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + domain_list = loaded.get('items') + assert isinstance(domain_list, list), 'Domain list not printed properly' + + test_domain = get_domain_from_list(domain_list, 'test') + + assert bool(test_domain), 'Test domain doesn\'t exist' + assert test_domain.get('description') == 'Test domain', 'Domain description mismatch' + assert test_domain.get('is_active') is False, 'Domain active status mismatch' + assert test_domain.get('total_resource_slots') == {}, 'Domain total resource slots mismatch' + assert test_domain.get('allowed_vfolder_hosts') == ['local:volume1'], 'Domain allowed vfolder hosts mismatch' + assert test_domain.get('allowed_docker_registries') == ['cr.backend.ai'], 'Domain allowed docker registries mismatch' + + +def test_update_domain(run: ClientRunnerFunc): + print("[ Update domain ]") + + # Update domain + add_arguments = [ + 'admin', 'domain', 'update', + '--new-name', 'test123', + '--description', 'Test domain updated', + '--is-active', 'TRUE', + '--total-resource-slots', '{}', + '--allowed-vfolder-hosts', 'local:volume2', + '--allowed-docker-registries', 'cr1.backend.ai', + 'test', + ] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'Domain test is updated.' in p.before.decode(), 'Domain update not successful' + + # Check if domain is updated + with closing(run(['--output=json', 'admin', 'domain', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + domain_list = loaded.get('items') + assert isinstance(domain_list, list), 'Domain list not printed properly' + + test_domain = get_domain_from_list(domain_list, 'test123') + + assert bool(test_domain), 'Test domain doesn\'t exist' + assert test_domain.get('description') == 'Test domain updated', 'Domain description mismatch' + assert test_domain.get('is_active') is True, 'Domain active status mismatch' + assert test_domain.get('total_resource_slots') == {}, 'Domain total resource slots mismatch' + assert test_domain.get('allowed_vfolder_hosts') == ['local:volume2'], 'Domain allowed vfolder hosts mismatch' + assert test_domain.get('allowed_docker_registries') == ['cr1.backend.ai'], \ + 'Domain allowed docker registries mismatch' + + +def test_delete_domain(run: ClientRunnerFunc): + print("[ Delete domain ]") + + # Delete domain + with closing(run(['admin', 'domain', 'purge', 'test123'])) as p: + p.sendline('y') + p.expect(EOF) + assert 'Domain is deleted:' in p.before.decode(), 'Domain deletion failed' + + +def test_list_domain(run: ClientRunnerFunc): + with closing(run(['--output=json', 'admin', 'domain', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + domain_list = loaded.get('items') + assert isinstance(domain_list, list), 'Domain list not printed properly' + + +def get_domain_from_list(domains: list, name: str) -> dict: + for domain in domains: + if domain.get('name') == name: + return domain + return {} diff --git a/src/ai/backend/test/cli_integration/admin/test_group.py b/src/ai/backend/test/cli_integration/admin/test_group.py new file mode 100644 index 0000000000..d381aac4df --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/test_group.py @@ -0,0 +1,26 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_add_group(run: ClientRunnerFunc): + pass + + +def test_update_group(run: ClientRunnerFunc): + pass + + +def test_delete_group(run: ClientRunnerFunc): + pass + + +def test_list_group(run: ClientRunnerFunc): + print("[ List group ]") + with closing(run(['--output=json', 'admin', 'group', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + group_list = loaded.get('items') + assert isinstance(group_list, list), 'Group list not printed properly' diff --git a/src/ai/backend/test/cli_integration/admin/test_image.py b/src/ai/backend/test/cli_integration/admin/test_image.py new file mode 100644 index 0000000000..688a7435cf --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/test_image.py @@ -0,0 +1,22 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_alias_image(run: ClientRunnerFunc): + pass + + +def test_dealias_image(run: ClientRunnerFunc): + pass + + +def test_list_image(run: ClientRunnerFunc): + print("[ List image ]") + with closing(run(['--output=json', 'admin', 'image', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + image_list = loaded.get('items') + assert isinstance(image_list, list), 'Image list not printed properly' diff --git a/src/ai/backend/test/cli_integration/admin/test_keypair.py b/src/ai/backend/test/cli_integration/admin/test_keypair.py new file mode 100644 index 0000000000..5af6778e8e --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/test_keypair.py @@ -0,0 +1,179 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_add_keypair(run: ClientRunnerFunc): + """ + Test add keypair. + This test should be execued first in test_keypair.py. + """ + print("[ Add keypair ]") + # Add test user + add_arguments = ['--output=json', 'admin', 'user', 'add', '-u', 'adminkeypair', '-n', 'John Doe', + '-r', 'admin', 'default', 'adminkeypair@lablup.com', '1q2w3e4r'] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'User adminkeypair@lablup.com is created' in p.before.decode(), 'Account add error' + + add_arguments = ['--output=json', 'admin', 'user', 'add', '-u', 'userkeypair', '-n', 'Richard Doe', + 'default', 'userkeypair@lablup.com', '1q2w3e4r'] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'User userkeypair@lablup.com is created' in p.before.decode(), 'Account add error' + + # Create keypair + with closing(run(['admin', 'keypair', 'add', '-a', '-i', '-r', '25000', 'adminkeypair@lablup.com', 'default'])) as p: + p.expect(EOF) + assert 'Access Key:' in p.before.decode() and 'Secret Key:' in p.before.decode(), 'Keypair add error' + + with closing(run(['admin', 'keypair', 'add', 'userkeypair@lablup.com', 'default'])) as p: + p.expect(EOF) + assert 'Access Key:' in p.before.decode() and 'Secret Key:' in p.before.decode(), 'Keypair add error' + + # Check if keypair is added + with closing(run(['--output=json', 'admin', 'keypair', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + keypair_list = loaded.get('items') + assert isinstance(keypair_list, list), 'List not printed properly!' + + admin_keypair = get_keypair_from_list(keypair_list, 'adminkeypair@lablup.com') + user_keypair = get_keypair_from_list(keypair_list, 'userkeypair@lablup.com') + + assert 'access_key' in admin_keypair, 'Admin keypair doesn\'t exist' + assert admin_keypair.get('is_active') is False, 'Admin keypair is_active mismatch' + assert admin_keypair.get('is_admin') is True, 'Admin keypair is_admin mismatch' + assert admin_keypair.get('rate_limit') == 25000, 'Admin keypair rate_limit mismatch' + assert admin_keypair.get('resource_policy') == 'default', 'Admin keypair resource_policy mismatch' + + assert 'access_key' in user_keypair, 'Admin keypair doesn\'t exist' + assert user_keypair.get('is_active') is True, 'User keypair is_active mismatch' + assert user_keypair.get('is_admin') is False, 'User keypair is_admin mismatch' + assert user_keypair.get('rate_limit') == 5000, 'User keypair rate_limit mismatch' + assert user_keypair.get('resource_policy') == 'default', 'User keypair resource_policy mismatch' + + +def test_update_keypair(run: ClientRunnerFunc): + """ + Test update keypair. + This test must be executed after test_add_keypair. + """ + print("[ Update keypair ]") + # Get access key + with closing(run(['--output=json', 'admin', 'keypair', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + keypair_list = loaded.get('items') + assert isinstance(keypair_list, list), 'List not printed properly!' + + admin_keypair = get_keypair_from_list(keypair_list, 'adminkeypair@lablup.com') + user_keypair = get_keypair_from_list(keypair_list, 'userkeypair@lablup.com') + assert 'access_key' in admin_keypair, 'Admin keypair info doesn\'t exist' + assert 'access_key' in user_keypair, 'User keypair info doesn\'t exist' + + # Update keypair + with closing(run([ + 'admin', 'keypair', 'update', + '--is-active', 'TRUE', + '--is-admin', 'FALSE', + '-r', '15000', + admin_keypair['access_key'], + ])) as p: + p.expect(EOF) + assert 'Key pair is updated:' in p.before.decode(), 'Admin keypair update error' + + with closing(run([ + 'admin', 'keypair', 'update', + '--is-active', 'FALSE', + '--is-admin', 'TRUE', + '-r', '15000', + user_keypair['access_key'], + ])) as p: + p.expect(EOF) + assert 'Key pair is updated:' in p.before.decode(), 'User keypair update error' + + # Check if keypair is updated + with closing(run(['--output=json', 'admin', 'keypair', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + updated_keypair_list = loaded.get('items') + assert isinstance(updated_keypair_list, list), 'List not printed properly!' + + updated_admin_keypair = get_keypair_from_list(updated_keypair_list, 'adminkeypair@lablup.com') + updated_user_keypair = get_keypair_from_list(updated_keypair_list, 'userkeypair@lablup.com') + + assert 'access_key' in updated_admin_keypair, 'Admin keypair doesn\'t exist' + assert updated_admin_keypair.get('is_active') is True, 'Admin keypair is_active mismatch' + assert updated_admin_keypair.get('is_admin') is False, 'Admin keypair is_admin mismatch' + assert updated_admin_keypair.get('rate_limit') == 15000, 'Admin keypair rate_limit mismatch' + assert updated_admin_keypair.get('resource_policy') == 'default', 'Admin keypair resource_policy mismatch' + + assert 'access_key' in updated_user_keypair, 'Admin keypair doesn\'t exist' + assert updated_user_keypair.get('is_active') is False, 'User keypair is_active mismatch' + assert updated_user_keypair.get('is_admin') is True, 'User keypair is_admin mismatch' + assert updated_user_keypair.get('rate_limit') == 15000, 'User keypair rate_limit mismatch' + assert updated_user_keypair.get('resource_policy') == 'default', 'User keypair resource_policy mismatch' + + +def test_delete_keypair(run: ClientRunnerFunc): + """ + Test delete keypair. + This test must be executed after test_add_keypair. + """ + print("[ Delete keypair ]") + # Get access key + with closing(run(['--output=json', 'admin', 'keypair', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + keypair_list = loaded.get('items') + assert isinstance(keypair_list, list), 'List not printed properly!' + + admin_keypair = get_keypair_from_list(keypair_list, 'adminkeypair@lablup.com') + user_keypair = get_keypair_from_list(keypair_list, 'userkeypair@lablup.com') + assert 'access_key' in admin_keypair, 'Admin keypair info doesn\'t exist' + assert 'access_key' in user_keypair, 'User keypair info doesn\'t exist' + + # Delete keypair + with closing(run(['admin', 'keypair', 'delete', admin_keypair['access_key']])) as p: + p.expect(EOF) + print(p.before.decode()) + + with closing(run(['admin', 'keypair', 'delete', user_keypair['access_key']])) as p: + p.expect(EOF) + print(p.before.decode()) + + # Delete test user + with closing(run(['admin', 'user', 'purge', 'adminkeypair@lablup.com'])) as p: + p.sendline('y') + p.expect(EOF) + assert 'User is deleted:' in p.before.decode(), 'Account deletion failed: adminkeypair' + + with closing(run(['admin', 'user', 'purge', 'userkeypair@lablup.com'])) as p: + p.sendline('y') + p.expect(EOF) + assert 'User is deleted:' in p.before.decode(), 'Account deletion failed: userkeypair' + + +def test_list_keypair(run: ClientRunnerFunc): + """ + Test list keypair. + """ + with closing(run(['--output=json', 'admin', 'keypair', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + keypair_list = loaded.get('items') + assert isinstance(keypair_list, list), 'List not printed properly!' + + +def get_keypair_from_list(keypairs: list, userid: str) -> dict: + for keypair in keypairs: + if keypair.get('user_id', '') == userid: + return keypair + return {} diff --git a/src/ai/backend/test/cli_integration/admin/test_keypair_resource_policy.py b/src/ai/backend/test/cli_integration/admin/test_keypair_resource_policy.py new file mode 100644 index 0000000000..3fb897b34e --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/test_keypair_resource_policy.py @@ -0,0 +1,116 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_add_keypair_resource_policy(run: ClientRunnerFunc): + print("[ Add keypair resource policy ]") + + # Add keypair resource policy + add_arguments = [ + 'admin', 'keypair-resource-policy', 'add', + '--default-for-unspecified', 'LIMITED', + '--total-resource-slots', '{}', + '--max-concurrent-sessions', '20', + '--max-containers-per-session', '2', + '--max-vfolder-count', '15', + '--max-vfolder-size', '0', + '--allowed-vfolder-hosts', 'local:volume1', + '--idle-timeout', '1200', + 'test_krp', + ] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'Keypair resource policy test_krp is created.' in p.before.decode(), \ + 'Keypair resource policy creation not successful' + + # Check if keypair resource policy is created + with closing(run(['--output=json', 'admin', 'keypair-resource-policy', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + krp_list = loaded.get('items') + assert isinstance(krp_list, list), 'Keypair resource policy list not printed properly' + + test_krp = get_keypair_resource_policy_from_list(krp_list, 'test_krp') + + assert bool(test_krp), 'Test keypair resource policy doesn\'t exist' + assert test_krp.get('total_resource_slots') == '{}', 'Test keypair resource policy total resource slot mismatch' + assert test_krp.get('max_concurrent_sessions') == 20, 'Test keypair resource policy max concurrent session mismatch' + assert test_krp.get('max_vfolder_count') == 15, 'Test keypair resource policy max vfolder count mismatch' + assert test_krp.get('max_vfolder_size') == '0 Bytes', 'Test keypair resource policy max vfolder size mismatch' + assert test_krp.get('idle_timeout') == 1200, 'Test keypair resource policy idle timeout mismatch' + assert test_krp.get('max_containers_per_session') == 2,\ + 'Test keypair resouce policy max containers per session mismatch' + assert test_krp.get('allowed_vfolder_hosts') == ['local:volume1'], \ + 'Test keypair resource policy allowed vfolder hosts mismatch' + + +def test_update_keypair_resource_policy(run: ClientRunnerFunc): + print("[ Update keypair resource policy ]") + + # Update keypair resource policy + add_arguments = [ + 'admin', 'keypair-resource-policy', 'update', + '--default-for-unspecified', 'UNLIMITED', + '--total-resource-slots', '{}', + '--max-concurrent-sessions', '30', + '--max-containers-per-session', '1', + '--max-vfolder-count', '10', + '--max-vfolder-size', '0', + '--allowed-vfolder-hosts', 'local:volume2', + '--idle-timeout', '1800', + 'test_krp', + ] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'Update succeeded.' in p.before.decode(), 'Keypair resource policy update not successful' + + # Check if keypair resource policy is updated + with closing(run(['--output=json', 'admin', 'keypair-resource-policy', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + krp_list = loaded.get('items') + assert isinstance(krp_list, list), 'Keypair resource policy list not printed properly' + + test_krp = get_keypair_resource_policy_from_list(krp_list, 'test_krp') + + assert bool(test_krp), 'Test keypair resource policy doesn\'t exist' + assert test_krp.get('total_resource_slots') == '{}', 'Test keypair resource policy total resource slot mismatch' + assert test_krp.get('max_concurrent_sessions') == 30, 'Test keypair resource policy max concurrent session mismatch' + assert test_krp.get('max_vfolder_count') == 10, 'Test keypair resource policy max vfolder count mismatch' + assert test_krp.get('max_vfolder_size') == '0 Bytes', 'Test keypair resource policy max vfolder size mismatch' + assert test_krp.get('idle_timeout') == 1800, 'Test keypair resource policy idle timeout mismatch' + assert test_krp.get('max_containers_per_session') == 1,\ + 'Test keypair resouce policy max containers per session mismatch' + assert test_krp.get('allowed_vfolder_hosts') == ['local:volume2'], \ + 'Test keypair resource policy allowed vfolder hosts mismatch' + + +def test_delete_keypair_resource_policy(run: ClientRunnerFunc): + print("[ Delete keypair resource policy ]") + + # Delete keypair resource policy + with closing(run(['admin', 'keypair-resource-policy', 'delete', 'test_krp'])) as p: + p.sendline('y') + p.expect(EOF) + assert 'Resource policy test_krp is deleted.' in p.before.decode(), 'Keypair resource policy deletion failed' + + +def test_list_keypair_resource_policy(run: ClientRunnerFunc): + print("[ List keypair resource policy ]") + with closing(run(['--output=json', 'admin', 'keypair-resource-policy', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + krp_list = loaded.get('items') + assert isinstance(krp_list, list), 'Keypair resource policy list not printed properly' + + +def get_keypair_resource_policy_from_list(krps: list, name: str) -> dict: + for krp in krps: + if krp.get('name') == name: + return krp + return {} diff --git a/src/ai/backend/test/cli_integration/admin/test_scaling_group.py b/src/ai/backend/test/cli_integration/admin/test_scaling_group.py new file mode 100644 index 0000000000..ffac82965d --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/test_scaling_group.py @@ -0,0 +1,101 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_add_scaling_group(run: ClientRunnerFunc): + # Create scaling group + with closing(run([ + 'admin', 'scaling-group', 'add', + '-d', 'Test scaling group', + '-i', + '--driver', 'static', + '--driver-opts', '{"x": 1}', + '--scheduler', 'fifo', + 'test_group1', + ])) as p: + p.expect(EOF) + assert 'Scaling group name test_group1 is created.' in p.before.decode(), \ + 'Test scaling group not created successfully' + + # Check if scaling group is created + with closing(run(['--output=json', 'admin', 'scaling-group', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + scaling_group_list = loaded.get('items') + assert isinstance(scaling_group_list, list), 'Scaling group list not printed properly' + + test_group = get_scaling_group_from_list(scaling_group_list, 'test_group1') + assert bool(test_group), 'Test scaling group doesn\'t exist' + + # Get the full detail. + with closing(run(['--output=json', 'admin', 'scaling-group', 'info', 'test_group1'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + scaling_group_list = loaded.get('items') + assert isinstance(scaling_group_list, list), 'Scaling group info not printed properly' + + test_group = get_scaling_group_from_list(scaling_group_list, 'test_group1') + assert test_group.get('description') == 'Test scaling group', 'Scaling group description mismatch' + assert test_group.get('is_active') is False, 'Scaling group active status mismatch' + assert test_group.get('driver') == 'static', 'Scaling group driver mismatch' + assert test_group.get('driver_opts') == {'x': 1}, 'Scaling group driver options mismatch' + assert test_group.get('scheduler') == 'fifo', 'Scaling group scheduler mismatch' + assert test_group.get('scheduler_opts') == {}, 'Scaling group scheduler options mismatch' + + +def test_update_scaling_group(run: ClientRunnerFunc): + # Update scaling group + with closing(run([ + 'admin', 'scaling-group', 'update', + '-d', 'Test scaling group updated', + '--driver', 'non-static', + '--scheduler', 'lifo', + 'test_group1', + ])) as p: + p.expect(EOF) + assert 'Scaling group test_group1 is updated.' in p.before.decode(), \ + 'Test scaling group not updated successfully' + + # Check if scaling group is updated + with closing(run(['--output=json', 'admin', 'scaling-group', 'info', 'test_group1'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + scaling_group_list = loaded.get('items') + assert isinstance(scaling_group_list, list), 'Scaling group list not printed properly' + + test_group = get_scaling_group_from_list(scaling_group_list, 'test_group1') + + assert bool(test_group), 'Test scaling group doesn\'t exist' + assert test_group.get('description') == 'Test scaling group updated', 'Scaling group description mismatch' + assert test_group.get('is_active') is True, 'Scaling group active status mismatch' + assert test_group.get('driver') == 'non-static', 'Scaling group driver mismatch' + assert test_group.get('driver_opts') == {'x': 1}, 'Scaling group driver options mismatch' + assert test_group.get('scheduler') == 'lifo', 'Scaling group scheduler mismatch' + assert test_group.get('scheduler_opts') == {}, 'Scaling group scheduler options mismatch' + + +def test_delete_scaling_group(run: ClientRunnerFunc): + with closing(run(['admin', 'scaling-group', 'delete', 'test_group1'])) as p: + p.expect(EOF) + assert 'Scaling group is deleted: test_group1.' in p.before.decode(), 'Test scaling group deletion unsuccessful' + + +def test_list_scaling_group(run: ClientRunnerFunc): + with closing(run(['--output=json', 'admin', 'scaling-group', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + scaling_group_list = loaded.get('items') + assert isinstance(scaling_group_list, list), 'Scaling group list not printed properly' + + +def get_scaling_group_from_list(scaling_groups: list, groupname: str) -> dict: + for sg in scaling_groups: + if sg.get('name') == groupname: + return sg + return {} diff --git a/src/ai/backend/test/cli_integration/admin/test_storage.py b/src/ai/backend/test/cli_integration/admin/test_storage.py new file mode 100644 index 0000000000..2f561deb50 --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/test_storage.py @@ -0,0 +1,30 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_list_storage(run: ClientRunnerFunc): + """ + Test list storage. + """ + print("[ List storage ]") + with closing(run(['--output=json', 'admin', 'storage', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + storage_list = loaded.get('items') + assert isinstance(storage_list, list), 'Storage list not printed properly' + + +def test_info_storage(run: ClientRunnerFunc): + """ + Test storage info. + """ + print("[ Print storage info ]") + with closing(run(['--output=json', 'admin', 'storage', 'info', 'local:volume1'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + storage_list = loaded.get('items') + assert isinstance(storage_list, list), 'Storage info not printed properly' diff --git a/src/ai/backend/test/cli_integration/admin/test_user.py b/src/ai/backend/test/cli_integration/admin/test_user.py new file mode 100644 index 0000000000..0fa5d2dcf4 --- /dev/null +++ b/src/ai/backend/test/cli_integration/admin/test_user.py @@ -0,0 +1,206 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_add_user(run: ClientRunnerFunc): + """ + Testcase for user addition. + """ + print("[ Add user ]") + + # Check if test account exists + with closing(run(['--output=json', 'admin', 'user', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + user_list = loaded.get('items') + + test_user1 = get_user_from_list(user_list, 'testaccount1') + test_user2 = get_user_from_list(user_list, 'testaccount2') + test_user3 = get_user_from_list(user_list, 'testaccount3') + + if not bool(test_user1): + # Add user + add_arguments = [ + '--output=json', 'admin', 'user', 'add', + '-u', 'testaccount1', + '-n', 'John Doe', + '--need-password-change', 'default', + 'testaccount1@lablup.com', + '1q2w3e4r', + ] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'User testaccount1@lablup.com is created' in p.before.decode(), 'Account add error' + + if not bool(test_user2): + # Add user + add_arguments = [ + '--output=json', 'admin', 'user', 'add', + '-u', 'testaccount2', + '-n', 'John Roe', + '-r', 'admin', + '-s', 'inactive', + 'default', + 'testaccount2@lablup.com', + '1q2w3e4r', + ] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'User testaccount2@lablup.com is created' in p.before.decode(), 'Account add error' + + if not bool(test_user3): + # Add user + add_arguments = [ + '--output=json', 'admin', 'user', 'add', + '-u', 'testaccount3', + '-n', 'Richard Roe', + '-r', 'monitor', + '-s', 'before-verification', + '--need-password-change', 'default', + 'testaccount3@lablup.com', + '1q2w3e4r', + ] + with closing(run(add_arguments)) as p: + p.expect(EOF) + assert 'User testaccount3@lablup.com is created' in p.before.decode(), 'Account add error' + + # Check if user is added + with closing(run(['--output=json', 'admin', 'user', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + user_list = loaded.get('items') + + assert isinstance(user_list, list), 'Expected user list' + added_user1 = get_user_from_list(user_list, 'testaccount1') + added_user2 = get_user_from_list(user_list, 'testaccount2') + added_user3 = get_user_from_list(user_list, 'testaccount3') + + assert bool(added_user1), 'Added account doesn\'t exist: Account#1' + assert added_user1.get('email') == 'testaccount1@lablup.com', 'E-mail mismatch: Account#1' + assert added_user1.get('full_name') == 'John Doe', 'Full name mismatch: Account#1' + assert added_user1.get('status') == 'active', 'User status mismatch: Account#1' + assert added_user1.get('role') == 'user', 'Role mismatch: Account#1' + assert added_user1.get('need_password_change') is True, 'Password change status mismatch: Account#1' + + assert bool(added_user2), 'Added account doesn\'t exist: Account#2' + assert added_user2.get('email') == 'testaccount2@lablup.com', 'E-mail mismatch: Account#2' + assert added_user2.get('full_name') == 'John Roe', 'Full name mismatch: Account#2' + assert added_user2.get('status') == 'inactive', 'User status mismatch: Account#2' + assert added_user2.get('role') == 'admin', 'Role mismatch: Account#2' + assert added_user2.get('need_password_change') is False, 'Password change status mismatch: Account#2' + + assert bool(added_user3), 'Added account doesn\'t exist: Account#3' + assert added_user3.get('email') == 'testaccount3@lablup.com', 'E-mail mismatch: Account#3' + assert added_user3.get('full_name') == 'Richard Roe', 'Full name mismatch: Account#3' + assert added_user3.get('status') == 'before-verification', 'User status mismatch: Account#3' + assert added_user3.get('role') == 'monitor', 'Role mismatch: Account#3' + assert added_user3.get('need_password_change') is True, 'Password change status mismatch: Account#3' + + +def test_update_user(run: ClientRunnerFunc): + """ + Run this testcase after test_update_user. + Testcase for user update. + TODO: User update with roles is not fully covered yet. + """ + print("[ Update user ]") + + # Check if user exists + with closing(run(['--output=json', 'admin', 'user', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + user_list = loaded.get('items') + assert isinstance(user_list, list), 'Expected user list' + + # Update user + update_arguments = ['--output=json', 'admin', 'user', 'update', '-u', 'testaccount123', '-n', 'Foo Bar', '-s', + 'inactive', '-d', 'default', 'testaccount1@lablup.com'] + with closing(run(update_arguments)) as p: + p.expect(EOF) + + update_arguments = ['--output=json', 'admin', 'user', 'update', '-u', 'testaccount231', '-n', 'Baz Quz', '-s', + 'active', '-r', 'admin', '--need-password-change', 'testaccount2@lablup.com'] + with closing(run(update_arguments)) as p: + p.expect(EOF) + + update_arguments = ['--output=json', 'admin', 'user', 'update', '-u', 'testaccount312', '-n', 'Alice B.', '-s', + 'active', '-r', 'monitor', 'testaccount3@lablup.com'] + with closing(run(update_arguments)) as p: + p.expect(EOF) + + # Check if user is updated correctly + with closing(run(['--output=json', 'admin', 'user', 'list'])) as p: + p.expect(EOF) + after_update_decoded = p.before.decode() + after_update_loaded = json.loads(after_update_decoded) + updated_user_list = after_update_loaded.get('items') + assert isinstance(updated_user_list, list), 'Expected user list' + + test_user1 = get_user_from_list(updated_user_list, 'testaccount123') + test_user2 = get_user_from_list(updated_user_list, 'testaccount231') + test_user3 = get_user_from_list(updated_user_list, 'testaccount312') + + assert bool(test_user1), 'Account not found - Account#1' + assert test_user1.get('full_name') == 'Foo Bar', 'Full name mismatch: Account#1' + assert test_user1.get('status') == 'inactive', 'User status mismatch: Account#1' + assert test_user1.get('role') == 'user', 'Role mismatch: Account#1' + assert test_user1.get('need_password_change') is False, 'Password change status mismatch: Account#1' + assert test_user1.get('domain_name') == 'default', 'Domain mismatch: Account#1' + + assert bool(test_user2), 'Account not found - Account#2' + assert test_user2.get('full_name') == 'Baz Quz', 'Full name mismatch: Account#2' + assert test_user2.get('status') == 'active', 'User status mismatch: Account#2' + assert test_user2.get('role') == 'admin', 'Role mismatch: Account#2' + assert test_user2.get('need_password_change') is True, 'Password change status mismatch: Account#2' + + assert bool(test_user3), 'Account not found - Account#3' + assert test_user3.get('full_name') == 'Alice B.', 'Full name mismatch: Account#3' + assert test_user3.get('status') == 'active', 'User status mismatch: Account#3' + assert test_user3.get('role') == 'monitor', 'Role mismatch: Account#3' + assert test_user3.get('need_password_change') is False, 'Password change status mismatch: Account#3' + + +def test_delete_user(run: ClientRunnerFunc): + """ + !!Run this testcase after running test_add_user + Testcase for user deletion. + """ + print("[ Delete user ]") + with closing(run(['admin', 'user', 'purge', 'testaccount1@lablup.com'])) as p: + p.sendline('y') + p.expect(EOF) + assert 'User is deleted:' in p.before.decode(), 'Account deletion failed: Account#1' + + with closing(run(['admin', 'user', 'purge', 'testaccount2@lablup.com'])) as p: + p.sendline('y') + p.expect(EOF) + assert 'User is deleted:' in p.before.decode(), 'Account deletion failed: Account#2' + + with closing(run(['admin', 'user', 'purge', 'testaccount3@lablup.com'])) as p: + p.sendline('y') + p.expect(EOF) + assert 'User is deleted:' in p.before.decode(), 'Account deletion failed: Account#3' + + +def test_list_user(run: ClientRunnerFunc): + """ + Testcase for user listing. + """ + with closing(run(['--output=json', 'admin', 'user', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + user_list = loaded.get('items') + assert isinstance(user_list, list) + + +def get_user_from_list(users: list, username: str) -> dict: + for user in users: + if user.get('username') == username: + return user + return {} diff --git a/src/ai/backend/test/cli_integration/conftest.py b/src/ai/backend/test/cli_integration/conftest.py new file mode 100644 index 0000000000..3aa8bdef8c --- /dev/null +++ b/src/ai/backend/test/cli_integration/conftest.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from contextlib import closing +import os +from pathlib import Path +import re +import secrets +from typing import Iterator, Sequence + +import pexpect +import pytest + +from ai.backend.test.utils.cli import ClientRunnerFunc, EOF, run as _run + +_rx_env_export = re.compile(r"^(export )?(?P\w+)=(?P.*)$") + + +@pytest.fixture(scope="session") +def client_venv() -> Path: + p = os.environ.get("BACKENDAI_TEST_CLIENT_VENV", None) + if p is None: + raise RuntimeError("Missing BACKENDAI_TEST_CLIENT_VENV env-var!") + return Path(p) + + +@pytest.fixture(scope="session") +def client_bin( + client_venv: Path, +) -> Path: + return client_venv / 'bin' / 'backend.ai' + + +@pytest.fixture(scope="session") +def client_environ() -> dict[str, str]: + p = os.environ.get("BACKENDAI_TEST_CLIENT_ENV", None) + if p is None: + raise RuntimeError("Missing BACKENDAI_TEST_CLIENT_ENV env-var!") + envs = {} + sample_admin_sh = Path(p) + if sample_admin_sh.exists(): + lines = sample_admin_sh.read_text().splitlines() + for line in lines: + if m := _rx_env_export.search(line.strip()): + envs[m.group('key')] = m.group('val') + return envs + + +@pytest.fixture(scope="session") +def run(client_bin: Path, client_environ: dict[str, str]) -> Iterator[ClientRunnerFunc]: + + def run_impl(cmdargs: Sequence[str | Path], *args, **kwargs) -> pexpect.spawn: + return _run([client_bin, *cmdargs], *args, **kwargs, env=client_environ) + + yield run_impl + + +@pytest.fixture +def domain_name() -> str: + return f"testing-{secrets.token_hex(8)}" + + +@pytest.fixture +def temp_domain(domain_name: str, run: ClientRunnerFunc) -> Iterator[str]: + run(['admin', 'domains', 'add', domain_name]) + print("==== temp_domain created ====") + try: + yield domain_name + finally: + with closing(run(['admin', 'domains', 'purge', domain_name])) as p: + p.expect_exact("Are you sure?") + p.sendline("Y") + p.expect(EOF) diff --git a/src/ai/backend/test/cli_integration/user/__init__.py b/src/ai/backend/test/cli_integration/user/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/test/cli_integration/user/test_vfolder.py b/src/ai/backend/test/cli_integration/user/test_vfolder.py new file mode 100644 index 0000000000..e50bc65cea --- /dev/null +++ b/src/ai/backend/test/cli_integration/user/test_vfolder.py @@ -0,0 +1,102 @@ +import json +from contextlib import closing + +from ...utils.cli import EOF, ClientRunnerFunc + + +def test_create_vfolder(run: ClientRunnerFunc): + """ + Test create vfolder function. + This test should be executed first in test_vfolder.py. + TODO: Unannotate the following code after group deletion issue is resolved. + """ + # Create group first + # with closing(run(['admin', 'group', 'add', 'default', 'testgroup'])) as p: + # p.expect(EOF) + # assert 'Group name testgroup is created in domain default' in p.before.decode(), \ + # 'Test group not created successfully.' + print("[ Create vfolder ]") + # Create vfolder + with closing(run(['vfolder', 'create', '-p', 'rw', 'test_folder1', 'local:volume1'])) as p: + p.expect(EOF) + assert 'Virtual folder "test_folder1" is created' in p.before.decode(), 'Test folder1 not created successfully.' + + with closing(run(['vfolder', 'create', '-p', 'ro', 'test_folder2', 'local:volume1'])) as p: + p.expect(EOF) + assert 'Virtual folder "test_folder2" is created' in p.before.decode(), 'Test folder2 not created successfully.' + + # Check if vfolder is created + with closing(run(['--output=json', 'vfolder', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + folder_list = loaded.get('items') + assert isinstance(folder_list, list), 'Error in listing test folders!' + + test_folder1 = get_folder_from_list(folder_list, 'test_folder1') + test_folder2 = get_folder_from_list(folder_list, 'test_folder2') + + assert bool(test_folder1), 'Test folder 1 doesn\'t exist!' + assert test_folder1.get('permission') == 'rw', 'Test folder 1 permission mismatch.' + + assert bool(test_folder2), 'Test folder 2 doesn\'t exist!' + assert test_folder2.get('permission') == 'ro', 'Test folder 2 permission mismatch.' + + +def test_rename_vfolder(run: ClientRunnerFunc): + """ + Test rename vfolder function. + !! Make sure you execute this test after test_create_vfolder !! + Otherwise, it will raise an error. + """ + print("[ Rename vfolder ]") + # Rename vfolder + with closing(run(['vfolder', 'rename', 'test_folder1', 'test_folder3'])) as p: + p.expect(EOF) + assert 'Renamed' in p.before.decode(), 'Test folder1 not renamed successfully.' + + # Check if vfolder is updated + with closing(run(['--output=json', 'vfolder', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + folder_list = loaded.get('items') + assert isinstance(folder_list, list), 'Error in listing test folders!' + + test_folder3 = get_folder_from_list(folder_list, 'test_folder3') + assert bool(test_folder3), 'Test folder 3 doesn\'t exist!' + + +def test_delete_vfolder(run: ClientRunnerFunc): + """ + Test delete vfolder function. + !! Make sure you execute this test after 1. test_create_vfolder, 2. test_rename_vfolder !! + Otherwise, it will raise an error. + """ + print("[ Delete vfolder ]") + with closing(run(['vfolder', 'delete', 'test_folder2'])) as p: + p.expect(EOF) + assert 'Deleted' in p.before.decode(), 'Test folder 2 not deleted successfully.' + + with closing(run(['vfolder', 'delete', 'test_folder3'])) as p: + p.expect(EOF) + assert 'Deleted' in p.before.decode(), 'Test folder 3 not deleted successfully.' + + +def test_list_vfolder(run: ClientRunnerFunc): + """ + Test list vfolder function. + """ + with closing(run(['--output=json', 'vfolder', 'list'])) as p: + p.expect(EOF) + decoded = p.before.decode() + loaded = json.loads(decoded) + folder_list = loaded.get('items') + assert isinstance(folder_list, list) + + +def get_folder_from_list(folders: list, foldername: str) -> dict: + for folder in folders: + if folder.get('name', '') == foldername: + return folder + return {} diff --git a/src/ai/backend/test/py.typed b/src/ai/backend/test/py.typed new file mode 100644 index 0000000000..48cdce8528 --- /dev/null +++ b/src/ai/backend/test/py.typed @@ -0,0 +1 @@ +placeholder diff --git a/src/ai/backend/test/utils/__init__.py b/src/ai/backend/test/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/test/utils/cli.py b/src/ai/backend/test/utils/cli.py new file mode 100644 index 0000000000..92997d7e71 --- /dev/null +++ b/src/ai/backend/test/utils/cli.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from pathlib import Path +from typing import ( + Protocol, + Sequence, +) + +import pexpect + + +EOF = pexpect.EOF +TIMEOUT = pexpect.TIMEOUT + + +class ClientRunnerFunc(Protocol): + + def __call__( + self, + cmdargs: Sequence[str | Path], + *args, + **kwargs, + ) -> pexpect.spawn: + pass + + +def run( + args: Sequence[str | Path], + *, + default_timeout: int = 5, + **kwargs, +) -> pexpect.spawn: + p = pexpect.spawn( + str(args[0]), + [str(arg) for arg in args[1:]], + timeout=default_timeout, + **kwargs, + ) + return p diff --git a/src/ai/backend/testutils/BUILD b/src/ai/backend/testutils/BUILD new file mode 100644 index 0000000000..1d8ef822d6 --- /dev/null +++ b/src/ai/backend/testutils/BUILD @@ -0,0 +1,6 @@ +python_test_utils( + name="lib", + sources=[ + "**/*.py", + ], +) diff --git a/src/ai/backend/testutils/__init__.py b/src/ai/backend/testutils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai/backend/testutils/bootstrap.py b/src/ai/backend/testutils/bootstrap.py new file mode 100644 index 0000000000..f9bec1e69c --- /dev/null +++ b/src/ai/backend/testutils/bootstrap.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import json +import subprocess +import time +from typing import ( + Iterator, +) + +import pytest + +from ai.backend.common.types import HostPortPair + + +@pytest.fixture(scope='session') +def etcd_container() -> Iterator[tuple[str, HostPortPair]]: + # Spawn a single-node etcd container for a testing session. + proc = subprocess.run( + [ + 'docker', 'run', '-d', + '-p', ':2379', + '-p', ':4001', + 'quay.io/coreos/etcd:v3.5.4', + '/usr/local/bin/etcd', + '-advertise-client-urls', 'http://0.0.0.0:2379', + '-listen-client-urls', 'http://0.0.0.0:2379', + ], + capture_output=True, + ) + container_id = proc.stdout.decode().strip() + proc = subprocess.run( + [ + 'docker', 'inspect', container_id, + ], + capture_output=True, + ) + container_info = json.loads(proc.stdout) + host_port = int(container_info[0]['NetworkSettings']['Ports']['2379/tcp'][0]['HostPort']) + yield container_id, HostPortPair('127.0.0.1', host_port) + subprocess.run( + [ + 'docker', 'rm', '-v', '-f', container_id, + ], + capture_output=True, + ) + + +@pytest.fixture(scope='session') +def redis_container() -> Iterator[tuple[str, HostPortPair]]: + # Spawn a single-node etcd container for a testing session. + proc = subprocess.run( + [ + 'docker', 'run', '-d', + '-p', ':6379', + 'redis:6.2-alpine', + ], + capture_output=True, + ) + container_id = proc.stdout.decode().strip() + proc = subprocess.run( + [ + 'docker', 'inspect', container_id, + ], + capture_output=True, + ) + container_info = json.loads(proc.stdout) + host_port = int(container_info[0]['NetworkSettings']['Ports']['6379/tcp'][0]['HostPort']) + yield container_id, HostPortPair('127.0.0.1', host_port) + subprocess.run( + [ + 'docker', 'rm', '-v', '-f', container_id, + ], + capture_output=True, + ) + + +@pytest.fixture(scope='session') +def postgres_container() -> Iterator[tuple[str, HostPortPair]]: + # Spawn a single-node etcd container for a testing session. + proc = subprocess.run( + [ + 'docker', 'run', '-d', + '-p', ':5432', + '-e', 'POSTGRES_PASSWORD=develove', + '-e', 'POSTGRES_DB=testing', + '--health-cmd', 'pg_isready -U postgres', + '--health-interval', '1s', + '--health-start-period', '2s', + 'postgres:13.6-alpine', + ], + capture_output=True, + ) + container_id = proc.stdout.decode().strip() + host_port = 0 + while host_port == 0: + proc = subprocess.run( + [ + 'docker', 'inspect', container_id, + ], + capture_output=True, + ) + container_info = json.loads(proc.stdout) + if container_info[0]['State']['Health']['Status'].lower() != 'healthy': + time.sleep(0.2) + continue + host_port = int(container_info[0]['NetworkSettings']['Ports']['5432/tcp'][0]['HostPort']) + yield container_id, HostPortPair('127.0.0.1', host_port) + subprocess.run( + [ + 'docker', 'rm', '-v', '-f', container_id, + ], + capture_output=True, + ) diff --git a/src/ai/backend/testutils/pants.py b/src/ai/backend/testutils/pants.py new file mode 100644 index 0000000000..6204b1cc68 --- /dev/null +++ b/src/ai/backend/testutils/pants.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import functools +import os + + +@functools.lru_cache +def get_parallel_slot() -> int: + return int(os.environ.get('BACKEND_TEST_EXEC_SLOT', '0')) diff --git a/src/ai/backend/web/BUILD b/src/ai/backend/web/BUILD new file mode 100644 index 0000000000..2274a532a0 --- /dev/null +++ b/src/ai/backend/web/BUILD @@ -0,0 +1,45 @@ +python_sources( + name="service", + sources=["**/*.py"], + dependencies=[ + "src/ai/backend/client:lib", + ":resources", + ], +) + +pex_binary( + name="server", + dependencies=[ + ":service", + ], + entry_point="server.py", +) + +python_distribution( + name="dist", + dependencies=[ + ":service", + "!!stubs/trafaret:stubs", + ], + provides=python_artifact( + name="backend.ai-webserver", + description="Backend.AI WebUI Host", + license="LGPLv3", + ), + entry_points={}, + generate_setup=True, + tags=["wheel"], +) + +resource(name="version", source="VERSION") + +resources( + name="resources", + dependencies=[ + ":version", + ], + sources=[ + "**/py.typed", + "static/**/*", + ], +) diff --git a/src/ai/backend/web/README.md b/src/ai/backend/web/README.md new file mode 100644 index 0000000000..cc9b05bbdb --- /dev/null +++ b/src/ai/backend/web/README.md @@ -0,0 +1,64 @@ +# Backend.AI Web Server + +[![GitHub version](https://badge.fury.io/gh/lablup%2Fbackend.ai-webserver.svg)](https://badge.fury.io/gh/lablup%2Fbackend.ai-webserver) [![PyPI version](https://badge.fury.io/py/backend.ai-webserver.svg)](https://badge.fury.io/py/backend.ai-webserver) + +A webapp hosting daemon which serves our `webui` as a SPA and proxies API requests + + +## Installation + +Prepare a Python virtualenv (Python 3.9 or higher) and a Redis server (6.2 or higher). + +```console +$ git clone https://github.com/lablup/backend.ai-webserver webserver +$ cd webserver +$ pip install -U -e . +$ cp webserver.sample.conf webserver.conf +``` + +## Mode + +If `service.mode` is set "webui" (the default), the webserver handles +PWA-style fallbacks (e.g., serving `index.html` when there are no matching +files for the requested URL path). +The PWA must exclude `/server` and `/func` URL prefixes from its own routing +to work with the webserver's web sessions and the API proxy. + +If it is set "static", the webserver serves the static files as-is, +without any fallbacks or hooking, while preserving the `/server` and `/func` +prefixed URLs and their functionalities. + +If you want to serve web UI in webserver with "webui" mode, prepare static web UI source by choosing one of the followings. + +### Option 1: Build web UI from source + +Build **[backend.ai-webui](https://github.com/lablup/backend.ai-webui)** and copy all files under `build/bundle` +into the `src/ai/backend/web/static` directory. + +### Option 2: Use pre-built web UI + +To download and deploy web UI from pre-built source, do the following: + +```console +git submodule init +git submodule update +cd src/ai/backend/web/static +git checkout main # or target branch +git fetch +git pull +``` +### Setup configuration for webserver + +You don't have to write `config.toml` for the web UI as this webserver auto-generates it on-the-fly. + +Edit `webserver.conf` to match with your environment. + + +## Usage + +To execute web server, run command below. (for debugging, append a `--debug` flag) + + +```console +$ python -m ai.backend.web.server +``` diff --git a/src/ai/backend/web/VERSION b/src/ai/backend/web/VERSION new file mode 120000 index 0000000000..a4e948506b --- /dev/null +++ b/src/ai/backend/web/VERSION @@ -0,0 +1 @@ +../../../../VERSION \ No newline at end of file diff --git a/src/ai/backend/web/__init__.py b/src/ai/backend/web/__init__.py new file mode 100644 index 0000000000..7fc3c57564 --- /dev/null +++ b/src/ai/backend/web/__init__.py @@ -0,0 +1,5 @@ +from pathlib import Path + +__version__ = (Path(__file__).parent / 'VERSION').read_text().strip() + +user_agent = f'Backend.AI Web Server {__version__}' diff --git a/src/ai/backend/web/auth.py b/src/ai/backend/web/auth.py new file mode 100644 index 0000000000..08f269ca43 --- /dev/null +++ b/src/ai/backend/web/auth.py @@ -0,0 +1,66 @@ +import copy +import json + +from aiohttp import web +from aiohttp_session import get_session + +from ai.backend.client.session import AsyncSession as APISession +from ai.backend.client.config import APIConfig + +from . import user_agent + + +async def get_api_session( + request: web.Request, + api_endpoint: str = None, +) -> APISession: + config = request.app['config'] + if api_endpoint is not None: + config = copy.deepcopy(config) + config['api']['endpoint'] = api_endpoint + session = await get_session(request) + if not session.get('authenticated', False): + raise web.HTTPUnauthorized(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/auth-failed', + 'title': 'Unauthorized access', + }), content_type='application/problem+json') + if 'token' not in session: + raise web.HTTPUnauthorized(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/auth-failed', + 'title': 'Unauthorized access', + }), content_type='application/problem+json') + token = session['token'] + if token['type'] != 'keypair': + raise web.HTTPBadRequest(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/invalid-auth-params', + 'title': 'Incompatible auth token type.', + }), content_type='application/problem+json') + ak, sk = token['access_key'], token['secret_key'] + config = APIConfig( + domain=config['api']['domain'], + endpoint=config['api']['endpoint'], + access_key=ak, + secret_key=sk, + user_agent=user_agent, + skip_sslcert_validation=not config['api'].get('ssl-verify', True), + ) + return APISession(config=config, proxy_mode=True) + + +async def get_anonymous_session( + request: web.Request, + api_endpoint: str = None, +) -> APISession: + config = request.app['config'] + if api_endpoint is not None: + config = copy.deepcopy(config) + config['api']['endpoint'] = api_endpoint + config = APIConfig( + domain=config['api']['domain'], + endpoint=config['api']['endpoint'], + access_key='', + secret_key='', + user_agent=user_agent, + skip_sslcert_validation=not config['api'].get('ssl-verify', True), + ) + return APISession(config=config, proxy_mode=True) diff --git a/src/ai/backend/web/logging.py b/src/ai/backend/web/logging.py new file mode 100644 index 0000000000..05bd6b0ce5 --- /dev/null +++ b/src/ai/backend/web/logging.py @@ -0,0 +1,24 @@ +import logging + + +class BraceMessage: + + __slots__ = ('fmt', 'args') + + def __init__(self, fmt, args): + self.fmt = fmt + self.args = args + + def __str__(self): + return self.fmt.format(*self.args) + + +class BraceStyleAdapter(logging.LoggerAdapter): + + def __init__(self, logger, extra=None): + super().__init__(logger, extra) + + def log(self, level, msg, *args, **kwargs): + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + self.logger._log(level, BraceMessage(msg, args), (), **kwargs) diff --git a/src/ai/backend/web/proxy.py b/src/ai/backend/web/proxy.py new file mode 100644 index 0000000000..058ea627c3 --- /dev/null +++ b/src/ai/backend/web/proxy.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import asyncio +import logging +import json +import random +from typing import ( + Optional, Union, + Tuple, + cast, +) + +import aiohttp +from aiohttp import web +from aiohttp_session import get_session, STORAGE_KEY + +from ai.backend.client.exceptions import BackendAPIError, BackendClientError +from ai.backend.client.request import Request + +from .auth import get_api_session, get_anonymous_session +from .logging import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger('ai.backend.console.proxy')) + +HTTP_HEADERS_TO_FORWARD = [ + 'Accept-Language', +] + + +class WebSocketProxy: + __slots__ = ( + 'up_conn', 'down_conn', + 'upstream_buffer', 'upstream_buffer_task', + ) + + up_conn: aiohttp.ClientWebSocketResponse + down_conn: web.WebSocketResponse + upstream_buffer: asyncio.Queue[Tuple[Union[str, bytes], aiohttp.WSMsgType]] + upstream_buffer_task: Optional[asyncio.Task] + + def __init__(self, up_conn: aiohttp.ClientWebSocketResponse, + down_conn: web.WebSocketResponse) -> None: + self.up_conn = up_conn + self.down_conn = down_conn + self.upstream_buffer = asyncio.Queue() + self.upstream_buffer_task = None + + async def proxy(self) -> None: + asyncio.ensure_future(self.downstream()) + await self.upstream() + + async def upstream(self) -> None: + try: + async for msg in self.down_conn: + if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY): + await self.send(msg.data, msg.type) + elif msg.type == aiohttp.WSMsgType.ERROR: + log.error("WebSocketProxy: connection closed with exception {}", + self.up_conn.exception()) + break + elif msg.type == aiohttp.WSMsgType.CLOSE: + break + # here, client gracefully disconnected + except asyncio.CancelledError: + # here, client forcibly disconnected + pass + finally: + await self.close_downstream() + + async def downstream(self) -> None: + try: + self.upstream_buffer_task = \ + asyncio.create_task(self.consume_upstream_buffer()) + async for msg in self.up_conn: + if msg.type == aiohttp.WSMsgType.TEXT: + await self.down_conn.send_str(msg.data) + elif msg.type == aiohttp.WSMsgType.BINARY: + await self.down_conn.send_bytes(msg.data) + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break + # here, server gracefully disconnected + except asyncio.CancelledError: + pass + except Exception as e: + log.error('WebSocketProxy: unexpected error: {}', e) + finally: + await self.close_upstream() + + async def consume_upstream_buffer(self) -> None: + try: + while True: + data, tp = await self.upstream_buffer.get() + if not self.up_conn.closed: + if tp == aiohttp.WSMsgType.BINARY: + await self.up_conn.send_bytes(cast(bytes, data)) + elif tp == aiohttp.WSMsgType.TEXT: + await self.up_conn.send_str(cast(str, data)) + except asyncio.CancelledError: + pass + + async def send(self, msg: str, tp: aiohttp.WSMsgType) -> None: + await self.upstream_buffer.put((msg, tp)) + + async def close_downstream(self) -> None: + if not self.down_conn.closed: + await self.down_conn.close() + + async def close_upstream(self) -> None: + if self.upstream_buffer_task is not None and not self.upstream_buffer_task.done(): + self.upstream_buffer_task.cancel() + await self.upstream_buffer_task + if not self.up_conn.closed: + await self.up_conn.close() + + +async def web_handler(request, *, is_anonymous=False) -> web.StreamResponse: + path = request.match_info.get('path', '') + if is_anonymous: + api_session = await asyncio.shield(get_anonymous_session(request)) + else: + api_session = await asyncio.shield(get_api_session(request)) + try: + async with api_session: + # We perform request signing by ourselves using the HTTP session data, + # but need to keep the client's version header so that + # the final clients may perform its own API versioning support. + request_api_version = request.headers.get('X-BackendAI-Version', None) + # Send X-Forwarded-For header for token authentication with the client IP. + client_ip = request.headers.get('X-Forwarded-For') + if not client_ip: + client_ip = request.remote + _headers = {'X-Forwarded-For': client_ip} + api_session.aiohttp_session.headers.update(_headers) + # Deliver cookie for token-based authentication. + api_session.aiohttp_session.cookie_jar.update_cookies(request.cookies) + # We treat all requests and responses as streaming universally + # to be a transparent proxy. + api_rqst = Request( + request.method, path, request.content, + params=request.query, + override_api_version=request_api_version) + if 'Content-Type' in request.headers: + api_rqst.content_type = request.content_type # set for signing + api_rqst.headers['Content-Type'] = request.headers['Content-Type'] # preserve raw value + if 'Content-Length' in request.headers: + api_rqst.headers['Content-Length'] = request.headers['Content-Length'] + for hdr in HTTP_HEADERS_TO_FORWARD: + if request.headers.get(hdr) is not None: + api_rqst.headers[hdr] = request.headers[hdr] + # Uploading request body happens at the entering of the block, + # and downloading response body happens in the read loop inside. + async with api_rqst.fetch() as up_resp: + down_resp = web.StreamResponse() + down_resp.set_status(up_resp.status, up_resp.reason) + down_resp.headers.update(up_resp.headers) + # We already have configured CORS handlers and the API server + # also provides those headers. Just let them as-is. + await down_resp.prepare(request) + while True: + chunk = await up_resp.read(8192) + if not chunk: + break + await down_resp.write(chunk) + return down_resp + except asyncio.CancelledError: + raise + except BackendAPIError as e: + return web.Response(body=json.dumps(e.data), + content_type="application/problem+json", + status=e.status, reason=e.reason) + except BackendClientError: + log.exception('web_handler: BackendClientError') + return web.HTTPBadGateway(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/bad-gateway', + 'title': "The proxy target server is inaccessible.", + }), content_type='application/problem+json') + except Exception: + log.exception('web_handler: unexpected error') + return web.HTTPInternalServerError(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/internal-server-error', + 'title': "Something has gone wrong.", + }), content_type='application/problem+json') + finally: + await api_session.close() + + +async def web_plugin_handler(request, *, is_anonymous=False) -> web.StreamResponse: + """ + This handler is almost same to web_handler, but does not manipulate the + content-type and content-length headers before sending up-requests. + It also configures the domain in the json body for "auth/signup" requests. + """ + path = request.match_info['path'] + if is_anonymous: + api_session = await asyncio.shield(get_anonymous_session(request)) + else: + api_session = await asyncio.shield(get_api_session(request)) + try: + async with api_session: + content = request.content + if path == 'auth/signup': + body = await request.json() + body['domain'] = request.app['config']['api']['domain'] + content = json.dumps(body).encode('utf8') + request_api_version = request.headers.get('X-BackendAI-Version', None) + # Send X-Forwarded-For header for token authentication with the client IP. + client_ip = request.headers.get('X-Forwarded-For') + if not client_ip: + client_ip = request.remote + _headers = {'X-Forwarded-For': client_ip} + api_session.aiohttp_session.headers.update(_headers) + # Deliver cookie for token-based authentication. + api_session.aiohttp_session.cookie_jar.update_cookies(request.cookies) + api_rqst = Request( + request.method, path, content, + params=request.query, + content_type=request.content_type, + override_api_version=request_api_version) + for hdr in HTTP_HEADERS_TO_FORWARD: + if request.headers.get(hdr) is not None: + api_rqst.headers[hdr] = request.headers[hdr] + async with api_rqst.fetch() as up_resp: + down_resp = web.StreamResponse() + down_resp.set_status(up_resp.status, up_resp.reason) + down_resp.headers.update(up_resp.headers) + # We already have configured CORS handlers and the API server + # also provides those headers. Just let them as-is. + await down_resp.prepare(request) + while True: + chunk = await up_resp.read(8192) + if not chunk: + break + await down_resp.write(chunk) + return down_resp + except asyncio.CancelledError: + raise + except BackendAPIError as e: + return web.Response(body=json.dumps(e.data), + content_type='application/problem+json', + status=e.status, reason=e.reason) + except BackendClientError: + log.exception('web_plugin_handler: BackendClientError') + return web.HTTPBadGateway(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/bad-gateway', + 'title': "The proxy target server is inaccessible.", + }), content_type='application/problem+json') + except Exception: + log.exception('web_plugin_handler: unexpected error') + return web.HTTPInternalServerError(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/internal-server-error', + 'title': "Something has gone wrong.", + }), content_type='application/problem+json') + + +async def websocket_handler(request, *, is_anonymous=False) -> web.StreamResponse: + path = request.match_info['path'] + session = await get_session(request) + app = request.query.get('app') + + # Choose a specific Manager endpoint for persistent web app connection. + api_endpoint = None + should_save_session = False + _endpoints = request.app['config']['api']['endpoint'].split(',') + _endpoints = [e.strip() for e in _endpoints] + if session.get('api_endpoints', {}).get(app): + if session['api_endpoints'][app] in _endpoints: + api_endpoint = session['api_endpoints'][app] + if api_endpoint is None: + api_endpoint = random.choice(_endpoints) + if 'api_endpoints' not in session: + session['api_endpoints'] = {} + session['api_endpoints'][app] = api_endpoint + should_save_session = True + + if is_anonymous: + api_session = await asyncio.shield(get_anonymous_session(request, api_endpoint)) + else: + api_session = await asyncio.shield(get_api_session(request, api_endpoint)) + try: + async with api_session: + request_api_version = request.headers.get('X-BackendAI-Version', None) + params = request.query if request.query else None + api_rqst = Request( + request.method, path, request.content, + params=params, + content_type=request.content_type, + override_api_version=request_api_version) + async with api_rqst.connect_websocket() as up_conn: + down_conn = web.WebSocketResponse() + await down_conn.prepare(request) + web_socket_proxy = WebSocketProxy(up_conn.raw_websocket, down_conn) + await web_socket_proxy.proxy() + if should_save_session: + storage = request.get(STORAGE_KEY) + await storage.save_session(request, down_conn, session) + return down_conn + except asyncio.CancelledError: + raise + except BackendAPIError as e: + return web.Response(body=json.dumps(e.data), + content_type='application/problem+json', + status=e.status, reason=e.reason) + except BackendClientError: + log.exception('websocket_handler: BackendClientError') + return web.HTTPBadGateway(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/bad-gateway', + 'title': "The proxy target server is inaccessible.", + }), content_type='application/problem+json') + except Exception: + log.exception('websocket_handler: unexpected error') + return web.HTTPInternalServerError(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/internal-server-error', + 'title': "Something has gone wrong.", + }), content_type='application/problem+json') diff --git a/src/ai/backend/web/py.typed b/src/ai/backend/web/py.typed new file mode 100644 index 0000000000..48cdce8528 --- /dev/null +++ b/src/ai/backend/web/py.typed @@ -0,0 +1 @@ +placeholder diff --git a/src/ai/backend/web/server.py b/src/ai/backend/web/server.py new file mode 100644 index 0000000000..6b638d50e8 --- /dev/null +++ b/src/ai/backend/web/server.py @@ -0,0 +1,694 @@ +import asyncio +from functools import partial +import logging +import logging.config +import json +import os +from pathlib import Path +import pkg_resources +from pprint import pprint +import re +import socket +import ssl +import sys +import time +from typing import ( + Any, + AsyncIterator, + MutableMapping, + Tuple, +) + +from aiohttp import web +import aiohttp_cors +from aiohttp_session import get_session, setup as setup_session +from aiohttp_session.redis_storage import RedisStorage +import aiotools +import aioredis +import click +import jinja2 +from setproctitle import setproctitle +import toml +import uvloop +import yarl + +from ai.backend.client.config import APIConfig +from ai.backend.client.exceptions import BackendClientError, BackendAPIError +from ai.backend.client.session import AsyncSession as APISession + +from . import __version__, user_agent +from .logging import BraceStyleAdapter +from .proxy import web_handler, websocket_handler, web_plugin_handler + +log = BraceStyleAdapter(logging.getLogger('ai.backend.web.server')) +static_path = Path(pkg_resources.resource_filename('ai.backend.web', 'static')).resolve() +assert static_path.is_dir() + + +console_config_ini_template = jinja2.Template('''[general] +apiEndpoint = {{endpoint_url}} +apiEndpointText = {{endpoint_text}} +{% if default_environment %} +defaultSessionEnvironment = "{{default_environment}}" +{% endif %} +siteDescription = {{site_description}} +connectionMode = "SESSION" + +[wsproxy] +proxyURL = {{proxy_url}}/ +proxyBaseURL = +proxyListenIP = +''') + +console_config_toml_template = jinja2.Template('''[general] +apiEndpoint = "{{endpoint_url}}" +apiEndpointText = "{{endpoint_text}}" +{% if default_environment %} +defaultSessionEnvironment = "{{default_environment}}" +{% endif %} +{% if default_import_environment %} +defaultImportEnvironment = "{{default_import_environment}}" +{% endif %} +siteDescription = "{{site_description}}" +connectionMode = "SESSION" +signupSupport = {{signup_support}} +allowChangeSigninMode = {{allow_change_signin_mode}} +allowAnonymousChangePassword = {{allow_anonymous_change_password}} +allowProjectResourceMonitor = {{allow_project_resource_monitor}} +allowManualImageNameForSession = {{allow_manual_image_name_for_session}} +allowSignupWithoutConfirmation = {{allow_signup_without_confirmation}} +autoLogout = {{auto_logout}} +debug = {{webui_debug}} +maskUserInfo = {{mask_user_info}} + +[resources] +openPortToPublic = {{open_port_to_public}} +maxCPUCoresPerContainer = {{max_cpu_cores_per_container}} +maxMemoryPerContainer = {{max_memory_per_container}} +maxCUDADevicesPerContainer = {{max_cuda_devices_per_container}} +maxCUDASharesPerContainer = {{max_cuda_shares_per_container}} +maxShmPerContainer = {{max_shm_per_container}} +maxFileUploadSize = {{max_file_upload_size}} + +[environments] +{% if environment_allowlist %} +allowlist = "{{environment_allowlist}}" +{% endif %} + +[menu] +{% if menu_blocklist %} +blocklist = "{{menu_blocklist}}" +{% endif %} + +{% if console_menu_plugins %} +[plugin] +page = "{{console_menu_plugins}}" + +{% endif %} +[wsproxy] +proxyURL = "{{proxy_url}}/" +#proxyBaseURL = +#proxyListenIP = + +[license] +edition = "{{license_edition}}" +validSince = "{{license_valid_since}}" +validUntil = "{{license_valid_until}}" +''') + + +async def static_handler(request: web.Request) -> web.StreamResponse: + request_path = request.match_info['path'] + file_path = (static_path / request_path).resolve() + try: + file_path.relative_to(static_path) + except (ValueError, FileNotFoundError): + return web.HTTPNotFound(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/generic-not-found', + 'title': 'Not Found', + }), content_type='application/problem+json') + if file_path.is_file(): + return header_handler(web.FileResponse(file_path), request_path) + return web.HTTPNotFound(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/generic-not-found', + 'title': 'Not Found', + }), content_type='application/problem+json') + + +async def console_handler(request: web.Request) -> web.StreamResponse: + request_path = request.match_info['path'] + file_path = (static_path / request_path).resolve() + config = request.app['config'] + scheme = config['service'].get('force-endpoint-protocol') + if scheme is None: + scheme = request.scheme + + if request_path == 'config.ini': + config_content = console_config_ini_template.render(**{ + 'endpoint_url': f'{scheme}://{request.host}', # must be absolute + 'endpoint_text': config['api']['text'], + 'site_description': config['ui']['brand'], + 'default_environment': config['ui'].get('default_environment'), + 'proxy_url': config['service']['wsproxy']['url'], + }) + return web.Response(text=config_content) + + if request_path == 'config.toml': + if 'license' in config: + license_edition = config['license'].get('edition', 'Open Source') + license_valid_since = config['license'].get('valid_since', '') + license_valid_until = config['license'].get('valid_until', '') + else: + license_edition = 'Open Source' + license_valid_since = '' + license_valid_until = '' + if 'resources' in config: + open_port_to_public = 'true' if config['resources'].get('open_port_to_public') else 'false' + max_cpu_cores_per_container = config['resources'].get('max_cpu_cores_per_container', 64) + max_memory_per_container = config['resources'].get('max_memory_per_container', 64) + max_cuda_devices_per_container = config['resources'].get( + 'max_cuda_devices_per_container', 16) + max_cuda_shares_per_container = config['resources'].get( + 'max_cuda_shares_per_container', 16) + max_shm_per_container = config['resources'].get('max_shm_per_container', 2) + max_file_upload_size = config['resources'].get('max_file_upload_size', 4294967296) + else: + open_port_to_public = 'false' + max_cpu_cores_per_container = 64 + max_memory_per_container = 64 + max_cuda_devices_per_container = 16 + max_cuda_shares_per_container = 16 + max_shm_per_container = 2 + max_file_upload_size = 4294967296 + if 'plugin' in config: + console_menu_plugins = config['plugin'].get('page', '') + else: + console_menu_plugins = False + config_content = console_config_toml_template.render(**{ + 'endpoint_url': f'{scheme}://{request.host}', # must be absolute + 'endpoint_text': config['api']['text'], + 'site_description': config['ui']['brand'], + 'default_environment': config['ui'].get('default_environment'), + 'default_import_environment': config['ui'].get('default_import_environment'), + 'proxy_url': config['service']['wsproxy']['url'], + 'signup_support': 'true' if config['service']['enable_signup'] else 'false', + 'allow_change_signin_mode': + 'true' if config['service'].get('allow_change_signin_mode') else 'false', + 'allow_anonymous_change_password': + 'true' if config['service'].get('allow_anonymous_change_password') else 'false', + 'allow_project_resource_monitor': + 'true' if config['service']['allow_project_resource_monitor'] else 'false', + 'allow_manual_image_name_for_session': + 'true' if config['service'].get('allow_manual_image_name_for_session') else 'false', + 'allow_signup_without_confirmation': + 'true' if config['service'].get('allow_signup_without_confirmation') else 'false', + 'webui_debug': 'true' if config['service'].get('webui_debug') else 'false', + 'auto_logout': + 'true' if config['session'].get('auto_logout') else 'false', + 'mask_user_info': + 'true' if config['service'].get('mask_user_info') else 'false', + 'open_port_to_public': open_port_to_public, + 'max_cpu_cores_per_container': max_cpu_cores_per_container, + 'max_memory_per_container': max_memory_per_container, + 'max_cuda_devices_per_container': max_cuda_devices_per_container, + 'max_cuda_shares_per_container': max_cuda_shares_per_container, + 'max_shm_per_container': max_shm_per_container, + 'max_file_upload_size': max_file_upload_size, + 'environment_allowlist': config['environments'].get('allowlist', ''), + 'menu_blocklist': config['ui'].get('menu_blocklist', ''), + 'console_menu_plugins': console_menu_plugins, + 'license_edition': license_edition, + 'license_valid_since': license_valid_since, + 'license_valid_until': license_valid_until, + }) + return web.Response(text=config_content) + # SECURITY: only allow reading files under static_path + try: + file_path.relative_to(static_path) + except (ValueError, FileNotFoundError): + return web.HTTPNotFound(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/generic-not-found', + 'title': 'Not Found', + }), content_type='application/problem+json') + if file_path.is_file(): + return header_handler(web.FileResponse(file_path), request_path) + + return header_handler(web.FileResponse(static_path / 'index.html'), 'index.html') + + +cache_patterns = { + r'\.(?:manifest|appcache|html?|xml|json|ini|toml)$': { + 'Cache-Control': 'no-store', + }, + r'(?:backend.ai-webui.js)$': { + 'Cache-Control': 'no-store', + }, + r'\.(?:jpg|jpeg|gif|png|ico|cur|gz|svg|svgz|mp4|ogg|ogv|webm|htc|woff|woff2)$': { + 'Cache-Control': 'max-age=259200, public', + }, + r'\.(?:css|js)$': { + 'Cache-Control': 'max-age=86400, public, must-revalidate, proxy-revalidate', + }, + r'\.(?:py|log?|txt)$': { + 'Cache-Control': 'no-store', + }, +} +_cache_patterns = {re.compile(k): v for k, v in cache_patterns.items()} + + +def header_handler(response: web.StreamResponse, path: str) -> web.StreamResponse: + for regex, headers in _cache_patterns.items(): + mo = regex.search(path) + if mo is not None: + response.headers.update(headers) + break + return response + + +async def login_check_handler(request: web.Request) -> web.Response: + session = await get_session(request) + authenticated = bool(session.get('authenticated', False)) + public_data = None + if authenticated: + stored_token = session['token'] + public_data = { + 'access_key': stored_token['access_key'], + 'role': stored_token['role'], + 'status': stored_token.get('status'), + } + return web.json_response({ + 'authenticated': authenticated, + 'data': public_data, + 'session_id': session.identity, # temporary wsproxy interop patch + }) + + +async def login_handler(request: web.Request) -> web.Response: + config = request.app['config'] + session = await get_session(request) + if session.get('authenticated', False): + return web.HTTPBadRequest(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/generic-bad-request', + 'title': 'You have already logged in.', + }), content_type='application/problem+json') + creds = await request.json() + if 'username' not in creds or not creds['username']: + return web.HTTPBadRequest(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/invalid-api-params', + 'title': 'You must provide the username field.', + }), content_type='application/problem+json') + if 'password' not in creds or not creds['password']: + return web.HTTPBadRequest(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/invalid-api-params', + 'title': 'You must provide the password field.', + }), content_type='application/problem+json') + result: MutableMapping[str, Any] = { + 'authenticated': False, + 'data': None, + } + try: + async def _get_login_history(): + login_history = await request.app['redis'].get( + f'login_history_{creds["username"]}', + ) + if not login_history: + login_history = { + 'last_login_attempt': 0, + 'login_fail_count': 0, + } + else: + login_history = json.loads(login_history) + if login_history['last_login_attempt'] < 0: + login_history['last_login_attempt'] = 0 + if login_history['login_fail_count'] < 0: + login_history['login_fail_count'] = 0 + return login_history + + async def _set_login_history(last_login_attempt, login_fail_count): + """ + Set login history per email (not in browser session). + """ + key = f'login_history_{creds["username"]}' + value = json.dumps({ + 'last_login_attempt': last_login_attempt, + 'login_fail_count': login_fail_count, + }) + await request.app['redis'].set(key, value) + + # Block login if there are too many consecutive failed login attempts. + BLOCK_TIME = config['session'].get('login_block_time', 1200) + ALLOWED_FAIL_COUNT = config['session'].get('login_allowed_fail_count', 10) + login_time = time.time() + login_history = await _get_login_history() + last_login_attempt = login_history.get('last_login_attempt', 0) + login_fail_count = login_history.get('login_fail_count', 0) + if login_time - last_login_attempt > BLOCK_TIME: + # If last attempt is far past, allow login again. + login_fail_count = 0 + last_login_attempt = login_time + if login_fail_count >= ALLOWED_FAIL_COUNT: + log.info('Too many consecutive login attempts for {}: {}', + creds['username'], login_fail_count) + await _set_login_history(last_login_attempt, login_fail_count) + return web.HTTPTooManyRequests(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/too-many-requests', + 'title': 'Too many failed login attempts', + }), content_type='application/problem+json') + + anon_api_config = APIConfig( + domain=config['api']['domain'], + endpoint=config['api']['endpoint'], + access_key='', secret_key='', # anonymous session + user_agent=user_agent, + skip_sslcert_validation=not config['api'].get('ssl-verify', True), + ) + assert anon_api_config.is_anonymous + async with APISession(config=anon_api_config) as api_session: + token = await api_session.User.authorize(creds['username'], creds['password']) + stored_token = { + 'type': 'keypair', + 'access_key': token.content['access_key'], + 'secret_key': token.content['secret_key'], + 'role': token.content['role'], + 'status': token.content.get('status'), + } + public_return = { + 'access_key': token.content['access_key'], + 'role': token.content['role'], + 'status': token.content.get('status'), + } + session['authenticated'] = True + session['token'] = stored_token # store full token + result['authenticated'] = True + result['data'] = public_return # store public info from token + login_fail_count = 0 + await _set_login_history(last_login_attempt, login_fail_count) + except BackendClientError as e: + # This is error, not failed login, so we should not update login history. + return web.HTTPBadGateway(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/bad-gateway', + 'title': "The proxy target server is inaccessible.", + 'details': str(e), + }), content_type='application/problem+json') + except BackendAPIError as e: + log.info('Authorization failed for {}: {}', creds['username'], e) + result['authenticated'] = False + result['data'] = { + 'type': e.data.get('type'), + 'title': e.data.get('title'), + 'details': e.data.get('msg'), + } + session['authenticated'] = False + login_fail_count += 1 + await _set_login_history(last_login_attempt, login_fail_count) + return web.json_response(result) + + +async def logout_handler(request: web.Request) -> web.Response: + session = await get_session(request) + session.invalidate() + return web.Response(status=201) + + +async def webserver_healthcheck(request: web.Request) -> web.Response: + result = { + 'version': __version__, + 'details': 'Success', + } + return web.json_response(result) + + +async def token_login_handler(request: web.Request) -> web.Response: + config = request.app['config'] + + # Check browser session exists. + session = await get_session(request) + if session.get('authenticated', False): + return web.HTTPBadRequest(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/generic-bad-request', + 'title': 'You have already logged in.', + }), content_type='application/problem+json') + + # Check if auth token is delivered through cookie. + auth_token_name = config['api'].get('auth_token_name') + if not auth_token_name: + return web.HTTPBadRequest(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/invalid-api-params', + 'title': 'Auth token name is not defined', + }), content_type='application/problem+json') + auth_token = request.cookies.get(auth_token_name) + if not auth_token: + return web.HTTPBadRequest(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/invalid-api-params', + 'title': 'You must provide cookie-based authentication token', + }), content_type='application/problem+json') + + # Login with the token. + # We do not pose consecutive login failure for this handler since + # user may frequently click edu-api launcher button. + result: MutableMapping[str, Any] = { + 'authenticated': False, + 'data': None, + } + try: + anon_api_config = APIConfig( + domain=config['api']['domain'], + endpoint=config['api']['endpoint'], + access_key='', secret_key='', # anonymous session + user_agent=user_agent, + skip_sslcert_validation=not config['api'].get('ssl-verify', True), + ) + assert anon_api_config.is_anonymous + async with APISession(config=anon_api_config) as api_session: + # Send X-Forwarded-For header for token authentication with the client IP. + client_ip = request.headers.get('X-Forwarded-For', request.remote) + if client_ip: + _headers = {'X-Forwarded-For': client_ip} + api_session.aiohttp_session.headers.update(_headers) + # Instead of email and password, cookie token will be used for auth. + api_session.aiohttp_session.cookie_jar.update_cookies(request.cookies) + token = await api_session.User.authorize('fake-email', 'fake-pwd') + stored_token = { + 'type': 'keypair', + 'access_key': token.content['access_key'], + 'secret_key': token.content['secret_key'], + 'role': token.content['role'], + 'status': token.content.get('status'), + } + public_return = { + 'access_key': token.content['access_key'], + 'role': token.content['role'], + 'status': token.content.get('status'), + } + session['authenticated'] = True + session['token'] = stored_token # store full token + result['authenticated'] = True + result['data'] = public_return # store public info from token + except BackendClientError as e: + return web.HTTPBadGateway(text=json.dumps({ + 'type': 'https://api.backend.ai/probs/bad-gateway', + 'title': "The proxy target server is inaccessible.", + 'details': str(e), + }), content_type='application/problem+json') + except BackendAPIError as e: + log.info('Authorization failed for token {}: {}', auth_token, e) + result['authenticated'] = False + result['data'] = { + 'type': e.data.get('type'), + 'title': e.data.get('title'), + 'details': e.data.get('msg'), + } + session['authenticated'] = False + return web.json_response(result) + + +async def server_shutdown(app) -> None: + pass + + +async def server_cleanup(app) -> None: + await app['redis'].close() + + +@aiotools.server +async def server_main( + loop: asyncio.AbstractEventLoop, + pidx: int, + args: Tuple[Any, ...], +) -> AsyncIterator[None]: + config = args[0] + app = web.Application() + app['config'] = config + redis_url = ( + yarl.URL("redis://host") + .with_host(config['session']['redis']['host']) + .with_port(config['session']['redis']['port']) + .with_password(config['session']['redis'].get('password', None)) + / str(config['session']['redis'].get('db', 0)) # noqa + ) + keepalive_options = {} + if hasattr(socket, 'TCP_KEEPIDLE'): + keepalive_options[socket.TCP_KEEPIDLE] = 20 + if hasattr(socket, 'TCP_KEEPINTVL'): + keepalive_options[socket.TCP_KEEPINTVL] = 5 + if hasattr(socket, 'TCP_KEEPCNT'): + keepalive_options[socket.TCP_KEEPCNT] = 3 + app['redis'] = await aioredis.Redis.from_url( + str(redis_url), + socket_keepalive=True, + socket_keepalive_options=keepalive_options, + ) + + if pidx == 0 and config['session'].get('flush_on_startup', False): + await app['redis'].flushdb() + log.info('flushed session storage.') + redis_storage = RedisStorage( + app['redis'], + max_age=config['session']['max_age']) + + setup_session(app, redis_storage) + cors_options = { + '*': aiohttp_cors.ResourceOptions( + allow_credentials=True, + allow_methods='*', + expose_headers="*", + allow_headers="*"), + } + cors = aiohttp_cors.setup(app, defaults=cors_options) + + anon_web_handler = partial(web_handler, is_anonymous=True) + anon_web_plugin_handler = partial(web_plugin_handler, is_anonymous=True) + + app.router.add_route('HEAD', '/func/{path:folders/_/tus/upload/.*$}', anon_web_plugin_handler) + app.router.add_route('PATCH', '/func/{path:folders/_/tus/upload/.*$}', anon_web_plugin_handler) + app.router.add_route('OPTIONS', '/func/{path:folders/_/tus/upload/.*$}', anon_web_plugin_handler) + cors.add(app.router.add_route('POST', '/server/login', login_handler)) + cors.add(app.router.add_route('POST', '/server/token-login', token_login_handler)) + cors.add(app.router.add_route('POST', '/server/login-check', login_check_handler)) + cors.add(app.router.add_route('POST', '/server/logout', logout_handler)) + cors.add(app.router.add_route('GET', '/func/ping', webserver_healthcheck)) + cors.add(app.router.add_route('GET', '/func/{path:hanati/user}', anon_web_plugin_handler)) + cors.add(app.router.add_route('GET', '/func/{path:cloud/.*$}', anon_web_plugin_handler)) + cors.add(app.router.add_route('POST', '/func/{path:cloud/.*$}', anon_web_plugin_handler)) + cors.add(app.router.add_route('POST', '/func/{path:auth/signup}', anon_web_plugin_handler)) + cors.add(app.router.add_route('POST', '/func/{path:auth/signout}', web_handler)) + cors.add(app.router.add_route('GET', '/func/{path:stream/kernel/_/events}', web_handler)) + cors.add(app.router.add_route('GET', '/func/{path:stream/session/[^/]+/apps$}', web_handler)) + cors.add(app.router.add_route('GET', '/func/{path:stream/.*$}', websocket_handler)) + cors.add(app.router.add_route('GET', '/func/', anon_web_handler)) + cors.add(app.router.add_route('HEAD', '/func/{path:.*$}', web_handler)) + cors.add(app.router.add_route('GET', '/func/{path:.*$}', web_handler)) + cors.add(app.router.add_route('PUT', '/func/{path:.*$}', web_handler)) + cors.add(app.router.add_route('POST', '/func/{path:.*$}', web_handler)) + cors.add(app.router.add_route('PATCH', '/func/{path:.*$}', web_handler)) + cors.add(app.router.add_route('DELETE', '/func/{path:.*$}', web_handler)) + if config['service']['mode'] == 'webui': + fallback_handler = console_handler + elif config['service']['mode'] == 'static': + fallback_handler = static_handler + else: + raise ValueError('Unrecognized service.mode', config['service']['mode']) + cors.add(app.router.add_route('GET', '/{path:.*$}', fallback_handler)) + + app.on_shutdown.append(server_shutdown) + app.on_cleanup.append(server_cleanup) + + ssl_ctx = None + if 'ssl-enabled' in config['service'] and config['service']['ssl-enabled']: + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain( + str(config['service']['ssl-cert']), + str(config['service']['ssl-privkey']), + ) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite( + runner, + str(config['service']['ip']), + config['service']['port'], + backlog=1024, + reuse_port=True, + ssl_context=ssl_ctx, + ) + await site.start() + log.info('started.') + + try: + yield + finally: + log.info('shutting down...') + await runner.cleanup() + + +@click.command() +@click.option('-f', '--config', 'config_path', + type=click.Path(exists=True), + default='webserver.conf', + help='The configuration file to use.') +@click.option('--debug', is_flag=True, + default=False, + help='Use more verbose logging.') +def main(config_path: str, debug: bool) -> None: + config = toml.loads(Path(config_path).read_text(encoding='utf-8')) + config['debug'] = debug + if config['debug']: + debugFlag = 'DEBUG' + else: + debugFlag = 'INFO' + setproctitle(f"backend.ai: webserver " + f"{config['service']['ip']}:{config['service']['port']}") + + logging.config.dictConfig({ + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'colored': { + '()': 'coloredlogs.ColoredFormatter', + 'format': '%(asctime)s %(levelname)s %(name)s ' + '[%(process)d] %(message)s', + 'field_styles': {'levelname': {'color': 248, 'bold': True}, + 'name': {'color': 246, 'bold': False}, + 'process': {'color': 'cyan'}, + 'asctime': {'color': 240}}, + }, + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': 'DEBUG', + 'formatter': 'colored', + 'stream': 'ext://sys.stderr', + }, + 'null': { + 'class': 'logging.NullHandler', + }, + }, + 'loggers': { + '': { + 'handlers': ['console'], + 'level': debugFlag, + }, + }, + }) + log.info('Backend.AI Web Server {0}', __version__) + log.info('runtime: {0}', sys.prefix) + log_config = logging.getLogger('ai.backend.web.config') + log_config.debug('debug mode enabled.') + print('== Web Server configuration ==') + pprint(config) + log.info('serving at {0}:{1}', config['service']['ip'], config['service']['port']) + + try: + uvloop.install() + aiotools.start_server( + server_main, + num_workers=min(4, os.cpu_count() or 1), + args=(config,), + ) + finally: + log.info('terminated.') + + +if __name__ == '__main__': + main() diff --git a/src/ai/backend/web/static b/src/ai/backend/web/static new file mode 160000 index 0000000000..abe911b036 --- /dev/null +++ b/src/ai/backend/web/static @@ -0,0 +1 @@ +Subproject commit abe911b036964c2d0bf97dc3f9e79ed070fc86ca diff --git a/stubs/trafaret/BUILD b/stubs/trafaret/BUILD new file mode 100644 index 0000000000..ddeb5171a8 --- /dev/null +++ b/stubs/trafaret/BUILD @@ -0,0 +1,3 @@ +python_sources( + name="stubs", +) diff --git a/stubs/trafaret/__init__.pyi b/stubs/trafaret/__init__.pyi new file mode 100644 index 0000000000..4bbe7e8b2b --- /dev/null +++ b/stubs/trafaret/__init__.pyi @@ -0,0 +1,76 @@ +from typing import Tuple as _Tuple + +from trafaret.base import ( + Trafaret as Trafaret, + TrafaretMeta as TrafaretMeta, + TypeMeta as TypeMeta, + SquareBracketsMeta as SquareBracketsMeta, + OnError as OnError, + TypingTrafaret as TypingTrafaret, + Subclass as Subclass, + Type as Type, + Any as Any, + And as And, + Or as Or, + Key as Key, + Dict as Dict, + DictKeys as DictKeys, + Mapping as Mapping, + Enum as Enum, + Callable as Callable, + Call as Call, + Forward as Forward, + List as List, + Tuple as Tuple, + Atom as Atom, + String as String, + Bytes as Bytes, + FromBytes as FromBytes, + Null as Null, + Bool as Bool, + ToBool as ToBool, + guard as guard, + ignore as ignore, + catch as catch, + extract_error as extract_error, + GuardError as GuardError, +) +from trafaret.constructor import ( + ConstructMeta as ConstructMeta, + C as C, + construct as construct, + construct_key as construct_key, +) +from trafaret.keys import ( + KeysSubset as KeysSubset, + subdict as subdict, + xor_key as xor_key, + confirm_key as confirm_key, +) +from trafaret.internet import ( + Email as Email, + Hex as Hex, + URL as URL, + URLSafe as URLSafe, + IPv4 as IPv4, + IPv6 as IPv6, + IP as IP, +) +from trafaret.numeric import ( + NumberMeta as NumberMeta, + Int as Int, + ToInt as ToInt, + Float as Float, + ToFloat as ToFloat, + ToDecimal as ToDecimal, +) +from trafaret.regexp import ( + RegexpRaw as RegexpRaw, + Regexp as Regexp, +) +from trafaret.dataerror import ( + DataError as DataError, +) + +__all__: _Tuple[str] +__VERSION__: _Tuple[int, int, int] diff --git a/stubs/trafaret/base.pyi b/stubs/trafaret/base.pyi new file mode 100644 index 0000000000..dd280fde19 --- /dev/null +++ b/stubs/trafaret/base.pyi @@ -0,0 +1,94 @@ +from typing import Any as _Any, Hashable, Optional, NoReturn, Type as _Type + +from trafaret.dataerror import DataError + + +class TrafaretMeta(type): + def __or__(cls, other): ... + def __and__(cls, other): ... + def __rshift__(cls, other: str): ... + +class Trafaret(metaclass=TrafaretMeta): + def check(self, value: _Any, context: Optional[_Any] = ...) -> _Any: ... + def _failure(self, error: str = None, value: _Any = ...) -> NoReturn: ... + def append(self, other): ... + def __or__(self, other): ... + def __and__(self, other): ... + def __rshift__(self, other: str): ... + def __call__(self, val, context=None): ... + +class Key: + def __init__(self, name: Hashable, + default: _Any = ..., + optional: bool = False, + to_name: Hashable = None, + trafaret: Trafaret = None) -> None: ... + def get_name(self) -> Hashable: ... + def __rshift__(self, other: str): ... + +def ensure_trafaret(trafaret): ... + +class TypeMeta(TrafaretMeta): + def __getitem__(self, type_): ... + +class SquareBracketsMeta(TrafaretMeta): + def __getitem__(self, args): ... + +class OnError(Trafaret): + def __init__(self, trafaret: Trafaret, message: str, code: str = None) -> None: ... + def transform(self, value: _Any, context: _Any = None): ... +class WithRepr(Trafaret): ... + +class TypingTrafaret(Trafaret, metaclass=TypeMeta): ... +class Subclass(TypingTrafaret): ... +class Type(TypingTrafaret): ... + +class Any(Trafaret): ... +class And(Trafaret): ... +class Or(Trafaret): ... + +class Dict(Trafaret): + def __init__(self, *args, **trafarets) -> None: ... + def allow_extra(self, *names: str, **kw) -> Dict: ... + def ignore_extra(self, *names: str) -> Dict: ... + def merge(self, other: Dict | list | tuple | dict) -> Dict: ... +class DictKeys(Trafaret): ... +class Mapping(Trafaret): + def __init__(self, key, value) -> None: ... +class Enum(Trafaret): + def __init__(self, *variants) -> None: ... +class Callable(Trafaret): ... +class Call(Trafaret): + def __init__(self, fn) -> None: ... +class Forward(Trafaret): ... +class List(Trafaret, metclass=SquareBracketsMeta): + def __init__(self, trafaret: Trafaret | _Type[Trafaret], + min_length: int = 0, max_length: int = None) -> None: ... +class Iterable(Trafaret): + def __init__(self, trafaret: Trafaret, + min_length: int = 0, max_length: int = None) -> None: ... +class Tuple(Trafaret): + def __init__(self, *args): ... +class Atom(Trafaret): + def __init__(self, value: str) -> None: ... +class String(Trafaret): + def __init__(self, allow_blank: bool = False, + min_length: int = None, max_length: int = None): ... +class Bytes(Trafaret): ... +class FromBytes(Trafaret): + def __init__(self, encoding: str = 'utf-8') -> None: ... +class Null(Trafaret): ... +class Bool(Trafaret): ... +class ToBool(Trafaret): ... + +class Date(Trafaret): ... +class ToDate(Trafaret): ... +class DateTime(Trafaret): ... +class ToDateTime(Trafaret): ... + +def guard(trafaret: Trafaret = None, **kwargs): ... +def ignore(val): ... +def catch(checker, *a, **kw): ... +def extract_error(checker, *a, **kw): ... + +class GuardError(DataError): ... diff --git a/stubs/trafaret/constructor.pyi b/stubs/trafaret/constructor.pyi new file mode 100644 index 0000000000..d7f299eb98 --- /dev/null +++ b/stubs/trafaret/constructor.pyi @@ -0,0 +1,12 @@ +from typing import Any as _Any +from .base import Trafaret, Key + + +class ConstructMeta(type): + def __or__(self, other): ... + def __and__(self, other): ... + +class C(object, metaclass=ConstructMeta): ... + +def construct(arg: _Any) -> Trafaret: ... +def construct_key(key: _Any) -> Key: ... diff --git a/stubs/trafaret/dataerror.pyi b/stubs/trafaret/dataerror.pyi new file mode 100644 index 0000000000..a2a028ebaf --- /dev/null +++ b/stubs/trafaret/dataerror.pyi @@ -0,0 +1,11 @@ +from typing import Any +from trafaret.base import Trafaret + +class DataError(ValueError): + def __init__(self, + error: str = None, + name: str = None, + value: Any = ..., + trafaret: Trafaret = None): ... + def as_dict(self, value=False): ... + ... diff --git a/stubs/trafaret/internet.pyi b/stubs/trafaret/internet.pyi new file mode 100644 index 0000000000..edb3961ef5 --- /dev/null +++ b/stubs/trafaret/internet.pyi @@ -0,0 +1,13 @@ +from .base import OnError, WithRepr +from .regexp import Regexp, RegexpString + + +Email: WithRepr +URL: WithRepr +IPv4: WithRepr +IPv6: WithRepr +IP: WithRepr + + +class Hex(RegexpString): ... +class URLSafe(RegexpString): ... diff --git a/stubs/trafaret/keys.pyi b/stubs/trafaret/keys.pyi new file mode 100644 index 0000000000..d4d27bed85 --- /dev/null +++ b/stubs/trafaret/keys.pyi @@ -0,0 +1,15 @@ +from typing import Sequence, Union +from .base import Dict, Key, Trafaret + + +class KeysSubset(Key): + def __init__(self, *keys: Sequence[str]): ... + def __call__(self, data): ... + + +def subdict(name: str, *keys: Sequence[Union[str, Key]], + trafaret: Trafaret) -> Dict: ... + +def xor_key(first, second, trafaret: Trafaret): ... + +def confirm_key(name, confirm_name, trafaret: Trafaret): ... diff --git a/stubs/trafaret/lib.pyi b/stubs/trafaret/lib.pyi new file mode 100644 index 0000000000..dafc38c76e --- /dev/null +++ b/stubs/trafaret/lib.pyi @@ -0,0 +1 @@ +_empty: object diff --git a/stubs/trafaret/numeric.pyi b/stubs/trafaret/numeric.pyi new file mode 100644 index 0000000000..25d55034e7 --- /dev/null +++ b/stubs/trafaret/numeric.pyi @@ -0,0 +1,15 @@ +from .base import TrafaretMeta, Trafaret + + +class NumberMeta(TrafaretMeta): + def __getitem__(cls, slice_): ... + def __lt__(cls, lt): ... + def __gt__(cls, gt): ... + +class Float(Trafaret, metaclass=NumberMeta): + def __init__(self, gte: float = None, lte: float = None, + gt: float = None, lt: float = None): ... +class ToFloat(Float): ... +class Int(Float): ... +class ToInt(Int): ... +class ToDecimal(Float): ... diff --git a/stubs/trafaret/regexp.pyi b/stubs/trafaret/regexp.pyi new file mode 100644 index 0000000000..ef6373adfe --- /dev/null +++ b/stubs/trafaret/regexp.pyi @@ -0,0 +1,9 @@ +from .base import Trafaret, String + + +class RegexpRaw(Trafaret): + def __init__(self, regexp: str, re_flags: int = 0): ... + +class Regexp(RegexpRaw): ... + +class RegexpString(String, Regexp): ... diff --git a/tests/agent/BUILD b/tests/agent/BUILD new file mode 100644 index 0000000000..6a349731a3 --- /dev/null +++ b/tests/agent/BUILD @@ -0,0 +1,11 @@ +python_test_utils( + name="test_utils", +) + +python_tests( + name="tests", + dependencies=[ + "src/ai/backend/agent:service", + "src/ai/backend/testutils:lib", + ], +) diff --git a/tests/agent/__init__.py b/tests/agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/agent/conftest.py b/tests/agent/conftest.py new file mode 100644 index 0000000000..5197b8cc27 --- /dev/null +++ b/tests/agent/conftest.py @@ -0,0 +1,164 @@ +import asyncio +import os +import secrets +import shutil +from collections import defaultdict +from pathlib import Path + +from ai.backend.common import config +from ai.backend.common import validators as tx +from ai.backend.common.types import EtcdRedisConfig, HostPortPair +from ai.backend.testutils.bootstrap import etcd_container, redis_container # noqa: F401 +from ai.backend.testutils.pants import get_parallel_slot + +import aiodocker +import pytest + + +@pytest.fixture(scope='session') +def test_id(): + return f'testing-{secrets.token_urlsafe(8)}' + + +@pytest.fixture(scope='session') +def local_config(test_id, etcd_container, redis_container): # noqa: F811 + # ipc_base_path = Path.cwd() / f'tmp/backend.ai/ipc-{test_id}' + ipc_base_path = Path.cwd() / f'ipc/ipc-{test_id}' + ipc_base_path.mkdir(parents=True, exist_ok=True) + etcd_addr = etcd_container[1] + + cfg = { + 'agent': { + 'region': f"rg-{test_id}", + 'id': f"i-{test_id}", + 'scaling-group': f"sg-{test_id}", + 'ipc-base-path': ipc_base_path, + 'backend': 'docker', + 'rpc-listen-addr': HostPortPair('', 6001), + 'agent-sock-port': 6009, + }, + 'container': { + 'scratch-type': 'hostdir', + 'stats-type': 'docker', + 'port-range': [ + 19000 + 200 * get_parallel_slot(), + 19200 + 200 * get_parallel_slot(), + ], + }, + 'resource': { + 'reserved-cpu': 1, + 'reserved-mem': tx.BinarySize().check('256M'), + 'reserved-disk': tx.BinarySize().check('1G'), + }, + 'logging': {}, + 'debug': defaultdict(lambda: False), + 'etcd': { + 'addr': etcd_addr, + 'namespace': f'ns-{test_id}', + }, + 'redis': EtcdRedisConfig( + addr=redis_container[1], + sentinel=None, + service_name=None, + password=None, + ), + 'plugins': {}, + } + + def _override_if_exists(src: dict, dst: dict, key: str) -> None: + sentinel = object() + if (val := src.get(key, sentinel)) is not sentinel: + dst[key] = val + + try: + # Override external database config with the current environment's config. + fs_local_config, cfg_src_path = config.read_from_file(None, 'agent') + cfg['etcd']['addr'] = fs_local_config['etcd']['addr'] + _override_if_exists(fs_local_config['etcd'], cfg['etcd'], 'user') + _override_if_exists(fs_local_config['etcd'], cfg['etcd'], 'password') + except config.ConfigurationError: + pass + yield cfg + shutil.rmtree(ipc_base_path) + + +@pytest.fixture(scope='session', autouse=True) +def test_local_instance_id(local_config, session_mocker, test_id): + ipc_base_path = local_config['agent']['ipc-base-path'] + registry_state_path = ipc_base_path / f'last_registry.{test_id}.dat' + try: + os.unlink(registry_state_path) + except FileNotFoundError: + pass + mock_generate_local_instance_id = session_mocker.patch( + 'ai.backend.agent.agent.generate_local_instance_id', + ) + mock_generate_local_instance_id.return_value = f"i-{test_id}" + yield + try: + os.unlink(registry_state_path) + except FileNotFoundError: + pass + + +@pytest.fixture(scope='session') +def prepare_images(): + + async def pull(): + docker = aiodocker.Docker() + images_to_pull = [ + 'alpine:3.8', + 'nginx:1.17-alpine', + ] + for img in images_to_pull: + try: + await docker.images.inspect(img) + except aiodocker.exceptions.DockerError as e: + assert e.status == 404 + print(f'Pulling image "{img}" for testing...') + await docker.pull(img) + await docker.close() + + # We need to preserve the current loop configured by pytest-asyncio + # because asyncio.run() calls asyncio.set_event_loop(None) upon its completion. + # Here we cannot just use "event_loop" fixture because this fixture + # is session-scoped and pytest does not allow calling function-scoped fixtuers + # from session-scoped fixtures. + try: + old_loop = asyncio.get_event_loop() + except RuntimeError as exc: + if 'no current event loop' not in str(exc): + raise + try: + asyncio.run(pull()) + finally: + asyncio.set_event_loop(old_loop) + + +@pytest.fixture +async def docker(): + docker = aiodocker.Docker() + try: + yield docker + finally: + await docker.close() + + +@pytest.fixture +async def create_container(test_id, docker): + container = None + cont_id = secrets.token_urlsafe(4) + + async def _create_container(config): + nonlocal container + container = await docker.containers.create_or_replace( + config=config, + name=f'kernel.{test_id}-{cont_id}', + ) + return container + + try: + yield _create_container + finally: + if container is not None: + await container.delete(force=True) diff --git a/tests/agent/docker/BUILD b/tests/agent/docker/BUILD new file mode 100644 index 0000000000..b446bff8fd --- /dev/null +++ b/tests/agent/docker/BUILD @@ -0,0 +1,6 @@ +python_tests( + name="tests", + dependencies=[ + "src/ai/backend/common:lib", + ], +) diff --git a/tests/agent/docker/test_agent.py b/tests/agent/docker/test_agent.py new file mode 100644 index 0000000000..0b9b8f76be --- /dev/null +++ b/tests/agent/docker/test_agent.py @@ -0,0 +1,211 @@ +import platform +import signal +from typing import ( + Any, + Mapping, +) +from unittest.mock import AsyncMock, MagicMock + +from aiodocker.exceptions import DockerError +import pytest + +from ai.backend.agent.docker.agent import DockerAgent +from ai.backend.common.exception import ImageNotAvailable +from ai.backend.common.types import AutoPullBehavior +from ai.backend.common.docker import ImageRef + + +class DummyEtcd: + async def get_prefix(self, key: str) -> Mapping[str, Any]: + pass + + +@pytest.fixture +async def agent(local_config, mocker): + dummy_etcd = DummyEtcd() + mocked_etcd_get_prefix = AsyncMock(return_value={}) + mocker.patch.object(dummy_etcd, 'get_prefix', new=mocked_etcd_get_prefix) + agent = await DockerAgent.new( + dummy_etcd, + local_config, + stats_monitor=None, + error_monitor=None, + skip_initial_scan=True, + ) # for faster test iteration + try: + yield agent + finally: + await agent.shutdown(signal.SIGTERM) + + +@pytest.mark.asyncio +async def test_init(agent, mocker): + print(agent) + +ret = platform.machine().lower() +aliases = { + "arm64": "aarch64", # macOS with LLVM + "amd64": "x86_64", # Windows/Linux + "x64": "x86_64", # Windows + "x32": "x86", # Windows + "i686": "x86", # Windows +} +arch = aliases.get(ret, ret) + +imgref = ImageRef( + 'index.docker.io/lablup/lua:5.3-alpine3.8', architecture=arch) +query_digest = "sha256:b000000000000000000000000000000000000000000000000000000000000001" +digest_matching_image_info = { + "Id": "sha256:b000000000000000000000000000000000000000000000000000000000000001", + "RepoTags": [ + "lablup/lua:5.3-alpine3.8", + ], +} +digest_mismatching_image_info = { + "Id": "sha256:a000000000000000000000000000000000000000000000000000000000000002", + "RepoTags": [ + "lablup/lua:5.3-alpine3.8", + ], +} + + +@pytest.mark.asyncio +async def test_auto_pull_digest_when_digest_matching(agent, mocker): + behavior = AutoPullBehavior.DIGEST + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock(return_value=digest_matching_image_info) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert not pull + inspect_mock.assert_awaited_with(imgref.canonical) + + +@pytest.mark.asyncio +async def test_auto_pull_digest_when_digest_mismatching(agent, mocker): + behavior = AutoPullBehavior.DIGEST + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock(return_value=digest_mismatching_image_info) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert pull + inspect_mock.assert_awaited_with(imgref.canonical) + + +@pytest.mark.asyncio +async def test_auto_pull_digest_when_missing(agent, mocker): + behavior = AutoPullBehavior.DIGEST + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock( + side_effect=DockerError( + status=404, + data={'message': 'Simulated missing image'}, + ), + ) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert pull + inspect_mock.assert_called_with(imgref.canonical) + + +@pytest.mark.asyncio +async def test_auto_pull_tag_when_digest_matching(agent, mocker): + behavior = AutoPullBehavior.TAG + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock(return_value=digest_matching_image_info) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert not pull + inspect_mock.assert_awaited_with(imgref.canonical) + + +@pytest.mark.asyncio +async def test_auto_pull_tag_when_digest_mismatching(agent, mocker): + behavior = AutoPullBehavior.TAG + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock(return_value=digest_mismatching_image_info) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert not pull + inspect_mock.assert_awaited_with(imgref.canonical) + + +@pytest.mark.asyncio +async def test_auto_pull_tag_when_missing(agent, mocker): + behavior = AutoPullBehavior.TAG + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock( + side_effect=DockerError( + status=404, + data={'message': 'Simulated missing image'}, + ), + ) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert pull + inspect_mock.assert_called_with(imgref.canonical) + + +@pytest.mark.asyncio +async def test_auto_pull_none_when_digest_matching(agent, mocker): + behavior = AutoPullBehavior.NONE + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock(return_value=digest_matching_image_info) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert not pull + inspect_mock.assert_awaited_with(imgref.canonical) + + +@pytest.mark.asyncio +async def test_auto_pull_none_when_digest_mismatching(agent, mocker): + behavior = AutoPullBehavior.NONE + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock(return_value=digest_mismatching_image_info) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + pull = await agent.check_image(imgref, query_digest, behavior) + assert not pull + inspect_mock.assert_awaited_with(imgref.canonical) + + +@pytest.mark.asyncio +async def test_auto_pull_none_when_missing(agent, mocker): + behavior = AutoPullBehavior.NONE + docker_mock = MagicMock() + docker_mock.close = AsyncMock() + docker_mock.images = MagicMock() + inspect_mock = AsyncMock( + side_effect=DockerError( + status=404, + data={'message': 'Simulated missing image'}, + ), + ) + docker_mock.images.inspect = inspect_mock + mocker.patch('ai.backend.agent.docker.agent.Docker', return_value=docker_mock) + with pytest.raises(ImageNotAvailable) as e: + await agent.check_image(imgref, query_digest, behavior) + assert e.value.args[0] is imgref + inspect_mock.assert_called_with(imgref.canonical) diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py new file mode 100644 index 0000000000..f66db52491 --- /dev/null +++ b/tests/agent/test_agent.py @@ -0,0 +1,64 @@ +''' +TODO: rewrite +''' + +import pytest + +from unittest.mock import AsyncMock + +from ai.backend.agent.server import AgentRPCServer + + +class Dummy: + pass + + +kgid = "kernel-gid" +kuid = "kernel-uid" +ctnr = "container" + + +@pytest.fixture +async def arpcs_no_ainit(test_id, redis_container): + etcd = Dummy() + etcd.get_prefix = None + ars = AgentRPCServer(etcd=etcd, local_config={ctnr: {}}, skip_detect_manager=True) + yield ars + + +@pytest.mark.asyncio +async def test_read_agent_config_container_invalid01(arpcs_no_ainit, mocker): + inspect_mock = AsyncMock(return_value={'a': 1, 'b': 2}) + mocker.patch.object(arpcs_no_ainit.etcd, 'get_prefix', new=inspect_mock) + await arpcs_no_ainit.read_agent_config_container() + assert kgid not in arpcs_no_ainit.local_config[ctnr] + assert kuid not in arpcs_no_ainit.local_config[ctnr] + + +@pytest.mark.asyncio +async def test_read_agent_config_container_invalid02(arpcs_no_ainit, mocker): + inspect_mock = AsyncMock(return_value={}) + mocker.patch.object(arpcs_no_ainit.etcd, 'get_prefix', new=inspect_mock) + await arpcs_no_ainit.read_agent_config_container() + assert kgid not in arpcs_no_ainit.local_config[ctnr] + assert kuid not in arpcs_no_ainit.local_config[ctnr] + + +@pytest.mark.asyncio +async def test_read_agent_config_container_1valid(arpcs_no_ainit, mocker): + inspect_mock = AsyncMock(return_value={kgid: 10}) + mocker.patch.object(arpcs_no_ainit.etcd, 'get_prefix', new=inspect_mock) + await arpcs_no_ainit.read_agent_config_container() + + assert arpcs_no_ainit.local_config[ctnr][kgid] == 10 + assert kuid not in arpcs_no_ainit.local_config[ctnr] + + +@pytest.mark.asyncio +async def test_read_agent_config_container_2valid(arpcs_no_ainit, mocker): + inspect_mock = AsyncMock(return_value={kgid: 10, kuid: 20}) + mocker.patch.object(arpcs_no_ainit.etcd, 'get_prefix', new=inspect_mock) + await arpcs_no_ainit.read_agent_config_container() + + assert arpcs_no_ainit.local_config[ctnr][kgid] == 10 + assert arpcs_no_ainit.local_config[ctnr][kuid] == 20 diff --git a/tests/agent/test_alloc_map.py b/tests/agent/test_alloc_map.py new file mode 100644 index 0000000000..1532280fff --- /dev/null +++ b/tests/agent/test_alloc_map.py @@ -0,0 +1,874 @@ +from decimal import Decimal, ROUND_DOWN + +import attr +import pytest +import random + +from ai.backend.agent.resources import ( + AbstractComputeDevice, + DeviceSlotInfo, + DiscretePropertyAllocMap, + FractionAllocMap, AllocationStrategy, +) +from ai.backend.agent.exception import ( + InsufficientResource, + InvalidResourceArgument, + InvalidResourceCombination, NotMultipleOfQuantum, +) +from ai.backend.common.types import ( + DeviceId, + SlotName, + SlotTypes, +) + + +@attr.s(auto_attribs=True) +class DummyDevice(AbstractComputeDevice): + pass + + +@pytest.mark.parametrize("alloc_strategy", [AllocationStrategy.FILL, AllocationStrategy.EVENLY]) +def test_discrete_alloc_map(alloc_strategy: AllocationStrategy): + alloc_map = DiscretePropertyAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1)), + }, + allocation_strategy=alloc_strategy, + ) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + + result = alloc_map.allocate({ + SlotName('x'): Decimal('1'), + }) + assert result[SlotName('x')][DeviceId('a0')] == 1 + assert DeviceId('a1') not in result[SlotName('x')] + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 1 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal('3'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 1 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + + alloc_map.free(result) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + + +def test_discrete_alloc_map_large_number_fill(): + alloc_map = DiscretePropertyAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(100)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(100)), + }, + allocation_strategy=AllocationStrategy.FILL, + ) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + + result = alloc_map.allocate({ + SlotName('x'): Decimal('130'), + }) + assert result[SlotName('x')][DeviceId('a0')] == 100 + assert result[SlotName('x')][DeviceId('a1')] == 30 + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 100 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 30 + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal('71'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 100 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 30 + + alloc_map.free(result) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + + +def test_discrete_alloc_map_large_number_even(): + alloc_map = DiscretePropertyAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(100)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(100)), + }, + allocation_strategy=AllocationStrategy.EVENLY, + ) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + + result1 = alloc_map.allocate({ + SlotName('x'): Decimal('130'), + }) + assert result1[SlotName('x')][DeviceId('a0')] == 65 + assert result1[SlotName('x')][DeviceId('a1')] == 65 + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 65 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 65 + + result2 = alloc_map.allocate({ + SlotName('x'): Decimal('15'), + }) + assert result2[SlotName('x')][DeviceId('a0')] == 8 + assert result2[SlotName('x')][DeviceId('a1')] == 7 + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 73 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 72 + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal('99'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 73 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 72 + + alloc_map.free(result1) + alloc_map.free(result2) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + + +def test_discrete_alloc_map_even_to_tightly_fill(): + alloc_map = DiscretePropertyAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(10)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(10)), + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(10)), + }, + allocation_strategy=AllocationStrategy.EVENLY, + ) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == 0 + + result1 = alloc_map.allocate({ + SlotName('x'): Decimal('7'), + }) + assert result1[SlotName('x')][DeviceId('a0')] == 3 + assert result1[SlotName('x')][DeviceId('a1')] == 2 + assert result1[SlotName('x')][DeviceId('a2')] == 2 + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 3 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 2 + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == 2 + + result2 = alloc_map.allocate({ + SlotName('x'): Decimal('23'), + }) + assert result2[SlotName('x')][DeviceId('a0')] == 7 + assert result2[SlotName('x')][DeviceId('a1')] == 8 + assert result2[SlotName('x')][DeviceId('a2')] == 8 + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 10 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 10 + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == 10 + + alloc_map.free(result1) + alloc_map.free(result2) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == 0 + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == 0 + + +def test_discrete_alloc_map_cpu_even(): + alloc_map = DiscretePropertyAllocMap( + device_slots={ + DeviceId('cpu0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cpu'), Decimal(2)), + DeviceId('cpu1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cpu'), Decimal(2)), + DeviceId('cpu2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cpu'), Decimal(2)), + DeviceId('cpu3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cpu'), Decimal(2)), + }, + allocation_strategy=AllocationStrategy.EVENLY, + ) + + def check_clean(): + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu0')] == 0 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu1')] == 0 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu2')] == 0 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu3')] == 0 + + check_clean() + + result1 = alloc_map.allocate({ + SlotName('cpu'): Decimal('4'), + }) + assert result1[SlotName('cpu')][DeviceId('cpu0')] == 1 + assert result1[SlotName('cpu')][DeviceId('cpu1')] == 1 + assert result1[SlotName('cpu')][DeviceId('cpu2')] == 1 + assert result1[SlotName('cpu')][DeviceId('cpu3')] == 1 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu0')] == 1 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu1')] == 1 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu2')] == 1 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu3')] == 1 + + result2 = alloc_map.allocate({ + SlotName('cpu'): Decimal('2'), + }) + assert result2[SlotName('cpu')][DeviceId('cpu0')] == 1 + assert result2[SlotName('cpu')][DeviceId('cpu1')] == 1 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu0')] == 2 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu1')] == 2 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu2')] == 1 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu3')] == 1 + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('cpu'): Decimal('3'), + }) + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu0')] == 2 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu1')] == 2 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu2')] == 1 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu3')] == 1 + + result3 = alloc_map.allocate({ + SlotName('cpu'): Decimal('2'), + }) + assert result3[SlotName('cpu')][DeviceId('cpu2')] == 1 + assert result3[SlotName('cpu')][DeviceId('cpu3')] == 1 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu0')] == 2 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu1')] == 2 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu2')] == 2 + assert alloc_map.allocations[SlotName('cpu')][DeviceId('cpu3')] == 2 + + alloc_map.free(result1) + alloc_map.free(result2) + alloc_map.free(result3) + check_clean() + + +def test_fraction_alloc_map(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + }, + allocation_strategy=AllocationStrategy.FILL, + ) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('1.5'), + }) + assert result[SlotName('x')][DeviceId('a0')] == Decimal('1.0') + assert result[SlotName('x')][DeviceId('a1')] == Decimal('0.5') + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('1.0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.5') + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal('1.5'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('1.0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.5') + + alloc_map.free(result) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0') + + +def test_fraction_alloc_map_many_device(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a4'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a5'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a6'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a7'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + }, + allocation_strategy=AllocationStrategy.FILL, + ) + for idx in range(8): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('7.95'), + }) + for idx in range(7): + assert result[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('1.0') + assert result[SlotName('x')][DeviceId('a7')] == Decimal('0.95') + for idx in range(7): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('1.0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a7')] == Decimal('0.95') + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal('1.0'), + }) + for idx in range(7): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('1.0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a7')] == Decimal('0.95') + + alloc_map.free(result) + for idx in range(8): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + +def test_fraction_alloc_map_iteration(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + }, + allocation_strategy=AllocationStrategy.FILL, + quantum_size=Decimal("0.00001"), + ) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0') + + for _ in range(1000): + alloc_map.allocate({ + SlotName('x'): Decimal('0.00001'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.005') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.005') + + alloc_map.free({SlotName('x'): {DeviceId('a0'): Decimal('0.00001')}}) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.00499') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.005') + + for _ in range(499): + alloc_map.free({SlotName('x'): {DeviceId('a0'): Decimal('0.00001')}}) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.005') + + +def test_fraction_alloc_map_random_generated_allocations(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.0)), + }, + allocation_strategy=AllocationStrategy.FILL, + ) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0') + + quantum = Decimal('.01') + for _ in range(5): + allocations = [] + for _ in range(10): + result = alloc_map.allocate({ + SlotName('x'): Decimal(random.uniform(0, 0.1)).quantize(quantum, ROUND_DOWN), + }) + allocations.append(result) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] >= Decimal('0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] >= Decimal('0') + for a in allocations: + alloc_map.free(a) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0') + + +def test_fraction_alloc_map_even_allocation(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(0.05)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(0.1)), + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(0.2)), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(0.3)), + DeviceId('a4'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(0.0)), + }, + allocation_strategy=AllocationStrategy.EVENLY, + ) + for idx in range(5): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal('0.66'), + }) + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal('0.06'), + }, min_memory=Decimal(0.6)) + for _ in range(20): + alloc_map.allocate({ + SlotName('x'): Decimal('0.01'), + }) + + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.05') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.1') + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0.05') + alloc_map.free({SlotName('x'): {DeviceId('a0'): Decimal('0.05'), + DeviceId('a1'): Decimal('0.1'), + DeviceId('a2'): Decimal('0.05')}}) + for idx in range(0): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('0.2'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0.2') + + alloc_map.free(result) + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('0.2'), + }, min_memory=Decimal('0.25')) + assert alloc_map.allocations[SlotName('x')][DeviceId('a3')] == Decimal('0.2') + alloc_map.free(result) + for idx in range(5): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('0.5'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0.2') + assert alloc_map.allocations[SlotName('x')][DeviceId('a3')] == Decimal('0.3') + alloc_map.free(result) + for idx in range(5): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('0.65'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.05') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.1') + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0.2') + assert alloc_map.allocations[SlotName('x')][DeviceId('a3')] == Decimal('0.3') + alloc_map.free(result) + for idx in range(5): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('0.6'), + }, min_memory=Decimal('0.1')) + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.1') + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0.2') + assert alloc_map.allocations[SlotName('x')][DeviceId('a3')] == Decimal('0.3') + alloc_map.free(result) + for idx in range(5): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('0.3')), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('0.3')), + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('0.9')), + }, + ) + result = alloc_map.allocate({ + SlotName('x'): Decimal('1'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.3') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.3') + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0.4') + + +def test_fraction_alloc_map_even_allocation_fractions(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('0.8')), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('0.75')), + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('0.7')), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('0.3')), + DeviceId('a4'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('0.0')), + }, + allocation_strategy=AllocationStrategy.EVENLY, + ) + result = alloc_map.allocate({ + SlotName('x'): Decimal('2.31'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.67') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.67') + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0.67') + assert alloc_map.allocations[SlotName('x')][DeviceId('a3')] == Decimal('0.3') + alloc_map.free(result) + for idx in range(4): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('2'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.67') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.67') + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('0.66') + alloc_map.free(result) + for idx in range(3): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + +def test_fraction_alloc_map_even_allocation_many_devices(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(2)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(3)), + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(3)), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(5)), + }, + allocation_strategy=AllocationStrategy.EVENLY, + ) + result = alloc_map.allocate({ + SlotName('x'): Decimal('6'), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('3') + assert alloc_map.allocations[SlotName('x')][DeviceId('a2')] == Decimal('3') + alloc_map.free(result) + for idx in range(4): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1)), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1.5)), + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(2)), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(3)), + DeviceId('a4'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(3)), + DeviceId('a5'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(4)), + DeviceId('a6'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(4.5)), + DeviceId('a7'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(5)), + DeviceId('a8'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(5)), + }, + allocation_strategy=AllocationStrategy.EVENLY, + ) + + result = alloc_map.allocate({ + SlotName('x'): Decimal('6'), + }, min_memory=Decimal('2.5')) + assert alloc_map.allocations[SlotName('x')][DeviceId('a3')] == Decimal('3') + assert alloc_map.allocations[SlotName('x')][DeviceId('a4')] == Decimal('3') + alloc_map.free(result) + for idx in range(9): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + result = alloc_map.allocate({ + SlotName('x'): Decimal('11'), + }, min_memory=Decimal('0.84')) + assert alloc_map.allocations[SlotName('x')][DeviceId('a3')] == Decimal('2.75') + assert alloc_map.allocations[SlotName('x')][DeviceId('a4')] == Decimal('2.75') + assert alloc_map.allocations[SlotName('x')][DeviceId('a5')] == Decimal('2.75') + assert alloc_map.allocations[SlotName('x')][DeviceId('a5')] == Decimal('2.75') + alloc_map.free(result) + for idx in range(9): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + +def test_fraction_alloc_map_even_allocation_many_devices_2(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('1.0')), + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('1.0')), + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('1.0')), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('1.0')), + DeviceId('a4'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('1.0')), + DeviceId('a5'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('1.0')), + DeviceId('a6'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('1.0')), + DeviceId('a7'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal('1.0')), + }, + allocation_strategy=AllocationStrategy.EVENLY, + ) + result = alloc_map.allocate({ + SlotName('x'): Decimal('6'), + }) + count_0 = 0 + count_1 = 0 + # NOTE: the even allocator favors the tail of device list when it fills up. + # So we rely on the counting of desire per-device allocations instead of matching + # the device index and the allocations. + for idx in range(8): + if alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('1.0'): + count_1 += 1 + if alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0'): + count_0 += 1 + assert count_0 == 2 + assert count_1 == 6 + alloc_map.free(result) + for idx in range(8): + assert alloc_map.allocations[SlotName('x')][DeviceId(f'a{idx}')] == Decimal('0') + + +@pytest.mark.parametrize( + "alloc_strategy", + [AllocationStrategy.FILL, AllocationStrategy.EVENLY], +) +def test_quantum_size(alloc_strategy): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1)), # noqa + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1)), # noqa + }, + quantum_size=Decimal("0.25"), + allocation_strategy=alloc_strategy, + ) + result = alloc_map.allocate({ + SlotName('x'): Decimal("0.5"), + }) + assert sum(alloc_map.allocations[SlotName('x')].values()) == Decimal("0.5") + alloc_map.free(result) + + result = alloc_map.allocate({ + SlotName('x'): Decimal("1.5"), + }) + assert sum(alloc_map.allocations[SlotName('x')].values()) == Decimal("1.5") + if alloc_strategy == AllocationStrategy.EVENLY: + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal("0.75") + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal("0.75") + else: + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal("1.00") + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal("0.50") + alloc_map.free(result) + + # input is below 0.25 + with pytest.raises(NotMultipleOfQuantum, match='actual calculated amount is zero'): + alloc_map.allocate({ + SlotName('x'): Decimal("0.24"), + }) + + if alloc_strategy == AllocationStrategy.EVENLY: + # input IS multiple of 0.25 but the CALCULATED allocations are not multiple of 0.25 + result = alloc_map.allocate({ + SlotName('x'): Decimal("1.75"), # divided to 0.88 and 0.87 + }) + assert sum(alloc_map.allocations[SlotName('x')].values()) == Decimal("1.5") + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.75') + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.75') + alloc_map.free(result) + + # inputs are not multiple of 0.25 + result = alloc_map.allocate({ + SlotName('x'): Decimal("0.52"), + }) + assert sum(alloc_map.allocations[SlotName('x')].values()) == Decimal("0.5") + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.5') + alloc_map.free(result) + + result = alloc_map.allocate({ + SlotName('x'): Decimal("0.42"), + }) + assert sum(alloc_map.allocations[SlotName('x')].values()) == Decimal("0.25") + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal('0.25') + alloc_map.free(result) + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal("3.99"), + }) + else: + # inputs are not multiple of 0.25 + result = alloc_map.allocate({ + SlotName('x'): Decimal("0.52"), + }) + assert sum(alloc_map.allocations[SlotName('x')].values()) == Decimal("0.5") + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.5') + alloc_map.free(result) + + result = alloc_map.allocate({ + SlotName('x'): Decimal("0.42"), + }) + assert sum(alloc_map.allocations[SlotName('x')].values()) == Decimal("0.25") + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.25') + alloc_map.free(result) + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('x'): Decimal("3.99"), + }) + # In this case, it satisfies the quantum condition, because the capacity of devices are + # multiples of the quantum. + alloc_map.allocate({ + SlotName('x'): Decimal("1.75"), + }) + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal("1.00") + assert alloc_map.allocations[SlotName('x')][DeviceId('a1')] == Decimal("0.75") + + # So let's change the situation. + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1)), # noqa + DeviceId('a1'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('x'), Decimal(1)), # noqa + }, + quantum_size=Decimal("0.3"), + allocation_strategy=alloc_strategy, + ) + result = alloc_map.allocate({ + SlotName('x'): Decimal("0.5"), + }) + assert sum(alloc_map.allocations[SlotName('x')].values()) == Decimal("0.3") + assert alloc_map.allocations[SlotName('x')][DeviceId('a0')] == Decimal('0.3') + alloc_map.free(result) + + +def test_exclusive_resource_slots(): + alloc_map = DiscretePropertyAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:1g.5gb-mig'), Decimal(1)), # noqa + DeviceId('a1'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:1g.5gb-mig'), Decimal(1)), # noqa + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cuda.device'), Decimal(1)), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cuda.device'), Decimal(1)), + DeviceId('a4'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:3g.20gb-mig'), Decimal(1)), # noqa + }, + exclusive_slot_types={'cuda.device:*-mig', 'cuda.device', 'cuda.shares'}, + ) + + def check_clean(): + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device')][DeviceId('a2')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device')][DeviceId('a3')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device:3g.20gb-mig')][DeviceId('a4')] == Decimal('0') + + with pytest.raises(InvalidResourceCombination): + alloc_map.allocate({ + SlotName('cuda.device'): Decimal('2'), + SlotName('cuda.device:1g.5gb-mig'): Decimal('1'), + }) + check_clean() + + +def test_heterogeneous_resource_slots_with_discrete_alloc_map(): + alloc_map = DiscretePropertyAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:1g.5gb-mig'), Decimal(1)), # noqa + DeviceId('a1'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:1g.5gb-mig'), Decimal(1)), # noqa + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cuda.device'), Decimal(1)), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cuda.device'), Decimal(1)), + DeviceId('a4'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:3g.20gb-mig'), Decimal(1)), # noqa + }, + exclusive_slot_types={'cuda.device:*-mig', 'cuda.device', 'cuda.shares'}, + ) + + def check_clean(): + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device')][DeviceId('a2')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device')][DeviceId('a3')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device:3g.20gb-mig')][DeviceId('a4')] == Decimal('0') + + check_clean() + + # check allocation of non-unique slots + result = alloc_map.allocate({ + SlotName('cuda.device'): Decimal('2'), + }) + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device')][DeviceId('a2')] == Decimal('1') + assert alloc_map.allocations[SlotName('cuda.device')][DeviceId('a3')] == Decimal('1') + assert alloc_map.allocations[SlotName('cuda.device:3g.20gb-mig')][DeviceId('a4')] == Decimal('0') + alloc_map.free(result) + check_clean() + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('cuda.device'): Decimal('3'), + }) + check_clean() + + # allocating zero means no-op. + alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('0'), + }) + check_clean() + + # any allocation request for unique slots should specify the amount 1. + with pytest.raises(InvalidResourceArgument): + alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('1.1'), + }) + with pytest.raises(InvalidResourceArgument): + alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('2'), + }) + check_clean() + + # test alloaction of unique slots + result1 = alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('1'), + }) + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('1') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('0') + result2 = alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('1'), + }) + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('1') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('1') + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('1'), + }) + alloc_map.free(result1) + alloc_map.free(result2) + check_clean() + + +def test_heterogeneous_resource_slots_with_fractional_alloc_map(): + alloc_map = FractionAllocMap( + device_slots={ + DeviceId('a0'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:1g.5gb-mig'), Decimal(1)), # noqa + DeviceId('a1'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:1g.5gb-mig'), Decimal(1)), # noqa + DeviceId('a2'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cuda.shares'), Decimal('1.0')), + DeviceId('a3'): DeviceSlotInfo(SlotTypes.COUNT, SlotName('cuda.shares'), Decimal('1.0')), + DeviceId('a4'): DeviceSlotInfo(SlotTypes.UNIQUE, SlotName('cuda.device:3g.20gb-mig'), Decimal(1)), # noqa + }, + exclusive_slot_types={'cuda.device:*-mig', 'cuda.device', 'cuda.shares'}, + allocation_strategy=AllocationStrategy.FILL, + ) + + def check_clean(): + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.shares')][DeviceId('a2')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.shares')][DeviceId('a3')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device:3g.20gb-mig')][DeviceId('a4')] == Decimal('0') + + check_clean() + + # check allocation of non-unique slots + result = alloc_map.allocate({ + SlotName('cuda.shares'): Decimal('2.0'), + }) + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('0') + assert alloc_map.allocations[SlotName('cuda.shares')][DeviceId('a2')] == Decimal('1.0') + assert alloc_map.allocations[SlotName('cuda.shares')][DeviceId('a3')] == Decimal('1.0') + assert alloc_map.allocations[SlotName('cuda.device:3g.20gb-mig')][DeviceId('a4')] == Decimal('0') + alloc_map.free(result) + check_clean() + + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('cuda.shares'): Decimal('2.5'), + }) + check_clean() + + # allocating zero means no-op. + alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('0'), + }) + check_clean() + + # any allocation request for unique slots should specify the amount 1. + with pytest.raises(InvalidResourceArgument): + alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('0.3'), + }) + with pytest.raises(InvalidResourceArgument): + alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('1.5'), + }) + check_clean() + + # test alloaction of unique slots + result1 = alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('1'), + }) + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('1') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('0') + result2 = alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('1'), + }) + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a0')] == Decimal('1') + assert alloc_map.allocations[SlotName('cuda.device:1g.5gb-mig')][DeviceId('a1')] == Decimal('1') + with pytest.raises(InsufficientResource): + alloc_map.allocate({ + SlotName('cuda.device:1g.5gb-mig'): Decimal('1'), + }) + alloc_map.free(result1) + alloc_map.free(result2) + check_clean() diff --git a/tests/agent/test_files.py b/tests/agent/test_files.py new file mode 100644 index 0000000000..87973b354b --- /dev/null +++ b/tests/agent/test_files.py @@ -0,0 +1,105 @@ +import os +from pathlib import Path +import tempfile + +from ai.backend.agent.docker.files import ( + scandir, diff_file_stats, +) + + +def test_scandir(): + # Create two files. + with tempfile.TemporaryDirectory() as tmpdir: + first = Path(tmpdir) / 'first.txt' + first.write_text('first') + second = Path(tmpdir) / 'second.txt' + second.write_text('second') + new_time = first.stat().st_mtime + 5 + os.utime(second, (new_time, new_time)) + + file_stats = scandir(Path(tmpdir), 1000) + + assert len(file_stats) == 2 + assert int(file_stats[second]) == int(file_stats[first]) + 5 + + +def test_scandir_skip_hidden_files(): + with tempfile.TemporaryDirectory() as tmpdir: + file = Path(tmpdir) / '.hidden_file' + file.write_text('dark templar') + file_stats = scandir(Path(tmpdir), 1000) + + assert len(file_stats) == 0 + + +def test_scandir_skip_large_files(): + with tempfile.TemporaryDirectory() as tmpdir: + file = Path(tmpdir) / 'file.jpg' + file.write_text('large file') + file_stats = scandir(Path(tmpdir), 1) + + assert len(file_stats) == 0 + + +def test_scandir_returns_files_in_sub_folder(): + with tempfile.TemporaryDirectory() as tmpdir: + sub_folder = Path(tmpdir) / 'sub' + sub_folder.mkdir() + sub_file = sub_folder / 'sub-file.txt' + sub_file.write_text('somedata') + + file_stats = scandir(Path(tmpdir), 1000) + + assert len(file_stats) == 1 + + +def test_get_new_file_diff_stats(): + with tempfile.TemporaryDirectory() as tmpdir: + first = Path(tmpdir) / 'first.txt' + first.write_text('first') + fs1 = scandir(tmpdir, 1000) + + second = Path(tmpdir) / 'second.txt' + second.write_text('second') + fs2 = scandir(tmpdir, 1000) + + diff_stats = diff_file_stats(fs1, fs2) + + assert first not in diff_stats + assert second in diff_stats + + +def test_get_modified_file_diff_stats(): + with tempfile.TemporaryDirectory() as tmpdir: + first = Path(tmpdir) / 'first.txt' + first.write_text('first') + second = Path(tmpdir) / 'second.txt' + second.write_text('second') + fs1 = scandir(tmpdir, 1000) + + new_time = first.stat().st_mtime + 5 + os.utime(second, (new_time, new_time)) + fs2 = scandir(tmpdir, 1000) + + diff_stats = diff_file_stats(fs1, fs2) + + assert first not in diff_stats + assert second in diff_stats + + +def test_get_both_new_and_modified_files_stat(): + with tempfile.TemporaryDirectory() as tmpdir: + first = Path(tmpdir) / 'first.txt' + first.write_text('first') + fs1 = scandir(tmpdir, 1000) + + new_time = first.stat().st_mtime + 5 + os.utime(first, (new_time, new_time)) + second = Path(tmpdir) / 'second.txt' + second.write_text('second') + fs2 = scandir(tmpdir, 1000) + + diff_stats = diff_file_stats(fs1, fs2) + + assert first in diff_stats + assert second in diff_stats diff --git a/tests/agent/test_kernel.py b/tests/agent/test_kernel.py new file mode 100644 index 0000000000..a935a92675 --- /dev/null +++ b/tests/agent/test_kernel.py @@ -0,0 +1,145 @@ +import pytest + +from ai.backend.agent.kernel import ( + match_distro_data, +) + + +def test_match_distro_data(): + krunner_volumes = { + 'ubuntu8.04': 'u1', + 'ubuntu18.04': 'u2', + 'centos7.6': 'c1', + 'centos8.0': 'c2', + 'centos5.0': 'c3', + } + + ret = match_distro_data(krunner_volumes, 'centos7.6') + assert ret[0] == 'centos7.6' + assert ret[1] == 'c1' + + ret = match_distro_data(krunner_volumes, 'centos8.0') + assert ret[0] == 'centos8.0' + assert ret[1] == 'c2' + + ret = match_distro_data(krunner_volumes, 'centos') + assert ret[0] == 'centos8.0' # assume latest + assert ret[1] == 'c2' + + ret = match_distro_data(krunner_volumes, 'ubuntu18.04') + assert ret[0] == 'ubuntu18.04' + assert ret[1] == 'u2' + + ret = match_distro_data(krunner_volumes, 'ubuntu20.04') + assert ret[0] == 'ubuntu18.04' + assert ret[1] == 'u2' + + ret = match_distro_data(krunner_volumes, 'ubuntu') + assert ret[0] == 'ubuntu18.04' # assume latest + assert ret[1] == 'u2' + + with pytest.raises(RuntimeError): + match_distro_data(krunner_volumes, 'ubnt') + + with pytest.raises(RuntimeError): + match_distro_data(krunner_volumes, 'xyz') + + +def test_match_distro_data_with_libc_based_krunners(): + krunner_volumes = { + 'static-gnu': 'x1', + 'static-musl': 'x2', + } + + # when there are static builds, it returns the distro name as-is + # and only distinguish the libc flavor (gnu or musl). + + ret = match_distro_data(krunner_volumes, 'centos7.6') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'centos8.0') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'centos') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'ubuntu18.04') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'ubuntu') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'alpine3.8') + assert ret[0] == 'static-musl' + assert ret[1] == 'x2' + + ret = match_distro_data(krunner_volumes, 'alpine') + assert ret[0] == 'static-musl' + assert ret[1] == 'x2' + + ret = match_distro_data(krunner_volumes, 'alpine3.11') + assert ret[0] == 'static-musl' + assert ret[1] == 'x2' + + # static-gnu works as a generic fallback in all unknown distributions + ret = match_distro_data(krunner_volumes, 'ubnt') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'xyz') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + +def test_match_distro_data_with_libc_based_krunners_mixed(): + krunner_volumes = { + 'static-gnu': 'x1', + 'alpine3.8': 'c1', + 'alpine3.11': 'c2', + } + + ret = match_distro_data(krunner_volumes, 'centos7.6') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'centos8.0') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'centos') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'ubuntu18.04') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'ubuntu') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'alpine3.8') + assert ret[0] == 'alpine3.8' + assert ret[1] == 'c1' + + ret = match_distro_data(krunner_volumes, 'alpine') + assert ret[0] == 'alpine3.11' # assume latest + assert ret[1] == 'c2' + + ret = match_distro_data(krunner_volumes, 'alpine3.11') + assert ret[0] == 'alpine3.11' + assert ret[1] == 'c2' + + # static-gnu works as a generic fallback in all unknown distributions + ret = match_distro_data(krunner_volumes, 'ubnt') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' + + ret = match_distro_data(krunner_volumes, 'xyz') + assert ret[0] == 'static-gnu' + assert ret[1] == 'x1' diff --git a/tests/agent/test_resources.py b/tests/agent/test_resources.py new file mode 100644 index 0000000000..ef0b1777b0 --- /dev/null +++ b/tests/agent/test_resources.py @@ -0,0 +1,91 @@ +import json +from unittest import mock + +import pytest + +from aioresponses import aioresponses +from ai.backend.agent.vendor import linux + +# TODO: write tests for KernelResourceSpec (read/write consistency) +# from ai.backend.agent.resources import ( +# KernelResourceSpec, +# ) + +# TODO: write tests for DiscretePropertyAllocMap, FractionAllocMap + + +def test_node_of_cpu(): + numa = linux.libnuma() + + # When NUMA is not supported. + linux._numa_supported = False + assert numa.node_of_cpu(5) == 0 + + # When NUMA is supported. + original_numa_supported = linux._numa_supported + linux._numa_supported = True + with mock.patch.object(linux, '_libnuma', create=True) \ + as mock_libnuma: + numa.node_of_cpu(5) + mock_libnuma.numa_node_of_cpu.assert_called_once_with(5) + + linux._numa_supported = original_numa_supported + + +def test_num_nodes(): + numa = linux.libnuma() + + # When NUMA is not supported. + linux._numa_supported = False + assert numa.num_nodes() == 1 + + # When NUMA is supported. + original_numa_supported = linux._numa_supported + linux._numa_supported = True + with mock.patch.object(linux, '_libnuma', create=True) \ + as mock_libnuma: + numa.num_nodes() + mock_libnuma.numa_num_configured_nodes.assert_called_once_with() + + linux._numa_supported = original_numa_supported + + +@pytest.mark.skip(reason='aioresponses 0.7 is incompatible with aiohttp 3.7+') +@pytest.mark.asyncio +async def test_get_available_cores_without_docker(monkeypatch): + + def mock_sched_getaffinity(pid): + raise AttributeError + + def mock_sched_getaffinity2(pid): + return {0, 1} + + numa = linux.libnuma() + with aioresponses() as m: + m.get('http://docker/info', body=json.dumps({ + 'NCPU': 4, + })) + + monkeypatch.setattr(linux.os, 'sched_getaffinity', + mock_sched_getaffinity, + raising=False) + monkeypatch.setattr(linux.os, 'cpu_count', lambda: 4) + numa.get_available_cores.cache_clear() + assert (await numa.get_available_cores()) == {0, 1, 2, 3} + + monkeypatch.setattr(linux.os, 'sched_getaffinity', + mock_sched_getaffinity2, + raising=False) + numa.get_available_cores.cache_clear() + assert (await numa.get_available_cores()) == {0, 1} + + +@pytest.mark.asyncio +async def test_get_core_topology(mocker): + mocker.patch.object(linux.libnuma, 'num_nodes', return_value=2) + mocker.patch.object(linux.libnuma, 'get_available_cores', + new=mock.AsyncMock(return_value={0, 1, 2, 3})) + mocker.patch.object(linux.libnuma, 'node_of_cpu', new=lambda n: n % 2 == 1) + + numa = linux.libnuma() + assert (await numa.get_core_topology()) == ([0, 2], [1, 3]) diff --git a/tests/agent/test_server.py b/tests/agent/test_server.py new file mode 100644 index 0000000000..ac8dc46b6d --- /dev/null +++ b/tests/agent/test_server.py @@ -0,0 +1,370 @@ +# import asyncio + +# import pytest + +# from ai.backend.agent.server import ( +# AgentRPCServer, +# ) + + +def test_dummy(): + # prevent pants error due to pytest exit code 5: "no tests collected" + pass + + +# TODO: rewrite +''' +@pytest.fixture +async def agent(request, tmpdir, event_loop): + config = argparse.Namespace() + config.namespace = os.environ.get('BACKEND_NAMESPACE', 'testing') + config.agent_host = '127.0.0.1' + config.agent_port = 6001 # default 6001 + config.stat_port = 6002 + config.kernel_host_override = '127.0.0.1' + etcd_addr = os.environ.get('BACKEND_ETCD_ADDR', '127.0.0.1:2379') + redis_addr = os.environ.get('BACKEND_REDIS_ADDR', '127.0.0.1:6379') + config.etcd_addr = host_port_pair(etcd_addr) + config.redis_addr = host_port_pair(redis_addr) + config.event_addr = '127.0.0.1:5000' # dummy value + config.docker_registry = 'lablup' + config.debug = True + config.debug_kernel = None + config.kernel_aliases = None + config.scratch_root = Path(tmpdir) + config.limit_cpus = None + config.limit_gpus = None + config.debug_kernel = None + config.debug_hook = None + config.debug_jail = None + config.debug_skip_container_deletion = False + + agent = None + + config.instance_id = await identity.get_instance_id() + config.inst_type = await identity.get_instance_type() + config.region = await identity.get_instance_region() + print(f'serving test agent: {config.instance_id} ({config.inst_type}),' + f' ip: {config.agent_host}') + agent = AgentRPCServer(config, loop=event_loop) + await agent.init(skip_detect_manager=True) + await asyncio.sleep(0) + + yield agent + + print('shutting down test agent...') + if agent: + await agent.shutdown() + await asyncio.sleep(3) + + +@pytest.mark.asyncio +async def test_get_extra_volumes(docker): + # No extra volumes + mnt_list = await get_extra_volumes(docker, 'python:latest') + assert len(mnt_list) == 0 + + # Create fake deeplearning sample volume and check it will be returned + vol = None + try: + config = {'Name': 'deeplearning-samples'} + vol = await docker.volumes.create(config) + mnt_list = await get_extra_volumes(docker, 'python-tensorflow:latest') + finally: + if vol: + await vol.delete() + + assert len(mnt_list) == 1 + assert mnt_list[0].name == 'deeplearning-samples' + + +@pytest.mark.asyncio +async def test_get_kernel_id_from_container(docker, container): + container_list = await docker.containers.list() + kid = await get_kernel_id_from_container(container_list[0]) + + assert kid == 'test-container' # defined as in the fixture + + +@pytest.fixture +async def kernel_info(agent, docker): + kernel_id = str(uuid.uuid4()) + config = { + 'lang': 'lua:5.3-alpine', + 'limits': {'cpu_slot': 1, 'gpu_slot': 0, 'mem_slot': 1, 'tpu_slot': 0}, + 'mounts': [], + 'environ': {}, + } + kernel_info = await agent.create_kernel(kernel_id, config) + + try: + yield kernel_info + finally: + if kernel_info['id'] in agent.container_registry: + # Container id may be changed (e.g. restarting kernel), so we + # should not rely on the initial value of the container_id. + container_info = agent.container_registry[kernel_info['id']] + container_id = container_info['container_id'] + else: + # If fallback to initial container_id if kernel is deleted. + container_id = kernel_info['container_id'] + try: + container = docker.containers.container(container_id) + cinfo = await container.show() if container else None + except aiodocker.exceptions.DockerError: + cinfo = None + if cinfo and cinfo['State']['Status'] != 'removing': + await container.delete(force=True) + + +@pytest.mark.integration +def test_ping(agent): + ret = agent.ping('ping~') + assert ret == 'ping~' + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_scan_running_containers(agent, kernel_info, docker): + agent.container_registry.clear() + assert kernel_info['id'] not in agent.container_registry + await agent.scan_running_containers() + assert agent.container_registry[kernel_info['id']] + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_create_kernel(agent, docker): + kernel_id = str(uuid.uuid4()) + config = { + 'lang': 'lablup/lua:5.3-alpine', + 'limits': {'cpu_slot': 1, 'gpu_slot': 0, 'mem_slot': 1, 'tpu_slot': 0}, + 'mounts': [], + 'environ': {}, + } + + kernel_info = container_info = None + try: + kernel_info = await agent.create_kernel(kernel_id, config) + container_info = agent.container_registry[kernel_id] + finally: + container = docker.containers.container(kernel_info['container_id']) + await container.delete(force=True) + + assert kernel_info + assert container_info + assert kernel_info['id'] == kernel_id + # TODO: rewrite using resource_spec: + # assert len(kernel_info['cpu_set']) == 1 + assert container_info['lang'] == config['lang'] + assert container_info['container_id'] == kernel_info['container_id'] + # TODO: rewrite using resource_spec: + # assert container_info['limits'] == config['limits'] + # assert container_info['mounts'] == config['mounts'] + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_destroy_kernel(agent, kernel_info): + stat = await agent.destroy_kernel(kernel_info['id']) + + assert stat + assert 'cpu_used' in stat + assert 'mem_max_bytes' in stat + assert 'mem_cur_bytes' in stat + assert 'net_rx_bytes' in stat + assert 'net_tx_bytes' in stat + assert 'io_read_bytes' in stat + assert 'io_write_bytes' in stat + assert 'io_max_scratch_size' in stat + assert 'io_cur_scratch_size' in stat + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_restart_kernel(agent, kernel_info): + kernel_id = kernel_info['id'] + container_id = kernel_info['container_id'] + new_config = { + 'lang': 'lablup/lua:5.3-alpine', + 'limits': {'cpu_slot': 1, 'gpu_slot': 0, 'mem_slot': 1, 'tpu_slot': 0}, + 'mounts': [], + } + + ret = await agent.restart_kernel(kernel_id, new_config) + + assert container_id != ret['container_id'] + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_restart_kernel_cancel_code_execution( + agent, kernel_info, event_loop): + async def execute_code(): + nonlocal kernel_info + api_ver = 2 + kid = kernel_info['id'] + runid = 'test-run-id' + mode = 'query' + code = ('local clock = os.clock\n' + 'function sleep(n)\n' + ' local t0 = clock()\n' + ' while clock() - t0 <= n do end\n' + 'end\n' + 'sleep(10)\nprint("code executed")') + while True: + ret = await agent.execute(api_ver, kid, runid, mode, code, {}) + if ret is None: + break + elif ret['status'] == 'finished': + break + elif ret['status'] == 'continued': + mode = 'continue', + code = '' + else: + raise Exception('Invalid execution status') + return ret + + async def restart_kernel(): + nonlocal kernel_info + kernel_id = kernel_info['id'] + new_config = { + 'lang': 'lablup/lua:5.3-alpine', + 'limits': {'cpu_slot': 1, 'gpu_slot': 0, 'mem_slot': 1, 'tpu_slot': 0}, + 'mounts': [], + } + await agent.restart_kernel(kernel_id, new_config) + + t1 = asyncio.ensure_future(execute_code(), loop=event_loop) + start = datetime.now() + await asyncio.sleep(1) + t2 = asyncio.ensure_future(restart_kernel(), loop=event_loop) + results = await asyncio.gather(t1, t2) + end = datetime.now() + + assert results[0] is None # no execution result + assert (end - start).total_seconds() < 10 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_execute(agent, kernel_info): + # Test with lua:5.3-alpine image only + api_ver = 2 + kid = kernel_info['id'] + runid = 'test-run-id' + mode = 'query' + code = 'print(17)' + + while True: + ret = await agent.execute(api_ver, kid, runid, mode, code, {}) + if ret['status'] == 'finished': + break + elif ret['status'] == 'continued': + mode = 'continue', + code = '' + else: + raise Exception('Invalid execution status') + + assert ret['console'][0][0] == 'stdout' + assert ret['console'][0][1] == '17\n' + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_execute_batch_mode(agent, kernel_info): + # Test with lua:5.3-alpine image only + api_ver = 2 + kid = kernel_info['id'] + runid = 'test-run-id' + mode = 'batch' + code = '' + opt = {'clean': '*', + 'build': '*', + 'exec': '*'} + + # clean_finished = False + build_finished = False + + await agent.upload_file(kid, 'main.lua', b'print(17)') + while True: + ret = await agent.execute(api_ver, kid, runid, mode, code, opt) + if ret['status'] == 'finished': + # assert clean_finished and build_finished + assert build_finished + break + # elif ret['status'] == 'clean-finished': + # assert not clean_finished and not build_finished + # clean_finished = True + # mode = 'continue' + elif ret['status'] == 'build-finished': + # assert clean_finished and not build_finished + assert not build_finished + build_finished = True + mode = 'continue' + elif ret['status'] == 'continued': + mode = 'continue' + else: + raise Exception('Invalid execution status') + + assert ret['console'][0][0] == 'stdout' + assert ret['console'][0][1] == '17\n' + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_upload_file(agent, kernel_info): + fname = 'test.txt' + await agent.upload_file(kernel_info['id'], fname, b'test content') + uploaded_to = agent.config.scratch_root / kernel_info['id'] / '.work' / fname + assert uploaded_to.exists() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_reset(agent, docker): + kernel_ids = [] + container_ids = [] + config = { + 'lang': 'lablup/lua:5.3-alpine', + 'limits': {'cpu_slot': 1, 'gpu_slot': 0, 'mem_slot': 1, 'tpu_slot': 0}, + 'mounts': [], + } + + try: + # Create two kernels + for i in range(2): + kid = str(uuid.uuid4()) + kernel_ids.append(kid) + info = await agent.create_kernel(kid, config) + container_ids.append(info['container_id']) + + # 2 containers are created + assert docker.containers.container(container_ids[0]) + assert docker.containers.container(container_ids[1]) + + await agent.reset() + + # Containers are destroyed + with pytest.raises(aiodocker.exceptions.DockerError): + c1 = docker.containers.container(container_ids[0]) + c1info = await c1.show() + if c1info['State']['Status'] == 'removing': + raise aiodocker.exceptions.DockerError( + 404, {'message': 'success'}) + with pytest.raises(aiodocker.exceptions.DockerError): + c2 = docker.containers.container(container_ids[1]) + c2info = await c2.show() + if c2info['State']['Status'] == 'removing': + raise aiodocker.exceptions.DockerError( + 404, {'message': 'success'}) + finally: + for cid in container_ids: + try: + container = docker.containers.container(cid) + cinfo = await container.show() if container else None + except aiodocker.exceptions.DockerError: + cinfo = None + if cinfo and cinfo['State']['Status'] != 'removing': + await container.delete(force=True) +''' diff --git a/tests/agent/test_stats.py b/tests/agent/test_stats.py new file mode 100644 index 0000000000..e0d3f21446 --- /dev/null +++ b/tests/agent/test_stats.py @@ -0,0 +1,5 @@ +# currently empty + +def test_dummy(): + # prevent pants error due to pytest exit code 5: "no tests collected" + pass diff --git a/tests/agent/test_utils.py b/tests/agent/test_utils.py new file mode 100644 index 0000000000..365c06a4eb --- /dev/null +++ b/tests/agent/test_utils.py @@ -0,0 +1,90 @@ +import tempfile + +import pytest + +from ai.backend.agent import utils + + +def test_read_sysfs(): + with tempfile.NamedTemporaryFile('w') as f: + f.write('10') + f.flush() + val = utils.read_sysfs(f.name, int) + assert isinstance(val, int) + assert val == 10 + val = utils.read_sysfs(f.name, str) + assert isinstance(val, str) + assert val == '10' + val = utils.read_sysfs(f.name, float) + assert isinstance(val, float) + assert val == 10.0 + + with tempfile.NamedTemporaryFile('w') as f: + f.write('1') + f.flush() + val = utils.read_sysfs(f.name, bool) + assert isinstance(val, bool) + assert val is True + f.seek(0, 0) + f.write('0') + f.flush() + val = utils.read_sysfs(f.name, bool) + assert isinstance(val, bool) + assert val is False + + val = utils.read_sysfs('/tmp/xxxxx-non-existent-file', int) + assert isinstance(val, int) + assert val == 0 + + val = utils.read_sysfs('/tmp/xxxxx-non-existent-file', int, -1) + assert isinstance(val, int) + assert val == -1 + + with pytest.raises(TypeError): + val = utils.read_sysfs('/tmp/xxxxx-non-existent-file', object) + + with pytest.raises(TypeError): + val = utils.read_sysfs('/tmp/xxxxx-non-existent-file', object, -1) + + +def test_update_nested_dict(): + o = { + 'a': 1, + 'b': 2, + } + utils.update_nested_dict(o, {'a': 3, 'c': 4}) + assert o == { + 'a': 3, + 'b': 2, + 'c': 4, + } + + o = { + 'a': { + 'x': 1, + }, + 'b': 2, + } + with pytest.raises(AssertionError): + utils.update_nested_dict(o, {'a': 3}) + + o = { + 'a': { + 'x': 1, + }, + 'b': 2, + } + utils.update_nested_dict(o, {'a': {'x': 3, 'y': 4}, 'b': 5}) + assert o['a'] == { + 'x': 3, + 'y': 4, + } + assert o['b'] == 5 + + o = { + 'a': [1, 2], + 'b': 3, + } + utils.update_nested_dict(o, {'a': [4, 5], 'b': 6}) + assert o['a'] == [1, 2, 4, 5] + assert o['b'] == 6 diff --git a/tests/common/BUILD b/tests/common/BUILD new file mode 100644 index 0000000000..a53eb1d84d --- /dev/null +++ b/tests/common/BUILD @@ -0,0 +1,44 @@ +python_test_utils( + name="test_utils", + sources=[ + "**/__init__.py", + "**/conftest.py", + "redis/*.py", + "!redis/test_*.py", + ], +) + +python_tests( + name="tests", + dependencies=[ + "src/ai/backend/common:lib", + "src/ai/backend/testutils:lib", + ], + sources=[ + "**/test_*.py", + ], +) + +pex_binary( + name="spawn-sentinel-cluster", + dependencies=[ + ":test_utils", + ], + entry_point="redis/docker.py", +) + +pex_binary( + name="spawn-compose-redis-sentinel-cluster", + dependencies=[ + ":test_utils", + ], + entry_point="redis/docker.py", +) + +pex_binary( + name="spawn-native-redis-sentinel-cluster", + dependencies=[ + ":test_utils", + ], + entry_point="redis/native.py", +) diff --git a/tests/common/__init__.py b/tests/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/common/conftest.py b/tests/common/conftest.py new file mode 100644 index 0000000000..8ab66e4c83 --- /dev/null +++ b/tests/common/conftest.py @@ -0,0 +1,150 @@ +import asyncio +import secrets +import time +from decimal import Decimal + +import pytest + +from ai.backend.common.etcd import AsyncEtcd, ConfigScopes +from ai.backend.testutils.bootstrap import etcd_container, redis_container # noqa: F401 + + +def pytest_addoption(parser): + parser.addoption( + "--do-test-redis", + action="store_true", + default=False, + help="run Redis tests", + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "redis: mark test as part of Redis test suite") + + +def pytest_collection_modifyitems(config, items): + if not config.getoption("--do-test-redis"): + # auto-skip tests marked with "redis" unless --test-redis option is given. + do_skip = pytest.mark.skip( + reason="skipped because no related files are changed", + ) + for item in items: + if "redis" in item.keywords: + item.add_marker(do_skip) + + +@pytest.fixture(scope="session", autouse=True) +def event_loop(): + # uvloop.install() + loop = asyncio.new_event_loop() + # setup_child_watcher() + yield loop + loop.close() + + +@pytest.fixture(scope="session", autouse=True) +def test_ns(): + return f'test-{secrets.token_hex(8)}' + + +@pytest.fixture +def test_case_ns(): + return secrets.token_hex(8) + + +@pytest.fixture +async def etcd(etcd_container, test_ns): # noqa: F811 + etcd = AsyncEtcd( + addr=etcd_container[1], + namespace=test_ns, + scope_prefix_map={ + ConfigScopes.GLOBAL: 'global', + ConfigScopes.SGROUP: 'sgroup/testing', + ConfigScopes.NODE: 'node/i-test', + }, + ) + try: + await etcd.delete_prefix('', scope=ConfigScopes.GLOBAL) + await etcd.delete_prefix('', scope=ConfigScopes.SGROUP) + await etcd.delete_prefix('', scope=ConfigScopes.NODE) + yield etcd + finally: + await etcd.delete_prefix('', scope=ConfigScopes.GLOBAL) + await etcd.delete_prefix('', scope=ConfigScopes.SGROUP) + await etcd.delete_prefix('', scope=ConfigScopes.NODE) + await etcd.close() + del etcd + + +@pytest.fixture +async def gateway_etcd(etcd_container, test_ns): # noqa: F811 + etcd = AsyncEtcd( + addr=etcd_container[1], + namespace=test_ns, + scope_prefix_map={ + ConfigScopes.GLOBAL: '', + }, + ) + try: + await etcd.delete_prefix('', scope=ConfigScopes.GLOBAL) + yield etcd + finally: + await etcd.delete_prefix('', scope=ConfigScopes.GLOBAL) + del etcd + + +@pytest.fixture +async def chaos_generator(): + + async def _chaos(): + try: + while True: + await asyncio.sleep(0.001) + except asyncio.CancelledError: + return + + tasks = [] + for i in range(20): + tasks.append(asyncio.create_task(_chaos())) + yield + for i in range(20): + tasks[i].cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +@pytest.fixture +def mock_time(mocker): + total_delay = Decimal(0) + call_count = 0 + base_time = time.monotonic() + accum_time = Decimal(0) + q = Decimal('.000000') + + async def _mock_async_sleep(delay: float) -> None: + nonlocal total_delay, call_count, accum_time, q + call_count += 1 + quantized_delay = Decimal(delay).quantize(q) + accum_time += quantized_delay + total_delay += quantized_delay + + def _reset() -> None: + nonlocal total_delay, call_count + total_delay = Decimal(0) + call_count = 0 + + def _get_total_delay() -> float: + nonlocal total_delay + return float(total_delay) + + def _get_call_count() -> int: + nonlocal call_count + return call_count + + def _mock_time_monotonic() -> float: + nonlocal accum_time + return base_time + float(accum_time) + + _mock_async_sleep.reset = _reset + _mock_async_sleep.get_total_delay = _get_total_delay + _mock_async_sleep.get_call_count = _get_call_count + yield _mock_async_sleep, _mock_time_monotonic diff --git a/tests/common/redis/.gitignore b/tests/common/redis/.gitignore new file mode 100644 index 0000000000..751553b3ac --- /dev/null +++ b/tests/common/redis/.gitignore @@ -0,0 +1 @@ +*.bak diff --git a/tests/common/redis/__init__.py b/tests/common/redis/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/common/redis/conftest.py b/tests/common/redis/conftest.py new file mode 100644 index 0000000000..0e00a93610 --- /dev/null +++ b/tests/common/redis/conftest.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import asyncio +import sys +from typing import ( + AsyncIterator, +) + +import pytest + +from .types import RedisClusterInfo +from .docker import DockerComposeRedisSentinelCluster +from .native import NativeRedisSentinelCluster +from .utils import wait_redis_ready + + +# A simple "redis_container" fixture is defined in ai.backend.testutils.bootstrap. + + +@pytest.fixture +async def redis_cluster(test_ns, test_case_ns) -> AsyncIterator[RedisClusterInfo]: + if sys.platform.startswith("darwin"): + impl = NativeRedisSentinelCluster + else: + impl = DockerComposeRedisSentinelCluster + cluster = impl(test_ns, test_case_ns, password="develove", service_name="mymaster") + async with cluster.make_cluster() as info: + node_wait_tasks = [ + wait_redis_ready(host, port, "develove") + for host, port in info.node_addrs + ] + sentinel_wait_tasks = [ + wait_redis_ready(host, port, None) + for host, port in info.sentinel_addrs + ] + await asyncio.gather(*node_wait_tasks, *sentinel_wait_tasks) + yield info diff --git a/tests/common/redis/docker.py b/tests/common/redis/docker.py new file mode 100644 index 0000000000..a662c43d70 --- /dev/null +++ b/tests/common/redis/docker.py @@ -0,0 +1,217 @@ +import asyncio +import contextlib +import json +import os +from pathlib import Path +import re +import signal +from typing import ( + AsyncIterator, + Tuple, +) + +import async_timeout +import pytest + +from ai.backend.testutils.pants import get_parallel_slot + +from .types import ( + AbstractRedisSentinelCluster, + AbstractRedisNode, + RedisClusterInfo, +) +from .utils import simple_run_cmd + + +class DockerRedisNode(AbstractRedisNode): + + def __init__(self, node_type: str, port: int, container_id: str) -> None: + self.node_type = node_type + self.port = port + get_parallel_slot() * 10 + self.container_id = container_id + + @property + def addr(self) -> Tuple[str, int]: + return ('127.0.0.1', self.port) + + def __str__(self) -> str: + return f"DockerRedisNode(cid:{self.container_id[:12]})" + + async def pause(self) -> None: + assert self.container_id is not None + print(f"Docker container {self.container_id[:12]} is being paused...") + await simple_run_cmd( + ['docker', 'pause', self.container_id], + # stdout=asyncio.subprocess.DEVNULL, + # stderr=asyncio.subprocess.DEVNULL, + ) + print(f"Docker container {self.container_id[:12]} is paused") + + async def unpause(self) -> None: + assert self.container_id is not None + await simple_run_cmd( + ['docker', 'unpause', self.container_id], + # stdout=asyncio.subprocess.DEVNULL, + # stderr=asyncio.subprocess.DEVNULL, + ) + print(f"Docker container {self.container_id[:12]} is unpaused") + + async def stop(self, force_kill: bool = False) -> None: + assert self.container_id is not None + if force_kill: + await simple_run_cmd( + ['docker', 'kill', self.container_id], + # stdout=asyncio.subprocess.DEVNULL, + # stderr=asyncio.subprocess.DEVNULL, + ) + print(f"Docker container {self.container_id[:12]} is killed") + else: + await simple_run_cmd( + ['docker', 'stop', self.container_id], + # stdout=asyncio.subprocess.DEVNULL, + # stderr=asyncio.subprocess.DEVNULL, + ) + print(f"Docker container {self.container_id[:12]} is terminated") + + async def start(self) -> None: + assert self.container_id is not None + await simple_run_cmd( + ['docker', 'start', self.container_id], + # stdout=asyncio.subprocess.DEVNULL, + # stderr=asyncio.subprocess.DEVNULL, + ) + print(f"Docker container {self.container_id[:12]} started") + + +class DockerComposeRedisSentinelCluster(AbstractRedisSentinelCluster): + + async def probe_docker_compose(self) -> list[str]: + # Try v2 first and fallback to v1 + p = await asyncio.create_subprocess_exec( + 'docker', 'compose', 'version', + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + exit_code = await p.wait() + if exit_code == 0: + compose_cmd = ['docker', 'compose'] + else: + compose_cmd = ['docker-compose'] + return compose_cmd + + @contextlib.asynccontextmanager + async def make_cluster(self) -> AsyncIterator[RedisClusterInfo]: + cfg_dir = Path(__file__).parent + compose_cfg = cfg_dir / 'redis-cluster.yml' + project_name = f"{self.test_ns}_{self.test_case_ns}" + compose_cmd = await self.probe_docker_compose() + + async with async_timeout.timeout(30.0): + p = await simple_run_cmd([ + *compose_cmd, + '-p', project_name, + '-f', os.fsencode(compose_cfg), + 'up', '-d', '--build', + ], stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL) + assert p.returncode == 0, "Compose cluster creation has failed." + + await asyncio.sleep(0.2) + try: + p = await asyncio.create_subprocess_exec( + *[ + *compose_cmd, + '-p', project_name, + '-f', str(compose_cfg), + 'ps', + '--format', 'json', + ], + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + assert p.stdout is not None + try: + ps_output = json.loads(await p.stdout.read()) + except json.JSONDecodeError: + pytest.fail("Cannot parse \"docker compose ... ps --format json\" output. " + "You may need to upgrade to docker-compose v2.0.0.rc.3 or later") + await p.wait() + worker_cids = {} + sentinel_cids = {} + + def find_port_node(item): + if m := re.search(r"--port (\d+) ", item['Command']): + return int(m.group(1)) + return None + + def find_port_sentinel(item): + if m := re.search(r"redis-sentinel(\d+)", item['Name']): + return 26379 + (int(m.group(1)) - 1) + return None + + if not ps_output: + pytest.fail("Cannot detect the temporary Redis cluster running as docker compose containers") + for item in ps_output: + if 'redis-node' in item['Name']: + port = find_port_node(item) + worker_cids[port] = item['ID'] + elif 'redis-sentinel' in item['Name']: + port = find_port_sentinel(item) + sentinel_cids[port] = item['ID'] + + yield RedisClusterInfo( + node_addrs=[ + ('127.0.0.1', 16379), + ('127.0.0.1', 16380), + ('127.0.0.1', 16381), + ], + nodes=[ + DockerRedisNode("node", 16379, worker_cids[16379]), + DockerRedisNode("node", 16380, worker_cids[16380]), + DockerRedisNode("node", 16381, worker_cids[16381]), + ], + sentinel_addrs=[ + ('127.0.0.1', 26379), + ('127.0.0.1', 26380), + ('127.0.0.1', 26381), + ], + sentinels=[ + DockerRedisNode("sentinel", 26379, sentinel_cids[26379]), + DockerRedisNode("sentinel", 26380, sentinel_cids[26380]), + DockerRedisNode("sentinel", 26381, sentinel_cids[26381]), + ], + ) + finally: + await asyncio.sleep(0.2) + async with async_timeout.timeout(30.0): + await simple_run_cmd([ + *compose_cmd, + '-p', project_name, + '-f', os.fsencode(compose_cfg), + 'down', + ], stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL) + await asyncio.sleep(0.2) + + +async def main(): + loop = asyncio.get_running_loop() + + async def redis_task(): + native_cluster = DockerComposeRedisSentinelCluster("testing", "testing-main", "develove", "testing") + async with native_cluster.make_cluster(): + while True: + await asyncio.sleep(10) + + t = asyncio.create_task(redis_task()) + loop.add_signal_handler(signal.SIGINT, t.cancel) + loop.add_signal_handler(signal.SIGTERM, t.cancel) + try: + await t + except asyncio.CancelledError: + pass + + +if __name__ == "__main__": + try: + asyncio.run(main()) + finally: + print("Terminated.") diff --git a/tests/common/redis/native.py b/tests/common/redis/native.py new file mode 100644 index 0000000000..6e1a6d8a05 --- /dev/null +++ b/tests/common/redis/native.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import asyncio +import contextlib +import os +from pathlib import Path +import signal +import textwrap +from typing import ( + AsyncIterator, + Sequence, + Tuple, +) + +from ai.backend.testutils.pants import get_parallel_slot + +from .types import ( + AbstractRedisSentinelCluster, + AbstractRedisNode, + RedisClusterInfo, +) + + +class NativeRedisNode(AbstractRedisNode): + + proc: asyncio.subprocess.Process | None + + def __init__(self, node_type: str, port: int, start_args: Sequence[str | bytes]) -> None: + self.node_type = node_type + self.port = port + get_parallel_slot() * 10 + self.start_args = start_args + self.proc = None + + @property + def addr(self) -> Tuple[str, int]: + return ('127.0.0.1', self.port) + + def __str__(self) -> str: + if self.proc is None: + return "NativeRedisNode(not-running)" + return f"NativeRedisNode(pid:{self.proc.pid})" + + async def pause(self) -> None: + assert self.proc is not None + self.proc.send_signal(signal.SIGSTOP) + await asyncio.sleep(0) + + async def unpause(self) -> None: + assert self.proc is not None + self.proc.send_signal(signal.SIGCONT) + await asyncio.sleep(0) + + async def stop(self, force_kill: bool = False) -> None: + assert self.proc is not None + try: + if force_kill: + self.proc.kill() + else: + self.proc.terminate() + exit_code = await self.proc.wait() + print(f"Redis {self.node_type} (pid:{self.proc.pid}) has terminated with exit code {exit_code}.") + except ProcessLookupError: + print(f"Redis {self.node_type} (pid:{self.proc.pid}) already terminated") + finally: + self.proc = None + + async def start(self) -> None: + assert self.proc is None + self.proc = await asyncio.create_subprocess_exec( + *self.start_args, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + start_new_session=True, # prevent signal propagation + ) + print(f"Redis {self.node_type} (pid:{self.proc.pid}, port:{self.port}) started.") + + +class NativeRedisSentinelCluster(AbstractRedisSentinelCluster): + + @contextlib.asynccontextmanager + async def make_cluster(self) -> AsyncIterator[RedisClusterInfo]: + nodes = [] + sentinels = [] + sentinel_config = textwrap.dedent(f""" + sentinel resolve-hostnames yes + sentinel monitor {self.service_name} 127.0.0.1 16379 2 + sentinel auth-pass {self.service_name} {self.password} + sentinel down-after-milliseconds {self.service_name} 1000 + sentinel failover-timeout {self.service_name} 5000 + sentinel parallel-syncs {self.service_name} 2 + protected-mode no + """).lstrip() + for node_port in [16379, 16380, 16381]: + rdb_path = Path(f"node.{node_port}.rdb") + try: + rdb_path.unlink() + except FileNotFoundError: + pass + node = NativeRedisNode( + "node", + node_port, + [ + "redis-server", + "--bind", "127.0.0.1", + "--port", str(node_port), + "--requirepass", self.password, + "--masterauth", self.password, + ] + ( + [] + if node_port == 16379 + else ["--slaveof", "127.0.0.1", "16379"] + ) + [ + "--cluster-announce-ip", "127.0.0.1", + "--min-slaves-to-write", "1", + "--min-slaves-max-lag", "10", + "--dbfilename", str(rdb_path), + ], + ) + nodes.append(node) + for sentinel_port in [26379, 26380, 26381]: + # Redis sentinels store their states in the config files (not rdb!), + # so the files should be separate to each sentinel instance. + sentinel_conf_path = Path(f"sentinel.{sentinel_port}.conf") + sentinel_conf_path.write_text(sentinel_config) + sentinel = NativeRedisNode( + "sentinel", + sentinel_port, + [ + "redis-server", + os.fsencode(sentinel_conf_path), + "--bind", "127.0.0.1", + "--port", str(sentinel_port), + "--sentinel", + ], + ) + sentinels.append(sentinel) + await asyncio.gather(*[node.start() for node in nodes]) + await asyncio.sleep(0.1) + await asyncio.gather(*[sentinel.start() for sentinel in sentinels]) + try: + yield RedisClusterInfo( + node_addrs=[ + ('127.0.0.1', 16379), + ('127.0.0.1', 16380), + ('127.0.0.1', 16381), + ], + nodes=nodes, + sentinel_addrs=[ + ('127.0.0.1', 26379), + ('127.0.0.1', 26380), + ('127.0.0.1', 26381), + ], + sentinels=sentinels, + ) + except asyncio.CancelledError: + raise + finally: + await asyncio.gather(*[sentinel.stop() for sentinel in sentinels]) + await asyncio.sleep(0.1) + await asyncio.gather(*[node.stop() for node in nodes]) + + +async def main(): + loop = asyncio.get_running_loop() + + async def redis_task(): + native_cluster = NativeRedisSentinelCluster("testing", "testing-main", "develove", "testing") + async with native_cluster.make_cluster(): + while True: + await asyncio.sleep(10) + + t = asyncio.create_task(redis_task()) + loop.add_signal_handler(signal.SIGINT, t.cancel) + loop.add_signal_handler(signal.SIGTERM, t.cancel) + try: + await t + except asyncio.CancelledError: + pass + + +if __name__ == "__main__": + try: + asyncio.run(main()) + finally: + print("Terminated.") diff --git a/tests/common/redis/redis-cluster.yml b/tests/common/redis/redis-cluster.yml new file mode 100644 index 0000000000..1a06306a30 --- /dev/null +++ b/tests/common/redis/redis-cluster.yml @@ -0,0 +1,90 @@ +version: "3.7" + +services: + + # Initial master is node01. + backendai-half-redis-node01: + image: redis:6-alpine + command: > + redis-server + --port 16379 + --requirepass ${REDIS_PASSWORD:-develove} + --masterauth ${REDIS_PASSWORD:-develove} + --cluster-announce-ip 127.0.0.1 + --min-slaves-to-write 1 + --min-slaves-max-lag 10 + network_mode: host + + backendai-half-redis-node02: + image: redis:6-alpine + command: > + redis-server + --port 16380 + --requirepass ${REDIS_PASSWORD:-develove} + --masterauth ${REDIS_PASSWORD:-develove} + --slaveof 127.0.0.1 16379 + --cluster-announce-ip 127.0.0.1 + --min-slaves-to-write 1 + --min-slaves-max-lag 10 + network_mode: host + + backendai-half-redis-node03: + image: redis:6-alpine + command: > + redis-server + --port 16381 + --requirepass ${REDIS_PASSWORD:-develove} + --masterauth ${REDIS_PASSWORD:-develove} + --slaveof 127.0.0.1 16379 + --cluster-announce-ip 127.0.0.1 + --min-slaves-to-write 1 + --min-slaves-max-lag 10 + network_mode: host + + backendai-half-redis-sentinel01: + build: + context: . + dockerfile: redis-sentinel.dockerfile + cache_from: + - redis:5-alpine + image: redis-sentinel:testing + environment: + - REDIS_PASSWORD=${REDIS_PASSWORD:-develove} + - REDIS_PORT=26379 + depends_on: + - backendai-half-redis-node01 + - backendai-half-redis-node02 + - backendai-half-redis-node03 + network_mode: host + + backendai-half-redis-sentinel02: + build: + context: . + dockerfile: redis-sentinel.dockerfile + cache_from: + - redis:6-alpine + image: redis-sentinel:testing + environment: + - REDIS_PASSWORD=${REDIS_PASSWORD:-develove} + - REDIS_PORT=26380 + depends_on: + - backendai-half-redis-node01 + - backendai-half-redis-node02 + - backendai-half-redis-node03 + network_mode: host + + backendai-half-redis-sentinel03: + build: + context: . + dockerfile: redis-sentinel.dockerfile + cache_from: + - redis:6-alpine + image: redis-sentinel:testing + environment: + - REDIS_PASSWORD=${REDIS_PASSWORD:-develove} + - REDIS_PORT=26381 + depends_on: + - backendai-half-redis-node01 + - backendai-half-redis-node02 + - backendai-half-redis-node03 + network_mode: host diff --git a/tests/common/redis/redis-sentinel.dockerfile b/tests/common/redis/redis-sentinel.dockerfile new file mode 100644 index 0000000000..d412d71483 --- /dev/null +++ b/tests/common/redis/redis-sentinel.dockerfile @@ -0,0 +1,9 @@ +FROM redis:6-alpine + +COPY ./sentinel.conf /etc/redis-sentinel.conf + +CMD sed -i'' "s/REDIS_PASSWORD/${REDIS_PASSWORD}/g" /etc/redis-sentinel.conf; \ + redis-server /etc/redis-sentinel.conf --sentinel --port ${REDIS_PORT} + + +# vim: ft=dockerfile diff --git a/tests/common/redis/sentinel.conf b/tests/common/redis/sentinel.conf new file mode 100644 index 0000000000..db84cf1ebb --- /dev/null +++ b/tests/common/redis/sentinel.conf @@ -0,0 +1,7 @@ +sentinel resolve-hostnames yes +sentinel monitor mymaster 127.0.0.1 16379 2 +sentinel auth-pass mymaster REDIS_PASSWORD +sentinel down-after-milliseconds mymaster 1000 +sentinel failover-timeout mymaster 5000 +sentinel parallel-syncs mymaster 2 +protected-mode no diff --git a/tests/common/redis/test_connect.py b/tests/common/redis/test_connect.py new file mode 100644 index 0000000000..2f912d4b80 --- /dev/null +++ b/tests/common/redis/test_connect.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import aioredis +import aioredis.client +import aioredis.exceptions +import aioredis.sentinel +import aiotools +import pytest + +from .types import RedisClusterInfo +from .utils import interrupt, with_timeout + +from ai.backend.common import redis, validators as tx +from ai.backend.common.types import HostPortPair + +if TYPE_CHECKING: + from typing import Any + + +@pytest.mark.asyncio +async def test_connect(redis_container: tuple[str, HostPortPair]) -> None: + addr = redis_container[1] + r = aioredis.from_url( + url=f'redis://{addr.host}:{addr.port}', + socket_timeout=0.5, + ) + await r.ping() + + +@pytest.mark.redis +@pytest.mark.asyncio +async def test_instantiate_redisconninfo() -> None: + sentinels = '127.0.0.1:26379,127.0.0.1:26380,127.0.0.1:26381' + r1 = redis.get_redis_object({ + 'sentinel': sentinels, + 'service_name': 'mymaster', + 'password': 'develove', + }) + + assert isinstance(r1.client, aioredis.sentinel.Sentinel) + + for i in range(3): + assert r1.client.sentinels[i].connection_pool.connection_kwargs['host'] == '127.0.0.1' + assert r1.client.sentinels[i].connection_pool.connection_kwargs['port'] == (26379 + i) + assert r1.client.sentinels[i].connection_pool.connection_kwargs['db'] == 0 + + parsed_addresses: Any = tx.DelimiterSeperatedList(tx.HostPortPair).check_and_return(sentinels) + r2 = redis.get_redis_object({ + 'sentinel': parsed_addresses, + 'service_name': 'mymaster', + 'password': 'develove', + }) + + assert isinstance(r2.client, aioredis.sentinel.Sentinel) + + for i in range(3): + assert r2.client.sentinels[i].connection_pool.connection_kwargs['host'] == '127.0.0.1' + assert r2.client.sentinels[i].connection_pool.connection_kwargs['port'] == (26379 + i) + assert r2.client.sentinels[i].connection_pool.connection_kwargs['db'] == 0 + + +@pytest.mark.redis +@pytest.mark.asyncio +@with_timeout(30.0) +async def test_connect_cluster_sentinel(redis_cluster: RedisClusterInfo) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + + async def control_interrupt() -> None: + await asyncio.sleep(1) + do_pause.set() + await paused.wait() + await asyncio.sleep(2) + do_unpause.set() + await unpaused.wait() + + s = aioredis.sentinel.Sentinel( + redis_cluster.sentinel_addrs, + password='develove', + socket_timeout=0.5, + ) + async with aiotools.TaskGroup() as tg: + tg.create_task(control_interrupt()) + tg.create_task(interrupt( + 'stop', + redis_cluster.nodes[0], + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + redis_password='develove', + )) + await asyncio.sleep(0) + + for _ in range(5): + print(f"CONNECT REPEAT {_}") + try: + master_addr = await s.discover_master('mymaster') + print("MASTER", master_addr) + except aioredis.sentinel.MasterNotFoundError: + print("MASTER (not found)") + try: + slave_addrs = await s.discover_slaves('mymaster') + print("SLAVE", slave_addrs) + slave = s.slave_for('mymaster', db=9) + await slave.ping() + except aioredis.sentinel.SlaveNotFoundError: + print("SLAVE (not found)") + await asyncio.sleep(1) diff --git a/tests/common/redis/test_list.py b/tests/common/redis/test_list.py new file mode 100644 index 0000000000..594cdfd328 --- /dev/null +++ b/tests/common/redis/test_list.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +import asyncio +from typing import ( + List, +) + +import aioredis +import aioredis.client +import aioredis.exceptions +import aioredis.sentinel +import aiotools +import pytest + +from ai.backend.common import redis +from ai.backend.common.types import RedisConnectionInfo + +from .docker import DockerRedisNode +from .types import RedisClusterInfo +from .utils import interrupt, with_timeout + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.xfail +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +async def test_blist(redis_container: str, disruption_method: str) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: List[str] = [] + + async def pop(r: RedisConnectionInfo, key: str) -> None: + try: + async with aiotools.aclosing( + redis.blpop(r, key, reconnect_poll_interval=0.3), + ) as agen: + async for raw_msg in agen: + msg = raw_msg.decode() + received_messages.append(msg) + except asyncio.CancelledError: + pass + + r = RedisConnectionInfo( + aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + service_name=None, + ) + assert isinstance(r.client, aioredis.Redis) + await r.client.delete("bl1") + + pop_task = asyncio.create_task(pop(r, "bl1")) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + DockerRedisNode("node", 9379, redis_container), + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + )) + await asyncio.sleep(0) + + for i in range(5): + await r.client.rpush("bl1", str(i)) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + for i in range(5): + # The Redis server is dead temporarily... + if disruption_method == 'stop': + with pytest.raises(aioredis.exceptions.ConnectionError): + await r.client.rpush("bl1", str(5 + i)) + elif disruption_method == 'pause': + with pytest.raises(asyncio.TimeoutError): + await r.client.rpush("bl1", str(5 + i)) + else: + raise RuntimeError("should not reach here") + await asyncio.sleep(0.1) + do_unpause.set() + await unpaused.wait() + for i in range(5): + await r.client.rpush("bl1", str(10 + i)) + await asyncio.sleep(0.1) + + await interrupt_task + pop_task.cancel() + await pop_task + assert pop_task.done() + + all_messages = set(map(int, received_messages)) + assert set(range(0, 5)) < all_messages + assert set(range(13, 15)) < all_messages # more msgs may be lost during restart + assert all_messages <= set(range(0, 15)) + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.xfail +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +async def test_blist_with_retrying_rpush(redis_container: str, disruption_method: str) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: List[str] = [] + + async def pop(r: RedisConnectionInfo, key: str) -> None: + try: + async with aiotools.aclosing( + redis.blpop(r, key, reconnect_poll_interval=0.3), + ) as agen: + async for raw_msg in agen: + msg = raw_msg.decode() + received_messages.append(msg) + except asyncio.CancelledError: + pass + + r = RedisConnectionInfo( + aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + service_name=None, + ) + assert isinstance(r.client, aioredis.Redis) + await r.client.delete("bl1") + + pop_task = asyncio.create_task(pop(r, "bl1")) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + DockerRedisNode("node", 9379, redis_container), + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + )) + await asyncio.sleep(0) + + for i in range(5): + await redis.execute(r, lambda r: r.rpush("bl1", str(i))) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + + async def wakeup(): + await asyncio.sleep(2.0) + do_unpause.set() + + wakeup_task = asyncio.create_task(wakeup()) + for i in range(5): + await redis.execute(r, lambda r: r.rpush("bl1", str(5 + i))) + await asyncio.sleep(0.1) + await wakeup_task + + await unpaused.wait() + for i in range(5): + await redis.execute(r, lambda r: r.rpush("bl1", str(10 + i))) + await asyncio.sleep(0.1) + + await interrupt_task + pop_task.cancel() + await pop_task + assert pop_task.done() + + all_messages = set(map(int, received_messages)) + assert set(range(0, 5)) < all_messages + assert set(range(13, 15)) < all_messages # more msgs may be lost during restart + assert all_messages <= set(range(0, 15)) + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.xfail +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +@with_timeout(30.0) +async def test_blist_cluster_sentinel( + redis_cluster: RedisClusterInfo, + disruption_method: str, +) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: List[str] = [] + + async def pop(s: RedisConnectionInfo, key: str) -> None: + try: + async with aiotools.aclosing( + redis.blpop( + s, key, + reconnect_poll_interval=0.3, + service_name="mymaster", + ), + ) as agen: + async for raw_msg in agen: + msg = raw_msg.decode() + received_messages.append(msg) + except asyncio.CancelledError: + pass + + s = RedisConnectionInfo( + aioredis.sentinel.Sentinel( + redis_cluster.sentinel_addrs, + password='develove', + socket_timeout=0.5, + ), + service_name='mymaster', + ) + await redis.execute(s, lambda r: r.delete("bl1")) + + pop_task = asyncio.create_task(pop(s, "bl1")) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + redis_cluster.nodes[0], + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + redis_password='develove', + )) + await asyncio.sleep(0) + + for i in range(5): + await redis.execute( + s, + lambda r: r.rpush("bl1", str(i)), + service_name="mymaster", + ) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + + async def wakeup(): + await asyncio.sleep(2.0) + do_unpause.set() + + wakeup_task = asyncio.create_task(wakeup()) + for i in range(5): + await redis.execute( + s, + lambda r: r.rpush("bl1", str(5 + i)), + service_name="mymaster", + ) + await asyncio.sleep(0.1) + await wakeup_task + + await unpaused.wait() + for i in range(5): + await redis.execute( + s, + lambda r: r.rpush("bl1", str(10 + i)), + service_name="mymaster", + ) + await asyncio.sleep(0.1) + + await interrupt_task + pop_task.cancel() + await pop_task + assert pop_task.done() + + if disruption_method == "stop": + assert [*map(int, received_messages)] == [*range(0, 15)] + else: + # loss happens during failover + all_messages = set(map(int, received_messages)) + assert set(range(0, 5)) < all_messages + assert set(range(10, 15)) < all_messages + assert all_messages <= set(range(0, 15)) diff --git a/tests/common/redis/test_pipeline.py b/tests/common/redis/test_pipeline.py new file mode 100644 index 0000000000..570b81324d --- /dev/null +++ b/tests/common/redis/test_pipeline.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from unittest import mock + +import aioredis +import aioredis.client +import aioredis.sentinel +import pytest + +from ai.backend.common.redis import execute +from ai.backend.common.types import RedisConnectionInfo + +from .types import RedisClusterInfo + + +@pytest.mark.redis +@pytest.mark.asyncio +async def test_pipeline_single_instance(redis_container: str) -> None: + rconn = RedisConnectionInfo( + aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + service_name=None, + ) + + def _build_pipeline(r: aioredis.Redis) -> aioredis.client.Pipeline: + pipe = r.pipeline(transaction=False) + pipe.set("xyz", "123") + pipe.incr("xyz") + return pipe + + results = await execute(rconn, _build_pipeline) + assert results[0] is True + assert str(results[1]) == "124" + + actual_value = await execute(rconn, lambda r: r.get("xyz")) + assert actual_value == b"124" + + async def _build_pipeline_async(r: aioredis.Redis) -> aioredis.client.Pipeline: + pipe = r.pipeline(transaction=False) + pipe.set("abc", "123") + pipe.incr("abc") + return pipe + + results = await execute(rconn, _build_pipeline_async) + assert results[0] is True + assert str(results[1]) == "124" + + actual_value = await execute(rconn, lambda r: r.get("abc")) + assert actual_value == b"124" + + +@pytest.mark.redis +@pytest.mark.asyncio +async def test_pipeline_single_instance_retries(redis_container: str) -> None: + rconn = RedisConnectionInfo( + aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + service_name=None, + ) + + build_count = 0 + + patcher = mock.patch( + 'aioredis.client.Pipeline._execute_pipeline', + side_effect=[ConnectionResetError, ConnectionResetError, mock.DEFAULT], + ) + patcher.start() + + def _build_pipeline(r: aioredis.Redis) -> aioredis.client.Pipeline: + nonlocal build_count, patcher + build_count += 1 + if build_count == 3: + # Restore the original function. + patcher.stop() + pipe = r.pipeline(transaction=False) + pipe.set("xyz", "123") + pipe.incr("xyz") + return pipe + + results = await execute(rconn, _build_pipeline, reconnect_poll_interval=0.01) + assert build_count == 3 + assert results[0] is True + assert results[1] == 124 + + actual_value = await execute(rconn, lambda r: r.get("xyz")) + assert actual_value == b"124" + + build_count = 0 + + patcher = mock.patch( + 'aioredis.client.Pipeline._execute_pipeline', + side_effect=[ConnectionResetError, ConnectionResetError, mock.DEFAULT], + ) + patcher.start() + + async def _build_pipeline_async(r: aioredis.Redis) -> aioredis.client.Pipeline: + nonlocal build_count, patcher + build_count += 1 + if build_count == 3: + # Restore the original function. + patcher.stop() + pipe = r.pipeline(transaction=False) + pipe.set("abc", "456") + pipe.incr("abc") + return pipe + + results = await execute(rconn, _build_pipeline_async, reconnect_poll_interval=0.01) + assert build_count == 3 + assert results[0] is True + assert results[1] == 457 + + actual_value = await execute(rconn, lambda r: r.get("abc")) + assert actual_value == b"457" + + +@pytest.mark.redis +@pytest.mark.asyncio +async def test_pipeline_sentinel_cluster(redis_cluster: RedisClusterInfo) -> None: + rconn = RedisConnectionInfo( + aioredis.sentinel.Sentinel( + redis_cluster.sentinel_addrs, + password='develove', + socket_timeout=0.5, + ), + service_name='mymaster', + ) + + def _build_pipeline(r: aioredis.Redis) -> aioredis.client.Pipeline: + pipe = r.pipeline(transaction=False) + pipe.set("xyz", "123") + pipe.incr("xyz") + return pipe + + results = await execute(rconn, _build_pipeline) + assert results[0] is True + assert str(results[1]) == "124" diff --git a/tests/common/redis/test_pubsub.py b/tests/common/redis/test_pubsub.py new file mode 100644 index 0000000000..31893be0bc --- /dev/null +++ b/tests/common/redis/test_pubsub.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import asyncio +from typing import ( + List, +) + +import aioredis +import aioredis.client +import aioredis.exceptions +import aiotools +import pytest + +from .docker import DockerRedisNode +from .utils import interrupt + +from ai.backend.common import redis +from ai.backend.common.types import RedisConnectionInfo + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.xfail +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +async def test_pubsub(redis_container: str, disruption_method: str) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: List[str] = [] + + async def subscribe(pubsub: aioredis.client.PubSub) -> None: + try: + async with aiotools.aclosing( + redis.subscribe(pubsub, reconnect_poll_interval=0.3), + ) as agen: + async for raw_msg in agen: + msg = raw_msg.decode() + received_messages.append(msg) + except asyncio.CancelledError: + pass + + r = RedisConnectionInfo( + aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + service_name=None, + ) + assert isinstance(r.client, aioredis.Redis) + await r.client.delete("ch1") + pubsub = r.client.pubsub() + async with pubsub: + await pubsub.subscribe("ch1") + + subscribe_task = asyncio.create_task(subscribe(pubsub)) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + DockerRedisNode("node", 9379, redis_container), + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + )) + await asyncio.sleep(0) + + for i in range(5): + await r.client.publish("ch1", str(i)) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + for i in range(5): + # The Redis server is dead temporarily... + if disruption_method == 'stop': + with pytest.raises(aioredis.exceptions.ConnectionError): + await r.client.publish("ch1", str(5 + i)) + elif disruption_method == 'pause': + with pytest.raises(asyncio.TimeoutError): + await r.client.publish("ch1", str(5 + i)) + else: + raise RuntimeError("should not reach here") + await asyncio.sleep(0.1) + do_unpause.set() + await unpaused.wait() + for i in range(5): + await r.client.publish("ch1", str(10 + i)) + await asyncio.sleep(0.1) + + await interrupt_task + subscribe_task.cancel() + await subscribe_task + assert subscribe_task.done() + + if disruption_method == 'stop': + all_messages = set(map(int, received_messages)) + assert set(range(0, 5)) <= all_messages + assert set(range(13, 15)) <= all_messages # more msgs may be lost during restart + assert all_messages <= set(range(0, 15)) + elif disruption_method == 'pause': + # Temporary pause of the container makes the kernel TCP stack to keep the packets. + assert [*map(int, received_messages)] == [*range(0, 15)] + else: + raise RuntimeError("should not reach here") + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.xfail +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +async def test_pubsub_with_retrying_pub(redis_container: str, disruption_method: str) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: List[str] = [] + + async def subscribe(pubsub: aioredis.client.PubSub) -> None: + try: + async with aiotools.aclosing( + redis.subscribe(pubsub, reconnect_poll_interval=0.3), + ) as agen: + async for raw_msg in agen: + msg = raw_msg.decode() + received_messages.append(msg) + except asyncio.CancelledError: + pass + + r = RedisConnectionInfo( + aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + service_name=None, + ) + assert isinstance(r.client, aioredis.Redis) + await r.client.delete("ch1") + pubsub = r.client.pubsub() + async with pubsub: + await pubsub.subscribe("ch1") + + subscribe_task = asyncio.create_task(subscribe(pubsub)) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + DockerRedisNode("node", 9379, redis_container), + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + )) + await asyncio.sleep(0) + + for i in range(5): + await redis.execute(r, lambda r: r.publish("ch1", str(i))) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + + async def wakeup(): + await asyncio.sleep(0.3) + do_unpause.set() + + wakeup_task = asyncio.create_task(wakeup()) + for i in range(5): + await redis.execute(r, lambda r: r.publish("ch1", str(5 + i))) + await asyncio.sleep(0.1) + await wakeup_task + + await unpaused.wait() + for i in range(5): + await redis.execute(r, lambda r: r.publish("ch1", str(10 + i))) + await asyncio.sleep(0.1) + + await interrupt_task + subscribe_task.cancel() + await subscribe_task + assert subscribe_task.done() + + all_messages = set(map(int, received_messages)) + assert set(range(0, 5)) <= all_messages + assert set(range(13, 15)) <= all_messages # more msgs may be lost during restart + assert all_messages <= set(range(0, 15)) + + +# FIXME: The below test case hangs... +# We skipped this issue because now we use Redis streams instead of pub-sub. +r""" +@pytest.mark.redis +@pytest.mark.asyncio +async def test_pubsub_cluster_sentinel(redis_cluster: RedisClusterInfo) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: List[str] = [] + + async def interrupt() -> None: + await do_pause.wait() + await simple_run_cmd(['docker', 'stop', redis_container]) + paused.set() + await do_unpause.wait() + await simple_run_cmd(['docker', 'start', redis_container]) + # The pub-sub channel may loose some messages while starting up. + # Make a pause here to wait until the container actually begins to listen. + await asyncio.sleep(0.5) + unpaused.set() + + async def subscribe(pubsub: aioredis.client.PubSub) -> None: + try: + async with aiotools.aclosing( + redis.subscribe(pubsub, reconnect_poll_interval=0.3) + ) as agen: + async for raw_msg in agen: + msg = raw_msg.decode() + print("SUBSCRIBE", msg) + received_messages.append(msg) + except asyncio.CancelledError: + pass + + s = aioredis.sentinel.Sentinel( + redis_cluster.sentinel_addrs, + password='develove', + socket_timeout=0.5, + ) + await redis.execute(s, lambda r: r.delete("ch1"), service_name="mymaster") + + m = s.master_for("mymaster") + pubsub = m.pubsub() + async with pubsub: + await pubsub.subscribe("ch1") + + subscribe_task = asyncio.create_task(subscribe(pubsub)) + interrupt_task = asyncio.create_task(interrupt()) + await asyncio.sleep(0) + + for i in range(5): + await redis.execute(s, lambda r: r.publish("ch1", str(i)), service_name="mymaster") + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + + async def wakeup(): + await asyncio.sleep(2.0) + do_unpause.set() + + wakeup_task = asyncio.create_task(wakeup()) + for i in range(5): + await redis.execute(s, lambda r: r.publish("ch1", str(5 + i)), service_name="mymaster") + await asyncio.sleep(0.1) + await wakeup_task + + await unpaused.wait() + for i in range(5): + await redis.execute(s, lambda r: r.publish("ch1", str(10 + i)), service_name="mymaster") + await asyncio.sleep(0.1) + + await interrupt_task + subscribe_task.cancel() + await subscribe_task + assert subscribe_task.done() + + assert [*map(int, received_messages)] == [*range(0, 15)] +""" diff --git a/tests/common/redis/test_stream.py b/tests/common/redis/test_stream.py new file mode 100644 index 0000000000..484c12b12d --- /dev/null +++ b/tests/common/redis/test_stream.py @@ -0,0 +1,395 @@ +from __future__ import annotations + +import asyncio +import sys +import traceback +from typing import ( + Dict, + List, +) + +import aioredis +import aioredis.client +import aioredis.exceptions +import aioredis.sentinel +import aiotools +from aiotools.context import aclosing +import pytest + +from ai.backend.common import redis +from ai.backend.common.types import RedisConnectionInfo + +from .docker import DockerRedisNode +from .types import RedisClusterInfo +from .utils import interrupt, with_timeout + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +async def test_stream_fanout(redis_container: str, disruption_method: str, chaos_generator) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: Dict[str, List[str]] = { + "c1": [], + "c2": [], + } + + async def consume( + consumer_id: str, + r: RedisConnectionInfo, + key: str, + ) -> None: + try: + async with aclosing(redis.read_stream(r, key)) as agen: + async for msg_id, msg_data in agen: + print(f"XREAD[{consumer_id}]", msg_id, repr(msg_data), file=sys.stderr) + received_messages[consumer_id].append(msg_data[b"idx"]) + except asyncio.CancelledError: + return + except Exception as e: + print("STREAM_FANOUT.CONSUME: unexpected error", repr(e), file=sys.stderr) + raise + + r = RedisConnectionInfo( + aioredis.from_url('redis://localhost:9379', socket_timeout=0.5), + service_name=None, + ) + assert isinstance(r.client, aioredis.Redis) + await redis.execute(r, lambda r: r.delete("stream1")) + + consumer_tasks = [ + asyncio.create_task(consume("c1", r, "stream1")), + asyncio.create_task(consume("c2", r, "stream1")), + ] + await asyncio.sleep(0.1) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + DockerRedisNode("node", 9379, redis_container), + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + )) + await asyncio.sleep(0) + + for i in range(5): + await r.client.xadd("stream1", {"idx": i}) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + loop = asyncio.get_running_loop() + loop.call_later(5.0, do_unpause.set) + for i in range(5): + # The Redis server is dead temporarily... + if disruption_method == 'stop': + with pytest.raises(aioredis.exceptions.ConnectionError): + await r.client.xadd("stream1", {"idx": 5 + i}) + elif disruption_method == 'pause': + with pytest.raises(asyncio.TimeoutError): + await r.client.xadd("stream1", {"idx": 5 + i}) + else: + raise RuntimeError("should not reach here") + await asyncio.sleep(0.1) + await unpaused.wait() + for i in range(5): + await r.client.xadd("stream1", {"idx": 10 + i}) + await asyncio.sleep(0.1) + + await interrupt_task + for t in consumer_tasks: + t.cancel() + await t + for t in consumer_tasks: + assert t.done() + + if disruption_method == "stop": + # loss happens + assert {*map(int, received_messages["c1"])} >= {*range(0, 5)} | {*range(10, 15)} + assert {*map(int, received_messages["c2"])} >= {*range(0, 5)} | {*range(10, 15)} + else: + # loss does not happen + # pause keeps the TCP connection and the messages are delivered late. + assert [*map(int, received_messages["c1"])] == [*range(0, 15)] + assert [*map(int, received_messages["c2"])] == [*range(0, 15)] + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +@with_timeout(30.0) +async def test_stream_fanout_cluster(redis_cluster: RedisClusterInfo, disruption_method: str, chaos_generator) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: Dict[str, List[str]] = { + "c1": [], + "c2": [], + } + + async def consume( + consumer_id: str, + r: RedisConnectionInfo, + key: str, + ) -> None: + try: + async with aclosing(redis.read_stream(r, key)) as agen: + async for msg_id, msg_data in agen: + print(f"XREAD[{consumer_id}]", msg_id, repr(msg_data), file=sys.stderr) + received_messages[consumer_id].append(msg_data[b"idx"]) + except asyncio.CancelledError: + return + except Exception as e: + print("STREAM_FANOUT.CONSUME: unexpected error", repr(e), file=sys.stderr) + raise + + s = RedisConnectionInfo( + aioredis.sentinel.Sentinel( + redis_cluster.sentinel_addrs, + password='develove', + socket_timeout=0.5, + ), + service_name='mymaster', + ) + _execute = aiotools.apartial(redis.execute, s) + await _execute(lambda r: r.delete("stream1")) + + consumer_tasks = [ + asyncio.create_task(consume("c1", s, "stream1")), + asyncio.create_task(consume("c2", s, "stream1")), + ] + await asyncio.sleep(0.1) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + redis_cluster.nodes[0], + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + redis_password='develove', + )) + await asyncio.sleep(0) + + try: + for i in range(5): + await _execute(lambda r: r.xadd("stream1", {"idx": i})) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + loop = asyncio.get_running_loop() + loop.call_later(5.0, do_unpause.set) + for i in range(5): + await _execute(lambda r: r.xadd("stream1", {"idx": 5 + i})) + await asyncio.sleep(0.1) + await unpaused.wait() + for i in range(5): + await _execute(lambda r: r.xadd("stream1", {"idx": 10 + i})) + await asyncio.sleep(0.1) + finally: + await interrupt_task + for t in consumer_tasks: + t.cancel() + await t + for t in consumer_tasks: + assert t.done() + + if disruption_method == "stop": + # loss does not happen due to retries + assert [*map(int, received_messages["c1"])] == [*range(0, 15)] + assert [*map(int, received_messages["c2"])] == [*range(0, 15)] + else: + # loss happens during failover + assert {*map(int, received_messages["c1"])} >= {*range(0, 5)} | {*range(10, 15)} + assert {*map(int, received_messages["c2"])} >= {*range(0, 5)} | {*range(10, 15)} + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +@with_timeout(30.0) +async def test_stream_loadbalance(redis_container: str, disruption_method: str, chaos_generator) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: Dict[str, List[str]] = { + "c1": [], + "c2": [], + } + + async def consume( + group_name: str, + consumer_id: str, + r: RedisConnectionInfo, + key: str, + ) -> None: + try: + async with aclosing(redis.read_stream_by_group( + r, key, group_name, consumer_id, + autoclaim_idle_timeout=500, + )) as agen: + async for msg_id, msg_data in agen: + print(f"-> message: {msg_id} {msg_data!r}") + received_messages[consumer_id].append(msg_data[b"idx"]) + except asyncio.CancelledError: + return + except Exception: + traceback.print_exc() + return + + r = RedisConnectionInfo( + aioredis.from_url(url='redis://localhost:9379', socket_timeout=0.5), + service_name=None, + ) + assert isinstance(r.client, aioredis.Redis) + await redis.execute(r, lambda r: r.delete("stream1")) + await redis.execute(r, lambda r: r.xgroup_create("stream1", "group1", b"$", mkstream=True)) + + consumer_tasks = [ + asyncio.create_task(consume("group1", "c1", r, "stream1")), + asyncio.create_task(consume("group1", "c2", r, "stream1")), + ] + await asyncio.sleep(0.1) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + DockerRedisNode("node", 9379, redis_container), + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + )) + await asyncio.sleep(0) + + for i in range(5): + await r.client.xadd("stream1", {"idx": i}) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + loop = asyncio.get_running_loop() + loop.call_later(5.0, do_unpause.set) + for i in range(5): + # The Redis server is dead temporarily... + if disruption_method == 'stop': + with pytest.raises(aioredis.exceptions.ConnectionError): + await r.client.xadd("stream1", {"idx": 5 + i}) + elif disruption_method == 'pause': + with pytest.raises(asyncio.TimeoutError): + await r.client.xadd("stream1", {"idx": 5 + i}) + else: + raise RuntimeError("should not reach here") + await asyncio.sleep(0.1) + await unpaused.wait() + print("RESUME TEST", file=sys.stderr) + for i in range(5): + await r.client.xadd("stream1", {"idx": 10 + i}) + await asyncio.sleep(0.1) + print("RESUME TEST DONE", file=sys.stderr) + + await interrupt_task + for t in consumer_tasks: + t.cancel() + try: + await t + except asyncio.CancelledError: + pass + await asyncio.gather(*consumer_tasks, return_exceptions=True) + + # loss happens + all_messages = set(map(int, received_messages["c1"])) | set(map(int, received_messages["c2"])) + print(f"{all_messages=}") + assert all_messages >= set(range(0, 5)) | set(range(10, 15)) + assert len(all_messages) >= 10 + + +@pytest.mark.redis +@pytest.mark.asyncio +@pytest.mark.parametrize("disruption_method", ['stop', 'pause']) +@with_timeout(30.0) +async def test_stream_loadbalance_cluster(redis_cluster: RedisClusterInfo, disruption_method: str, chaos_generator) -> None: + do_pause = asyncio.Event() + paused = asyncio.Event() + do_unpause = asyncio.Event() + unpaused = asyncio.Event() + received_messages: Dict[str, List[str]] = { + "c1": [], + "c2": [], + } + + async def consume( + group_name: str, + consumer_id: str, + r: RedisConnectionInfo, + key: str, + ) -> None: + try: + async with aclosing(redis.read_stream_by_group( + r, key, group_name, consumer_id, + autoclaim_idle_timeout=500, + )) as agen: + async for msg_id, msg_data in agen: + print(f"-> message: {msg_id} {msg_data!r}") + received_messages[consumer_id].append(msg_data[b"idx"]) + except asyncio.CancelledError: + return + except Exception: + traceback.print_exc() + return + + s = RedisConnectionInfo( + aioredis.sentinel.Sentinel( + redis_cluster.sentinel_addrs, + password='develove', + socket_timeout=0.5, + ), + service_name='mymaster', + ) + _execute = aiotools.apartial(redis.execute, s) + await _execute(lambda r: r.delete("stream1")) + await _execute(lambda r: r.xgroup_create("stream1", "group1", b"$", mkstream=True)) + + consumer_tasks = [ + asyncio.create_task(consume("group1", "c1", s, "stream1")), + asyncio.create_task(consume("group1", "c2", s, "stream1")), + ] + await asyncio.sleep(0.1) + interrupt_task = asyncio.create_task(interrupt( + disruption_method, + redis_cluster.nodes[0], + do_pause=do_pause, + do_unpause=do_unpause, + paused=paused, + unpaused=unpaused, + redis_password='develove', + )) + await asyncio.sleep(0) + + try: + for i in range(5): + await _execute(lambda r: r.xadd("stream1", {"idx": i})) + await asyncio.sleep(0.1) + do_pause.set() + await paused.wait() + loop = asyncio.get_running_loop() + loop.call_later(5.0, do_unpause.set) + for i in range(5): + # The Redis server is dead temporarily... + await _execute(lambda r: r.xadd("stream1", {"idx": 5 + i})) + await asyncio.sleep(0.1) + await unpaused.wait() + for i in range(5): + await _execute(lambda r: r.xadd("stream1", {"idx": 10 + i})) + await asyncio.sleep(0.1) + finally: + await interrupt_task + for t in consumer_tasks: + t.cancel() + await asyncio.gather(*consumer_tasks, return_exceptions=True) + + # loss may happen + all_messages = set(map(int, received_messages["c1"])) | set(map(int, received_messages["c2"])) + print(f"{all_messages=}") + assert all_messages >= set(range(0, 5)) | set(range(10, 15)) + assert len(all_messages) >= 10 diff --git a/tests/common/redis/types.py b/tests/common/redis/types.py new file mode 100644 index 0000000000..747d1d8ebc --- /dev/null +++ b/tests/common/redis/types.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +import contextlib +from typing import ( + AsyncIterator, + Sequence, + Tuple, +) +import attr + + +@attr.define +class RedisClusterInfo: + node_addrs: Sequence[Tuple[str, int]] + nodes: Sequence[AbstractRedisNode] + sentinel_addrs: Sequence[Tuple[str, int]] + sentinels: Sequence[AbstractRedisNode] + + +class AbstractRedisSentinelCluster(metaclass=ABCMeta): + + def __init__(self, test_ns: str, test_case_ns: str, password: str, service_name: str) -> None: + self.test_ns = test_ns + self.test_case_ns = test_case_ns + self.password = password + self.service_name = service_name + + @contextlib.asynccontextmanager + @abstractmethod + async def make_cluster(self) -> AsyncIterator[RedisClusterInfo]: + raise NotImplementedError + yield self + + +class AbstractRedisNode(metaclass=ABCMeta): + + @property + @abstractmethod + def addr(self) -> Tuple[str, int]: + raise NotImplementedError + + @abstractmethod + async def pause(self) -> None: + raise NotImplementedError + + @abstractmethod + async def unpause(self) -> None: + raise NotImplementedError + + @abstractmethod + async def stop(self, force_kill: bool = False) -> None: + raise NotImplementedError + + @abstractmethod + async def start(self) -> None: + raise NotImplementedError diff --git a/tests/common/redis/utils.py b/tests/common/redis/utils.py new file mode 100644 index 0000000000..54b4394bb3 --- /dev/null +++ b/tests/common/redis/utils.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import aioredis +import aioredis.exceptions +import async_timeout +import asyncio +import functools +import sys +from typing import ( + Awaitable, + Callable, + Final, + Sequence, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import ( + ParamSpec, +) +if TYPE_CHECKING: + from .types import AbstractRedisNode + + +disruptions: Final = { + 'stop': { + 'begin': 'stop', + 'end': 'start', + }, + 'pause': { + 'begin': 'pause', + 'end': 'unpause', + }, +} + + +async def simple_run_cmd(cmdargs: Sequence[Union[str, bytes]], **kwargs) -> asyncio.subprocess.Process: + p = await asyncio.create_subprocess_exec(*cmdargs, **kwargs) + await p.wait() + return p + + +async def wait_redis_ready(host: str, port: int, password: str = None) -> None: + r = aioredis.from_url(f"redis://{host}:{port}", password=password, socket_timeout=0.2) + while True: + try: + print("CheckReady.PING", port, file=sys.stderr) + await r.ping() + print("CheckReady.PONG", port, file=sys.stderr) + except aioredis.exceptions.AuthenticationError: + raise + except ( + ConnectionResetError, + aioredis.exceptions.ConnectionError, + ): + await asyncio.sleep(0.1) + except aioredis.exceptions.TimeoutError: + pass + else: + break + + +async def interrupt( + disruption_method: str, + node: AbstractRedisNode, + *, + do_pause: asyncio.Event, + do_unpause: asyncio.Event, + paused: asyncio.Event, + unpaused: asyncio.Event, + redis_password: str = None, +) -> None: + # Interrupt + await do_pause.wait() + print(f"STOPPING {node}", file=sys.stderr) + if disruption_method == "stop": + await node.stop(force_kill=True) + elif disruption_method == "pause": + await node.pause() + print(f"STOPPED {node}", file=sys.stderr) + paused.set() + # Resume + await do_unpause.wait() + print(f"STARTING {node}", file=sys.stderr) + if disruption_method == "stop": + await node.start() + elif disruption_method == "pause": + await node.unpause() + await wait_redis_ready(*node.addr, password=redis_password) + await asyncio.sleep(0.6) + print(f"STARTED {node}", file=sys.stderr) + unpaused.set() + + +_TReturn = TypeVar('_TReturn') +_PInner = ParamSpec('_PInner') + + +# FIXME: mypy 0.910 does not support PEP-612 (ParamSpec) yet... + +def with_timeout(t: float) -> Callable[ # type: ignore + [Callable[_PInner, Awaitable[_TReturn]]], + Callable[_PInner, Awaitable[_TReturn]], +]: + def wrapper( + corofunc: Callable[_PInner, Awaitable[_TReturn]], # type: ignore + ) -> Callable[_PInner, Awaitable[_TReturn]]: # type: ignore + @functools.wraps(corofunc) + async def run(*args: _PInner.args, **kwargs: _PInner.kwargs) -> _TReturn: # type: ignore + async with async_timeout.timeout(t): + return await corofunc(*args, **kwargs) + return run + return wrapper diff --git a/tests/common/test_argparse.py b/tests/common/test_argparse.py new file mode 100644 index 0000000000..d53c03ca0d --- /dev/null +++ b/tests/common/test_argparse.py @@ -0,0 +1,145 @@ +import argparse +import ipaddress + +import pytest + +from ai.backend.common.argparse import ( + port_no, port_range, positive_int, non_negative_int, + HostPortPair, host_port_pair, ipaddr, path, +) + +localhost_ipv4 = ipaddress.ip_address('127.0.0.1') +localhost_ipv6 = ipaddress.ip_address('::1') + + +def test_port_no(): + assert port_no(1) == 1 + assert port_no(20) == 20 + assert port_no(65535) == 65535 + + with pytest.raises(argparse.ArgumentTypeError): + port_no(-1) + with pytest.raises(argparse.ArgumentTypeError): + port_no(0) + with pytest.raises(argparse.ArgumentTypeError): + port_no(65536) + with pytest.raises(argparse.ArgumentTypeError): + port_no(65537) + + +def test_port_range(): + assert port_range('1-2') == (1, 2) + assert port_range('1000-2000') == (1000, 2000) + assert port_range('1-65535') == (1, 65535) + + with pytest.raises(argparse.ArgumentTypeError): + port_range('0-65535') + with pytest.raises(argparse.ArgumentTypeError): + port_range('1-65536') + with pytest.raises(argparse.ArgumentTypeError): + port_range('1-2-3') + with pytest.raises(argparse.ArgumentTypeError): + port_range('1') + with pytest.raises(argparse.ArgumentTypeError): + port_range('xxx') + with pytest.raises(argparse.ArgumentTypeError): + port_range('-') + with pytest.raises(argparse.ArgumentTypeError): + port_range('') + with pytest.raises(argparse.ArgumentTypeError): + port_range('10-5') + + +def test_positive_int(): + assert positive_int(1) + assert positive_int(100000) + + with pytest.raises(argparse.ArgumentTypeError): + positive_int(0) + with pytest.raises(argparse.ArgumentTypeError): + positive_int(-1) + with pytest.raises(argparse.ArgumentTypeError): + positive_int(-10) + + +def test_non_positive_int(): + assert non_negative_int(1) + assert non_negative_int(100000) + assert non_negative_int(0) == 0 + + with pytest.raises(argparse.ArgumentTypeError): + non_negative_int(-1) + with pytest.raises(argparse.ArgumentTypeError): + non_negative_int(-10) + + +def test_host_port_pair_direct_creation(): + ip = ipaddress.ip_address('1.2.3.4') + pair = HostPortPair(ip, 8000) + + assert pair.as_sockaddr() == ('1.2.3.4', 8000) + assert '{}'.format(pair) == '1.2.3.4:8000' + assert str(pair) == '1.2.3.4:8000' + + +def test_host_port_pair_parse(): + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair('oihasdfoih') + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair('99999') + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair('oihasdfoih:oixzcghboihx') + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair('oihasdfoih:-1') + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair('oihasdfoih:99999') + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair('oihasdfoih:123.45') + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair(':') + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair('::') + with pytest.raises(argparse.ArgumentTypeError): + host_port_pair(':::') + + a = host_port_pair('oihasdfoih:123') + assert a.host == 'oihasdfoih' + assert a.port == 123 + + a = host_port_pair('[::1]:9871') + assert a.host == localhost_ipv6 + assert a.port == 9871 + + a = host_port_pair('::1:9871') + assert a.host == localhost_ipv6 + assert a.port == 9871 + + +def test_host_port_pair_comparison(): + a = host_port_pair('oihasdfoih:123') + b = host_port_pair('oihasdfoih:123') + assert a == b + b = host_port_pair('oihasdfoih:124') + assert a != b + b = host_port_pair('oihasdfoix:123') + assert a != b + + +def test_ipaddr(): + assert ipaddr('[192.168.0.1]') == ipaddress.ip_address('192.168.0.1') + assert ipaddr('192.168.0.1') == ipaddress.ip_address('192.168.0.1') + assert ipaddr('2001:DB8::1') == ipaddress.ip_address('2001:DB8::1') + + with pytest.raises(argparse.ArgumentTypeError): + ipaddr('50') + with pytest.raises(argparse.ArgumentTypeError): + ipaddr('1.1') + with pytest.raises(argparse.ArgumentTypeError): + ipaddr('1.1.1') + + +def test_path(tmpdir): + assert path(None) is None + assert path(tmpdir) == tmpdir + with pytest.raises(argparse.ArgumentTypeError): + assert path('/path/not/exist/') diff --git a/tests/common/test_config.py b/tests/common/test_config.py new file mode 100644 index 0000000000..7d4e913920 --- /dev/null +++ b/tests/common/test_config.py @@ -0,0 +1,89 @@ +import pickle + +import toml +from toml.decoder import InlineTableDict + +from ai.backend.common.config import override_key, merge, _sanitize_inline_dicts + + +def test_override_key(): + sample = { + 'a': { + 'b': 0, + }, + 'c': 1, + } + override_key(sample, ('a', 'b'), -1) + assert sample['a']['b'] == -1 + assert sample['c'] == 1 + + sample = { + 'a': { + 'b': 0, + }, + 'c': 1, + } + override_key(sample, ('c',), -1) + assert sample['a']['b'] == 0 + assert sample['c'] == -1 + + +def test_merge(): + left = { + 'a': { + 'a': 5, + 'b': 0, + }, + 'c': 1, + } + right = { + 'a': { + 'b': 2, + 'c': 3, + }, + 'x': 10, + } + result = merge(left, right) + assert result == { + 'a': { + 'a': 5, + 'b': 2, + 'c': 3, + }, + 'c': 1, + 'x': 10, + } + + +def test_sanitize_inline_dicts(): + sample = ''' + [section] + a = { x = 1, y = 1 } + b = { x = 1, y = { t = 2, u = 2 } } + ''' + + result = toml.loads(sample) + assert isinstance(result['section']['a'], dict) + assert isinstance(result['section']['a'], InlineTableDict) + assert isinstance(result['section']['b'], dict) + assert isinstance(result['section']['b'], InlineTableDict) + assert isinstance(result['section']['b']['y'], dict) + assert isinstance(result['section']['b']['y'], InlineTableDict) + + result = _sanitize_inline_dicts(result) + assert isinstance(result['section']['a'], dict) + assert not isinstance(result['section']['a'], InlineTableDict) + assert isinstance(result['section']['b'], dict) + assert not isinstance(result['section']['b'], InlineTableDict) + assert isinstance(result['section']['b']['y'], dict) + assert not isinstance(result['section']['b']['y'], InlineTableDict) + + # Also ensure the result is picklable. + data = pickle.dumps(result) + result = pickle.loads(data) + assert result == { + 'section': { + 'a': {'x': 1, 'y': 1}, + 'b': {'x': 1, 'y': {'t': 2, 'u': 2}}, + }, + } diff --git a/tests/common/test_distributed.py b/tests/common/test_distributed.py new file mode 100644 index 0000000000..ecf4aa7163 --- /dev/null +++ b/tests/common/test_distributed.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +import asyncio +import tempfile +import threading +import time +from dataclasses import dataclass +from decimal import Decimal +from functools import partial +from multiprocessing import Event, Process, Queue +from pathlib import Path +from typing import ( + Any, + Iterable, + List, +) + +import attr +from etcetra.types import HostPortPair as EtcdHostPortPair +import pytest + +from ai.backend.common.distributed import GlobalTimer +from ai.backend.common.events import AbstractEvent, EventDispatcher, EventProducer +from ai.backend.common.lock import EtcdLock, FileLock +from ai.backend.common.types import AgentId, EtcdRedisConfig, HostPortPair + +from ai.backend.common.etcd import AsyncEtcd, ConfigScopes + + +@dataclass +class TimerNodeContext: + test_ns: str + redis_addr: HostPortPair + interval: float + + +@dataclass +class EtcdLockContext: + namespace: str + addr: EtcdHostPortPair + lock_name: str + + +def drange(start: Decimal, stop: Decimal, step: Decimal) -> Iterable[Decimal]: + while start < stop: + yield start + start += step + + +def dslice(start: Decimal, stop: Decimal, num: int): + """ + A simplified version of numpy.linspace with default options + """ + delta = stop - start + step = delta / (num - 1) + yield from (start + step * Decimal(tick) for tick in range(0, num)) + + +@attr.s(slots=True, frozen=True) +class NoopEvent(AbstractEvent): + name = "_noop" + + test_ns: str = attr.ib() + + def serialize(self) -> tuple: + return (self.test_ns, ) + + @classmethod + def deserialize(cls, value: tuple): + return cls(value[0]) + + +class TimerNode(threading.Thread): + + def __init__( + self, + event_records: list[float], + lock_path: Path, + thread_idx: int, + timer_ctx: TimerNodeContext, + ) -> None: + super().__init__() + self.event_records = event_records + self.lock_path = lock_path + self.thread_idx = thread_idx + self.interval = timer_ctx.interval + self.test_ns = timer_ctx.test_ns + self.redis_addr = timer_ctx.redis_addr + + async def timer_node_async(self) -> None: + self.loop = asyncio.get_running_loop() + self.stop_event = asyncio.Event() + + async def _tick(context: Any, source: AgentId, event: NoopEvent) -> None: + print("_tick") + self.event_records.append(time.monotonic()) + + redis_config = EtcdRedisConfig(addr=self.redis_addr) + event_dispatcher = await EventDispatcher.new( + redis_config, + node_id=self.test_ns, + ) + event_producer = await EventProducer.new( + redis_config, + ) + event_dispatcher.consume(NoopEvent, None, _tick) + + timer = GlobalTimer( + FileLock(self.lock_path, timeout=0, debug=True), + event_producer, + lambda: NoopEvent(self.test_ns), + self.interval, + ) + try: + await timer.join() + await self.stop_event.wait() + finally: + await timer.leave() + await event_producer.close() + await event_dispatcher.close() + + def run(self) -> None: + asyncio.run(self.timer_node_async()) + + +@pytest.mark.asyncio +async def test_global_timer_filelock(request, test_ns, redis_container) -> None: + lock_path = Path(tempfile.gettempdir()) / f'{test_ns}.lock' + request.addfinalizer(partial(lock_path.unlink, missing_ok=True)) + event_records: List[float] = [] + num_threads = 7 + num_records = 0 + delay = 3.0 + interval = 0.5 + target_count = (delay / interval) + threads: List[TimerNode] = [] + for thread_idx in range(num_threads): + timer_node = TimerNode( + event_records, + lock_path, + thread_idx, + TimerNodeContext( + test_ns=test_ns, + redis_addr=redis_container[1], + interval=interval, + ), + ) + threads.append(timer_node) + timer_node.start() + print(f"spawned {num_threads} timers") + print(threads) + print("waiting") + time.sleep(delay) + print("stopping timers") + for timer_node in threads: + timer_node.loop.call_soon_threadsafe(timer_node.stop_event.set) + print("joining timer threads") + for timer_node in threads: + timer_node.join() + print("checking records") + print(event_records) + num_records = len(event_records) + print(f"{num_records=}") + assert target_count - 2 <= num_records <= target_count + 2 + + +def etcd_timer_node_process( + queue, + stop_event, + etcd_ctx: EtcdLockContext, + timer_ctx: TimerNodeContext, +) -> None: + asyncio.set_event_loop(asyncio.new_event_loop()) + + async def _main() -> None: + + async def _tick(context: Any, source: AgentId, event: NoopEvent) -> None: + print("_tick") + queue.put(time.monotonic()) + + redis_config = EtcdRedisConfig(addr=timer_ctx.redis_addr) + event_dispatcher = await EventDispatcher.new( + redis_config, + node_id=timer_ctx.test_ns, + ) + event_producer = await EventProducer.new( + redis_config, + ) + event_dispatcher.consume(NoopEvent, None, _tick) + + etcd = AsyncEtcd( + addr=etcd_ctx.addr, + namespace=etcd_ctx.namespace, + scope_prefix_map={ + ConfigScopes.GLOBAL: 'global', + ConfigScopes.SGROUP: 'sgroup/testing', + ConfigScopes.NODE: 'node/i-test', + }, + ) + timer = GlobalTimer( + EtcdLock(etcd_ctx.lock_name, etcd, timeout=None, debug=True), + event_producer, + lambda: NoopEvent(timer_ctx.test_ns), + timer_ctx.interval, + ) + try: + await timer.join() + while not stop_event.is_set(): + await asyncio.sleep(0) + finally: + await timer.leave() + await event_producer.close() + await event_dispatcher.close() + + asyncio.run(_main()) + + +@pytest.mark.asyncio +async def test_global_timer_etcdlock( + test_ns, etcd_container, redis_container, +) -> None: + lock_name = f'{test_ns}lock' + event_records_queue: Queue = Queue() + num_processes = 7 + num_records = 0 + delay = 3.0 + interval = 0.5 + target_count = (delay / interval) + processes: List[Process] = [] + stop_event = Event() + for proc_idx in range(num_processes): + process = Process( + target=etcd_timer_node_process, + name=f'proc-{proc_idx}', + args=( + event_records_queue, + stop_event, + EtcdLockContext( + addr=etcd_container[1], + namespace=test_ns, + lock_name=lock_name, + ), + TimerNodeContext( + test_ns=test_ns, + redis_addr=redis_container[1], + interval=interval, + ), + ), + ) + process.start() + processes.append(process) + print(f"spawned {num_processes} timers") + print(processes) + print("waiting") + time.sleep(delay) + print("stopping timers") + stop_event.set() + print("joining timer processes") + for timer_node in processes: + timer_node.join() + print("checking records") + event_records: List[float] = [] + while not event_records_queue.empty(): + event_records.append(event_records_queue.get()) + print(event_records) + num_records = len(event_records) + print(f"{num_records=}") + assert target_count - 2 <= num_records <= target_count + 2 + + +@pytest.mark.asyncio +async def test_global_timer_join_leave(request, test_ns, redis_container) -> None: + + event_records = [] + + async def _tick(context: Any, source: AgentId, event: NoopEvent) -> None: + print("_tick") + event_records.append(time.monotonic()) + + redis_config = EtcdRedisConfig(addr=redis_container[1]) + event_dispatcher = await EventDispatcher.new( + redis_config, + node_id=test_ns, + ) + event_producer = await EventProducer.new( + redis_config, + ) + event_dispatcher.consume(NoopEvent, None, _tick) + + lock_path = Path(tempfile.gettempdir()) / f'{test_ns}.lock' + request.addfinalizer(partial(lock_path.unlink, missing_ok=True)) + for _ in range(10): + timer = GlobalTimer( + FileLock(lock_path, timeout=0, debug=True), + event_producer, + lambda: NoopEvent(test_ns), + 0.01, + ) + await timer.join() + await timer.leave() + + await event_producer.close() + await event_dispatcher.close() diff --git a/tests/common/test_docker.py b/tests/common/test_docker.py new file mode 100644 index 0000000000..d9b29e4b70 --- /dev/null +++ b/tests/common/test_docker.py @@ -0,0 +1,292 @@ +import collections +import functools +import itertools +import typing +from ai.backend.common.docker import ( + default_registry, default_repository, + ImageRef, PlatformTagSet, +) + +import pytest + + +def test_image_ref_typing(): + ref = ImageRef('c') + assert isinstance(ref, collections.abc.Hashable) + + +def test_image_ref_parsing(): + ref = ImageRef('c') + assert ref.name == f'{default_repository}/c' + assert ref.architecture == 'x86_64' + assert ref.tag == 'latest' + assert ref.registry == default_registry + assert ref.tag_set == ('latest', set()) + + ref = ImageRef('c:gcc6.3-alpine3.8', architecture='aarch64') + assert ref.name == f'{default_repository}/c' + assert ref.architecture == 'aarch64' + assert ref.tag == 'gcc6.3-alpine3.8' + assert ref.registry == default_registry + assert ref.tag_set == ('gcc6.3', {'alpine'}) + + ref = ImageRef('python:3.6-ubuntu', architecture='amd64') + assert ref.name == f'{default_repository}/python' + assert ref.architecture == 'x86_64' + assert ref.tag == '3.6-ubuntu' + assert ref.registry == default_registry + assert ref.tag_set == ('3.6', {'ubuntu'}) + + ref = ImageRef('kernel-python:3.6-ubuntu') + assert ref.name == f'{default_repository}/kernel-python' + assert ref.tag == '3.6-ubuntu' + assert ref.registry == default_registry + assert ref.tag_set == ('3.6', {'ubuntu'}) + + ref = ImageRef('lablup/python-tensorflow:1.10-py36-ubuntu') + assert ref.name == 'lablup/python-tensorflow' + assert ref.tag == '1.10-py36-ubuntu' + assert ref.registry == default_registry + assert ref.tag_set == ('1.10', {'ubuntu', 'py'}) + + ref = ImageRef('lablup/kernel-python:3.6-ubuntu') + assert ref.name == 'lablup/kernel-python' + assert ref.tag == '3.6-ubuntu' + assert ref.registry == default_registry + assert ref.tag_set == ('3.6', {'ubuntu'}) + + # To parse registry URLs correctly, we first need to give + # the valid registry URLs! + ref = ImageRef('myregistry.org/lua', []) + assert ref.name == 'myregistry.org/lua' + assert ref.tag == 'latest' + assert ref.registry == default_registry + assert ref.tag_set == ('latest', set()) + + ref = ImageRef('myregistry.org/lua', ['myregistry.org']) + assert ref.name == 'lua' + assert ref.tag == 'latest' + assert ref.registry == 'myregistry.org' + assert ref.tag_set == ('latest', set()) + + ref = ImageRef('myregistry.org/lua:5.3-alpine', ['myregistry.org']) + assert ref.name == 'lua' + assert ref.tag == '5.3-alpine' + assert ref.registry == 'myregistry.org' + assert ref.tag_set == ('5.3', {'alpine'}) + + # Non-standard port number should be a part of the known registry value. + ref = ImageRef('myregistry.org:999/mybase/python:3.6-cuda9-ubuntu', + ['myregistry.org:999']) + assert ref.name == 'mybase/python' + assert ref.tag == '3.6-cuda9-ubuntu' + assert ref.registry == 'myregistry.org:999' + assert ref.tag_set == ('3.6', {'ubuntu', 'cuda'}) + + ref = ImageRef('myregistry.org/mybase/moon/python:3.6-cuda9-ubuntu', + ['myregistry.org']) + assert ref.name == 'mybase/moon/python' + assert ref.tag == '3.6-cuda9-ubuntu' + assert ref.registry == 'myregistry.org' + assert ref.tag_set == ('3.6', {'ubuntu', 'cuda'}) + + # IP addresses are treated as valid registry URLs. + ref = ImageRef('127.0.0.1:5000/python:3.6-cuda9-ubuntu') + assert ref.name == 'python' + assert ref.tag == '3.6-cuda9-ubuntu' + assert ref.registry == '127.0.0.1:5000' + assert ref.tag_set == ('3.6', {'ubuntu', 'cuda'}) + + # IPv6 addresses must be bracketted. + ref = ImageRef('::1/python:3.6-cuda9-ubuntu') + assert ref.name == '::1/python' + assert ref.tag == '3.6-cuda9-ubuntu' + assert ref.registry == default_registry + assert ref.tag_set == ('3.6', {'ubuntu', 'cuda'}) + + ref = ImageRef('[::1]/python:3.6-cuda9-ubuntu') + assert ref.name == 'python' + assert ref.tag == '3.6-cuda9-ubuntu' + assert ref.registry == '[::1]' + assert ref.tag_set == ('3.6', {'ubuntu', 'cuda'}) + + ref = ImageRef('[::1]:5000/python:3.6-cuda9-ubuntu') + assert ref.name == 'python' + assert ref.tag == '3.6-cuda9-ubuntu' + assert ref.registry == '[::1]:5000' + assert ref.tag_set == ('3.6', {'ubuntu', 'cuda'}) + + ref = ImageRef('[212c:9cb9:eada:e57b:84c9:6a9:fbec:bdd2]:1024/python') + assert ref.name == 'python' + assert ref.tag == 'latest' + assert ref.registry == '[212c:9cb9:eada:e57b:84c9:6a9:fbec:bdd2]:1024' + assert ref.tag_set == ('latest', set()) + + with pytest.raises(ValueError): + ref = ImageRef('a:!') + + with pytest.raises(ValueError): + ref = ImageRef('127.0.0.1:5000/a:-x-') + + with pytest.raises(ValueError): + ref = ImageRef('http://127.0.0.1:5000/xyz') + + with pytest.raises(ValueError): + ref = ImageRef('//127.0.0.1:5000/xyz') + + +def test_image_ref_formats(): + ref = ImageRef('python:3.6-cuda9-ubuntu', []) + assert ref.canonical == 'index.docker.io/lablup/python:3.6-cuda9-ubuntu' + assert ref.short == 'lablup/python:3.6-cuda9-ubuntu' + assert str(ref) == ref.canonical + assert repr(ref) == f'' + + ref = ImageRef('myregistry.org/user/python:3.6-cuda9-ubuntu', ['myregistry.org'], 'aarch64') + assert ref.canonical == 'myregistry.org/user/python:3.6-cuda9-ubuntu' + assert ref.short == 'user/python:3.6-cuda9-ubuntu' + assert str(ref) == ref.canonical + assert repr(ref) == f'' + + +def test_platform_tag_set_typing(): + tags = PlatformTagSet(['py36', 'cuda9']) + assert isinstance(tags, collections.abc.Mapping) + assert isinstance(tags, typing.Mapping) + assert not isinstance(tags, collections.abc.MutableMapping) + assert not isinstance(tags, typing.MutableMapping) + + +def test_platform_tag_set(): + tags = PlatformTagSet(['py36', 'cuda9', 'ubuntu16.04', 'mkl2018.3']) + assert 'py' in tags + assert 'cuda' in tags + assert 'ubuntu' in tags + assert 'mkl' in tags + assert tags['py'] == '36' + assert tags['cuda'] == '9' + assert tags['ubuntu'] == '16.04' + assert tags['mkl'] == '2018.3' + + with pytest.raises(ValueError): + tags = PlatformTagSet(['cuda9', 'cuda8']) + + tags = PlatformTagSet(['myplatform9b1', 'other']) + assert 'myplatform' in tags + assert tags['myplatform'] == '9b1' + assert 'other' in tags + assert tags['other'] == '' + + with pytest.raises(ValueError): + tags = PlatformTagSet(['1234']) + + +def test_platform_tag_set_abbreviations(): + pass + + +def test_image_ref_generate_aliases(): + ref = ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu16.04') + aliases = ref.generate_aliases() + possible_names = ['python-tensorflow', 'tensorflow'] + possible_platform_tags = [ + ['1.5'], + ['', 'py', 'py3', 'py36'], + ['', 'ubuntu', 'ubuntu16', 'ubuntu16.04'], + ] + # combinations of abbreviated/omitted platforms tags + for name, ptags in itertools.product( + possible_names, + itertools.product(*possible_platform_tags)): + assert f"{name}:{'-'.join(t for t in ptags if t)}" in aliases + + +def test_image_ref_generate_aliases_with_accelerator(): + ref = ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu16.04-cuda10.0') + aliases = ref.generate_aliases() + possible_names = ['python-tensorflow', 'tensorflow'] + possible_platform_tags = [ + ['1.5'], + ['', 'py', 'py3', 'py36'], + ['', 'ubuntu', 'ubuntu16', 'ubuntu16.04'], + ['cuda', 'cuda10', 'cuda10.0'], # cannot be empty! + ] + # combinations of abbreviated/omitted platforms tags + for name, ptags in itertools.product( + possible_names, + itertools.product(*possible_platform_tags)): + assert f"{name}:{'-'.join(t for t in ptags if t)}" in aliases + + +def test_image_ref_generate_aliases_of_names(): + # an alias may include only last framework name in the name. + ref = ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu16.04-cuda10.0') + aliases = ref.generate_aliases() + assert 'python-tensorflow' in aliases + assert 'tensorflow' in aliases + assert 'python' not in aliases + + +def test_image_ref_generate_aliases_disallowed(): + # an alias must include the main platform version tag + ref = ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu16.04-cuda10.0') + aliases = ref.generate_aliases() + # always the main version must be included! + assert 'python-tensorflow:py3' not in aliases + assert 'python-tensorflow:py36' not in aliases + assert 'python-tensorflow:ubuntu' not in aliases + assert 'python-tensorflow:ubuntu16.04' not in aliases + assert 'python-tensorflow:cuda' not in aliases + assert 'python-tensorflow:cuda10.0' not in aliases + + +def test_image_ref_ordering(): + # ordering is defined as the tuple-ordering of platform tags. + # (tag components that come first have higher priority when comparing.) + r1 = ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu16.04-cuda10.0') + r2 = ImageRef('lablup/python-tensorflow:1.7-py36-ubuntu16.04-cuda10.0') + r3 = ImageRef('lablup/python-tensorflow:1.7-py37-ubuntu18.04-cuda9.0') + assert r1 < r2 + assert r1 < r3 + assert r2 < r3 + + # only the image-refs with same names can be compared. + rx = ImageRef('lablup/python:3.6-ubuntu') + with pytest.raises(ValueError): + rx < r1 + with pytest.raises(ValueError): + r1 < rx + + # test case added for explicit behavior documentation + # ImageRef(...:ubuntu16.04) > ImageRef(...:ubuntu) == False + # ImageRef(...:ubuntu16.04) > ImageRef(...:ubuntu) == False + # by keeping naming convetion, no need to handle these cases + r4 = ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu16.04-cuda9.0') + r5 = ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu-cuda9.0') + assert not r4 > r5 + assert not r5 > r4 + + +def test_image_ref_merge_aliases(): + # After merging, aliases that indicates two or more references should + # indicate most recent versions. + refs = [ + ImageRef('lablup/python:3.7-ubuntu18.04'), # 0 + ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu16.04-cuda10.0'), # 1 + ImageRef('lablup/python-tensorflow:1.7-py36-ubuntu16.04-cuda10.0'), # 2 + ImageRef('lablup/python-tensorflow:1.7-py37-ubuntu16.04-cuda9.0'), # 3 + ImageRef('lablup/python-tensorflow:1.5-py36-ubuntu16.04'), # 4 + ImageRef('lablup/python-tensorflow:1.7-py36-ubuntu16.04'), # 5 + ImageRef('lablup/python-tensorflow:1.7-py37-ubuntu16.04'), # 6 + ] + aliases = [ref.generate_aliases() for ref in refs] + aliases = functools.reduce(ImageRef.merge_aliases, aliases) + assert aliases['python-tensorflow'] is refs[6] + assert aliases['python-tensorflow:1.5'] is refs[4] + assert aliases['python-tensorflow:1.7'] is refs[6] + assert aliases['python-tensorflow:1.7-py36'] is refs[5] + assert aliases['python-tensorflow:1.5'] is refs[4] + assert aliases['python-tensorflow:1.5-cuda'] is refs[1] + assert aliases['python-tensorflow:1.7-cuda10'] is refs[2] + assert aliases['python-tensorflow:1.7-cuda9'] is refs[3] + assert aliases['python'] is refs[0] diff --git a/tests/common/test_etcd.py b/tests/common/test_etcd.py new file mode 100644 index 0000000000..ccdf71e12c --- /dev/null +++ b/tests/common/test_etcd.py @@ -0,0 +1,297 @@ +import asyncio + +from etcetra.types import WatchEventType +import pytest + +from ai.backend.common.etcd import ConfigScopes + + +@pytest.mark.asyncio +async def test_basic_crud(etcd): + + await etcd.put('wow', 'abc') + + v = await etcd.get('wow') + assert v == 'abc' + v = await etcd.get_prefix('wow') + assert len(v) == 1 + assert v == {'': 'abc'} + + r = await etcd.replace('wow', 'aaa', 'ccc') + assert r is False + r = await etcd.replace('wow', 'abc', 'def') + assert r is True + v = await etcd.get('wow') + assert v == 'def' + + await etcd.delete('wow') + + v = await etcd.get('wow') + assert v is None + v = await etcd.get_prefix('wow') + assert len(v) == 0 + + +@pytest.mark.asyncio +async def test_quote_for_put_prefix(etcd): + await etcd.put_prefix('data', { + 'aa:bb': { + 'option1': 'value1', + 'option2': 'value2', + 'myhost/path': 'this', + }, + 'aa:cc': 'wow', + 'aa:dd': { + '': 'oops', + }, + }, scope=ConfigScopes.GLOBAL) + v = await etcd.get('data/aa%3Abb/option1') + assert v == 'value1' + v = await etcd.get('data/aa%3Abb/option2') + assert v == 'value2' + v = await etcd.get('data/aa%3Abb/myhost%2Fpath') + assert v == 'this' + v = await etcd.get('data/aa%3Acc') + assert v == 'wow' + v = await etcd.get('data/aa%3Add') + assert v == 'oops' + + +@pytest.mark.asyncio +async def test_unquote_for_get_prefix(etcd): + await etcd.put('obj/aa%3Abb/option1', 'value1') + await etcd.put('obj/aa%3Abb/option2', 'value2') + await etcd.put('obj/aa%3Abb/myhost%2Fpath', 'this') + await etcd.put('obj/aa%3Acc', 'wow') + + v = await etcd.get_prefix('obj') + assert dict(v) == { + 'aa:bb': { + 'option1': 'value1', + 'option2': 'value2', + 'myhost/path': 'this', + }, + 'aa:cc': 'wow', + } + + v = await etcd.get_prefix('obj/aa%3Abb') + assert dict(v) == { + 'option1': 'value1', + 'option2': 'value2', + 'myhost/path': 'this', + } + + v = await etcd.get_prefix('obj/aa%3Acc') + assert dict(v) == {'': 'wow'} + + +@pytest.mark.asyncio +async def test_scope_empty_prefix(gateway_etcd): + # This test case is to ensure compatibility with the legacy managers. + # gateway_etcd is created with a scope prefix map that contains + # ConfigScopes.GLOBAL => '' + # setting so that global scope configurations have the same key + # used before introduction of scoped configurations. + await gateway_etcd.put('wow', 'abc') + v = await gateway_etcd.get('wow') + assert v == 'abc' + + v = await gateway_etcd.get_prefix('wow') + assert len(v) == 1 + assert v == {'': 'abc'} + + r = await gateway_etcd.replace('wow', 'aaa', 'ccc') + assert r is False + r = await gateway_etcd.replace('wow', 'abc', 'def') + assert r is True + v = await gateway_etcd.get('wow') + assert v == 'def' + + await gateway_etcd.delete('wow') + + v = await gateway_etcd.get('wow') + assert v is None + v = await gateway_etcd.get_prefix('wow') + assert len(v) == 0 + + +@pytest.mark.asyncio +async def test_scope(etcd): + await etcd.put('wow', 'abc', scope=ConfigScopes.GLOBAL) + await etcd.put('wow', 'def', scope=ConfigScopes.SGROUP) + await etcd.put('wow', 'ghi', scope=ConfigScopes.NODE) + v = await etcd.get('wow') + assert v == 'ghi' + + await etcd.delete('wow', scope=ConfigScopes.NODE) + v = await etcd.get('wow') + assert v == 'def' + + await etcd.delete('wow', scope=ConfigScopes.SGROUP) + v = await etcd.get('wow') + assert v == 'abc' + + await etcd.delete('wow', scope=ConfigScopes.GLOBAL) + v = await etcd.get('wow') + assert v is None + + await etcd.put('wow', '000', scope=ConfigScopes.NODE) + v = await etcd.get('wow') + assert v == '000' + + +@pytest.mark.asyncio +async def test_scope_dict(etcd): + await etcd.put_dict({'point/x': '1', 'point/y': '2'}, scope=ConfigScopes.GLOBAL) + await etcd.put_dict({'point/y': '3'}, scope=ConfigScopes.SGROUP) + await etcd.put_dict({'point/x': '4', 'point/z': '5'}, scope=ConfigScopes.NODE) + v = await etcd.get_prefix('point', scope=ConfigScopes.MERGED) + assert v == {'x': '4', 'y': '3', 'z': '5'} + v = await etcd.get_prefix('point', scope=ConfigScopes.SGROUP) + assert v == {'x': '1', 'y': '3'} + v = await etcd.get_prefix('point', scope=ConfigScopes.GLOBAL) + assert v == {'x': '1', 'y': '2'} + + await etcd.delete_prefix('point', scope=ConfigScopes.NODE) + v = await etcd.get_prefix('point', scope=ConfigScopes.MERGED) + assert v == {'x': '1', 'y': '3'} + + await etcd.delete_prefix('point', scope=ConfigScopes.SGROUP) + v = await etcd.get_prefix('point', scope=ConfigScopes.MERGED) + assert v == {'x': '1', 'y': '2'} + + await etcd.delete_prefix('point', scope=ConfigScopes.GLOBAL) + v = await etcd.get_prefix('point', scope=ConfigScopes.MERGED) + assert len(v) == 0 + + +@pytest.mark.asyncio +async def test_multi(etcd): + + v = await etcd.get('foo') + assert v is None + v = await etcd.get('bar') + assert v is None + + await etcd.put_dict({'foo': 'x', 'bar': 'y'}) + v = await etcd.get('foo') + assert v == 'x' + v = await etcd.get('bar') + assert v == 'y' + + await etcd.delete_multi(['foo', 'bar']) + v = await etcd.get('foo') + assert v is None + v = await etcd.get('bar') + assert v is None + + +@pytest.mark.asyncio +async def test_watch(etcd): + + records = [] + records_prefix = [] + r_ready = asyncio.Event() + rp_ready = asyncio.Event() + + async def _record(): + try: + async for ev in etcd.watch('wow', ready_event=r_ready): + records.append(ev) + except asyncio.CancelledError: + pass + + async def _record_prefix(): + try: + async for ev in etcd.watch_prefix('wow', ready_event=rp_ready): + records_prefix.append(ev) + except asyncio.CancelledError: + pass + + t1 = asyncio.create_task(_record()) + t2 = asyncio.create_task(_record_prefix()) + + await r_ready.wait() + await rp_ready.wait() + + await etcd.put('wow', '123') + await etcd.delete('wow') + await etcd.put('wow/child', 'hello') + await etcd.delete_prefix('wow') + + await asyncio.sleep(0.2) + t1.cancel() + t2.cancel() + await t1 + await t2 + + assert len(records) == 2 + assert records[0].key == 'wow' + assert records[0].event == WatchEventType.PUT + assert records[0].value == '123' + assert records[1].key == 'wow' + assert records[1].event == WatchEventType.DELETE + assert records[1].value == '' + + assert len(records_prefix) == 4 + assert records_prefix[0].key == 'wow' + assert records_prefix[0].event == WatchEventType.PUT + assert records_prefix[0].value == '123' + assert records_prefix[1].key == 'wow' + assert records_prefix[1].event == WatchEventType.DELETE + assert records_prefix[1].value == '' + assert records_prefix[2].key == 'wow/child' + assert records_prefix[2].event == WatchEventType.PUT + assert records_prefix[2].value == 'hello' + assert records_prefix[3].key == 'wow/child' + assert records_prefix[3].event == WatchEventType.DELETE + assert records_prefix[3].value == '' + + +@pytest.mark.asyncio +async def test_watch_once(etcd): + + records = [] + records_prefix = [] + r_ready = asyncio.Event() + rp_ready = asyncio.Event() + + async def _record(): + try: + async for ev in etcd.watch('wow', once=True, ready_event=r_ready): + records.append(ev) + except asyncio.CancelledError: + pass + + async def _record_prefix(): + try: + async for ev in etcd.watch_prefix('wow/city', once=True, ready_event=rp_ready): + records_prefix.append(ev) + except asyncio.CancelledError: + pass + + t1 = asyncio.create_task(_record()) + t2 = asyncio.create_task(_record_prefix()) + await r_ready.wait() + await rp_ready.wait() + + await etcd.put('wow/city1', 'seoul') + await etcd.put('wow/city2', 'daejeon') + await etcd.put('wow', 'korea') + await etcd.delete_prefix('wow') + + await asyncio.sleep(0.2) + t1.cancel() + t2.cancel() + await t1 + await t2 + + assert len(records) == 1 + assert records[0].key == 'wow' + assert records[0].event == WatchEventType.PUT + assert records[0].value == 'korea' + + assert len(records_prefix) == 1 + assert records_prefix[0].key == 'wow/city1' + assert records_prefix[0].event == WatchEventType.PUT + assert records_prefix[0].value == 'seoul' diff --git a/tests/common/test_events.py b/tests/common/test_events.py new file mode 100644 index 0000000000..ada48f5dd4 --- /dev/null +++ b/tests/common/test_events.py @@ -0,0 +1,179 @@ +import asyncio +from typing import Type +from types import TracebackType + +import aiotools +import attr +import pytest + +from ai.backend.common.events import ( + AbstractEvent, + CoalescingOptions, + CoalescingState, + EventDispatcher, + EventProducer, +) +from ai.backend.common.types import ( + AgentId, + EtcdRedisConfig, +) +from ai.backend.common import redis + + +@attr.s(slots=True, frozen=True) +class DummyEvent(AbstractEvent): + name = "testing" + + value: int = attr.ib() + + def serialize(self) -> tuple: + return (self.value + 1, ) + + @classmethod + def deserialize(cls, value: tuple): + return cls(value[0] + 1) + + +@pytest.mark.asyncio +async def test_dispatch(redis_container) -> None: + app = object() + + redis_config = EtcdRedisConfig(addr=redis_container[1]) + dispatcher = await EventDispatcher.new(redis_config) + producer = await EventProducer.new(redis_config) + + records = set() + + async def acb(context: object, source: AgentId, event: DummyEvent) -> None: + assert context is app + assert source == AgentId('i-test') + assert isinstance(event, DummyEvent) + assert event.name == "testing" + assert event.value == 1001 + await asyncio.sleep(0.01) + records.add('async') + + def scb(context: object, source: AgentId, event: DummyEvent) -> None: + assert context is app + assert source == AgentId('i-test') + assert isinstance(event, DummyEvent) + assert event.name == "testing" + assert event.value == 1001 + records.add('sync') + + dispatcher.subscribe(DummyEvent, app, acb) + dispatcher.subscribe(DummyEvent, app, scb) + await asyncio.sleep(0.1) + + # Dispatch the event + await producer.produce_event(DummyEvent(999), source='i-test') + await asyncio.sleep(0.2) + assert records == {'async', 'sync'} + + await redis.execute(producer.redis_client, lambda r: r.flushdb()) + await producer.close() + await dispatcher.close() + + +@pytest.mark.asyncio +async def test_error_on_dispatch(redis_container) -> None: + app = object() + exception_log: list[str] = [] + + async def handle_exception( + et: Type[Exception], + exc: Exception, + tb: TracebackType, + ) -> None: + exception_log.append(type(exc).__name__) + + redis_config = EtcdRedisConfig(addr=redis_container[1]) + dispatcher = await EventDispatcher.new( + redis_config, + consumer_exception_handler=handle_exception, + subscriber_exception_handler=handle_exception, + ) + producer = await EventProducer.new(redis_config) + + async def acb(context: object, source: AgentId, event: DummyEvent) -> None: + assert context is app + assert source == AgentId('i-test') + assert isinstance(event, DummyEvent) + raise ZeroDivisionError + + def scb(context: object, source: AgentId, event: DummyEvent) -> None: + assert context is app + assert source == AgentId('i-test') + assert isinstance(event, DummyEvent) + raise OverflowError + + dispatcher.subscribe(DummyEvent, app, scb) + dispatcher.subscribe(DummyEvent, app, acb) + await asyncio.sleep(0.1) + + await producer.produce_event(DummyEvent(0), source='i-test') + await asyncio.sleep(0.5) + assert len(exception_log) == 2 + assert 'ZeroDivisionError' in exception_log + assert 'OverflowError' in exception_log + + await redis.execute(producer.redis_client, lambda r: r.flushdb()) + await producer.close() + await dispatcher.close() + + +@pytest.mark.asyncio +async def test_event_dispatcher_rate_control(): + opts = CoalescingOptions(max_wait=0.1, max_batch_size=5) + state = CoalescingState() + assert await state.rate_control(None) is True + epsilon = 0.01 + clock = aiotools.VirtualClock() + with clock.patch_loop(): + for _ in range(2): # repetition should not affect the behavior + t1 = asyncio.create_task(state.rate_control(opts)) + await asyncio.sleep(0.1 + epsilon) + assert t1.result() is True + + t1 = asyncio.create_task(state.rate_control(opts)) + t2 = asyncio.create_task(state.rate_control(opts)) + t3 = asyncio.create_task(state.rate_control(opts)) + await asyncio.sleep(0.1 + epsilon) + assert t1.result() is False + assert t2.result() is False + assert t3.result() is True + + t1 = asyncio.create_task(state.rate_control(opts)) + await asyncio.sleep(0.1 + epsilon) + t2 = asyncio.create_task(state.rate_control(opts)) + await asyncio.sleep(0.1 + epsilon) + assert t1.result() is True + assert t2.result() is True + + t1 = asyncio.create_task(state.rate_control(opts)) + t2 = asyncio.create_task(state.rate_control(opts)) + t3 = asyncio.create_task(state.rate_control(opts)) + t4 = asyncio.create_task(state.rate_control(opts)) + t5 = asyncio.create_task(state.rate_control(opts)) + await asyncio.sleep(epsilon) # should be executed immediately + assert t1.result() is False + assert t2.result() is False + assert t3.result() is False + assert t4.result() is False + assert t5.result() is True + + t1 = asyncio.create_task(state.rate_control(opts)) + t2 = asyncio.create_task(state.rate_control(opts)) + t3 = asyncio.create_task(state.rate_control(opts)) + t4 = asyncio.create_task(state.rate_control(opts)) + t5 = asyncio.create_task(state.rate_control(opts)) + t6 = asyncio.create_task(state.rate_control(opts)) + await asyncio.sleep(epsilon) + assert t1.result() is False + assert t2.result() is False + assert t3.result() is False + assert t4.result() is False + assert t5.result() is True + assert not t6.done() # t5 executed but t6 should be pending + await asyncio.sleep(0.1 + epsilon) + assert t6.result() is True diff --git a/tests/common/test_identity.py b/tests/common/test_identity.py new file mode 100644 index 0000000000..cf746491c7 --- /dev/null +++ b/tests/common/test_identity.py @@ -0,0 +1,199 @@ +import secrets +import socket +import random +from unittest.mock import patch, MagicMock + +import pytest +import aiodns +from aioresponses import aioresponses + +import ai.backend.common.identity + + +def test_is_containerized(): + mocked_path = MagicMock() + mocked_path.read_text.return_value = '\n'.join([ + '13:name=systemd:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '12:pids:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '11:hugetlb:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '10:net_prio:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '9:perf_event:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '8:net_cls:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '7:freezer:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '6:devices:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '5:memory:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '4:blkio:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '3:cpuacct:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '2:cpu:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + '1:cpuset:/docker-ce/docker/67bfa4f7a0d87eb95592dd95ce851fe6625db539fa2ea616000202b328c32c92', + ]) + with patch('ai.backend.common.identity.Path', return_value=mocked_path): + assert ai.backend.common.identity.is_containerized() + mocked_path = MagicMock() + mocked_path.read_text.return_value = '\n'.join([ + '11:devices:/user.slice', + '10:pids:/user.slice/user-1000.slice', + '9:hugetlb:/', + '8:cpuset:/', + '7:blkio:/user.slice', + '6:memory:/user.slice', + '5:cpu,cpuacct:/user.slice', + '4:freezer:/', + '3:net_cls,net_prio:/', + '2:perf_event:/', + '1:name=systemd:/user.slice/user-1000.slice/session-3.scope', + ]) + with patch('ai.backend.common.identity.Path', return_value=mocked_path): + assert not ai.backend.common.identity.is_containerized() + mocked_path = MagicMock() + mocked_path.side_effect = FileNotFoundError('no such file') + with patch('ai.backend.common.identity.Path', return_value=mocked_path): + assert not ai.backend.common.identity.is_containerized() + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize('provider', ['amazon', 'google', 'azure', None]) +async def test_get_instance_id(mocker, provider): + ai.backend.common.identity.current_provider = provider + ai.backend.common.identity._defined = False + ai.backend.common.identity._define_functions() + + with aioresponses() as m: + random_id = secrets.token_hex(16) + if provider == 'amazon': + m.get('http://169.254.169.254/latest/meta-data/instance-id', + body=random_id) + ret = await ai.backend.common.identity.get_instance_id() + assert ret == random_id + elif provider == 'azure': + m.get( + 'http://169.254.169.254/metadata/instance?version=2017-03-01', + payload={ + 'compute': { + 'vmId': random_id, + }, + }) + ret = await ai.backend.common.identity.get_instance_id() + assert ret == random_id + elif provider == 'google': + m.get('http://metadata.google.internal/computeMetadata/v1/instance/id', + body=random_id) + ret = await ai.backend.common.identity.get_instance_id() + assert ret == random_id + elif provider is None: + with patch('socket.gethostname', return_value='myname'): + ret = await ai.backend.common.identity.get_instance_id() + assert ret == 'i-myname' + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize('provider', ['amazon', 'google', 'azure', None]) +async def test_get_instance_id_failures(mocker, provider): + ai.backend.common.identity.current_provider = provider + ai.backend.common.identity._defined = False + ai.backend.common.identity._define_functions() + + with aioresponses(): + # If we don't set any mocked responses, aioresponses will raise ClientConnectionError. + ret = await ai.backend.common.identity.get_instance_id() + assert ret == f'i-{socket.gethostname()}' + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize('provider', ['amazon', 'google', 'azure', None]) +async def test_get_instance_ip(mocker, provider): + ai.backend.common.identity.current_provider = provider + ai.backend.common.identity._defined = False + ai.backend.common.identity._define_functions() + + with aioresponses() as m: + random_ip = '.'.join(str(random.randint(0, 255)) for _ in range(4)) + if provider == 'amazon': + m.get('http://169.254.169.254/latest/meta-data/local-ipv4', + body=random_ip) + ret = await ai.backend.common.identity.get_instance_ip() + assert ret == random_ip + elif provider == 'azure': + m.get( + 'http://169.254.169.254/metadata/instance?version=2017-03-01', + payload={ + 'network': { + 'interface': [ + { + 'ipv4': { + 'ipaddress': [ + {'ipaddress': random_ip}, + ], + }, + }, + ], + }, + }) + ret = await ai.backend.common.identity.get_instance_ip() + assert ret == random_ip + elif provider == 'google': + m.get('http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/ip', + body=random_ip) + ret = await ai.backend.common.identity.get_instance_ip() + assert ret == random_ip + elif provider is None: + mocked_ares_host_result = MagicMock() + mocked_ares_host_result.addresses = ['10.1.2.3'] + mocked_resolver = MagicMock() + + async def coro_return_mocked_result(*args): + return mocked_ares_host_result + + mocked_resolver.gethostbyname = coro_return_mocked_result + with patch('aiodns.DNSResolver', return_value=mocked_resolver), \ + patch('socket.gethostname', return_value='myname'): + ret = await ai.backend.common.identity.get_instance_ip() + assert ret == '10.1.2.3' + + async def coro_raise_error(*args): + raise aiodns.error.DNSError('domain not found') + + mocked_resolver = MagicMock() + mocked_resolver.gethostbyname = coro_raise_error + with patch('aiodns.DNSResolver', return_value=mocked_resolver), \ + patch('socket.gethostname', return_value='myname'): + ret = await ai.backend.common.identity.get_instance_ip() + assert ret == '127.0.0.1' + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize('provider', ['amazon', 'google', 'azure', None]) +async def test_get_instance_type(mocker, provider): + ai.backend.common.identity.current_provider = provider + ai.backend.common.identity._defined = False + ai.backend.common.identity._define_functions() + + with aioresponses() as m: + random_type = secrets.token_hex(16) + if provider == 'amazon': + m.get('http://169.254.169.254/latest/meta-data/instance-type', + body=random_type) + ret = await ai.backend.common.identity.get_instance_type() + assert ret == random_type + elif provider == 'azure': + m.get( + 'http://169.254.169.254/metadata/instance?version=2017-03-01', + payload={ + 'compute': { + 'vmSize': random_type, + }, + }) + ret = await ai.backend.common.identity.get_instance_type() + assert ret == random_type + elif provider == 'google': + m.get('http://metadata.google.internal/computeMetadata/v1/instance/machine-type', + body=random_type) + ret = await ai.backend.common.identity.get_instance_type() + assert ret == random_type + elif provider is None: + ret = await ai.backend.common.identity.get_instance_type() + assert ret == 'default' diff --git a/tests/common/test_json.py b/tests/common/test_json.py new file mode 100644 index 0000000000..ae1df2df1a --- /dev/null +++ b/tests/common/test_json.py @@ -0,0 +1,19 @@ +import datetime +import json +import uuid + +from ai.backend.common.json import ExtendedJSONEncoder + + +def test_encode(): + ret = json.dumps( + {'x': uuid.UUID('78bd79c7-214b-4ec6-9a22-3461785bced6')}, + cls=ExtendedJSONEncoder, + ) + assert '"78bd79c7-214b-4ec6-9a22-3461785bced6"' in ret + ret = json.dumps( + {'x': datetime.datetime(year=2000, month=1, day=1, hour=11, minute=30, second=22, + tzinfo=datetime.timezone.utc)}, + cls=ExtendedJSONEncoder, + ) + assert '2000-01-01T11:30:22+00:00' in ret diff --git a/tests/common/test_logging.py b/tests/common/test_logging.py new file mode 100644 index 0000000000..a820672072 --- /dev/null +++ b/tests/common/test_logging.py @@ -0,0 +1,81 @@ +import logging +import os +from pathlib import Path +import threading +import time + +from ai.backend.common.logging import Logger, BraceStyleAdapter + + +test_log_config = { + 'level': 'DEBUG', + 'drivers': ['console'], + 'pkg-ns': {'': 'DEBUG'}, + 'console': { + 'colored': True, + }, +} + +test_log_path = Path(f'/tmp/bai-testing-agent-logger-{os.getpid()}.sock') + +log = BraceStyleAdapter(logging.getLogger('ai.backend.common.testing')) + + +def get_logger_thread(): + for t in threading.enumerate(): + if t.name == 'Logger': + return t + return None + + +def test_logger(unused_tcp_port): + test_log_path.parent.mkdir(parents=True, exist_ok=True) + log_endpoint = f'ipc://{test_log_path}' + logger = Logger(test_log_config, is_master=True, log_endpoint=log_endpoint) + with logger: + assert test_log_path.exists() + log.warning('blizzard warning {}', 123) + assert get_logger_thread() is not None + assert not test_log_path.exists() + assert get_logger_thread() is None + + +class NotPicklableClass: + """A class that cannot be pickled.""" + + def __reduce__(self): + raise TypeError('this is not picklable') + + +class NotUnpicklableClass: + """A class that is pickled successfully but cannot be unpickled.""" + + def __init__(self, x): + if x == 1: + raise TypeError('this is not unpicklable') + + def __reduce__(self): + return type(self), (1, ) + + +def test_logger_not_picklable(): + test_log_path.parent.mkdir(parents=True, exist_ok=True) + log_endpoint = f'ipc://{test_log_path}' + logger = Logger(test_log_config, is_master=True, log_endpoint=log_endpoint) + with logger: + # The following line should not throw an error. + log.warning('blizzard warning {}', NotPicklableClass()) + assert not test_log_path.exists() + assert get_logger_thread() is None + + +def test_logger_not_unpicklable(): + test_log_path.parent.mkdir(parents=True, exist_ok=True) + log_endpoint = f'ipc://{test_log_path}' + logger = Logger(test_log_config, is_master=True, log_endpoint=log_endpoint) + with logger: + log.warning('blizzard warning {}', NotUnpicklableClass(0)) + time.sleep(1.0) + assert get_logger_thread() is not None, 'logger thread must be alive' + assert not test_log_path.exists() + assert get_logger_thread() is None diff --git a/tests/common/test_msgpack.py b/tests/common/test_msgpack.py new file mode 100644 index 0000000000..23f6dc775b --- /dev/null +++ b/tests/common/test_msgpack.py @@ -0,0 +1,29 @@ +from ai.backend.common import msgpack + + +def test_msgpack_with_unicode(): + # msgpack-python module requires special treatment + # to distinguish unicode strings and binary data + # correctly, and ai.backend.common.msgpack wraps it for that. + + data = [b'\xff', '한글', 12345, 12.5] + packed = msgpack.packb(data) + unpacked = msgpack.unpackb(packed) + + # We also use tuples when unpacking for performance. + assert unpacked == tuple(data) + + +def test_msgpack_kwargs(): + x = {'cpu': [0.42, 0.44], 'cuda_mem': [0.0, 0.0], 'cuda_util': [0.0, 0.0], 'mem': [30.0, 30.0]} + packed = msgpack.packb(x) + unpacked = msgpack.unpackb(packed, use_list=False) + assert isinstance(unpacked['cpu'], tuple) + assert isinstance(unpacked['mem'], tuple) + assert isinstance(unpacked['cuda_mem'], tuple) + assert isinstance(unpacked['cuda_util'], tuple) + unpacked = msgpack.unpackb(packed, use_list=True) + assert isinstance(unpacked['cpu'], list) + assert isinstance(unpacked['mem'], list) + assert isinstance(unpacked['cuda_mem'], list) + assert isinstance(unpacked['cuda_util'], list) diff --git a/tests/common/test_plugin.py b/tests/common/test_plugin.py new file mode 100644 index 0000000000..64b2f5396c --- /dev/null +++ b/tests/common/test_plugin.py @@ -0,0 +1,294 @@ +import asyncio +import functools +from typing import ( + Any, + Mapping, +) +from unittest.mock import AsyncMock + +from ai.backend.common.plugin import ( + AbstractPlugin, + BasePluginContext, +) +from ai.backend.common.plugin.hook import ( + HookPlugin, + HookPluginContext, + Reject, + PASSED, + REJECTED, + ERROR, + ALL_COMPLETED, + FIRST_COMPLETED, +) + +import pytest + + +class DummyPlugin(AbstractPlugin): + + def __init__(self, plugin_config, local_config) -> None: + super().__init__(plugin_config, local_config) + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, plugin_config: Mapping[str, Any]) -> None: + pass + + +class DummyEntrypoint: + + def __init__(self, name: str, load_result: Any) -> None: + self.name = name + self._load_result = load_result + + def load(self) -> Any: + return self._load_result + + +def mock_entrypoints_with_instance(plugin_group_name: str, *, mocked_plugin): + # Since mocked_plugin is already an instance constructed via AsyncMock, + # we emulate the original constructor using a lambda fucntion. + yield DummyEntrypoint('dummy', lambda plugin_config, local_config: mocked_plugin) + + +def mock_entrypoints_with_class(plugin_group_name: str, *, plugin_cls): + if isinstance(plugin_cls, list): + yield from (DummyEntrypoint(getattr(p, '_entrypoint_name', 'dummy'), p) for p in plugin_cls) + else: + yield DummyEntrypoint('dummy', plugin_cls) + + +@pytest.mark.asyncio +async def test_plugin_context_init_cleanup(etcd, mocker): + print('test plugin context init cleanup') + mocked_plugin = AsyncMock(DummyPlugin) + mocked_entrypoints = functools.partial(mock_entrypoints_with_instance, + mocked_plugin=mocked_plugin) + mocker.patch('ai.backend.common.plugin.pkg_resources.iter_entry_points', mocked_entrypoints) + ctx = BasePluginContext(etcd, {}) + try: + assert not ctx.plugins + await ctx.init() + assert ctx.plugins + ctx.plugins['dummy'].init.assert_awaited_once() + finally: + await ctx.cleanup() + ctx.plugins['dummy'].cleanup.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_plugin_context_config(etcd, mocker): + mocked_entrypoints = functools.partial(mock_entrypoints_with_class, plugin_cls=DummyPlugin) + mocker.patch('ai.backend.common.plugin.pkg_resources.iter_entry_points', mocked_entrypoints) + await etcd.put('config/plugins/XXX/dummy/etcd-key', 'etcd-value') + ctx = BasePluginContext( + etcd, + {'local-key': 'local-value'}, + ) + try: + assert not ctx.plugins + await ctx.init() + assert ctx.plugins + assert isinstance(ctx.plugins['dummy'], DummyPlugin) + ctx.plugins['dummy'].local_config['local-key'] == 'local-value' + ctx.plugins['dummy'].plugin_config['etcd-key'] == 'etcd-value' + finally: + await ctx.cleanup() + + +@pytest.mark.asyncio +async def test_plugin_context_config_autoupdate(etcd, mocker): + mocked_plugin = AsyncMock(DummyPlugin) + mocked_entrypoints = functools.partial(mock_entrypoints_with_instance, + mocked_plugin=mocked_plugin) + mocker.patch('ai.backend.common.plugin.pkg_resources.iter_entry_points', mocked_entrypoints) + await etcd.put_prefix('config/plugins/XXX/dummy', {'a': '1', 'b': '2'}) + ctx = BasePluginContext( + etcd, + {'local-key': 'local-value'}, + ) + try: + await ctx.init() + await asyncio.sleep(0.01) + await etcd.put_prefix('config/plugins/XXX/dummy', {'a': '3', 'b': '4'}) + await asyncio.sleep(0.6) # we should see the update only once + await etcd.put_prefix('config/plugins/XXX/dummy', {'a': '5', 'b': '6'}) + await asyncio.sleep(0.3) + print(mocked_plugin.update_plugin_config) + args_list = mocked_plugin.update_plugin_config.await_args_list + assert len(args_list) == 2 + assert args_list[0].args[0] == {'a': '3', 'b': '4'} + assert args_list[1].args[0] == {'a': '5', 'b': '6'} + finally: + await ctx.cleanup() + + +class DummyHookPassingPlugin(HookPlugin): + + config_watch_enabled = False + + _entrypoint_name = 'hook-p' + + def get_handlers(self): + return [ + ('HOOK1', self.hook1_handler), + ('HOOK2', self.hook2_handler), + ] + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, new_config: Mapping[str, Any]) -> None: + pass + + async def hook1_handler(self, arg1, arg2): + assert arg1 == 'a' + assert arg2 == 'b' + return 1 + + async def hook2_handler(self, arg1, arg2): + assert arg1 == 'c' + assert arg2 == 'd' + return 2 + + +class DummyHookRejectingPlugin(HookPlugin): + + config_watch_enabled = False + + _entrypoint_name = 'hook-r' + + def get_handlers(self): + return [ + ('HOOK1', self.hook1_handler), + ('HOOK2', self.hook2_handler), + ] + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, new_config: Mapping[str, Any]) -> None: + pass + + async def hook1_handler(self, arg1, arg2): + assert arg1 == 'a' + assert arg2 == 'b' + raise Reject('dummy rejected 1') + + async def hook2_handler(self, arg1, arg2): + assert arg1 == 'c' + assert arg2 == 'd' + return 3 + + +class DummyHookErrorPlugin(HookPlugin): + + config_watch_enabled = False + + _entrypoint_name = 'hook-e' + + def get_handlers(self): + return [ + ('HOOK3', self.hook3_handler), + ] + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, new_config: Mapping[str, Any]) -> None: + pass + + async def hook3_handler(self, arg1, arg2): + assert arg1 == 'e' + assert arg2 == 'f' + raise ZeroDivisionError('oops') + + +@pytest.mark.asyncio +async def test_hook_dispatch(etcd, mocker): + mocked_entrypoints = functools.partial( + mock_entrypoints_with_class, + plugin_cls=[DummyHookPassingPlugin, DummyHookRejectingPlugin, DummyHookErrorPlugin], + ) + mocker.patch('ai.backend.common.plugin.pkg_resources.iter_entry_points', mocked_entrypoints) + ctx = HookPluginContext(etcd, {}) + try: + await ctx.init() + + hook_result = await ctx.dispatch('HOOK1', ('a', 'b'), return_when=FIRST_COMPLETED) + assert hook_result.status == PASSED + assert hook_result.result == 1 + assert hook_result.reason is None + assert hook_result.src_plugin == 'hook-p' + + # If a plugin rejects when set ALL_COMPLETED, only the rejected result is returned. + hook_result = await ctx.dispatch('HOOK1', ('a', 'b'), return_when=ALL_COMPLETED) + assert hook_result.status == REJECTED + assert hook_result.result is None + assert hook_result.reason == 'dummy rejected 1' + assert hook_result.src_plugin == 'hook-r' + + # Even when all plguins pass, FIRST_COMPLETED executes only the first successful plugin. + hook_result = await ctx.dispatch('HOOK2', ('c', 'd'), return_when=FIRST_COMPLETED) + assert hook_result.status == PASSED + assert hook_result.result == 2 + assert hook_result.reason is None + assert hook_result.src_plugin == 'hook-p' + + # For when return_when=ALL_COMPLETED and all plugin succeeds, + # the caller may map the result returned as a list with src_plugin returned as a list. + hook_result = await ctx.dispatch('HOOK2', ('c', 'd'), return_when=ALL_COMPLETED) + assert hook_result.status == PASSED + assert hook_result.result == [2, 3] + assert hook_result.reason is None + assert hook_result.src_plugin == ['hook-p', 'hook-r'] + + # If a plugin raises an arbitrary exception other than Reject, it's marked as ERROR. + hook_result = await ctx.dispatch('HOOK3', ('e', 'f'), return_when=FIRST_COMPLETED) + assert hook_result.status == ERROR + assert hook_result.result is None + assert 'ZeroDivisionError' in hook_result.reason + assert 'oops' in hook_result.reason + assert hook_result.src_plugin == 'hook-e' + hook_result = await ctx.dispatch('HOOK3', ('e', 'f'), return_when=ALL_COMPLETED) + assert hook_result.status == ERROR + assert hook_result.result is None + assert 'ZeroDivisionError' in hook_result.reason + assert 'oops' in hook_result.reason + assert hook_result.src_plugin == 'hook-e' + finally: + await ctx.cleanup() + + +@pytest.mark.asyncio +async def test_hook_notify(etcd, mocker): + mocked_entrypoints = functools.partial( + mock_entrypoints_with_class, + plugin_cls=[DummyHookPassingPlugin, DummyHookRejectingPlugin, DummyHookErrorPlugin], + ) + mocker.patch('ai.backend.common.plugin.pkg_resources.iter_entry_points', mocked_entrypoints) + ctx = HookPluginContext(etcd, {}) + try: + await ctx.init() + # notify() should return successfully no matter a plugin rejects/fails or not. + hook_result = await ctx.notify('HOOK1', ('a', 'b')) + assert hook_result is None + hook_result = await ctx.notify('HOOK2', ('c', 'd')) + assert hook_result is None + hook_result = await ctx.notify('HOOK3', ('e', 'f')) + assert hook_result is None + finally: + await ctx.cleanup() diff --git a/tests/common/test_service_ports.py b/tests/common/test_service_ports.py new file mode 100644 index 0000000000..ac1b490c76 --- /dev/null +++ b/tests/common/test_service_ports.py @@ -0,0 +1,82 @@ +import pytest + +from ai.backend.common.service_ports import parse_service_ports + + +def test_parse_service_ports(): + result = parse_service_ports('') + assert len(result) == 0 + + result = parse_service_ports('a:http:1230') + assert len(result) == 1 + assert result[0] == { + 'name': 'a', 'protocol': 'http', + 'container_ports': (1230,), + 'host_ports': (None,), + } + + result = parse_service_ports('a:tcp:[5000,5005]') + assert len(result) == 1 + assert result[0] == { + 'name': 'a', 'protocol': 'tcp', + 'container_ports': (5000, 5005), + 'host_ports': (None, None), + } + + result = parse_service_ports('a:tcp:[1230,1240,9000],x:http:3000,t:http:[5000,5001]') + assert len(result) == 3 + assert result[0] == { + 'name': 'a', 'protocol': 'tcp', + 'container_ports': (1230, 1240, 9000), + 'host_ports': (None, None, None), + } + assert result[1] == { + 'name': 'x', 'protocol': 'http', + 'container_ports': (3000,), + 'host_ports': (None,), + } + assert result[2] == { + 'name': 't', 'protocol': 'http', + 'container_ports': (5000, 5001), + 'host_ports': (None, None), + } + + +def test_parse_service_ports_invalid_values(): + with pytest.raises(ValueError, match="Unsupported"): + parse_service_ports('x:unsupported:1234') + + with pytest.raises(ValueError, match="smaller than"): + parse_service_ports('x:http:65536') + + with pytest.raises(ValueError, match="larger than"): + parse_service_ports('x:http:1000') + + with pytest.raises(ValueError, match="Invalid format"): + parse_service_ports('x:http:-1') + + with pytest.raises(ValueError, match="Invalid format"): + parse_service_ports('abcdefg') + + with pytest.raises(ValueError, match="Invalid format"): + parse_service_ports('x:tcp:1234,abcdefg') + + with pytest.raises(ValueError, match="Invalid format"): + parse_service_ports('abcdefg,x:tcp:1234') + + with pytest.raises(ValueError, match="already used"): + parse_service_ports('x:tcp:1234,y:tcp:1234') + + with pytest.raises(ValueError, match="reserved"): + parse_service_ports('y:tcp:7711,x:tcp:2200') + + +def test_parse_service_ports_custom_exception(): + with pytest.raises(ZeroDivisionError): + parse_service_ports('x:unsupported:1234', ZeroDivisionError) + + +def test_parse_service_ports_ignore_pty(): + result = parse_service_ports('x:pty:1234,y:tcp:1235') + assert len(result) == 1 + assert result[0]['name'] == 'y' diff --git a/tests/common/test_types.py b/tests/common/test_types.py new file mode 100644 index 0000000000..528a187a61 --- /dev/null +++ b/tests/common/test_types.py @@ -0,0 +1,288 @@ +import asyncio +from decimal import Decimal +from ai.backend.common.types import ( + BinarySize, ResourceSlot, + DefaultForUnspecified, + aobject, + HardwareMetadata, + check_typed_dict, +) + +import pytest + + +@pytest.mark.asyncio +async def test_aobject(): + + init_count = 0 + ainit_count = 0 + + class MyBase(aobject): + def __init__(self, x: int) -> None: + nonlocal init_count + init_count += 1 + self.x = x + + async def __ainit__(self) -> None: + await asyncio.sleep(0.01) + nonlocal ainit_count + ainit_count += 1 + + class MyDerived(MyBase): + def __init__(self, x: int, y: int) -> None: + super().__init__(x) + nonlocal init_count + init_count += 1 + self.y = y + + async def __ainit__(self) -> None: + await super().__ainit__() + await asyncio.sleep(0.01) + nonlocal ainit_count + ainit_count += 1 + + init_count = 0 + ainit_count = 0 + o = await MyBase.new(1) + assert o.x == 1 + assert init_count == 1 + assert ainit_count == 1 + + init_count = 0 + ainit_count = 0 + o = await MyDerived.new(2, 3) + assert o.x == 2 + assert o.y == 3 + assert init_count == 2 + assert ainit_count == 2 + + +def test_check_typed_dict(): + with pytest.raises(TypeError): + check_typed_dict({}, {}) + with pytest.raises(AssertionError): + check_typed_dict({}, dict) + with pytest.raises(AssertionError): + check_typed_dict({}, int) + with pytest.raises(TypeError): + check_typed_dict({}, HardwareMetadata) + with pytest.raises(TypeError): + check_typed_dict({'status': 'oops', 'status_info': None, 'metadata': {}}, HardwareMetadata) + with pytest.raises(TypeError): + check_typed_dict({'status': 'healthy', 'status_info': None, 'metadata': {'a': 1}}, HardwareMetadata) + + a = check_typed_dict({'status': 'healthy', 'status_info': None, 'metadata': {'a': 'b'}}, HardwareMetadata) + assert isinstance(a, dict) + + +def test_binary_size(): + assert 1 == BinarySize.from_str('1 byte') + assert 19291991 == BinarySize.from_str(19291991) + with pytest.raises(ValueError): + BinarySize.from_str('1.1') + assert 1126 == BinarySize.from_str('1.1k') + assert 11021204 == BinarySize.from_str('11_021_204') + assert 12345 == BinarySize.from_str('12345 bytes') + assert 12345 == BinarySize.from_str('12345 B') + assert 12345 == BinarySize.from_str('12_345 bytes') + assert 99 == BinarySize.from_str('99 bytes') + assert 1024 == BinarySize.from_str('1 KiB') + assert 2048 == BinarySize.from_str('2 KiBytes') + assert 127303 == BinarySize.from_str('124.32 KiB') + assert str(BinarySize(1)) == '1 byte' + assert str(BinarySize(2)) == '2 bytes' + assert str(BinarySize(1024)) == '1 KiB' + assert str(BinarySize(2048)) == '2 KiB' + assert str(BinarySize(105935)) == '103.45 KiB' + assert str(BinarySize(127303)) == '124.32 KiB' + assert str(BinarySize(1048576)) == '1 MiB' + + x = BinarySize.from_str('inf') + assert isinstance(x, Decimal) + assert x.is_infinite() + with pytest.raises(ValueError): + BinarySize.finite_from_str('inf') + + # short-hand formats + assert 2 ** 30 == BinarySize.from_str('1g') + assert 1048576 == BinarySize.from_str('1m') + assert 524288 == BinarySize.from_str('0.5m') + assert 524288 == BinarySize.from_str('512k') + assert '{: }'.format(BinarySize(930)) == '930' + assert '{:k}'.format(BinarySize(1024)) == '1k' # type: ignore + assert '{:k}'.format(BinarySize(524288)) == '512k' # type: ignore + assert '{:k}'.format(BinarySize(1048576)) == '1024k' # type: ignore + assert '{:m}'.format(BinarySize(524288)) == '0.5m' # type: ignore + assert '{:m}'.format(BinarySize(1048576)) == '1m' # type: ignore + assert '{:g}'.format(BinarySize(2 ** 30)) == '1g' + with pytest.raises(ValueError): + '{:x}'.format(BinarySize(1)) + with pytest.raises(ValueError): + '{:qqqq}'.format(BinarySize(1)) + with pytest.raises(ValueError): + '{:}'.format(BinarySize(1)) + assert '{:s}'.format(BinarySize(930)) == '930' + assert '{:s}'.format(BinarySize(1024)) == '1k' + assert '{:s}'.format(BinarySize(524288)) == '512k' + assert '{:s}'.format(BinarySize(1048576)) == '1m' + assert '{:s}'.format(BinarySize(2 ** 30)) == '1g' + + +def test_resource_slot_serialization(): + + # from_user_input() and from_policy() takes the explicit slot type information to + # convert human-readable values to raw decimal values, + # while from_json() treats those values as stringified decimal expressions "as-is". + + st = {'a': 'count', 'b': 'bytes'} + r1 = ResourceSlot.from_user_input({'a': '1', 'b': '2g'}, st) + r2 = ResourceSlot.from_user_input({'a': '2', 'b': '1g'}, st) + r3 = ResourceSlot.from_user_input({'a': '1'}, st) + with pytest.raises(ValueError): + ResourceSlot.from_user_input({'x': '1'}, st) + + assert r1['a'] == Decimal(1) + assert r2['a'] == Decimal(2) + assert r3['a'] == Decimal(1) + assert r1['b'] == Decimal(2 * (2**30)) + assert r2['b'] == Decimal(1 * (2**30)) + assert r3['b'] == Decimal(0) + + x = r2 - r3 + assert x['a'] == Decimal(1) + assert x['b'] == Decimal(1 * (2**30)) + + # Conversely, to_json() stringifies the decimal values as-is, + # while to_humanized() takes the explicit slot type information + # to generate human-readable strings. + + assert r1.to_json() == {'a': '1', 'b': '2147483648'} + assert r2.to_json() == {'a': '2', 'b': '1073741824'} + assert r3.to_json() == {'a': '1', 'b': '0'} + assert r1.to_humanized(st) == {'a': '1', 'b': '2g'} + assert r2.to_humanized(st) == {'a': '2', 'b': '1g'} + assert r3.to_humanized(st) == {'a': '1', 'b': '0'} + assert r1 == ResourceSlot.from_json({'a': '1', 'b': '2147483648'}) + assert r2 == ResourceSlot.from_json({'a': '2', 'b': '1073741824'}) + assert r3 == ResourceSlot.from_json({'a': '1', 'b': '0'}) + + r4 = ResourceSlot.from_user_input({'a': Decimal('Infinity'), 'b': Decimal('-Infinity')}, st) + assert not r4['a'].is_finite() + assert not r4['b'].is_finite() + assert r4['a'] > 0 + assert r4['b'] < 0 + assert r4.to_humanized(st) == {'a': 'Infinity', 'b': '-Infinity'} + + # The result for "unspecified" fields may be different + # depending on the policy options. + + r1 = ResourceSlot.from_policy({ + 'total_resource_slots': {'a': '10'}, + 'default_for_unspecified': DefaultForUnspecified.UNLIMITED, + }, st) + assert r1['a'] == Decimal(10) + assert r1['b'] == Decimal('Infinity') + r2 = ResourceSlot.from_policy({ + 'total_resource_slots': {'a': '10'}, + 'default_for_unspecified': DefaultForUnspecified.LIMITED, + }, st) + assert r2['a'] == Decimal(10) + assert r2['b'] == Decimal(0) + + +def test_resource_slot_serialization_prevent_scientific_notation(): + r1 = ResourceSlot({'a': '2E+1', 'b': '200'}) + assert r1.to_json()['a'] == '20' + assert r1.to_json()['b'] == '200' + + +def test_resource_slot_serialization_filter_null(): + r1 = ResourceSlot({'a': '1', 'x': None}) + assert r1.to_json()['a'] == '1' + assert 'x' not in r1.to_json() + + +def test_resource_slot_serialization_typeless(): + r1 = ResourceSlot.from_user_input({'a': '1', 'cuda.mem': '2g'}, None) + assert r1['a'] == Decimal(1) + assert r1['cuda.mem'] == Decimal(2 * (2**30)) + + r1 = ResourceSlot.from_user_input({'a': 'inf', 'cuda.mem': 'inf'}, None) + assert r1['a'].is_infinite() + assert r1['cuda.mem'].is_infinite() + + with pytest.raises(ValueError): + r1 = ResourceSlot.from_user_input({'a': '1', 'cuda.smp': '2g'}, None) + + r1 = ResourceSlot.from_user_input({'a': 'inf', 'cuda.smp': 'inf'}, None) + assert r1['a'].is_infinite() + assert r1['cuda.smp'].is_infinite() + + +def test_resource_slot_comparison_simple_equality(): + r1 = ResourceSlot.from_json({'a': '3', 'b': '200'}) + r2 = ResourceSlot.from_json({'a': '4', 'b': '100'}) + r3 = ResourceSlot.from_json({'a': '2'}) + r4 = ResourceSlot.from_json({'a': '1'}) + r5 = ResourceSlot.from_json({'b': '100', 'a': '4'}) + assert r1 != r2 + assert r1 != r3 + assert r2 != r3 + assert r3 != r4 + assert r2 == r5 + + +def test_resource_slot_comparison_ordering(): + r1 = ResourceSlot.from_json({'a': '3', 'b': '200'}) + r2 = ResourceSlot.from_json({'a': '4', 'b': '100'}) + r3 = ResourceSlot.from_json({'a': '2'}) + r4 = ResourceSlot.from_json({'a': '1'}) + assert not r2 < r1 + assert not r2 <= r1 + assert r4 < r1 + assert r4 <= r1 + assert r4['b'] == 0 # auto-sync of slots + assert r3 < r1 + assert r3 <= r1 + assert r3['b'] == 0 # auto-sync of slots + + +def test_resource_slot_comparison_ordering_reverse(): + r1 = ResourceSlot.from_json({'a': '3', 'b': '200'}) + r2 = ResourceSlot.from_json({'a': '4', 'b': '100'}) + r3 = ResourceSlot.from_json({'a': '2'}) + r4 = ResourceSlot.from_json({'a': '1'}) + assert not r2 > r1 + assert not r2 >= r1 + assert r1 > r3 + assert r1 >= r3 + assert r3['b'] == 0 # auto-sync of slots + assert r1 > r4 + assert r1 >= r4 + assert r4['b'] == 0 # auto-sync of slots + + +def test_resource_slot_comparison_subset(): + r1 = ResourceSlot.from_json({'a': '3', 'b': '200'}) + r3 = ResourceSlot.from_json({'a': '3'}) + assert r3.eq_contained(r1) + assert not r3.eq_contains(r1) + assert not r1.eq_contained(r3) + assert r1.eq_contains(r3) + + +def test_resource_slot_calc_with_infinity(): + r1 = ResourceSlot.from_json({'a': 'Infinity'}) + r2 = ResourceSlot.from_json({'a': '3'}) + r3 = r1 - r2 + assert r3['a'] == Decimal('Infinity') + r3 = r1 + r2 + assert r3['a'] == Decimal('Infinity') + + r4 = ResourceSlot.from_json({'b': '5'}) + r5 = r1 - r4 + assert r5['a'] == Decimal('Infinity') + assert r5['b'] == -5 + r5 = r1 + r4 + assert r5['a'] == Decimal('Infinity') + assert r5['b'] == 5 diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py new file mode 100644 index 0000000000..b94cad7d0e --- /dev/null +++ b/tests/common/test_utils.py @@ -0,0 +1,329 @@ +import asyncio +from collections import OrderedDict +from datetime import timedelta +from pathlib import Path +from random import choice, randint +from string import ascii_uppercase +import sys +from tempfile import NamedTemporaryFile +from unittest import mock + +import aiohttp +import pytest + +from ai.backend.common.asyncio import AsyncBarrier, run_through +from ai.backend.common.enum_extension import StringSetFlag +from ai.backend.common.files import AsyncFileWriter +from ai.backend.common.networking import curl +from ai.backend.common.utils import ( + odict, dict2kvlist, nmget, + generate_uuid, get_random_seq, + readable_size_to_bytes, + str_to_timedelta, +) +from ai.backend.common.testutils import ( + mock_corofunc, mock_awaitable, AsyncContextManagerMock, +) + + +def test_odict() -> None: + assert odict(('a', 1), ('b', 2)) == OrderedDict([('a', 1), ('b', 2)]) + + +def test_dict2kvlist() -> None: + ret = list(dict2kvlist({'a': 1, 'b': 2})) + assert set(ret) == {'a', 1, 'b', 2} + + +def test_generate_uuid() -> None: + u = generate_uuid() + assert len(u) == 22 + assert isinstance(u, str) + + +def test_random_seq() -> None: + assert [*get_random_seq(10, 11, 1)] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert [*get_random_seq(10, 6, 2)] == [0, 2, 4, 6, 8, 10] + with pytest.raises(AssertionError): + [*get_random_seq(10, 12, 1)] + with pytest.raises(AssertionError): + [*get_random_seq(10, 7, 2)] + for _ in range(30): + result = [*get_random_seq(10, 9, 1)] + assert result[0] >= 0 + assert result[-1] <= 10 + last_x = result[0] + for x in result[1:]: + assert x > last_x + 1 + + +def test_nmget() -> None: + o = {'a': {'b': 1}, 'x': None} + assert nmget(o, 'a', 0) == {'b': 1} + assert nmget(o, 'a.b', 0) == 1 + assert nmget(o, 'a/b', 0, '/') == 1 + assert nmget(o, 'a.c', 0) == 0 + assert nmget(o, 'a.c', 100) == 100 + assert nmget(o, 'x', 0) == 0 + assert nmget(o, 'x', 0, null_as_default=False) is None + + +def test_readable_size_to_bytes() -> None: + assert readable_size_to_bytes(2) == 2 + assert readable_size_to_bytes('2') == 2 + assert readable_size_to_bytes('2K') == 2 * (2 ** 10) + assert readable_size_to_bytes('2k') == 2 * (2 ** 10) + assert readable_size_to_bytes('2M') == 2 * (2 ** 20) + assert readable_size_to_bytes('2m') == 2 * (2 ** 20) + assert readable_size_to_bytes('2G') == 2 * (2 ** 30) + assert readable_size_to_bytes('2g') == 2 * (2 ** 30) + assert readable_size_to_bytes('2T') == 2 * (2 ** 40) + assert readable_size_to_bytes('2t') == 2 * (2 ** 40) + assert readable_size_to_bytes('2P') == 2 * (2 ** 50) + assert readable_size_to_bytes('2p') == 2 * (2 ** 50) + assert readable_size_to_bytes('2E') == 2 * (2 ** 60) + assert readable_size_to_bytes('2e') == 2 * (2 ** 60) + assert readable_size_to_bytes('2Z') == 2 * (2 ** 70) + assert readable_size_to_bytes('2z') == 2 * (2 ** 70) + assert readable_size_to_bytes('2Y') == 2 * (2 ** 80) + assert readable_size_to_bytes('2y') == 2 * (2 ** 80) + with pytest.raises(ValueError): + readable_size_to_bytes('3A') + with pytest.raises(ValueError): + readable_size_to_bytes('TT') + + +def test_str_to_timedelta() -> None: + assert str_to_timedelta('1d2h3m4s') == timedelta(days=1, hours=2, minutes=3, seconds=4) + assert str_to_timedelta('1d2h3m') == timedelta(days=1, hours=2, minutes=3) + assert str_to_timedelta('1d2h') == timedelta(days=1, hours=2) + assert str_to_timedelta('1d') == timedelta(days=1) + assert str_to_timedelta('2h3m4s') == timedelta(hours=2, minutes=3, seconds=4) + assert str_to_timedelta('2h3m') == timedelta(hours=2, minutes=3) + assert str_to_timedelta('2h') == timedelta(hours=2) + assert str_to_timedelta('3m4s') == timedelta(minutes=3, seconds=4) + assert str_to_timedelta('3m') == timedelta(minutes=3) + assert str_to_timedelta('4s') == timedelta(seconds=4) + assert str_to_timedelta('4') == timedelta(seconds=4) + + assert str_to_timedelta('+1d2h3m4s') == timedelta(days=1, hours=2, minutes=3, seconds=4) + assert str_to_timedelta('-1d2h3m4s') == timedelta(days=-1, hours=-2, minutes=-3, seconds=-4) + assert str_to_timedelta('1day2hr3min4sec') == timedelta(days=1, hours=2, minutes=3, seconds=4) + assert str_to_timedelta('1day2hour3minute4second') == timedelta(days=1, hours=2, minutes=3, seconds=4) + assert str_to_timedelta('1day 2hour 3minute 4second') == timedelta(days=1, hours=2, minutes=3, seconds=4) + assert str_to_timedelta('1days 2hours 3minutes 4seconds') == timedelta(days=1, hours=2, minutes=3, seconds=4) + assert str_to_timedelta('0.1d0.2h0.3m0.4s') == timedelta(days=.1, hours=.2, minutes=.3, seconds=.4) + assert str_to_timedelta('1d 2h 3m 4s') == timedelta(days=1, hours=2, minutes=3, seconds=4) + assert str_to_timedelta('-1d 2h 3m 4s') == timedelta(days=-1, hours=-2, minutes=-3, seconds=-4) + assert str_to_timedelta('- 1d 2h 3m 4s') == timedelta(days=-1, hours=-2, minutes=-3, seconds=-4) + + with pytest.raises(ValueError): + assert str_to_timedelta('1da1hr') + with pytest.raises(ValueError): + assert str_to_timedelta('--1d2h3m4s') + with pytest.raises(ValueError): + assert str_to_timedelta('+') + with pytest.raises(ValueError): + assert str_to_timedelta('') + + +@pytest.mark.asyncio +async def test_curl_returns_stripped_body(mocker) -> None: + mock_get = mocker.patch.object(aiohttp.ClientSession, 'get') + mock_resp = {'status': 200, 'text': mock_corofunc(b'success ')} + mock_get.return_value = AsyncContextManagerMock(**mock_resp) + + resp = await curl('/test/url', '') + + body = await mock_resp['text']() + assert resp == body.strip() + + +@pytest.mark.asyncio +async def test_curl_returns_default_value_if_not_success(mocker) -> None: + mock_get = mocker.patch.object(aiohttp.ClientSession, 'get') + mock_resp = {'status': 400, 'text': mock_corofunc(b'bad request')} + mock_get.return_value = AsyncContextManagerMock(**mock_resp) + + # Value. + resp = await curl('/test/url', default_value='default') + assert resp == 'default' + + # Callable. + resp = await curl('/test/url', default_value=lambda: 'default') + assert resp == 'default' + + +def test_string_set_flag() -> None: + + class MyFlags(StringSetFlag): + A = 'a' + B = 'b' + + assert MyFlags.A in {'a', 'c'} + assert MyFlags.B not in {'a', 'c'} + + assert MyFlags.A == 'a' + assert MyFlags.A != 'b' + assert 'a' == MyFlags.A + assert 'b' != MyFlags.A + + assert {'a', 'b'} == MyFlags.A | MyFlags.B + assert {'a', 'b'} == MyFlags.A | 'b' + assert {'a', 'b'} == 'a' | MyFlags.B + assert {'a', 'b', 'c'} == {'b', 'c'} | MyFlags.A + assert {'a', 'b', 'c'} == MyFlags.A | {'b', 'c'} + + assert {'b', 'c'} == {'a', 'b', 'c'} ^ MyFlags.A + assert {'a', 'b', 'c'} == {'b', 'c'} ^ MyFlags.A + assert set() == MyFlags.A ^ 'a' + assert {'b'} == MyFlags.A ^ {'a', 'b'} + assert {'a', 'b', 'c'} == MyFlags.A ^ {'b', 'c'} + with pytest.raises(TypeError): + 123 & MyFlags.A # type: ignore[operator] + + assert {'a', 'c'} & MyFlags.A + assert not {'a', 'c'} & MyFlags.B + assert 'a' & MyFlags.A + assert not 'a' & MyFlags.B + assert MyFlags.A & 'a' + assert not MyFlags.A & 'b' + assert MyFlags.A & {'a', 'b'} + assert not MyFlags.A & {'b', 'c'} + + +class TestAsyncBarrier: + def test_async_barrier_initialization(self) -> None: + barrier = AsyncBarrier(num_parties=5) + + assert barrier.num_parties == 5 + assert barrier.cond is not None # default condition + + @pytest.mark.asyncio + async def test_wait_notify_all_if_cound_eq_num_parties(self, mocker) -> None: + mock_cond = mocker.patch.object(asyncio, 'Condition') + mock_resp = { + 'notify_all': mock.Mock(), + 'wait': await mock_awaitable(), + } + mock_cond.return_value = AsyncContextManagerMock(**mock_resp) + + barrier = AsyncBarrier(num_parties=1) + assert barrier.count == 0 + + await barrier.wait() + + assert barrier.count == 1 + # The methods are added at runtime. + mock_cond.return_value.notify_all.assert_called_once_with() # type: ignore + mock_cond.return_value.wait.assert_not_called() # type: ignore + + def test_async_barrier_reset(self): + barrier = AsyncBarrier(num_parties=5) + barrier.count = 5 + + assert barrier.count == 5 + barrier.reset() + assert barrier.count == 0 + + +@pytest.mark.asyncio +async def test_run_through() -> None: + + i = 0 + + async def do(): + nonlocal i + i += 1 + raise ZeroDivisionError + + def do_sync(): + nonlocal i + i += 1 + raise ZeroDivisionError + + await run_through( + do(), + do(), + do(), + ignored_exceptions=(ZeroDivisionError,), + ) + assert i == 3 + + with pytest.raises(ZeroDivisionError): + await run_through( + do(), + do(), + do(), + ignored_exceptions=(KeyError,), + ) + # only the addition is executed. + assert i == 4 + + await run_through( + do, # coroutine-function + do_sync, # function + lambda: do_sync(), # function wrapped with lambda + do(), # coroutine + ignored_exceptions=(ZeroDivisionError,), + ) + assert i == 8 + + +@pytest.mark.asyncio +async def test_async_file_writer_str() -> None: + # 1. Get temporary filename + with NamedTemporaryFile() as temp_file: + file_name = temp_file.name + + # 2. Generate random string + init_str = (''.join(choice(ascii_uppercase) for i in range(100))) + + # 3. Write chuncked decoded string into file + async with AsyncFileWriter( + target_filename=file_name, + access_mode='w', + encode=lambda v: v.upper().encode(), + max_chunks=1, + ) as file_writer: + for i in range(0, 100, 20): + await file_writer.write(init_str[i:i + 20]) + + # 4. Read string from the file and close it + with open(file_name, 'r') as f: + final_str = f.read() + Path(file_name).unlink() + + # 5. Check initial and final strings + assert init_str.upper() == final_str + + +@pytest.mark.asyncio +async def test_async_file_writer_bytes() -> None: + # 1. Get temporary filename + with NamedTemporaryFile() as temp_file: + file_name = temp_file.name + + # 2. Generate random binary data + init_data = (b''.join(randint(0, 255).to_bytes(1, sys.byteorder) for i in range(100))) + + def dummy_encode(v: str) -> bytes: + assert False, "should not be called" + + # 3. Write chuncked decoded string into file + async with AsyncFileWriter( + target_filename=file_name, + access_mode='wb', + encode=dummy_encode, + max_chunks=1, + ) as file_writer: + for i in range(0, 100, 20): + await file_writer.write(init_data[i:i + 20]) + + # 4. Read string from the file and close it + with open(file_name, 'rb') as f: + final_data = f.read() + Path(file_name).unlink() + + # 5. Check initial and final data + assert init_data == final_data diff --git a/tests/common/test_validators.py b/tests/common/test_validators.py new file mode 100644 index 0000000000..a33c424e35 --- /dev/null +++ b/tests/common/test_validators.py @@ -0,0 +1,469 @@ +from datetime import datetime, timedelta +import enum +from ipaddress import IPv4Address, ip_address +import multidict +import pickle +import os +import pwd + +from dateutil.relativedelta import relativedelta +import pytest +import trafaret as t +import yarl + +from ai.backend.common import validators as tx + + +def test_trafaret_dataerror_pickling(): + + with pytest.raises(t.DataError): + iv = t.Int() + iv.check('x') + + # Remove the already installed monkey-patch. + # (e.g., when running the whole test suite) + try: + if hasattr(t.DataError, '__reduce__'): + delattr(t.DataError, '__reduce__') + except AttributeError: + pass + + with pytest.raises(RuntimeError): + try: + iv = t.Int() + iv.check('x') + except t.DataError as e: + bindata = pickle.dumps(e) + pickle.loads(bindata) + + tx.fix_trafaret_pickle_support() + + try: + iv = t.Int() + iv.check('x') + except t.DataError as e: + bindata = pickle.dumps(e) + unpacked = pickle.loads(bindata) + assert unpacked.error == e.error + assert unpacked.name == e.name + assert unpacked.value == e.value + + +def test_aliased_key(): + iv = t.Dict({ + t.Key('x') >> 'z': t.Int, + tx.AliasedKey(['y', 'Y']): t.Int, + }) + assert iv.check({'x': 1, 'y': 2}) == {'z': 1, 'y': 2} + + with pytest.raises(t.DataError) as e: + iv.check({'x': 1}) + err_data = e.value.as_dict() + assert 'y' in err_data + assert "is required" in err_data['y'] + + with pytest.raises(t.DataError) as e: + iv.check({'y': 2}) + err_data = e.value.as_dict() + assert 'x' in err_data + assert "is required" in err_data['x'] + + with pytest.raises(t.DataError) as e: + iv.check({'x': 1, 'y': 'string'}) + err_data = e.value.as_dict() + assert 'y' in err_data + assert "can't be converted to int" in err_data['y'] + + with pytest.raises(t.DataError) as e: + iv.check({'x': 1, 'Y': 'string'}) + err_data = e.value.as_dict() + assert 'Y' in err_data + assert "can't be converted to int" in err_data['Y'] + + iv = t.Dict({ + t.Key('x', default=0): t.Int, + tx.AliasedKey(['y', 'Y'], default=1): t.Int, + }) + assert iv.check({'x': 5, 'Y': 6}) == {'x': 5, 'y': 6} + assert iv.check({'x': 5, 'y': 6}) == {'x': 5, 'y': 6} + assert iv.check({'y': 3}) == {'x': 0, 'y': 3} + assert iv.check({'Y': 3}) == {'x': 0, 'y': 3} + assert iv.check({'x': 3}) == {'x': 3, 'y': 1} + assert iv.check({}) == {'x': 0, 'y': 1} + + with pytest.raises(t.DataError) as e: + iv.check({'z': 99}) + err_data = e.value.as_dict() + assert 'z' in err_data + assert "not allowed key" in err_data['z'] + + +def test_multikey(): + iv = t.Dict({ + tx.MultiKey('x'): t.List(t.Int), + t.Key('y'): t.Int, + }) + + data = multidict.MultiDict() + data.add('x', 1) + data.add('x', 2) + data.add('y', 3) + result = iv.check(data) + assert result['x'] == [1, 2] + assert result['y'] == 3 + + data = multidict.MultiDict() + data.add('x', 1) + data.add('y', 3) + result = iv.check(data) + assert result['x'] == [1] + assert result['y'] == 3 + + plain_data = { + 'x': [10, 20], + 'y': 30, + } + result = iv.check(plain_data) + assert result['x'] == [10, 20] + assert result['y'] == 30 + + plain_data_nolist = { + 'x': 10, + 'y': 30, + } + result = iv.check(plain_data_nolist) + assert result['x'] == [10] + assert result['y'] == 30 + + +def test_multikey_string(): + iv = t.Dict({ + tx.MultiKey('x'): t.List(t.String), + t.Key('y'): t.String, + }) + + plain_data = { + 'x': ['abc'], + 'y': 'def', + } + result = iv.check(plain_data) + assert result['x'] == ['abc'] + assert result['y'] == 'def' + + plain_data_nolist = { + 'x': 'abc', + 'y': 'def', + } + result = iv.check(plain_data_nolist) + assert result['x'] == ['abc'] + assert result['y'] == 'def' + + +def test_binary_size(): + iv = tx.BinarySize() + assert iv.check('10M') == 10 * (2 ** 20) + assert iv.check('1K') == 1024 + assert iv.check(1058476) == 1058476 + + with pytest.raises(t.DataError): + iv.check('XX') + + +def test_binary_size_commutative_with_null(): + iv1 = t.Null | tx.BinarySize() + iv2 = tx.BinarySize() | t.Null + + iv1.check(None) + iv2.check(None) + + with pytest.raises(t.DataError): + iv1.check('xxxxx') + with pytest.raises(t.DataError): + iv2.check('xxxxx') + + +def test_delimiter_list(): + iv = tx.DelimiterSeperatedList(t.String, delimiter=':') + assert iv.check('aaa:bbb:ccc') == ['aaa', 'bbb', 'ccc'] + assert iv.check('xxx') == ['xxx'] + iv = tx.DelimiterSeperatedList(tx.HostPortPair, delimiter=',') + assert iv.check('127.0.0.1:6379,127.0.0.1:6380') == \ + [(ip_address('127.0.0.1'), 6379), (ip_address('127.0.0.1'), 6380)] + + +def test_string_list(): + iv = tx.StringList(delimiter=':') + assert iv.check('aaa:bbb:ccc') == ['aaa', 'bbb', 'ccc'] + assert iv.check(':bbb') == ['', 'bbb'] + assert iv.check('aaa:') == ['aaa', ''] + assert iv.check('xxx') == ['xxx'] + assert iv.check('') == [''] + assert iv.check(123) == ['123'] + + +def test_enum(): + + class MyTypes(enum.Enum): + TYPE1 = 1 + TYPE2 = 2 + + iv = tx.Enum(MyTypes) + assert iv.check(1) == MyTypes.TYPE1 + assert iv.check(2) == MyTypes.TYPE2 + with pytest.raises(t.DataError): + iv.check(3) + with pytest.raises(t.DataError): + iv.check('STRING') + + iv = tx.Enum(MyTypes, use_name=True) + assert iv.check('TYPE1') == MyTypes.TYPE1 + assert iv.check('TYPE2') == MyTypes.TYPE2 + with pytest.raises(t.DataError): + iv.check('TYPE3') + with pytest.raises(t.DataError): + iv.check(0) + + +def test_path(): + # TODO: write tests + pass + + +def test_host_port_pair(): + iv = tx.HostPortPair() + + p = iv.check(('127.0.0.1', 80)) + assert isinstance(p, tx._HostPortPair) + assert p.host == IPv4Address('127.0.0.1') + assert p.port == 80 + + p = iv.check('127.0.0.1:80') + assert isinstance(p, tx._HostPortPair) + assert p.host == IPv4Address('127.0.0.1') + assert p.port == 80 + + p = iv.check({'host': '127.0.0.1', 'port': 80}) + assert isinstance(p, tx._HostPortPair) + assert p.host == IPv4Address('127.0.0.1') + assert p.port == 80 + + p = iv.check({'host': '127.0.0.1', 'port': '80'}) + assert isinstance(p, tx._HostPortPair) + assert p.host == IPv4Address('127.0.0.1') + assert p.port == 80 + + p = iv.check(('mydomain.com', 443)) + assert isinstance(p, tx._HostPortPair) + assert p.host == 'mydomain.com' + assert p.port == 443 + + p = iv.check('mydomain.com:443') + assert isinstance(p, tx._HostPortPair) + assert p.host == 'mydomain.com' + assert p.port == 443 + + p = iv.check({'host': 'mydomain.com', 'port': 443}) + assert isinstance(p, tx._HostPortPair) + assert p.host == 'mydomain.com' + assert p.port == 443 + + p = iv.check({'host': 'mydomain.com', 'port': '443'}) + assert isinstance(p, tx._HostPortPair) + assert p.host == 'mydomain.com' + assert p.port == 443 + + with pytest.raises(t.DataError): + p = iv.check(('127.0.0.1', -1)) + with pytest.raises(t.DataError): + p = iv.check(('127.0.0.1', 0)) + with pytest.raises(t.DataError): + p = iv.check(('127.0.0.1', 65536)) + with pytest.raises(t.DataError): + p = iv.check('127.0.0.1:65536') + with pytest.raises(t.DataError): + p = iv.check(('', 80)) + with pytest.raises(t.DataError): + p = iv.check(':80') + with pytest.raises(t.DataError): + p = iv.check({}) + with pytest.raises(t.DataError): + p = iv.check({'host': 'x'}) + with pytest.raises(t.DataError): + p = iv.check({'port': 80}) + with pytest.raises(t.DataError): + p = iv.check({'host': '', 'port': 80}) + + +def test_port_range(): + iv = tx.PortRange() + + r = iv.check('1000-2000') + assert isinstance(r, tuple) + assert len(r) == 2 + assert r[0] == 1000 + assert r[1] == 2000 + + r = iv.check([1000, 2000]) + assert isinstance(r, tuple) + assert len(r) == 2 + assert r[0] == 1000 + assert r[1] == 2000 + + r = iv.check((1000, 2000)) + assert isinstance(r, tuple) + assert len(r) == 2 + assert r[0] == 1000 + assert r[1] == 2000 + + with pytest.raises(t.DataError): + r = iv.check([0, 1000]) + with pytest.raises(t.DataError): + r = iv.check([1000, 65536]) + with pytest.raises(t.DataError): + r = iv.check([2000, 1000]) + with pytest.raises(t.DataError): + r = iv.check('x-y') + + +def test_user_id(): + iv = tx.UserID() + assert iv.check(123) == 123 + assert iv.check('123') == 123 + assert iv.check(os.getuid()) == os.getuid() + assert iv.check(None) == os.getuid() + assert iv.check('') == os.getuid() + + iv = tx.UserID(default_uid=1) + assert iv.check(os.getuid()) == os.getuid() + assert iv.check(None) == 1 + assert iv.check(1) == 1 + assert iv.check(123) == 123 + assert iv.check('123') == 123 + assert iv.check('') == 1 + assert iv.check(pwd.getpwuid(os.getuid())[0]) == os.getuid() + assert iv.check(-1) == os.getuid() + assert iv.check('-1') == os.getuid() + + with pytest.raises(t.DataError): + iv.check('nonExistentUserName') + with pytest.raises(t.DataError): + iv.check([1, 2]) + with pytest.raises(t.DataError): + iv.check((1, 2)) + + +def test_slug(): + iv = tx.Slug() + assert iv.check('a') == 'a' + assert iv.check('0Z') == '0Z' + assert iv.check('abc') == 'abc' + assert iv.check('a-b') == 'a-b' + assert iv.check('a_b') == 'a_b' + + with pytest.raises(t.DataError): + iv.check('_') + with pytest.raises(t.DataError): + iv.check('') + + iv = tx.Slug(allow_dot=True) + assert iv.check('.a') == '.a' + assert iv.check('a') == 'a' + with pytest.raises(t.DataError): + iv.check('..a') + + iv = tx.Slug[:4] + assert iv.check('abc') == 'abc' + assert iv.check('abcd') == 'abcd' + with pytest.raises(t.DataError): + iv.check('abcde') + + iv = tx.Slug[4:] + with pytest.raises(t.DataError): + iv.check('abc') + assert iv.check('abcd') == 'abcd' + assert iv.check('abcde') == 'abcde' + + iv = tx.Slug[2:4] + with pytest.raises(t.DataError): + iv.check('a') + assert iv.check('ab') == 'ab' + assert iv.check('abcd') == 'abcd' + with pytest.raises(t.DataError): + iv.check('abcde') + + iv = tx.Slug[2:2] + with pytest.raises(t.DataError): + iv.check('a') + assert iv.check('ab') == 'ab' + with pytest.raises(t.DataError): + iv.check('abc') + + with pytest.raises(TypeError): + tx.Slug[2:1] + with pytest.raises(TypeError): + tx.Slug[-1:] + with pytest.raises(TypeError): + tx.Slug[:-1] + + +def test_json_string(): + iv = tx.JSONString() + assert iv.check('{}') == {} + assert iv.check('{"a":123}') == {'a': 123} + assert iv.check('[]') == [] + with pytest.raises(ValueError): + iv.check('x') + + +def test_time_duration(): + iv = tx.TimeDuration() + date = datetime(2020, 2, 29) + with pytest.raises(t.DataError): + iv.check('') + assert iv.check(0) == timedelta(seconds=0) + assert iv.check(10) == timedelta(seconds=10) + assert iv.check(86400.55) == timedelta(days=1, microseconds=550000) + assert iv.check('1w') == timedelta(weeks=1) + assert iv.check('1d') == timedelta(days=1) + assert iv.check('0.5d') == timedelta(hours=12) + assert iv.check('1h') == timedelta(hours=1) + assert iv.check('1m') == timedelta(minutes=1) + assert iv.check('1') == timedelta(seconds=1) + assert iv.check('0.5h') == timedelta(minutes=30) + assert iv.check('0.001') == timedelta(milliseconds=1) + assert iv.check('1yr') == relativedelta(years=1) + assert iv.check('1mo') == relativedelta(months=1) + assert date + iv.check('4yr') == date + relativedelta(years=4) + with pytest.raises(t.DataError): + iv.check('-1') + with pytest.raises(t.DataError): + iv.check('a') + with pytest.raises(t.DataError): + iv.check('xxh') + + +def test_time_duration_negative(): + iv = tx.TimeDuration(allow_negative=True) + with pytest.raises(t.DataError): + iv.check('') + assert iv.check('0.5h') == timedelta(minutes=30) + assert iv.check('0.001') == timedelta(milliseconds=1) + assert iv.check('-1') == timedelta(seconds=-1) + assert iv.check('-3d') == timedelta(days=-3) + assert iv.check('-1yr') == relativedelta(years=-1) + assert iv.check('-1mo') == relativedelta(months=-1) + with pytest.raises(t.DataError): + iv.check('-a') + with pytest.raises(t.DataError): + iv.check('-xxh') + + +def test_url(): + iv = tx.URL() + with pytest.raises(t.DataError): + iv.check('') + with pytest.raises(t.DataError): + iv.check('example.com') + assert iv.check('https://example.com') == yarl.URL('https://example.com') + iv = tx.URL(scheme_required=False) + assert iv.check('example.com') == yarl.URL('example.com') diff --git a/tests/manager/BUILD b/tests/manager/BUILD new file mode 100644 index 0000000000..dacb555d9b --- /dev/null +++ b/tests/manager/BUILD @@ -0,0 +1,24 @@ +python_test_utils( + name="test_utils", + sources=[ + "conftest.py", + "model_factory.py", + ], + dependencies=[ + ":fixtures", + ], +) + +python_tests( + name="tests", + dependencies=[ + "src/ai/backend/manager:service", + "src/ai/backend/testutils:lib", + "//:reqs#aiosqlite", + ], +) + +files( + name="fixtures", + sources=["fixtures/*"], +) diff --git a/tests/manager/__init__.py b/tests/manager/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/manager/api/BUILD b/tests/manager/api/BUILD new file mode 100644 index 0000000000..2c1c32f7e8 --- /dev/null +++ b/tests/manager/api/BUILD @@ -0,0 +1,6 @@ +python_tests( + name="tests", + dependencies=[ + "src/ai/backend/manager:service", + ], +) diff --git a/tests/manager/api/test_auth.py b/tests/manager/api/test_auth.py new file mode 100644 index 0000000000..aa6f4f5189 --- /dev/null +++ b/tests/manager/api/test_auth.py @@ -0,0 +1,126 @@ +from datetime import datetime, timedelta +import json +from unittest.mock import MagicMock +import uuid + +from aiohttp import web +from dateutil.tz import tzutc, gettz +import pytest + +from ai.backend.manager.api.auth import _extract_auth_params, check_date +from ai.backend.manager.api.exceptions import InvalidAuthParameters +from ai.backend.manager.server import ( + database_ctx, + event_dispatcher_ctx, + hook_plugin_ctx, + monitoring_ctx, + redis_ctx, + shared_config_ctx, + +) + + +def test_extract_auth_params(): + request = MagicMock(spec=web.Request) + + request.headers = {} + assert _extract_auth_params(request) is None + + request.headers = {'Authorization': 'no-space'} + with pytest.raises(InvalidAuthParameters): + _extract_auth_params(request) + + request.headers = {'Authorization': ('BadAuthType signMethod=HMAC-SHA256,' + 'credential=fake-ak:fake-sig')} + with pytest.raises(InvalidAuthParameters): + _extract_auth_params(request) + + request.headers = {'Authorization': ('BackendAI signMethod=HMAC-SHA256,' + 'credential=fake-ak:fake-sig')} + ret = _extract_auth_params(request) + assert ret is not None + assert ret[0] == 'HMAC-SHA256' + assert ret[1] == 'fake-ak' + assert ret[2] == 'fake-sig' + + +def test_check_date(): + # UserDict allows attribute assignment like types.SimpleNamespace + # but also works like a plain dict. + request = MagicMock(spec=web.Request) + + request.headers = {'X-Nothing': ''} + assert not check_date(request) + + now = datetime.now(tzutc()) + request.headers = {'Date': now.isoformat()} + assert check_date(request) + + # Timestamps without timezone info + request.headers = {'Date': f'{now:%Y%m%dT%H:%M:%S}'} + assert check_date(request) + + request.headers = {'Date': (now - timedelta(minutes=14, seconds=55)).isoformat()} + assert check_date(request) + request.headers = {'Date': (now + timedelta(minutes=14, seconds=55)).isoformat()} + assert check_date(request) + + request.headers = {'Date': (now - timedelta(minutes=15, seconds=5)).isoformat()} + assert not check_date(request) + request.headers = {'Date': (now + timedelta(minutes=15, seconds=5)).isoformat()} + assert not check_date(request) + + # RFC822-style date formatting used in plain HTTP + request.headers = {'Date': '{:%a, %d %b %Y %H:%M:%S GMT}'.format(now)} + assert check_date(request) + + # RFC822-style date formatting used in plain HTTP with a non-UTC timezone + now_kst = now.astimezone(gettz('Asia/Seoul')) + request.headers = {'Date': '{:%a, %d %b %Y %H:%M:%S %Z}'.format(now_kst)} + assert check_date(request) + now_est = now.astimezone(gettz('America/Panama')) + request.headers = {'Date': '{:%a, %d %b %Y %H:%M:%S %Z}'.format(now_est)} + assert check_date(request) + + request.headers = {'Date': 'some-unrecognizable-malformed-date-time'} + assert not check_date(request) + + request.headers = {'X-BackendAI-Date': now.isoformat()} + assert check_date(request) + + +@pytest.mark.asyncio +async def test_authorize(etcd_fixture, database_fixture, create_app_and_client, get_headers): + # The auth module requires config_server and database to be set up. + app, client = await create_app_and_client( + [ + shared_config_ctx, + redis_ctx, + event_dispatcher_ctx, + database_ctx, + monitoring_ctx, + hook_plugin_ctx, + ], + ['.auth']) + + async def do_authorize(hash_type, api_version): + url = '/auth/test' + req_data = {'echo': str(uuid.uuid4())} + req_bytes = json.dumps(req_data).encode() + headers = get_headers( + 'POST', + url, + req_bytes, + hash_type=hash_type, + api_version=api_version, + ) + resp = await client.post(url, data=req_bytes, headers=headers) + assert resp.status == 200 + data = json.loads(await resp.text()) + assert data['authorized'] == 'yes' + assert data['echo'] == req_data['echo'] + + # Try multiple different hashing schemes + await do_authorize('sha256', 'v5.20191215') + await do_authorize('sha256', 'v4.20190615') + await do_authorize('sha1', 'v4.20190615') diff --git a/tests/manager/api/test_bgtask.py b/tests/manager/api/test_bgtask.py new file mode 100644 index 0000000000..afcf360219 --- /dev/null +++ b/tests/manager/api/test_bgtask.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import asyncio + +from aiohttp import web +import attr +import pytest + +from ai.backend.common import redis +from ai.backend.common.events import ( + BgtaskDoneEvent, + BgtaskFailedEvent, + BgtaskUpdatedEvent, + EventDispatcher, + EventProducer, +) +from ai.backend.common.types import ( + AgentId, +) +from ai.backend.manager.api.context import RootContext +from ai.backend.manager.server import ( + shared_config_ctx, event_dispatcher_ctx, background_task_ctx, +) + + +@pytest.mark.asyncio +async def test_background_task(etcd_fixture, create_app_and_client) -> None: + app, client = await create_app_and_client( + [shared_config_ctx, event_dispatcher_ctx, background_task_ctx], + ['.events'], + ) + root_ctx: RootContext = app['_root.context'] + producer: EventProducer = root_ctx.event_producer + dispatcher: EventDispatcher = root_ctx.event_dispatcher + update_handler_ctx = {} + done_handler_ctx = {} + + async def update_sub( + context: web.Application, + source: AgentId, + event: BgtaskUpdatedEvent, + ) -> None: + # Copy the arguments to the uppser scope + # since assertions inside the handler does not affect the test result + # because the handlers are executed inside a separate asyncio task. + update_handler_ctx['event_name'] = event.name + update_handler_ctx.update(**attr.asdict(event)) + + async def done_sub( + context: web.Application, + source: AgentId, + event: BgtaskDoneEvent, + ) -> None: + done_handler_ctx['event_name'] = event.name + done_handler_ctx.update(**attr.asdict(event)) + + async def _mock_task(reporter): + reporter.total_progress = 2 + await asyncio.sleep(1) + await reporter.update(1, message='BGTask ex1') + await asyncio.sleep(0.5) + await reporter.update(1, message='BGTask ex2') + return 'hooray' + + dispatcher.subscribe(BgtaskUpdatedEvent, app, update_sub) + dispatcher.subscribe(BgtaskDoneEvent, app, done_sub) + task_id = await root_ctx.background_task_manager.start(_mock_task, name='MockTask1234') + await asyncio.sleep(2) + + try: + assert update_handler_ctx['task_id'] == task_id + assert update_handler_ctx['event_name'] == 'bgtask_updated' + assert update_handler_ctx['total_progress'] == 2 + assert update_handler_ctx['message'] in ['BGTask ex1', 'BGTask ex2'] + if update_handler_ctx['message'] == 'BGTask ex1': + assert update_handler_ctx['current_progress'] == 1 + else: + assert update_handler_ctx['current_progress'] == 2 + assert done_handler_ctx['task_id'] == task_id + assert done_handler_ctx['event_name'] == 'bgtask_done' + assert done_handler_ctx['message'] == 'hooray' + finally: + await redis.execute(producer.redis_client, lambda r: r.flushdb()) + await producer.close() + await dispatcher.close() + + +@pytest.mark.asyncio +async def test_background_task_fail(etcd_fixture, create_app_and_client) -> None: + app, client = await create_app_and_client( + [shared_config_ctx, event_dispatcher_ctx, background_task_ctx], + ['.events'], + ) + root_ctx: RootContext = app['_root.context'] + producer: EventProducer = root_ctx.event_producer + dispatcher: EventDispatcher = root_ctx.event_dispatcher + fail_handler_ctx = {} + + async def fail_sub( + context: web.Application, + source: AgentId, + event: BgtaskFailedEvent, + ) -> None: + fail_handler_ctx['event_name'] = event.name + fail_handler_ctx.update(**attr.asdict(event)) + + async def _mock_task(reporter): + reporter.total_progress = 2 + await asyncio.sleep(1) + await reporter.update(1, message='BGTask ex1') + raise ZeroDivisionError('oops') + + dispatcher.subscribe(BgtaskFailedEvent, app, fail_sub) + task_id = await root_ctx.background_task_manager.start(_mock_task, name='MockTask1234') + await asyncio.sleep(2) + try: + assert fail_handler_ctx['task_id'] == task_id + assert fail_handler_ctx['event_name'] == 'bgtask_failed' + assert fail_handler_ctx['message'] is not None + assert 'ZeroDivisionError' in fail_handler_ctx['message'] + finally: + await redis.execute(producer.redis_client, lambda r: r.flushdb()) + await producer.close() + await dispatcher.close() diff --git a/tests/manager/api/test_config.py b/tests/manager/api/test_config.py new file mode 100644 index 0000000000..2902afbab5 --- /dev/null +++ b/tests/manager/api/test_config.py @@ -0,0 +1,21 @@ +from unittest.mock import AsyncMock + +import pytest + + +@pytest.mark.asyncio +async def test_register_myself(shared_config, mocker): + instance_id = 'i-test-manager' + from ai.backend.manager import config as config_mod + mocked_get_instance_id = AsyncMock(return_value=instance_id) + mocker.patch.object(config_mod, 'get_instance_id', mocked_get_instance_id) + + await shared_config.register_myself() + mocked_get_instance_id.await_count == 1 + data = await shared_config.etcd.get_prefix(f'nodes/manager/{instance_id}') + assert data[''] == 'up' + + await shared_config.deregister_myself() + mocked_get_instance_id.await_count == 2 + data = await shared_config.etcd.get_prefix(f'nodes/manager/{instance_id}') + assert len(data) == 0 diff --git a/tests/manager/api/test_exceptions.py b/tests/manager/api/test_exceptions.py new file mode 100644 index 0000000000..51b18258af --- /dev/null +++ b/tests/manager/api/test_exceptions.py @@ -0,0 +1,52 @@ +import json +import pickle + +from ai.backend.manager.api.exceptions import BackendError, BackendAgentError +from ai.backend.common.utils import odict + + +def test_backend_error_obj(): + eobj = BackendError() + assert eobj.args == (eobj.status_code, eobj.reason, eobj.error_type) + assert eobj.body == json.dumps(odict( + ('type', eobj.error_type), ('title', eobj.error_title), + )).encode() + + extra_msg = '!@#$' + eobj = BackendError(extra_msg) + assert extra_msg in str(eobj) + assert extra_msg in repr(eobj) + + +def test_backend_error_obj_pickle(): + eobj = BackendError() + encoded = pickle.dumps(eobj) + decoded = pickle.loads(encoded) + assert eobj.status_code == decoded.status_code + assert eobj.error_type == decoded.error_type + assert eobj.error_title == decoded.error_title + assert eobj.content_type == decoded.content_type + assert eobj.extra_msg == decoded.extra_msg + + +def test_backend_agent_error_obj(): + eobj = BackendAgentError('timeout') + + assert eobj.args == (eobj.status_code, eobj.reason, + eobj.error_type, eobj.agent_error_type) + assert eobj.body == json.dumps(odict( + ('type', eobj.error_type), + ('title', eobj.error_title), + ('agent-details', odict( + ('type', eobj.agent_error_type), + ('title', eobj.agent_error_title), + )), + )).encode() + + +def test_backend_agent_error_obj_pickle(): + eobj = BackendAgentError('timeout') + encoded = pickle.dumps(eobj) + decoded = pickle.loads(encoded) + assert eobj.body == decoded.body + assert eobj.agent_details == decoded.agent_details diff --git a/tests/manager/api/test_middlewares.py b/tests/manager/api/test_middlewares.py new file mode 100644 index 0000000000..4943061efc --- /dev/null +++ b/tests/manager/api/test_middlewares.py @@ -0,0 +1,134 @@ +from aiohttp import web + +from ai.backend.manager.api.utils import method_placeholder +from ai.backend.manager.server import api_middleware + + +async def test_api_method_override(aiohttp_client): + observed_method = None + app = web.Application() + + async def service_handler(request): + nonlocal observed_method + observed_method = request.method + return web.Response(body=b'test') + + app.router.add_route('POST', r'/test', + method_placeholder('REPORT')) + app.router.add_route('REPORT', r'/test', + service_handler) + app.middlewares.append(api_middleware) + client = await aiohttp_client(app) + + # native method + resp = await client.request('REPORT', '/test') + assert resp.status == 200 + assert (await resp.read()) == b'test' + assert observed_method == 'REPORT' + + # overriden method + observed_method = None + resp = await client.post('/test', headers={ + 'X-Method-Override': 'REPORT', + }) + assert resp.status == 200 + assert (await resp.read()) == b'test' + assert observed_method == 'REPORT' + + # calling placeholder + observed_method = None + resp = await client.post('/test') + assert resp.status == 405 + assert observed_method is None + + # calling with non-relevant method + observed_method = None + resp = await client.delete('/test') + assert resp.status == 405 + assert observed_method is None + + +async def test_api_method_override_with_different_ops(aiohttp_client): + observed_method = None + app = web.Application() + + async def op1_handler(request): + nonlocal observed_method + observed_method = request.method + return web.Response(body=b'op1') + + async def op2_handler(request): + nonlocal observed_method + observed_method = request.method + return web.Response(body=b'op2') + + app.router.add_route('POST', r'/test', op1_handler) + app.router.add_route('REPORT', r'/test', op2_handler) + app.middlewares.append(api_middleware) + client = await aiohttp_client(app) + + # native method + resp = await client.request('POST', '/test') + assert resp.status == 200 + assert (await resp.read()) == b'op1' + assert observed_method == 'POST' + + # native method + observed_method = None + resp = await client.request('REPORT', '/test') + assert resp.status == 200 + assert (await resp.read()) == b'op2' + assert observed_method == 'REPORT' + + # overriden method + observed_method = None + resp = await client.request('REPORT', '/test', headers={ + 'X-Method-Override': 'POST', + }) + assert resp.status == 200 + assert (await resp.read()) == b'op1' + assert observed_method == 'POST' + + # overriden method + observed_method = None + resp = await client.request('POST', '/test', headers={ + 'X-Method-Override': 'REPORT', + }) + assert resp.status == 200 + assert (await resp.read()) == b'op2' + assert observed_method == 'REPORT' + + +async def test_api_ver(aiohttp_client): + inner_request = None + app = web.Application() + + async def dummy_handler(request): + nonlocal inner_request + inner_request = request + return web.Response(body=b'test') + + app.router.add_post(r'/test', dummy_handler) + app.middlewares.append(api_middleware) + client = await aiohttp_client(app) + + # normal call + resp = await client.post('/test', headers={ + 'X-BackendAI-Version': 'v5.20191215', + }) + assert resp.status == 200 + assert inner_request['api_version'][0] == 5 + + # normal call with different version + resp = await client.post('/test', headers={ + 'X-BackendAI-Version': 'v4.20190615', + }) + assert resp.status == 200 + assert inner_request['api_version'][0] == 4 + + # calling with invalid/deprecated version + resp = await client.post('/test', headers={ + 'X-BackendAI-Version': 'v2.20170315', + }) + assert resp.status == 400 + assert 'Unsupported' in (await resp.json())['msg'] diff --git a/tests/manager/api/test_ratelimit.py b/tests/manager/api/test_ratelimit.py new file mode 100644 index 0000000000..78b5b4b4f3 --- /dev/null +++ b/tests/manager/api/test_ratelimit.py @@ -0,0 +1,67 @@ +import json + +import pytest + +from ai.backend.manager.server import ( + database_ctx, + event_dispatcher_ctx, + hook_plugin_ctx, + monitoring_ctx, + redis_ctx, + shared_config_ctx, +) +import ai.backend.manager.api.ratelimit as rlim + + +@pytest.mark.asyncio +async def test_check_rlim_for_anonymous_query( + etcd_fixture, + database_fixture, + create_app_and_client, +): + app, client = await create_app_and_client( + [ + shared_config_ctx, + redis_ctx, + event_dispatcher_ctx, + database_ctx, + monitoring_ctx, + hook_plugin_ctx, + ], + ['.auth', '.ratelimit'], + ) + ret = await client.get('/') + assert ret.status == 200 + assert '1000' == ret.headers['X-RateLimit-Limit'] + assert '1000' == ret.headers['X-RateLimit-Remaining'] + assert str(rlim._rlim_window) == ret.headers['X-RateLimit-Window'] + + +@pytest.mark.asyncio +async def test_check_rlim_for_authorized_query( + etcd_fixture, + database_fixture, + create_app_and_client, + get_headers, +): + app, client = await create_app_and_client( + [ + shared_config_ctx, + redis_ctx, + event_dispatcher_ctx, + database_ctx, + monitoring_ctx, + hook_plugin_ctx, + ], + ['.auth', '.ratelimit'], + ) + url = '/auth/test' + req_bytes = json.dumps({'echo': 'hello!'}).encode() + headers = get_headers('POST', url, req_bytes) + ret = await client.post(url, data=req_bytes, headers=headers) + + assert ret.status == 200 + # The default example keypair's ratelimit is 30000. + assert '30000' == ret.headers['X-RateLimit-Limit'] + assert '29999' == ret.headers['X-RateLimit-Remaining'] + assert str(rlim._rlim_window) == ret.headers['X-RateLimit-Window'] diff --git a/tests/manager/api/test_utils.py b/tests/manager/api/test_utils.py new file mode 100644 index 0000000000..cfe12b85a5 --- /dev/null +++ b/tests/manager/api/test_utils.py @@ -0,0 +1,85 @@ +import asyncio + +import pytest + +from ai.backend.manager.models import verify_dotfile_name, verify_vfolder_name +from ai.backend.manager.api.utils import ( + call_non_bursty, + mask_sensitive_keys, +) + + +@pytest.mark.asyncio +async def test_call_non_bursty(): + key = 'x' + execution_count = 0 + + async def execute(): + nonlocal execution_count + await asyncio.sleep(0) + execution_count += 1 + + # ensure reset + await asyncio.sleep(0.11) + + # check run as coroutine + execution_count = 0 + with pytest.raises(TypeError): + await call_non_bursty(key, execute()) + + # check run as coroutinefunction + execution_count = 0 + await call_non_bursty(key, execute) + assert execution_count == 1 + await asyncio.sleep(0.11) + + # check burstiness control + execution_count = 0 + for _ in range(129): + await call_non_bursty(key, execute) + assert execution_count == 3 + await asyncio.sleep(0.01) + await call_non_bursty(key, execute) + assert execution_count == 3 + await asyncio.sleep(0.11) + await call_non_bursty(key, execute) + assert execution_count == 4 + for _ in range(64): + await call_non_bursty(key, execute) + assert execution_count == 5 + + +def test_vfolder_name_validator(): + assert not verify_vfolder_name('.bashrc') + assert not verify_vfolder_name('.terminfo') + assert verify_vfolder_name('bashrc') + assert verify_vfolder_name('.config') + assert verify_vfolder_name('bin') + assert verify_vfolder_name('boot') + assert verify_vfolder_name('root') + assert not verify_vfolder_name('/bin') + assert not verify_vfolder_name('/boot') + assert not verify_vfolder_name('/root') + assert verify_vfolder_name('/home/work/bin') + assert verify_vfolder_name('/home/work/boot') + assert verify_vfolder_name('/home/work/root') + assert verify_vfolder_name('home/work') + + +def test_dotfile_name_validator(): + assert not verify_dotfile_name('.terminfo') + assert not verify_dotfile_name('.config') + assert not verify_dotfile_name('.ssh/authorized_keys') + assert verify_dotfile_name('.bashrc') + assert verify_dotfile_name('.ssh/id_rsa') + + +def test_mask_sensitive_keys(): + a = {'a': 123, 'my-Secret': 'hello'} + b = mask_sensitive_keys(a) + # original is untouched + assert a['a'] == 123 + assert a['my-Secret'] == 'hello' + # cloned has masked fields + assert b['a'] == 123 + assert b['my-Secret'] == '***' diff --git a/tests/manager/conftest.py b/tests/manager/conftest.py new file mode 100644 index 0000000000..08448658b1 --- /dev/null +++ b/tests/manager/conftest.py @@ -0,0 +1,714 @@ +import asyncio +import hashlib, hmac +import json +import os +import secrets +import shutil +import tempfile +import textwrap +from datetime import datetime +from functools import partial +from pathlib import Path +from typing import ( + Any, + AsyncContextManager, + AsyncIterator, + Iterator, + List, + Mapping, + Sequence, + Tuple, + Type, +) +from unittest.mock import MagicMock, AsyncMock +from urllib.parse import quote_plus as urlquote + +import aiohttp +from aiohttp import web +from dateutil.tz import tzutc +import sqlalchemy as sa +import pytest +from sqlalchemy.ext.asyncio.engine import AsyncEngine as SAEngine + +from ai.backend.common.config import ConfigurationError, etcd_config_iv, redis_config_iv +from ai.backend.common.plugin.hook import HookPluginContext +from ai.backend.common.types import HostPortPair +from ai.backend.manager.api.context import RootContext +from ai.backend.manager.cli.context import CLIContext, init_logger +from ai.backend.manager.cli.dbschema import oneshot as cli_schema_oneshot +from ai.backend.manager.cli.etcd import ( + put_json as cli_etcd_put_json, + delete as cli_etcd_delete, +) +from ai.backend.manager.config import LocalConfig, SharedConfig, load as load_config +from ai.backend.manager.server import ( + build_root_app, +) +from ai.backend.manager.api.types import ( + CleanupContext, +) +from ai.backend.manager.models.base import populate_fixture, pgsql_connect_opts +from ai.backend.manager.models import ( + domains, + scaling_groups, + agents, + kernels, + keypairs, + users, + vfolders, +) +from ai.backend.manager.models.utils import connect_database +from ai.backend.manager.registry import AgentRegistry +from ai.backend.testutils.bootstrap import ( # noqa: F401 + etcd_container, + redis_container, + postgres_container, +) +from ai.backend.testutils.pants import get_parallel_slot + +here = Path(__file__).parent + + +@pytest.fixture(scope='session', autouse=True) +def test_id(): + return secrets.token_hex(12) + + +@pytest.fixture(scope='session', autouse=True) +def test_ns(test_id): + ret = f'testing-ns-{test_id}' + os.environ['BACKEND_NAMESPACE'] = ret + return ret + + +@pytest.fixture(scope='session') +def test_db(test_id): + return f'test_db_{test_id}' + + +@pytest.fixture(scope='session') +def vfolder_mount(test_id): + ret = Path.cwd() / f'tmp/backend.ai/manager-testing/vfolders-{test_id}' + ret.mkdir(parents=True, exist_ok=True) + yield ret + try: + shutil.rmtree(ret.parent) + except IOError: + pass + + +@pytest.fixture(scope='session') +def vfolder_fsprefix(test_id): + # NOTE: the prefix must NOT start with "/" + return Path('fsprefix/inner/') + + +@pytest.fixture(scope='session') +def vfolder_host(): + return 'local' + + +@pytest.fixture(scope='session') +def local_config( + test_id, + etcd_container, # noqa: F811 + redis_container, # noqa: F811 + postgres_container, # noqa: F811 + test_db, +) -> Iterator[LocalConfig]: + ipc_base_path = Path.cwd() / f'tmp/backend.ai/manager-testing/ipc-{test_id}' + ipc_base_path.mkdir(parents=True, exist_ok=True) + etcd_addr = etcd_container[1] + redis_addr = redis_container[1] + postgres_addr = postgres_container[1] + + # Establish a self-contained config. + cfg = LocalConfig({ + **etcd_config_iv.check({ + 'etcd': { + 'namespace': test_id, + 'addr': {'host': etcd_addr.host, 'port': etcd_addr.port}, + }, + }), + 'redis': redis_config_iv.check({ + 'addr': {'host': redis_addr.host, 'port': redis_addr.port}, + }), + 'db': { + 'addr': postgres_addr, + 'name': test_db, + 'user': 'postgres', + 'password': 'develove', + }, + 'manager': { + 'id': f"i-{test_id}", + 'num-proc': 1, + 'distributed-lock': 'filelock', + 'ipc-base-path': ipc_base_path, + 'service-addr': HostPortPair('127.0.0.1', 29100 + get_parallel_slot() * 10), + }, + 'debug': { + 'enabled': False, + 'log-events': False, + 'log-scheduler-ticks': False, + 'periodic-sync-stats': False, + }, + 'logging': { + 'drivers': ['console'], + 'console': {'colored': False, 'format': 'verbose'}, + }, + }) + + def _override_if_exists(src: dict, dst: dict, key: str) -> None: + sentinel = object() + if (val := src.get(key, sentinel)) is not sentinel: + dst[key] = val + + try: + # Override external database config with the current environment's config. + fs_local_config = load_config() + cfg['etcd']['addr'] = fs_local_config['etcd']['addr'] + _override_if_exists(fs_local_config['etcd'], cfg['etcd'], 'user') + _override_if_exists(fs_local_config['etcd'], cfg['etcd'], 'password') + cfg['redis']['addr'] = fs_local_config['redis']['addr'] + _override_if_exists(fs_local_config['redis'], cfg['redis'], 'password') + cfg['db']['addr'] = fs_local_config['db']['addr'] + _override_if_exists(fs_local_config['db'], cfg['db'], 'user') + _override_if_exists(fs_local_config['db'], cfg['db'], 'password') + except ConfigurationError: + pass + yield cfg + try: + shutil.rmtree(ipc_base_path) + except IOError: + pass + + +@pytest.fixture(scope='session') +def etcd_fixture(test_id, local_config, vfolder_mount, vfolder_fsprefix, vfolder_host) -> Iterator[None]: + # Clear and reset etcd namespace using CLI functions. + redis_addr = local_config['redis']['addr'] + cli_ctx = CLIContext( + logger=init_logger(local_config, nested=True), + local_config=local_config, + ) + with tempfile.NamedTemporaryFile(mode='w', suffix='.etcd.json') as f: + etcd_fixture = { + 'volumes': { + '_mount': str(vfolder_mount), + '_fsprefix': str(vfolder_fsprefix), + '_default_host': str(vfolder_host), + }, + 'nodes': { + }, + 'config': { + 'docker': { + 'registry': { + 'cr.backend.ai': { + '': 'https://cr.backend.ai', + 'type': 'harbor2', + 'project': 'stable', + }, + }, + }, + 'redis': { + 'addr': f"{redis_addr.host}:{redis_addr.port}", + }, + 'plugins': { + 'cloudia': { + 'base_url': '127.0.0.1:8090', + 'user': 'fake-cloudia-user@lablup.com', + 'password': 'fake-password', + }, + }, + }, + } + json.dump(etcd_fixture, f) + f.flush() + click_ctx = cli_etcd_put_json.make_context( + 'test', ['', f.name], obj=cli_ctx, + ) + click_ctx.obj = cli_ctx + cli_etcd_put_json.invoke(click_ctx) + yield + click_ctx = cli_etcd_delete.make_context( + 'test', ['--prefix', ''], obj=cli_ctx, + ) + cli_etcd_delete.invoke(click_ctx) + + +@pytest.fixture +async def shared_config(app, etcd_fixture): + root_ctx: RootContext = app['_root.context'] + shared_config = SharedConfig( + root_ctx.local_config['etcd']['addr'], + root_ctx.local_config['etcd']['user'], + root_ctx.local_config['etcd']['password'], + root_ctx.local_config['etcd']['namespace'], + ) + await shared_config.reload() + root_ctx: RootContext = app['_root.context'] + root_ctx.shared_config = shared_config + yield shared_config + + +@pytest.fixture(scope='session') +def database(request, local_config, test_db): + """ + Create a new database for the current test session + and install the table schema using alembic. + """ + db_addr = local_config['db']['addr'] + db_user = local_config['db']['user'] + db_pass = local_config['db']['password'] + + # Create database using low-level psycopg2 API. + # Temporarily use "testing" dbname until we create our own db. + if db_pass: + db_url = f'postgresql+asyncpg://{urlquote(db_user)}:{urlquote(db_pass)}@{db_addr}/testing' + else: + db_url = f'postgresql+asyncpg://{urlquote(db_user)}@{db_addr}/testing' + + async def init_db(): + engine = sa.ext.asyncio.create_async_engine( + db_url, + connect_args=pgsql_connect_opts, + isolation_level="AUTOCOMMIT", + ) + async with engine.connect() as conn: + await conn.execute(sa.text(f'CREATE DATABASE "{test_db}";')) + await engine.dispose() + + asyncio.run(init_db()) + + async def finalize_db(): + engine = sa.ext.asyncio.create_async_engine( + db_url, + connect_args=pgsql_connect_opts, + isolation_level="AUTOCOMMIT", + ) + async with engine.connect() as conn: + await conn.execute(sa.text(f'REVOKE CONNECT ON DATABASE "{test_db}" FROM public;')) + await conn.execute(sa.text('SELECT pg_terminate_backend(pid) FROM pg_stat_activity ' + 'WHERE pid <> pg_backend_pid();')) + await conn.execute(sa.text(f'DROP DATABASE "{test_db}";')) + await engine.dispose() + + request.addfinalizer(lambda: asyncio.run(finalize_db())) + + alembic_config_template = textwrap.dedent(""" + [alembic] + script_location = ai.backend.manager.models:alembic + sqlalchemy.url = {sqlalchemy_url:s} + + [loggers] + keys = root + + [logger_root] + level = WARNING + handlers = console + + [handlers] + keys = console + + [handler_console] + class = StreamHandler + args = (sys.stdout,) + formatter = simple + level = INFO + + [formatters] + keys = simple + + [formatter_simple] + format = [%(name)s] %(message)s + """).strip() + + # Load the database schema using CLI function. + cli_ctx = CLIContext( + logger=init_logger(local_config, nested=True), + local_config=local_config, + ) + sqlalchemy_url = f'postgresql://{db_user}:{db_pass}@{db_addr}/{test_db}' + with tempfile.NamedTemporaryFile(mode='w', encoding='utf8') as alembic_cfg: + alembic_cfg_data = alembic_config_template.format( + sqlalchemy_url=sqlalchemy_url, + ) + alembic_cfg.write(alembic_cfg_data) + alembic_cfg.flush() + click_ctx = cli_schema_oneshot.make_context( + 'test', ['-f', alembic_cfg.name], obj=cli_ctx, + ) + cli_schema_oneshot.invoke(click_ctx) + + +@pytest.fixture() +async def database_engine(local_config, database): + async with connect_database(local_config) as db: + yield db + + +@pytest.fixture() +def database_fixture(local_config, test_db, database): + """ + Populate the example data as fixtures to the database + and delete them after use. + """ + db_addr = local_config['db']['addr'] + db_user = local_config['db']['user'] + db_pass = local_config['db']['password'] + db_url = f'postgresql+asyncpg://{db_user}:{urlquote(db_pass)}@{db_addr}/{test_db}' + + fixtures = {} + # NOTE: The fixtures must be loaded in the order that they are present. + # Normal dicts on Python 3.6 or later guarantees the update ordering. + fixtures.update(json.loads( + (Path(__file__).parent / + 'fixtures' / 'example-keypairs.json').read_text(), + )) + fixtures.update(json.loads( + (Path(__file__).parent / + 'fixtures' / 'example-resource-presets.json').read_text(), + )) + + async def init_fixture(): + engine: SAEngine = sa.ext.asyncio.create_async_engine( + db_url, + connect_args=pgsql_connect_opts, + ) + try: + await populate_fixture(engine, fixtures) + finally: + await engine.dispose() + + asyncio.run(init_fixture()) + + yield + + async def clean_fixture(): + engine: SAEngine = sa.ext.asyncio.create_async_engine( + db_url, + connect_args=pgsql_connect_opts, + ) + try: + async with engine.begin() as conn: + await conn.execute((vfolders.delete())) + await conn.execute((kernels.delete())) + await conn.execute((agents.delete())) + await conn.execute((keypairs.delete())) + await conn.execute((users.delete())) + await conn.execute((scaling_groups.delete())) + await conn.execute((domains.delete())) + finally: + await engine.dispose() + + asyncio.run(clean_fixture()) + + +@pytest.fixture +def file_lock_factory(local_config, request): + from ai.backend.common.lock import FileLock + + def _make_lock(lock_id): + lock_path = local_config['manager']['ipc-base-path'] / f'testing.{lock_id}.lock' + lock = FileLock(lock_path, timeout=0) + request.addfinalizer(partial(lock_path.unlink, missing_ok=True)) + return lock + + return _make_lock + + +class Client: + def __init__(self, session: aiohttp.ClientSession, url) -> None: + self._session = session + if not url.endswith('/'): + url += '/' + self._url = url + + def request(self, method, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.request(method, url, **kwargs) + + def get(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.get(url, **kwargs) + + def post(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.post(url, **kwargs) + + def put(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.put(url, **kwargs) + + def patch(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.patch(url, **kwargs) + + def delete(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.delete(url, **kwargs) + + def ws_connect(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.ws_connect(url, **kwargs) + + +@pytest.fixture +async def app(local_config, event_loop): + """ + Create an empty application with the test configuration. + """ + return build_root_app( + 0, + local_config, + cleanup_contexts=[], + subapp_pkgs=[], + ) + + +@pytest.fixture +async def create_app_and_client(local_config, event_loop) -> AsyncIterator: + client: Client | None = None + client_session: aiohttp.ClientSession | None = None + runner: web.BaseRunner | None = None + _outer_ctxs: List[AsyncContextManager] = [] + + async def app_builder( + cleanup_contexts: Sequence[CleanupContext] = None, + subapp_pkgs: Sequence[str] = None, + scheduler_opts: Mapping[str, Any] = None, + ) -> Tuple[web.Application, Client]: + nonlocal client, client_session, runner + nonlocal _outer_ctxs + + if scheduler_opts is None: + scheduler_opts = {} + _cleanup_ctxs = [] + _outer_ctx_classes: List[Type[AsyncContextManager]] = [] + if cleanup_contexts is not None: + for ctx in cleanup_contexts: + # if isinstance(ctx, AsyncContextManager): + if ctx.__name__ in ['shared_config_ctx', 'webapp_plugins_ctx']: + _outer_ctx_classes.append(ctx) # type: ignore + else: + _cleanup_ctxs.append(ctx) + app = build_root_app( + 0, + local_config, + cleanup_contexts=_cleanup_ctxs, + subapp_pkgs=subapp_pkgs, + scheduler_opts={ + 'close_timeout': 10, + **scheduler_opts, + }, + ) + root_ctx: RootContext = app['_root.context'] + for octx_cls in _outer_ctx_classes: + octx = octx_cls(root_ctx) # type: ignore + _outer_ctxs.append(octx) + await octx.__aenter__() + runner = web.AppRunner(app, handle_signals=False) + await runner.setup() + site = web.TCPSite( + runner, + str(root_ctx.local_config['manager']['service-addr'].host), + root_ctx.local_config['manager']['service-addr'].port, + reuse_port=True, + ) + await site.start() + port = root_ctx.local_config['manager']['service-addr'].port + client_session = aiohttp.ClientSession() + client = Client(client_session, f'http://127.0.0.1:{port}') + return app, client + + yield app_builder + + if client_session is not None: + await client_session.close() + if runner is not None: + await runner.cleanup() + for octx in reversed(_outer_ctxs): + await octx.__aexit__(None, None, None) + + +@pytest.fixture +def default_keypair(): + return { + 'access_key': 'AKIAIOSFODNN7EXAMPLE', + 'secret_key': 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', + } + + +@pytest.fixture +def default_domain_keypair(): + """Default domain admin keypair""" + return { + 'access_key': 'AKIAHUKCHDEZGEXAMPLE', + 'secret_key': 'cWbsM_vBB4CzTW7JdORRMx8SjGI3-wEXAMPLEKEY', + } + + +@pytest.fixture +def user_keypair(): + return { + 'access_key': 'AKIANABBDUSEREXAMPLE', + 'secret_key': 'C8qnIo29EZvXkPK_MXcuAakYTy4NYrxwmCEyNPlf', + } + + +@pytest.fixture +def monitor_keypair(): + return { + 'access_key': 'AKIANAMONITOREXAMPLE', + 'secret_key': '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', + } + + +@pytest.fixture +def get_headers(app, default_keypair): + def create_header( + method, + url, + req_bytes, + ctype='application/json', + hash_type='sha256', + api_version='v5.20191215', + keypair=default_keypair, + ) -> dict[str, str]: + now = datetime.now(tzutc()) + root_ctx: RootContext = app['_root.context'] + hostname = f"127.0.0.1:{root_ctx.local_config['manager']['service-addr'].port}" + headers = { + 'Date': now.isoformat(), + 'Content-Type': ctype, + 'Content-Length': str(len(req_bytes)), + 'X-BackendAI-Version': api_version, + } + if api_version >= 'v4.20181215': + req_bytes = b'' + else: + if ctype.startswith('multipart'): + req_bytes = b'' + if ctype.startswith('multipart'): + # Let aiohttp to create appropriate header values + # (e.g., multipart content-type header with message boundaries) + del headers['Content-Type'] + del headers['Content-Length'] + req_hash = hashlib.new(hash_type, req_bytes).hexdigest() + sign_bytes = method.upper().encode() + b'\n' \ + + url.encode() + b'\n' \ + + now.isoformat().encode() + b'\n' \ + + b'host:' + hostname.encode() + b'\n' \ + + b'content-type:' + ctype.encode() + b'\n' \ + + b'x-backendai-version:' + api_version.encode() + b'\n' \ + + req_hash.encode() + sign_key = hmac.new(keypair['secret_key'].encode(), + now.strftime('%Y%m%d').encode(), hash_type).digest() + sign_key = hmac.new(sign_key, hostname.encode(), hash_type).digest() + signature = hmac.new(sign_key, sign_bytes, hash_type).hexdigest() + headers['Authorization'] = \ + f'BackendAI signMethod=HMAC-{hash_type.upper()}, ' \ + + f'credential={keypair["access_key"]}:{signature}' + return headers + return create_header + + +@pytest.fixture +async def prepare_kernel(request, create_app_and_client, + get_headers, default_keypair): + sess_id = f'test-kernel-session-{secrets.token_hex(8)}' + app, client = await create_app_and_client( + modules=['etcd', 'events', 'auth', 'vfolder', + 'admin', 'ratelimit', 'kernel', 'stream', 'manager'], + spawn_agent=True) + root_ctx: RootContext = app['_root.context'] + + async def create_kernel(image='lua:5.3-alpine', tag=None): + url = '/v3/kernel/' + req_bytes = json.dumps({ + 'image': image, + 'tag': tag, + 'clientSessionToken': sess_id, + }).encode() + headers = get_headers('POST', url, req_bytes) + response = await client.post(url, data=req_bytes, headers=headers) + return await response.json() + + yield app, client, create_kernel + + access_key = default_keypair['access_key'] + try: + await root_ctx.registry.destroy_session(sess_id, access_key) + except Exception: + pass + + +class DummyEtcd: + async def get_prefix(self, key: str) -> Mapping[str, Any]: + return {} + + +@pytest.fixture +async def registry_ctx(mocker): + mock_shared_config = MagicMock() + mock_shared_config.update_resource_slots = AsyncMock() + mock_shared_config.etcd = None + mock_db = MagicMock() + mock_dbconn = MagicMock() + mock_dbconn_ctx = MagicMock() + mock_dbresult = MagicMock() + mock_dbresult.rowcount = 1 + mock_db.connect = MagicMock(return_value=mock_dbconn_ctx) + mock_db.begin = MagicMock(return_value=mock_dbconn_ctx) + mock_dbconn_ctx.__aenter__ = AsyncMock(return_value=mock_dbconn) + mock_dbconn_ctx.__aexit__ = AsyncMock() + mock_dbconn.execute = AsyncMock(return_value=mock_dbresult) + mock_dbconn.begin = MagicMock(return_value=mock_dbconn_ctx) + mock_redis_stat = MagicMock() + mock_redis_live = MagicMock() + mock_redis_live.hset = AsyncMock() + mock_redis_image = MagicMock() + mock_event_dispatcher = MagicMock() + mock_event_producer = MagicMock() + mock_event_producer.produce_event = AsyncMock() + mocked_etcd = DummyEtcd() + # mocker.object.patch(mocked_etcd, 'get_prefix', AsyncMock(return_value={})) + hook_plugin_ctx = HookPluginContext(mocked_etcd, {}) # type: ignore + + registry = AgentRegistry( + shared_config=mock_shared_config, + db=mock_db, + redis_stat=mock_redis_stat, + redis_live=mock_redis_live, + redis_image=mock_redis_image, + event_dispatcher=mock_event_dispatcher, + event_producer=mock_event_producer, + storage_manager=None, # type: ignore + hook_plugin_ctx=hook_plugin_ctx, + ) + await registry.init() + try: + yield ( + registry, + mock_dbconn, + mock_dbresult, + mock_shared_config, + mock_event_dispatcher, + mock_event_producer, + ) + finally: + await registry.shutdown() diff --git a/tests/manager/fixtures/example-keypairs.json b/tests/manager/fixtures/example-keypairs.json new file mode 100644 index 0000000000..5d76bc821c --- /dev/null +++ b/tests/manager/fixtures/example-keypairs.json @@ -0,0 +1,172 @@ +{ + "domains": [ + { + "name": "default", + "description": "The default domain", + "is_active": true, + "total_resource_slots": {}, + "allowed_vfolder_hosts": {}, + "allowed_docker_registries": ["cr.backend.ai"] + } + ], + "groups": [ + { + "id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "name": "default", + "description": "The default user group", + "is_active": true, + "domain_name": "default", + "total_resource_slots": {}, + "allowed_vfolder_hosts": {} + } + ], + "scaling_groups": [ + { + "name": "default", + "description": "The default agent scaling group", + "is_active": true, + "driver": "static", + "driver_opts": {}, + "scheduler": "fifo", + "scheduler_opts": {} + } + ], + "sgroups_for_domains": [ + { + "scaling_group": "default", + "domain": "default" + } + ], + "users": [ + { + "uuid": "f38dea23-50fa-42a0-b5ae-338f5f4693f4", + "username": "admin", + "email": "admin@lablup.com", + "password": "wJalrXUt", + "need_password_change": false, + "full_name": "Admin Lablup", + "description": "Lablup's Admin Account", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "role": "superadmin" + }, + { + "uuid": "4f13d193-f646-425a-a340-270c4d2b9860", + "username": "domain-admin", + "email": "domain-admin@lablup.com", + "password": "cWbsM_vB", + "need_password_change": false, + "full_name": "Default Domain Admin Lablup", + "description": "Lablup's Default Domain Admin Account", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "role": "admin" + }, + { + "uuid": "dfa9da54-4b28-432f-be29-c0d680c7a412", + "username": "user", + "email": "user@lablup.com", + "password": "C8qnIo29", + "need_password_change": false, + "full_name": "User Lablup", + "description": "Lablup's User Account", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "role": "user" + }, + { + "uuid": "2e10157d-20ca-4bd0-9806-3f909cbcd0e6", + "username": "monitor", + "email": "monitor@lablup.com", + "password": "7tuEwF1J", + "need_password_change": false, + "full_name": "Monitor Lablup", + "description": "Lablup's Monitor Account", + "status": "active", + "status_info": "admin-requested", + "domain_name": "default", + "role": "monitor" + } + ], + "association_groups_users": [ + { + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_id": "f38dea23-50fa-42a0-b5ae-338f5f4693f4" + }, + { + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_id": "4f13d193-f646-425a-a340-270c4d2b9860" + }, + { + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_id": "dfa9da54-4b28-432f-be29-c0d680c7a412" + }, + { + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_id": "2e10157d-20ca-4bd0-9806-3f909cbcd0e6" + } + ], + "keypair_resource_policies": [ + { + "name": "default", + "default_for_unspecified": "UNLIMITED", + "total_resource_slots": {}, + "max_session_lifetime": 0, + "max_concurrent_sessions": 5, + "max_containers_per_session": 1, + "max_vfolder_count": 10, + "max_vfolder_size": 0, + "idle_timeout": 3600, + "allowed_vfolder_hosts": ["local:volume1"] + } + ], + "keypairs": [ + { + "user_id": "admin@lablup.com", + "access_key": "AKIAIOSFODNN7EXAMPLE", + "secret_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "is_active": true, + "resource_policy": "default", + "rate_limit": 30000, + "num_queries": 0, + "is_admin": true, + "user": "f38dea23-50fa-42a0-b5ae-338f5f4693f4" + }, + { + "user_id": "domain-admin@lablup.com", + "access_key": "AKIAHUKCHDEZGEXAMPLE", + "secret_key": "cWbsM_vBB4CzTW7JdORRMx8SjGI3-wEXAMPLEKEY", + "is_active": true, + "resource_policy": "default", + "rate_limit": 30000, + "num_queries": 0, + "is_admin": true, + "user": "4f13d193-f646-425a-a340-270c4d2b9860" + }, + { + "user_id": "user@lablup.com", + "access_key": "AKIANABBDUSEREXAMPLE", + "secret_key": "C8qnIo29EZvXkPK_MXcuAakYTy4NYrxwmCEyNPlf", + "is_active": true, + "resource_policy": "default", + "rate_limit": 30000, + "num_queries": 0, + "is_admin": false, + "user": "dfa9da54-4b28-432f-be29-c0d680c7a412" + }, + { + "user_id": "monitor@lablup.com", + "access_key": "AKIANAMONITOREXAMPLE", + "secret_key": "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", + "is_active": true, + "resource_policy": "default", + "rate_limit": 30000, + "num_queries": 0, + "is_admin": false, + "user": "2e10157d-20ca-4bd0-9806-3f909cbcd0e6" + } + ] +} diff --git a/tests/manager/fixtures/example-resource-presets.json b/tests/manager/fixtures/example-resource-presets.json new file mode 100644 index 0000000000..63c4169013 --- /dev/null +++ b/tests/manager/fixtures/example-resource-presets.json @@ -0,0 +1,33 @@ +{ + "resource_presets": [ + { + "name": "01-small", + "resource_slots": { + "cpu": "8", + "mem": "34359738368", + "cuda.device": "1", + "cuda.shares": "0.5" + } + }, + { + "name": "02-medium", + "resource_slots": { + "cpu": "24", + "mem": "171798691840", + "cuda.device": "2", + "cuda.shares": "2.0" + }, + "shared_memory": "1073741824" + }, + { + "name": "03-large", + "resource_slots": { + "cpu": "64", + "mem": "343597383680", + "cuda.device": "4", + "cuda.shares": "4.0" + }, + "shared_memory": "2147483648" + } + ] +} diff --git a/tests/manager/fixtures/example-session-templates.json b/tests/manager/fixtures/example-session-templates.json new file mode 100755 index 0000000000..4e69cca076 --- /dev/null +++ b/tests/manager/fixtures/example-session-templates.json @@ -0,0 +1,71 @@ +{ + "session_templates": [ + { + "id": "c1b8441a-ba46-4a83-8727-de6645f521b4", + "is_active": true, + "domain_name": "default", + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_uuid": "f38dea23-50fa-42a0-b5ae-338f5f4693f4", + "type": "TASK", + "name": "jupyter", + "template": { + "api_version": "6", + "kind": "task_template", + "metadata": { + "name": "cr.backend.ai/testing/ngc-pytorch", + "tag": "20.11-py3" + }, + "spec": { + "session_type": "interactive", + "kernel": { + "image": "cr.backend.ai/testing/ngc-pytorch:20.11-py3", + "environ": {}, + "run": null, + "git": null + }, + "scaling_group": "default", + "mounts": { + }, + "resources": { + "cpu": "2", + "mem": "4g", + "cuda.shares": "0.2" + } + } + } + }, + { + "id": "59062449-4f57-4434-975d-add2a593438c", + "is_active": true, + "domain_name": "default", + "group_id": "2de2b969-1d04-48a6-af16-0bc8adb3c831", + "user_uuid": "f38dea23-50fa-42a0-b5ae-338f5f4693f4", + "type": "TASK", + "name": "rstudio", + "template": { + "api_version": "6", + "kind": "task_template", + "metadata": { + "name": "cr.backend.ai/cloud/r-base", + "tag": "4.0" + }, + "spec": { + "session_type": "interactive", + "kernel": { + "image": "cr.backend.ai/cloud/r-base:4.0", + "environ": {}, + "run": null, + "git": null + }, + "scaling_group": "default", + "mounts": { + }, + "resources": { + "cpu": "1", + "mem": "2g" + } + } + } + } + ] +} diff --git a/tests/manager/model_factory.py b/tests/manager/model_factory.py new file mode 100644 index 0000000000..444de336b2 --- /dev/null +++ b/tests/manager/model_factory.py @@ -0,0 +1,194 @@ +from abc import ABC, abstractmethod +import sqlalchemy as sa +import uuid + +from ai.backend.manager.api.context import RootContext +import ai.backend.manager.models as models + + +def get_random_string(length=10): + return uuid.uuid4().hex[:length] + + +class ModelFactory(ABC): + + model = None + app = None + defaults = None + + def __init__(self, app): + self.app = app + + @abstractmethod + def get_creation_defaults(self): + return {} + + async def before_creation(self): + pass + + async def after_creation(self, row): + return row + + async def create(self, **kwargs): + self.defaults = self.get_creation_defaults() + self.defaults.update(**kwargs) + await self.before_creation() + root_ctx: RootContext = self.app['_root.context'] + async with root_ctx.db.begin() as conn: + query = (self.model.insert().returning(self.model).values(self.defaults)) + result = await conn.execute(query) + row = result.first() + row = dict(row.items()) + row = await self.after_creation(row) + return row + + async def get(self, **kwargs): + root_ctx: RootContext = self.app['_root.context'] + async with root_ctx.db.begin() as conn: + filters = [sa.sql.column(key) == value for key, value in kwargs.items()] + query = sa.select([self.model]).where(sa.and_(*filters)) + result = await conn.execute(query) + rows = result.fetchall() + assert len(rows) < 2, 'Multiple items found' + return rows[0] if len(rows) == 1 else None + + async def list(self, **kwargs): + root_ctx: RootContext = self.app['_root.context'] + async with root_ctx.db.begin() as conn: + filters = [sa.sql.column(key) == value for key, value in kwargs.items()] + query = sa.select([self.model]).where(sa.and_(*filters)) + result = await conn.execute(query) + return result.fetchall() + + +class KeyPairFactory(ModelFactory): + + model = models.keypairs + + def get_creation_defaults(self, **kwargs): + from ai.backend.manager.models.keypair import generate_keypair + ak, sk = generate_keypair() + return { + 'access_key': ak, + 'secret_key': sk, + 'is_active': True, + 'is_admin': False, + 'resource_policy': 'default', + } + + async def before_creation(self): + assert 'user_id' in self.defaults and 'user' in self.defaults, \ + 'user_id and user should be provided to create a keypair' + + +class UserFactory(ModelFactory): + + model = models.users + + def get_creation_defaults(self, **kwargs): + username = f'test-user-{get_random_string()}' + return { + 'username': username, + 'email': username + '@lablup.com', + 'password': get_random_string(), + 'domain_name': 'default', + } + + async def after_creation(self, row): + kp = await KeyPairFactory(self.app).create(user_id=row['email'], user=row['uuid']) + row['keypair'] = { + 'access_key': kp['access_key'], + 'secret_key': kp['secret_key'], + } + return row + + +class DomainFactory(ModelFactory): + + model = models.domains + + def get_creation_defaults(self, **kwargs): + return { + 'name': f'test-domain-{get_random_string()}', + 'total_resource_slots': {}, + } + + +class GroupFactory(ModelFactory): + + model = models.groups + + def get_creation_defaults(self, **kwargs): + return { + 'name': f'test-group-{get_random_string()}', + 'domain_name': 'default', + 'total_resource_slots': {}, + } + + +class AssociationGroupsUsersFactory(ModelFactory): + + model = models.association_groups_users + + def get_creation_defaults(self, **kwargs): + return {} + + async def before_creation(self): + assert 'user_id' in self.defaults and 'group_id' in self.defaults, \ + 'user_id and group_id should be provided to associate a group and a user' + + +class VFolderFactory(ModelFactory): + + model = models.vfolders + + def get_creation_defaults(self, **kwargs): + return { + 'host': 'local', + 'name': f'test-vfolder-{get_random_string()}', + } + + async def before_creation(self): + if 'user' not in self.defaults and 'group' not in self.defaults: + user = await UserFactory(self.app).create() + self.defaults['user'] = user['uuid'] + + +class VFolderInvitationFactory(ModelFactory): + + model = models.vfolder_invitations + + def get_creation_defaults(self, **kwargs): + return { + 'permission': models.VFolderPermission('ro'), + 'state': 'pending', + } + + async def before_creation(self): + if 'vfolder' not in self.defaults: + vf = await VFolderFactory(self.app).create() + self.defaults['vfolder'] = vf['id'] + if 'inviter' not in self.defaults: + user = await UserFactory(self.app).create() + self.defaults['inviter'] = user['email'] + if 'invitee' not in self.defaults: + user = await UserFactory(self.app).create() + self.defaults['invitee'] = user['email'] + + +class VFolderPermissionFactory(ModelFactory): + + model = models.vfolder_permissions + + def get_creation_defaults(self, **kwargs): + return { + 'permission': models.VFolderPermission('ro'), + } + + async def before_creation(self): + if 'vfolder' not in self.defaults: + vf = await VFolderFactory(self.app).create() + self.defaults['vfolder'] = vf['id'] + if 'user' not in self.defaults: + user = await UserFactory(self.app).create() + self.defaults['user'] = user['uuid'] diff --git a/tests/manager/models/BUILD b/tests/manager/models/BUILD new file mode 100644 index 0000000000..2c1c32f7e8 --- /dev/null +++ b/tests/manager/models/BUILD @@ -0,0 +1,6 @@ +python_tests( + name="tests", + dependencies=[ + "src/ai/backend/manager:service", + ], +) diff --git a/tests/manager/models/test_dbutils.py b/tests/manager/models/test_dbutils.py new file mode 100644 index 0000000000..d5f2882c8b --- /dev/null +++ b/tests/manager/models/test_dbutils.py @@ -0,0 +1,47 @@ +import aiotools +import pytest +import sqlalchemy as sa + +from ai.backend.manager.models.utils import execute_with_retry + + +@pytest.mark.asyncio +async def test_execute_with_retry(): + + class DummyDBError(Exception): + def __init__(self, pgcode): + self.pgcode = pgcode + + async def txn_func_generic_failure(): + raise sa.exc.IntegrityError('DUMMY_SQL', params=None, orig=DummyDBError('999')) + + async def txn_func_generic_failure_2(): + raise ZeroDivisionError("oops") + + async def txn_func_permanent_serialization_failure(): + raise sa.exc.DBAPIError('DUMMY_SQL', params=None, orig=DummyDBError('40001')) + + _fail_count = 0 + + async def txn_func_temporary_serialization_failure(): + nonlocal _fail_count + _fail_count += 1 + if _fail_count == 10: + return 1234 + raise sa.exc.DBAPIError('DUMMY_SQL', params=None, orig=DummyDBError('40001')) + + vclock = aiotools.VirtualClock() + with vclock.patch_loop(): + + with pytest.raises(sa.exc.IntegrityError): + await execute_with_retry(txn_func_generic_failure) + + with pytest.raises(ZeroDivisionError): + await execute_with_retry(txn_func_generic_failure_2) + + with pytest.raises(RuntimeError) as e: + await execute_with_retry(txn_func_permanent_serialization_failure) + assert "serialization failed" in e.value.args[0].lower() + + ret = await execute_with_retry(txn_func_temporary_serialization_failure) + assert ret == 1234 diff --git a/tests/manager/sample-ssl-cert/sample.crt b/tests/manager/sample-ssl-cert/sample.crt new file mode 100644 index 0000000000..6025a17a2a --- /dev/null +++ b/tests/manager/sample-ssl-cert/sample.crt @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICRzCCAbACCQCIXVju6dmcdzANBgkqhkiG9w0BAQUFADBoMQswCQYDVQQGEwJL +UjETMBEGA1UECBMKU29tZS1TdGF0ZTEOMAwGA1UEBxMFU2VvdWwxDzANBgNVBAoT +BkxhYmx1cDEPMA0GA1UECxMGRGV2T3BzMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcN +MTYxMDEwMDI0MTA2WhcNMjYxMDA4MDI0MTA2WjBoMQswCQYDVQQGEwJLUjETMBEG +A1UECBMKU29tZS1TdGF0ZTEOMAwGA1UEBxMFU2VvdWwxDzANBgNVBAoTBkxhYmx1 +cDEPMA0GA1UECxMGRGV2T3BzMRIwEAYDVQQDEwlsb2NhbGhvc3QwgZ8wDQYJKoZI +hvcNAQEBBQADgY0AMIGJAoGBANWBj4K90ZI7mSco5vLT1YZb/57xgb8e0qOFq0wG +bSFfTl//6bzw0G3+GPl/2L/9DMMivi7HS9iAT9/T7NusiHNDPhC8bqRnQYOYO67s +k7UCXeOkMl59MJqU4rn4IhHj8X1huOW8BosDMCkRx9PuS9FHUTJsCp1vnxi0G4Lo +uP5rAgMBAAEwDQYJKoZIhvcNAQEFBQADgYEAigmiXFi4n6h1B8w01l5Q38Ge1Rpp ++7fHAI+4FyNnsJKBhuCBX4AMmqLzgzNDpGyv4QEEUzMWERuAP0vpYNRj09i+xAXB +DeFgrIGbEKCbG4Ukp9U4R5kewp+qJnBfwGlBA1r9SF2ejWr7fPobGj1SrviZrLZ5 +f/7uWD54ie5aPkk= +-----END CERTIFICATE----- diff --git a/tests/manager/sample-ssl-cert/sample.csr b/tests/manager/sample-ssl-cert/sample.csr new file mode 100644 index 0000000000..d85ff77560 --- /dev/null +++ b/tests/manager/sample-ssl-cert/sample.csr @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIBqDCCARECAQAwaDELMAkGA1UEBhMCS1IxEzARBgNVBAgTClNvbWUtU3RhdGUx +DjAMBgNVBAcTBVNlb3VsMQ8wDQYDVQQKEwZMYWJsdXAxDzANBgNVBAsTBkRldk9w +czESMBAGA1UEAxMJbG9jYWxob3N0MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB +gQDVgY+CvdGSO5knKOby09WGW/+e8YG/HtKjhatMBm0hX05f/+m88NBt/hj5f9i/ +/QzDIr4ux0vYgE/f0+zbrIhzQz4QvG6kZ0GDmDuu7JO1Al3jpDJefTCalOK5+CIR +4/F9YbjlvAaLAzApEcfT7kvRR1EybAqdb58YtBuC6Lj+awIDAQABoAAwDQYJKoZI +hvcNAQEFBQADgYEAB+QwJKRAW9Du7MvZKE8xVuKamI3q13vuAOK+uFWU4iIwqfgR +OhjCrizkStOIRcScsKu023hmEhph8XHHN1IBOm3EjQ4iOZqXBgKAoEMiqPJjGRGk +LAQ7KPuDFv5QumKbTbd+mfvu56+o5U086+fo5pKVAcXsjNf9Sc90JEF4dtE= +-----END CERTIFICATE REQUEST----- diff --git a/tests/manager/sample-ssl-cert/sample.key b/tests/manager/sample-ssl-cert/sample.key new file mode 100644 index 0000000000..274bf6ab44 --- /dev/null +++ b/tests/manager/sample-ssl-cert/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQDVgY+CvdGSO5knKOby09WGW/+e8YG/HtKjhatMBm0hX05f/+m8 +8NBt/hj5f9i//QzDIr4ux0vYgE/f0+zbrIhzQz4QvG6kZ0GDmDuu7JO1Al3jpDJe +fTCalOK5+CIR4/F9YbjlvAaLAzApEcfT7kvRR1EybAqdb58YtBuC6Lj+awIDAQAB +AoGAbywYWv6V+mv4EnD02Ko++8g5sTyVz7uv+J+ok1yhRIhI2in6PnyyOyPdQ0Uz +yrxsAcu8dcUmlCQz8xt5sOUE4vOyXXgXil4v7/amMmKwhPXKssCwYA58U5S5e/I1 +DVHw4OaxT7qiPPZteZaJa2QgH1ihtXhNGbqYeTv9nBeEKAkCQQD6Bb8TLiWz1GFV +YgnEk+wAHX7f0RfQAwqr3W3Xc+Os0iLGt1s0Wu7kvnzzWMBQAMSXjLjvEABTM8zP +eXx7dpQdAkEA2pxRMU4ZjFjTQy/CJtRf7aWFj+0ctGv/2D0VXdmv7ArrjTVkAD9e +culPueqzKcdC53fZn8SnHuiA2FTBcGLGJwJBAMC4rzmEp9E/Uyuyn17kutS357V0 +gkt4HMCvtVyPWx86901/xpDLyzuNTdlyPwMsJF3BPkggaG+6DRScS4ULuU0CQC1Y +Y1cQ1ifQfPHgxCr9vnAy90NlcaDTDhyyfu4aq20Qzs9ZlcafXl4Dmy/7SPKPjIcq +yw9i4S9+FsvIuN8w/d0CQQCLEZ3PGNorU+lvfVl3YbAek/qY7bN5fdUI23+6JImD +O/NvyY0RKCgNo8EnCFtVqgE9YI7DvcMHT338Kmizj8FL +-----END RSA PRIVATE KEY----- diff --git a/tests/manager/test_advisory_lock.py b/tests/manager/test_advisory_lock.py new file mode 100644 index 0000000000..6579231b4f --- /dev/null +++ b/tests/manager/test_advisory_lock.py @@ -0,0 +1,68 @@ +import asyncio + +import pytest + +from ai.backend.manager.defs import LockID +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + + +@pytest.mark.asyncio +async def test_lock(database_engine: ExtendedAsyncSAEngine) -> None: + + enter_count = 0 + done_count = 0 + + async def critical_section(db: ExtendedAsyncSAEngine) -> None: + nonlocal enter_count + async with db.advisory_lock(LockID.LOCKID_TEST): + enter_count += 1 + await asyncio.sleep(1.0) + + tasks = [] + for idx in range(5): + tasks.append( + asyncio.create_task( + critical_section(database_engine), + name=f"critical-section-{idx}", + ), + ) + await asyncio.sleep(0.5) + + async with database_engine.connect() as conn: + result = await conn.exec_driver_sql( + "SELECT objid, granted, pid FROM pg_locks " + "WHERE locktype = 'advisory' AND objid = 42;", + ) + rows = result.fetchall() + print(rows) + result = await conn.exec_driver_sql( + "SELECT objid, granted FROM pg_locks " + "WHERE locktype = 'advisory' AND objid = 42 AND granted = 't';", + ) + rows = result.fetchall() + assert len(rows) == 1 + + await asyncio.sleep(2.5) + for t in tasks: + if t.done(): + done_count += 1 + else: + try: + t.cancel() + await t + except asyncio.CancelledError: + pass + await asyncio.sleep(0.1) + + assert 2 <= done_count <= 3 + assert enter_count >= done_count + + # Check all tasks have unlocked. + async with database_engine.connect() as conn: + result = await conn.exec_driver_sql( + "SELECT objid, granted, pid FROM pg_locks " + "WHERE locktype = 'advisory' AND objid = 42 AND granted = 't';", + ) + rows = result.fetchall() + print(rows) + assert len(rows) == 0 diff --git a/tests/manager/test_image.py b/tests/manager/test_image.py new file mode 100644 index 0000000000..cb693e7b49 --- /dev/null +++ b/tests/manager/test_image.py @@ -0,0 +1,95 @@ +from pathlib import Path +import uuid + +import pytest +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import ( + AsyncSession, + create_async_engine, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import ( + selectinload, + sessionmaker, +) + +from ai.backend.common.docker import ImageRef + +from ai.backend.manager.models import ( + update_aliases_from_file, + ImageAliasRow, + ImageRow, +) +from ai.backend.manager.models.base import metadata as old_metadata +from ai.backend.manager.models.utils import regenerate_table + +column_keys = ['nullable', 'index', 'unique', 'primary_key'] + + +@pytest.fixture +async def virtual_image_db(): + engine = create_async_engine('sqlite+aiosqlite:///:memory:', echo=True) + base = declarative_base() + metadata = base.metadata + + regenerate_table(old_metadata.tables['images'], metadata) + regenerate_table(old_metadata.tables['image_aliases'], metadata) + ImageAliasRow.metadata = metadata + ImageRow.metadata = metadata + async_session = sessionmaker(engine, class_=AsyncSession, autoflush=False) + async with engine.begin() as conn: + await conn.run_sync(metadata.create_all) + await conn.commit() + async with async_session() as session: + image_1 = ImageRow( + 'index.docker.io/lablup/test-python:latest', 'x86_64', + 'index.docker.io', 'lablup/test-python', 'latest', + 'sha256:2d577a600afe2d1b38d78bc2ee5abe3bd350890d0652e48096249694e074f9c3', + 123123123, 'COMPUTE', '', {}, {}, + ) + image_1.id = uuid.uuid4() + image_2 = ImageRow( + 'index.docker.io/lablup/test-python:3.6-debian', 'aarch64', + 'index.docker.io', 'lablup/test-python', '3.6-debian', + 'sha256:2d577a600afe2d1b38d78bc2ee5abe3bd350890d0652e48096249694e074f9c3', + 123123123, 'COMPUTE', '', {}, {}, + ) + image_2.id = uuid.uuid4() + session.add(image_1) + session.add(image_2) + await session.commit() + yield async_session + await engine.dispose() + + +@pytest.fixture +async def image_aliases(tmpdir): + content = ''' +aliases: + - ['my-python', 'test-python:latest', 'x86_64'] + - ['my-python:3.6', 'test-python:3.6-debian', 'aarch64'] # preferred +''' + p = Path(tmpdir) / 'test-image-aliases.yml' + p.write_text(content) + + yield p + + +@pytest.mark.asyncio +async def test_update_aliases_from_file(virtual_image_db, image_aliases): + async_session = virtual_image_db + async with async_session() as session: + created_aliases = await update_aliases_from_file(session, image_aliases) + for alias in created_aliases: + alias.id = uuid.uuid4() + await session.commit() + result = await session.execute( + sa.select(ImageAliasRow).options(selectinload(ImageAliasRow.image)), + ) + aliases = {} + for row in result.scalars().all(): + aliases[row.alias] = row.image.image_ref + assert aliases == { + 'my-python': ImageRef('lablup/test-python:latest', architecture='x86_64'), + 'my-python:3.6': ImageRef('lablup/test-python:3.6-debian', architecture='aarch64'), + } diff --git a/tests/manager/test_predicates.py b/tests/manager/test_predicates.py new file mode 100644 index 0000000000..0229437735 --- /dev/null +++ b/tests/manager/test_predicates.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +from ai.backend.common.types import SessionTypes +from ai.backend.manager.models.scaling_group import ScalingGroupOpts +from ai.backend.manager.scheduler.predicates import check_scaling_group + + +@pytest.mark.asyncio +@mock.patch('ai.backend.manager.scheduler.predicates.execute_with_retry') +async def test_allowed_session_types_check(mock_query): + mock_query.return_value = [ + { + 'name': 'a', + 'scheduler_opts': ScalingGroupOpts().from_json({ + 'allowed_session_types': ['batch'], + }), + }, + { + 'name': 'b', + 'scheduler_opts': ScalingGroupOpts().from_json({ + 'allowed_session_types': ['interactive'], + }), + }, + { + 'name': 'c', + 'scheduler_opts': ScalingGroupOpts().from_json({ + 'allowed_session_types': ['batch', 'interactive'], + }), + }, + ] + mock_conn = MagicMock() + mock_sched_ctx = MagicMock() + mock_sess_ctx = MagicMock() + + # Preferred scaling group with one match in allowed sgroups + + mock_sess_ctx.session_type = SessionTypes.BATCH + mock_sess_ctx.scaling_group = 'a' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed + assert mock_sess_ctx.target_sgroup_names == ['a'] + + mock_sess_ctx.session_type = SessionTypes.BATCH + mock_sess_ctx.scaling_group = 'b' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert not result.passed + assert result.message is not None + assert "does not accept" in result.message + assert mock_sess_ctx.target_sgroup_names == [] + + mock_sess_ctx.session_type = SessionTypes.BATCH + mock_sess_ctx.scaling_group = 'c' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed + assert mock_sess_ctx.target_sgroup_names == ['c'] + + mock_sess_ctx.session_type = SessionTypes.INTERACTIVE + mock_sess_ctx.scaling_group = 'a' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert not result.passed + assert result.message is not None + assert "does not accept" in result.message + assert mock_sess_ctx.target_sgroup_names == [] + + mock_sess_ctx.session_type = SessionTypes.INTERACTIVE + mock_sess_ctx.scaling_group = 'b' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed + assert mock_sess_ctx.target_sgroup_names == ['b'] + + mock_sess_ctx.session_type = SessionTypes.INTERACTIVE + mock_sess_ctx.scaling_group = 'c' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed + assert mock_sess_ctx.target_sgroup_names == ['c'] + + # Non-existent/disallowed preferred scaling group + + mock_sess_ctx.session_type = SessionTypes.INTERACTIVE + mock_sess_ctx.scaling_group = 'x' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert not result.passed + assert result.message is not None + assert "do not have access" in result.message + assert mock_sess_ctx.target_sgroup_names == [] + + # No preferred scaling group with partially matching allowed sgroups + + mock_sess_ctx.session_type = SessionTypes.BATCH + mock_sess_ctx.scaling_group = None + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed + assert mock_sess_ctx.target_sgroup_names == ['a', 'c'] + + mock_sess_ctx.session_type = SessionTypes.INTERACTIVE + mock_sess_ctx.scaling_group = None + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed + assert mock_sess_ctx.target_sgroup_names == ['b', 'c'] + + # No preferred scaling group with an empty list of allowed sgroups + + mock_query.return_value = [] + + mock_sess_ctx.session_type = SessionTypes.BATCH + mock_sess_ctx.scaling_group = 'x' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert not result.passed + assert result.message is not None + assert "do not have any" in result.message + assert mock_sess_ctx.target_sgroup_names == [] + + mock_sess_ctx.session_type = SessionTypes.INTERACTIVE + mock_sess_ctx.scaling_group = 'x' + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert not result.passed + assert result.message is not None + assert "do not have any" in result.message + assert mock_sess_ctx.target_sgroup_names == [] + + # No preferred scaling group with a non-empty list of allowed sgroups + + mock_query.return_value = [ + { + 'name': 'a', + 'scheduler_opts': ScalingGroupOpts.from_json({ + 'allowed_session_types': ['batch'], + }), + }, + ] + + mock_sess_ctx.session_type = SessionTypes.BATCH + mock_sess_ctx.scaling_group = None + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed + assert mock_sess_ctx.target_sgroup_names == ['a'] + + mock_sess_ctx.session_type = SessionTypes.INTERACTIVE + mock_sess_ctx.scaling_group = None + mock_sess_ctx.target_sgroup_names = [] + result = await check_scaling_group(mock_conn, mock_sched_ctx, mock_sess_ctx) + assert not result.passed + assert result.message is not None + assert "No scaling groups accept" in result.message + assert mock_sess_ctx.target_sgroup_names == [] diff --git a/tests/manager/test_queryfilter.py b/tests/manager/test_queryfilter.py new file mode 100644 index 0000000000..56fc3e5c62 --- /dev/null +++ b/tests/manager/test_queryfilter.py @@ -0,0 +1,301 @@ +import enum + +import pytest +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from ai.backend.manager.models.minilang.queryfilter import QueryFilterParser + + +class UserTypes(enum.Enum): + ADMIN = 0 + USER = 1 + + +@pytest.fixture +def virtual_user_db(): + engine = sa.engine.create_engine('sqlite:///:memory:', echo=False) + base = declarative_base() + metadata = base.metadata + users = sa.Table( + 'users', metadata, + sa.Column('id', sa.Integer, sa.Sequence('user_id_seq'), primary_key=True), + sa.Column('name', sa.String(50)), + sa.Column('full_name', sa.String(50)), + sa.Column('type', sa.Enum(UserTypes)), + sa.Column('age', sa.Integer), + sa.Column('is_active', sa.Boolean), + sa.Column('data', sa.Float, nullable=True), + ) + metadata.create_all(engine) + with engine.connect() as conn: + conn.execute( + users.insert(), [ + { + 'name': 'tester', + 'full_name': 'tester1', + 'type': UserTypes.ADMIN, + 'age': 30, + 'is_active': True, + 'data': 10.5, + }, + { + 'name': 'test\"er', + 'full_name': 'tester2', + 'type': UserTypes.USER, + 'age': 40, + 'is_active': True, + 'data': None, + }, + { + 'name': 'test\'er', + 'full_name': 'tester3', + 'type': UserTypes.USER, + 'age': 50, + 'is_active': False, + 'data': 2.33, + }, + { + 'name': 'tester ♪', + 'full_name': 'tester4', + 'type': UserTypes.USER, + 'age': 20, + 'is_active': False, + 'data': None, + }, + ], + ) + yield conn, users + engine.dispose() + + +def test_select_queries(virtual_user_db) -> None: + conn, users = virtual_user_db + parser = QueryFilterParser() + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "full_name == \"tester1\"", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("tester", 30)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "name == \"test'er\"", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("test'er", 50)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "name == \"test\\\"er\"", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("test\"er", 40)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "(full_name == \"tester1\")", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("tester", 30)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "full_name in [\"tester1\", \"tester3\", \"tester9\"]", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("tester", 30), ("test\'er", 50)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "type in [\"USER\", \"ADMIN\"]", + ) + actual_ret = list(conn.execute(sa_query)) + assert len(actual_ret) == 4 + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "full_name == \"tester1\" & age == 20", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "(full_name == \"tester1\") & (age == 20)", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "(full_name == \"tester1\") | (age == 20)", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("tester", 30), ("tester ♪", 20)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "(name contains \"test\") & (age > 30) & (is_active is true)", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("test\"er", 40)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "data isnot null", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("tester", 30), ("test\'er", 50)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "data is null", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("test\"er", 40), ("tester ♪", 20)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "data is null | data isnot null", + ) + actual_ret = list(conn.execute(sa_query)) + assert len(actual_ret) == 4 # all rows + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "data < 9.4", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("test\'er", 50)] # Note: null values are not matched + assert test_ret == actual_ret + + # invalid syntax + with pytest.raises(ValueError): + parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "", + ) + with pytest.raises(ValueError): + parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "!!!", + ) + with pytest.raises(ValueError): + parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "123", + ) + with pytest.raises(ValueError): + parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "\"abc\"", + ) + with pytest.raises(ValueError): + parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "name =", + ) + + # invalid value type + # => This case is handled during the actual execution of SQL statements + # in the database, not when preparing statements. + # So it is the out of scope issue. + # with pytest.raises(ValueError): + # parser.append_filter( + # sa.select([users.c.name, users.c.age]).select_from(users), + # "full_name == 123", + # ) + + # non-existent column + with pytest.raises(ValueError): + parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "xyz == 123", + ) + + +def test_modification_queries(virtual_user_db) -> None: + conn, users = virtual_user_db + parser = QueryFilterParser() + + sa_query = parser.append_filter( + sa.update(users).values({'name': 'hello'}), + "full_name == \"tester1\"", + ) + result = conn.execute(sa_query) + assert result.rowcount == 1 + + sa_query = parser.append_filter( + sa.delete(users), + "full_name like \"tester%\"", + ) + result = conn.execute(sa_query) + assert result.rowcount == 4 + + +def test_fieldspec(virtual_user_db) -> None: + conn, users = virtual_user_db + parser = QueryFilterParser({ + "n1": ("name", None), + "n2": ("full_name", lambda s: s.lower()), + "t1": ("type", lambda s: UserTypes[s]), + }) + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "n1 == \"tester\"", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("tester", 30)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "n2 == \"TESTER1\"", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("tester", 30)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "n2 in [\"TESTER2\", \"TESTER4\"]", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [("test\"er", 40), ("tester ♪", 20)] + assert test_ret == actual_ret + + sa_query = parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "t1 in [\"USER\", \"ADMIN\"]", + ) + actual_ret = list(conn.execute(sa_query)) + assert len(actual_ret) == 4 + + # non-existent column in fieldspec + with pytest.raises(ValueError): + parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "full_name == \"TESTER1\"", + ) + + # non-existent enum value + with pytest.raises(ValueError): + parser.append_filter( + sa.select([users.c.name, users.c.age]).select_from(users), + "t1 == \"XYZ\"", + ) diff --git a/tests/manager/test_queryorder.py b/tests/manager/test_queryorder.py new file mode 100644 index 0000000000..53a5dc0d69 --- /dev/null +++ b/tests/manager/test_queryorder.py @@ -0,0 +1,115 @@ +import pytest +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from ai.backend.manager.models.minilang.ordering import QueryOrderParser + + +@pytest.fixture +def virtual_grid_db(): + engine = sa.engine.create_engine('sqlite:///:memory:', echo=False) + base = declarative_base() + metadata = base.metadata + grid = sa.Table( + 'users', metadata, + sa.Column('id', sa.Integer, sa.Sequence('user_id_seq'), primary_key=True), + sa.Column('data1', sa.Integer), + sa.Column('data2', sa.Float), + sa.Column('data3', sa.String(10)), + ) + metadata.create_all(engine) + with engine.connect() as conn: + conn.execute(grid.insert(), [ + {'data1': 10, 'data2': 0.2, 'data3': 'a'}, + {'data1': 10, 'data2': 0.1, 'data3': 'c'}, + {'data1': 20, 'data2': 0.0, 'data3': 'b'}, + {'data1': 20, 'data2': -0.1, 'data3': 'd'}, + ]) + yield conn, grid + engine.dispose() + + +def test_select_queries(virtual_grid_db) -> None: + conn, grid = virtual_grid_db + parser = QueryOrderParser() + + sa_query = parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "+data1", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [(1,), (2,), (3,), (4,)] + assert test_ret == actual_ret + + sa_query = parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "-data1", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [(3,), (4,), (1,), (2,)] + assert test_ret == actual_ret + + sa_query = parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "-data1,+data2", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [(4,), (3,), (2,), (1,)] + assert test_ret == actual_ret + + sa_query = parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "-data1,+data3,-data2", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [(3,), (4,), (1,), (2,)] + assert test_ret == actual_ret + + # default ordering + sa_query = parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [(1,), (2,), (3,), (4,)] + assert test_ret == actual_ret + + # without order marks, it's assumed to be ascending + sa_query = parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "data3,-data2,data1", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [(1,), (3,), (2,), (4,)] + assert test_ret == actual_ret + + # invalid syntax + with pytest.raises(ValueError): + parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "xxx", + ) + + +def test_column_map(virtual_grid_db) -> None: + conn, grid = virtual_grid_db + parser = QueryOrderParser({ + "v1": "data1", + "v2": "data2", + "v3": "data3", + }) + + sa_query = parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "-v3", + ) + actual_ret = list(conn.execute(sa_query)) + test_ret = [(4,), (2,), (3,), (1,)] + assert test_ret == actual_ret + + # non-existent column in the column map + with pytest.raises(ValueError): + parser.append_ordering( + sa.select([grid.c.id]).select_from(grid), + "-data1,+data2", + ) diff --git a/tests/manager/test_registry.py b/tests/manager/test_registry.py new file mode 100644 index 0000000000..6f9db0f647 --- /dev/null +++ b/tests/manager/test_registry.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from decimal import Decimal +from unittest.mock import MagicMock, AsyncMock + +import pytest +import snappy +from sqlalchemy.sql.dml import Insert, Update + +from ai.backend.manager.defs import DEFAULT_IMAGE_ARCH +from ai.backend.manager.registry import AgentRegistry +from ai.backend.manager.models import AgentStatus +from ai.backend.common import msgpack +from ai.backend.common.types import BinarySize, DeviceId, ResourceSlot, SlotName + + +@pytest.mark.asyncio +async def test_handle_heartbeat( + registry_ctx: tuple[AgentRegistry, MagicMock, MagicMock, MagicMock, MagicMock, MagicMock], + mocker, +) -> None: + mock_get_known_registries = AsyncMock(return_value=[ + {'index.docker.io': 'https://registry-1.docker.io'}, + ]) + mocker.patch('ai.backend.manager.registry.get_known_registries', mock_get_known_registries) + mock_redis_wrapper = MagicMock() + mock_redis_wrapper.execute = AsyncMock() + mocker.patch('ai.backend.manager.registry.redis', mock_redis_wrapper) + + def mocked_entrypoints(entry_point_group: str): + return [] + + mocker.patch('ai.backend.common.plugin.pkg_resources.iter_entry_points', mocked_entrypoints) + + registry, mock_dbconn, mock_dbresult, mock_shared_config, _, _ = registry_ctx + image_data = snappy.compress(msgpack.packb([ + ('index.docker.io/lablup/python:3.6-ubuntu18.04', ), + ])) + + _1 = Decimal('1') + _4 = Decimal('4') + _1g = Decimal('1073741824') + _2g = Decimal('2147483648') + + # Join + mock_dbresult.first = MagicMock(return_value=None) + await registry.handle_heartbeat('i-001', { + 'scaling_group': 'sg-testing', + 'resource_slots': {'cpu': ('count', _1), 'mem': ('bytes', _1g)}, + 'region': 'ap-northeast-2', + 'addr': '10.0.0.5', + 'architecture': DEFAULT_IMAGE_ARCH, + 'version': '19.12.0', + 'compute_plugins': [], + 'images': image_data, + }) + mock_shared_config.update_resource_slots.assert_awaited_once() + q = mock_dbconn.execute.await_args_list[1].args[0] + assert isinstance(q, Insert) + + # Update alive instance + mock_shared_config.update_resource_slots.reset_mock() + mock_dbconn.execute.reset_mock() + mock_dbresult.first = MagicMock(return_value={ + 'status': AgentStatus.ALIVE, + 'addr': '10.0.0.5', + 'architecture': DEFAULT_IMAGE_ARCH, + 'scaling_group': 'sg-testing', + 'available_slots': ResourceSlot({'cpu': _1, 'mem': _1g}), + 'version': '19.12.0', + 'compute_plugins': [], + }) + await registry.handle_heartbeat('i-001', { + 'scaling_group': 'sg-testing', + 'resource_slots': {'cpu': ('count', _1), 'mem': ('bytes', _2g)}, + 'region': 'ap-northeast-2', + 'addr': '10.0.0.6', + 'architecture': DEFAULT_IMAGE_ARCH, + 'version': '19.12.0', + 'compute_plugins': [], + 'images': image_data, + }) + mock_shared_config.update_resource_slots.assert_awaited_once() + q = mock_dbconn.execute.await_args_list[1].args[0] + assert isinstance(q, Update) + q_params = q.compile().params + assert q_params['addr'] == '10.0.0.6' + assert q_params['available_slots'] == ResourceSlot({'cpu': _1, 'mem': _2g}) + assert 'scaling_group' not in q_params + + # Rejoin + mock_shared_config.update_resource_slots.reset_mock() + mock_dbconn.execute.reset_mock() + mock_dbresult.first = MagicMock(return_value={ + 'status': AgentStatus.LOST, + 'addr': '10.0.0.5', + 'architecture': DEFAULT_IMAGE_ARCH, + 'scaling_group': 'sg-testing', + 'available_slots': ResourceSlot({'cpu': _1, 'mem': _1g}), + 'version': '19.12.0', + 'compute_plugins': [], + }) + await registry.handle_heartbeat('i-001', { + 'scaling_group': 'sg-testing2', + 'resource_slots': {'cpu': ('count', _4), 'mem': ('bytes', _2g)}, + 'region': 'ap-northeast-2', + 'addr': '10.0.0.6', + 'architecture': DEFAULT_IMAGE_ARCH, + 'version': '19.12.0', + 'compute_plugins': [], + 'images': image_data, + }) + mock_shared_config.update_resource_slots.assert_awaited_once() + q = mock_dbconn.execute.await_args_list[1].args[0] + assert isinstance(q, Update) + q_params = q.compile().params + assert q_params['status'] == AgentStatus.ALIVE + assert q_params['addr'] == '10.0.0.6' + assert "lost_at=NULL" in str(q) # stringified and removed from bind params + assert q_params['available_slots'] == ResourceSlot({'cpu': _4, 'mem': _2g}) + assert q_params['scaling_group'] == 'sg-testing2' + assert 'compute_plugins' in q_params + assert 'version' in q_params + + +@pytest.mark.asyncio +async def test_convert_resource_spec_to_resource_slot( + registry_ctx: tuple[AgentRegistry, MagicMock, MagicMock, MagicMock, MagicMock, MagicMock], +): + registry, _, _, _, _, _ = registry_ctx + allocations = { + 'cuda': { + SlotName('cuda.shares'): { + DeviceId('a0'): '2.5', + DeviceId('a1'): '2.0', + }, + }, + } + converted_allocations = registry.convert_resource_spec_to_resource_slot(allocations) + assert converted_allocations['cuda.shares'] == '4.5' + allocations = { + 'cpu': { + SlotName('cpu'): { + DeviceId('a0'): '3', + DeviceId('a1'): '1', + }, + }, + 'ram': { + SlotName('ram'): { + DeviceId('b0'): '2.5g', + DeviceId('b1'): '512m', + }, + }, + } + converted_allocations = registry.convert_resource_spec_to_resource_slot(allocations) + assert converted_allocations['cpu'] == '4' + assert converted_allocations['ram'] == str(Decimal(BinarySize.from_str('1g')) * 3) diff --git a/tests/manager/test_scheduler.py b/tests/manager/test_scheduler.py new file mode 100644 index 0000000000..7823ee06f2 --- /dev/null +++ b/tests/manager/test_scheduler.py @@ -0,0 +1,1110 @@ +from __future__ import annotations + +import secrets +from datetime import datetime, timedelta +from decimal import Decimal +from typing import ( + Any, + Mapping, + Sequence, +) +from unittest import mock +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4, UUID +from pprint import pprint + +import attr +import pytest +import trafaret as t +from dateutil.parser import parse as dtparse + +from ai.backend.common.docker import ImageRef +from ai.backend.common.types import ( + AccessKey, AgentId, KernelId, SessionId, + ResourceSlot, SessionTypes, + ClusterMode, +) +from ai.backend.manager.defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE +from ai.backend.manager.models.scaling_group import ScalingGroupOpts +from ai.backend.manager.scheduler.types import ( + KernelInfo, + PendingSession, + ExistingSession, + AgentContext, +) +from ai.backend.manager.registry import AgentRegistry +from ai.backend.manager.scheduler.dispatcher import ( + load_scheduler, + SchedulerDispatcher, + _list_pending_sessions, +) +from ai.backend.manager.scheduler.fifo import FIFOSlotScheduler, LIFOSlotScheduler +from ai.backend.manager.scheduler.drf import DRFScheduler +from ai.backend.manager.scheduler.mof import MOFScheduler +from ai.backend.manager.scheduler.predicates import check_reserved_batch_session + + +def test_load_intrinsic(): + default_sgroup_opts = ScalingGroupOpts() + assert isinstance(load_scheduler('fifo', default_sgroup_opts, {}), FIFOSlotScheduler) + assert isinstance(load_scheduler('lifo', default_sgroup_opts, {}), LIFOSlotScheduler) + assert isinstance(load_scheduler('drf', default_sgroup_opts, {}), DRFScheduler) + assert isinstance(load_scheduler('mof', default_sgroup_opts, {}), MOFScheduler) + + +def test_scheduler_configs(): + example_sgroup_opts = ScalingGroupOpts( # already processed by column trafaret + allowed_session_types=[SessionTypes.BATCH], + pending_timeout=timedelta(seconds=86400 * 2), + config={ + 'extra_config': None, + 'num_retries_to_skip': 5, + }, + ) + scheduler = load_scheduler('fifo', example_sgroup_opts, example_sgroup_opts.config) + assert isinstance(scheduler, FIFOSlotScheduler) + assert scheduler.config == { + 'extra_config': None, + 'num_retries_to_skip': 5, + } + with pytest.raises(t.DataError): + example_sgroup_opts.config['num_retries_to_skip'] = -1 # invalid value + scheduler = load_scheduler('fifo', example_sgroup_opts, example_sgroup_opts.config) + + +example_group_id = uuid4() + +example_total_capacity = ResourceSlot({'cpu': '4.0', 'mem': '4096'}) + + +@pytest.fixture +def example_agents(): + return [ + AgentContext( + agent_id=AgentId('i-001'), + agent_addr='10.0.1.1:6001', + architecture=DEFAULT_IMAGE_ARCH, + scaling_group='sg01', + available_slots=ResourceSlot({ + 'cpu': Decimal('4.0'), + 'mem': Decimal('4096'), + 'cuda.shares': Decimal('4.0'), + 'rocm.devices': Decimal('2'), + }), + occupied_slots=ResourceSlot({ + 'cpu': Decimal('0'), + 'mem': Decimal('0'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + ), + AgentContext( + agent_id=AgentId('i-101'), + agent_addr='10.0.2.1:6001', + architecture=DEFAULT_IMAGE_ARCH, + scaling_group='sg02', + available_slots=ResourceSlot({ + 'cpu': Decimal('3.0'), + 'mem': Decimal('2560'), + 'cuda.shares': Decimal('1.0'), + 'rocm.devices': Decimal('8'), + }), + occupied_slots=ResourceSlot({ + 'cpu': Decimal('0'), + 'mem': Decimal('0'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + ), + ] + + +@pytest.fixture +def example_mixed_agents(): + return [ + AgentContext( + agent_id=AgentId('i-gpu'), + agent_addr='10.0.1.1:6001', + architecture=DEFAULT_IMAGE_ARCH, + scaling_group='sg01', + available_slots=ResourceSlot({ + 'cpu': Decimal('4.0'), + 'mem': Decimal('4096'), + 'cuda.shares': Decimal('4.0'), + }), + occupied_slots=ResourceSlot({ + 'cpu': Decimal('0'), + 'mem': Decimal('0'), + 'cuda.shares': Decimal('0'), + }), + ), + AgentContext( + agent_id=AgentId('i-cpu'), + agent_addr='10.0.2.1:6001', + architecture=DEFAULT_IMAGE_ARCH, + scaling_group='sg02', + available_slots=ResourceSlot({ + 'cpu': Decimal('3.0'), + 'mem': Decimal('2560'), + 'cuda.shares': Decimal('0'), + }), + occupied_slots=ResourceSlot({ + 'cpu': Decimal('0'), + 'mem': Decimal('0'), + 'cuda.shares': Decimal('0'), + }), + ), + ] + + +@pytest.fixture +def example_agents_first_one_assigned(): + return [ + AgentContext( + agent_id=AgentId('i-001'), + agent_addr='10.0.1.1:6001', + architecture=DEFAULT_IMAGE_ARCH, + scaling_group='sg01', + available_slots=ResourceSlot({ + 'cpu': Decimal('2.0'), + 'mem': Decimal('2048'), + 'cuda.shares': Decimal('2.0'), + 'rocm.devices': Decimal('1'), + }), + occupied_slots=ResourceSlot({ + 'cpu': Decimal('2.0'), + 'mem': Decimal('2048'), + 'cuda.shares': Decimal('2.0'), + 'rocm.devices': Decimal('1'), + }), + ), + AgentContext( + agent_id=AgentId('i-101'), + agent_addr='10.0.2.1:6001', + architecture=DEFAULT_IMAGE_ARCH, + scaling_group='sg02', + available_slots=ResourceSlot({ + 'cpu': Decimal('3.0'), + 'mem': Decimal('2560'), + 'cuda.shares': Decimal('1.0'), + 'rocm.devices': Decimal('8'), + }), + occupied_slots=ResourceSlot({ + 'cpu': Decimal('0'), + 'mem': Decimal('0'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + ), + ] + + +@pytest.fixture +def example_agents_no_valid(): + return [ + AgentContext( + agent_id=AgentId('i-001'), + agent_addr='10.0.1.1:6001', + architecture=DEFAULT_IMAGE_ARCH, + scaling_group='sg01', + available_slots=ResourceSlot({ + 'cpu': Decimal('0'), + 'mem': Decimal('0'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + occupied_slots=ResourceSlot({ + 'cpu': Decimal('4.0'), + 'mem': Decimal('4096'), + 'cuda.shares': Decimal('4.0'), + 'rocm.devices': Decimal('2'), + }), + ), + AgentContext( + agent_id=AgentId('i-101'), + agent_addr='10.0.2.1:6001', + architecture=DEFAULT_IMAGE_ARCH, + scaling_group='sg02', + available_slots=ResourceSlot({ + 'cpu': Decimal('0'), + 'mem': Decimal('0'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + occupied_slots=ResourceSlot({ + 'cpu': Decimal('3.0'), + 'mem': Decimal('2560'), + 'cuda.shares': Decimal('1.0'), + 'rocm.devices': Decimal('8'), + }), + ), + ] + + +@attr.s(auto_attribs=True, slots=True) +class SessionKernelIdPair: + session_id: UUID + kernel_ids: Sequence[KernelId] + + +pending_session_kernel_ids = [ + SessionKernelIdPair( + session_id=UUID('251907d9-1290-4126-bc6c-000000000100'), + kernel_ids=[KernelId(UUID('251907d9-1290-4126-bc6c-000000000100'))]), + SessionKernelIdPair( + session_id=UUID('251907d9-1290-4126-bc6c-000000000200'), + kernel_ids=[KernelId(UUID('251907d9-1290-4126-bc6c-000000000200'))]), + SessionKernelIdPair( + # single-node mode multi-container session + session_id=UUID('251907d9-1290-4126-bc6c-000000000300'), + kernel_ids=[ + KernelId(UUID('251907d9-1290-4126-bc6c-000000000300')), + KernelId(UUID('251907d9-1290-4126-bc6c-000000000301')), + KernelId(UUID('251907d9-1290-4126-bc6c-000000000302')), + ]), + SessionKernelIdPair( + session_id=UUID('251907d9-1290-4126-bc6c-000000000400'), + kernel_ids=[KernelId(UUID('251907d9-1290-4126-bc6c-000000000400'))]), +] + +existing_session_kernel_ids = [ + SessionKernelIdPair( + session_id=UUID('251907d9-1290-4126-bc6c-100000000100'), + kernel_ids=[ + KernelId(UUID('251907d9-1290-4126-bc6c-100000000100')), + KernelId(UUID('251907d9-1290-4126-bc6c-100000000101')), + ]), + SessionKernelIdPair( + session_id=UUID('251907d9-1290-4126-bc6c-100000000200'), + kernel_ids=[KernelId(UUID('251907d9-1290-4126-bc6c-100000000200'))]), + SessionKernelIdPair( + # single-node mode multi-container session + session_id=UUID('251907d9-1290-4126-bc6c-100000000300'), + kernel_ids=[KernelId(UUID('251907d9-1290-4126-bc6c-100000000300'))]), +] + +common_image_ref = ImageRef('lablup/python:3.6-ubunt18.04') + +_common_dummy_for_pending_session: Mapping[str, Any] = dict( + domain_name='default', + group_id=example_group_id, + resource_policy={}, + resource_opts={}, + vfolder_mounts=[], + environ={}, + bootstrap_script=None, + startup_command=None, + internal_data=None, + preopen_ports=[], +) + +_common_dummy_for_existing_session: Mapping[str, Any] = dict( + domain_name='default', + group_id=example_group_id, +) + + +@pytest.fixture +def example_pending_sessions(): + # lower indicies are enqueued first. + return [ + PendingSession( # rocm + kernels=[ + KernelInfo( + kernel_id=pending_session_kernel_ids[0].kernel_ids[0], + session_id=pending_session_kernel_ids[0].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role=DEFAULT_ROLE, + cluster_idx=1, + cluster_hostname=f"{DEFAULT_ROLE}0", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('2.0'), + 'mem': Decimal('1024'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('1'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2021-12-28T23:59:59+00:00'), + ), + ], + access_key=AccessKey('user01'), + agent_id=None, + agent_addr=None, + status_data={}, + session_id=pending_session_kernel_ids[0].session_id, + session_creation_id='aaa100', + session_name='es01', + session_type=SessionTypes.BATCH, + cluster_mode='single-node', + cluster_size=1, + scaling_group='sg01', + requested_slots=ResourceSlot({ + 'cpu': Decimal('2.0'), + 'mem': Decimal('1024'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('1'), + }), + target_sgroup_names=[], + **_common_dummy_for_pending_session, + created_at=dtparse('2021-12-28T23:59:59+00:00'), + ), + PendingSession( # cuda + kernels=[ + KernelInfo( + kernel_id=pending_session_kernel_ids[1].kernel_ids[0], + session_id=pending_session_kernel_ids[1].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role=DEFAULT_ROLE, + cluster_idx=1, + cluster_hostname=f"{DEFAULT_ROLE}0", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('1.0'), + 'mem': Decimal('2048'), + 'cuda.shares': Decimal('0.5'), + 'rocm.devices': Decimal('0'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2022-02-01T23:59:59+00:00'), + ), + ], + access_key=AccessKey('user02'), + agent_id=None, + agent_addr=None, + status_data={}, + session_id=pending_session_kernel_ids[1].session_id, + session_creation_id='aaa101', + session_name='es01', + session_type=SessionTypes.BATCH, + cluster_mode='single-node', + cluster_size=1, + scaling_group='sg01', + requested_slots=ResourceSlot({ + 'cpu': Decimal('1.0'), + 'mem': Decimal('2048'), + 'cuda.shares': Decimal('0.5'), + 'rocm.devices': Decimal('0'), + }), + target_sgroup_names=[], + **_common_dummy_for_pending_session, + created_at=dtparse('2022-02-01T23:59:59+00:00'), + ), + PendingSession( # cpu-only + kernels=[ + KernelInfo( + kernel_id=pending_session_kernel_ids[2].kernel_ids[0], + session_id=pending_session_kernel_ids[2].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role=DEFAULT_ROLE, + cluster_idx=1, + cluster_hostname=f"{DEFAULT_ROLE}0", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('0.4'), + 'mem': Decimal('512'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2021-12-01T23:59:59+00:00'), + ), + KernelInfo( + kernel_id=pending_session_kernel_ids[2].kernel_ids[1], + session_id=pending_session_kernel_ids[2].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role='sub', + cluster_idx=2, + cluster_hostname="sub1", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('0.3'), + 'mem': Decimal('256'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2021-12-01T23:59:59+00:00'), + ), + KernelInfo( + kernel_id=pending_session_kernel_ids[2].kernel_ids[2], + session_id=pending_session_kernel_ids[2].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role='sub', + cluster_idx=3, + cluster_hostname="sub2", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('0.3'), + 'mem': Decimal('256'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2021-12-01T23:59:59+00:00'), + ), + ], + access_key=AccessKey('user03'), + agent_id=None, + agent_addr=None, + status_data={}, + session_id=pending_session_kernel_ids[2].session_id, + session_creation_id='aaa102', + session_name='es01', + session_type=SessionTypes.BATCH, + cluster_mode='single-node', + cluster_size=3, + scaling_group='sg01', + requested_slots=ResourceSlot({ + 'cpu': Decimal('1.0'), + 'mem': Decimal('1024'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + target_sgroup_names=[], + **_common_dummy_for_pending_session, + created_at=dtparse('2021-12-01T23:59:59+00:00'), + ), + ] + + +@pytest.fixture +def example_existing_sessions(): + return [ + ExistingSession( + kernels=[ + KernelInfo( + kernel_id=existing_session_kernel_ids[0].kernel_ids[0], + session_id=existing_session_kernel_ids[0].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role=DEFAULT_ROLE, + cluster_idx=1, + cluster_hostname=f"{DEFAULT_ROLE}0", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('1.0'), + 'mem': Decimal('512'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2022-02-05T00:00:00+00:00'), + ), + KernelInfo( + kernel_id=existing_session_kernel_ids[0].kernel_ids[1], + session_id=existing_session_kernel_ids[0].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role='sub', + cluster_idx=2, + cluster_hostname="sub1", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('2.0'), + 'mem': Decimal('512'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('1'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2022-02-05T00:00:00+00:00'), + ), + ], + access_key=AccessKey('user01'), + session_id=existing_session_kernel_ids[0].session_id, + session_name='es01', + session_type=SessionTypes.BATCH, + cluster_mode='single-node', + cluster_size=2, + occupying_slots=ResourceSlot({ + 'cpu': Decimal('3.0'), + 'mem': Decimal('1024'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('1'), + }), + scaling_group='sg01', + **_common_dummy_for_existing_session, + ), + ExistingSession( + kernels=[ + KernelInfo( + kernel_id=existing_session_kernel_ids[1].kernel_ids[0], + session_id=existing_session_kernel_ids[1].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role=DEFAULT_ROLE, + cluster_idx=1, + cluster_hostname=f"{DEFAULT_ROLE}0", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('1.0'), + 'mem': Decimal('2048'), + 'cuda.shares': Decimal('0.5'), + 'rocm.devices': Decimal('0'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2021-09-03T00:00:00+00:00'), + ), + ], + access_key=AccessKey('user02'), + session_id=existing_session_kernel_ids[1].session_id, + session_type=SessionTypes.BATCH, + session_name='es01', + cluster_mode='single-node', + cluster_size=1, + occupying_slots=ResourceSlot({ + 'cpu': Decimal('1.0'), + 'mem': Decimal('2048'), + 'cuda.shares': Decimal('0.5'), + 'rocm.devices': Decimal('0'), + }), + scaling_group='sg01', + **_common_dummy_for_existing_session, + ), + ExistingSession( + kernels=[ + KernelInfo( + kernel_id=existing_session_kernel_ids[2].kernel_ids[0], + session_id=existing_session_kernel_ids[2].session_id, + access_key='dummy-access-key', + agent_id=None, + agent_addr=None, + cluster_role=DEFAULT_ROLE, + cluster_idx=1, + cluster_hostname=f"{DEFAULT_ROLE}0", + image_ref=common_image_ref, + resource_opts={}, + requested_slots=ResourceSlot({ + 'cpu': Decimal('4.0'), + 'mem': Decimal('4096'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + bootstrap_script=None, + startup_command=None, + created_at=dtparse('2022-01-15T00:00:00+00:00'), + ), + ], + access_key=AccessKey('user03'), + session_id=existing_session_kernel_ids[2].session_id, + session_type=SessionTypes.BATCH, + session_name='es01', + cluster_mode='single-node', + cluster_size=1, + occupying_slots=ResourceSlot({ + 'cpu': Decimal('4.0'), + 'mem': Decimal('4096'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }), + scaling_group='sg01', + **_common_dummy_for_existing_session, + ), + ] + + +def _find_and_pop_picked_session(pending_sessions, picked_session_id): + for picked_idx, pending_sess in enumerate(pending_sessions): + if pending_sess.session_id == picked_session_id: + break + else: + # no matching entry for picked session? + raise RuntimeError('should not reach here') + return pending_sessions.pop(picked_idx) + + +def test_fifo_scheduler(example_agents, example_pending_sessions, example_existing_sessions): + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) + picked_session_id = scheduler.pick_session( + example_total_capacity, + example_pending_sessions, + example_existing_sessions, + ) + assert picked_session_id == example_pending_sessions[0].session_id + picked_session = _find_and_pop_picked_session( + example_pending_sessions, + picked_session_id, + ) + agent_id = scheduler.assign_agent_for_session( + example_agents, + picked_session, + ) + assert agent_id == AgentId('i-001') + + +def test_lifo_scheduler(example_agents, example_pending_sessions, example_existing_sessions): + scheduler = LIFOSlotScheduler(ScalingGroupOpts(), {}) + picked_session_id = scheduler.pick_session( + example_total_capacity, + example_pending_sessions, + example_existing_sessions, + ) + assert picked_session_id == example_pending_sessions[2].session_id + picked_session = _find_and_pop_picked_session( + example_pending_sessions, + picked_session_id, + ) + agent_id = scheduler.assign_agent_for_session(example_agents, picked_session) + assert agent_id == 'i-001' + + +def test_fifo_scheduler_favor_cpu_for_requests_without_accelerators( + example_mixed_agents, + example_pending_sessions, +): + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) + for idx in range(3): + picked_session_id = scheduler.pick_session( + example_total_capacity, + example_pending_sessions, + [], + ) + assert picked_session_id == example_pending_sessions[0].session_id + picked_session = _find_and_pop_picked_session( + example_pending_sessions, + picked_session_id, + ) + agent_id = scheduler.assign_agent_for_session(example_mixed_agents, picked_session) + if idx == 0: + # example_mixed_agents do not have any agent with ROCM accelerators. + assert agent_id is None + elif idx == 1: + assert agent_id == AgentId('i-gpu') + elif idx == 2: + # It should favor the CPU-only agent if the requested slots + # do not include accelerators. + assert agent_id == AgentId('i-cpu') + + +def gen_pending_for_holb_tests(session_id: str, status_data: Mapping[str, Any]) -> PendingSession: + return PendingSession( + session_id=SessionId(session_id), # type: ignore + session_name=secrets.token_hex(8), + access_key=AccessKey('ak1'), + agent_id=AgentId('i-001'), + agent_addr='10.0.1.1:6001', + status_data=status_data, + session_creation_id=secrets.token_urlsafe(8), + kernels=[], + session_type=SessionTypes.INTERACTIVE, + cluster_mode=ClusterMode.SINGLE_NODE, + cluster_size=1, + scaling_group='sg01', + requested_slots=ResourceSlot({'cpu': Decimal(1), 'mem': Decimal(1024)}), + target_sgroup_names=[], + **_common_dummy_for_pending_session, + created_at=dtparse('2020-03-21T00:00:00+00:00'), + ) + + +def test_fifo_scheduler_hol_blocking_avoidance_empty_status_data(): + """ + Without any status_data, it should just pick the first session. + """ + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {'num_retries_to_skip': 5}) + pending_sessions = [ + gen_pending_for_holb_tests("s0", {}), + gen_pending_for_holb_tests("s1", {}), + gen_pending_for_holb_tests("s2", {}), + ] + picked_session_id = scheduler.pick_session( + example_total_capacity, + pending_sessions, + []) + assert picked_session_id == 's0' + + +def test_fifo_scheduler_hol_blocking_avoidance_config(): + """ + If the upfront sessions have enough number of retries, + it should skip them. + """ + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {'num_retries_to_skip': 0}) + pending_sessions = [ + gen_pending_for_holb_tests("s0", {'scheduler': {'retries': 5}}), + gen_pending_for_holb_tests("s1", {}), + gen_pending_for_holb_tests("s2", {}), + ] + picked_session_id = scheduler.pick_session( + example_total_capacity, + pending_sessions, + []) + assert picked_session_id == 's0' + + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {'num_retries_to_skip': 5}) + pending_sessions = [ + gen_pending_for_holb_tests("s0", {'scheduler': {'retries': 5}}), + gen_pending_for_holb_tests("s1", {'scheduler': {'retries': 4}}), + gen_pending_for_holb_tests("s2", {'scheduler': {'retries': 3}}), + ] + picked_session_id = scheduler.pick_session( + example_total_capacity, + pending_sessions, + []) + assert picked_session_id == 's1' + + +def test_fifo_scheduler_hol_blocking_avoidance_skips(): + """ + If the upfront sessions have enough number of retries, + it should skip them. + """ + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {'num_retries_to_skip': 5}) + pending_sessions = [ + gen_pending_for_holb_tests("s0", {'scheduler': {'retries': 5}}), + gen_pending_for_holb_tests("s1", {}), + gen_pending_for_holb_tests("s2", {}), + ] + picked_session_id = scheduler.pick_session( + example_total_capacity, + pending_sessions, + []) + assert picked_session_id == 's1' + + pending_sessions = [ + gen_pending_for_holb_tests("s0", {'scheduler': {'retries': 5}}), + gen_pending_for_holb_tests("s1", {'scheduler': {'retries': 10}}), + gen_pending_for_holb_tests("s2", {}), + ] + picked_session_id = scheduler.pick_session( + example_total_capacity, + pending_sessions, + []) + assert picked_session_id == 's2' + + +def test_fifo_scheduler_hol_blocking_avoidance_all_skipped(): + """ + If all sessions are skipped due to excessive number of retries, + then we go back to the normal FIFO by choosing the first of them. + """ + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {'num_retries_to_skip': 5}) + pending_sessions = [ + gen_pending_for_holb_tests("s0", {'scheduler': {'retries': 5}}), + gen_pending_for_holb_tests("s1", {'scheduler': {'retries': 5}}), + gen_pending_for_holb_tests("s2", {'scheduler': {'retries': 5}}), + ] + picked_session_id = scheduler.pick_session( + example_total_capacity, + pending_sessions, + []) + assert picked_session_id == 's0' + + +def test_fifo_scheduler_hol_blocking_avoidance_no_skip(): + """ + If non-first sessions have to be skipped, the scheduler should still + choose the first session. + """ + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {'num_retries_to_skip': 5}) + pending_sessions = [ + gen_pending_for_holb_tests("s0", {}), + gen_pending_for_holb_tests("s1", {'scheduler': {'retries': 10}}), + gen_pending_for_holb_tests("s2", {}), + ] + picked_session_id = scheduler.pick_session( + example_total_capacity, + pending_sessions, + []) + assert picked_session_id == 's0' + + +def test_lifo_scheduler_favor_cpu_for_requests_without_accelerators( + example_mixed_agents, + example_pending_sessions, +): + # Check the reverse with the LIFO scheduler. + # The result must be same. + scheduler = LIFOSlotScheduler(ScalingGroupOpts(), {}) + for idx in range(3): + picked_session_id = scheduler.pick_session( + example_total_capacity, + example_pending_sessions, + []) + assert picked_session_id == example_pending_sessions[-1].session_id + picked_session = _find_and_pop_picked_session( + example_pending_sessions, picked_session_id) + agent_id = scheduler.assign_agent_for_session(example_mixed_agents, picked_session) + if idx == 2: + # example_mixed_agents do not have any agent with ROCM accelerators. + assert agent_id is None + elif idx == 1: + assert agent_id == AgentId('i-gpu') + elif idx == 0: + # It should favor the CPU-only agent if the requested slots + # do not include accelerators. + assert agent_id == AgentId('i-cpu') + + +def test_drf_scheduler( + example_agents, + example_pending_sessions, + example_existing_sessions, +): + scheduler = DRFScheduler(ScalingGroupOpts(), {}) + picked_session_id = scheduler.pick_session( + example_total_capacity, + example_pending_sessions, + example_existing_sessions, + ) + pprint(example_pending_sessions) + assert picked_session_id == example_pending_sessions[1].session_id + picked_session = _find_and_pop_picked_session( + example_pending_sessions, + picked_session_id, + ) + agent_id = scheduler.assign_agent_for_session(example_agents, picked_session) + assert agent_id == 'i-001' + + +def test_mof_scheduler_first_assign( + example_agents, + example_pending_sessions, + example_existing_sessions, +): + scheduler = MOFScheduler(ScalingGroupOpts(), {}) + picked_session_id = scheduler.pick_session( + example_total_capacity, + example_pending_sessions, + example_existing_sessions) + assert picked_session_id == example_pending_sessions[0].session_id + picked_session = _find_and_pop_picked_session( + example_pending_sessions, picked_session_id) + + agent_id = scheduler.assign_agent_for_session(example_agents, picked_session) + assert agent_id == 'i-001' + + +def test_mof_scheduler_second_assign( + example_agents_first_one_assigned, + example_pending_sessions, + example_existing_sessions, +): + scheduler = MOFScheduler(ScalingGroupOpts(), {}) + picked_session_id = scheduler.pick_session( + example_total_capacity, + example_pending_sessions, + example_existing_sessions) + assert picked_session_id == example_pending_sessions[0].session_id + picked_session = _find_and_pop_picked_session( + example_pending_sessions, picked_session_id) + + agent_id = scheduler.assign_agent_for_session( + example_agents_first_one_assigned, picked_session) + assert agent_id == 'i-101' + + +def test_mof_scheduler_no_valid_agent( + example_agents_no_valid, + example_pending_sessions, + example_existing_sessions, +): + scheduler = MOFScheduler(ScalingGroupOpts(), {}) + picked_session_id = scheduler.pick_session( + example_total_capacity, + example_pending_sessions, + example_existing_sessions) + assert picked_session_id == example_pending_sessions[0].session_id + picked_session = _find_and_pop_picked_session( + example_pending_sessions, picked_session_id) + + agent_id = scheduler.assign_agent_for_session(example_agents_no_valid, picked_session) + assert agent_id is None + + +@pytest.mark.asyncio +async def test_pending_timeout(mocker): + + class MockDatetime: + @classmethod + def now(cls, tzinfo): + return datetime(2021, 1, 1, 0, 0, 0) + + mocker.patch('ai.backend.manager.scheduler.dispatcher.datetime', MockDatetime) + mock_query_result = MagicMock() + mock_query_result.fetchall = MagicMock(return_value=[ + {'id': 'session3', 'created_at': datetime(2020, 12, 31, 23, 59, 59)}, + {'id': 'session2', 'created_at': datetime(2020, 12, 30, 23, 59, 59)}, + {'id': 'session1', 'created_at': datetime(2020, 12, 29, 23, 59, 59)}, + ]) + mock_execute = AsyncMock(return_value=mock_query_result) + mock_dbconn = MagicMock() + mock_dbconn.execute = mock_execute + + scheduler = FIFOSlotScheduler(ScalingGroupOpts(pending_timeout=timedelta(seconds=86400 * 2)), {}) + candidate_session_rows, cancelled_session_rows = await _list_pending_sessions( + mock_dbconn, scheduler, 'default', + ) + assert len(candidate_session_rows) == 2 + assert len(cancelled_session_rows) == 1 + assert cancelled_session_rows[0]['id'] == 'session1' + + scheduler = FIFOSlotScheduler(ScalingGroupOpts(pending_timeout=timedelta(seconds=0)), {}) + candidate_session_rows, cancelled_session_rows = await _list_pending_sessions( + mock_dbconn, scheduler, 'default', + ) + assert len(candidate_session_rows) == 3 + assert len(cancelled_session_rows) == 0 + + +class DummyEtcd: + async def get_prefix(self, key: str) -> Mapping[str, Any]: + return {} + + +@pytest.mark.asyncio +async def test_manually_assign_agent_available( + file_lock_factory, + registry_ctx: tuple[AgentRegistry, MagicMock, MagicMock, MagicMock, MagicMock, MagicMock], + example_agents, + example_pending_sessions, +): + mock_local_config = MagicMock() + registry, mock_db, mock_dbresult, mock_shared_config, mock_event_dispatcher, mock_event_producer = \ + registry_ctx + sess_ctx = example_pending_sessions[0] + mock_sched_ctx = MagicMock() + mock_check_result = MagicMock() + scheduler = FIFOSlotScheduler(ScalingGroupOpts(), {}) + sgroup_name = example_agents[0].scaling_group + candidate_agents = example_agents + example_pending_sessions[0].agent_id = 'i-001' + sess_ctx = example_pending_sessions[0] + + dispatcher = SchedulerDispatcher( + local_config=mock_local_config, + shared_config=mock_shared_config, + event_dispatcher=mock_event_dispatcher, + event_producer=mock_event_producer, + lock_factory=file_lock_factory, + registry=registry, + ) + + # manually assigned agent is None + mock_dbresult.scalar = MagicMock(return_value=None) + await dispatcher._schedule_single_node_session( + mock_sched_ctx, + scheduler, + sgroup_name, + candidate_agents, + sess_ctx, + mock_check_result, + ) + result = mock_dbresult.scalar() + assert result is None + + # manually assigned agent is enough capacity + mock_dbresult.scalar = MagicMock(return_value={ + 'cpu': Decimal('8.0'), + 'mem': Decimal('8192'), + 'cuda.shares': Decimal('4'), + 'rocm.devices': Decimal('4'), + }) + await dispatcher._schedule_single_node_session( + mock_sched_ctx, + scheduler, + sgroup_name, + candidate_agents, + sess_ctx, + mock_check_result, + ) + result = mock_dbresult.scalar() + for key in result: + assert result[key] >= example_pending_sessions[0].requested_slots[key] + + # manually assigned agent is not enough capacity. + mock_dbresult.scalar = MagicMock(return_value={ + 'cpu': Decimal('0.0'), + 'mem': Decimal('0'), + 'cuda.shares': Decimal('0'), + 'rocm.devices': Decimal('0'), + }) + await dispatcher._schedule_single_node_session( + mock_sched_ctx, + scheduler, + sgroup_name, + candidate_agents, + sess_ctx, + mock_check_result, + ) + result = mock_dbresult.scalar() + for key in result: + assert result[key] <= example_pending_sessions[0].requested_slots[key] + + +@pytest.mark.asyncio +@mock.patch('ai.backend.manager.scheduler.predicates.datetime') +async def test_multiple_timezones_for_reserved_batch_session_predicate(mock_dt): + mock_db_conn = MagicMock() + mock_sched_ctx = MagicMock() + mock_sess_ctx = MagicMock() + mock_sess_ctx.session_type = SessionTypes.BATCH + mock_sess_ctx.kernel_id = 'fake-kernel-id' + + now = '2020-06-29T00:00:00+00:00' + mock_dt.now = MagicMock(return_value=dtparse(now)) + + # Start time is not yet reached (now < start time) + start_time = '2020-06-29T00:00:01+00:00' + mock_db_conn.scalar = AsyncMock(return_value=dtparse(start_time)) + result = await check_reserved_batch_session(mock_db_conn, mock_sched_ctx, mock_sess_ctx) + assert not result.passed, (now, start_time) + + # Start time is reached (now > start time) + start_time = '2020-06-28T23:59:59+00:00' + mock_db_conn.scalar = AsyncMock(return_value=dtparse(start_time)) + result = await check_reserved_batch_session(mock_db_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed, (now, start_time) + + # Start time is not yet reached by timezone (now < start time) + # Note that 6/29 00:00 (UTC) < 6/29 00:00 (-09:00) == 6/29 09:00 (UTC) + for i in range(1, 12): + start_time = f'2020-06-29T00:00:00-{i:02d}:00' + mock_db_conn.scalar = AsyncMock(return_value=dtparse(start_time)) + result = await check_reserved_batch_session(mock_db_conn, mock_sched_ctx, mock_sess_ctx) + assert not result.passed, (now, start_time) + + # Start time is reached by timezone (now > start time) + # Note that 6/29 00:00 (UTC) > 6/29 00:00 (+09:00) == 6/28 15:00 (UTC) + for i in range(1, 12): + start_time = f'2020-06-29T00:00:00+{i:02d}:00' + mock_db_conn.scalar = AsyncMock(return_value=dtparse(start_time)) + result = await check_reserved_batch_session(mock_db_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed, (now, start_time) + + # Should pass if start time is not specified (start immediately). + mock_db_conn.scalar = AsyncMock(return_value=None) + result = await check_reserved_batch_session(mock_db_conn, mock_sched_ctx, mock_sess_ctx) + assert result.passed + + +# TODO: write tests for multiple agents and scaling groups diff --git a/tests/plugin/BUILD b/tests/plugin/BUILD new file mode 100644 index 0000000000..6322b48a81 --- /dev/null +++ b/tests/plugin/BUILD @@ -0,0 +1,10 @@ +# python_test_utils( +# name="test_utils", +# ) + +python_tests( + name="tests", + runtime_package_dependencies=[ + "src/ai/backend/plugin:lib", + ], +) diff --git a/tests/plugin/test_entrypoint.py b/tests/plugin/test_entrypoint.py new file mode 100644 index 0000000000..f3cd7f9881 --- /dev/null +++ b/tests/plugin/test_entrypoint.py @@ -0,0 +1,63 @@ +import tempfile +import textwrap as tw +from pathlib import Path + +from ai.backend.plugin.entrypoint import extract_entrypoints_from_buildscript + + +def test_parse_build(): + with tempfile.NamedTemporaryFile('w') as f: + f.write(tw.dedent(''' + python_sources( + name="lib", + ) + python_distribution( + name="dist", + dependencies=[ + ":service", + ], + provides=python_artifact( + name="backend.ai-manager", + description="Backend.AI Manager", + license="LGPLv3", + ), + entry_points={ + "backendai_cli_v10": { + "mgr": "ai.backend.manager.cli.__main__:main", + "mgr.start-server": "ai.backend.manager.server:main", + }, + "backendai_scheduler_v10": { + "fifo": "ai.backend.manager.scheduler.fifo:FIFOSlotScheduler", + "lifo": "ai.backend.manager.scheduler.fifo:LIFOSlotScheduler", + "drf": "ai.backend.manager.scheduler.drf:DRFScheduler", + "mof": "ai.backend.manager.scheduler.mof:MOFScheduler", + }, + "backendai_error_monitor_v20": { + "intrinsic": "ai.backend.manager.plugin.error_monitor:ErrorMonitor", + }, + }, + generate_setup=True, + ) + python_tests( + name="tests", + ) + ''')) + f.flush() + p = Path(f.name) + items = [*extract_entrypoints_from_buildscript("backendai_cli_v10", p)] + assert (items[0].name, items[0].module, items[0].attr) == \ + ("mgr", "ai.backend.manager.cli.__main__", "main") + assert (items[1].name, items[1].module, items[1].attr) == \ + ("mgr.start-server", "ai.backend.manager.server", "main") + items = [*extract_entrypoints_from_buildscript("backendai_scheduler_v10", p)] + assert (items[0].name, items[0].module, items[0].attr) == \ + ("fifo", "ai.backend.manager.scheduler.fifo", "FIFOSlotScheduler") + assert (items[1].name, items[1].module, items[1].attr) == \ + ("lifo", "ai.backend.manager.scheduler.fifo", "LIFOSlotScheduler") + assert (items[2].name, items[2].module, items[2].attr) == \ + ("drf", "ai.backend.manager.scheduler.drf", "DRFScheduler") + assert (items[3].name, items[3].module, items[3].attr) == \ + ("mof", "ai.backend.manager.scheduler.mof", "MOFScheduler") + items = [*extract_entrypoints_from_buildscript("backendai_error_monitor_v20", p)] + assert (items[0].name, items[0].module, items[0].attr) == \ + ("intrinsic", "ai.backend.manager.plugin.error_monitor", "ErrorMonitor") diff --git a/tests/storage-proxy/BUILD b/tests/storage-proxy/BUILD new file mode 100644 index 0000000000..b4b0f6fa77 --- /dev/null +++ b/tests/storage-proxy/BUILD @@ -0,0 +1,10 @@ +python_test_utils( + name="test_utils", +) + +python_tests( + name="tests", + dependencies=[ + "src/ai/backend/storage:service", + ], +) diff --git a/tests/storage-proxy/conftest.py b/tests/storage-proxy/conftest.py new file mode 100644 index 0000000000..78e95beb48 --- /dev/null +++ b/tests/storage-proxy/conftest.py @@ -0,0 +1,17 @@ +import tempfile +from pathlib import Path + +import pytest + + +@pytest.fixture +def vfroot(): + with tempfile.TemporaryDirectory(prefix="bai-storage-test-") as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def local_volume(vfroot): + volume = vfroot / "local" + volume.mkdir(parents=True, exist_ok=True) + yield volume diff --git a/tests/storage-proxy/test_netapp.py b/tests/storage-proxy/test_netapp.py new file mode 100644 index 0000000000..c6dac64afb --- /dev/null +++ b/tests/storage-proxy/test_netapp.py @@ -0,0 +1,73 @@ +import uuid +from pathlib import PurePath + +import pytest + +from ai.backend.storage.netapp import NetAppVolume + + +@pytest.fixture +async def netapp_volume(vfroot): + options = { + # TODO: mock options + } + netapp = NetAppVolume( + {}, + vfroot / "netapp", + fsprefix=PurePath("fsprefix"), + options=options, + ) + await netapp.init() + try: + yield netapp + finally: + await netapp.shutdown() + + +@pytest.fixture +async def empty_vfolder(netapp_volume): + vfid = uuid.uuid4() + await netapp_volume.create_vfolder(vfid) + yield vfid + await netapp_volume.delete_vfolder(vfid) + + +def test_dummy(): + # prevent pants error due to when no tests are selected. + pass + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_netapp_get_usage(netapp_volume, empty_vfolder): + vfpath = netapp_volume.mangle_vfpath(empty_vfolder) + (vfpath / "test.txt").write_bytes(b"12345") + (vfpath / "inner").mkdir() + (vfpath / "inner" / "hello.txt").write_bytes(b"678") + (vfpath / "inner" / "world.txt").write_bytes(b"901") + (vfpath / "test2.txt").symlink_to((vfpath / "inner" / "hello.txt")) + (vfpath / "inner2").symlink_to((vfpath / "inner")) + usage = await netapp_volume.get_usage(empty_vfolder) + assert usage.file_count == 6 + assert usage.used_bytes == 92 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_netapp_clone(netapp_volume): + vfid1 = uuid.uuid4() + vfid2 = uuid.uuid4() + vfpath1 = netapp_volume.mount_path / vfid1.hex[0:2] / vfid1.hex[2:4] / vfid1.hex[4:] + vfpath2 = netapp_volume.mount_path / vfid2.hex[0:2] / vfid2.hex[2:4] / vfid2.hex[4:] + await netapp_volume.create_vfolder(vfid1) + assert vfpath1.is_dir() + (vfpath1 / "test.txt").write_bytes(b"12345") + (vfpath1 / "inner").mkdir() + (vfpath1 / "inner" / "hello.txt").write_bytes(b"678") + await netapp_volume.clone_vfolder(vfid1, netapp_volume, vfid2, None) + assert vfpath2.is_dir() + assert (vfpath2 / "test.txt").is_file() + assert (vfpath2 / "inner").is_dir() + assert (vfpath2 / "inner" / "hello.txt").is_file() + await netapp_volume.delete_vfolder(vfid1) + await netapp_volume.delete_vfolder(vfid2) diff --git a/tests/storage-proxy/test_purestorage.py b/tests/storage-proxy/test_purestorage.py new file mode 100644 index 0000000000..ecc8af4d29 --- /dev/null +++ b/tests/storage-proxy/test_purestorage.py @@ -0,0 +1,89 @@ +import os +import secrets +import shutil +import uuid +from pathlib import Path, PurePath, PurePosixPath + +import pytest + +from ai.backend.storage.purestorage import FlashBladeVolume +from ai.backend.storage.types import DirEntryType + + +@pytest.fixture +def fbroot(): + tmpdir_name = f"bai-storage-test-{secrets.token_urlsafe(12)}" + tmpdir = Path(os.environ["BACKEND_STORAGE_TEST_FBMOUNT"]) / tmpdir_name + tmpdir.mkdir() + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + +@pytest.fixture +async def fb_volume(fbroot): + options = { + # TODO: mock options + } + host = FlashBladeVolume(fbroot, fsprefix=PurePath("fsprefix"), options=options) + await host.init() + try: + yield host + finally: + await host.shutdown() + + +@pytest.fixture +async def empty_vfolder(fb_volume): + vfid = uuid.uuid4() + await fb_volume.create_vfolder(vfid) + yield vfid + await fb_volume.delete_vfolder(vfid) + + +def test_dummy(): + # prevent pants error due to when no tests are selected. + pass + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_fb_get_usage(fb_volume, empty_vfolder): + vfpath = fb_volume._mangle_vfpath(empty_vfolder) + (vfpath / "test.txt").write_bytes(b"12345") + (vfpath / "inner").mkdir() + (vfpath / "inner" / "hello.txt").write_bytes(b"678") + (vfpath / "inner" / "world.txt").write_bytes(b"901") + (vfpath / "test2.txt").symlink_to((vfpath / "inner" / "hello.txt")) + (vfpath / "inner2").symlink_to((vfpath / "inner")) + usage = await fb_volume.get_usage(empty_vfolder) + assert usage.file_count == 5 # including symlinks + assert usage.used_bytes == ( + 11 + len(bytes(vfpath / "inner" / "hello.txt")) + len(bytes(vfpath / "inner")) + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_fb_scandir(fb_volume, empty_vfolder): + vfpath = fb_volume._mangle_vfpath(empty_vfolder) + (vfpath / "test1.txt").write_bytes(b"12345") + (vfpath / "inner").mkdir() + (vfpath / "inner" / "hello.txt").write_bytes(b"abc") + (vfpath / "inner" / "world.txt").write_bytes(b"def") + (vfpath / "test2.txt").symlink_to((vfpath / "inner" / "hello.txt")) + (vfpath / "inner2").symlink_to((vfpath / "inner")) + entries = [ + item async for item in fb_volume.scandir(empty_vfolder, PurePosixPath(".")) + ] + assert len(entries) == 4 + entries.sort(key=lambda entry: entry.name) + assert entries[0].name == "inner" + assert entries[0].type == DirEntryType.DIRECTORY + assert entries[1].name == "inner2" + assert entries[1].type == DirEntryType.SYMLINK + assert entries[2].name == "test1.txt" + assert entries[2].type == DirEntryType.FILE + assert entries[3].name == "test2.txt" + assert entries[3].type == DirEntryType.SYMLINK diff --git a/tests/storage-proxy/test_vfs.py b/tests/storage-proxy/test_vfs.py new file mode 100644 index 0000000000..3868031b18 --- /dev/null +++ b/tests/storage-proxy/test_vfs.py @@ -0,0 +1,123 @@ +import uuid +from pathlib import Path, PurePath + +import pytest + +from ai.backend.storage.vfs import BaseVolume + + +@pytest.fixture +async def vfs(local_volume): + vfs = BaseVolume({}, local_volume, fsprefix=PurePath("fsprefix"), options={}) + await vfs.init() + try: + yield vfs + finally: + await vfs.shutdown() + + +@pytest.fixture +async def empty_vfolder(vfs): + vfid = uuid.uuid4() + await vfs.create_vfolder(vfid) + yield vfid + await vfs.delete_vfolder(vfid) + + +@pytest.mark.asyncio +async def test_vfs_vfolder_mgmt(vfs): + vfid = uuid.uuid4() + await vfs.create_vfolder(vfid) + vfpath = vfs.mount_path / vfid.hex[0:2] / vfid.hex[2:4] / vfid.hex[4:] + assert vfpath.is_dir() + await vfs.delete_vfolder(vfid) + assert not vfpath.exists() + assert not vfpath.parent.exists() + assert not vfpath.parent.parent.exists() + + vfid1 = uuid.UUID(hex="82a6ba2b7b8e41deb5ee2c909ce34bcb") + vfid2 = uuid.UUID(hex="82a6ba2b7b8e41deb5ee2c909ce34bcc") + await vfs.create_vfolder(vfid1) + await vfs.create_vfolder(vfid2) + vfpath1 = vfs.mount_path / vfid1.hex[0:2] / vfid1.hex[2:4] / vfid1.hex[4:] + vfpath2 = vfs.mount_path / vfid2.hex[0:2] / vfid2.hex[2:4] / vfid2.hex[4:] + assert vfpath2.relative_to(vfpath1.parent).name == vfpath2.name + assert vfpath1.is_dir() + await vfs.delete_vfolder(vfid1) + assert not vfpath1.exists() + # if the prefix dirs are not empty, they shouldn't be deleted + assert vfpath1.parent.exists() + assert vfpath1.parent.parent.exists() + await vfs.delete_vfolder(vfid2) + assert not vfpath2.exists() + # if the prefix dirs become empty, they should be deleted + assert not vfpath2.parent.exists() + assert not vfpath2.parent.parent.exists() + + +@pytest.mark.asyncio +async def test_vfs_get_usage(vfs, empty_vfolder): + vfpath = vfs.mangle_vfpath(empty_vfolder) + (vfpath / "test.txt").write_bytes(b"12345") + (vfpath / "inner").mkdir() + (vfpath / "inner" / "hello.txt").write_bytes(b"678") + (vfpath / "inner" / "world.txt").write_bytes(b"901") + usage = await vfs.get_usage(empty_vfolder) + assert usage.file_count == 3 + assert usage.used_bytes == 11 + + +@pytest.mark.asyncio +async def test_vfs_clone(vfs): + vfid1 = uuid.uuid4() + vfid2 = uuid.uuid4() + vfpath1 = vfs.mount_path / vfid1.hex[0:2] / vfid1.hex[2:4] / vfid1.hex[4:] + vfpath2 = vfs.mount_path / vfid2.hex[0:2] / vfid2.hex[2:4] / vfid2.hex[4:] + await vfs.create_vfolder(vfid1) + assert vfpath1.is_dir() + (vfpath1 / "test.txt").write_bytes(b"12345") + (vfpath1 / "inner").mkdir() + (vfpath1 / "inner" / "hello.txt").write_bytes(b"678") + await vfs.clone_vfolder(vfid1, vfs, vfid2) + assert vfpath2.is_dir() + assert (vfpath2 / "test.txt").is_file() + assert (vfpath2 / "inner").is_dir() + assert (vfpath2 / "inner" / "hello.txt").is_file() + await vfs.delete_vfolder(vfid1) + await vfs.delete_vfolder(vfid2) + + +@pytest.mark.asyncio +async def test_vfs_operation(vfs, empty_vfolder): + vfpath = vfs.mangle_vfpath(empty_vfolder) + (vfpath / "test0").mkdir() + (vfpath / "test0" / "test.txt").write_bytes(b"12345") + with pytest.raises(FileNotFoundError): + await vfs.move_file(empty_vfolder, Path("test0/test.txt"), Path("test1/test.txt")) + (vfpath / "test1").mkdir() + await vfs.move_file(empty_vfolder, Path("test0/test.txt"), Path("test1/test.txt")) + assert (vfpath / "test1" / "test.txt").is_file() + assert (vfpath / "test1" / "test.txt").read_bytes() == b"12345" + assert not (vfpath / "test0" / "test.txt").is_file() + + # rename directory from test1 to test2 + await vfs.move_tree(empty_vfolder, Path("test1"), Path("test2")) + assert (vfpath / "test2").is_dir() + assert (vfpath / "test2" / "test.txt").read_bytes() == b"12345" + + # move directory into another directory that not exists + await vfs.move_tree(empty_vfolder, Path("test2"), Path("test0/inner/test2/test3")) + assert (vfpath / "test0" / "inner").is_dir() + assert (vfpath / "test0" / "inner" / "test2" / "test3").is_dir() + assert ( + vfpath / "test0" / "inner" / "test2" / "test3" / "test.txt" + ).read_bytes() == b"12345" + + # move directory into another directory that already exists + await vfs.move_tree(empty_vfolder, Path("test0/inner/test2/"), Path("test0/")) + assert (vfpath / "test0" / "test2" / "test3").is_dir() + + # do not let move directory to non-relative directory + with pytest.raises(Exception): + await vfs.move_tree(empty_vfolder, Path("test0"), Path("../")) + await vfs.move_tree(empty_vfolder, Path("/"), Path("./")) diff --git a/tests/storage-proxy/test_xfs.py b/tests/storage-proxy/test_xfs.py new file mode 100644 index 0000000000..fdec5743cb --- /dev/null +++ b/tests/storage-proxy/test_xfs.py @@ -0,0 +1,267 @@ +import os +import uuid +from pathlib import Path, PurePath + +import pytest + +from ai.backend.common.types import BinarySize +from ai.backend.storage.vfs import BaseVolume, run +from ai.backend.storage.xfs import XfsVolume + + +def read_etc_projid(): + with open("/etc/projid") as fp: + content = fp.read() + project_id_dict = {} + for line in content.splitlines(): + proj_name, proj_id = line.split(":")[:2] + project_id_dict[proj_name] = int(proj_id) + return project_id_dict + + +def read_etc_projects(): + with open("/etc/projects") as fp: + content = fp.read() + vfpath_id_dict = {} + for line in content.splitlines(): + proj_id, vfpath = line.split(":")[:2] + vfpath_id_dict[int(proj_id)] = vfpath + return vfpath_id_dict + + +def create_sample_dir_tree(vfpath: Path) -> int: + (vfpath / "test.txt").write_bytes(b"12345") + (vfpath / "inner").mkdir() + (vfpath / "inner" / "hello.txt").write_bytes(b"678") + return 8 # return number of bytes written + + +def assert_sample_dir_tree(vfpath: Path) -> None: + assert (vfpath / "test.txt").is_file() + assert (vfpath / "test.txt").read_bytes() == b"12345" + assert (vfpath / "inner").is_dir() + assert (vfpath / "inner" / "hello.txt").is_file() + assert (vfpath / "inner" / "hello.txt").read_bytes() == b"678" + + +@pytest.fixture +async def xfs(vfroot): + xfs = XfsVolume({}, vfroot / "xfs") + await xfs.init(os.getuid(), os.getgid()) + try: + yield xfs + finally: + await xfs.shutdown() + + +@pytest.fixture +async def vfs(local_volume): + vfs = BaseVolume({}, local_volume, fsprefix=PurePath("fsprefix"), options={}) + await vfs.init() + try: + yield vfs + finally: + await vfs.shutdown() + + +@pytest.fixture +async def empty_vfolder(xfs): + vfid = uuid.uuid4() + await xfs.create_vfolder(vfid, options={"quota": BinarySize.from_str("10m")}) + yield vfid + await xfs.delete_vfolder(vfid) + + +def test_dummy(): + # prevent pants error due to when no tests are selected. + pass + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_single_vfolder_mgmt(xfs): + vfid = uuid.uuid4() + options = {"quota": BinarySize.from_str("10m")} + # vfolder create test + await xfs.create_vfolder(vfid, options=options) + vfpath = xfs.mount_path / vfid.hex[0:2] / vfid.hex[2:4] / vfid.hex[4:] + project_id_dict = read_etc_projid() + vfpath_id_dict = read_etc_projects() + assert vfpath.is_dir() + assert str(vfid) in project_id_dict + vfid_project_id = project_id_dict[str(vfid)] + # vfolder delete test + assert vfpath_id_dict[project_id_dict[str(vfid)]] == str(vfpath) + await xfs.delete_vfolder(vfid) + assert not vfpath.exists() + assert not vfpath.parent.exists() or not (vfpath.parent / vfid.hex[2:4]).exists() + assert ( + not vfpath.parent.parent.exists() + or not (vfpath.parent.parent / vfid.hex[0:2]).exists() + ) + project_id_dict = read_etc_projid() + vfpath_id_dict = read_etc_projects() + assert str(vfid) not in project_id_dict + assert vfid_project_id not in vfpath_id_dict + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_multiple_vfolder_mgmt(xfs): + vfid1 = uuid.UUID(hex="83a6ba2b7b8e41deb5ee2c909ce34bcb") + vfid2 = uuid.UUID(hex="83a6ba2b7b8e41deb5ee2c909ce34bcc") + options = {"quota": BinarySize.from_str("10m")} + await xfs.create_vfolder(vfid1, options=options) + await xfs.create_vfolder(vfid2, options=options) + vfpath1 = xfs.mount_path / vfid1.hex[0:2] / vfid1.hex[2:4] / vfid1.hex[4:] + vfpath2 = xfs.mount_path / vfid2.hex[0:2] / vfid2.hex[2:4] / vfid2.hex[4:] + assert vfpath2.relative_to(vfpath1.parent).name == vfpath2.name + assert vfpath1.is_dir() + await xfs.delete_vfolder(vfid1) + assert not vfpath1.exists() + # if the prefix dirs are not empty, they shouldn't be deleted + assert vfpath1.parent.exists() + assert vfpath1.parent.parent.exists() + await xfs.delete_vfolder(vfid2) + # if the prefix dirs become empty, they should be deleted + assert not vfpath2.parent.exists() + assert not vfpath2.parent.parent.exists() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_quota(xfs): + vfid = uuid.uuid4() + options = {"quota": BinarySize.from_str("10m")} + await xfs.create_vfolder(vfid, options=options) + vfpath = xfs.mount_path / vfid.hex[0:2] / vfid.hex[2:4] / vfid.hex[4:] + assert vfpath.is_dir() + assert await xfs.get_quota(vfid) == BinarySize.from_str("10m") + await xfs.set_quota(vfid, BinarySize.from_str("1m")) + assert await xfs.get_quota(vfid) == BinarySize.from_str("1m") + await xfs.delete_vfolder(vfid) + assert not vfpath.is_dir() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_get_usage(xfs, empty_vfolder): + vfpath = xfs.mangle_vfpath(empty_vfolder) + (vfpath / "test.txt").write_bytes(b"12345") + (vfpath / "inner").mkdir() + (vfpath / "inner" / "hello.txt").write_bytes(b"678") + (vfpath / "inner" / "world.txt").write_bytes(b"901") + usage = await xfs.get_usage(empty_vfolder) + assert usage.file_count == 3 + assert usage.used_bytes == 11 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_get_used_bytes(xfs): + vfid = uuid.uuid4() + options = {"quota": BinarySize.from_str("10m")} + await xfs.create_vfolder(vfid, options=options) + vfpath = xfs.mount_path / vfid.hex[0:2] / vfid.hex[2:4] / vfid.hex[4:] + (vfpath / "test.txt").write_bytes(b"12345") + (vfpath / "inner").mkdir() + (vfpath / "inner" / "hello.txt").write_bytes(b"678") + (vfpath / "inner" / "world.txt").write_bytes(b"901") + + used_bytes = await xfs.get_used_bytes(vfid) + full_report = await run( + ["sudo", "xfs_quota", "-x", "-c", "report -h", xfs.mount_path], + ) + report = "" + for line in full_report.split("\n"): + if str(vfid) in line: + report = line + break + assert len(report.split()) == 6 + proj_name, xfs_used, _, _, _, _ = report.split() + assert str(vfid)[:-5] == proj_name + assert used_bytes == BinarySize.from_str(xfs_used) + await xfs.delete_vfolder(vfid) + assert not vfpath.is_dir() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_mkdir_rmdir(xfs, empty_vfolder): + vfpath = xfs.mangle_vfpath(empty_vfolder) + test_rel_path = "test/abc" + await xfs.mkdir(empty_vfolder, Path(test_rel_path), parents=True) + assert Path(vfpath, test_rel_path).is_dir() + await xfs.rmdir(empty_vfolder, Path(test_rel_path), recursive=True) + assert not Path(vfpath, test_rel_path).is_dir() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_vfolder_operations(xfs, empty_vfolder): + vfpath = xfs.mangle_vfpath(empty_vfolder) + (vfpath / "test0").mkdir() + (vfpath / "test0" / "test.txt").write_bytes(b"12345") + await xfs.move_file(empty_vfolder, Path("test0/test.txt"), Path("test1/test.txt")) + assert (vfpath / "test1" / "test.txt").is_file() + assert (vfpath / "test1" / "test.txt").read_bytes() == b"12345" + assert not (vfpath / "test0" / "test.txt").is_file() + + await xfs.copy_file(empty_vfolder, Path("test1/test.txt"), Path("test2/test.txt")) + assert (vfpath / "test1" / "test.txt").is_file() + assert (vfpath / "test2" / "test.txt").is_file() + assert (vfpath / "test2" / "test.txt").read_bytes() == b"12345" + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_clone_to_vfs(xfs, vfs): + vfid_src = uuid.uuid4() + vfid_dst = uuid.uuid4() + vfpath_src = xfs.mangle_vfpath(vfid_src) + vfpath_dst = vfs.mangle_vfpath(vfid_dst) + await xfs.create_vfolder(vfid_src) + assert vfpath_src.is_dir() + create_sample_dir_tree(vfpath_src) + + await xfs.clone_vfolder(vfid_src, vfs, vfid_dst) + assert_sample_dir_tree(vfpath_dst) + + await xfs.delete_vfolder(vfid_src) + await vfs.delete_vfolder(vfid_dst) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_vfs_clone_to_xfs(xfs, vfs): + vfid_src = uuid.uuid4() + vfid_dst = uuid.uuid4() + vfpath_src = vfs.mangle_vfpath(vfid_src) + vfpath_dst = xfs.mangle_vfpath(vfid_dst) + await vfs.create_vfolder(vfid_src) + assert vfpath_src.is_dir() + create_sample_dir_tree(vfpath_src) + + await vfs.clone_vfolder(vfid_src, xfs, vfid_dst) + assert_sample_dir_tree(vfpath_dst) + + await vfs.delete_vfolder(vfid_src) + await xfs.delete_vfolder(vfid_dst) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xfs_clone_to_xfs(xfs, vfs): + vfid_src = uuid.uuid4() + vfid_dst = uuid.uuid4() + vfpath_src = xfs.mangle_vfpath(vfid_src) + vfpath_dst = xfs.mangle_vfpath(vfid_dst) + await xfs.create_vfolder(vfid_src) + assert vfpath_src.is_dir() + create_sample_dir_tree(vfpath_src) + + await xfs.clone_vfolder(vfid_src, xfs, vfid_dst) + assert_sample_dir_tree(vfpath_dst) + + await xfs.delete_vfolder(vfid_src) + await xfs.delete_vfolder(vfid_dst) diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 7f117b3af3..0000000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,8 +0,0 @@ -''' -Since this is a meta-package without any functionality, -we place a dummy placeholder that always passes. -''' - - -def test_dummy(): - pass diff --git a/tests/webserver/BUILD b/tests/webserver/BUILD new file mode 100644 index 0000000000..d44a25174e --- /dev/null +++ b/tests/webserver/BUILD @@ -0,0 +1,10 @@ +python_test_utils( + name="test_utils", +) + +python_tests( + name="tests", + dependencies=[ + "src/ai/backend/web:service", + ], +) diff --git a/tests/webserver/conftest.py b/tests/webserver/conftest.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/webserver/test_auth.py b/tests/webserver/test_auth.py new file mode 100644 index 0000000000..f639f67c95 --- /dev/null +++ b/tests/webserver/test_auth.py @@ -0,0 +1,99 @@ +import pytest + +from typing import Any, Dict +from unittest.mock import MagicMock, AsyncMock + +from aiohttp import web + +from ai.backend.web.auth import get_api_session, get_anonymous_session + + +class DummyRequest: + def __init__(self, app_data: Dict[str, Any]) -> None: + self.app = app_data + + +@pytest.mark.asyncio +async def test_get_api_session(mocker): + mock_request = DummyRequest({'config': { + 'api': {'domain': 'default', 'endpoint': 'https://api.backend.ai'}, + }}) + + mock_get_session = AsyncMock(return_value={ + 'authenticated': False, + }) + mocker.patch('ai.backend.web.auth.get_session', mock_get_session) + with pytest.raises(web.HTTPUnauthorized): + await get_api_session(mock_request) + mock_get_session.assert_awaited_once() + + mock_get_session = AsyncMock(return_value={ + 'authenticated': True, + 'token': {'type': 'something-else'}, + }) + mocker.patch('ai.backend.web.auth.get_session', mock_get_session) + with pytest.raises(web.HTTPBadRequest): + await get_api_session(mock_request) + mock_get_session.assert_awaited_once() + + mock_get_session = AsyncMock(return_value={ + 'authenticated': True, + 'token': {'type': 'keypair', 'access_key': 'ABC', 'secret_key': 'xyz'}, + }) + mocker.patch('ai.backend.web.auth.get_session', mock_get_session) + api_session = await get_api_session(mock_request) + mock_get_session.assert_awaited_once() + async with api_session: + assert not api_session.config.is_anonymous + assert api_session.config.domain == 'default' + assert str(api_session.config.endpoint) == 'https://api.backend.ai' + assert api_session.config.access_key == 'ABC' + assert api_session.config.secret_key == 'xyz' + + +@pytest.mark.asyncio +async def test_get_api_session_with_specific_api_endpoint(mocker): + mock_request = DummyRequest({'config': { + 'api': {'domain': 'default', 'endpoint': 'https://api.backend.ai'}, + }}) + mock_get_session = AsyncMock(return_value={ + 'authenticated': True, + 'token': {'type': 'keypair', 'access_key': 'ABC', 'secret_key': 'xyz'}, + }) + specific_api_endpoint = 'https://alternative.backend.ai' + mocker.patch('ai.backend.web.auth.get_session', mock_get_session) + api_session = await get_api_session(mock_request, specific_api_endpoint) + mock_get_session.assert_awaited_once() + async with api_session: + assert str(api_session.config.endpoint) == specific_api_endpoint + + +@pytest.mark.asyncio +async def test_get_anonymous_session(mocker): + mock_request = DummyRequest({'config': { + 'api': {'domain': 'default', 'endpoint': 'https://api.backend.ai'}, + }}) + mock_get_session = MagicMock() + mocker.patch('ai.backend.web.auth.get_session', mock_get_session) + api_session = await get_anonymous_session(mock_request) + mock_get_session.assert_not_called() + async with api_session: + assert api_session.config.is_anonymous + assert api_session.config.domain == 'default' + assert str(api_session.config.endpoint) == 'https://api.backend.ai' + assert api_session.config.access_key == '' + assert api_session.config.secret_key == '' + + +@pytest.mark.asyncio +async def test_get_anonymous_session_with_specific_api_endpoint(mocker): + mock_request = DummyRequest({'config': { + 'api': {'domain': 'default', 'endpoint': 'https://api.backend.ai'}, + }}) + specific_api_endpoint = 'https://alternative.backend.ai' + mock_get_session = MagicMock() + mocker.patch('ai.backend.web.auth.get_session', mock_get_session) + api_session = await get_anonymous_session(mock_request, specific_api_endpoint) + mock_get_session.assert_not_called() + async with api_session: + assert str(api_session.config.endpoint) == specific_api_endpoint diff --git a/tools/flake8.lock b/tools/flake8.lock new file mode 100644 index 0000000000..f1165ffa9e --- /dev/null +++ b/tools/flake8.lock @@ -0,0 +1,208 @@ +// This lockfile was autogenerated by Pants. To regenerate, run: +// +// ./pants generate-lockfiles --resolve=flake8 +// +// --- BEGIN PANTS LOCKFILE METADATA: DO NOT EDIT OR REMOVE --- +// { +// "version": 2, +// "valid_for_interpreter_constraints": [ +// "CPython==3.10.4" +// ], +// "generated_with_requirements": [ +// "flake8-commas>=2.1", +// "flake8>=4.0", +// "setuptools>=60.0" +// ] +// } +// --- END PANTS LOCKFILE METADATA --- + +{ + "allow_builds": true, + "allow_prereleases": false, + "allow_wheels": true, + "build_isolation": true, + "constraints": [], + "locked_resolves": [ + { + "locked_requirements": [ + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "479b1304f72536a55948cb40a32dce8bb0ffe3501e26eaf292c7e60eb5e0428d", + "url": "https://files.pythonhosted.org/packages/34/39/cde2c8a227abb4f9ce62fe55586b920f438f1d2903a1a22514d0b982c333/flake8-4.0.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d", + "url": "https://files.pythonhosted.org/packages/e6/84/d8db922289195c435779b4ca3a3f583f263f87e67954f7b2e83c8da21f48/flake8-4.0.1.tar.gz" + } + ], + "project_name": "flake8", + "requires_dists": [ + "importlib-metadata<4.3; python_version < \"3.8\"", + "mccabe<0.7.0,>=0.6.0", + "pycodestyle<2.9.0,>=2.8.0", + "pyflakes<2.5.0,>=2.4.0" + ], + "requires_python": ">=3.6", + "version": "4.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "ebb96c31e01d0ef1d0685a21f3f0e2f8153a0381430e748bf0bbbb5d5b453d54", + "url": "https://files.pythonhosted.org/packages/18/0d/41895badcdbbe84893b95c114d5bd4345d69c9d5645a42857f1ccb84d556/flake8_commas-2.1.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "940441ab8ee544df564ae3b3f49f20462d75d5c7cac2463e0b27436e2050f263", + "url": "https://files.pythonhosted.org/packages/0e/83/814bc8eb02b8883bc004384a1fb8b1f45b4a0b892e579fec7c80a9368526/flake8-commas-2.1.0.tar.gz" + } + ], + "project_name": "flake8-commas", + "requires_dists": [ + "flake8>=2" + ], + "requires_python": null, + "version": "2.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", + "url": "https://files.pythonhosted.org/packages/87/89/479dc97e18549e21354893e4ee4ef36db1d237534982482c3681ee6e7b57/mccabe-0.6.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f", + "url": "https://files.pythonhosted.org/packages/06/18/fa675aa501e11d6d6ca0ae73a101b2f3571a565e0f7d38e062eec18a91ee/mccabe-0.6.1.tar.gz" + } + ], + "project_name": "mccabe", + "requires_dists": [], + "requires_python": null, + "version": "0.6.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20", + "url": "https://files.pythonhosted.org/packages/15/94/bc43a2efb7b8615e38acde2b6624cae8c9ec86faf718ff5676c5179a7714/pycodestyle-2.8.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "eddd5847ef438ea1c7870ca7eb78a9d47ce0cdb4851a5523949f2601d0cbbe7f", + "url": "https://files.pythonhosted.org/packages/08/dc/b29daf0a202b03f57c19e7295b60d1d5e1281c45a6f5f573e41830819918/pycodestyle-2.8.0.tar.gz" + } + ], + "project_name": "pycodestyle", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "2.8" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e", + "url": "https://files.pythonhosted.org/packages/43/fb/38848eb494af7df9aeb2d7673ace8b213313eb7e391691a79dbaeb6a838f/pyflakes-2.4.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c", + "url": "https://files.pythonhosted.org/packages/15/60/c577e54518086e98470e9088278247f4af1d39cb43bcbd731e2c307acd6a/pyflakes-2.4.0.tar.gz" + } + ], + "project_name": "pyflakes", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7", + "version": "2.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "68e45d17c9281ba25dc0104eadd2647172b3472d9e01f911efa57965e8d51a36", + "url": "https://files.pythonhosted.org/packages/e9/1c/ec080fde54ab30a738c92f794eab7f5d2f354f2b619ee95b2efe353e0766/setuptools-62.3.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "a43bdedf853c670e5fed28e5623403bad2f73cf02f9a2774e91def6bda8265a7", + "url": "https://files.pythonhosted.org/packages/4a/25/ec29a23ef38b9456f9965c57a9e1221e6c246d87abbf2a31158799bca201/setuptools-62.3.2.tar.gz" + } + ], + "project_name": "setuptools", + "requires_dists": [ + "build[virtualenv]; extra == \"testing\"", + "build[virtualenv]; extra == \"testing-integration\"", + "filelock>=3.4.0; extra == \"testing\"", + "filelock>=3.4.0; extra == \"testing-integration\"", + "flake8-2020; extra == \"testing\"", + "furo; extra == \"docs\"", + "ini2toml[lite]>=0.9; extra == \"testing\"", + "jaraco.envs>=2.2; extra == \"testing\"", + "jaraco.envs>=2.2; extra == \"testing-integration\"", + "jaraco.packaging>=9; extra == \"docs\"", + "jaraco.path>=3.2.0; extra == \"testing\"", + "jaraco.path>=3.2.0; extra == \"testing-integration\"", + "jaraco.tidelift>=1.4; extra == \"docs\"", + "mock; extra == \"testing\"", + "pip-run>=8.8; extra == \"testing\"", + "pip>=19.1; extra == \"testing\"", + "pygments-github-lexers==0.0.5; extra == \"docs\"", + "pytest-black>=0.3.7; platform_python_implementation != \"PyPy\" and extra == \"testing\"", + "pytest-checkdocs>=2.4; extra == \"testing\"", + "pytest-cov; platform_python_implementation != \"PyPy\" and extra == \"testing\"", + "pytest-enabler; extra == \"testing-integration\"", + "pytest-enabler>=1.0.1; extra == \"testing\"", + "pytest-flake8; extra == \"testing\"", + "pytest-mypy>=0.9.1; platform_python_implementation != \"PyPy\" and extra == \"testing\"", + "pytest-perf; extra == \"testing\"", + "pytest-xdist; extra == \"testing\"", + "pytest-xdist; extra == \"testing-integration\"", + "pytest; extra == \"testing-integration\"", + "pytest>=6; extra == \"testing\"", + "rst.linker>=1.9; extra == \"docs\"", + "sphinx-favicon; extra == \"docs\"", + "sphinx-inline-tabs; extra == \"docs\"", + "sphinx-reredirects; extra == \"docs\"", + "sphinx; extra == \"docs\"", + "sphinxcontrib-towncrier; extra == \"docs\"", + "tomli-w>=1.0.0; extra == \"testing\"", + "tomli; extra == \"testing-integration\"", + "virtualenv>=13.0.0; extra == \"testing\"", + "virtualenv>=13.0.0; extra == \"testing-integration\"", + "wheel; extra == \"testing\"", + "wheel; extra == \"testing-integration\"" + ], + "requires_python": ">=3.7", + "version": "62.3.2" + } + ], + "platform_tag": [ + "cp310", + "cp310", + "manylinux_2_31_aarch64" + ] + } + ], + "path_mappings": {}, + "pex_version": "2.1.84", + "prefer_older_binary": false, + "requirements": [ + "flake8-commas>=2.1", + "flake8>=4.0", + "setuptools>=60.0" + ], + "requires_python": [ + "==3.10.4" + ], + "resolver_version": "pip-2020-resolver", + "style": "universal", + "transitive": true, + "use_pep517": null +} \ No newline at end of file diff --git a/tools/mypy.lock b/tools/mypy.lock new file mode 100644 index 0000000000..ebbf913208 --- /dev/null +++ b/tools/mypy.lock @@ -0,0 +1,154 @@ +// This lockfile was autogenerated by Pants. To regenerate, run: +// +// ./pants generate-lockfiles --resolve=mypy +// +// --- BEGIN PANTS LOCKFILE METADATA: DO NOT EDIT OR REMOVE --- +// { +// "version": 2, +// "valid_for_interpreter_constraints": [ +// "CPython==3.10.4" +// ], +// "generated_with_requirements": [ +// "mypy>=0.950" +// ] +// } +// --- END PANTS LOCKFILE METADATA --- + +{ + "allow_builds": true, + "allow_prereleases": false, + "allow_wheels": true, + "build_isolation": true, + "constraints": [], + "locked_resolves": [ + { + "locked_requirements": [ + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "bfd4f6536bd384c27c392a8b8f790fd0ed5c0cf2f63fc2fed7bce56751d53026", + "url": "https://files.pythonhosted.org/packages/7e/53/386d8b939c5654c7a2218c1625f991cc123c3d3f69d2c0e2aed16c475eb4/mypy-0.960-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "7a76dc4f91e92db119b1be293892df8379b08fd31795bb44e0ff84256d34c251", + "url": "https://files.pythonhosted.org/packages/0c/8f/9b3e12e9971389befe2bfca98eb99b5555a5f6f625f8482ded3a0e97715e/mypy-0.960-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "d4fccf04c1acf750babd74252e0f2db6bd2ac3aa8fe960797d9f3ef41cf2bfd4", + "url": "https://files.pythonhosted.org/packages/22/22/49792504e249a774554cd473e69af411a62c7d0591651104538fbcdaec10/mypy-0.960.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "ffdad80a92c100d1b0fe3d3cf1a4724136029a29afe8566404c0146747114382", + "url": "https://files.pythonhosted.org/packages/33/b8/bab515402dbb5874b3bc5c1586c9a70bccee73f5daa82c170c54a72b729e/mypy-0.960-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "7d390248ec07fa344b9f365e6ed9d205bd0205e485c555bed37c4235c868e9d5", + "url": "https://files.pythonhosted.org/packages/7f/b8/aaa840a5f37e39bb3eccef982d48dccb8689074e7f6fd76d8dfffe3a8a66/mypy-0.960-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "3a3e525cd76c2c4f90f1449fd034ba21fcca68050ff7c8397bb7dd25dd8b8248", + "url": "https://files.pythonhosted.org/packages/a7/1c/ac13b5d83aa025adc9d7e525b8eea180f3600ac40d47b6ad7746cf796a9a/mypy-0.960-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "925aa84369a07846b7f3b8556ccade1f371aa554f2bd4fb31cb97a24b73b036e", + "url": "https://files.pythonhosted.org/packages/f1/8f/1cf7aaad4957c3e329e160d993353b440da7499365d2f639cf973ab0f235/mypy-0.960-cp310-cp310-win_amd64.whl" + } + ], + "project_name": "mypy", + "requires_dists": [ + "lxml; extra == \"reports\"", + "mypy-extensions>=0.4.3", + "psutil>=4.0; extra == \"dmypy\"", + "tomli>=1.1.0; python_version < \"3.11\"", + "typed-ast<2,>=1.4.0; extra == \"python2\"", + "typed-ast<2,>=1.4.0; python_version < \"3.8\"", + "typing-extensions>=3.10" + ], + "requires_python": ">=3.6", + "version": "0.960" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d", + "url": "https://files.pythonhosted.org/packages/5c/eb/975c7c080f3223a5cdaff09612f3a5221e4ba534f7039db34c35d95fa6a5/mypy_extensions-0.4.3-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8", + "url": "https://files.pythonhosted.org/packages/63/60/0582ce2eaced55f65a4406fc97beba256de4b7a95a0034c6576458c6519f/mypy_extensions-0.4.3.tar.gz" + } + ], + "project_name": "mypy-extensions", + "requires_dists": [ + "typing>=3.5.3; python_version < \"3.5\"" + ], + "requires_python": null, + "version": "0.4.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc", + "url": "https://files.pythonhosted.org/packages/97/75/10a9ebee3fd790d20926a90a2547f0bf78f371b2f13aa822c759680ca7b9/tomli-2.0.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f", + "url": "https://files.pythonhosted.org/packages/c0/3f/d7af728f075fb08564c5949a9c95e44352e23dee646869fa104a3b2060a3/tomli-2.0.1.tar.gz" + } + ], + "project_name": "tomli", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "2.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708", + "url": "https://files.pythonhosted.org/packages/75/e1/932e06004039dd670c9d5e1df0cd606bf46e29a28e65d5bb28e894ea29c9/typing_extensions-4.2.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376", + "url": "https://files.pythonhosted.org/packages/fe/71/1df93bd59163c8084d812d166c907639646e8aac72886d563851b966bf18/typing_extensions-4.2.0.tar.gz" + } + ], + "project_name": "typing-extensions", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "4.2" + } + ], + "platform_tag": [ + "cp310", + "cp310", + "manylinux_2_31_aarch64" + ] + } + ], + "path_mappings": {}, + "pex_version": "2.1.84", + "prefer_older_binary": false, + "requirements": [ + "mypy>=0.950" + ], + "requires_python": [ + "==3.10.4" + ], + "resolver_version": "pip-2020-resolver", + "style": "universal", + "transitive": true, + "use_pep517": null +} \ No newline at end of file diff --git a/tools/pants-linux-aarch64.patch b/tools/pants-linux-aarch64.patch new file mode 100644 index 0000000000..f7b39ede67 --- /dev/null +++ b/tools/pants-linux-aarch64.patch @@ -0,0 +1,68 @@ +diff --git a/build-support/common.sh b/build-support/common.sh +index 0601b0d04..56235678a 100644 +--- a/build-support/common.sh ++++ b/build-support/common.sh +@@ -52,7 +52,7 @@ function determine_python() { + echo "${PY}" + return 0 + fi +- for version in '3.7' '3.8' '3.9'; do ++ for version in '3.7' '3.8' '3.9' '3.10'; do + local interpreter_path + interpreter_path="$(command -v "python${version}")" + if [[ -z "${interpreter_path}" ]]; then +diff --git a/src/python/pants/backend/python/util_rules/pex_cli.py b/src/python/pants/backend/python/util_rules/pex_cli.py +index e6526e3ed..56589e86a 100644 +--- a/src/python/pants/backend/python/util_rules/pex_cli.py ++++ b/src/python/pants/backend/python/util_rules/pex_cli.py +@@ -54,7 +54,7 @@ class PexCli(TemplatedExternalTool): + "3741845", + ) + ) +- for plat in ["macos_arm64", "macos_x86_64", "linux_x86_64"] ++ for plat in ["macos_arm64", "macos_x86_64", "linux_arm64", "linux_x86_64"] + ] + + +diff --git a/src/rust/engine/process_execution/src/lib.rs b/src/rust/engine/process_execution/src/lib.rs +index 762944d65..580ccc471 100644 +--- a/src/rust/engine/process_execution/src/lib.rs ++++ b/src/rust/engine/process_execution/src/lib.rs +@@ -87,6 +87,7 @@ pub enum Platform { + Macos_x86_64, + Macos_arm64, + Linux_x86_64, ++ Linux_arm64, + } + + impl Platform { +@@ -115,6 +116,13 @@ impl Platform { + } if sysname.to_lowercase() == "darwin" && machine.to_lowercase() == "x86_64" => { + Ok(Platform::Macos_x86_64) + } ++ uname::Info { ++ ref sysname, ++ ref machine, ++ .. ++ } if sysname.to_lowercase() == "linux" && (machine.to_lowercase() == "arm64" || machine.to_lowercase() == "aarch64") => { ++ Ok(Platform::Linux_arm64) ++ } + uname::Info { + ref sysname, + ref machine, +@@ -130,6 +138,7 @@ impl Platform { + impl From for String { + fn from(platform: Platform) -> String { + match platform { ++ Platform::Linux_arm64 => "linux_arm64".to_string(), + Platform::Linux_x86_64 => "linux_x86_64".to_string(), + Platform::Macos_arm64 => "macos_arm64".to_string(), + Platform::Macos_x86_64 => "macos_x86_64".to_string(), +@@ -143,6 +152,7 @@ impl TryFrom for Platform { + match variant_candidate.as_ref() { + "macos_arm64" => Ok(Platform::Macos_arm64), + "macos_x86_64" => Ok(Platform::Macos_x86_64), ++ "linux_arm64" => Ok(Platform::Linux_arm64), + "linux_x86_64" => Ok(Platform::Linux_x86_64), + other => Err(format!( + "Unknown platform {:?} encountered in parsing", diff --git a/tools/pants-local b/tools/pants-local new file mode 100755 index 0000000000..e120d16b82 --- /dev/null +++ b/tools/pants-local @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Copyright 2020 Pants project contributors. +# Licensed under the Apache License, Version 2.0 (see LICENSE). + +# Run pants from a locally built source clone. + +set -euo pipefail + +cd "$(git rev-parse --show-toplevel)" + +if [ -f '.pants.env' ]; then + source .pants.env +fi +if [ -z "$PY" ]; then + echo "The Python version for source-built Pants is not configured in ./.pants.env" + exit 1 +fi +PANTS_SOURCE="${PANTS_SOURCE:-$(pwd)/tools/pants-src}" + +# When running pants from sources you are likely to be modifying those sources, so +# you won't want pantsd running. You can override this by setting ENABLE_PANTSD=true. +ENABLE_PANTSD="${ENABLE_PANTSD:-false}" + +export PANTS_VERSION="$(cat "${PANTS_SOURCE}/src/python/pants/VERSION")" +export PANTS_PANTSD="${ENABLE_PANTSD}" +export no_proxy="*" + +exec "${PANTS_SOURCE}/pants" "--no-verify-config" "$@" diff --git a/tools/pants-plugins/platform_resources/BUILD b/tools/pants-plugins/platform_resources/BUILD new file mode 100644 index 0000000000..db46e8d6c9 --- /dev/null +++ b/tools/pants-plugins/platform_resources/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/tools/pants-plugins/platform_resources/register.py b/tools/pants-plugins/platform_resources/register.py new file mode 100644 index 0000000000..9ea2cadb1d --- /dev/null +++ b/tools/pants-plugins/platform_resources/register.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import logging + +from pants.engine.addresses import Address, Addresses, UnparsedAddressInputs +from pants.engine.platform import Platform +from pants.engine.rules import ( + Get, + SubsystemRule, + collect_rules, + rule, +) +from pants.engine.target import ( + COMMON_TARGET_FIELDS, + DictStringToStringField, + Dependencies, + InjectedDependencies, + InjectDependenciesRequest, + Target, + WrappedTarget, +) +from pants.engine.unions import UnionRule +from pants.option.option_types import EnumOption +from pants.option.subsystem import Subsystem + +logger = logging.getLogger(__name__) + + +class PlatformResourcesSusbystem(Subsystem): + options_scope = "platform-specific-resources" + help = "The platform-specific resource provider." + platform = EnumOption( + "--target", + default=Platform.current, + enum_type=Platform, + advanced=False, + help="Select only resources compatible with the given platform", + ) + + +class PlatformDependencyMapField(DictStringToStringField): + alias = "dependency_map" + help = "Specifies platform-specific dependencies as a dictionary from platform names to dependency lists." + + +class PlatformSpecificDependencies(Dependencies): + """ + This field will be populated by injection based on the `--platform-specific-resources-target` option + and from the `dependency_map` field of the `platform_resources` target. + """ + + +class PlatformResourcesTarget(Target): + alias = "platform_resources" + core_fields = (*COMMON_TARGET_FIELDS, PlatformDependencyMapField, PlatformSpecificDependencies) + help = "A target to declare selective dependency sets for multiple different platforms" + + +class InjectPlatformSpecificDependenciesRequest(InjectDependenciesRequest): + inject_for = PlatformSpecificDependencies + + +@rule +async def inject_platform_specific_dependencies( + request: InjectPlatformSpecificDependenciesRequest, + subsystem: PlatformResourcesSusbystem, +) -> InjectedDependencies: + logger.info( + "configured target platform (%s) = %s", + request.dependencies_field.address, + subsystem.platform.value, + ) + wrapped_target = await Get(WrappedTarget, Address, request.dependencies_field.address) + platforms = wrapped_target.target.get(PlatformDependencyMapField).value + platform_resources_unparsed_address = platforms and platforms.get(subsystem.platform.value) + if not platform_resources_unparsed_address: + return InjectedDependencies() + parsed_addresses = await Get( + Addresses, + UnparsedAddressInputs( + (platform_resources_unparsed_address,), + owning_address=request.dependencies_field.address, + ), + ) + return InjectedDependencies(Addresses(parsed_addresses)) + + +# Plugin registration + + +def target_types(): + return ( + PlatformResourcesTarget, + ) + + +def rules(): + return [ + *collect_rules(), + SubsystemRule(PlatformResourcesSusbystem), + UnionRule(InjectDependenciesRequest, InjectPlatformSpecificDependenciesRequest), + ] diff --git a/tools/pants-plugins/setupgen/BUILD b/tools/pants-plugins/setupgen/BUILD new file mode 100644 index 0000000000..db46e8d6c9 --- /dev/null +++ b/tools/pants-plugins/setupgen/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/tools/pants-plugins/setupgen/register.py b/tools/pants-plugins/setupgen/register.py new file mode 100644 index 0000000000..d9a0fe18e8 --- /dev/null +++ b/tools/pants-plugins/setupgen/register.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import re +from pathlib import Path + +from pants.backend.python.goals.setup_py import SetupKwargs, SetupKwargsRequest +from pants.engine.fs import DigestContents, GlobMatchErrorBehavior, PathGlobs +from pants.engine.rules import Get, collect_rules, rule +from pants.engine.target import Target +from pants.engine.unions import UnionRule + + +class CustomSetupKwargsRequest(SetupKwargsRequest): + @classmethod + def is_applicable(cls, _: Target) -> bool: + # We always use our custom `setup()` kwargs generator for `python_distribution` targets in + # this repo. + return True + + +license_classifier_map = { + "LGPLv3": "License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)", + "Apache 2.0": "License :: OSI Approved :: Apache Software License", + "BSD": "License :: OSI Approved :: BSD License", + "MIT": "License :: OSI Approved :: MIT License", +} + + +@rule +async def setup_kwargs_plugin(request: CustomSetupKwargsRequest) -> SetupKwargs: + kwargs = request.explicit_kwargs.copy() + + # Single-source the version from VERSION. + _digest_contents = await Get( + DigestContents, + PathGlobs( + ["VERSION"], + description_of_origin="setupgen plugin", + glob_match_error_behavior=GlobMatchErrorBehavior.error, + ), + ) + VERSION = _digest_contents[0].content.decode() + + # Validate that required fields are set. + if not kwargs["name"].startswith("backend.ai-"): + raise ValueError( + f"Invalid `name` kwarg in the `provides` field for {request.target.address}. The name " + f"must start with 'backend.ai-', but was {kwargs['name']}.", + ) + if "description" not in kwargs: + raise ValueError( + f"Missing a `description` kwarg in the `provides` field for {request.target.address}.", + ) + + # Add classifiers. We preserve any that were already set. + standard_classifiers = [ + "Intended Audience :: Developers", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Environment :: No Input/Output (Daemon)", + "Topic :: Scientific/Engineering", + "Topic :: Software Development", + ] + if re.search(r"\.?dev\d*$", VERSION): + standard_classifiers.append("Development Status :: 2 - Pre-Alpha") + elif re.search(r"\.?a(lpha)?\d*$", VERSION): + standard_classifiers.append("Development Status :: 3 - Alpha") + elif re.search(r"\.?b(eta)?\d*$", VERSION): + standard_classifiers.append("Development Status :: 4 - Beta") + elif re.search(r"\.?rc?\d*$", VERSION): + standard_classifiers.append("Development Status :: 4 - Beta") + else: + standard_classifiers.append("Development Status :: 5 - Production/Stable") + + license_classifier = license_classifier_map.get(kwargs["license"]) + if license_classifier: + standard_classifiers.append(license_classifier) + + kwargs["classifiers"] = [*standard_classifiers, *kwargs.get("classifiers", [])] + + # Determine the long description by reading from ABOUT.md and the release notes. + spec_path = Path(request.target.address.spec_path) + if (spec_path / "README.md").is_file(): + readme_path = spec_path / "README.md" + elif (spec_path / "README.rst").is_file(): + readme_path = spec_path / "README.rst" + else: + readme_path = spec_path / "README" + _digest_contents = await Get( + DigestContents, + PathGlobs( + [str(readme_path)], + description_of_origin="setupgen plugin", + glob_match_error_behavior=GlobMatchErrorBehavior.error, + ), + ) + long_description = _digest_contents[0].content.decode() + + # Hardcode certain kwargs and validate that they weren't already set. + hardcoded_kwargs = dict( + version=VERSION, + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/lablup/backend.ai", + project_urls={ + "Documentation": "https://docs.backend.ai/", + "Source": "https://github.com/lablup/backend.ai", + }, + author="Lablup Inc. and contributors", + zip_safe=False, + ) + conflicting_hardcoded_kwargs = set(kwargs.keys()).intersection(hardcoded_kwargs.keys()) + if conflicting_hardcoded_kwargs: + raise ValueError( + f"These kwargs should not be set in the `provides` field for {request.target.address} " + "because Pants's internal plugin will automatically set them: " + f"{sorted(conflicting_hardcoded_kwargs)}", + ) + kwargs.update(hardcoded_kwargs) + + return SetupKwargs(kwargs, address=request.target.address) + + +def rules(): + return [ + *collect_rules(), + UnionRule(SetupKwargsRequest, CustomSetupKwargsRequest), + ] diff --git a/tools/pytest.lock b/tools/pytest.lock new file mode 100644 index 0000000000..3a90fb963d --- /dev/null +++ b/tools/pytest.lock @@ -0,0 +1,955 @@ +// This lockfile was autogenerated by Pants. To regenerate, run: +// +// ./pants generate-lockfiles --resolve=pytest +// +// --- BEGIN PANTS LOCKFILE METADATA: DO NOT EDIT OR REMOVE --- +// { +// "version": 2, +// "valid_for_interpreter_constraints": [ +// "CPython==3.10.4" +// ], +// "generated_with_requirements": [ +// "aioresponses>=0.7.3", +// "pytest-aiohttp>=1.0.4", +// "pytest-asyncio>=0.18", +// "pytest-cov!=2.12.1,<3.1,>=2.12", +// "pytest-dependency>=0.5.1", +// "pytest-mock>=3.5.0", +// "pytest>=7.0" +// ] +// } +// --- END PANTS LOCKFILE METADATA --- + +{ + "allow_builds": true, + "allow_prereleases": false, + "allow_wheels": true, + "build_isolation": true, + "constraints": [], + "locked_resolves": [ + { + "locked_requirements": [ + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "713ac174a629d39b7c6a3aa757b337599798da4c1157114a314e4e391cd28e32", + "url": "https://files.pythonhosted.org/packages/2e/4f/119a8efad036d1f766ad736864a6dbfc8db9596e74ce9820f8c1282a240b/aiohttp-3.8.1-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "12de6add4038df8f72fac606dff775791a60f113a725c960f2bab01d8b8e6b15", + "url": "https://files.pythonhosted.org/packages/48/08/c3efb449dea5f38292804e4fbf8eaef1b3f168535a4163cc3fce3f9b4915/aiohttp-3.8.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "97ef77eb6b044134c0b3a96e16abcb05ecce892965a2124c566af0fd60f717e2", + "url": "https://files.pythonhosted.org/packages/4f/c6/a8ce9fc6bbf9c0dbdaa631bcb8f9da5b532fd22ead50ef7390976fc9bf0d/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "fc5471e1a54de15ef71c1bc6ebe80d4dc681ea600e68bfd1cbce40427f0b7578", + "url": "https://files.pythonhosted.org/packages/5a/86/5f63de7a202550269a617a5d57859a2961f3396ecd1739a70b92224766bc/aiohttp-3.8.1.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "c2aef4703f1f2ddc6df17519885dbfa3514929149d3ff900b73f45998f2532fa", + "url": "https://files.pythonhosted.org/packages/75/86/c55c7b6b9d0d9e25b1d721e204424f154bd72bb172d2056f0f9f06c50254/aiohttp-3.8.1-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "61bfc23df345d8c9716d03717c2ed5e27374e0fe6f659ea64edcd27b4b044cf7", + "url": "https://files.pythonhosted.org/packages/76/3d/8f64ed6d429f9feeefc52b551f4ba5554d2f7a6f46d92c080f4ae48e0478/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "7dadf3c307b31e0e61689cbf9e06be7a867c563d5a63ce9dca578f956609abf8", + "url": "https://files.pythonhosted.org/packages/7e/9f/3cd2502f3cab61eccd7c20f5ab67447cf891ad8613282141955df1b7fb98/aiohttp-3.8.1-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "099ebd2c37ac74cce10a3527d2b49af80243e2a4fa39e7bce41617fbc35fa3c1", + "url": "https://files.pythonhosted.org/packages/80/a3/9403173d3a6ba5893a4e0a1816b211da7ba0cb7c00c9ac0279ec2dbbf576/aiohttp-3.8.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "eaba923151d9deea315be1f3e2b31cc39a6d1d2f682f942905951f4e40200922", + "url": "https://files.pythonhosted.org/packages/85/e6/d52a342bf22b5b5c759a94af340836490bcbffd288d4a65494234d8298f7/aiohttp-3.8.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "a79004bb58748f31ae1cbe9fa891054baaa46fb106c2dc7af9f8e3304dc30316", + "url": "https://files.pythonhosted.org/packages/a6/7f/4c202b0fd3c33029e45bb0d06eaac2886be4427763cc9589774fb39b5da7/aiohttp-3.8.1-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "6f0d5f33feb5f69ddd57a4a4bd3d56c719a141080b445cbf18f238973c5c9923", + "url": "https://files.pythonhosted.org/packages/b1/bd/e412cb6cd12b7a86966239a97ed0391e1ad5ac6f8a749caddc49e18264ec/aiohttp-3.8.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "fa0ffcace9b3aa34d205d8130f7873fcfefcb6a4dd3dd705b0dab69af6712642", + "url": "https://files.pythonhosted.org/packages/c0/6d/f5423a7c899c538e2cff2e713f9eb2c51b02fad909ec8e8b1c3ed713049a/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "01d7bdb774a9acc838e6b8f1d114f45303841b89b95984cbb7d80ea41172a9e3", + "url": "https://files.pythonhosted.org/packages/cc/28/c95a0694da3082cb76808799017b02db6c10ec8687ee1ac5edad091ab070/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "1ed0b6477896559f17b9eaeb6d38e07f7f9ffe40b9f0f9627ae8b9926ae260a8", + "url": "https://files.pythonhosted.org/packages/e3/3a/720635a98bb0eef9179d12ee3ccca659d1fcccfbafaacdf42ed5536a0861/aiohttp-3.8.1-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "31560d268ff62143e92423ef183680b9829b1b482c011713ae941997921eebc8", + "url": "https://files.pythonhosted.org/packages/f3/0d/a035862f8a11b6cba4220b0c1201443fa6f5151137889e2dfe1cc983e58e/aiohttp-3.8.1-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "2e5d962cf7e1d426aa0e528a7e198658cdc8aa4fe87f781d039ad75dcd52c516", + "url": "https://files.pythonhosted.org/packages/f4/2d/07e3ba718571e79509f88a791611a3e156e8915ed9a19116547806bce8fa/aiohttp-3.8.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + } + ], + "project_name": "aiohttp", + "requires_dists": [ + "Brotli; extra == \"speedups\"", + "aiodns; extra == \"speedups\"", + "aiosignal>=1.1.2", + "async-timeout<5.0,>=4.0.0a3", + "asynctest==0.13.0; python_version < \"3.8\"", + "attrs>=17.3.0", + "cchardet; extra == \"speedups\"", + "charset-normalizer<3.0,>=2.0", + "frozenlist>=1.1.1", + "idna-ssl>=1.0; python_version < \"3.7\"", + "multidict<7.0,>=4.5", + "typing-extensions>=3.7.4; python_version < \"3.8\"", + "yarl<2.0,>=1.0" + ], + "requires_python": ">=3.6", + "version": "3.8.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "7b1897169062c92fa87d6ecc503ac566ac87fbfacb2504f8ca81c8035a2eb068", + "url": "https://files.pythonhosted.org/packages/74/47/ac822df01b018323d325d04d35a2df406d22e2d399f7ddb7cb8cb0805dbb/aioresponses-0.7.3-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2c64ed5710ee8cb4e958c569184dad12f4c9cd5939135cb38f88c6a8261cceb3", + "url": "https://files.pythonhosted.org/packages/20/7e/ecc31ff29b3e14859c4e7edf2d9fd38154c4d775a1646ee0441f34b61571/aioresponses-0.7.3.tar.gz" + } + ], + "project_name": "aioresponses", + "requires_dists": [ + "aiohttp<4.0.0,>=2.0.0" + ], + "requires_python": null, + "version": "0.7.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "26e62109036cd181df6e6ad646f91f0dcfd05fe16d0cb924138ff2ab75d64e3a", + "url": "https://files.pythonhosted.org/packages/3b/87/fe94898f2d44a93a35d5aa74671ed28094d80753a1113d68b799fab6dc22/aiosignal-1.2.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "78ed67db6c7b7ced4f98e495e572106d5c432a93e1ddd1bf475e1dc05f5b7df2", + "url": "https://files.pythonhosted.org/packages/27/6b/a89fbcfae70cf53f066ec22591938296889d3cc58fec1e1c393b10e8d71d/aiosignal-1.2.0.tar.gz" + } + ], + "project_name": "aiosignal", + "requires_dists": [ + "frozenlist>=1.1.0" + ], + "requires_python": ">=3.6", + "version": "1.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c", + "url": "https://files.pythonhosted.org/packages/d6/c1/8991e7c5385b897b8c020cdaad718c5b087a6626d1d11a23e1ea87e325a7/async_timeout-4.0.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15", + "url": "https://files.pythonhosted.org/packages/54/6e/9678f7b2993537452710ffb1750c62d2c26df438aa621ad5fa9d1507a43a/async-timeout-4.0.2.tar.gz" + } + ], + "project_name": "async-timeout", + "requires_dists": [ + "typing-extensions>=3.6.5; python_version < \"3.8\"" + ], + "requires_python": ">=3.6", + "version": "4.0.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197", + "url": "https://files.pythonhosted.org/packages/2c/a0/da5f49008ec6e9a658dbf5d7310a4debd397bce0b4db03cf8a410066bb87/atomicwrites-1.4.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a", + "url": "https://files.pythonhosted.org/packages/55/8d/74a75635f2c3c914ab5b3850112fd4b0c8039975ecb320e4449aa363ba54/atomicwrites-1.4.0.tar.gz" + } + ], + "project_name": "atomicwrites", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7", + "version": "1.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4", + "url": "https://files.pythonhosted.org/packages/be/be/7abce643bfdf8ca01c48afa2ddf8308c2308b0c3b239a44e57d020afa0ef/attrs-21.4.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd", + "url": "https://files.pythonhosted.org/packages/d7/77/ebb15fc26d0f815839ecd897b919ed6d85c050feeb83e100e020df9153d2/attrs-21.4.0.tar.gz" + } + ], + "project_name": "attrs", + "requires_dists": [ + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"dev\"", + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"tests\"", + "cloudpickle; platform_python_implementation == \"CPython\" and extra == \"tests_no_zope\"", + "coverage[toml]>=5.0.2; extra == \"dev\"", + "coverage[toml]>=5.0.2; extra == \"tests\"", + "coverage[toml]>=5.0.2; extra == \"tests_no_zope\"", + "furo; extra == \"dev\"", + "furo; extra == \"docs\"", + "hypothesis; extra == \"dev\"", + "hypothesis; extra == \"tests\"", + "hypothesis; extra == \"tests_no_zope\"", + "mypy; extra == \"dev\"", + "mypy; extra == \"tests\"", + "mypy; extra == \"tests_no_zope\"", + "pre-commit; extra == \"dev\"", + "pympler; extra == \"dev\"", + "pympler; extra == \"tests\"", + "pympler; extra == \"tests_no_zope\"", + "pytest-mypy-plugins; extra == \"dev\"", + "pytest-mypy-plugins; extra == \"tests\"", + "pytest-mypy-plugins; extra == \"tests_no_zope\"", + "pytest>=4.3.0; extra == \"dev\"", + "pytest>=4.3.0; extra == \"tests\"", + "pytest>=4.3.0; extra == \"tests_no_zope\"", + "six; extra == \"dev\"", + "six; extra == \"tests\"", + "six; extra == \"tests_no_zope\"", + "sphinx-notfound-page; extra == \"dev\"", + "sphinx-notfound-page; extra == \"docs\"", + "sphinx; extra == \"dev\"", + "sphinx; extra == \"docs\"", + "zope.interface; extra == \"dev\"", + "zope.interface; extra == \"docs\"", + "zope.interface; extra == \"tests\"" + ], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "21.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6881edbebdb17b39b4eaaa821b438bf6eddffb4468cf344f09f89def34a8b1df", + "url": "https://files.pythonhosted.org/packages/06/b3/24afc8868eba069a7f03650ac750a778862dc34941a4bebeb58706715726/charset_normalizer-2.0.12-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2857e29ff0d34db842cd7ca3230549d1a697f96ee6d3fb071cfa6c7393832597", + "url": "https://files.pythonhosted.org/packages/56/31/7bcaf657fafb3c6db8c787a865434290b726653c912085fbd371e9b92e1c/charset-normalizer-2.0.12.tar.gz" + } + ], + "project_name": "charset-normalizer", + "requires_dists": [ + "unicodedata2; extra == \"unicode_backport\"" + ], + "requires_python": ">=3.5.0", + "version": "2.0.12" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2", + "url": "https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b", + "url": "https://files.pythonhosted.org/packages/1f/bb/5d3246097ab77fa083a61bd8d3d527b7ae063c7d8e8671b1cf8c4ec10cbe/colorama-0.4.4.tar.gz" + } + ], + "project_name": "colorama", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "0.4.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "3cfd07c5889ddb96a401449109a8b97a165be9d67077df6802f59708bfb07720", + "url": "https://files.pythonhosted.org/packages/d8/e8/dd4a92c84359e6c5647de9245a7210eba015d919e8e14be8a7b485850805/coverage-6.4-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "af5b9ee0fc146e907aa0f5fb858c3b3da9199d78b7bb2c9973d95550bd40f701", + "url": "https://files.pythonhosted.org/packages/03/ad/d820de6679d051e20e0f72254a09a2ea2a6f6948dbb94e28dd460e72ea5d/coverage-6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "5a78cf2c43b13aa6b56003707c5203f28585944c277c1f3f109c7b041b16bd39", + "url": "https://files.pythonhosted.org/packages/09/0d/5a60d0d14feb2c498ebcd00b5c57c459ba7f779b75e4579e1e74fcb37124/coverage-6.4-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "60c2147921da7f4d2d04f570e1838db32b95c5509d248f3fe6417e91437eaf41", + "url": "https://files.pythonhosted.org/packages/13/af/418aa14415712839604fb2eadc307aaf4913540cf42e9b76fd41330b3d74/coverage-6.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "50ed480b798febce113709846b11f5d5ed1e529c88d8ae92f707806c50297abf", + "url": "https://files.pythonhosted.org/packages/27/1e/b36f224db34a0d77be90da93224ae972f799373af803718f2a21d9689c6c/coverage-6.4-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "a022394996419142b33a0cf7274cb444c01d2bb123727c4bb0b9acabcb515dea", + "url": "https://files.pythonhosted.org/packages/2b/e7/88b69e5aec5b317c1f8371ea9a0e386dae0e61ded27ab85649ab471a59ce/coverage-6.4-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "9229d074e097f21dfe0643d9d0140ee7433814b3f0fc3706b4abffd1e3038632", + "url": "https://files.pythonhosted.org/packages/46/09/941c11c98d56758eb13e5da0f451f00a051072c751fff3e191bf4723b1d7/coverage-6.4-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "26f8f92699756cb7af2b30720de0c5bb8d028e923a95b6d0c891088025a1ac8f", + "url": "https://files.pythonhosted.org/packages/7e/3a/ad9b121206034eff6e4d144adf0e017cbecca56bc0b29de4e5fcf20a5941/coverage-6.4-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "727dafd7f67a6e1cad808dc884bd9c5a2f6ef1f8f6d2f22b37b96cb0080d4f49", + "url": "https://files.pythonhosted.org/packages/af/c9/91533449c1c8685aa09b89069a73f181a0d841b53a4516798e80af6efc5e/coverage-6.4.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "fb45fe08e1abc64eb836d187b20a59172053999823f7f6ef4f18a819c44ba16f", + "url": "https://files.pythonhosted.org/packages/d3/5b/fe12a83e9671cb9c5a7c9e623735c121cd80ca1e049fd049b371dd704f4f/coverage-6.4-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "750e13834b597eeb8ae6e72aa58d1d831b96beec5ad1d04479ae3772373a8088", + "url": "https://files.pythonhosted.org/packages/ea/1a/1356f2fe9dbe2f7b61dfecaef298a36503a9392fe961959e0ec6558f0155/coverage-6.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + } + ], + "project_name": "coverage", + "requires_dists": [ + "tomli; python_version < \"3.11\" and extra == \"toml\"" + ], + "requires_python": ">=3.7", + "version": "6.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "f7353ba3367473d1d616ee727945f439e027f0bb16ac1a750219a8344d1d5d3c", + "url": "https://files.pythonhosted.org/packages/a0/fa/7e6e4cbd0911966ca52846deee74b6ef9b138c45765bdb0f7242f14688e4/frozenlist-1.3.0-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "6eb275c6385dd72594758cbe96c07cdb9bd6becf84235f4a594bdf21e3596c9d", + "url": "https://files.pythonhosted.org/packages/0e/36/c4659bee33cab5ed22b7df23bafc3841a269793ca8e5527822f3fe41b568/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "754728d65f1acc61e0f4df784456106e35afb7bf39cfe37227ab00436fb38676", + "url": "https://files.pythonhosted.org/packages/14/36/9a396760b7d1a48efe3520e994064401b36dfa9286e5b5e5bfb5bde16db7/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "e30b2f9683812eb30cf3f0a8e9f79f8d590a7999f731cf39f9105a7c4a39489d", + "url": "https://files.pythonhosted.org/packages/24/1c/076b1a5a0b8b4af0bae5f999eaf0e3deaa25eb08fe195cdc3e628e41c279/frozenlist-1.3.0-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "4a44ebbf601d7bac77976d429e9bdb5a4614f9f4027777f9e54fd765196e9d3b", + "url": "https://files.pythonhosted.org/packages/29/03/a300b151ecb1cf78c4fe404978ffbdb719eed810a1606e6afc8ae8f16837/frozenlist-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "45334234ec30fc4ea677f43171b18a27505bfb2dba9aca4398a62692c0ea8868", + "url": "https://files.pythonhosted.org/packages/32/61/b322998b806633b7df19d614916600d00439099dbb030a623eeb0694304e/frozenlist-1.3.0-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "04cb491c4b1c051734d41ea2552fde292f5f3a9c911363f74f39c23659c4af78", + "url": "https://files.pythonhosted.org/packages/3f/9e/991076d645ddfff334ace95b9386daef81cc144676c7f0057938f29ffa48/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "436496321dad302b8b27ca955364a439ed1f0999311c393dccb243e451ff66aa", + "url": "https://files.pythonhosted.org/packages/49/22/cb44c4c4671c55fc2ecf0727496d466390315f705ec3f0b0c7aeb5658a50/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "b9e3e9e365991f8cc5f5edc1fd65b58b41d0514a6a7ad95ef5c7f34eb49b3d3e", + "url": "https://files.pythonhosted.org/packages/4c/4e/0a153040dc966105dc99ccb597358d30a9bbda4a13aa753d0f382eced4fb/frozenlist-1.3.0-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "6a202458d1298ced3768f5a7d44301e7c86defac162ace0ab7434c2e961166e8", + "url": "https://files.pythonhosted.org/packages/5d/98/10edca86eb789469648049d0f8ea0b5bd74f5a3e11064ae620095db8595e/frozenlist-1.3.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "691ddf6dc50480ce49f68441f1d16a4c3325887453837036e0fb94736eae1e58", + "url": "https://files.pythonhosted.org/packages/71/46/d96b08a7f84bf77a7e4a5238bfabd7a1c34b2c1617476c69445668de7923/frozenlist-1.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "47be22dc27ed933d55ee55845d34a3e4e9f6fee93039e7f8ebadb0c2f60d403f", + "url": "https://files.pythonhosted.org/packages/79/58/3a0a77a6be2c368f8e52f4aeba0016bb3a040c9a43553b901bc0e969f54f/frozenlist-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "bde99812f237f79eaf3f04ebffd74f6718bbd216101b35ac7955c2d47c17da02", + "url": "https://files.pythonhosted.org/packages/b3/ac/ac631cdb022ddcf199305c03e45b3234aaab79e00663c4d96dacc39013d9/frozenlist-1.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "03a7dd1bfce30216a3f51a84e6dd0e4a573d23ca50f0346634916ff105ba6e6b", + "url": "https://files.pythonhosted.org/packages/cd/e5/c813ed0b4efa409ba74eb001f552243d4cb8d180723745f04a92340cc3fe/frozenlist-1.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "d2257aaba9660f78c7b1d8fea963b68f3feffb1a9d5d05a18401ca9eb3e8d0a3", + "url": "https://files.pythonhosted.org/packages/e8/28/da4e60e30dad3638570db89f9d6be26ae1f3e183607629b48cd5e35b1c81/frozenlist-1.3.0-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "ce6f2ba0edb7b0c1d8976565298ad2deba6f8064d2bebb6ffce2ca896eb35b0b", + "url": "https://files.pythonhosted.org/packages/f4/f7/8dfeb76d2a52bcea2b0718427af954ffec98be1d34cd8f282034b3e36829/frozenlist-1.3.0.tar.gz" + } + ], + "project_name": "frozenlist", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "1.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff", + "url": "https://files.pythonhosted.org/packages/04/a2/d918dcd22354d8958fe113e1a3630137e0fc8b44859ade3063982eacd2a4/idna-3.3-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d", + "url": "https://files.pythonhosted.org/packages/62/08/e3fc7c8161090f742f504f40b1bccbfc544d4a4e09eb774bf40aafce5436/idna-3.3.tar.gz" + } + ], + "project_name": "idna", + "requires_dists": [], + "requires_python": ">=3.5", + "version": "3.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3", + "url": "https://files.pythonhosted.org/packages/9b/dd/b3c12c6d707058fa947864b67f0c4e0c39ef8610988d7baea9578f3c48f3/iniconfig-1.1.1-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32", + "url": "https://files.pythonhosted.org/packages/23/a2/97899f6bd0e873fed3a7e67ae8d3a08b21799430fb4da15cfedf10d6e2c2/iniconfig-1.1.1.tar.gz" + } + ], + "project_name": "iniconfig", + "requires_dists": [], + "requires_python": null, + "version": "1.1.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "8cbf0132f3de7cc6c6ce00147cc78e6439ea736cee6bca4f068bcf892b0fd658", + "url": "https://files.pythonhosted.org/packages/72/e4/9ea1c573503ddf11ea56c48e9af49660fbd45a13ceb394a48e437c32eba9/multidict-6.0.2-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "041b81a5f6b38244b34dc18c7b6aba91f9cdaf854d9a39e5ff0b58e2b5773b9c", + "url": "https://files.pythonhosted.org/packages/14/7b/d11a6dec8996ca054e727f7d3b1578753b44ba9e378c9449404aef076b47/multidict-6.0.2-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2", + "url": "https://files.pythonhosted.org/packages/1d/35/0ea9ce0cc0aeb3b4c898595d807ac80ebbd295efefabc80c4f6c6bee8106/multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "fcb91630817aa8b9bc4a74023e4198480587269c272c58b3279875ed7235c293", + "url": "https://files.pythonhosted.org/packages/23/31/c8736506ae534e20c8f0b1b090bc2ad89349d96e5e7c5928464c6c876599/multidict-6.0.2-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "50bd442726e288e884f7be9071016c15a8742eb689a593a0cac49ea093eef0a7", + "url": "https://files.pythonhosted.org/packages/2a/c2/0f63e839b93a68dd2bcfbf30cc35dbdb4b172ad0078e32176628ec7d91d5/multidict-6.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "0556a1d4ea2d949efe5fd76a09b4a82e3a4a30700553a6725535098d8d9fb672", + "url": "https://files.pythonhosted.org/packages/31/b1/eb1a8cdb3bb177929dfee9543c0fd8074768c9e4431c7b3da7d01a3c66d8/multidict-6.0.2-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "8064b7c6f0af936a741ea1efd18690bacfbae4078c0c385d7c3f611d11f0cf87", + "url": "https://files.pythonhosted.org/packages/3f/44/83e4bd573cc80c41896394129f162b69fe1ed9fd7a99ca4153740e20349c/multidict-6.0.2-cp310-cp310-musllinux_1_1_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "2d36e929d7f6a16d4eb11b250719c39560dd70545356365b494249e2186bc389", + "url": "https://files.pythonhosted.org/packages/69/d7/c49e9ca438846658191905f5df53a895738b478cdca98580f092b557802c/multidict-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f4f052ee022928d34fe1f4d2bc743f32609fb79ed9c49a1710a5ad6b2198db20", + "url": "https://files.pythonhosted.org/packages/7e/21/73f8a51219fd9b4b04badcc7933ce5f5344ab33308492755220524bc4faf/multidict-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "47e6a7e923e9cada7c139531feac59448f1f47727a79076c0b1ee80274cd8eee", + "url": "https://files.pythonhosted.org/packages/9b/a4/a8d3c6bb884d97fd1e9d37c5c9a8c46de799d7465e455b617f33dfbb52ba/multidict-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "3368bf2398b0e0fcbf46d85795adc4c259299fec50c1416d0f77c0a843a3eed9", + "url": "https://files.pythonhosted.org/packages/bf/b9/b8c9845853b7086476201ff18bcff5a169e945c5d8397e234ba4453a38d4/multidict-6.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "626fe10ac87851f4cffecee161fc6f8f9853f0f6f1035b59337a51d29ff3b4f9", + "url": "https://files.pythonhosted.org/packages/ce/b3/7b2ed0a1fca198da0e6354ccd0358757c12b56f204c179271cf81a7372ae/multidict-6.0.2-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3", + "url": "https://files.pythonhosted.org/packages/d2/67/ef1ef8f3539642d90c77bc7c86cc7283297cd2ab100b45d7541476ef641e/multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "225383a6603c086e6cef0f2f05564acb4f4d5f019a4e3e983f572b8530f70c88", + "url": "https://files.pythonhosted.org/packages/df/93/34efbfa7aa778b04b365960f52f7071d7942ce386572aac8940ae032dd48/multidict-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "5fdda29a3c7e76a064f2477c9aab1ba96fd94e02e386f1e665bca1807fc5386f", + "url": "https://files.pythonhosted.org/packages/ee/a1/a7cc44b7ed84e430c2c176420ffa432a74a2432f7df4f71988365fa8772a/multidict-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "5ff3bd75f38e4c43f1f470f2df7a4d430b821c4ce22be384e1459cb57d6bb013", + "url": "https://files.pythonhosted.org/packages/fa/a7/71c253cdb8a1528802bac7503bf82fe674367e4055b09c28846fdfa4ab90/multidict-6.0.2.tar.gz" + } + ], + "project_name": "multidict", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "6.0.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522", + "url": "https://files.pythonhosted.org/packages/05/8e/8de486cbd03baba4deef4142bd643a3e7bbe954a784dc1bb17142572d127/packaging-21.3-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb", + "url": "https://files.pythonhosted.org/packages/df/9e/d1a7217f69310c1db8fdf8ab396229f55a699ce34a203691794c5d1cad0c/packaging-21.3.tar.gz" + } + ], + "project_name": "packaging", + "requires_dists": [ + "pyparsing!=3.0.5,>=2.0.2" + ], + "requires_python": ">=3.6", + "version": "21.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3", + "url": "https://files.pythonhosted.org/packages/9e/01/f38e2ff29715251cf25532b9082a1589ab7e4f571ced434f98d0139336dc/pluggy-1.0.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159", + "url": "https://files.pythonhosted.org/packages/a1/16/db2d7de3474b6e37cbb9c008965ee63835bba517e22cdb8c35b5116b5ce1/pluggy-1.0.0.tar.gz" + } + ], + "project_name": "pluggy", + "requires_dists": [ + "importlib-metadata>=0.12; python_version < \"3.8\"", + "pre-commit; extra == \"dev\"", + "pytest-benchmark; extra == \"testing\"", + "pytest; extra == \"testing\"", + "tox; extra == \"dev\"" + ], + "requires_python": ">=3.6", + "version": "1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", + "url": "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", + "url": "https://files.pythonhosted.org/packages/98/ff/fec109ceb715d2a6b4c4a85a61af3b40c723a961e8828319fbcb15b868dc/py-1.11.0.tar.gz" + } + ], + "project_name": "py", + "requires_dists": [], + "requires_python": "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7", + "version": "1.11" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc", + "url": "https://files.pythonhosted.org/packages/6c/10/a7d0fa5baea8fe7b50f448ab742f26f52b80bfca85ac2be9d35cdd9a3246/pyparsing-3.0.9-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb", + "url": "https://files.pythonhosted.org/packages/71/22/207523d16464c40a0310d2d4d8926daffa00ac1f5b1576170a32db749636/pyparsing-3.0.9.tar.gz" + } + ], + "project_name": "pyparsing", + "requires_dists": [ + "jinja2; extra == \"diagrams\"", + "railroad-diagrams; extra == \"diagrams\"" + ], + "requires_python": ">=3.6.8", + "version": "3.0.9" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "13d0e3ccfc2b6e26be000cb6568c832ba67ba32e719443bfe725814d3c42433c", + "url": "https://files.pythonhosted.org/packages/fb/d0/bae533985f2338c5d02184b4a7083b819f6b3fc101da792e0d96e6e5299d/pytest-7.1.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45", + "url": "https://files.pythonhosted.org/packages/4e/1f/34657c6ac56f3c58df650ba41f8ffb2620281ead8e11bcdc7db63cf72a78/pytest-7.1.2.tar.gz" + } + ], + "project_name": "pytest", + "requires_dists": [ + "argcomplete; extra == \"testing\"", + "atomicwrites>=1.0; sys_platform == \"win32\"", + "attrs>=19.2.0", + "colorama; sys_platform == \"win32\"", + "hypothesis>=3.56; extra == \"testing\"", + "importlib-metadata>=0.12; python_version < \"3.8\"", + "iniconfig", + "mock; extra == \"testing\"", + "nose; extra == \"testing\"", + "packaging", + "pluggy<2.0,>=0.12", + "py>=1.8.2", + "pygments>=2.7.2; extra == \"testing\"", + "requests; extra == \"testing\"", + "tomli>=1.0.0", + "xmlschema; extra == \"testing\"" + ], + "requires_python": ">=3.7", + "version": "7.1.2" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "1d2dc3a304c2be1fd496c0c2fb6b31ab60cd9fc33984f761f951f8ea1eb4ca95", + "url": "https://files.pythonhosted.org/packages/02/ee/871f5d1833e7a3f2325796ab9509de52fd934ca71a37cffbea5c3b6d7ecf/pytest_aiohttp-1.0.4-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "39ff3a0d15484c01d1436cbedad575c6eafbf0f57cdf76fb94994c97b5b8c5a4", + "url": "https://files.pythonhosted.org/packages/11/fa/64b1bbc2514c934fd8cd251cc91ba38faa533c3fbbab5b7cf17d54b05e22/pytest-aiohttp-1.0.4.tar.gz" + } + ], + "project_name": "pytest-aiohttp", + "requires_dists": [ + "aiohttp>=3.8.1", + "coverage==6.2; extra == \"testing\"", + "mypy==0.931; extra == \"testing\"", + "pytest-asyncio>=0.17.2", + "pytest>=6.1.0" + ], + "requires_python": ">=3.7", + "version": "1.0.4" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "16cf40bdf2b4fb7fc8e4b82bd05ce3fbcd454cbf7b92afc445fe299dabb88213", + "url": "https://files.pythonhosted.org/packages/8b/d6/4ecdd0c5b49a2209131b6af78baa643cec35f213abbc54d0eb1542b3786d/pytest_asyncio-0.18.3-1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "7659bdb0a9eb9c6e3ef992eef11a2b3e69697800ad02fb06374a210d85b29f91", + "url": "https://files.pythonhosted.org/packages/4d/73/769d29676fb36a36e5a57c198154171081aabcfd08112a24a4e3fb5c9f10/pytest-asyncio-0.18.3.tar.gz" + }, + { + "algorithm": "sha256", + "hash": "8fafa6c52161addfd41ee7ab35f11836c5a16ec208f93ee388f752bea3493a84", + "url": "https://files.pythonhosted.org/packages/ac/4b/7c400506ec484ec999b10133aa8e31af39dfc727042dc6944cd45fd927d0/pytest_asyncio-0.18.3-py3-none-any.whl" + } + ], + "project_name": "pytest-asyncio", + "requires_dists": [ + "coverage==6.2; extra == \"testing\"", + "flaky>=3.5.0; extra == \"testing\"", + "hypothesis>=5.7.1; extra == \"testing\"", + "mypy==0.931; extra == \"testing\"", + "pytest-trio>=0.7.0; extra == \"testing\"", + "pytest>=6.1.0", + "typing-extensions>=3.7.2; python_version < \"3.8\"" + ], + "requires_python": ">=3.7", + "version": "0.18.3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6", + "url": "https://files.pythonhosted.org/packages/20/49/b3e0edec68d81846f519c602ac38af9db86e1e71275528b3e814ae236063/pytest_cov-3.0.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470", + "url": "https://files.pythonhosted.org/packages/61/41/e046526849972555928a6d31c2068410e47a31fb5ab0a77f868596811329/pytest-cov-3.0.0.tar.gz" + } + ], + "project_name": "pytest-cov", + "requires_dists": [ + "coverage[toml]>=5.2.1", + "fields; extra == \"testing\"", + "hunter; extra == \"testing\"", + "process-tests; extra == \"testing\"", + "pytest-xdist; extra == \"testing\"", + "pytest>=4.6", + "six; extra == \"testing\"", + "virtualenv; extra == \"testing\"" + ], + "requires_python": ">=3.6", + "version": "3" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "c2a892906192663f85030a6ab91304e508e546cddfe557d692d61ec57a1d946b", + "url": "https://files.pythonhosted.org/packages/69/6d/cfd6d654877f75e0368e4040f1cf0350dd9f427b578bf7b685af629f8167/pytest-dependency-0.5.1.tar.gz" + } + ], + "project_name": "pytest-dependency", + "requires_dists": [ + "pytest>=3.6.0" + ], + "requires_python": null, + "version": "0.5.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "6cff27cec936bf81dc5ee87f07132b807bcda51106b5ec4b90a04331cba76231", + "url": "https://files.pythonhosted.org/packages/11/40/8fcb3c0f72e11dc44e1102b2adf5f160b8a00e84d915798c60aabcd9257a/pytest_mock-3.7.0-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "5112bd92cc9f186ee96e1a92efc84969ea494939c3aead39c50f421c4cc69534", + "url": "https://files.pythonhosted.org/packages/96/e1/fb53b62056e6840a36d9a4beb4e42726155594c567b574103435a7131c60/pytest-mock-3.7.0.tar.gz" + } + ], + "project_name": "pytest-mock", + "requires_dists": [ + "pre-commit; extra == \"dev\"", + "pytest-asyncio; extra == \"dev\"", + "pytest>=5.0", + "tox; extra == \"dev\"" + ], + "requires_python": ">=3.7", + "version": "3.7" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc", + "url": "https://files.pythonhosted.org/packages/97/75/10a9ebee3fd790d20926a90a2547f0bf78f371b2f13aa822c759680ca7b9/tomli-2.0.1-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f", + "url": "https://files.pythonhosted.org/packages/c0/3f/d7af728f075fb08564c5949a9c95e44352e23dee646869fa104a3b2060a3/tomli-2.0.1.tar.gz" + } + ], + "project_name": "tomli", + "requires_dists": [], + "requires_python": ">=3.7", + "version": "2.0.1" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "c9c6d927e098c2d360695f2e9d38870b2e92e0919be07dbe339aefa32a090265", + "url": "https://files.pythonhosted.org/packages/7c/ad/bf6dfc6521394aa7d0b3ecbdf5e2b272fd1e79d585107869e75f0e283245/yarl-1.7.2-cp310-cp310-win_amd64.whl" + }, + { + "algorithm": "sha256", + "hash": "cff3ba513db55cc6a35076f32c4cdc27032bd075c9faef31fec749e64b45d26c", + "url": "https://files.pythonhosted.org/packages/1a/09/a9b4fc484f562297158ad03f6db123f9e1f39424a969599ca0b6cbe5367f/yarl-1.7.2-cp310-cp310-win32.whl" + }, + { + "algorithm": "sha256", + "hash": "167ab7f64e409e9bdd99333fe8c67b5574a1f0495dcfd905bc7454e766729b9e", + "url": "https://files.pythonhosted.org/packages/48/2d/3992de6e80cacc12b51f3cb690590a5a834f9ac2022c88e9ac0d3b293c77/yarl-1.7.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f2a8508f7350512434e41065684076f640ecce176d262a7d54f0da41d99c5a95", + "url": "https://files.pythonhosted.org/packages/4e/a5/edfa475dc2138da03cc7561b4fbfb26c2bb18c1f41a99333adb28a9a90e5/yarl-1.7.2-cp310-cp310-macosx_10_9_universal2.whl" + }, + { + "algorithm": "sha256", + "hash": "6152224d0a1eb254f97df3997d79dadd8bb2c1a02ef283dbb34b97d4f8492d23", + "url": "https://files.pythonhosted.org/packages/69/4d/a64f3371ff9e599aa738699a539d6391cea226299b28a922900b3e5a2bd1/yarl-1.7.2-cp310-cp310-musllinux_1_1_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "95a1873b6c0dd1c437fb3bb4a4aaa699a48c218ac7ca1e74b0bee0ab16c7d60d", + "url": "https://files.pythonhosted.org/packages/90/6c/23b7bba775522b819b2b6616aa83fd1f4577fea3e7c6ed0a862df1aeb855/yarl-1.7.2-cp310-cp310-musllinux_1_1_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "a1d0894f238763717bdcfea74558c94e3bc34aeacd3351d769460c1a586a8b05", + "url": "https://files.pythonhosted.org/packages/94/d3/434dca72103d1280dd3e1281f501fb5e6ad0eb6c18ae92ca8d43fb8c2fa7/yarl-1.7.2-cp310-cp310-macosx_11_0_arm64.whl" + }, + { + "algorithm": "sha256", + "hash": "1d3d5ad8ea96bd6d643d80c7b8d5977b4e2fb1bab6c9da7322616fd26203d125", + "url": "https://files.pythonhosted.org/packages/a9/3a/19cb4d33a7b3e81d2a3663803c59a7365bf4694077823c3d1ff2f82a2481/yarl-1.7.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl" + }, + { + "algorithm": "sha256", + "hash": "da6df107b9ccfe52d3a48165e48d72db0eca3e3029b5b8cb4fe6ee3cb870ba8b", + "url": "https://files.pythonhosted.org/packages/b8/43/bd158143b6facbd309fd0b10a21b9546f455db6f851be6911e6b25c40c47/yarl-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "f44477ae29025d8ea87ec308539f95963ffdc31a82f42ca9deecf2d505242e72", + "url": "https://files.pythonhosted.org/packages/bc/4a/a6f020c4be2654bf8d375731fcacfdcfd1d2f5fd0c48c8dfebb6ec14a84b/yarl-1.7.2-cp310-cp310-musllinux_1_1_x86_64.whl" + }, + { + "algorithm": "sha256", + "hash": "1ca56f002eaf7998b5fcf73b2421790da9d2586331805f38acd9997743114e98", + "url": "https://files.pythonhosted.org/packages/d0/5f/0410c8c038e626b8732db53bf7ca2b5deb2b1ac8b4a4659763890a61a43c/yarl-1.7.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "5bb7d54b8f61ba6eee541fba4b83d22b8a046b4ef4d8eb7f15a7e35db2e1e245", + "url": "https://files.pythonhosted.org/packages/d8/71/c3b593ccef94111a41aed0cf068be3a5f0e331eb1ff9ea538d21b523e6f4/yarl-1.7.2-cp310-cp310-musllinux_1_1_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "dfe4b95b7e00c6635a72e2d00b478e8a28bfb122dc76349a06e20792eb53a523", + "url": "https://files.pythonhosted.org/packages/db/c7/6f0ae227ea247012055daf4856a8cd85d690f0b18480c54da0b919d2beba/yarl-1.7.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + }, + { + "algorithm": "sha256", + "hash": "c145ab54702334c42237a6c6c4cc08703b6aa9b94e2f227ceb3d477d20c36c63", + "url": "https://files.pythonhosted.org/packages/e8/ce/920cebfb0fef407eae4d21b37be949d9c4e47671bb9d7271dd8203cd55d8/yarl-1.7.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" + }, + { + "algorithm": "sha256", + "hash": "9c1f083e7e71b2dd01f7cd7434a5f88c15213194df38bc29b388ccdf1492b739", + "url": "https://files.pythonhosted.org/packages/f2/0b/b897521eb6367f97f452bb6313d99e3653f93e5e62b53c60c865c4bc23b0/yarl-1.7.2-cp310-cp310-musllinux_1_1_s390x.whl" + }, + { + "algorithm": "sha256", + "hash": "45399b46d60c253327a460e99856752009fcee5f5d3c80b2f7c0cae1c38d56dd", + "url": "https://files.pythonhosted.org/packages/f6/da/46d1b3d69a9a0835dabf9d59c7eb0f1600599edd421a4c5a15ab09f527e0/yarl-1.7.2.tar.gz" + } + ], + "project_name": "yarl", + "requires_dists": [ + "idna>=2.0", + "multidict>=4.0", + "typing-extensions>=3.7.4; python_version < \"3.8\"" + ], + "requires_python": ">=3.6", + "version": "1.7.2" + } + ], + "platform_tag": [ + "cp310", + "cp310", + "manylinux_2_31_aarch64" + ] + } + ], + "path_mappings": {}, + "pex_version": "2.1.84", + "prefer_older_binary": false, + "requirements": [ + "aioresponses>=0.7.3", + "pytest-aiohttp>=1.0.4", + "pytest-asyncio>=0.18", + "pytest-cov!=2.12.1,<3.1,>=2.12", + "pytest-dependency>=0.5.1", + "pytest-mock>=3.5.0", + "pytest>=7.0" + ], + "requires_python": [ + "==3.10.4" + ], + "resolver_version": "pip-2020-resolver", + "style": "universal", + "transitive": true, + "use_pep517": null +} \ No newline at end of file