Skip to content

Commit

Permalink
Prediction engine options (#5964)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Erhardt <eric.erhardt@microsoft.com>
  • Loading branch information
michaelgsharp and eerhardt committed Oct 12, 2021
1 parent f696661 commit 1dfccca
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 6 deletions.
17 changes: 17 additions & 0 deletions src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs
Expand Up @@ -333,5 +333,22 @@ public ITransformer LoadWithDataLoader(string filePath, out IDataLoader<IMultiSt
return transformer.CreatePredictionEngine<TSrc, TDst>(_env, false,
DataViewConstructionUtils.GetSchemaDefinition<TSrc>(_env, inputSchema));
}

/// <summary>
/// Create a prediction engine for one-time prediction.
/// It's mainly used in conjunction with <see cref="Load(Stream, out DataViewSchema)"/>,
/// where input schema is extracted during loading the model.
/// </summary>
/// <typeparam name="TSrc">The class that defines the input data.</typeparam>
/// <typeparam name="TDst">The class that defines the output data.</typeparam>
/// <param name="transformer">The transformer to use for prediction.</param>
/// <param name="options">Advanced configuration options.</param>
public PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(ITransformer transformer, PredictionEngineOptions options)
where TSrc : class
where TDst : class, new()
{
return transformer.CreatePredictionEngine<TSrc, TDst>(_env, options.IgnoreMissingColumns,
options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnsTransformer);
}
}
}
5 changes: 3 additions & 2 deletions src/Microsoft.ML.Data/Model/PredictionEngineExtensions.cs
Expand Up @@ -24,10 +24,11 @@ internal static class PredictionEngineExtensions
/// <typeparamref name="TDst"/>.</param>
/// <param name="inputSchemaDefinition">Additional settings of the input schema.</param>
/// <param name="outputSchemaDefinition">Additional settings of the output schema.</param>
/// <param name="ownsTransformer">Whether the prediction engine owns the transformer and should dispose of it.</param>
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this ITransformer transformer,
IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
IHostEnvironment env, bool ignoreMissingColumns = true, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownsTransformer = true)
where TSrc : class
where TDst : class, new()
=> new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
=> new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, ownsTransformer);
}
}
39 changes: 35 additions & 4 deletions src/Microsoft.ML.Data/Prediction/PredictionEngine.cs
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

Expand Down Expand Up @@ -58,8 +59,8 @@ public sealed class PredictionEngine<TSrc, TDst> : PredictionEngineBase<TSrc, TD
where TDst : class, new()
{
internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownsTransformer = true)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, ownsTransformer)
{
}

Expand Down Expand Up @@ -92,6 +93,7 @@ public abstract class PredictionEngineBase<TSrc, TDst> : IDisposable
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
private readonly IRowReadableAs<TDst> _outputRow;
private readonly Action _disposer;
private readonly bool _ownsTransformer;
private bool _disposed;

/// <summary>
Expand All @@ -104,14 +106,15 @@ public abstract class PredictionEngineBase<TSrc, TDst> : IDisposable

[BestFriend]
private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null, bool ownsTransformer = true)
{
Contracts.CheckValue(env, nameof(env));
env.AssertValue(transformer);
Transformer = transformer;
var makeMapper = TransformerChecker(env, transformer);
env.AssertValue(makeMapper);
_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
_ownsTransformer = ownsTransformer;
PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, outputSchemaDefinition, out _disposer, out _outputRow);
OutputSchema = Transformer.GetOutputSchema(_inputRow.Schema);
}
Expand Down Expand Up @@ -139,7 +142,9 @@ public void Dispose()
return;

_disposer?.Invoke();
(Transformer as IDisposable)?.Dispose();

if (_ownsTransformer)
(Transformer as IDisposable)?.Dispose();

_disposed = true;
}
Expand Down Expand Up @@ -170,4 +175,30 @@ public TDst Predict(TSrc example)
/// is reused.</param>
public abstract void Predict(TSrc example, ref TDst prediction);
}

