diff --git a/mlflow/R/mlflow/R/tracking-runs.R b/mlflow/R/mlflow/R/tracking-runs.R index 509bd23c10087..4336b6bc5e95d 100644 --- a/mlflow/R/mlflow/R/tracking-runs.R +++ b/mlflow/R/mlflow/R/tracking-runs.R @@ -599,6 +599,12 @@ mlflow_get_run_context.default <- function(client, experiment_id, ...) { tags[[MLFLOW_TAGS$MLFLOW_SOURCE_NAME]] <- get_source_name() tags[[MLFLOW_TAGS$MLFLOW_SOURCE_VERSION]] <- get_source_version() tags[[MLFLOW_TAGS$MLFLOW_SOURCE_TYPE]] <- MLFLOW_SOURCE_TYPE$LOCAL + parent_run_id <- mlflow_get_active_run_id() + if (!is.null(parent_run_id)) { + # create a tag containing the parent run ID so that MLflow UI can display + # nested runs properly + tags[[MLFLOW_TAGS$MLFLOW_PARENT_RUN_ID]] <- parent_run_id + } list( client = client, tags = tags, @@ -649,5 +655,6 @@ MLFLOW_TAGS <- list( MLFLOW_USER = "mlflow.user", MLFLOW_SOURCE_NAME = "mlflow.source.name", MLFLOW_SOURCE_VERSION = "mlflow.source.version", - MLFLOW_SOURCE_TYPE = "mlflow.source.type" + MLFLOW_SOURCE_TYPE = "mlflow.source.type", + MLFLOW_PARENT_RUN_ID = "mlflow.parentRunId" ) diff --git a/mlflow/R/mlflow/tests/testthat/test-tracking-runs.R b/mlflow/R/mlflow/tests/testthat/test-tracking-runs.R index bd744bb90b113..920f149098ae5 100644 --- a/mlflow/R/mlflow/tests/testthat/test-tracking-runs.R +++ b/mlflow/R/mlflow/tests/testthat/test-tracking-runs.R @@ -77,6 +77,13 @@ test_that("mlflow_start_run()/mlflow_end_run() works properly with nested runs", expect_equal(mlflow:::mlflow_get_active_run_id(), runs[[i]]$run_uuid) run <- mlflow_end_run(client = client, run_id = runs[[i]]$run_uuid) expect_identical(run$run_uuid, runs[[i]]$run_uuid) + if (i > 1) { + tags <- run$tags[[1]] + expect_equal( + tags[tags$key == "mlflow.parentRunId",]$value, + runs[[i - 1]]$run_uuid + ) + } } expect_null(mlflow:::mlflow_get_active_run_id()) })