Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML EP] Contrib Op: FusedMatMul #12898

Merged
merged 11 commits into from
Sep 9, 2022
Merged
18 changes: 9 additions & 9 deletions onnxruntime/core/graph/dml_ops/dml_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using ONNX_NAMESPACE::OPTIONAL_VALUE;

void RegisterDmlSchemas() {

MS_DML_OPERATOR_SCHEMA(FusedConv)
MS_DML_OPERATOR_SCHEMA(DmlFusedConv)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused Conv+Activation)DOC")
Expand All @@ -52,7 +52,7 @@ void RegisterDmlSchemas() {
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1);
});

MS_DML_OPERATOR_SCHEMA(FusedConvTranspose)
MS_DML_OPERATOR_SCHEMA(DmlFusedConvTranspose)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused ConvTranspose+Activation)DOC")
Expand All @@ -79,7 +79,7 @@ void RegisterDmlSchemas() {
.TypeAndShapeInferenceFunction(
[](ONNX_NAMESPACE::InferenceContext& ctx) { ONNX_NAMESPACE::convTransposeShapeInference(ctx); });

MS_DML_OPERATOR_SCHEMA(FusedInstanceNormalization)
MS_DML_OPERATOR_SCHEMA(DmlFusedInstanceNormalization)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused InstanceNormalization+Activation)DOC")
Expand All @@ -100,7 +100,7 @@ void RegisterDmlSchemas() {
ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx);
});

MS_DML_OPERATOR_SCHEMA(FusedBatchNormalization)
MS_DML_OPERATOR_SCHEMA(DmlFusedBatchNormalization)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused BatchNormalization+Activation)DOC")
Expand Down Expand Up @@ -133,7 +133,7 @@ void RegisterDmlSchemas() {
// the other outputs as well.
});

MS_DML_OPERATOR_SCHEMA(FusedMeanVarianceNormalization)
MS_DML_OPERATOR_SCHEMA(DmlFusedMeanVarianceNormalization)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused MeanVarianceNormalization+Activation)DOC")
Expand All @@ -151,7 +151,7 @@ void RegisterDmlSchemas() {
.Attr(AttrName::FusedRatio, "", onnx::AttributeProto::FLOAT, false)
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);

MS_DML_OPERATOR_SCHEMA(FusedGemm)
MS_DML_OPERATOR_SCHEMA(DmlFusedGemm)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused Gemm+Activation)DOC")
Expand Down Expand Up @@ -194,7 +194,7 @@ void RegisterDmlSchemas() {
}
});

MS_DML_OPERATOR_SCHEMA(FusedMatMul)
MS_DML_OPERATOR_SCHEMA(DmlFusedMatMul)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused MatMul+Activation)DOC")
Expand Down Expand Up @@ -283,7 +283,7 @@ void RegisterDmlSchemas() {
resultShape;
});

MS_DML_OPERATOR_SCHEMA(FusedAdd)
MS_DML_OPERATOR_SCHEMA(DmlFusedAdd)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused Add+Activation)DOC")
Expand All @@ -307,7 +307,7 @@ void RegisterDmlSchemas() {
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
});

MS_DML_OPERATOR_SCHEMA(FusedSum)
MS_DML_OPERATOR_SCHEMA(DmlFusedSum)
.SetDomain(kMSDmlDomain)
.SinceVersion(1)
.SetDoc(R"DOC(DirectML fused Sum+Activation)DOC")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace Dml
auto kernelInputIndices = ReplaceUnusedEdgeIndicesWithSentinel(m_kernelInputIndices);
properties.dmlInputCount = static_cast<uint32_t>(kernelInputIndices.size());
properties.kernelInputIndices = kernelInputIndices.data();

auto kernelOutputIndices = ReplaceUnusedEdgeIndicesWithSentinel(m_kernelOutputIndices);
properties.dmlOutputCount = static_cast<uint32_t>(kernelOutputIndices.size());
properties.kernelOutputIndices = kernelOutputIndices.data();
Expand All @@ -88,7 +88,7 @@ namespace Dml

m_persistentResourceBinding = DML_BUFFER_BINDING{ m_persistentResource.Get(), 0, persistentResourceSize };
}

std::vector<DML_BUFFER_BINDING> initializationInputBindings(m_kernelInputIndices.size());

