Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow log to an existing run ID in MLflow with MLFlowLogger #12290

Merged
merged 22 commits into from Mar 27, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6719acb
specify mlflow exception type for mlflow get_run call
Kr4is Mar 9, 2022
7a6cf0f
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 10, 2022
0a52a8a
update changelog
Kr4is Mar 10, 2022
9d75176
place run_id parameter at constructor end
Kr4is Mar 11, 2022
e98c390
move changelog description from changed to added
Kr4is Mar 11, 2022
c050e86
Update CHANGELOG.md
Kr4is Mar 12, 2022
28d606d
set existing run experiment_id if run_if already exists
Kr4is Mar 12, 2022
bd11951
add mlflow logger tests
Kr4is Mar 12, 2022
c6aa700
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 12, 2022
e43e6c3
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 18, 2022
df7e23e
solve logger initialization and tests
Kr4is Mar 18, 2022
e73d60e
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 18, 2022
c22e564
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2022
2500173
remove comment
Kr4is Mar 18, 2022
075099d
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 19, 2022
319db6f
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 21, 2022
83c5304
move run_id to the end documentation args
Kr4is Mar 23, 2022
606f8a9
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 23, 2022
2b517a3
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 24, 2022
be67516
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 25, 2022
9c76ad3
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 26, 2022
260d735
Merge branch 'PyTorchLightning:master' into master
Kr4is Mar 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Allow logging to an existing run ID in MLflow with `MLFlowLogger` ([#12290](https://github.com/PyTorchLightning/pytorch-lightning/pull/12290))


- Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911))


Expand Down
19 changes: 17 additions & 2 deletions pytorch_lightning/loggers/mlflow.py
Expand Up @@ -87,7 +87,8 @@ def any_lightning_module_function_or_hook(self):
self.logger.experiment.whatever_ml_flow_supports(...)

Args:
experiment_name: The name of the experiment
run_id: The run identifier of the experiment. If not provided, a new run is started.
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
experiment_name: The name of the experiment.
run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag.
If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`.
tracking_uri: Address of local or remote tracking server.
Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
save_dir: Optional[str] = "./mlruns",
prefix: str = "",
artifact_location: Optional[str] = None,
run_id: Optional[str] = None,
):
if mlflow is None:
raise ModuleNotFoundError(
Expand All @@ -130,11 +132,13 @@ def __init__(
self._experiment_id = None
self._tracking_uri = tracking_uri
self._run_name = run_name
self._run_id = None
self._run_id = run_id
self.tags = tags
self._prefix = prefix
self._artifact_location = artifact_location

self._initialized = False

self._mlflow_client = MlflowClient(tracking_uri)

@property
Expand All @@ -149,6 +153,16 @@ def experiment(self) -> MlflowClient:
self.logger.experiment.some_mlflow_function()

"""

if self._initialized:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
return self._mlflow_client

if self._run_id is not None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
run = self._mlflow_client.get_run(self._run_id)
self._experiment_id = run.info.experiment_id
self._initialized = True
return self._mlflow_client

if self._experiment_id is None:
expt = self._mlflow_client.get_experiment_by_name(self._experiment_name)
if expt is not None:
Expand All @@ -169,6 +183,7 @@ def experiment(self) -> MlflowClient:
self.tags[MLFLOW_RUN_NAME] = self._run_name
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
self._run_id = run.info.run_id
self._initialized = True
return self._mlflow_client

@property
Expand Down
22 changes: 22 additions & 0 deletions tests/loggers/test_mlflow.py
Expand Up @@ -40,6 +40,7 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):

run1 = MagicMock()
run1.info.run_id = "run-id-1"
run1.info.experiment_id = "exp-id-1"

run2 = MagicMock()
run2.info.run_id = "run-id-2"
Expand Down Expand Up @@ -113,6 +114,27 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir):
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_run_id_setting(client, mlflow, tmpdir):
"""Test that the run_id argument uses the provided run_id."""

run = MagicMock()
run.info.run_id = "run-id"
run.info.experiment_id = "experiment-id"

# simulate existing run
client.return_value.get_run = MagicMock(return_value=run)

# run_id exists uses the existing run
logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=tmpdir)
_ = logger.experiment
client.return_value.get_run.assert_called_with(run.info.run_id)
assert logger.experiment_id == run.info.experiment_id
assert logger.run_id == run.info.run_id
client.reset_mock(return_value=True)


@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_log_dir(client, mlflow, tmpdir):
Expand Down