diff --git a/CHANGELOG.md b/CHANGELOG.md index 46bb8106bf80b..0126950f2697b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Allow logging to an existing run ID in MLflow with `MLFlowLogger` ([#12290](https://github.com/PyTorchLightning/pytorch-lightning/pull/12290)) + + - Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911)) diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index f43d8358e1d43..a74dda324e15f 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -87,7 +87,7 @@ def any_lightning_module_function_or_hook(self): self.logger.experiment.whatever_ml_flow_supports(...) Args: - experiment_name: The name of the experiment + experiment_name: The name of the experiment. run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag. If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`. tracking_uri: Address of local or remote tracking server. @@ -100,6 +100,7 @@ def any_lightning_module_function_or_hook(self): prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. + run_id: The run identifier of the experiment. If not provided, a new run is started. Raises: ModuleNotFoundError: @@ -117,6 +118,7 @@ def __init__( save_dir: Optional[str] = "./mlruns", prefix: str = "", artifact_location: Optional[str] = None, + run_id: Optional[str] = None, ): if mlflow is None: raise ModuleNotFoundError( @@ -130,11 +132,13 @@ def __init__( self._experiment_id = None self._tracking_uri = tracking_uri self._run_name = run_name - self._run_id = None + self._run_id = run_id self.tags = tags self._prefix = prefix self._artifact_location = artifact_location + self._initialized = False + self._mlflow_client = MlflowClient(tracking_uri) @property @@ -149,6 +153,16 @@ def experiment(self) -> MlflowClient: self.logger.experiment.some_mlflow_function() """ + + if self._initialized: + return self._mlflow_client + + if self._run_id is not None: + run = self._mlflow_client.get_run(self._run_id) + self._experiment_id = run.info.experiment_id + self._initialized = True + return self._mlflow_client + if self._experiment_id is None: expt = self._mlflow_client.get_experiment_by_name(self._experiment_name) if expt is not None: @@ -169,6 +183,7 @@ def experiment(self) -> MlflowClient: self.tags[MLFLOW_RUN_NAME] = self._run_name run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags)) self._run_id = run.info.run_id + self._initialized = True return self._mlflow_client @property diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 77afe361b035f..1e9e8ec271bd5 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -40,6 +40,7 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir): run1 = MagicMock() run1.info.run_id = "run-id-1" + run1.info.experiment_id = "exp-id-1" run2 = MagicMock() run2.info.run_id = "run-id-2" @@ -113,6 +114,27 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir): client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags) +@mock.patch("pytorch_lightning.loggers.mlflow.mlflow") +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_run_id_setting(client, mlflow, tmpdir): + """Test that the run_id argument uses the provided run_id.""" + + run = MagicMock() + run.info.run_id = "run-id" + run.info.experiment_id = "experiment-id" + + # simulate existing run + client.return_value.get_run = MagicMock(return_value=run) + + # run_id exists uses the existing run + logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=tmpdir) + _ = logger.experiment + client.return_value.get_run.assert_called_with(run.info.run_id) + assert logger.experiment_id == run.info.experiment_id + assert logger.run_id == run.info.run_id + client.reset_mock(return_value=True) + + @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_log_dir(client, mlflow, tmpdir):