ORT_THROW_IF_FAILED(m_executionProvider->InitializeOperator(
Expand Down Expand Up @@ -183,7 +183,7 @@ namespace Dml
else
{
m_inputTensorDescs.push_back(CreateTensorDescFromInput(
kernelInfo,
kernelInfo,
*m_kernelInputIndices[i],
TensorAxis::DoNotCoerce,
TensorAxis::W,
Expand All @@ -205,7 +205,7 @@ namespace Dml
else
{
m_outputTensorDescs.push_back(CreateTensorDescFromOutput(
kernelInfo,
kernelInfo,
*m_kernelOutputIndices[i],
TensorAxis::DoNotCoerce,
TensorAxis::W,
Expand All @@ -216,6 +216,112 @@ namespace Dml
}
}

void DmlOperator::InitializeWithShapes(
const MLOperatorKernelCreationContext& kernelInfo,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices,
const std::optional<gsl::span<gsl::span<const uint32_t>>> inputShapes,
const std::optional<gsl::span<gsl::span<const uint32_t>>> outputShapes,
uint32_t minDimensionCount
)
{
if (kernelInputIndices)
{
m_kernelInputIndices = *kernelInputIndices;
}
else
{
m_kernelInputIndices.resize(kernelInfo.GetInputCount());
std::iota(m_kernelInputIndices.begin(), m_kernelInputIndices.end(), 0);
}

if (kernelOutputIndices)
{
m_kernelOutputIndices = *kernelOutputIndices;
}
else
{
m_kernelOutputIndices.resize(kernelInfo.GetOutputCount());
std::iota(m_kernelOutputIndices.begin(), m_kernelOutputIndices.end(), 0);
}

for (uint32_t i = 0; i < m_kernelInputIndices.size(); i++)
{
// Update m_kernelInputIndices to reflect optional tensors.
if (m_kernelInputIndices[i] == std::nullopt ||
!kernelInfo.IsInputValid(*m_kernelInputIndices[i]))
{
m_kernelInputIndices[i] = std::nullopt;
m_inputTensorDescs.push_back(TensorDesc());
}
else
{
auto edgeDesc = kernelInfo.GetInputEdgeDescription(*m_kernelInputIndices[i]);
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);

// prioritize the given input shapes
TensorDesc tensorDesc;
if (inputShapes.has_value() && i < (*inputShapes).size())
{
tensorDesc = TensorDesc(
edgeDesc.tensorDataType,
(*inputShapes)[i], // desired
(*inputShapes)[i], // original
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
minDimensionCount,
0
);
}
else if (kernelInfo.HasTensorShapeDescription())
{
std::vector<uint32_t> actualTensorShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(*m_kernelInputIndices[i]);
tensorDesc = TensorDesc(
edgeDesc.tensorDataType,
actualTensorShape, // desired
actualTensorShape, // original
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
minDimensionCount,
0
);
}
m_inputTensorDescs.push_back(tensorDesc);
}
}

for (uint32_t i = 0; i < m_kernelOutputIndices.size(); i++)
{
// Update m_kernelOutputIndices to reflect optional tensors.
if (m_kernelOutputIndices[i] == std::nullopt ||
!kernelInfo.IsOutputValid(*m_kernelOutputIndices[i]))
{
m_kernelOutputIndices[i] = std::nullopt;
m_outputTensorDescs.push_back(TensorDesc());
}
else
{
std::optional<gsl::span<const uint32_t>> outputShape;
if (outputShapes.has_value() && i < (*outputShapes).size())
{
outputShape = (*outputShapes)[i];
}

m_outputTensorDescs.push_back(CreateTensorDescFromOutput(
kernelInfo,
*m_kernelOutputIndices[i],
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
outputShape,
minDimensionCount
));
}
}
}

void DmlOperator::Compute(const MLOperatorKernelContext& kernelContext)
{
std::vector<IMLOperatorTensor*> inputTensors = GetInputTensorsForExecute(kernelContext);
Expand All @@ -231,7 +337,7 @@ namespace Dml
bool DmlOperator::AllowHalfPrecisionComputation() const
{
// Most of our operators work with float data, but some do not. In those cases
// no input params are float tensors. This function returns true if the operator
// no input params are float tensors. This function returns true if the operator
// works with at least one float16 tensor and has no tensors of float32 type
bool usesFloat16Tensors = false;

Expand Down Expand Up @@ -464,7 +570,7 @@ namespace Dml
}

auto outputShape = outputShapeDescription.GetOutputTensorShape(index);

return TensorDesc(
edgeDesc.tensorDataType,
tensorShape ? *tensorShape : outputShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Dml
virtual void Compute(const MLOperatorKernelContext& kernelContext);

protected:
ComPtr<IExecutionProvider> m_executionProvider;
ComPtr<IExecutionProvider> m_executionProvider;
ComPtr<IDMLDevice> m_dmlDevice;

// Tensor descs ordered based on index arrays passed to Initialize
Expand All @@ -43,23 +43,35 @@ namespace Dml
uint32_t minDimensionCount = NchwDimensionCount
);

// This first tries to create TensorDesc with the given input and output shapes, no broadcasting.
// If the shapes are not present, then it will try to create TensorDesc with the shapes from the actual input tensors and shape inference.
// The inputShapes and kernelInputIndices should have same length. Same for outputShapes and kernelOutputIndices.
void InitializeWithShapes(
const MLOperatorKernelCreationContext& kernelInfo,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelInputIndices = std::nullopt,
const std::optional<const std::vector<std::optional<uint32_t>>>& kernelOutputIndices = std::nullopt,
const std::optional<gsl::span<gsl::span<const uint32_t>>> inputShapes = std::nullopt,
const std::optional<gsl::span<gsl::span<const uint32_t>>> outputShapes = std::nullopt,
uint32_t minDimensionCount = NchwDimensionCount
);

bool AllowHalfPrecisionComputation() const;
DML_EXECUTION_FLAGS GetExecutionFlags() const;

void SetDmlOperatorDesc(
const DML_OPERATOR_DESC& operatorDesc,
const DML_OPERATOR_DESC& operatorDesc,
const MLOperatorKernelCreationContext& kernelInfo
);

void SetDmlOperatorDesc(
const DML_OPERATOR_DESC& operatorDesc,
const MLOperatorKernelContext& kernelInfo
);

// Tensors ordered based on index arrays passed to Initialize
std::vector<IMLOperatorTensor*> GetInputTensors(const MLOperatorKernelContext& kernelContext);
std::vector<IMLOperatorTensor*> GetOutputTensors(const MLOperatorKernelContext& kernelContext);

// Retrieves the input/output tensors to be supplied to DirectML for execution. These differ from
// Get[Input|Output]Tensors in that they account for the binding requirements of DML, instead of
// unconditionally retrieving all input and output tensors.
Expand Down Expand Up @@ -106,7 +118,7 @@ namespace Dml
) const;

private:
// For each input or output of the DML kernel, the corresponding input or output of the original
// For each input or output of the DML kernel, the corresponding input or output of the original
// kernel. Entries for unused DML inputs are nullopt.
std::vector<std::optional<uint32_t>> m_kernelInputIndices;
std::vector<std::optional<uint32_t>> m_kernelOutputIndices;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,6 @@ void CALLBACK QueryBatchNormalization(IMLOperatorSupportQueryContextPrivate* con
}

DML_OP_DEFINE_CREATION_FUNCTION(BatchNormalization, DmlOperatorBatchNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(FusedBatchNormalization, DmlOperatorBatchNormalization);
DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedBatchNormalization, DmlOperatorBatchNormalization);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ class DmlOperatorConvolutionTemplate : public DmlOperatorConvolution

DML_OP_DEFINE_CREATION_FUNCTION(Conv, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_FORWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(ConvTranspose, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_BACKWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(FusedConv, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_FORWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(FusedConvTranspose, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_BACKWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedConv, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_FORWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedConvTranspose, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_BACKWARD>);
DML_OP_DEFINE_CREATION_FUNCTION(ConvTransposeWithDynamicPads, DmlOperatorConvolutionTemplate<DML_CONVOLUTION_MODE_CROSS_CORRELATION, DML_CONVOLUTION_DIRECTION_BACKWARD, true>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(IsInf, DmlOperatorElementwiseIsInf);
DML_OP_DEFINE_CREATION_FUNCTION(Round, DmlOperatorElementwiseRound);

// Fused operators:
DML_OP_DEFINE_CREATION_FUNCTION(FusedAdd, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_ADD1_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(FusedSum, DmlOperatorElementwiseBinaryLoop<DML_ELEMENT_WISE_ADD1_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedAdd, DmlOperatorElementwiseBinary<DML_ELEMENT_WISE_ADD1_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(DmlFusedSum, DmlOperatorElementwiseBinaryLoop<DML_ELEMENT_WISE_ADD1_OPERATOR_DESC>);

} // namespace Dml