Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DFT on DirectML #12710

Merged
merged 6 commits into from Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmake/onnxruntime_providers.cmake
Expand Up @@ -621,7 +621,7 @@ if (onnxruntime_USE_TENSORRT)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers")
endif()
set(CXX_VERSION_DEFINED TRUE)

if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
# Add TensorRT library
find_path(TENSORRT_INCLUDE_DIR NvInfer.h
Expand Down Expand Up @@ -658,9 +658,9 @@ if (onnxruntime_USE_TENSORRT)
include_directories(${TENSORRT_INCLUDE_DIR})
set(onnxparser_link_libs nvonnxparser_static)
endif()

set(trt_link_libs cudnn ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})

file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.cc"
Expand Down Expand Up @@ -1157,7 +1157,7 @@ if (onnxruntime_USE_DML)
if (GDK_PLATFORM STREQUAL Scarlett)
target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs})
else()
target_link_libraries(onnxruntime_providers_dml PRIVATE d3d12.lib dxgi.lib)
target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib)
endif()

target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib)
Expand Down
Expand Up @@ -544,8 +544,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
// Currently unsupported for external operators
if (canAliasFirstInput ||
supportsGraph ||
requiredInputCountForGraph ||
requiredConstantCpuInputs)
requiredInputCountForGraph)
{
ORT_THROW_HR(E_INVALIDARG);
}
Expand Down
Expand Up @@ -1408,11 +1408,14 @@ namespace Windows::AI::MachineLearning::Adapter
ComPtr<IMLOperatorTensor> tensor;
ORT_THROW_IF_FAILED(GetInputTensor(i, tensor.GetAddressOf()));

ComPtr<IUnknown> resource;
tensor->GetDataInterface(resource.GetAddressOf());
if (resource)
if (tensor)
{
resourcesToTransition.push_back(resource.Get());
ComPtr<IUnknown> resource;
tensor->GetDataInterface(resource.GetAddressOf());
if (resource)
{
resourcesToTransition.push_back(resource.Get());
}
}
}

Expand Down Expand Up @@ -1525,21 +1528,27 @@ namespace Windows::AI::MachineLearning::Adapter

ML_CHECK_BOOL(inputIndex < m_inputTensors.size());

auto opKernelContextWrapper = const_cast<OpKernelContextWrapper*>(this);
if (m_inputTensors[inputIndex]->GetInterface() == nullptr)
{
auto inputTensor = m_impl->Input<onnxruntime::Tensor>(inputIndex);
if (inputTensor != nullptr)
{
ComPtr<TensorWrapper> tensorWrapper = wil::MakeOrThrow<TensorWrapper>(
const_cast<onnxruntime::Tensor*>(inputTensor),
IsAllocationInterface(inputTensor->Location()),
m_winmlProvider.Get(),
m_internalOperator);

ComPtr<TensorWrapper> tensorWrapper = wil::MakeOrThrow<TensorWrapper>(
const_cast<onnxruntime::Tensor*>(inputTensor),
IsAllocationInterface(inputTensor->Location()),
m_winmlProvider.Get(),
m_internalOperator);

const_cast<OpKernelContextWrapper*>(this)->m_inputTensors[inputIndex] = tensorWrapper;
opKernelContextWrapper->m_inputTensors[inputIndex] = tensorWrapper;
}
}

const_cast<OpKernelContextWrapper*>(this)->m_inputTensors[inputIndex].CopyTo(tensor);

if (opKernelContextWrapper->m_inputTensors[inputIndex] != nullptr)
{
opKernelContextWrapper->m_inputTensors[inputIndex].CopyTo(tensor);
}
return S_OK;
}
ORT_CATCH_RETURN
Expand Down