diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml
index eeb707267184e..b6fd22af0fbce 100644
--- a/.github/checkgroup.yml
+++ b/.github/checkgroup.yml
@@ -6,7 +6,7 @@ subprojects:
- id: "pytorch_lightning: Tests workflow"
paths:
- ".actions/**"
- - ".github/workflows/ci-pytorch-tests.yml"
+ - ".github/workflows/ci-tests-pytorch.yml"
- "requirements/lite/**"
- "src/lightning_lite/**"
- "requirements/pytorch/**"
@@ -178,7 +178,7 @@ subprojects:
- "src/lightning_lite/**"
- "tests/tests_lite/**"
- "setup.cfg" # includes pytest config
- - ".github/workflows/ci-lite-tests.yml"
+ - ".github/workflows/ci-tests-lite.yml"
- "!requirements/*/docs.txt"
- "!*.md"
- "!**/*.md"
@@ -221,7 +221,7 @@ subprojects:
- id: "lightning_app: Tests workflow"
paths:
- ".actions/**"
- - ".github/workflows/ci-app-tests.yml"
+ - ".github/workflows/ci-tests-app.yml"
- "src/lightning_app/**"
- "tests/tests_app/**"
- "requirements/app/**"
@@ -243,7 +243,7 @@ subprojects:
- id: "lightning_app: Examples"
paths:
- ".actions/**"
- - ".github/workflows/ci-app-examples.yml"
+ - ".github/workflows/ci-examples-app.yml"
- "src/lightning_app/**"
- "tests/tests_examples_app/**"
- "examples/app_*/**"
diff --git a/.github/workflows/README.md b/.github/workflows/README.md
index 3437dd03e6d50..9f3d7a05584b7 100644
--- a/.github/workflows/README.md
+++ b/.github/workflows/README.md
@@ -4,10 +4,10 @@
## Unit and Integration Testing
-| workflow name | workflow file | action | accelerator\* |
-| -------------------------- | ------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- |
-| Test PyTorch full | .github/workflows/ci-pytorch-tests.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU |
-| Test PyTorch slow | .github/workflows/ci-pytorch-tests-slow.yml | Run only slow tests. Slow tests usually need to spawn threads and cannot be speed up or simplified. | CPU |
+| workflow name | workflow file | action | accelerator\* |
+| ----------------- | -------------------------------------- | ------------------------------------------------------------------------- | ------------- |
+| Test PyTorch full | .github/workflows/ci-tests-pytorch.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU |
+
| pytorch-lightning (IPUs) | .azure-pipelines/ipu-tests.yml | Run only IPU-specific tests. | IPU |
| pytorch-lightning (HPUs) | .azure-pipelines/hpu-tests.yml | Run only HPU-specific tests. | HPU |
| pytorch-lightning (GPUs) | .azure-pipelines/gpu-tests-pytorch.yml | Run all CPU and GPU-specific tests, standalone, and examples. Each standalone test needs to be run in separate processes to avoid unwanted interactions between test cases. | GPU |
diff --git a/.github/workflows/ci-pytorch-dockers.yml b/.github/workflows/ci-dockers-pytorch.yml
similarity index 100%
rename from .github/workflows/ci-pytorch-dockers.yml
rename to .github/workflows/ci-dockers-pytorch.yml
diff --git a/.github/workflows/ci-app-examples.yml b/.github/workflows/ci-examples-app.yml
similarity index 99%
rename from .github/workflows/ci-app-examples.yml
rename to .github/workflows/ci-examples-app.yml
index 201f3981f2619..e0713eef896ad 100644
--- a/.github/workflows/ci-app-examples.yml
+++ b/.github/workflows/ci-examples-app.yml
@@ -9,7 +9,7 @@ on:
types: [opened, reopened, ready_for_review, synchronize] # added `ready_for_review` since draft is skipped
paths:
- ".actions/**"
- - ".github/workflows/ci-app-examples.yml"
+ - ".github/workflows/ci-examples-app.yml"
- "src/lightning_app/**"
- "tests/tests_examples_app/**"
- "examples/app_*/**"
diff --git a/.github/workflows/ci-app-tests.yml b/.github/workflows/ci-tests-app.yml
similarity index 99%
rename from .github/workflows/ci-app-tests.yml
rename to .github/workflows/ci-tests-app.yml
index 1a082f69c0a1d..c07cb6faeed38 100644
--- a/.github/workflows/ci-app-tests.yml
+++ b/.github/workflows/ci-tests-app.yml
@@ -9,7 +9,7 @@ on:
types: [opened, reopened, ready_for_review, synchronize] # added `ready_for_review` since draft is skipped
paths:
- ".actions/**"
- - ".github/workflows/ci-app-tests.yml"
+ - ".github/workflows/ci-tests-app.yml"
- "src/lightning_app/**"
- "tests/tests_app/**"
- "requirements/app/**"
diff --git a/.github/workflows/ci-lite-tests.yml b/.github/workflows/ci-tests-lite.yml
similarity index 99%
rename from .github/workflows/ci-lite-tests.yml
rename to .github/workflows/ci-tests-lite.yml
index 1db82fe8ba52c..8d38f3247b529 100644
--- a/.github/workflows/ci-lite-tests.yml
+++ b/.github/workflows/ci-tests-lite.yml
@@ -13,7 +13,7 @@ on:
- "src/lightning_lite/**"
- "tests/tests_lite/**"
- "setup.cfg" # includes pytest config
- - ".github/workflows/ci-lite-tests.yml"
+ - ".github/workflows/ci-tests-lite.yml"
- "!requirements/*/docs.txt"
- "!*.md"
- "!**/*.md"
diff --git a/.github/workflows/ci-pytorch-tests.yml b/.github/workflows/ci-tests-pytorch.yml
similarity index 99%
rename from .github/workflows/ci-pytorch-tests.yml
rename to .github/workflows/ci-tests-pytorch.yml
index 34ef2b0834949..fd6692c69a459 100644
--- a/.github/workflows/ci-pytorch-tests.yml
+++ b/.github/workflows/ci-tests-pytorch.yml
@@ -14,7 +14,7 @@ on:
- "tests/tests_pytorch/**"
- "tests/legacy/back-compatible-versions.txt"
- "setup.cfg" # includes pytest config
- - ".github/workflows/ci-pytorch-tests.yml"
+ - ".github/workflows/ci-tests-pytorch.yml"
- "requirements/lite/**"
- "src/lightning_lite/**"
- "!requirements/pytorch/docs.txt"
diff --git a/.github/workflows/probot-check-group.yml b/.github/workflows/probot-check-group.yml
index 47a60061cc8a3..c9b0efdd9b2d8 100644
--- a/.github/workflows/probot-check-group.yml
+++ b/.github/workflows/probot-check-group.yml
@@ -14,7 +14,7 @@ jobs:
if: github.event.pull_request.draft == false
timeout-minutes: 61 # in case something is wrong with the internal timeout
steps:
- - uses: Lightning-AI/probot@v5.1
+ - uses: Lightning-AI/probot@v5.3
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
diff --git a/.gitignore b/.gitignore
index 054a5ba16aff5..835a98a19efd9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -109,8 +109,8 @@ celerybeat-schedule
# dotenv
.env
-.env_staging
-.env_local
+.env.staging
+.env.local
# virtualenv
.venv
diff --git a/README.md b/README.md
index 28e588a52145c..19c618122a8a8 100644
--- a/README.md
+++ b/README.md
@@ -90,15 +90,15 @@ Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, and HPUs
-| System / PyTorch ver. | 1.9 | 1.10 | 1.12 (latest) |
-| :------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| Linux py3.7 \[GPUs\*\*\] | - | - | - |
-| Linux py3.7 \[TPUs\*\*\*\] | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/tpu-tests.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/tpu-tests.yml) | - | - |
-| Linux py3.8 \[IPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=25&branchName=master) | - | - |
-| Linux py3.8 \[HPUs\] | - | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=26&branchName=master) | - |
-| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml) |
-| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml) |
-| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml) |
+| System / PyTorch ver. | 1.10 | 1.12 |
+| :------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| Linux py3.7 \[GPUs\*\*\] | - | - |
+| Linux py3.7 \[TPUs\*\*\*\] | - | - |
+| Linux py3.8 \[IPUs\] | - | - |
+| Linux py3.8 \[HPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=26&branchName=master) | - |
+| Linux py3.{7,9} | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) |
+| OSX py3.{7,9} | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) |
+| Windows py3.{7,9} | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) |
- _\*\* tests run on two NVIDIA P100_
- _\*\*\* tests run on Google GKE TPUv2/3. TPU py3.7 means we support Colab and Kaggle env._
diff --git a/docs/source-app/api_references.rst b/docs/source-app/api_references.rst
index 2272f7bf13c41..9bb5874b533e4 100644
--- a/docs/source-app/api_references.rst
+++ b/docs/source-app/api_references.rst
@@ -45,7 +45,8 @@ ___________________
~multi_node.lite.LiteMultiNode
~multi_node.pytorch_spawn.PyTorchSpawnMultiNode
~multi_node.trainer.LightningTrainerMultiNode
- ~auto_scaler.AutoScaler
+ ~serve.auto_scaler.AutoScaler
+ ~serve.auto_scaler.ColdStartProxy
----
diff --git a/docs/source-pytorch/index.rst b/docs/source-pytorch/index.rst
index e8803ba147e83..1d40b3e35fc0c 100644
--- a/docs/source-pytorch/index.rst
+++ b/docs/source-pytorch/index.rst
@@ -64,6 +64,8 @@ Conda users
Or read the `advanced install guide `_
+We are fully compatible with any stable PyTorch version v1.10 and above.
+
.. raw:: html
diff --git a/docs/source-pytorch/model/build_model.rst b/docs/source-pytorch/model/build_model.rst
index 8d12110db8053..c480a90b75e87 100644
--- a/docs/source-pytorch/model/build_model.rst
+++ b/docs/source-pytorch/model/build_model.rst
@@ -23,7 +23,7 @@ Build a Model
:header: 2: Validate and test a model
:description: Add a validation and test data split to avoid overfitting.
:col_css: col-md-4
- :button_link: validate_model_basic.html
+ :button_link: ../common/evaluation_basic.html
:height: 150
:tag: basic
diff --git a/examples/app_boring/app.py b/examples/app_boring/app.py
index aad288a11acb4..78a9b1c819f06 100644
--- a/examples/app_boring/app.py
+++ b/examples/app_boring/app.py
@@ -43,6 +43,10 @@ def __init__(self):
raise_exception=True,
)
+ @property
+ def ready(self) -> bool:
+ return self.dest_work.is_running
+
def run(self):
self.source_work.run()
if self.source_work.has_succeeded:
diff --git a/examples/app_display_name/.lightningignore b/examples/app_display_name/.lightningignore
new file mode 100644
index 0000000000000..f7275bbbd035b
--- /dev/null
+++ b/examples/app_display_name/.lightningignore
@@ -0,0 +1 @@
+venv/
diff --git a/examples/app_display_name/app.py b/examples/app_display_name/app.py
new file mode 100644
index 0000000000000..f06d8ee562fdf
--- /dev/null
+++ b/examples/app_display_name/app.py
@@ -0,0 +1,25 @@
+import lightning as L
+
+
+class Work(L.LightningWork):
+ def __init__(self, start_with_flow=True):
+ super().__init__(start_with_flow=start_with_flow)
+
+ def run(self):
+ pass
+
+
+class Flow(L.LightningFlow):
+ def __init__(self):
+ super().__init__()
+ self.w = Work()
+ self.w1 = Work(start_with_flow=False)
+ self.w.display_name = "My Custom Name" # Not supported yet
+ self.w1.display_name = "My Custom Name 1"
+
+ def run(self):
+ self.w.run()
+ self.w1.run()
+
+
+app = L.LightningApp(Flow())
diff --git a/examples/app_server_with_auto_scaler/app.py b/examples/app_server_with_auto_scaler/app.py
index 70799827776a8..2c8fb744c4fcf 100644
--- a/examples/app_server_with_auto_scaler/app.py
+++ b/examples/app_server_with_auto_scaler/app.py
@@ -1,5 +1,5 @@
# ! pip install torch torchvision
-from typing import Any, List
+from typing import List
import torch
import torchvision
@@ -8,16 +8,12 @@
import lightning as L
-class RequestModel(BaseModel):
- image: str # bytecode
-
-
class BatchRequestModel(BaseModel):
- inputs: List[RequestModel]
+ inputs: List[L.app.components.Image]
class BatchResponse(BaseModel):
- outputs: List[Any]
+ outputs: List[L.app.components.Number]
class PyTorchServer(L.app.components.PythonServer):
@@ -79,10 +75,11 @@ def scale(self, replicas: int, metrics: dict) -> int:
# autoscaler specific args
min_replicas=1,
max_replicas=4,
- autoscale_interval=10,
+ scale_out_interval=10,
+ scale_in_interval=10,
endpoint="predict",
- input_type=RequestModel,
- output_type=Any,
+ input_type=L.app.components.Image,
+ output_type=L.app.components.Number,
timeout_batching=1,
max_batch_size=8,
)
diff --git a/pyproject.toml b/pyproject.toml
index 8611ef9323deb..4461d956634c6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -79,8 +79,8 @@ module = [
"lightning_app.components.serve.types.image",
"lightning_app.components.serve.types.type",
"lightning_app.components.serve.python_server",
+ "lightning_app.components.serve.auto_scaler",
"lightning_app.components.training",
- "lightning_app.components.auto_scaler",
"lightning_app.core.api",
"lightning_app.core.app",
"lightning_app.core.flow",
diff --git a/requirements/app/cloud.txt b/requirements/app/cloud.txt
index 14b4f30d7db5f..45b237ec2bd71 100644
--- a/requirements/app/cloud.txt
+++ b/requirements/app/cloud.txt
@@ -1,5 +1,3 @@
-# WARNING: this file is not used directly by the backend
-# any dependency here needs to be shipped with the base image
redis>=4.0.1, <=4.2.4
docker>=5.0.0, <6.0.2
s3fs>=2022.5.0, <2022.8.3
diff --git a/requirements/app/test.txt b/requirements/app/test.txt
index ddbe7f1e0be12..44990fafc137e 100644
--- a/requirements/app/test.txt
+++ b/requirements/app/test.txt
@@ -4,6 +4,7 @@ pytest==7.2.0
pytest-timeout==2.1.0
pytest-cov==4.0.0
pytest-doctestplus>=0.9.0
+pytest-asyncio==0.20.3
playwright==1.28.0
httpx
trio<0.22.0
diff --git a/requirements/lite/base.txt b/requirements/lite/base.txt
index 6595b229ebf52..f80d5292d77c3 100644
--- a/requirements/lite/base.txt
+++ b/requirements/lite/base.txt
@@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
numpy>=1.17.2, <1.23.1
-torch>=1.9.0, <=1.13.0
+torch>=1.9.0, <=1.13.1
fsspec[http]>2021.06.0, <2022.6.0
packaging>=17.0, <=21.3
typing-extensions>=4.0.0, <=4.4.0
diff --git a/requirements/lite/examples.txt b/requirements/lite/examples.txt
index 43bb03e07cc80..e4d4136b6b0c4 100644
--- a/requirements/lite/examples.txt
+++ b/requirements/lite/examples.txt
@@ -1,4 +1,4 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
-torchvision>=0.10.0, <=0.13.0
+torchvision>=0.10.0, <=0.14.1
diff --git a/requirements/pytorch/adjust-versions.py b/requirements/pytorch/adjust-versions.py
index 69d61e130ca4b..567449e577b71 100644
--- a/requirements/pytorch/adjust-versions.py
+++ b/requirements/pytorch/adjust-versions.py
@@ -5,7 +5,9 @@
# IMPORTANT: this list needs to be sorted in reverse
VERSIONS = [
- dict(torch="1.13.0", torchvision="0.14.0"), # stable
+ dict(torch="1.14.0", torchvision="0.15.0"), # nightly
+ dict(torch="1.13.1", torchvision="0.14.1"), # stable
+ dict(torch="1.13.0", torchvision="0.14.0"),
dict(torch="1.12.1", torchvision="0.13.1"),
dict(torch="1.12.0", torchvision="0.13.0"),
dict(torch="1.11.0", torchvision="0.12.0"),
diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt
index cd9c8c603ded0..0a50e8427055d 100644
--- a/requirements/pytorch/base.txt
+++ b/requirements/pytorch/base.txt
@@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
numpy>=1.17.2, <1.23.1
-torch>=1.9.0, <=1.13.0
+torch>=1.9.0, <=1.13.1
tqdm>=4.57.0, <4.65.0
PyYAML>=5.4, <=6.0
fsspec[http]>2021.06.0, <2022.8.0
diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt
index 82ad1ecf400ef..8d96b4290eda2 100644
--- a/requirements/pytorch/examples.txt
+++ b/requirements/pytorch/examples.txt
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
-torchvision>=0.10.0, <=0.14.0
+torchvision>=0.10.0, <=0.14.1
gym[classic_control]>=0.17.0, <0.26.3
ipython[all] <8.6.1
diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt
index 101415a7c9e1a..4a7bee510cbdb 100644
--- a/requirements/pytorch/strategies.txt
+++ b/requirements/pytorch/strategies.txt
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
-colossalai>=0.1.10
+# colossalai>=0.1.10 # TODO: uncomment when there's a stable version released
fairscale>=0.4.5, <0.4.13
deepspeed>=0.6.0, <=0.7.0
# no need to install with [pytorch] as pytorch is already installed
diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt
index 485e6e6fd0393..ec1d68ccb8459 100644
--- a/requirements/pytorch/test.txt
+++ b/requirements/pytorch/test.txt
@@ -8,7 +8,7 @@ pre-commit==2.20.0
# needed in tests
cloudpickle>=1.3, <2.3.0
-scikit-learn>0.22.1, <1.1.3
+scikit-learn>0.22.1, <1.2.1
onnxruntime<1.14.0
psutil<5.9.4 # for `DeviceStatsMonitor`
pandas>1.0, <1.5.2 # needed in benchmarks
diff --git a/src/lightning/__version__.py b/src/lightning/__version__.py
index 4a295c3c7c531..fd11d27b1347a 100644
--- a/src/lightning/__version__.py
+++ b/src/lightning/__version__.py
@@ -1 +1 @@
-version = "1.8.5.post0"
+version = "1.8.6"
diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md
index eb428e3876642..b5a7e6f849b4d 100644
--- a/src/lightning_app/CHANGELOG.md
+++ b/src/lightning_app/CHANGELOG.md
@@ -5,6 +5,73 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
+## [1.8.6] - 2022-12-21
+
+### Added
+
+- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))
+
+
+- Added a nicer UI with URL and examples for the autoscaler component ([#16063](https://github.com/Lightning-AI/lightning/pull/16063))
+
+
+- Enabled users to have more control over scaling out/in interval ([#16093](https://github.com/Lightning-AI/lightning/pull/16093))
+
+
+- Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018))
+
+
+- Added `work.delete` method to delete the work ([#16103](https://github.com/Lightning-AI/lightning/pull/16103))
+
+
+- Added `display_name` property to LightningWork for the cloud ([#16095](https://github.com/Lightning-AI/lightning/pull/16095))
+
+
+- Added `ColdStartProxy` to the AutoScaler ([#16094](https://github.com/Lightning-AI/lightning/pull/16094))
+
+
+### Changed
+
+
+- The default `start_method` for creating Work processes locally on MacOS is now 'spawn' (previously 'fork') ([#16089](https://github.com/Lightning-AI/lightning/pull/16089))
+
+
+- The utility `lightning.app.utilities.cloud.is_running_in_cloud` now returns `True` during loading of the app locally when running with `--cloud` ([#16045](https://github.com/Lightning-AI/lightning/pull/16045))
+
+
+### Deprecated
+
+-
+
+
+### Removed
+
+-
+
+
+### Fixed
+
+- Fixed `PythonServer` messaging "Your app has started" ([#15989](https://github.com/Lightning-AI/lightning/pull/15989))
+
+
+- Fixed auto-batching to enable batching for requests coming even after batch interval but is in the queue ([#16110](https://github.com/Lightning-AI/lightning/pull/16110))
+
+
+- Fixed a bug where `AutoScaler` would fail with min_replica=0 ([#16092](https://github.com/Lightning-AI/lightning/pull/16092)
+
+
+- Fixed a non-thread safe deepcopy in the scheduler ([#16114](https://github.com/Lightning-AI/lightning/pull/16114))
+
+
+- Fixed Http Queue sleeping for 1 sec by default if no delta were found ([#16114](https://github.com/Lightning-AI/lightning/pull/16114))
+
+
+- Fixed the endpoint info tab not showing up in `AutoScaler` UI ([#16128](https://github.com/Lightning-AI/lightning/pull/16128))
+
+
+- Fixed an issue where an exception would be raised in the logs when using a recent version of streamlit ([#16139](https://github.com/Lightning-AI/lightning/pull/16139))
+
+
## [1.8.5] - 2022-12-15
### Added
diff --git a/src/lightning_app/__version__.py b/src/lightning_app/__version__.py
index 4a295c3c7c531..fd11d27b1347a 100644
--- a/src/lightning_app/__version__.py
+++ b/src/lightning_app/__version__.py
@@ -1 +1 @@
-version = "1.8.5.post0"
+version = "1.8.6"
diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py
index ca09a9a83eecc..379e87cb68676 100644
--- a/src/lightning_app/api/http_methods.py
+++ b/src/lightning_app/api/http_methods.py
@@ -2,12 +2,14 @@
import inspect
import time
from copy import deepcopy
+from dataclasses import dataclass
from functools import wraps
from multiprocessing import Queue
from typing import Any, Callable, Dict, List, Optional
from uuid import uuid4
-from fastapi import FastAPI, HTTPException
+from fastapi import FastAPI, HTTPException, Request, status
+from lightning_utilities.core.apply_func import apply_to_collection
from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
from lightning_app.utilities.app_helpers import Logger
@@ -19,6 +21,77 @@ def _signature_proxy_function():
pass
+@dataclass
+class _FastApiMockRequest:
+ """This class is meant to mock FastAPI Request class that isn't pickle-able.
+
+ If a user relies on FastAPI Request annotation, the Lightning framework
+ patches the annotation before pickling and replace them right after.
+
+ Finally, the FastAPI request is converted back to the _FastApiMockRequest
+ before being delivered to the users.
+
+ Example:
+
+ import lightning as L
+ from fastapi import Request
+ from lightning.app.api import Post
+
+ class Flow(L.LightningFlow):
+
+ def request(self, request: Request) -> OutputRequestModel:
+ ...
+
+ def configure_api(self):
+ return [Post("/api/v1/request", self.request)]
+ """
+
+ _body: Optional[str] = None
+ _json: Optional[str] = None
+ _method: Optional[str] = None
+ _headers: Optional[Dict] = None
+
+ @property
+ def receive(self):
+ raise NotImplementedError
+
+ @property
+ def method(self):
+ raise self._method
+
+ @property
+ def headers(self):
+ return self._headers
+
+ def body(self):
+ return self._body
+
+ def json(self):
+ return self._json
+
+ def stream(self):
+ raise NotImplementedError
+
+ def form(self):
+ raise NotImplementedError
+
+ def close(self):
+ raise NotImplementedError
+
+ def is_disconnected(self):
+ raise NotImplementedError
+
+
+async def _mock_fastapi_request(request: Request):
+ # TODO: Add more requests parameters.
+ return _FastApiMockRequest(
+ _body=await request.body(),
+ _json=await request.json(),
+ _headers=request.headers,
+ _method=request.method,
+ )
+
+
class _HttpMethod:
def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs):
"""This class is used to inject user defined methods within the App Rest API.
@@ -34,6 +107,7 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
self.method_annotations = method.__annotations__
# TODO: Validate the signature contains only pydantic models.
self.method_signature = inspect.signature(method)
+
if not self.attached_to_flow:
self.component_name = method.__name__
self.method = method
@@ -43,10 +117,16 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
self.timeout = timeout
self.kwargs = kwargs
+ # Enable the users to rely on FastAPI annotation typing with Request.
+ # Note: Only a part of the Request functionatilities are supported.
+ self._patch_fast_api_request()
+
def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
# 1: Get the route associated with the http method.
route = getattr(app, self.__class__.__name__.lower())
+ self._unpatch_fast_api_request()
+
# 2: Create a proxy function with the signature of the wrapped method.
fn = deepcopy(_signature_proxy_function)
fn.__annotations__ = self.method_annotations
@@ -69,6 +149,11 @@ async def _handle_request(*args, **kwargs):
@wraps(_signature_proxy_function)
async def _handle_request(*args, **kwargs):
async def fn(*args, **kwargs):
+ args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request)
+ for k, v in kwargs.items():
+ if hasattr(v, "__await__"):
+ kwargs[k] = await v
+
request_id = str(uuid4()).split("-")[0]
logger.debug(f"Processing request {request_id} for route: {self.route}")
request_queue.put(
@@ -85,7 +170,10 @@ async def fn(*args, **kwargs):
while request_id not in responses_store:
await asyncio.sleep(0.01)
if (time.time() - t0) > self.timeout:
- raise Exception("The response was never received.")
+ raise HTTPException(
+ status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="The response was never received.",
+ )
logger.debug(f"Processed request {request_id} for route: {self.route}")
@@ -101,6 +189,26 @@ async def fn(*args, **kwargs):
# 4: Register the user provided route to the Rest API.
route(self.route, **self.kwargs)(_handle_request)
+ def _patch_fast_api_request(self):
+ """This function replaces signature annotation for Request with its mock."""
+ for k, v in self.method_annotations.items():
+ if v == Request:
+ self.method_annotations[k] = _FastApiMockRequest
+
+ for v in self.method_signature.parameters.values():
+ if v._annotation == Request:
+ v._annotation = _FastApiMockRequest
+
+ def _unpatch_fast_api_request(self):
+ """This function replaces back signature annotation to fastapi Request."""
+ for k, v in self.method_annotations.items():
+ if v == _FastApiMockRequest:
+ self.method_annotations[k] = Request
+
+ for v in self.method_signature.parameters.values():
+ if v._annotation == _FastApiMockRequest:
+ v._annotation = Request
+
class Post(_HttpMethod):
pass
diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py
index 98df731415920..71e57d8b14565 100644
--- a/src/lightning_app/cli/lightning_cli.py
+++ b/src/lightning_app/cli/lightning_cli.py
@@ -233,7 +233,7 @@ def _run_app(
if not os.path.exists(file):
original_file = file
- file = cmd_install.gallery_apps_and_components(file, True, "latest", overwrite=False) # type: ignore[assignment] # noqa E501
+ file = cmd_install.gallery_apps_and_components(file, True, "latest", overwrite=True) # type: ignore[assignment] # noqa E501
if file is None:
click.echo(f"The provided entrypoint `{original_file}` doesn't exist.")
sys.exit(1)
diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py
index ca47c36071dae..0275596288ff0 100644
--- a/src/lightning_app/components/__init__.py
+++ b/src/lightning_app/components/__init__.py
@@ -1,4 +1,3 @@
-from lightning_app.components.auto_scaler import AutoScaler
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
from lightning_app.components.multi_node import (
@@ -9,14 +8,16 @@
)
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
+from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy
from lightning_app.components.serve.gradio import ServeGradio
-from lightning_app.components.serve.python_server import Image, Number, PythonServer
+from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.serve import ModelInferenceAPI
from lightning_app.components.serve.streamlit import ServeStreamlit
from lightning_app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner
__all__ = [
"AutoScaler",
+ "ColdStartProxy",
"DatabaseClient",
"Database",
"PopenPythonScript",
@@ -28,6 +29,8 @@
"PythonServer",
"Image",
"Number",
+ "Category",
+ "Text",
"MultiNode",
"LiteMultiNode",
"LightningTrainerScript",
diff --git a/src/lightning_app/components/multi_node/base.py b/src/lightning_app/components/multi_node/base.py
index 5662442b7375a..ac99abecff028 100644
--- a/src/lightning_app/components/multi_node/base.py
+++ b/src/lightning_app/components/multi_node/base.py
@@ -56,12 +56,12 @@ def run(
"""
super().__init__()
if num_nodes > 1 and not is_running_in_cloud():
- num_nodes = 1
warnings.warn(
f"You set {type(self).__name__}(num_nodes={num_nodes}, ...)` but this app is running locally."
" We assume you are debugging and will ignore the `num_nodes` argument."
" To run on multiple nodes in the cloud, launch your app with `--cloud`."
)
+ num_nodes = 1
self.ws = structures.List(
*[
work_cls(
diff --git a/src/lightning_app/components/serve/__init__.py b/src/lightning_app/components/serve/__init__.py
index cb46a71bf9ea5..39dafe2f7ff1b 100644
--- a/src/lightning_app/components/serve/__init__.py
+++ b/src/lightning_app/components/serve/__init__.py
@@ -1,5 +1,16 @@
+from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy
from lightning_app.components.serve.gradio import ServeGradio
-from lightning_app.components.serve.python_server import Image, Number, PythonServer
+from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.streamlit import ServeStreamlit
-__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]
+__all__ = [
+ "ServeGradio",
+ "ServeStreamlit",
+ "PythonServer",
+ "Image",
+ "Number",
+ "Category",
+ "Text",
+ "AutoScaler",
+ "ColdStartProxy",
+]
diff --git a/src/lightning_app/components/auto_scaler.py b/src/lightning_app/components/serve/auto_scaler.py
similarity index 59%
rename from src/lightning_app/components/auto_scaler.py
rename to src/lightning_app/components/serve/auto_scaler.py
index 13948ba50af89..4ba662603e552 100644
--- a/src/lightning_app/components/auto_scaler.py
+++ b/src/lightning_app/components/serve/auto_scaler.py
@@ -6,7 +6,7 @@
import uuid
from base64 import b64encode
from itertools import cycle
-from typing import Any, Dict, List, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
import requests
import uvicorn
@@ -15,11 +15,13 @@
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel
+from starlette.staticfiles import StaticFiles
from starlette.status import HTTP_401_UNAUTHORIZED
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
+from lightning_app.utilities.cloud import is_running_in_cloud
from lightning_app.utilities.imports import _is_aiohttp_available, requires
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
@@ -30,7 +32,53 @@
logger = Logger(__name__)
-def _raise_granular_exception(exception: Exception) -> None:
+class ColdStartProxy:
+ """ColdStartProxy allows users to configure the load balancer to use a proxy service while the work is cold
+ starting. This is useful with services that gets realtime requests but startup time for workers is high.
+
+ If the request body is same and the method is POST for the proxy service,
+ then the default implementation of `handle_request` can be used. In that case
+ initialize the proxy with the proxy url. Otherwise, the user can override the `handle_request`
+
+ Args:
+ proxy_url (str): The url of the proxy service
+ """
+
+ def __init__(self, proxy_url):
+ self.proxy_url = proxy_url
+ self.proxy_timeout = 50
+ # checking `asyncio.iscoroutinefunction` instead of `inspect.iscoroutinefunction`
+ # because AsyncMock in the tests requres the former to pass
+ if not asyncio.iscoroutinefunction(self.handle_request):
+ raise TypeError("handle_request must be an `async` function")
+
+ async def handle_request(self, request: BaseModel) -> Any:
+ """This method is called when the request is received while the work is cold starting. The default
+ implementation of this method is to forward the request body to the proxy service with POST method but the
+ user can override this method to handle the request in any way.
+
+ Args:
+ request (BaseModel): The request body, a pydantic model that is being
+ forwarded by load balancer which is a FastAPI service
+ """
+ try:
+ async with aiohttp.ClientSession() as session:
+ headers = {
+ "accept": "application/json",
+ "Content-Type": "application/json",
+ }
+ async with session.post(
+ self.proxy_url,
+ json=request.dict(),
+ timeout=self.proxy_timeout,
+ headers=headers,
+ ) as response:
+ return await response.json()
+ except Exception as ex:
+ raise HTTPException(status_code=500, detail=f"Error in proxy: {ex}")
+
+
+def _maybe_raise_granular_exception(exception: Exception) -> None:
"""Handle an exception from hitting the model servers."""
if not isinstance(exception, Exception):
return
@@ -114,20 +162,24 @@ class _LoadBalancer(LightningWork):
requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
timeout_inference_request: The number of seconds to wait for inference.
- \**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
+ api_name: The name to be displayed on the UI. Normally, it is the name of the work class
+ cold_start_proxy: The proxy service to use while the work is cold starting.
+ **kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
"""
@requires(["aiohttp"])
def __init__(
self,
- input_type: BaseModel,
- output_type: BaseModel,
+ input_type: Type[BaseModel],
+ output_type: Type[BaseModel],
endpoint: str,
max_batch_size: int = 8,
# all timeout args are in seconds
- timeout_batching: int = 1,
+ timeout_batching: float = 1,
timeout_keep_alive: int = 60,
timeout_inference_request: int = 60,
+ api_name: Optional[str] = "API", # used for displaying the name in the UI
+ cold_start_proxy: Union[ColdStartProxy, str, None] = None,
**kwargs: Any,
) -> None:
super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
@@ -135,36 +187,56 @@ def __init__(
self._output_type = output_type
self._timeout_keep_alive = timeout_keep_alive
self._timeout_inference_request = timeout_inference_request
- self.servers = []
+ self._servers = []
self.max_batch_size = max_batch_size
self.timeout_batching = timeout_batching
self._iter = None
self._batch = []
self._responses = {} # {request_id: response}
self._last_batch_sent = 0
+ self._server_status = {}
+ self._api_name = api_name
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
self.endpoint = endpoint
+ self._fastapi_app = None
+
+ self._cold_start_proxy = None
+ if cold_start_proxy:
+ if isinstance(cold_start_proxy, str):
+ self._cold_start_proxy = ColdStartProxy(proxy_url=cold_start_proxy)
+ elif isinstance(cold_start_proxy, ColdStartProxy):
+ self._cold_start_proxy = cold_start_proxy
+ else:
+ raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")
+
+ self.ready = False
- async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
- server = next(self._iter) # round-robin
+ async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
batch_request_data = _BatchRequestModel(inputs=request_data)
try:
+ self._server_status[server_url] = False
async with aiohttp.ClientSession() as session:
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
async with session.post(
- f"{server}{self.endpoint}",
+ f"{server_url}{self.endpoint}",
json=batch_request_data.dict(),
timeout=self._timeout_inference_request,
headers=headers,
) as response:
+ # resetting the server status so other requests can be
+ # scheduled on this node
+ if server_url in self._server_status:
+ # TODO - if the server returns an error, track that so
+ # we don't send more requests to it
+ self._server_status[server_url] = True
if response.status == 408:
raise HTTPException(408, "Request timed out")
response.raise_for_status()
@@ -177,48 +249,89 @@ async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
except Exception as ex:
result = {request[0]: ex for request in batch}
self._responses.update(result)
+ finally:
+ self._server_status[server_url] = True
+
+ def _find_free_server(self) -> Optional[str]:
+ existing = set(self._server_status.keys())
+ for server in existing:
+ status = self._server_status.get(server, None)
+ if status is None:
+ logger.error("Server is not found in the status list. This should not happen.")
+ if status:
+ return server
async def consumer(self):
+ """The consumer process that continuously checks for new requests and sends them to the API.
+
+ Two instances of this function should not be running with shared `_state_server` as that would create race
+ conditions
+ """
+ self._last_batch_sent = time.time()
while True:
await asyncio.sleep(0.05)
-
batch = self._batch[: self.max_batch_size]
- while batch and (
- (len(batch) == self.max_batch_size) or ((time.time() - self._last_batch_sent) > self.timeout_batching)
- ):
- asyncio.create_task(self.send_batch(batch))
-
- self._batch = self._batch[self.max_batch_size :]
- batch = self._batch[: self.max_batch_size]
+ is_batch_ready = len(batch) == self.max_batch_size
+ is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching
+ server_url = self._find_free_server()
+ # setting the server status to be busy! This will be reset by
+ # the send_batch function after the server responds
+ if server_url is None:
+ continue
+ if batch and (is_batch_ready or is_batch_timeout):
+ # find server with capacity
+ asyncio.create_task(self.send_batch(batch, server_url))
+ # resetting the batch array, TODO - not locking the array
+ self._batch = self._batch[len(batch) :]
self._last_batch_sent = time.time()
- async def process_request(self, data: BaseModel):
- if not self.servers:
+ async def process_request(self, data: BaseModel, request_id=uuid.uuid4().hex):
+ if not self._servers and not self._cold_start_proxy:
raise HTTPException(500, "None of the workers are healthy!")
- request_id = uuid.uuid4().hex
- request: Tuple = (request_id, data)
- self._batch.append(request)
+ # if no servers are available, proxy the request to cold start proxy handler
+ if not self._servers and self._cold_start_proxy:
+ return await self._cold_start_proxy.handle_request(data)
+ # if out of capacity, proxy the request to cold start proxy handler
+ if not self._has_processing_capacity() and self._cold_start_proxy:
+ return await self._cold_start_proxy.handle_request(data)
+
+ # if we have capacity, process the request
+ self._batch.append((request_id, data))
while True:
await asyncio.sleep(0.05)
-
if request_id in self._responses:
result = self._responses[request_id]
del self._responses[request_id]
- _raise_granular_exception(result)
+ _maybe_raise_granular_exception(result)
return result
+ def _has_processing_capacity(self):
+ """This function checks if we have processing capacity for one more request or not.
+
+ Depends on the value from here, we decide whether we should proxy the request or not
+ """
+ if not self._fastapi_app:
+ return False
+ active_server_count = len(self._servers)
+ max_processable = self.max_batch_size * active_server_count
+ current_req_count = self._fastapi_app.num_current_requests
+ return current_req_count < max_processable
+
def run(self):
- logger.info(f"servers: {self.servers}")
+ logger.info(f"servers: {self._servers}")
lock = asyncio.Lock()
- self._iter = cycle(self.servers)
+ self._iter = cycle(self._servers)
self._last_batch_sent = time.time()
fastapi_app = _create_fastapi("Load Balancer")
security = HTTPBasic()
fastapi_app.SEND_TASK = None
+ self._fastapi_app = fastapi_app
+
+ input_type = self._input_type
@fastapi_app.middleware("http")
async def current_request_counter(request: Request, call_next):
@@ -263,8 +376,8 @@ def authenticate_private_endpoint(credentials: HTTPBasicCredentials = Depends(se
@fastapi_app.get("/system/info", response_model=_SysInfo)
async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)):
return _SysInfo(
- num_workers=len(self.servers),
- servers=self.servers,
+ num_workers=len(self._servers),
+ servers=self._servers,
num_requests=fastapi_app.num_current_requests,
processing_time=fastapi_app.last_processing_time,
global_request_count=fastapi_app.global_request_count,
@@ -273,13 +386,34 @@ async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint))
@fastapi_app.put("/system/update-servers")
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
async with lock:
- self.servers = servers
- self._iter = cycle(self.servers)
+ self._servers = servers
+ self._iter = cycle(self._servers)
+ updated_servers = set()
+ # do not try to loop over the dict keys as the dict might change from other places
+ existing_servers = list(self._server_status.keys())
+ for server in servers:
+ updated_servers.add(server)
+ if server not in existing_servers:
+ self._server_status[server] = True
+ logger.info(f"Registering server {server}", self._server_status)
+ for existing in existing_servers:
+ if existing not in updated_servers:
+ logger.info(f"De-Registering server {existing}", self._server_status)
+ del self._server_status[existing]
@fastapi_app.post(self.endpoint, response_model=self._output_type)
- async def balance_api(inputs: self._input_type):
+ async def balance_api(inputs: input_type):
return await self.process_request(inputs)
+ endpoint_info_page = self._get_endpoint_info_page()
+ if endpoint_info_page:
+ fastapi_app.mount(
+ "/endpoint-info", StaticFiles(directory=endpoint_info_page.serve_dir, html=True), name="static"
+ )
+
+ logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
+ self.ready = True
+
uvicorn.run(
fastapi_app,
host=self.host,
@@ -294,7 +428,7 @@ def update_servers(self, server_works: List[LightningWork]):
AutoScaler uses this method to increase/decrease the number of works.
"""
- old_servers = set(self.servers)
+ old_servers = set(self._servers)
server_urls: List[str] = [server.url for server in server_works if server.url]
new_servers = set(server_urls)
@@ -332,6 +466,60 @@ def send_request_to_update_servers(self, servers: List[str]):
response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10)
response.raise_for_status()
+ @staticmethod
+ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
+ if not hasattr(datatype, "schema"):
+ # not a pydantic model
+ raise TypeError(f"datatype must be a pydantic model, for the UI to be generated. but got {datatype}")
+
+ if hasattr(datatype, "get_sample_data"):
+ return datatype.get_sample_data()
+
+ datatype_props = datatype.schema()["properties"]
+ out: Dict[str, Any] = {}
+ lut = {"string": "data string", "number": 0.0, "integer": 0, "boolean": False}
+ for k, v in datatype_props.items():
+ if v["type"] not in lut:
+ raise TypeError("Unsupported type")
+ out[k] = lut[v["type"]]
+ return out
+
+ def get_code_sample(self, url: str) -> Optional[str]:
+ input_type: Any = self._input_type
+ output_type: Any = self._output_type
+
+ if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
+ return None
+ return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"
+
+ def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F821
+ try:
+ from lightning_api_access import APIAccessFrontend
+ except ModuleNotFoundError:
+ logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI")
+ return
+
+ if is_running_in_cloud():
+ url = f"{self._future_url}{self.endpoint}"
+ else:
+ url = f"http://localhost:{self.port}{self.endpoint}"
+
+ frontend_objects = {"name": self._api_name, "url": url, "method": "POST", "request": None, "response": None}
+ code_samples = self.get_code_sample(url)
+ if code_samples:
+ frontend_objects["code_samples"] = code_samples
+ # TODO also set request/response for JS UI
+ else:
+ try:
+ request = self._get_sample_dict_from_datatype(self._input_type)
+ response = self._get_sample_dict_from_datatype(self._output_type)
+ except TypeError:
+ return None
+ else:
+ frontend_objects["request"] = request
+ frontend_objects["response"] = response
+ return APIAccessFrontend(apis=[frontend_objects])
+
class AutoScaler(LightningFlow):
"""The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in
@@ -341,12 +529,14 @@ class AutoScaler(LightningFlow):
Args:
min_replicas: The number of works to start when app initializes.
max_replicas: The max number of works to spawn to handle the incoming requests.
- autoscale_interval: The number of seconds to wait before checking whether to upscale or downscale the works.
+ scale_out_interval: The number of seconds to wait before checking whether to increase the number of servers.
+ scale_in_interval: The number of seconds to wait before checking whether to decrease the number of servers.
endpoint: Provide the REST API path.
max_batch_size: (auto-batching) The number of requests to process at once.
timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process.
input_type: Input type.
output_type: Output type.
+ cold_start_proxy: If provided, the proxy will be used while the worker machines are warming up.
.. testcode::
@@ -358,7 +548,8 @@ class AutoScaler(LightningFlow):
MyPythonServer,
min_replicas=1,
max_replicas=8,
- autoscale_interval=10,
+ scale_out_interval=10,
+ scale_in_interval=10,
)
)
@@ -387,7 +578,8 @@ def scale(self, replicas: int, metrics: dict) -> int:
MyPythonServer,
min_replicas=1,
max_replicas=8,
- autoscale_interval=10,
+ scale_out_interval=10,
+ scale_in_interval=10,
max_batch_size=8, # for auto batching
timeout_batching=1, # for auto batching
)
@@ -399,12 +591,14 @@ def __init__(
work_cls: Type[LightningWork],
min_replicas: int = 1,
max_replicas: int = 4,
- autoscale_interval: int = 10,
+ scale_out_interval: int = 10,
+ scale_in_interval: int = 10,
max_batch_size: int = 8,
timeout_batching: float = 1,
endpoint: str = "api/predict",
- input_type: BaseModel = Dict,
- output_type: BaseModel = Dict,
+ input_type: Type[BaseModel] = Dict,
+ output_type: Type[BaseModel] = Dict,
+ cold_start_proxy: Union[ColdStartProxy, str, None] = None,
*work_args: Any,
**work_kwargs: Any,
) -> None:
@@ -418,7 +612,8 @@ def __init__(
self._input_type = input_type
self._output_type = output_type
- self.autoscale_interval = autoscale_interval
+ self.scale_out_interval = scale_out_interval
+ self.scale_in_interval = scale_in_interval
self.max_batch_size = max_batch_size
if max_replicas < min_replicas:
@@ -438,6 +633,8 @@ def __init__(
timeout_batching=timeout_batching,
cache_calls=True,
parallel=True,
+ api_name=self._work_cls.__name__,
+ cold_start_proxy=cold_start_proxy,
)
for _ in range(min_replicas):
work = self.create_work()
@@ -447,6 +644,10 @@ def __init__(
def workers(self) -> List[LightningWork]:
return [self.get_work(i) for i in range(self.num_replicas)]
+ @property
+ def ready(self) -> bool:
+ return self.load_balancer.ready
+
def create_work(self) -> LightningWork:
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
cloud_compute = self._work_kwargs.get("cloud_compute", None)
@@ -511,9 +712,13 @@ def scale(self, replicas: int, metrics: dict) -> int:
The target number of running works. The value will be adjusted after this method runs
so that it satisfies ``min_replicas<=replicas<=max_replicas``.
"""
- pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
- replicas + metrics["pending_works"]
- )
+ pending_requests = metrics["pending_requests"]
+ active_or_pending_works = replicas + metrics["pending_works"]
+
+ if active_or_pending_works == 0:
+ return 1 if pending_requests > 0 else 0
+
+ pending_requests_per_running_or_pending_work = pending_requests / active_or_pending_works
# scale out if the number of pending requests exceeds max batch size.
max_requests_per_work = self.max_batch_size
@@ -539,11 +744,6 @@ def num_pending_works(self) -> int:
def autoscale(self) -> None:
"""Adjust the number of works based on the target number returned by ``self.scale``."""
- if time.time() - self._last_autoscale < self.autoscale_interval:
- return
-
- self.load_balancer.update_servers(self.workers)
-
metrics = {
"pending_requests": self.num_pending_requests,
"pending_works": self.num_pending_works,
@@ -555,24 +755,33 @@ def autoscale(self) -> None:
min(self.max_replicas, self.scale(self.num_replicas, metrics)),
)
- # upscale
- num_workers_to_add = num_target_workers - self.num_replicas
- for _ in range(num_workers_to_add):
- logger.info(f"Upscaling from {self.num_replicas} to {self.num_replicas + 1}")
- work = self.create_work()
- new_work_id = self.add_work(work)
- logger.info(f"Work created: '{new_work_id}'")
-
- # downscale
- num_workers_to_remove = self.num_replicas - num_target_workers
- for _ in range(num_workers_to_remove):
- logger.info(f"Downscaling from {self.num_replicas} to {self.num_replicas - 1}")
- removed_work_id = self.remove_work(self.num_replicas - 1)
- logger.info(f"Work removed: '{removed_work_id}'")
+ # scale-out
+ if time.time() - self._last_autoscale > self.scale_out_interval:
+ num_workers_to_add = num_target_workers - self.num_replicas
+ for _ in range(num_workers_to_add):
+ logger.info(f"Scaling out from {self.num_replicas} to {self.num_replicas + 1}")
+ work = self.create_work()
+ # TODO: move works into structures
+ new_work_id = self.add_work(work)
+ logger.info(f"Work created: '{new_work_id}'")
+ if num_workers_to_add > 0:
+ self._last_autoscale = time.time()
+
+ # scale-in
+ if time.time() - self._last_autoscale > self.scale_in_interval:
+ num_workers_to_remove = self.num_replicas - num_target_workers
+ for _ in range(num_workers_to_remove):
+ logger.info(f"Scaling in from {self.num_replicas} to {self.num_replicas - 1}")
+ removed_work_id = self.remove_work(self.num_replicas - 1)
+ logger.info(f"Work removed: '{removed_work_id}'")
+ if num_workers_to_remove > 0:
+ self._last_autoscale = time.time()
self.load_balancer.update_servers(self.workers)
- self._last_autoscale = time.time()
def configure_layout(self):
- tabs = [{"name": "Swagger", "content": self.load_balancer.url}]
+ tabs = [
+ {"name": "Endpoint Info", "content": f"{self.load_balancer.url}/endpoint-info"},
+ {"name": "Swagger", "content": self.load_balancer.url},
+ ]
return tabs
diff --git a/src/lightning_app/components/serve/gradio.py b/src/lightning_app/components/serve/gradio.py
index 7c07129d39b25..29af372e03172 100644
--- a/src/lightning_app/components/serve/gradio.py
+++ b/src/lightning_app/components/serve/gradio.py
@@ -42,6 +42,8 @@ def __init__(self, *args, **kwargs):
assert self.outputs
self._model = None
+ self.ready = False
+
@property
def model(self):
return self._model
@@ -62,6 +64,7 @@ def run(self, *args, **kwargs):
self._model = self.build_model()
fn = partial(self.predict, *args, **kwargs)
fn.__name__ = self.predict.__name__
+ self.ready = True
gradio.Interface(
fn=fn,
inputs=self.inputs,
diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py
index 40b7e83a3bdca..55760bd06e3be 100644
--- a/src/lightning_app/components/serve/python_server.py
+++ b/src/lightning_app/components/serve/python_server.py
@@ -2,9 +2,9 @@
import base64
import os
import platform
-from pathlib import Path
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, TYPE_CHECKING
+import requests
import uvicorn
from fastapi import FastAPI
from lightning_utilities.core.imports import compare_version, module_available
@@ -14,6 +14,9 @@
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_torch_available, requires
+if TYPE_CHECKING:
+ from lightning_app.frontend.frontend import Frontend
+
logger = Logger(__name__)
# Skip doctests if requirements aren't available
@@ -48,18 +51,80 @@ class Image(BaseModel):
image: Optional[str]
@staticmethod
- def _get_sample_data() -> Dict[Any, Any]:
- imagepath = Path(__file__).parent / "catimage.png"
- with open(imagepath, "rb") as image_file:
- encoded_string = base64.b64encode(image_file.read())
- return {"image": encoded_string.decode("UTF-8")}
+ def get_sample_data() -> Dict[Any, Any]:
+ url = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
+ img = requests.get(url).content
+ img = base64.b64encode(img).decode("UTF-8")
+ return {"image": img}
+
+ @staticmethod
+ def request_code_sample(url: str) -> str:
+ return (
+ """import base64
+from pathlib import Path
+import requests
+
+imgurl = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
+img = requests.get(imgurl).content
+img = base64.b64encode(img).decode("UTF-8")
+response = requests.post('"""
+ + url
+ + """', json={
+ "image": img
+})"""
+ )
+
+ @staticmethod
+ def response_code_sample() -> str:
+ return """img = response.json()["image"]
+img = base64.b64decode(img.encode("utf-8"))
+Path("response.png").write_bytes(img)
+"""
+
+
+class Category(BaseModel):
+ category: Optional[int]
+
+ @staticmethod
+ def get_sample_data() -> Dict[Any, Any]:
+ return {"prediction": 463}
+
+ @staticmethod
+ def response_code_sample() -> str:
+ return """print("Predicted category is: ", response.json()["category"])
+"""
+
+
+class Text(BaseModel):
+ text: Optional[str]
+
+ @staticmethod
+ def get_sample_data() -> Dict[Any, Any]:
+ return {"text": "A portrait of a person looking away from the camera"}
+
+ @staticmethod
+ def request_code_sample(url: str) -> str:
+ return (
+ """import base64
+from pathlib import Path
+import requests
+
+response = requests.post('"""
+ + url
+ + """', json={
+ "text": "A portrait of a person looking away from the camera"
+})
+"""
+ )
class Number(BaseModel):
+ # deprecated
+ # TODO remove this in favour of Category
prediction: Optional[int]
@staticmethod
- def _get_sample_data() -> Dict[Any, Any]:
+ def get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}
@@ -128,6 +193,8 @@ def predict(self, request):
self._input_type = input_type
self._output_type = output_type
+ self.ready = False
+
def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
@@ -154,8 +221,8 @@ def predict(self, request: Any) -> Any:
@staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
- if hasattr(datatype, "_get_sample_data"):
- return datatype._get_sample_data()
+ if hasattr(datatype, "get_sample_data"):
+ return datatype.get_sample_data()
datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {}
@@ -187,7 +254,15 @@ def predict_fn(request: input_type): # type: ignore
fastapi_app.post("/predict", response_model=output_type)(predict_fn)
- def configure_layout(self) -> None:
+ def get_code_sample(self, url: str) -> Optional[str]:
+ input_type: Any = self.configure_input_type()
+ output_type: Any = self.configure_output_type()
+
+ if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
+ return None
+ return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"
+
+ def configure_layout(self) -> Optional["Frontend"]:
try:
from lightning_api_access import APIAccessFrontend
except ModuleNotFoundError:
@@ -203,17 +278,19 @@ def configure_layout(self) -> None:
except TypeError:
return None
- return APIAccessFrontend(
- apis=[
- {
- "name": class_name,
- "url": url,
- "method": "POST",
- "request": request,
- "response": response,
- }
- ]
- )
+ frontend_payload = {
+ "name": class_name,
+ "url": url,
+ "method": "POST",
+ "request": request,
+ "response": response,
+ }
+
+ code_sample = self.get_code_sample(url)
+ if code_sample:
+ frontend_payload["code_sample"] = code_sample
+
+ return APIAccessFrontend(apis=[frontend_payload])
def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.
@@ -225,5 +302,8 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
fastapi_app = FastAPI()
self._attach_predict_fn(fastapi_app)
- logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}")
+ self.ready = True
+ logger.info(
+ f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}"
+ )
uvicorn.run(app=fastapi_app, host=self.host, port=self.port, log_level="error")
diff --git a/src/lightning_app/components/serve/serve.py b/src/lightning_app/components/serve/serve.py
index 8b6f35364cc38..50caca0079852 100644
--- a/src/lightning_app/components/serve/serve.py
+++ b/src/lightning_app/components/serve/serve.py
@@ -64,6 +64,8 @@ def __init__(
self.workers = workers
self._model = None
+ self.ready = False
+
@property
def model(self):
return self._model
@@ -108,9 +110,11 @@ def run(self):
"serve:fastapi_service",
]
process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__))
+ self.ready = True
process.wait()
else:
self._populate_app(fastapi_service)
+ self.ready = True
self._launch_server(fastapi_service)
def _populate_app(self, fastapi_service: FastAPI):
diff --git a/src/lightning_app/components/serve/streamlit.py b/src/lightning_app/components/serve/streamlit.py
index 9b943a1708fa3..720139f93f25e 100644
--- a/src/lightning_app/components/serve/streamlit.py
+++ b/src/lightning_app/components/serve/streamlit.py
@@ -20,6 +20,8 @@ class ServeStreamlit(LightningWork, abc.ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ self.ready = False
+
self._process = None
@property
@@ -58,6 +60,7 @@ def run(self) -> None:
],
env=env,
)
+ self.ready = True
self._process.wait()
def on_exit(self) -> None:
diff --git a/src/lightning_app/core/api.py b/src/lightning_app/core/api.py
index e6f7b6ad0024c..51b57b2c732a2 100644
--- a/src/lightning_app/core/api.py
+++ b/src/lightning_app/core/api.py
@@ -1,10 +1,12 @@
import asyncio
+import json
import os
import queue
import sys
import traceback
from copy import deepcopy
from multiprocessing import Queue
+from pathlib import Path
from tempfile import TemporaryDirectory
from threading import Event, Lock, Thread
from time import sleep
@@ -24,16 +26,17 @@
from lightning_app.api.http_methods import _HttpMethod
from lightning_app.api.request_types import _DeltaRequest
from lightning_app.core.constants import (
- CLOUD_QUEUE_TYPE,
ENABLE_PULLING_STATE_ENDPOINT,
ENABLE_PUSHING_STATE_ENDPOINT,
ENABLE_STATE_WEBSOCKET,
ENABLE_UPLOAD_ENDPOINT,
FRONTEND_DIR,
+ get_cloud_queue_type,
)
from lightning_app.core.queues import QueuingSystem
from lightning_app.storage import Drive
from lightning_app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
+from lightning_app.utilities.app_status import AppStatus
from lightning_app.utilities.cloud import is_running_in_cloud
from lightning_app.utilities.component import _context
from lightning_app.utilities.enum import ComponentContext, OpenAPITags
@@ -66,18 +69,25 @@ class SessionMiddleware:
lock = Lock()
app_spec: Optional[List] = None
+app_status: Optional[AppStatus] = None
+app_annotations: Optional[List] = None
+
# In the future, this would be abstracted to support horizontal scaling.
responses_store = {}
logger = Logger(__name__)
-
# This can be replaced with a consumer that publishes states in a kv-store
# in a serverless architecture
class UIRefresher(Thread):
- def __init__(self, api_publish_state_queue, api_response_queue, refresh_interval: float = 0.1) -> None:
+ def __init__(
+ self,
+ api_publish_state_queue,
+ api_response_queue,
+ refresh_interval: float = 0.1,
+ ) -> None:
super().__init__(daemon=True)
self.api_publish_state_queue = api_publish_state_queue
self.api_response_queue = api_response_queue
@@ -98,7 +108,8 @@ def run(self):
def run_once(self):
try:
- state = self.api_publish_state_queue.get(timeout=0)
+ global app_status
+ state, app_status = self.api_publish_state_queue.get(timeout=0)
with lock:
global_app_state_store.set_app_state(TEST_SESSION_UUID, state)
except queue.Empty:
@@ -326,12 +337,30 @@ async def upload_file(response: Response, filename: str, uploaded_file: UploadFi
return f"Successfully uploaded '{filename}' to the Drive"
+@fastapi_service.get("/api/v1/status", response_model=AppStatus)
+async def get_status() -> AppStatus:
+ """Get the current status of the app and works."""
+ global app_status
+ if app_status is None:
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="App status hasn't been reported yet."
+ )
+ return app_status
+
+
+@fastapi_service.get("/api/v1/annotations", response_class=JSONResponse)
+async def get_annotations() -> Union[List, Dict]:
+ """Get the annotations associated with this app."""
+ global app_annotations
+ return app_annotations or []
+
+
@fastapi_service.get("/healthz", status_code=200)
async def healthz(response: Response):
"""Health check endpoint used in the cloud FastAPI servers to check the status periodically."""
# check the queue status only if running in cloud
if is_running_in_cloud():
- queue_obj = QueuingSystem(CLOUD_QUEUE_TYPE).get_queue(queue_name="healthz")
+ queue_obj = QueuingSystem(get_cloud_queue_type()).get_queue(queue_name="healthz")
# this is only being implemented on Redis Queue. For HTTP Queue, it doesn't make sense to have every single
# app checking the status of the Queue server
if not queue_obj.is_running:
@@ -421,6 +450,7 @@ def start_server(
global api_app_delta_queue
global global_app_state_store
global app_spec
+ global app_annotations
app_spec = spec
api_app_delta_queue = api_delta_queue
@@ -430,6 +460,12 @@ def start_server(
global_app_state_store.add(TEST_SESSION_UUID)
+ # Load annotations
+ annotations_path = Path("lightning-annotations.json").resolve()
+ if annotations_path.exists():
+ with open(annotations_path) as f:
+ app_annotations = json.load(f)
+
refresher = UIRefresher(api_publish_state_queue, api_response_queue)
refresher.setDaemon(True)
refresher.start()
diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py
index 9c3aeeb650de0..1b113e1ceb48f 100644
--- a/src/lightning_app/core/app.py
+++ b/src/lightning_app/core/app.py
@@ -35,6 +35,7 @@
_should_dispatch_app,
Logger,
)
+from lightning_app.utilities.app_status import AppStatus
from lightning_app.utilities.commands.base import _process_requests
from lightning_app.utilities.component import _convert_paths_after_init, _validate_root_flow
from lightning_app.utilities.enum import AppStage, CacheCallsKeys
@@ -140,6 +141,7 @@ def __init__(
self.exception = None
self.collect_changes: bool = True
+ self.status: Optional[AppStatus] = None
# TODO: Enable ready locally for opening the UI.
self.ready = False
@@ -150,6 +152,7 @@ def __init__(
self.checkpointing: bool = False
self._update_layout()
+ self._update_status()
self.is_headless: Optional[bool] = None
@@ -353,6 +356,8 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque
deltas.append(delta)
else:
api_or_command_request_deltas.append(delta)
+ else:
+ break
if api_or_command_request_deltas:
_process_requests(self, api_or_command_request_deltas)
@@ -418,6 +423,7 @@ def run_once(self):
self._update_layout()
self._update_is_headless()
+ self._update_status()
self.maybe_apply_changes()
if self.checkpointing and self._should_snapshot():
@@ -485,19 +491,12 @@ def _run(self) -> bool:
self._original_state = deepcopy(self.state)
done = False
- # TODO: Re-enable the `ready` property once issues are resolved
- if not self.root.ready:
- warnings.warn(
- "One of your Flows returned `.ready` as `False`. "
- "This feature is not yet enabled so this will be ignored.",
- UserWarning,
- )
- self.ready = True
+ self.ready = self.root.ready
self._start_with_flow_works()
- if self.ready and self.should_publish_changes_to_api and self.api_publish_state_queue:
- self.api_publish_state_queue.put(self.state_vars)
+ if self.should_publish_changes_to_api and self.api_publish_state_queue is not None:
+ self.api_publish_state_queue.put((self.state_vars, self.status))
self._reset_run_time_monitor()
@@ -506,8 +505,8 @@ def _run(self) -> bool:
self._update_run_time_monitor()
- if self.ready and self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue:
- self.api_publish_state_queue.put(self.state_vars)
+ if self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue is not None:
+ self.api_publish_state_queue.put((self.state_vars, self.status))
self._has_updated = False
@@ -532,6 +531,23 @@ def _update_is_headless(self) -> None:
# This ensures support for apps which dynamically add a UI at runtime.
_handle_is_headless(self)
+ def _update_status(self) -> None:
+ old_status = self.status
+
+ work_statuses = {}
+ for work in breadth_first(self.root, types=(lightning_app.LightningWork,)):
+ work_statuses[work.name] = work.status
+
+ self.status = AppStatus(
+ is_ui_ready=self.ready,
+ work_statuses=work_statuses,
+ )
+
+ # If the work statuses changed, the state delta will trigger an update.
+ # If ready has changed, we trigger an update manually.
+ if self.status != old_status:
+ self._has_updated = True
+
def _apply_restarting(self) -> bool:
self._reset_original_state()
# apply stage after restoring the original state.
diff --git a/src/lightning_app/core/constants.py b/src/lightning_app/core/constants.py
index da99db9018320..6882598cab223 100644
--- a/src/lightning_app/core/constants.py
+++ b/src/lightning_app/core/constants.py
@@ -1,5 +1,6 @@
import os
from pathlib import Path
+from typing import Optional
import lightning_cloud.env
@@ -13,7 +14,7 @@ def get_lightning_cloud_url() -> str:
SUPPORTED_PRIMITIVE_TYPES = (type(None), str, int, float, bool)
STATE_UPDATE_TIMEOUT = 0.001
-STATE_ACCUMULATE_WAIT = 0.05
+STATE_ACCUMULATE_WAIT = 0.15
# Duration in seconds of a moving average of a full flow execution
# beyond which an exception is raised.
FLOW_DURATION_THRESHOLD = 1.0
@@ -25,7 +26,6 @@ def get_lightning_cloud_url() -> str:
APP_SERVER_PORT = _find_lit_app_port(7501)
APP_STATE_MAX_SIZE_BYTES = 1024 * 1024 # 1 MB
-CLOUD_QUEUE_TYPE = os.getenv("LIGHTNING_CLOUD_QUEUE_TYPE", None)
WARNING_QUEUE_SIZE = 1000
# different flag because queue debug can be very noisy, and almost always not useful unless debugging the queue itself.
QUEUE_DEBUG_ENABLED = bool(int(os.getenv("LIGHTNING_QUEUE_DEBUG_ENABLED", "0")))
@@ -77,5 +77,9 @@ def enable_multiple_works_in_default_container() -> bool:
return bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0")))
+def get_cloud_queue_type() -> Optional[str]:
+ return os.getenv("LIGHTNING_CLOUD_QUEUE_TYPE", None)
+
+
# Number of seconds to wait between filesystem checks when waiting for files in remote storage
REMOTE_STORAGE_WAIT = 0.5
diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py
index 0be6b6f8ade98..ee2931a6afadb 100644
--- a/src/lightning_app/core/flow.py
+++ b/src/lightning_app/core/flow.py
@@ -249,10 +249,7 @@ def __getattr__(self, item):
@property
def ready(self) -> bool:
- """Not currently enabled.
-
- Override to customize when your App should be ready.
- """
+ """Override to customize when your App should be ready."""
flows = self.flows
return all(flow.ready for flow in flows.values()) if flows else True
@@ -800,7 +797,7 @@ def __init__(self, work):
@property
def ready(self) -> bool:
ready = getattr(self.work, "ready", None)
- if ready:
+ if ready is not None:
return ready
return self.work.url != ""
diff --git a/src/lightning_app/core/queues.py b/src/lightning_app/core/queues.py
index db150a57eb098..0579552de1875 100644
--- a/src/lightning_app/core/queues.py
+++ b/src/lightning_app/core/queues.py
@@ -364,12 +364,18 @@ def get(self, timeout: int = None) -> Any:
# timeout is some value - loop until the timeout is reached
start_time = time.time()
- timeout += 0.1 # add 0.1 seconds as a safe margin
while (time.time() - start_time) < timeout:
try:
return self._get()
except queue.Empty:
- time.sleep(HTTP_QUEUE_REFRESH_INTERVAL)
+ # Note: In theory, there isn't a need for a sleep as the queue shouldn't
+ # block the flow if the queue is empty.
+ # However, as the Http Server can saturate,
+ # let's add a sleep here if a higher timeout is provided
+ # than the default timeout
+ if timeout > self.default_timeout:
+ time.sleep(0.05)
+ pass
def _get(self):
resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "pop"})
diff --git a/src/lightning_app/core/work.py b/src/lightning_app/core/work.py
index 029f01fd2f7ae..863d50db47cec 100644
--- a/src/lightning_app/core/work.py
+++ b/src/lightning_app/core/work.py
@@ -12,13 +12,13 @@
from lightning_app.storage.drive import _maybe_create_drive, Drive
from lightning_app.storage.payload import Payload
from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, is_overridden
+from lightning_app.utilities.app_status import WorkStatus
from lightning_app.utilities.component import _is_flow_context, _sanitize_state
from lightning_app.utilities.enum import (
CacheCallsKeys,
make_status,
WorkFailureReasons,
WorkStageStatus,
- WorkStatus,
WorkStopReasons,
)
from lightning_app.utilities.exceptions import LightningWorkException
@@ -51,7 +51,7 @@ class LightningWork:
_run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor
# TODO: Move to spawn for all Operating System.
- _start_method = "spawn" if sys.platform == "win32" else "fork"
+ _start_method = "spawn" if sys.platform in ("darwin", "win32") else "fork"
def __init__(
self,
@@ -119,7 +119,16 @@ def __init__(
" in the next version. Use `cache_calls` instead."
)
self._cache_calls = run_once if run_once is not None else cache_calls
- self._state = {"_host", "_port", "_url", "_future_url", "_internal_ip", "_restarting", "_cloud_compute"}
+ self._state = {
+ "_host",
+ "_port",
+ "_url",
+ "_future_url",
+ "_internal_ip",
+ "_restarting",
+ "_cloud_compute",
+ "_display_name",
+ }
self._parallel = parallel
self._host: str = host
self._port: Optional[int] = port
@@ -129,6 +138,7 @@ def __init__(
# setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator
self._setattr_replacement: Optional[Callable[[str, Any], None]] = None
self._name = ""
+ self._display_name = ""
# The ``self._calls`` is used to track whether the run
# method with a given set of input arguments has already been called.
# Example of its usage:
@@ -207,6 +217,22 @@ def name(self):
"""Returns the name of the LightningWork."""
return self._name
+ @property
+ def display_name(self):
+ """Returns the display name of the LightningWork in the cloud.
+
+ The display name needs to set before the run method of the work is called.
+ """
+ return self._display_name
+
+ @display_name.setter
+ def display_name(self, display_name: str):
+ """Sets the display name of the LightningWork in the cloud."""
+ if not self.has_started:
+ self._display_name = display_name
+ elif self._display_name != display_name:
+ raise RuntimeError("The display name can be set only before the work has started.")
+
@property
def cache_calls(self) -> bool:
"""Returns whether the ``run`` method should cache its input arguments and not run again when provided with
@@ -604,12 +630,12 @@ def on_exit(self):
pass
def stop(self):
- """Stops LightingWork component and shuts down hardware provisioned via L.CloudCompute."""
+ """Stops LightingWork component and shuts down hardware provisioned via L.CloudCompute.
+
+ This can only be called from a ``LightningFlow``.
+ """
if not self._backend:
- raise Exception(
- "Can't stop the work, it looks like it isn't attached to a LightningFlow. "
- "Make sure to assign the Work to a flow instance."
- )
+ raise RuntimeError(f"Only the `LightningFlow` can request this work ({self.name!r}) to stop.")
if self.status.stage == WorkStageStatus.STOPPED:
return
latest_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
@@ -618,6 +644,19 @@ def stop(self):
app = _LightningAppRef().get_current()
self._backend.stop_work(app, self)
+ def delete(self):
+ """Delete LightingWork component and shuts down hardware provisioned via L.CloudCompute.
+
+ Locally, the work.delete() behaves as work.stop().
+ """
+ if not self._backend:
+ raise Exception(
+ "Can't delete the work, it looks like it isn't attached to a LightningFlow. "
+ "Make sure to assign the Work to a flow instance."
+ )
+ app = _LightningAppRef().get_current()
+ self._backend.delete_work(app, self)
+
def _check_run_is_implemented(self) -> None:
if not is_overridden("run", instance=self, parent=LightningWork):
raise TypeError(
diff --git a/src/lightning_app/runners/backends/mp_process.py b/src/lightning_app/runners/backends/mp_process.py
index dc0681390046e..36f3cb8097604 100644
--- a/src/lightning_app/runners/backends/mp_process.py
+++ b/src/lightning_app/runners/backends/mp_process.py
@@ -88,6 +88,9 @@ def stop_work(self, app, work: "lightning_app.LightningWork") -> None:
work_manager: MultiProcessWorkManager = app.processes[work.name]
work_manager.kill()
+ def delete_work(self, app, work: "lightning_app.LightningWork") -> None:
+ self.stop_work(app, work)
+
class CloudMultiProcessingBackend(MultiProcessingBackend):
def __init__(self, *args, **kwargs):
@@ -108,3 +111,6 @@ def stop_work(self, app, work: "lightning_app.LightningWork") -> None:
disable_port(work._port)
self.ports = [port for port in self.ports if port != work._port]
return super().stop_work(app, work)
+
+ def delete_work(self, app, work: "lightning_app.LightningWork") -> None:
+ self.stop_work(app, work)
diff --git a/src/lightning_app/runners/cloud.py b/src/lightning_app/runners/cloud.py
index 0011c29ce0e6b..884f809266582 100644
--- a/src/lightning_app/runners/cloud.py
+++ b/src/lightning_app/runners/cloud.py
@@ -1,5 +1,6 @@
import fnmatch
import json
+import os
import random
import re
import string
@@ -49,7 +50,6 @@
from lightning_app import LightningWork
from lightning_app.core.app import LightningApp
from lightning_app.core.constants import (
- CLOUD_QUEUE_TYPE,
CLOUD_UPLOAD_WARNING,
DEFAULT_NUMBER_OF_EXPOSED_PORTS,
DISABLE_DEPENDENCY_CACHE,
@@ -59,6 +59,7 @@
ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER,
ENABLE_PULLING_STATE_ENDPOINT,
ENABLE_PUSHING_STATE_ENDPOINT,
+ get_cloud_queue_type,
get_lightning_cloud_url,
)
from lightning_app.runners.backends.cloud import CloudBackend
@@ -417,9 +418,11 @@ def dispatch(
initial_port += 1
queue_server_type = V1QueueServerType.UNSPECIFIED
- if CLOUD_QUEUE_TYPE == "http":
+ # Note: Enable app to select their own queue type.
+ queue_type = get_cloud_queue_type()
+ if queue_type == "http":
queue_server_type = V1QueueServerType.HTTP
- elif CLOUD_QUEUE_TYPE == "redis":
+ elif queue_type == "redis":
queue_server_type = V1QueueServerType.REDIS
release_body = Body8(
@@ -445,7 +448,7 @@ def dispatch(
raise RuntimeError("The source upload url is empty.")
if getattr(lightning_app_release, "cluster_id", None):
- logger.info(f"Running app on {lightning_app_release.cluster_id}")
+ print(f"Running app on {lightning_app_release.cluster_id}")
# Save the config for re-runs
app_config.save_to_dir(root)
@@ -495,7 +498,8 @@ def dispatch(
if lightning_app_instance.status.phase == V1LightningappInstanceState.FAILED:
raise RuntimeError("Failed to create the application. Cannot upload the source code.")
- if open_ui:
+ # TODO: Remove testing dependency, but this would open a tab for each test...
+ if open_ui and "PYTEST_CURRENT_TEST" not in os.environ:
click.launch(self._get_app_url(lightning_app_instance, not has_sufficient_credits))
if cleanup_handle:
@@ -589,6 +593,10 @@ def _project_has_sufficient_credits(self, project: V1Membership, app: Optional[L
@classmethod
def load_app_from_file(cls, filepath: str) -> "LightningApp":
"""Load a LightningApp from a file, mocking the imports."""
+
+ # Pretend we are running in the cloud when loading the app locally
+ os.environ["LAI_RUNNING_IN_CLOUD"] = "1"
+
try:
app = load_app_from_file(filepath, raise_exception=True, mock_imports=True)
except FileNotFoundError as e:
@@ -599,6 +607,8 @@ def load_app_from_file(cls, filepath: str) -> "LightningApp":
# Create a generic app.
logger.info("Could not load the app locally. Starting the app directly on the cloud.")
app = LightningApp(EmptyFlow())
+ finally:
+ del os.environ["LAI_RUNNING_IN_CLOUD"]
return app
@staticmethod
diff --git a/src/lightning_app/runners/runtime.py b/src/lightning_app/runners/runtime.py
index a30b78f9178a0..c6d8c3d4394b9 100644
--- a/src/lightning_app/runners/runtime.py
+++ b/src/lightning_app/runners/runtime.py
@@ -121,7 +121,7 @@ def terminate(self) -> None:
self._add_stopped_status_to_work(work)
# Publish the updated state and wait for the frontend to update.
- self.app.api_publish_state_queue.put(self.app.state)
+ self.app.api_publish_state_queue.put((self.app.state, self.app.status))
for thread in self.threads + self.app.threads:
thread.join(timeout=0)
diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py
index 40b705458dd49..5602af1e523be 100644
--- a/src/lightning_app/testing/testing.py
+++ b/src/lightning_app/testing/testing.py
@@ -386,34 +386,24 @@ def run_app_in_cloud(
process = Process(target=_print_logs, kwargs={"app_id": app_id})
process.start()
- if not app.spec.is_headless:
- while True:
- try:
- with admin_page.context.expect_page() as page_catcher:
- admin_page.locator('[data-cy="open"]').click()
- view_page = page_catcher.value
- view_page.wait_for_load_state(timeout=0)
- break
- except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError):
- pass
- else:
- view_page = None
-
- # Wait until the app is running
- while True:
- sleep(1)
-
- lit_apps = [
- app
- for app in client.lightningapp_instance_service_list_lightningapp_instances(
- project_id=project.project_id
- ).lightningapps
- if app.name == name
- ]
- app = lit_apps[0]
-
- if app.status.phase == V1LightningappInstanceState.RUNNING:
- break
+ # Wait until the app is running
+ while True:
+ sleep(1)
+
+ lit_apps = [
+ app
+ for app in client.lightningapp_instance_service_list_lightningapp_instances(
+ project_id=project.project_id
+ ).lightningapps
+ if app.name == name
+ ]
+ app = lit_apps[0]
+
+ if app.status.phase == V1LightningappInstanceState.RUNNING:
+ break
+
+ view_page = context.new_page()
+ view_page.goto(f"{app.status.url}/view")
# TODO: is re-creating this redundant?
lit_apps = [
@@ -488,12 +478,12 @@ def wait_for(page, callback: Callable, *args, **kwargs) -> Any:
except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError) as e:
print(e)
try:
- sleep(5)
+ sleep(7)
page.reload()
except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError) as e:
print(e)
pass
- sleep(2)
+ sleep(3)
def _delete_lightning_app(client, project_id, app_id, app_name):
diff --git a/src/lightning_app/utilities/app_helpers.py b/src/lightning_app/utilities/app_helpers.py
index bc3d092b280dd..d9efeb6862ba7 100644
--- a/src/lightning_app/utilities/app_helpers.py
+++ b/src/lightning_app/utilities/app_helpers.py
@@ -184,11 +184,22 @@ def render_non_authorized(self):
def target_fn():
- from streamlit.server.server import Server
+ try:
+ # streamlit >= 1.14.0
+ from streamlit import runtime
+
+ get_instance = runtime.get_instance
+ exists = runtime.exists()
+ except ImportError:
+ # Older versions
+ from streamlit.server.server import Server
+
+ get_instance = Server.get_current
+ exists = bool(Server._singleton)
async def update_fn():
- server = Server.get_current()
- sessions = list(server._session_info_by_id.values())
+ runtime_instance = get_instance()
+ sessions = list(runtime_instance._session_info_by_id.values())
url = (
"localhost:8080"
if "LIGHTNING_APP_STATE_URL" in os.environ
@@ -198,15 +209,20 @@ async def update_fn():
last_updated = time.time()
async with websockets.connect(ws_url) as websocket:
while True:
- _ = await websocket.recv()
- while (time.time() - last_updated) < 1:
- time.sleep(0.1)
- for session in sessions:
- session = session.session
- session.request_rerun(session._client_state)
- last_updated = time.time()
-
- if Server._singleton:
+ try:
+ _ = await websocket.recv()
+
+ while (time.time() - last_updated) < 1:
+ time.sleep(0.1)
+ for session in sessions:
+ session = session.session
+ session.request_rerun(session._client_state)
+ last_updated = time.time()
+ except websockets.exceptions.ConnectionClosedOK:
+ # The websocket is not enabled
+ break
+
+ if exists:
asyncio.run(update_fn())
diff --git a/src/lightning_app/utilities/app_logs.py b/src/lightning_app/utilities/app_logs.py
index 369adc5d09f11..04ee7435eb4fa 100644
--- a/src/lightning_app/utilities/app_logs.py
+++ b/src/lightning_app/utilities/app_logs.py
@@ -79,7 +79,7 @@ def _app_logs_reader(
# And each socket on separate thread pushing log event to print queue
# run_forever() will run until we close() the connection from outside
- log_threads = [Thread(target=work.run_forever) for work in log_sockets]
+ log_threads = [Thread(target=work.run_forever, daemon=True) for work in log_sockets]
# Establish connection and begin pushing logs to the print queue
for th in log_threads:
diff --git a/src/lightning_app/utilities/app_status.py b/src/lightning_app/utilities/app_status.py
new file mode 100644
index 0000000000000..232c3f0b65210
--- /dev/null
+++ b/src/lightning_app/utilities/app_status.py
@@ -0,0 +1,29 @@
+from datetime import datetime
+from typing import Any, Dict, Optional
+
+from pydantic import BaseModel
+
+
+class WorkStatus(BaseModel):
+ """The ``WorkStatus`` captures the status of a work according to the app."""
+
+ stage: str
+ timestamp: float
+ reason: Optional[str] = None
+ message: Optional[str] = None
+ count: int = 1
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+
+ assert self.timestamp > 0 and self.timestamp < (int(datetime.now().timestamp()) + 10)
+
+
+class AppStatus(BaseModel):
+ """The ``AppStatus`` captures the current status of the app and its components."""
+
+ # ``True`` when the app UI is ready to be viewed
+ is_ui_ready: bool
+
+ # The statuses of ``LightningWork`` objects currently associated with this app
+ work_statuses: Dict[str, WorkStatus]
diff --git a/src/lightning_app/utilities/cloud.py b/src/lightning_app/utilities/cloud.py
index 20ab6d14827c9..5fb3cc768a973 100644
--- a/src/lightning_app/utilities/cloud.py
+++ b/src/lightning_app/utilities/cloud.py
@@ -1,5 +1,4 @@
import os
-import warnings
from lightning_cloud.openapi import V1Membership
@@ -25,11 +24,7 @@ def _get_project(client: LightningClient, project_id: str = LIGHTNING_CLOUD_PROJ
if len(projects.memberships) == 0:
raise ValueError("No valid projects found. Please reach out to lightning.ai team to create a project")
if len(projects.memberships) > 1:
- warnings.warn(
- f"It is currently not supported to have multiple projects but "
- f"found {len(projects.memberships)} projects."
- f" Defaulting to the project {projects.memberships[0].name}"
- )
+ print(f"Defaulting to the project: {projects.memberships[0].name}")
return projects.memberships[0]
@@ -39,4 +34,4 @@ def _sigterm_flow_handler(*_, app: "lightning_app.LightningApp"):
def is_running_in_cloud() -> bool:
"""Returns True if the Lightning App is running in the cloud."""
- return "LIGHTNING_APP_STATE_URL" in os.environ
+ return bool(int(os.environ.get("LAI_RUNNING_IN_CLOUD", "0"))) or "LIGHTNING_APP_STATE_URL" in os.environ
diff --git a/src/lightning_app/utilities/enum.py b/src/lightning_app/utilities/enum.py
index 11cd7fabc4299..4c92ffba3db11 100644
--- a/src/lightning_app/utilities/enum.py
+++ b/src/lightning_app/utilities/enum.py
@@ -1,5 +1,4 @@
import enum
-from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Optional
@@ -47,18 +46,6 @@ class WorkStageStatus:
FAILED = "failed"
-@dataclass
-class WorkStatus:
- stage: WorkStageStatus
- timestamp: float
- reason: Optional[str] = None
- message: Optional[str] = None
- count: int = 1
-
- def __post_init__(self):
- assert self.timestamp > 0 and self.timestamp < (int(datetime.now().timestamp()) + 10)
-
-
def make_status(stage: str, message: Optional[str] = None, reason: Optional[str] = None):
status = {
"stage": stage,
diff --git a/src/lightning_app/utilities/packaging/cloud_compute.py b/src/lightning_app/utilities/packaging/cloud_compute.py
index ca6c9705ae866..db890b3301f76 100644
--- a/src/lightning_app/utilities/packaging/cloud_compute.py
+++ b/src/lightning_app/utilities/packaging/cloud_compute.py
@@ -71,7 +71,7 @@ class CloudCompute:
name: str = "default"
disk_size: int = 0
idle_timeout: Optional[int] = None
- shm_size: Optional[int] = 0
+ shm_size: Optional[int] = None
mounts: Optional[Union[Mount, List[Mount]]] = None
_internal_id: Optional[str] = None
@@ -80,6 +80,12 @@ def __post_init__(self) -> None:
self.name = self.name.lower()
+ if self.shm_size is None:
+ if "gpu" in self.name:
+ self.shm_size = 1024
+ else:
+ self.shm_size = 0
+
# All `default` CloudCompute are identified in the same way.
if self._internal_id is None:
self._internal_id = self._generate_id()
diff --git a/src/lightning_app/utilities/scheduler.py b/src/lightning_app/utilities/scheduler.py
index e45b0879246b9..b15e49a92673d 100644
--- a/src/lightning_app/utilities/scheduler.py
+++ b/src/lightning_app/utilities/scheduler.py
@@ -1,10 +1,9 @@
import threading
-from copy import deepcopy
from datetime import datetime
from typing import Optional
from croniter import croniter
-from deepdiff import DeepDiff, Delta
+from deepdiff import Delta
from lightning_app.utilities.proxies import ComponentDelta
@@ -34,11 +33,15 @@ def run_once(self):
next_event = croniter(metadata["cron_pattern"], start_time).get_next(datetime)
# When the event is reached, send a delta to activate scheduling.
if current_date > next_event:
- flow = self._app.get_component_by_name(metadata["name"])
- previous_state = deepcopy(flow.state)
- flow._enable_schedule(call_hash)
component_delta = ComponentDelta(
- id=flow.name, delta=Delta(DeepDiff(previous_state, flow.state, verbose_level=2))
+ id=metadata["name"],
+ delta=Delta(
+ {
+ "values_changed": {
+ f"root['calls']['scheduling']['{call_hash}']['running']": {"new_value": True}
+ }
+ }
+ ),
)
self._app.delta_queue.put(component_delta)
metadata["start_time"] = next_event.isoformat()
diff --git a/src/lightning_lite/CHANGELOG.md b/src/lightning_lite/CHANGELOG.md
index baae664a95cb2..3f661b727c3a7 100644
--- a/src/lightning_lite/CHANGELOG.md
+++ b/src/lightning_lite/CHANGELOG.md
@@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
+## [1.8.6] - 2022-12-21
+
+### Added
+
+
+### Changed
+
+
+### Fixed
+
+
## [1.8.5] - 2022-12-15
- minor cleaning
diff --git a/src/lightning_lite/__version__.py b/src/lightning_lite/__version__.py
index 4a295c3c7c531..fd11d27b1347a 100644
--- a/src/lightning_lite/__version__.py
+++ b/src/lightning_lite/__version__.py
@@ -1 +1 @@
-version = "1.8.5.post0"
+version = "1.8.6"
diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md
index 29cf1e28428d2..dfdeea3252bb8 100644
--- a/src/pytorch_lightning/CHANGELOG.md
+++ b/src/pytorch_lightning/CHANGELOG.md
@@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
+## [1.8.6] - 2022-12-21
+
+### Added
+
+
+### Changed
+
+
+### Fixed
+
+
## [1.8.5] - 2022-12-15
- Add function to remove checkpoint to allow override for extended classes ([#16067](https://github.com/Lightning-AI/lightning/pull/16067))
diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md
index 54c3db39c4973..67ddf6eeca6a6 100644
--- a/src/pytorch_lightning/README.md
+++ b/src/pytorch_lightning/README.md
@@ -78,15 +78,15 @@ Lightning is rigorously tested across multiple CPUs, GPUs, TPUs, IPUs, and HPUs
-| System / PyTorch ver. | 1.9 | 1.10 | 1.12 (latest) |
-| :------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| Linux py3.7 \[GPUs\*\*\] | - | - | - |
-| Linux py3.7 \[TPUs\*\*\*\] | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/tpu-tests.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/tpu-tests.yml) | - | - |
-| Linux py3.8 \[IPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=25&branchName=master) | - | - |
-| Linux py3.8 \[HPUs\] | - | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=26&branchName=master) | - |
-| Linux py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml) |
-| OSX py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml) |
-| Windows py3.{7,9} | - | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-pytorch-tests.yml) |
+| System / PyTorch ver. | 1.10 | 1.12 |
+| :------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| Linux py3.7 \[GPUs\*\*\] | - | - |
+| Linux py3.7 \[TPUs\*\*\*\] | - | - |
+| Linux py3.8 \[IPUs\] | - | - |
+| Linux py3.8 \[HPUs\] | [![Build Status]()](https://dev.azure.com/Lightning-AI/lightning/_build/latest?definitionId=26&branchName=master) | - |
+| Linux py3.{7,9} | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) |
+| OSX py3.{7,9} | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) |
+| Windows py3.{7,9} | - | [![Test](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg?branch=master&event=push)](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml) |
- _\*\* tests run on two NVIDIA P100_
- _\*\*\* tests run on Google GKE TPUv2/3. TPU py3.7 means we support Colab and Kaggle env._
diff --git a/src/pytorch_lightning/__version__.py b/src/pytorch_lightning/__version__.py
index 4a295c3c7c531..fd11d27b1347a 100644
--- a/src/pytorch_lightning/__version__.py
+++ b/src/pytorch_lightning/__version__.py
@@ -1 +1 @@
-version = "1.8.5.post0"
+version = "1.8.6"
diff --git a/tests/tests_app/components/multi_node/test_base.py b/tests/tests_app/components/multi_node/test_base.py
index 2c6aed1120c0a..2eb73d7c026f6 100644
--- a/tests/tests_app/components/multi_node/test_base.py
+++ b/tests/tests_app/components/multi_node/test_base.py
@@ -13,7 +13,7 @@ class Work(LightningWork):
def run(self):
pass
- with pytest.warns(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
+ with pytest.warns(UserWarning, match=escape("You set MultiNode(num_nodes=2, ...)` but ")):
MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))
with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
diff --git a/tests/tests_app/components/serve/test_auto_scaler.py b/tests/tests_app/components/serve/test_auto_scaler.py
new file mode 100644
index 0000000000000..e53c7890696a4
--- /dev/null
+++ b/tests/tests_app/components/serve/test_auto_scaler.py
@@ -0,0 +1,230 @@
+import time
+import uuid
+from unittest import mock
+from unittest.mock import patch
+
+import pytest
+from fastapi import HTTPException
+
+from lightning_app import CloudCompute, LightningWork
+from lightning_app.components import AutoScaler, ColdStartProxy, Text
+from lightning_app.components.serve.auto_scaler import _LoadBalancer
+
+
+class EmptyWork(LightningWork):
+ def run(self):
+ pass
+
+
+class AutoScaler1(AutoScaler):
+ def scale(self, replicas: int, metrics) -> int:
+ # only upscale
+ return replicas + 1
+
+
+class AutoScaler2(AutoScaler):
+ def scale(self, replicas: int, metrics) -> int:
+ # only downscale
+ return replicas - 1
+
+
+def test_num_replicas_after_init():
+ """Test the number of works is the same as min_replicas after initialization."""
+ min_replicas = 2
+ auto_scaler = AutoScaler(EmptyWork, min_replicas=min_replicas)
+ assert auto_scaler.num_replicas == min_replicas
+
+
+@patch("uvicorn.run")
+@patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url")
+@patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests")
+def test_num_replicas_not_above_max_replicas(*_):
+ """Test self.num_replicas doesn't exceed max_replicas."""
+ max_replicas = 6
+ auto_scaler = AutoScaler1(
+ EmptyWork,
+ min_replicas=1,
+ max_replicas=max_replicas,
+ scale_out_interval=0.001,
+ scale_in_interval=0.001,
+ )
+
+ for _ in range(max_replicas + 1):
+ time.sleep(0.002)
+ auto_scaler.run()
+
+ assert auto_scaler.num_replicas == max_replicas
+
+
+@patch("uvicorn.run")
+@patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url")
+@patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests")
+def test_num_replicas_not_belo_min_replicas(*_):
+ """Test self.num_replicas doesn't exceed max_replicas."""
+ min_replicas = 1
+ auto_scaler = AutoScaler2(
+ EmptyWork,
+ min_replicas=min_replicas,
+ max_replicas=4,
+ scale_out_interval=0.001,
+ scale_in_interval=0.001,
+ )
+
+ for _ in range(3):
+ time.sleep(0.002)
+ auto_scaler.run()
+
+ assert auto_scaler.num_replicas == min_replicas
+
+
+@pytest.mark.parametrize(
+ "replicas, metrics, expected_replicas",
+ [
+ pytest.param(1, {"pending_requests": 1, "pending_works": 0}, 2, id="increase if no pending work"),
+ pytest.param(1, {"pending_requests": 1, "pending_works": 1}, 1, id="dont increase if pending works"),
+ pytest.param(8, {"pending_requests": 1, "pending_works": 0}, 7, id="reduce if requests < 25% capacity"),
+ pytest.param(8, {"pending_requests": 2, "pending_works": 0}, 8, id="dont reduce if requests >= 25% capacity"),
+ ],
+)
+def test_scale(replicas, metrics, expected_replicas):
+ """Test `scale()`, the default scaling strategy."""
+ auto_scaler = AutoScaler(
+ EmptyWork,
+ min_replicas=1,
+ max_replicas=8,
+ max_batch_size=1,
+ )
+
+ assert auto_scaler.scale(replicas, metrics) == expected_replicas
+
+
+def test_scale_from_zero_min_replica():
+ auto_scaler = AutoScaler(
+ EmptyWork,
+ min_replicas=0,
+ max_replicas=2,
+ max_batch_size=10,
+ )
+
+ resp = auto_scaler.scale(0, {"pending_requests": 0, "pending_works": 0})
+ assert resp == 0
+
+ resp = auto_scaler.scale(0, {"pending_requests": 1, "pending_works": 0})
+ assert resp == 1
+
+ resp = auto_scaler.scale(0, {"pending_requests": 1, "pending_works": 1})
+ assert resp <= 0
+
+
+def test_create_work_cloud_compute_cloned():
+ """Test CloudCompute is cloned to avoid creating multiple works in a single machine."""
+ cloud_compute = CloudCompute("gpu")
+ auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
+ _ = auto_scaler.create_work()
+ assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute
+
+
+fastapi_mock = mock.MagicMock()
+mocked_fastapi_creater = mock.MagicMock(return_value=fastapi_mock)
+
+
+@patch("lightning_app.components.serve.auto_scaler._create_fastapi", mocked_fastapi_creater)
+@patch("lightning_app.components.serve.auto_scaler.uvicorn.run", mock.MagicMock())
+def test_API_ACCESS_ENDPOINT_creation():
+ auto_scaler = AutoScaler(EmptyWork, input_type=Text, output_type=Text)
+ assert auto_scaler.load_balancer._api_name == "EmptyWork"
+
+ auto_scaler.load_balancer.run()
+ fastapi_mock.mount.assert_called_once_with("/endpoint-info", mock.ANY, name="static")
+
+
+def test_autoscaler_scale_up(monkeypatch):
+ monkeypatch.setattr(AutoScaler, "num_pending_works", 0)
+ monkeypatch.setattr(AutoScaler, "num_pending_requests", 100)
+ monkeypatch.setattr(AutoScaler, "scale", mock.MagicMock(return_value=1))
+ monkeypatch.setattr(AutoScaler, "create_work", mock.MagicMock())
+ monkeypatch.setattr(AutoScaler, "add_work", mock.MagicMock())
+
+ auto_scaler = AutoScaler(EmptyWork, min_replicas=0, max_replicas=4, scale_out_interval=0.001)
+
+ # Mocking the attributes
+ auto_scaler._last_autoscale = time.time() - 100000
+ auto_scaler.num_replicas = 0
+
+ # triggering scale up
+ auto_scaler.autoscale()
+ auto_scaler.scale.assert_called_once()
+ auto_scaler.create_work.assert_called_once()
+ auto_scaler.add_work.assert_called_once()
+
+
+def test_autoscaler_scale_down(monkeypatch):
+ monkeypatch.setattr(AutoScaler, "num_pending_works", 0)
+ monkeypatch.setattr(AutoScaler, "num_pending_requests", 0)
+ monkeypatch.setattr(AutoScaler, "scale", mock.MagicMock(return_value=0))
+ monkeypatch.setattr(AutoScaler, "remove_work", mock.MagicMock())
+ monkeypatch.setattr(AutoScaler, "workers", mock.MagicMock())
+
+ auto_scaler = AutoScaler(EmptyWork, min_replicas=0, max_replicas=4, scale_in_interval=0.001)
+
+ # Mocking the attributes
+ auto_scaler._last_autoscale = time.time() - 100000
+ auto_scaler.num_replicas = 1
+ auto_scaler.__dict__["load_balancer"] = mock.MagicMock()
+
+ # triggering scale up
+ auto_scaler.autoscale()
+ auto_scaler.scale.assert_called_once()
+ auto_scaler.remove_work.assert_called_once()
+
+
+class TestLoadBalancerProcessRequest:
+ @pytest.mark.asyncio
+ async def test_workers_not_ready_with_cold_start_proxy(self, monkeypatch):
+ monkeypatch.setattr(ColdStartProxy, "handle_request", mock.AsyncMock())
+ load_balancer = _LoadBalancer(
+ input_type=Text, output_type=Text, endpoint="/predict", cold_start_proxy=ColdStartProxy("url")
+ )
+ req_id = uuid.uuid4().hex
+ await load_balancer.process_request("test", req_id)
+ load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test")
+
+ @pytest.mark.asyncio
+ async def test_workers_not_ready_without_cold_start_proxy(self, monkeypatch):
+ load_balancer = _LoadBalancer(
+ input_type=Text,
+ output_type=Text,
+ endpoint="/predict",
+ )
+ req_id = uuid.uuid4().hex
+ # populating the responses so the while loop exists
+ load_balancer._responses = {req_id: "Dummy"}
+ with pytest.raises(HTTPException):
+ await load_balancer.process_request("test", req_id)
+
+ @pytest.mark.asyncio
+ async def test_workers_have_no_capacity_with_cold_start_proxy(self, monkeypatch):
+ monkeypatch.setattr(ColdStartProxy, "handle_request", mock.AsyncMock())
+ load_balancer = _LoadBalancer(
+ input_type=Text, output_type=Text, endpoint="/predict", cold_start_proxy=ColdStartProxy("url")
+ )
+ load_balancer._fastapi_app = mock.MagicMock()
+ load_balancer._fastapi_app.num_current_requests = 1000
+ load_balancer._servers.append(mock.MagicMock())
+ req_id = uuid.uuid4().hex
+ await load_balancer.process_request("test", req_id)
+ load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test")
+
+ @pytest.mark.asyncio
+ async def test_workers_are_free(self):
+ load_balancer = _LoadBalancer(
+ input_type=Text,
+ output_type=Text,
+ endpoint="/predict",
+ )
+ load_balancer._servers.append(mock.MagicMock())
+ req_id = uuid.uuid4().hex
+ # populating the responses so the while loop exists
+ load_balancer._responses = {req_id: "Dummy"}
+ await load_balancer.process_request("test", req_id)
+ assert load_balancer._batch == [(req_id, "test")]
diff --git a/tests/tests_app/components/serve/test_model_inference_api.py b/tests/tests_app/components/serve/test_model_inference_api.py
index 17ed09aa2eea8..06a2ea9186ff6 100644
--- a/tests/tests_app/components/serve/test_model_inference_api.py
+++ b/tests/tests_app/components/serve/test_model_inference_api.py
@@ -48,6 +48,7 @@ def test_model_inference_api(workers):
process.terminate()
# TODO: Investigate why this doesn't match exactly `imgstr`.
assert res.json()
+ process.kill()
class EmptyServer(serve.ModelInferenceAPI):
diff --git a/tests/tests_app/components/serve/test_python_server.py b/tests/tests_app/components/serve/test_python_server.py
index 313638e9ec42a..f497927a4897b 100644
--- a/tests/tests_app/components/serve/test_python_server.py
+++ b/tests/tests_app/components/serve/test_python_server.py
@@ -29,17 +29,18 @@ def test_python_server_component():
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
process.terminate()
assert res.json()["prediction"] == "test"
+ process.kill()
def test_image_sample_data():
- data = Image()._get_sample_data()
+ data = Image().get_sample_data()
assert isinstance(data, dict)
assert "image" in data
assert len(data["image"]) > 100
def test_number_sample_data():
- data = Number()._get_sample_data()
+ data = Number().get_sample_data()
assert isinstance(data, dict)
assert "prediction" in data
assert data["prediction"] == 463
diff --git a/tests/tests_app/components/test_auto_scaler.py b/tests/tests_app/components/test_auto_scaler.py
deleted file mode 100644
index 672b05bbc9a15..0000000000000
--- a/tests/tests_app/components/test_auto_scaler.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import time
-from unittest.mock import patch
-
-import pytest
-
-from lightning_app import CloudCompute, LightningWork
-from lightning_app.components import AutoScaler
-
-
-class EmptyWork(LightningWork):
- def run(self):
- pass
-
-
-class AutoScaler1(AutoScaler):
- def scale(self, replicas: int, metrics) -> int:
- # only upscale
- return replicas + 1
-
-
-class AutoScaler2(AutoScaler):
- def scale(self, replicas: int, metrics) -> int:
- # only downscale
- return replicas - 1
-
-
-def test_num_replicas_after_init():
- """Test the number of works is the same as min_replicas after initialization."""
- min_replicas = 2
- auto_scaler = AutoScaler(EmptyWork, min_replicas=min_replicas)
- assert auto_scaler.num_replicas == min_replicas
-
-
-@patch("uvicorn.run")
-@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
-@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
-def test_num_replicas_not_above_max_replicas(*_):
- """Test self.num_replicas doesn't exceed max_replicas."""
- max_replicas = 6
- auto_scaler = AutoScaler1(
- EmptyWork,
- min_replicas=1,
- max_replicas=max_replicas,
- autoscale_interval=0.001,
- )
-
- for _ in range(max_replicas + 1):
- time.sleep(0.002)
- auto_scaler.run()
-
- assert auto_scaler.num_replicas == max_replicas
-
-
-@patch("uvicorn.run")
-@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
-@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
-def test_num_replicas_not_belo_min_replicas(*_):
- """Test self.num_replicas doesn't exceed max_replicas."""
- min_replicas = 1
- auto_scaler = AutoScaler2(
- EmptyWork,
- min_replicas=min_replicas,
- max_replicas=4,
- autoscale_interval=0.001,
- )
-
- for _ in range(3):
- time.sleep(0.002)
- auto_scaler.run()
-
- assert auto_scaler.num_replicas == min_replicas
-
-
-@pytest.mark.parametrize(
- "replicas, metrics, expected_replicas",
- [
- pytest.param(1, {"pending_requests": 1, "pending_works": 0}, 2, id="increase if no pending work"),
- pytest.param(1, {"pending_requests": 1, "pending_works": 1}, 1, id="dont increase if pending works"),
- pytest.param(8, {"pending_requests": 1, "pending_works": 0}, 7, id="reduce if requests < 25% capacity"),
- pytest.param(8, {"pending_requests": 2, "pending_works": 0}, 8, id="dont reduce if requests >= 25% capacity"),
- ],
-)
-def test_scale(replicas, metrics, expected_replicas):
- """Test `scale()`, the default scaling strategy."""
- auto_scaler = AutoScaler(
- EmptyWork,
- min_replicas=1,
- max_replicas=8,
- max_batch_size=1,
- )
-
- assert auto_scaler.scale(replicas, metrics) == expected_replicas
-
-
-def test_create_work_cloud_compute_cloned():
- """Test CloudCompute is cloned to avoid creating multiple works in a single machine."""
- cloud_compute = CloudCompute("gpu")
- auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
- _ = auto_scaler.create_work()
- assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute
diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py
index 6f74feb8a360c..d0df4feaa11fa 100644
--- a/tests/tests_app/conftest.py
+++ b/tests/tests_app/conftest.py
@@ -1,14 +1,17 @@
import os
import shutil
+import signal
import threading
from datetime import datetime
from pathlib import Path
+from threading import Thread
import psutil
import py
import pytest
from lightning_app.storage.path import _storage_root_dir
+from lightning_app.utilities.app_helpers import _collect_child_process_pids
from lightning_app.utilities.component import _set_context
from lightning_app.utilities.packaging import cloud_compute
from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME
@@ -16,6 +19,15 @@
os.environ["LIGHTNING_DISPATCHED"] = "1"
+original_method = Thread._wait_for_tstate_lock
+
+
+def fn(self, *args, timeout=None, **kwargs):
+ original_method(self, *args, timeout=1, **kwargs)
+
+
+Thread._wait_for_tstate_lock = fn
+
def pytest_sessionfinish(session, exitstatus):
"""Pytest hook that get called after whole test run finished, right before returning the exit status to the
@@ -40,6 +52,9 @@ def pytest_sessionfinish(session, exitstatus):
if t is not main_thread:
t.join(0)
+ for child_pid in _collect_child_process_pids(os.getpid()):
+ os.kill(child_pid, signal.SIGTERM)
+
@pytest.fixture(scope="function", autouse=True)
def cleanup():
diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py
index 04b89c927941a..3003ecdd62e2d 100644
--- a/tests/tests_app/core/test_lightning_api.py
+++ b/tests/tests_app/core/test_lightning_api.py
@@ -5,6 +5,7 @@
import sys
from copy import deepcopy
from multiprocessing import Process
+from pathlib import Path
from time import sleep, time
from unittest import mock
@@ -12,7 +13,7 @@
import pytest
import requests
from deepdiff import DeepDiff, Delta
-from fastapi import HTTPException
+from fastapi import HTTPException, Request
from httpx import AsyncClient
from pydantic import BaseModel
@@ -31,6 +32,7 @@
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage.drive import Drive
from lightning_app.testing.helpers import _MockQueue
+from lightning_app.utilities.app_status import AppStatus
from lightning_app.utilities.component import _set_frontend_context, _set_work_context
from lightning_app.utilities.enum import AppStage
from lightning_app.utilities.load_app import extract_metadata_from_app
@@ -195,7 +197,7 @@ def test_update_publish_state_and_maybe_refresh_ui():
publish_state_queue = _MockQueue("publish_state_queue")
api_response_queue = _MockQueue("api_response_queue")
- publish_state_queue.put(app.state_with_changes)
+ publish_state_queue.put((app.state_with_changes, None))
thread = UIRefresher(publish_state_queue, api_response_queue)
thread.run_once()
@@ -226,7 +228,7 @@ def get(self, timeout: int = 0):
has_started_queue = _MockQueue("has_started_queue")
api_response_queue = _MockQueue("api_response_queue")
state = app.state_with_changes
- publish_state_queue.put(state)
+ publish_state_queue.put((state, AppStatus(is_ui_ready=True, work_statuses={})))
spec = extract_metadata_from_app(app)
ui_refresher = start_server(
publish_state_queue,
@@ -284,6 +286,9 @@ def get(self, timeout: int = 0):
{"name": "main_4", "content": "https://te"},
]
+ response = await client.get("/api/v1/status")
+ assert response.json() == {"is_ui_ready": True, "work_statuses": {}}
+
response = await client.post("/api/v1/state", json={"state": new_state}, headers=headers)
assert change_state_queue._queue[1].to_dict() == {
"values_changed": {"root['vars']['counter']": {"new_value": 1}}
@@ -380,7 +385,7 @@ async def test_health_endpoint_success():
@pytest.mark.anyio
async def test_health_endpoint_failure(monkeypatch):
monkeypatch.setenv("LIGHTNING_APP_STATE_URL", "http://someurl") # adding this to make is_running_in_cloud pass
- monkeypatch.setattr(api, "CLOUD_QUEUE_TYPE", "redis")
+ monkeypatch.setitem(os.environ, "LIGHTNING_CLOUD_QUEUE_TYPE", "redis")
async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
# will respond 503 if redis is not running
response = await client.get("/healthz")
@@ -479,10 +484,13 @@ def run(self):
if self.counter == 501:
self._exit()
- def request(self, config: InputRequestModel) -> OutputRequestModel:
+ def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel:
self.counter += 1
if config.index % 5 == 0:
raise HTTPException(status_code=400, detail="HERE")
+ assert request.body()
+ assert request.json()
+ assert request.headers
return OutputRequestModel(name=config.name, counter=self.counter)
def configure_api(self):
@@ -554,3 +562,37 @@ def test_configure_api():
sleep(0.1)
time_left -= 0.1
assert process.exitcode == 0
+ process.kill()
+
+
+@pytest.mark.anyio
+@mock.patch("lightning_app.core.api.UIRefresher", mock.MagicMock())
+async def test_get_annotations(tmpdir):
+ cwd = os.getcwd()
+ os.chdir(tmpdir)
+
+ Path("lightning-annotations.json").write_text('[{"test": 3}]')
+
+ try:
+ app = AppStageTestingApp(FlowA(), log_level="debug")
+ app._update_layout()
+ app.stage = AppStage.BLOCKING
+ change_state_queue = _MockQueue("change_state_queue")
+ has_started_queue = _MockQueue("has_started_queue")
+ api_response_queue = _MockQueue("api_response_queue")
+ spec = extract_metadata_from_app(app)
+ start_server(
+ None,
+ change_state_queue,
+ api_response_queue,
+ has_started_queue=has_started_queue,
+ uvicorn_run=False,
+ spec=spec,
+ )
+
+ async with AsyncClient(app=fastapi_service, base_url="http://test") as client:
+ response = await client.get("/api/v1/annotations")
+ assert response.json() == [{"test": 3}]
+ finally:
+ # Cleanup
+ os.chdir(cwd)
diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py
index ea552adad7972..68284434d15c8 100644
--- a/tests/tests_app/core/test_lightning_app.py
+++ b/tests/tests_app/core/test_lightning_app.py
@@ -2,7 +2,7 @@
import os
import pickle
from re import escape
-from time import sleep
+from time import sleep, time
from unittest import mock
import pytest
@@ -124,6 +124,7 @@ def test_simple_app(tmpdir):
"_paths": {},
"_port": None,
"_restarting": False,
+ "_display_name": "",
},
"calls": {"latest_call_hash": None},
"changes": {},
@@ -140,6 +141,7 @@ def test_simple_app(tmpdir):
"_paths": {},
"_port": None,
"_restarting": False,
+ "_display_name": "",
},
"calls": {"latest_call_hash": None},
"changes": {},
@@ -480,6 +482,21 @@ def make_delta(i):
assert generated > expect
+def test_lightning_app_aggregation_empty():
+ """Verify the while loop exits before `state_accumulate_wait` is reached if no deltas are found."""
+
+ class SlowQueue(MultiProcessQueue):
+ def get(self, timeout):
+ out = super().get(timeout)
+ return out
+
+ app = LightningApp(EmptyFlow())
+ app.delta_queue = SlowQueue("api_delta_queue", 0)
+ t0 = time()
+ assert app._collect_deltas_from_ui_and_work_queues() == []
+ assert (time() - t0) < app.state_accumulate_wait
+
+
class SimpleFlow2(LightningFlow):
def __init__(self):
super().__init__()
@@ -639,6 +656,7 @@ def run(self):
self.flow.run()
+@pytest.mark.skipif(True, reason="reloading isn't properly supported")
def test_lightning_app_checkpointing_with_nested_flows():
work = CheckpointCounter()
app = LightningApp(CheckpointFlow(work))
@@ -969,7 +987,7 @@ def run(self):
def test_state_size_constant_growth():
app = LightningApp(SizeFlow())
MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root._state_sizes[0] <= 7824
+ assert app.root._state_sizes[0] <= 7888
assert app.root._state_sizes[20] <= 26500
diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py
index c8e9921f29eec..7e547d75c5117 100644
--- a/tests/tests_app/core/test_lightning_flow.py
+++ b/tests/tests_app/core/test_lightning_flow.py
@@ -5,21 +5,21 @@
from dataclasses import dataclass
from functools import partial
from time import time
-from unittest.mock import ANY, MagicMock
+from unittest.mock import ANY
import pytest
from deepdiff import DeepDiff, Delta
import lightning_app
from lightning_app import CloudCompute, LightningApp
-from lightning_app.core.flow import LightningFlow
+from lightning_app.core.flow import _RootFlow, LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage import Path
from lightning_app.storage.path import _storage_root_dir
from lightning_app.structures import Dict as LDict
from lightning_app.structures import List as LList
-from lightning_app.testing.helpers import EmptyFlow, EmptyWork
+from lightning_app.testing.helpers import _MockQueue, EmptyFlow, EmptyWork
from lightning_app.utilities.app_helpers import (
_delta_to_app_state_delta,
_LightningAppRef,
@@ -329,6 +329,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
@@ -352,6 +353,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
@@ -391,6 +393,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
@@ -414,6 +417,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
@@ -644,14 +648,14 @@ def run(self):
if len(self._last_times) < 3:
self._last_times.append(time())
else:
- assert abs((time() - self._last_times[-1]) - self.target) < 3
+ assert abs((time() - self._last_times[-1]) - self.target) < 12
self._exit()
def test_scheduling_api():
app = LightningApp(FlowSchedule())
- MultiProcessRuntime(app, start_server=True).dispatch()
+ MultiProcessRuntime(app, start_server=False).dispatch()
def test_lightning_flow():
@@ -864,10 +868,10 @@ def test_lightning_flow_flows_and_works():
class WorkReady(LightningWork):
def __init__(self):
super().__init__(parallel=True)
- self.counter = 0
+ self.ready = False
def run(self):
- self.counter += 1
+ self.ready = True
class FlowReady(LightningFlow):
@@ -886,22 +890,44 @@ def run(self):
self._exit()
-def test_flow_ready():
- """This test validates the api publish state queue is populated only once ready is True."""
+class RootFlowReady(_RootFlow):
+ def __init__(self):
+ super().__init__(WorkReady())
+
+
+@pytest.mark.parametrize("flow", [FlowReady, RootFlowReady])
+def test_flow_ready(flow):
+ """This test validates that the app status queue is populated correctly."""
+
+ mock_queue = _MockQueue("api_publish_state_queue")
def run_patch(method):
- app.api_publish_state_queue = MagicMock()
- app.should_publish_changes_to_api = False
+ app.should_publish_changes_to_api = True
+ app.api_publish_state_queue = mock_queue
method()
- app = LightningApp(FlowReady())
+ state = {"done": False}
+
+ def lagged_run_once(method):
+ """Ensure that the full loop is run after the app exits."""
+ new_done = method()
+ if state["done"]:
+ return True
+ state["done"] = new_done
+ return False
+
+ app = LightningApp(flow())
app._run = partial(run_patch, method=app._run)
+ app.run_once = partial(lagged_run_once, method=app.run_once)
MultiProcessRuntime(app, start_server=False).dispatch()
- # Validates the state has been added only when ready was true.
- state = app.api_publish_state_queue.put._mock_call_args[0][0]
- call_hash = state["works"]["w"]["calls"]["latest_call_hash"]
- assert state["works"]["w"]["calls"][call_hash]["statuses"][0]["stage"] == "succeeded"
+ _, first_status = mock_queue.get()
+ assert not first_status.is_ui_ready
+
+ _, last_status = mock_queue.get()
+ while len(mock_queue) > 0:
+ _, last_status = mock_queue.get()
+ assert last_status.is_ui_ready
def test_structures_register_work_cloudcompute():
diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py
index cb97eabfa237c..8eb69c6539168 100644
--- a/tests/tests_app/core/test_lightning_work.py
+++ b/tests/tests_app/core/test_lightning_work.py
@@ -1,6 +1,6 @@
from queue import Empty
from re import escape
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock
import pytest
@@ -11,7 +11,7 @@
from lightning_app.storage import Path
from lightning_app.testing.helpers import _MockQueue, EmptyFlow, EmptyWork
from lightning_app.testing.testing import LightningTestApp
-from lightning_app.utilities.enum import WorkStageStatus
+from lightning_app.utilities.enum import make_status, WorkStageStatus
from lightning_app.utilities.exceptions import LightningWorkException
from lightning_app.utilities.packaging.build_config import BuildConfig
from lightning_app.utilities.proxies import ProxyWorkRun, WorkRunner
@@ -203,7 +203,8 @@ def run(self):
pass
res = delta_queue._queue[0].delta.to_dict()["iterable_item_added"]
- res_end = delta_queue._queue[1].delta.to_dict()["iterable_item_added"]
+ index = 1 if len(delta_queue._queue) == 2 else 2
+ res_end = delta_queue._queue[index].delta.to_dict()["iterable_item_added"]
if enable_exception:
exception_cls = Exception if raise_exception else Empty
assert isinstance(error_queue._queue[0], exception_cls)
@@ -211,7 +212,8 @@ def run(self):
res_end[f"root['calls']['{call_hash}']['statuses'][1]"]["message"] == "Custom Exception"
else:
assert res[f"root['calls']['{call_hash}']['statuses'][0]"]["stage"] == "running"
- assert res_end[f"root['calls']['{call_hash}']['statuses'][1]"]["stage"] == "succeeded"
+ key = f"root['calls']['{call_hash}']['statuses'][1]"
+ assert res_end[key]["stage"] == "succeeded"
# Stop blocking and let the thread join
work_runner.copier.join()
@@ -384,3 +386,36 @@ def run(self):
def test_lightning_app_work_start(cache_calls, parallel):
app = LightningApp(FlowStart(cache_calls, parallel))
MultiProcessRuntime(app, start_server=False).dispatch()
+
+
+def test_lightning_work_delete():
+ work = WorkCounter()
+
+ with pytest.raises(Exception, match="Can't delete the work"):
+ work.delete()
+
+ mock = MagicMock()
+ work._backend = mock
+ work.delete()
+ assert work == mock.delete_work._mock_call_args_list[0].args[1]
+
+
+class WorkDisplay(LightningWork):
+ def __init__(self):
+ super().__init__()
+
+ def run(self):
+ pass
+
+
+def test_lightning_work_display_name():
+ work = WorkDisplay()
+ assert work.state_vars["vars"]["_display_name"] == ""
+ work.display_name = "Hello"
+ assert work.state_vars["vars"]["_display_name"] == "Hello"
+
+ work._calls["latest_call_hash"] = "test"
+ work._calls["test"] = {"statuses": [make_status(WorkStageStatus.PENDING)]}
+ with pytest.raises(RuntimeError, match="The display name can be set only before the work has started."):
+ work.display_name = "HELLO"
+ work.display_name = "Hello"
diff --git a/tests/tests_app/runners/backends/__init__.py b/tests/tests_app/runners/backends/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/tests_app/runners/backends/test_mp_process.py b/tests/tests_app/runners/backends/test_mp_process.py
new file mode 100644
index 0000000000000..868a5b37717da
--- /dev/null
+++ b/tests/tests_app/runners/backends/test_mp_process.py
@@ -0,0 +1,28 @@
+from unittest import mock
+from unittest.mock import MagicMock, Mock
+
+from lightning_app import LightningApp, LightningWork
+from lightning_app.runners.backends import MultiProcessingBackend
+
+
+@mock.patch("lightning_app.core.app.AppStatus")
+@mock.patch("lightning_app.runners.backends.mp_process.multiprocessing")
+def test_backend_create_work_with_set_start_method(multiprocessing_mock, *_):
+ backend = MultiProcessingBackend(entrypoint_file="fake.py")
+ work = Mock(spec=LightningWork)
+ work._start_method = "test_start_method"
+
+ app = LightningApp(work)
+ app.caller_queues = MagicMock()
+ app.delta_queue = MagicMock()
+ app.readiness_queue = MagicMock()
+ app.error_queue = MagicMock()
+ app.request_queues = MagicMock()
+ app.response_queues = MagicMock()
+ app.copy_request_queues = MagicMock()
+ app.copy_response_queues = MagicMock()
+ app.flow_to_work_delta_queues = MagicMock()
+
+ backend.create_work(app=app, work=work)
+ multiprocessing_mock.get_context.assert_called_with("test_start_method")
+ multiprocessing_mock.get_context().Process().start.assert_called_once()
diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py
index cb4bd5ddaa3c0..8ea068fc0992c 100644
--- a/tests/tests_app/runners/test_cloud.py
+++ b/tests/tests_app/runners/test_cloud.py
@@ -675,7 +675,7 @@ def test_call_with_queue_server_type_specified(self, lightningapps, monkeypatch,
)
# calling with env variable set to http
- monkeypatch.setattr(cloud, "CLOUD_QUEUE_TYPE", "http")
+ monkeypatch.setitem(os.environ, "LIGHTNING_CLOUD_QUEUE_TYPE", "http")
cloud_runtime.backend.client.reset_mock()
cloud_runtime.dispatch()
body = IdGetBody(
@@ -1383,9 +1383,8 @@ def test_get_project(monkeypatch):
V1Membership(name="test-project2", project_id="test-project-id2"),
]
)
- with pytest.warns(UserWarning, match="Defaulting to the project test-project1"):
- ret = _get_project(mock_client)
- assert ret.project_id == "test-project-id1"
+ ret = _get_project(mock_client)
+ assert ret.project_id == "test-project-id1"
def write_file_of_size(path, size):
diff --git a/tests/tests_app/runners/test_multiprocess.py b/tests/tests_app/runners/test_multiprocess.py
index 2e1a34ab38677..48bbedf555d63 100644
--- a/tests/tests_app/runners/test_multiprocess.py
+++ b/tests/tests_app/runners/test_multiprocess.py
@@ -68,7 +68,7 @@ def run(self):
assert _get_context().value == "work"
-class ContxtFlow(LightningFlow):
+class ContextFlow(LightningFlow):
def __init__(self):
super().__init__()
self.work = ContextWork()
@@ -83,7 +83,7 @@ def run(self):
def test_multiprocess_runtime_sets_context():
"""Test that the runtime sets the global variable COMPONENT_CONTEXT in Flow and Work."""
- MultiProcessRuntime(LightningApp(ContxtFlow())).dispatch()
+ MultiProcessRuntime(LightningApp(ContextFlow())).dispatch()
@pytest.mark.parametrize(
diff --git a/tests/tests_app/storage/test_orchestrator.py b/tests/tests_app/storage/test_orchestrator.py
index ca671e6b93704..4b391a890f1a9 100644
--- a/tests/tests_app/storage/test_orchestrator.py
+++ b/tests/tests_app/storage/test_orchestrator.py
@@ -39,7 +39,7 @@ def test_orchestrator():
# orchestrator is now waiting for a response for copier in Work A
assert "work_b" in orchestrator.waiting_for_response
- assert not request_queues["work_a"]
+ assert len(request_queues["work_a"]) == 0
assert request in copy_request_queues["work_a"]
assert request.destination == "work_b"
@@ -54,7 +54,7 @@ def test_orchestrator():
# orchestrator processes confirmation and confirms to the pending request from Work B
orchestrator.run_once("work_a")
- assert not copy_response_queues["work_a"]
+ assert len(copy_response_queues["work_a"]) == 0
assert response in response_queues["work_b"]
assert not orchestrator.waiting_for_response
orchestrator.run_once("work_b")
@@ -71,7 +71,7 @@ def test_orchestrator():
assert response.exception is None
# all queues should be empty
- assert all(not queue for queue in request_queues.values())
- assert all(not queue for queue in response_queues.values())
- assert all(not queue for queue in copy_request_queues.values())
- assert all(not queue for queue in copy_response_queues.values())
+ assert all(len(queue) == 0 for queue in request_queues.values())
+ assert all(len(queue) == 0 for queue in response_queues.values())
+ assert all(len(queue) == 0 for queue in copy_request_queues.values())
+ assert all(len(queue) == 0 for queue in copy_response_queues.values())
diff --git a/tests/tests_app/storage/test_path.py b/tests/tests_app/storage/test_path.py
index 3cd501f7344c8..2310b8034c303 100644
--- a/tests/tests_app/storage/test_path.py
+++ b/tests/tests_app/storage/test_path.py
@@ -606,7 +606,7 @@ def test_path_response_not_matching_reqeuest(tmpdir):
path.get()
# simulate a response that has a different hash than the request had
- assert not response_queue
+ assert len(response_queue) == 0
response.path = str(path)
response.hash = "other_hash"
response_queue.put(response)
diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py
index 3346da5a858fc..852589a4443eb 100644
--- a/tests/tests_app/structures/test_structures.py
+++ b/tests/tests_app/structures/test_structures.py
@@ -44,6 +44,7 @@ def run(self):
"_host": "127.0.0.1",
"_paths": {},
"_restarting": False,
+ "_display_name": "",
"_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@@ -76,6 +77,7 @@ def run(self):
"_host": "127.0.0.1",
"_paths": {},
"_restarting": False,
+ "_display_name": "",
"_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@@ -108,6 +110,7 @@ def run(self):
"_host": "127.0.0.1",
"_paths": {},
"_restarting": False,
+ "_display_name": "",
"_internal_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@@ -193,6 +196,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
@@ -225,6 +229,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
@@ -252,6 +257,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "default",
diff --git a/tests/tests_app/utilities/packaging/test_cloud_compute.py b/tests/tests_app/utilities/packaging/test_cloud_compute.py
index f2670723f132a..67b5a25ab8c46 100644
--- a/tests/tests_app/utilities/packaging/test_cloud_compute.py
+++ b/tests/tests_app/utilities/packaging/test_cloud_compute.py
@@ -14,6 +14,12 @@ def test_cloud_compute_shared_memory():
cloud_compute = CloudCompute("gpu", shm_size=1100)
assert cloud_compute.shm_size == 1100
+ cloud_compute = CloudCompute("gpu")
+ assert cloud_compute.shm_size == 1024
+
+ cloud_compute = CloudCompute("cpu")
+ assert cloud_compute.shm_size == 0
+
def test_cloud_compute_with_mounts():
mount_1 = Mount(source="s3://foo/", mount_path="/foo")
diff --git a/tests/tests_app/utilities/test_cloud.py b/tests/tests_app/utilities/test_cloud.py
index 573ec46106b84..6e93ad1e68d57 100644
--- a/tests/tests_app/utilities/test_cloud.py
+++ b/tests/tests_app/utilities/test_cloud.py
@@ -4,13 +4,18 @@
from lightning_app.utilities.cloud import is_running_in_cloud
-@mock.patch.dict(os.environ, clear=True)
-def test_is_running_locally():
- """We can determine if Lightning is running locally."""
- assert not is_running_in_cloud()
-
-
-@mock.patch.dict(os.environ, {"LIGHTNING_APP_STATE_URL": "127.0.0.1"})
def test_is_running_cloud():
"""We can determine if Lightning is running in the cloud."""
- assert is_running_in_cloud()
+ with mock.patch.dict(os.environ, {}, clear=True):
+ assert not is_running_in_cloud()
+
+ with mock.patch.dict(os.environ, {"LAI_RUNNING_IN_CLOUD": "0"}, clear=True):
+ assert not is_running_in_cloud()
+
+ # in the cloud, LIGHTNING_APP_STATE_URL is defined
+ with mock.patch.dict(os.environ, {"LIGHTNING_APP_STATE_URL": "defined"}, clear=True):
+ assert is_running_in_cloud()
+
+ # LAI_RUNNING_IN_CLOUD is used to fake the value of `is_running_in_cloud` when loading the app for --cloud
+ with mock.patch.dict(os.environ, {"LAI_RUNNING_IN_CLOUD": "1"}):
+ assert is_running_in_cloud()
diff --git a/tests/tests_app/utilities/test_commands.py b/tests/tests_app/utilities/test_commands.py
index 81415cee7b7d8..87623bb4547ed 100644
--- a/tests/tests_app/utilities/test_commands.py
+++ b/tests/tests_app/utilities/test_commands.py
@@ -160,3 +160,4 @@ def test_configure_commands(monkeypatch):
time_left -= 0.1
assert process.exitcode == 0
disconnect()
+ process.kill()
diff --git a/tests/tests_app/utilities/test_git.py b/tests/tests_app/utilities/test_git.py
index cb2db0a2bfe33..554c32d6fd82d 100644
--- a/tests/tests_app/utilities/test_git.py
+++ b/tests/tests_app/utilities/test_git.py
@@ -1,5 +1,7 @@
import sys
+import pytest
+
from lightning_app.utilities.git import (
check_github_repository,
check_if_remote_head_is_different,
@@ -10,6 +12,7 @@
)
+@pytest.mark.skipif(sys.platform == "win32", reason="Don't run on windows")
def test_execute_git_command():
res = execute_git_command(["pull"])
diff --git a/tests/tests_app/utilities/test_load_app.py b/tests/tests_app/utilities/test_load_app.py
index 573f73a670aad..c92c4261daab6 100644
--- a/tests/tests_app/utilities/test_load_app.py
+++ b/tests/tests_app/utilities/test_load_app.py
@@ -85,7 +85,7 @@ def test_extract_metadata_from_component():
"name": "gpu",
"disk_size": 0,
"idle_timeout": None,
- "shm_size": 0,
+ "shm_size": 1024,
"mounts": None,
"_internal_id": ANY,
},
diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py
index 4b8a5f25f71e3..42d1fb8f82ba6 100644
--- a/tests/tests_app/utilities/test_proxies.py
+++ b/tests/tests_app/utilities/test_proxies.py
@@ -67,8 +67,9 @@ def proxy_setattr():
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("cache_calls", [False, True])
+@mock.patch("lightning_app.utilities.proxies._Copier", MagicMock())
@pytest.mark.skipif(sys.platform == "win32", reason="TODO (@ethanwharris): Fix this on Windows")
-def test_work_runner(parallel, cache_calls):
+def test_work_runner(parallel, cache_calls, *_):
"""This test validates the `WorkRunner` runs the work.run method and properly populates the `delta_queue`,
`error_queue` and `readiness_queue`."""
@@ -149,13 +150,14 @@ def get(self, timeout: int = 0):
assert isinstance(error_queue._queue[0], Exception)
else:
assert isinstance(error_queue._queue[0], Empty)
- assert len(delta_queue._queue) == 3
+ assert len(delta_queue._queue) in [3, 4]
res = delta_queue._queue[0].delta.to_dict()["iterable_item_added"]
assert res[f"root['calls']['{call_hash}']['statuses'][0]"]["stage"] == "running"
assert delta_queue._queue[1].delta.to_dict() == {
"values_changed": {"root['vars']['counter']": {"new_value": 1}}
}
- res = delta_queue._queue[2].delta.to_dict()["dictionary_item_added"]
+ index = 3 if len(delta_queue._queue) == 4 else 2
+ res = delta_queue._queue[index].delta.to_dict()["dictionary_item_added"]
assert res[f"root['calls']['{call_hash}']['ret']"] is None
# Stop blocking and let the thread join
@@ -250,6 +252,7 @@ def __call__(self):
state = deepcopy(self.work.state)
self.work._calls[call_hash]["statuses"].append(
{
+ "name": self.work.name,
"stage": WorkStageStatus.FAILED,
"reason": WorkFailureReasons.TIMEOUT,
"timestamp": time.time(),
@@ -547,7 +550,7 @@ def run(self, use_setattr=False, use_containers=False):
# 1. Simulate no state changes
##############################
work.run(use_setattr=False, use_containers=False)
- assert not delta_queue
+ assert len(delta_queue) == 0
############################
# 2. Simulate a setattr call
@@ -563,16 +566,16 @@ def run(self, use_setattr=False, use_containers=False):
assert len(observer._delta_memory) == 1
# The observer should not trigger any deltas being sent and only consume the delta memory
- assert not delta_queue
+ assert len(delta_queue) == 0
observer.run_once()
- assert not delta_queue
+ assert len(delta_queue) == 0
assert not observer._delta_memory
################################
# 3. Simulate a container update
################################
work.run(use_setattr=False, use_containers=True)
- assert not delta_queue
+ assert len(delta_queue) == 0
assert not observer._delta_memory
observer.run_once()
observer.run_once() # multiple runs should not affect how many deltas are sent unless there are changes
@@ -591,7 +594,7 @@ def run(self, use_setattr=False, use_containers=False):
delta = delta_queue.get().delta.to_dict()
assert delta == {"values_changed": {"root['vars']['var']": {"new_value": 3}}}
- assert not delta_queue
+ assert len(delta_queue) == 0
assert len(observer._delta_memory) == 1
observer.run_once()
@@ -599,7 +602,7 @@ def run(self, use_setattr=False, use_containers=False):
assert delta["values_changed"] == {"root['vars']['dict']['counter']": {"new_value": 2}}
assert delta["iterable_item_added"] == {"root['vars']['list'][1]": 1}
- assert not delta_queue
+ assert len(delta_queue) == 0
assert not observer._delta_memory
diff --git a/tests/tests_examples_app/public/test_boring_app.py b/tests/tests_examples_app/public/test_boring_app.py
index a5177d0a18062..80be34745a1de 100644
--- a/tests/tests_examples_app/public/test_boring_app.py
+++ b/tests/tests_examples_app/public/test_boring_app.py
@@ -10,7 +10,7 @@
@pytest.mark.cloud
def test_boring_app_example_cloud() -> None:
- with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "app_boring"), app_name="app_dynamic.py", debug=True,) as (
+ with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "app_boring"), app_name="app_dynamic.py", debug=True) as (
_,
view_page,
_,
diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py
index 5e864cea3568d..8694a87489403 100644
--- a/tests/tests_pytorch/test_cli.py
+++ b/tests/tests_pytorch/test_cli.py
@@ -1233,7 +1233,7 @@ def configure_optimizers(self, optimizer, lr_scheduler=None):
[optimizer], [scheduler] = cli.model.configure_optimizers()
assert isinstance(optimizer, SGD)
assert isinstance(scheduler, StepLR)
- with mock.patch("sys.argv", ["any.py", "--lr_scheduler=StepLR"]):
+ with mock.patch("sys.argv", ["any.py", "--lr_scheduler=StepLR", "--lr_scheduler.step_size=50"]):
cli = MyCLI()
[optimizer], [scheduler] = cli.model.configure_optimizers()
assert isinstance(optimizer, SGD)