/// <summary>
/// Options for the <see cref="PredictionEngine{TSrc, TDst}"/>
/// </summary>
public sealed class PredictionEngineOptions
{
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to throw an error if a column exists in the output schema but not the output object.", ShortName = "ignore", SortOrder = 50)]
public bool IgnoreMissingColumns = Defaults.IgnoreMissingColumns;

[Argument(ArgumentType.AtMostOnce, HelpText = "Additional settings of the input schema.", ShortName = "input", SortOrder = 50)]
public SchemaDefinition InputSchemaDefinition = Defaults.InputSchemaDefinition;

[Argument(ArgumentType.AtMostOnce, HelpText = "Additional settings of the output schema.", ShortName = "output")]
public SchemaDefinition OutputSchemaDefinition = Defaults.OutputSchemaDefinition;

[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the prediction engine owns the transformer and should dispose of it.", ShortName = "own")]
public bool OwnsTransformer = Defaults.OwnsTransformer;

internal static class Defaults
{
public const bool IgnoreMissingColumns = true;
public const SchemaDefinition InputSchemaDefinition = null;
public const SchemaDefinition OutputSchemaDefinition = null;
public const bool OwnsTransformer = true;
}
}
}
37 changes: 37 additions & 0 deletions src/Microsoft.ML.TimeSeries/PredictionEngine.cs
Expand Up @@ -148,6 +148,15 @@ private static ITransformer CloneTransformers(ITransformer transformer)
{
}

/// <summary>
/// Contructor for creating time series specific prediction engine. It allows the time series model to be updated with the observations
/// seen at prediction time via <see cref="CheckPoint(IHostEnvironment, string)"/>
/// </summary>
internal TimeSeriesPredictionEngine(IHostEnvironment env, ITransformer transformer, PredictionEngineOptions options) :
base(env, CloneTransformers(transformer), options.IgnoreMissingColumns, options.InputSchemaDefinition, options.OutputSchemaDefinition, options.OwnsTransformer)
{
}

internal DataViewRow GetStatefulRows(DataViewRow input, IRowToRowMapper mapper, IEnumerable<DataViewSchema.Column> activeColumns, List<StatefulRow> rows)
{
Contracts.CheckValue(input, nameof(input));
Expand Down Expand Up @@ -398,5 +407,33 @@ public static class PredictionFunctionExtensions
env.CheckValueOrNull(outputSchemaDefinition);
return new TimeSeriesPredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}

/// <summary>
/// <see cref="TimeSeriesPredictionEngine{TSrc, TDst}"/> creates a prediction engine for a time series pipeline.
/// It updates the state of time series model with observations seen at prediction phase and allows checkpointing the model.
/// </summary>
/// <typeparam name="TSrc">Class describing input schema to the model.</typeparam>
/// <typeparam name="TDst">Class describing the output schema of the prediction.</typeparam>
/// <param name="transformer">The time series pipeline in the form of a <see cref="ITransformer"/>.</param>
/// <param name="env">Usually <see cref="MLContext"/></param>
/// <param name="options">Advanced configuration options.</param>
/// <p>Example code can be found by searching for <i>TimeSeriesPredictionEngine</i> in <a href='https://github.com/dotnet/machinelearning'>ML.NET.</a></p>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// This is an example for detecting change point using Singular Spectrum Analysis (SSA) model.
/// [!code-csharp[MF](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectChangePointBySsa.cs)]
/// ]]>
/// </format>
/// </example>
public static TimeSeriesPredictionEngine<TSrc, TDst> CreateTimeSeriesEngine<TSrc, TDst>(this ITransformer transformer, IHostEnvironment env,
PredictionEngineOptions options)
where TSrc : class
where TDst : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(options, nameof(options));
return new TimeSeriesPredictionEngine<TSrc, TDst>(env, transformer, options);
}
}
}
41 changes: 41 additions & 0 deletions test/Microsoft.ML.IntegrationTests/Prediction.cs
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Reflection;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.IntegrationTests.Datasets;
Expand Down Expand Up @@ -97,5 +98,45 @@ public void ReconfigurablePredictionNoPipeline()
Assert.True(pr.Score <= 0);
}

[Fact]
public void PredictionEngineModelDisposal()
{
var mlContext = new MLContext(seed: 1);
var data = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
var pipeline = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
new Trainers.LbfgsLogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 });
var model = pipeline.Fit(data);

var engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model, new PredictionEngineOptions());

// Dispose of prediction engine, should dispose of model
engine.Dispose();

// Get disposed flag using reflection
var bfIsDisposed = BindingFlags.Instance | BindingFlags.NonPublic;
var field = model.GetType().BaseType.BaseType.GetField("_disposed", bfIsDisposed);

// Make sure the model is actually disposed
Assert.True((bool)field.GetValue(model));

// Make a new model/prediction engine. Set the options so prediction engine doesn't dispose
model = pipeline.Fit(data);

var options = new PredictionEngineOptions()
{
OwnsTransformer = false
};

engine = mlContext.Model.CreatePredictionEngine<TypeTestData, Prediction>(model, options);

// Dispose of prediction engine, shouldn't dispose of model
engine.Dispose();

// Make sure model is not disposed of.
Assert.False((bool)field.GetValue(model));

// Dispose of the model for test cleanliness
model.Dispose();
}
}
}

0 comments on commit 1dfccca

Please sign in to comment.