From 1dfccca85d4ab87f1238b8554e0a49f8f6a20e8e Mon Sep 17 00:00:00 2001 From: Michael Sharp <51342856+michaelgsharp@users.noreply.github.com> Date: Mon, 11 Oct 2021 18:19:36 -0700 Subject: [PATCH] Prediction engine options (#5964) Co-authored-by: Eric Erhardt --- .../Model/ModelOperationsCatalog.cs | 17 ++++++++ .../Model/PredictionEngineExtensions.cs | 5 ++- .../Prediction/PredictionEngine.cs | 39 ++++++++++++++++-- .../PredictionEngine.cs | 37 +++++++++++++++++ .../Prediction.cs | 41 +++++++++++++++++++ 5 files changed, 133 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs index b449bdfa8e..b9f9f4bb80 100644 --- a/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs +++ b/src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs @@ -333,5 +333,22 @@ public ITransformer LoadWithDataLoader(string filePath, out IDataLoader(_env, false, DataViewConstructionUtils.GetSchemaDefinition(_env, inputSchema)); } + + /// + /// Create a prediction engine for one-time prediction. + /// It's mainly used in conjunction with , + /// where input schema is extracted during loading the model. + /// + /// The class that defines the input data. + /// The class that defines the output data. + /// The transformer to use for prediction. + /// Advanced configuration options. + public PredictionEngine CreatePredictionEngine(ITransformer transformer, PredictionEngineOptions options) + where TSrc : class + where TDst : class, new() + { + return transformer.CreatePredictionEngine(_env, options.IgnoreMissingColumns, + options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnsTransformer); + } } } diff --git a/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs b/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs index 35d7fbf148..2e162b764a 100644 --- a/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs +++ b/src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs @@ -24,10 +24,11 @@ internal static class PredictionEngineExtensions /// . /// Additional settings of the input schema. /// Additional settings of the output schema. + /// Whether the prediction engine owns the transformer and should dispose of it. public static PredictionEngine CreatePredictionEngine(this ITransformer transformer, - IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownsTransformer = true) where TSrc : class where TDst : class, new() - => new PredictionEngine(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition); + => new PredictionEngine(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, ownsTransformer); } } diff --git a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs index b932369570..8fdac4874b 100644 --- a/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs +++ b/src/Microsoft.ML.Data/Prediction/PredictionEngine.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using Microsoft.ML.CommandLine; using Microsoft.ML.Data; using Microsoft.ML.Runtime; @@ -58,8 +59,8 @@ public sealed class PredictionEngine : PredictionEngineBase : IDisposable private readonly DataViewConstructionUtils.InputRow _inputRow; private readonly IRowReadableAs _outputRow; private readonly Action _disposer; + private readonly bool _ownsTransformer; private bool _disposed; /// @@ -104,7 +106,7 @@ public abstract class PredictionEngineBase : IDisposable [BestFriend] private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns, - SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownsTransformer = true) { Contracts.CheckValue(env, nameof(env)); env.AssertValue(transformer); @@ -112,6 +114,7 @@ public abstract class PredictionEngineBase : IDisposable var makeMapper = TransformerChecker(env, transformer); env.AssertValue(makeMapper); _inputRow = DataViewConstructionUtils.CreateInputRow(env, inputSchemaDefinition); + _ownsTransformer = ownsTransformer; PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, outputSchemaDefinition, out _disposer, out _outputRow); OutputSchema = Transformer.GetOutputSchema(_inputRow.Schema); } @@ -139,7 +142,9 @@ public void Dispose() return; _disposer?.Invoke(); - (Transformer as IDisposable)?.Dispose(); + + if (_ownsTransformer) + (Transformer as IDisposable)?.Dispose(); _disposed = true; } @@ -170,4 +175,30 @@ public TDst Predict(TSrc example) /// is reused. public abstract void Predict(TSrc example, ref TDst prediction); } + + /// + /// Options for the + /// + public sealed class PredictionEngineOptions + { + [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to throw an error if a column exists in the output schema but not the output object.", ShortName = "ignore", SortOrder = 50)] + public bool IgnoreMissingColumns = Defaults.IgnoreMissingColumns; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Additional settings of the input schema.", ShortName = "input", SortOrder = 50)] + public SchemaDefinition InputSchemaDefinition = Defaults.InputSchemaDefinition; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Additional settings of the output schema.", ShortName = "output")] + public SchemaDefinition OutputSchemaDefinition = Defaults.OutputSchemaDefinition; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Whether the prediction engine owns the transformer and should dispose of it.", ShortName = "own")] + public bool OwnsTransformer = Defaults.OwnsTransformer; + + internal static class Defaults + { + public const bool IgnoreMissingColumns = true; + public const SchemaDefinition InputSchemaDefinition = null; + public const SchemaDefinition OutputSchemaDefinition = null; + public const bool OwnsTransformer = true; + } + } } diff --git a/src/Microsoft.ML.TimeSeries/PredictionEngine.cs b/src/Microsoft.ML.TimeSeries/PredictionEngine.cs index 7f75ad1d63..d1742dc941 100644 --- a/src/Microsoft.ML.TimeSeries/PredictionEngine.cs +++ b/src/Microsoft.ML.TimeSeries/PredictionEngine.cs @@ -148,6 +148,15 @@ private static ITransformer CloneTransformers(ITransformer transformer) { } + /// + /// Contructor for creating time series specific prediction engine. It allows the time series model to be updated with the observations + /// seen at prediction time via + /// + internal TimeSeriesPredictionEngine(IHostEnvironment env, ITransformer transformer, PredictionEngineOptions options) : + base(env, CloneTransformers(transformer), options.IgnoreMissingColumns, options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnsTransformer) + { + } + internal DataViewRow GetStatefulRows(DataViewRow input, IRowToRowMapper mapper, IEnumerable activeColumns, List rows) { Contracts.CheckValue(input, nameof(input)); @@ -398,5 +407,33 @@ public static class PredictionFunctionExtensions env.CheckValueOrNull(outputSchemaDefinition); return new TimeSeriesPredictionEngine(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition); } + + /// + /// creates a prediction engine for a time series pipeline. + /// It updates the state of time series model with observations seen at prediction phase and allows checkpointing the model. + /// + /// Class describing input schema to the model. + /// Class describing the output schema of the prediction. + /// The time series pipeline in the form of a . + /// Usually + /// Advanced configuration options. + ///

Example code can be found by searching for TimeSeriesPredictionEngine in ML.NET.

+ /// + /// + /// + /// + /// + public static TimeSeriesPredictionEngine CreateTimeSeriesEngine(this ITransformer transformer, IHostEnvironment env, + PredictionEngineOptions options) + where TSrc : class + where TDst : class, new() + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(options, nameof(options)); + return new TimeSeriesPredictionEngine(env, transformer, options); + } } } diff --git a/test/Microsoft.ML.IntegrationTests/Prediction.cs b/test/Microsoft.ML.IntegrationTests/Prediction.cs index 40df4c104b..ec5c333561 100644 --- a/test/Microsoft.ML.IntegrationTests/Prediction.cs +++ b/test/Microsoft.ML.IntegrationTests/Prediction.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Reflection; using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.IntegrationTests.Datasets; @@ -97,5 +98,45 @@ public void ReconfigurablePredictionNoPipeline() Assert.True(pr.Score <= 0); } + [Fact] + public void PredictionEngineModelDisposal() + { + var mlContext = new MLContext(seed: 1); + var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset()); + var pipeline = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression( + new Trainers.LbfgsLogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }); + var model = pipeline.Fit(data); + + var engine = mlContext.Model.CreatePredictionEngine(model, new PredictionEngineOptions()); + + // Dispose of prediction engine, should dispose of model + engine.Dispose(); + + // Get disposed flag using reflection + var bfIsDisposed = BindingFlags.Instance | BindingFlags.NonPublic; + var field = model.GetType().BaseType.BaseType.GetField("_disposed", bfIsDisposed); + + // Make sure the model is actually disposed + Assert.True((bool)field.GetValue(model)); + + // Make a new model/prediction engine. Set the options so prediction engine doesn't dispose + model = pipeline.Fit(data); + + var options = new PredictionEngineOptions() + { + OwnsTransformer = false + }; + + engine = mlContext.Model.CreatePredictionEngine(model, options); + + // Dispose of prediction engine, shouldn't dispose of model + engine.Dispose(); + + // Make sure model is not disposed of. + Assert.False((bool)field.GetValue(model)); + + // Dispose of the model for test cleanliness + model.Dispose(); + } } }