diff --git a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs index 6255e526ee..143e3de1a7 100644 --- a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs +++ b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs @@ -102,7 +102,7 @@ private static void ValidateTrainData(IDataView trainData, ColumnInformation col } } - private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation, TaskKind task) + private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation, TaskKind task) { ValidateColumnInformation(columnInformation); ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName, GetAllowedLabelTypes(task)); @@ -217,7 +217,7 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida throw new ArgumentException($"{schemaMismatchError} Column '{trainCol.Name}' exists in train data, but not in validation data.", nameof(validationData)); } - if (trainCol.Type != validCol.Value.Type) + if (trainCol.Type != validCol.Value.Type && !trainCol.Type.Equals(validCol.Value.Type)) { throw new ArgumentException($"{schemaMismatchError} Column '{trainCol.Name}' is of type {trainCol.Type} in train data, and type " + $"{validCol.Value.Type} in validation data.", nameof(validationData)); @@ -260,7 +260,7 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa throw new ArgumentException(exceptionMessage); } - if(allowedTypes == null) + if (allowedTypes == null) { return; }