Skip to content

Commit

Permalink
Add sweepable estimator to NER (#6965)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewditu committed Jan 19, 2024
1 parent 48b6fbe commit 125b6d5
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@
"TextClassifcation",
"SentenceSimilarity",
"ObjectDetection",
"QuestionAnswering"
"QuestionAnswering",
"NamedEntityRecognition"
]
},
"nugetDependencies": {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"$schema": "./search-space-schema.json#",
"name": "named_entity_recognition_option",
"search_space": [
{
"name": "PredictionColumnName",
"type": "string",
"default": "predictedLabel"
},
{
"name": "LabelColumnName",
"type": "string",
"default": "Label"
},
{
"name": "Sentence1ColumnName",
"type": "string",
"default": "Sentence"
},
{
"name": "BatchSize",
"type": "integer",
"default": 32
},
{
"name": "MaxEpochs",
"type": "integer",
"default": 10
},
{
"name": "Architecture",
"type": "bertArchitecture",
"default": "BertArchitecture.Roberta"
}
]
}
6 changes: 4 additions & 2 deletions src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@
"text_classification_option",
"sentence_similarity_option",
"object_detection_option",
"question_answering_option"
"question_answering_option",
"named_entity_recognition_option"
]
},
"option_name": {
Expand Down Expand Up @@ -238,7 +239,8 @@
"AnswerIndexStartColumnName",
"predictedAnswerColumnName",
"TopKAnswers",
"TargetType"
"TargetType",
"PredictionColumnName"
]
},
"option_type": {
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,13 @@
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
"searchOption": "question_answering_option"
},
{
"functionName": "NamedEntityRecognition",
"estimatorTypes": [ "MultiClassification" ],
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
"searchOption": "named_entity_recognition_option"
},
{
"functionName": "ForecastBySsa",
"estimatorTypes": [ "Forecasting" ],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.TorchSharp;
using Microsoft.ML.TorchSharp.NasBert;

namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class NamedEntityRecognitionMulti
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, NamedEntityRecognitionOption param)
{
return context.MulticlassClassification.Trainers.NamedEntityRecognition(
labelColumnName: param.LabelColumnName,
outputColumnName: param.PredictionColumnName,
sentence1ColumnName: param.Sentence1ColumnName,
batchSize: param.BatchSize,
maxEpochs: param.MaxEpochs,
architecture: BertArchitecture.Roberta);
}
}
}

0 comments on commit 125b6d5

Please sign in to comment.