diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index ed94b14578e8..5c34df984152 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -681,6 +681,18 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ThreadPoolProfiler profiler_; + void SignalAllAndWait() { + done_ = true; + + // Now if all threads block without work, they will start exiting. + // But note that threads can continue to work arbitrary long, + // block, submit new work, unblock and otherwise live full life. + WakeAllWorkersForExit(); + // Join threads explicitly (by destroying) to avoid destruction order within + // this class. + for (size_t i = 0; i < worker_data_.size(); ++i) worker_data_[i].thread.reset(); + } + public: void StartProfiling() override { profiler_.Start(); @@ -750,22 +762,24 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ComputeCoprimes(i, &all_coprimes_.back()); } - worker_data_.resize(num_threads_); - for (auto i = 0u; i < num_threads_; i++) { - worker_data_[i].thread.reset(env_.CreateThread(name, i, WorkerLoop, this, thread_options)); + // Eigen::MaxSizeVector has neither essential exception safety features + // such as swap, nor it is movable. So we have to join threads right here + // on exception + ORT_TRY { + worker_data_.resize(num_threads_); + for (auto i = 0u; i < num_threads_; i++) { + worker_data_[i].thread.reset(env_.CreateThread(name, i, WorkerLoop, this, thread_options)); + } + } ORT_CATCH(...) { + ORT_HANDLE_EXCEPTION([&]() { + SignalAllAndWait(); + throw; + }); } } ~ThreadPoolTempl() override { - done_ = true; - - // Now if all threads block without work, they will start exiting. - // But note that threads can continue to work arbitrary long, - // block, submit new work, unblock and otherwise live full life. - WakeAllWorkersForExit(); - // Join threads explicitly (by destroying) to avoid destruction order within - // this class. - for (size_t i = 0; i < worker_data_.size(); ++i) worker_data_[i].thread.reset(); + SignalAllAndWait(); } // Run fn(). Ordinarily, the function will be added to the thread pool and executed diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index fa3b33166fa1..975a73d21284 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -37,7 +37,7 @@ class FusedConv : public onnxruntime::cuda::Conv { Status ComputeInternal(OpKernelContext* context) const override { CUDNN_RETURN_IF_ERROR(status_); std::lock_guard lock(Base::s_.mutex); - ORT_RETURN_IF_ERROR(Base::UpdateState(context)); + ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); if (Base::s_.Y->Shape().Size() == 0) { return Status::OK(); } @@ -47,27 +47,25 @@ class FusedConv : public onnxruntime::cuda::Conv { const auto alpha = onnxruntime::cuda::Consts::One; const auto beta = onnxruntime::cuda::Consts::Zero; IAllocatorUniquePtr workspace = Base::GetWorkSpace(); - - if (has_b && has_z && !Base::s_.post_slicing_required) { - CUDNN_RETURN_IF_ERROR(cudnnConvolutionBiasActivationForward(Base::CudnnHandle(), - &alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.algo, - workspace.get(), - Base::s_.workspace_bytes, - &alpha, - Base::s_.z_tensor, - Base::s_.z_data, - Base::s_.b_tensor, - Base::s_.b_data, - activation_desc_, - Base::s_.y_tensor, - Base::s_.y_data)); - } else { + auto cudnn_status = cudnnConvolutionBiasActivationForward(Base::CudnnHandle(), + &alpha, + Base::s_.x_tensor, + Base::s_.x_data, + Base::s_.w_desc, + Base::s_.w_data, + Base::s_.conv_desc, + Base::s_.algo, + workspace.get(), + Base::s_.workspace_bytes, + has_z ? &alpha : &beta, + has_z ? Base::s_.z_tensor : Base::s_.y_tensor, + has_z ? Base::s_.z_data : Base::s_.y_data, + Base::s_.b_tensor, + has_b ? Base::s_.b_data : Base::s_.b_zero, + activation_desc_, + Base::s_.y_tensor, + Base::s_.y_data); + if (CUDNN_STATUS_SUCCESS != cudnn_status) { CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(Base::CudnnHandle(), &alpha, Base::s_.x_tensor, @@ -81,38 +79,21 @@ class FusedConv : public onnxruntime::cuda::Conv { &beta, Base::s_.y_tensor, Base::s_.y_data)); - - if (Base::s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection( - this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(), - Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size)); - - onnxruntime::cuda::CudnnTensor sliced_y_tensor; - ORT_RETURN_IF_ERROR(sliced_y_tensor.Set(Base::s_.y_dims.GetDims(), onnxruntime::cuda::CudnnTensor::GetDataType())); - - if (has_b) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data, - &alpha, sliced_y_tensor, Base::s_.Y->MutableDataRaw())); - } - if (has_z) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, - &alpha, sliced_y_tensor, Base::s_.Y->MutableDataRaw())); - } - - CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, sliced_y_tensor, - Base::s_.y_data, &beta, sliced_y_tensor, Base::s_.y_data)); - } else { - if (has_b) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data)); - } - if (has_z) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, - &alpha, Base::s_.y_tensor, Base::s_.y_data)); - } - CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, Base::s_.y_tensor, - Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data)); + if (has_b) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data, + &alpha, Base::s_.y_tensor, Base::s_.y_data)); } + if (has_z) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, + &alpha, Base::s_.y_tensor, Base::s_.y_data)); + } + CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, Base::s_.y_tensor, + Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data)); + } + if (Base::s_.post_slicing_required) { + ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection( + this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(), + Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size)); } return Status::OK(); } diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 06e0cd2a7de6..bcdd748f4057 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -26,12 +26,15 @@ limitations under the License. #include #include #include +#include #include #include #include // for std::forward #include #include +#include + #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/platform/scoped_resource.h" @@ -54,8 +57,7 @@ class UnmapFileParam { * * @return a pair of {errno, error message} */ -static std::pair GetSystemError() { - auto e = errno; +static std::pair GetSystemError(int e) { char buf[1024]; const char* msg = ""; if (e > 0) { @@ -73,6 +75,11 @@ static std::pair GetSystemError() { return std::make_pair(e, msg); } +static std::pair GetSystemError() { + auto e = errno; + return GetSystemError(e); +} + static void UnmapFile(void* param) noexcept { std::unique_ptr p(reinterpret_cast(param)); int ret = munmap(p->addr, p->len); @@ -128,6 +135,7 @@ struct Freer { using MallocdStringPtr = std::unique_ptr >; + class PosixThread : public EnvThread { private: struct Param { @@ -135,22 +143,38 @@ class PosixThread : public EnvThread { int index; unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param); Eigen::ThreadPoolInterface* param; - const ThreadOptions& thread_options; + std::optional affinity_mask; + + Param(const ORTCHAR_T* name_prefix1, + int index1, + unsigned (*start_address1)(int id, Eigen::ThreadPoolInterface* param), + Eigen::ThreadPoolInterface* param1) + : name_prefix(name_prefix1), + index(index1), + start_address(start_address1), + param(param1) {} }; public: PosixThread(const ORTCHAR_T* name_prefix, int index, unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { + ORT_ENFORCE(index >= 0, "Negative thread index is not allowed"); custom_create_thread_fn = thread_options.custom_create_thread_fn; custom_thread_creation_options = thread_options.custom_thread_creation_options; custom_join_thread_fn = thread_options.custom_join_thread_fn; + auto param_ptr = std::make_unique(name_prefix, index, start_address, param); + if (gsl::narrow(index) < thread_options.affinity.size()) { + param_ptr->affinity_mask = thread_options.affinity[index]; + } + if (custom_create_thread_fn) { - custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, CustomThreadMain, new Param{name_prefix, index, start_address, param, thread_options}); + custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, CustomThreadMain, param_ptr.get()); if (!custom_thread_handle) { ORT_THROW("custom_create_thread_fn returned invalid handle."); } + param_ptr.release(); } else { pthread_attr_t attr; int s = pthread_attr_init(&attr); @@ -165,24 +189,14 @@ class PosixThread : public EnvThread { ORT_THROW("pthread_attr_setstacksize failed, error code: ", err_no, " error msg: ", err_msg); } } - s = pthread_create(&hThread, &attr, ThreadMain, - new Param{name_prefix, index, start_address, param, thread_options}); + + s = pthread_create(&hThread, &attr, ThreadMain, param_ptr.get()); if (s != 0) { auto [err_no, err_msg] = GetSystemError(); ORT_THROW("pthread_create failed, error code: ", err_no, " error msg: ", err_msg); } -#if !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__) && !defined(_AIX) - if (!thread_options.affinity.empty()) { - cpu_set_t cpuset; - CPU_ZERO(&cpuset); - CPU_SET(thread_options.affinity[index], &cpuset); - s = pthread_setaffinity_np(hThread, sizeof(cpu_set_t), &cpuset); - if (s != 0) { - auto [err_no, err_msg] = GetSystemError(); - ORT_THROW("pthread_setaffinity_np failed, error code: ", err_no, " error msg: ", err_msg); - } - } -#endif + param_ptr.release(); + // Do not throw beyond this point so we do not lose thread handle and then not being able to join it. } } @@ -203,13 +217,29 @@ class PosixThread : public EnvThread { private: static void* ThreadMain(void* param) { - std::unique_ptr p((Param*)param); + std::unique_ptr p(static_cast(param)); ORT_TRY { +#if !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__) && !defined(_AIX) + if (p->affinity_mask.has_value()) { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(*p->affinity_mask, &cpuset); + // pthread_setaffinity_np() does not set errno, it returns it. + auto ret = pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); + if (ret != 0) { + auto [err_no, err_msg] = GetSystemError(ret); + LOGS_DEFAULT(ERROR) << "pthread_setaffinity_np failed for thread: " << pthread_self() + << ", mask: " << *p->affinity_mask + << ", error code: " << err_no << " error msg: " << err_msg + << ". Specify the number of threads explicitly so the affinity is not set."; + } + } +#endif // Ignore the returned value for now p->start_address(p->index, p->param); } - ORT_CATCH(const std::exception&) { - //ignore any exceptions + ORT_CATCH(...) { + // Ignore exceptions } return nullptr; } @@ -440,7 +470,7 @@ class PosixEnv : public Env { common::Status GetCanonicalPath( const PathString& path, PathString& canonical_path) const override { - MallocdStringPtr canonical_path_cstr{realpath(path.c_str(), nullptr)}; + MallocdStringPtr canonical_path_cstr{realpath(path.c_str(), nullptr), Freer()}; if (!canonical_path_cstr) { return ReportSystemError("realpath", path); } diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index c711f6d36a09..23eb51f3b4e6 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include +#include #include #include #include #include #include +#include #include "core/common/logging/logging.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" @@ -68,31 +70,53 @@ class WindowsThread : public EnvThread { int index; unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param); Eigen::ThreadPoolInterface* param; - const ThreadOptions& thread_options; + std::optional affinity_mask; Param(const ORTCHAR_T* name_prefix1, int index1, unsigned (*start_address1)(int id, Eigen::ThreadPoolInterface* param), - Eigen::ThreadPoolInterface* param1, - const ThreadOptions& thread_options1) : name_prefix(name_prefix1), index(index1), start_address(start_address1), param(param1), thread_options(thread_options1) {} + Eigen::ThreadPoolInterface* param1) + : name_prefix(name_prefix1), + index(index1), + start_address(start_address1), + param(param1) {} }; public: WindowsThread(const ORTCHAR_T* name_prefix, int index, unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { + ORT_ENFORCE(index >= 0, "Negative thread index is not allowed"); custom_create_thread_fn = thread_options.custom_create_thread_fn; custom_thread_creation_options = thread_options.custom_thread_creation_options; custom_join_thread_fn = thread_options.custom_join_thread_fn; - std::unique_ptr local_param = std::make_unique(name_prefix, index, start_address, param, thread_options); + + std::unique_ptr local_param = std::make_unique(name_prefix, index, start_address, param); + if (gsl::narrow(index) < thread_options.affinity.size()) { + local_param->affinity_mask = thread_options.affinity[index]; + } + if (custom_create_thread_fn) { - custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, (OrtThreadWorkerFn)CustomThreadMain, local_param.release()); + custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, (OrtThreadWorkerFn)CustomThreadMain, local_param.get()); if (!custom_thread_handle) { ORT_THROW("custom_create_thread_fn returned invalid handle."); } + local_param.release(); } else { - hThread.reset(reinterpret_cast(_beginthreadex(nullptr, thread_options.stack_size, ThreadMain, - local_param.release(), 0, - &threadID))); + _set_errno(0); + _set_doserrno(0); + auto th_handle = _beginthreadex(nullptr, thread_options.stack_size, ThreadMain, + local_param.get(), 0, + &threadID); + if (th_handle == 0) { + auto err = errno; + auto dos_error = _doserrno; + char message_buf[256]; + strerror_s(message_buf, sizeof(message_buf), err); + ORT_THROW("WindowThread:_beginthreadex failed with message: ", message_buf, " doserrno: ", dos_error); + } + local_param.release(); + hThread.reset(reinterpret_cast(th_handle)); + // Do not throw beyond this point so we do not lose thread handle and then not being able to join it. } } @@ -112,10 +136,7 @@ class WindowsThread : public EnvThread { #pragma warning(push) #pragma warning(disable : 6387) static unsigned __stdcall ThreadMain(void* param) { - std::unique_ptr p((Param*)param); - // TODO: should I try to use SetThreadSelectedCpuSets? - if (!p->thread_options.affinity.empty()) - SetThreadAffinityMask(GetCurrentThread(), p->thread_options.affinity[p->index]); + std::unique_ptr p(static_cast(param)); #if WINVER >= _WIN32_WINNT_WIN10 constexpr SetThreadDescriptionFunc pSetThrDesc = SetThreadDescription; #elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) @@ -137,9 +158,22 @@ class WindowsThread : public EnvThread { } unsigned ret = 0; ORT_TRY { + // TODO: should I try to use SetThreadSelectedCpuSets? + if (p->affinity_mask.has_value()) { + auto rc = SetThreadAffinityMask(GetCurrentThread(), *p->affinity_mask); + if (!rc) { + const auto error_code = GetLastError(); + LOGS_DEFAULT(ERROR) << "SetThreadAffinityMask failed for thread: " << GetCurrentThreadId() + << ", mask: " << *p->affinity_mask + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code) + << ". Specify the number of threads explicitly so the affinity is not set."; + } + } + ret = p->start_address(p->index, p->param); } - ORT_CATCH(const std::exception&) { + ORT_CATCH(...) { p->param->Cancel(); ret = 1; } @@ -148,11 +182,11 @@ class WindowsThread : public EnvThread { #pragma warning(pop) static void __stdcall CustomThreadMain(void* param) { - std::unique_ptr p((Param*)param); + std::unique_ptr p(static_cast(param)); ORT_TRY { p->start_address(p->index, p->param); } - ORT_CATCH(const std::exception&) { + ORT_CATCH(...) { p->param->Cancel(); } } @@ -222,7 +256,7 @@ class WindowsEnv : public Env { ret.push_back(buffer[i].ProcessorMask); } } - if (ret.empty()){ + if (ret.empty()) { return generate_vector_of_n(std::thread::hardware_concurrency()); } return ret; @@ -363,9 +397,9 @@ class WindowsEnv : public Env { if (file_handle.get() == INVALID_HANDLE_VALUE) { const auto error_code = GetLastError(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "open file ", ToUTF8String(Basename(file_path)), - " fail, errcode = ", error_code, - " - ", std::system_category().message(error_code)); + "open file ", ToUTF8String(Basename(file_path)), + " fail, errcode = ", error_code, + " - ", std::system_category().message(error_code)); } #if NTDDI_VERSION >= NTDDI_WIN10_RS5 && WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM) @@ -391,9 +425,9 @@ class WindowsEnv : public Env { if (file_mapping_handle.get() == INVALID_HANDLE_VALUE) { const auto error_code = GetLastError(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "open file mapping ", ToUTF8String(Basename(file_path)), - " fail, errcode = ", error_code, - " - ", std::system_category().message(error_code)); + "open file mapping ", ToUTF8String(Basename(file_path)), + " fail, errcode = ", error_code, + " - ", std::system_category().message(error_code)); } SYSTEM_INFO sysinfo; @@ -407,11 +441,11 @@ class WindowsEnv : public Env { if (mapped_offset % allocation_granularity != 0) { const auto error_code = GetLastError(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "mapped offset must be a multiple of the allocation granularity", - " , mapped_offset = ", mapped_offset, - " , allocation_granularity = ", allocation_granularity, - " , errcode = ", error_code, - " - ", std::system_category().message(error_code)); + "mapped offset must be a multiple of the allocation granularity", + " , mapped_offset = ", mapped_offset, + " , allocation_granularity = ", allocation_granularity, + " , errcode = ", error_code, + " - ", std::system_category().message(error_code)); } void* const mapped_base = MapViewOfFile(file_mapping_handle.get(), @@ -650,7 +684,7 @@ class WindowsEnv : public Env { static constexpr DWORD bufferLength = 64 * 1024; std::wstring s(bufferLength, '\0'); FormatMessageW( - FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error_code, @@ -682,7 +716,7 @@ class WindowsEnv : public Env { static constexpr DWORD bufferLength = 64 * 1024; std::wstring s(bufferLength, '\0'); FormatMessageW( - FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error_code, diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 8e2983dd40d8..fd0d15640f47 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -87,7 +87,7 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, } template -Status Conv::UpdateState(OpKernelContext* context) const { +Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { //set X const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -109,7 +109,8 @@ Status Conv::UpdateState(OpKernelContext* context) const { //set Z if (context->InputCount() >= 4) { const Tensor* Z = context->Input(3); - s_.z_data = reinterpret_cast(Z->template Data()); + ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType())); + s_.z_data = reinterpret_cast(Z->Data()); } else { s_.z_data = nullptr; } @@ -236,43 +237,22 @@ Status Conv::UpdateState(OpKernelContext* context) const { if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); const auto& b_shape = B->Shape(); - if (b_shape.NumDimensions() == 1) { - TensorShapeVector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = b_shape[0]; - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); - } else { - const auto& y_rank = y_dims_cudnn.size(); - const auto& b_rank = b_shape.GetDims().size(); - ORT_RETURN_IF_NOT(b_rank <= y_rank, "rank of B is ", b_rank, ", which is bigger than the rank of Y - ", y_rank); - if (b_rank == y_rank) { - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_shape.GetDims(), CudnnTensor::GetDataType())); - } else { - TensorShapeVector b_extended_dims = b_shape.AsShapeVector(); - for (auto i = b_rank; i < y_rank; ++i) { - ORT_RETURN_IF_NOT(y_dims_cudnn[i] == 1, "dim ", i, " of Y is ", y_dims_cudnn[i], ", cannot apply it to that dim of B"); - b_extended_dims.push_back(1); - } - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_extended_dims, CudnnTensor::GetDataType())); - } - } - } - - if (context->InputCount() >= 4) { - const Tensor* Z = context->Input(3); - const auto& z_shape = Z->Shape(); - const auto& z_rank = z_shape.GetDims().size(); - const auto& y_rank = y_dims_cudnn.size(); - ORT_RETURN_IF_NOT(z_rank <= y_rank, "rank of Z is ", z_rank, ", which is bigger than the rank of Y - ", y_rank); - if (z_rank == y_rank) { - ORT_RETURN_IF_ERROR(s_.z_tensor.Set(z_shape.GetDims(), CudnnTensor::GetDataType())); - } else { - TensorShapeVector z_extended_dims = z_shape.AsShapeVector(); - for (auto i = z_rank; i < y_rank; ++i) { - ORT_RETURN_IF_NOT(y_dims_cudnn[i] == 1, "dim ", i, " of Y is ", y_dims_cudnn[i], ", cannot apply it to that dim of Z"); - z_extended_dims.push_back(1); - } - ORT_RETURN_IF_ERROR(s_.z_tensor.Set(z_extended_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = b_shape[0]; + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + //s_.b_data = reinterpret_cast(B->Data()); + } else if (bias_expected) { + TensorShapeVector b_dims(2 + kernel_shape.size(), 1); + b_dims[1] = w_dims[0]; + auto malloc_size = b_dims[1] * sizeof(CudaT); + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + if (s_.b_zero) { + CUDA_CALL_THROW(cudaFree(s_.b_zero)); + s_.b_zero = nullptr; } + CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size)); + CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream())); } if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 822785255246..135b189d4b2a 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -141,6 +141,7 @@ struct CudnnConvState { const void* w_data = nullptr; CudnnTensor b_tensor; const void* b_data = nullptr; + void* b_zero = nullptr; CudnnTensor y_tensor; Tensor* Y = nullptr; void* y_data = nullptr; @@ -165,6 +166,13 @@ struct CudnnConvState { // note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing OrtMutex mutex; IAllocatorUniquePtr memory_for_cudnn_conv_results; + + ~CudnnConvState() { + if (b_zero) { + CUDA_CALL_THROW(cudaFree(b_zero)); + b_zero = nullptr; + } + } }; enum : size_t { @@ -189,7 +197,7 @@ class Conv : public CudaKernel { return GetScratchBuffer(s_.workspace_bytes); } - Status UpdateState(OpKernelContext* context) const; + Status UpdateState(OpKernelContext* context, bool bias_expected = false) const; ConvAttributes conv_attrs_; mutable CudnnConvState s_; constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 42aeeaef0ef4..51497e90581a 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -817,7 +817,12 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): min_kl_divergence_idx = np.argmin(kl_divergence) optimal_threshold = thresholds[min_kl_divergence_idx] - + min_value = histogram[2] + max_value = histogram[3] + if optimal_threshold[0] < min_value: + optimal_threshold = (min_value, optimal_threshold[1]) + if optimal_threshold[1] > max_value: + optimal_threshold = (optimal_threshold[0], max_value) return optimal_threshold diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 8cb79461b902..890bccd4f18c 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -3,9 +3,6 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#include "core/session/inference_session.h" -#include "test/framework/test_utils.h" - using namespace std; namespace onnxruntime { namespace test { @@ -728,141 +725,5 @@ TEST(ConvTest, Conv_AutoPad_with_non_default_strides) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } -#ifdef USE_CUDA -TEST(ConvTest, Fuse_Conv_Bias) { - auto model_uri = ORT_TSTR("testdata/fuse_conv_bias.onnx"); - SessionOptions so; - InferenceSession session{so, GetEnvironment()}; - ASSERT_STATUS_OK(session.Load(model_uri)); - ASSERT_TRUE(session.Initialize().IsOK()); - - NameMLValMap feeds; - OrtValue ml_value; - - size_t X_count = 1 * 3 * 32 * 32; - std::vector X_data(X_count, 1.f); - std::vector X_shape{1, 3, 32, 32}; - - size_t W_count = 1 * 3 * 5 * 32; - std::vector W_data(W_count, 2.f); - std::vector W_shape{1, 3, 5, 32}; - - size_t B_count = 1; - std::vector B_data(B_count, 5.f); - std::vector B_shape{1}; - - size_t Z_count = 1 * 1 * 28; - std::vector Z_data(Z_count, 1.f); - std::vector Z_shape{1, 1, 28}; - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), X_shape, X_data, &ml_value); - feeds.insert(std::make_pair("X", ml_value)); - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), W_shape, W_data, &ml_value); - feeds.insert(std::make_pair("W", ml_value)); - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), B_shape, B_data, &ml_value); - feeds.insert(std::make_pair("B", ml_value)); - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), Z_shape, Z_data, &ml_value); - feeds.insert(std::make_pair("Z", ml_value)); - - std::vector output_names{"R"}; - std::vector fetches; - - onnxruntime::RunOptions run_options; - auto st = session.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(st.IsOK()) << st; - ASSERT_EQ(1u, fetches.size()); -} - -TEST(ConvTest, Fuse_Conv_Bias_Slice) { - auto model_uri = ORT_TSTR("testdata/fuse_conv_bias_slice.onnx"); - SessionOptions so; - InferenceSession session{so, GetEnvironment()}; - ASSERT_STATUS_OK(session.Load(model_uri)); - ASSERT_TRUE(session.Initialize().IsOK()); - - NameMLValMap feeds; - OrtValue ml_value; - - size_t X_count = 1 * 2 * 6 * 6; - std::vector X_data(X_count, 1.f); - std::vector X_shape{1, 2, 6, 6}; - - size_t W_count = 1 * 2 * 4 * 4; - std::vector W_data(W_count, 2.f); - std::vector W_shape{1, 2, 4, 4}; - - size_t B_count = 1; - std::vector B_data(B_count, 5.f); - std::vector B_shape{1}; - - size_t Z_count = 1 * 1 * 4 * 2; - std::vector Z_data(Z_count, 1.f); - std::vector Z_shape{1, 1, 4, 2}; - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), X_shape, X_data, &ml_value); - feeds.insert(std::make_pair("X", ml_value)); - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), W_shape, W_data, &ml_value); - feeds.insert(std::make_pair("W", ml_value)); - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), B_shape, B_data, &ml_value); - feeds.insert(std::make_pair("B", ml_value)); - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), Z_shape, Z_data, &ml_value); - feeds.insert(std::make_pair("Z", ml_value)); - - std::vector output_names{"R"}; - std::vector fetches; - - onnxruntime::RunOptions run_options; - auto st = session.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(st.IsOK()) << st; - ASSERT_EQ(1u, fetches.size()); -} - -TEST(ConvTest, Fuse_Conv_No_Bias) { - auto model_uri = ORT_TSTR("testdata/fuse_conv_no_bias.onnx"); - SessionOptions so; - InferenceSession session{so, GetEnvironment()}; - ASSERT_STATUS_OK(session.Load(model_uri)); - ASSERT_TRUE(session.Initialize().IsOK()); - - NameMLValMap feeds; - OrtValue ml_value; - - size_t X_count = 1 * 3 * 32 * 32; - std::vector X_data(X_count, 1.f); - std::vector X_shape{1, 3, 32, 32}; - - size_t W_count = 1 * 3 * 5 * 32; - std::vector W_data(W_count, 2.f); - std::vector W_shape{1, 3, 5, 32}; - - size_t Z_count = 1 * 1 * 28; - std::vector Z_data(Z_count, 1.f); - std::vector Z_shape{1, 1, 28}; - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), X_shape, X_data, &ml_value); - feeds.insert(std::make_pair("X", ml_value)); - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), W_shape, W_data, &ml_value); - feeds.insert(std::make_pair("W", ml_value)); - - CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), Z_shape, Z_data, &ml_value); - feeds.insert(std::make_pair("Z", ml_value)); - - std::vector output_names{"R"}; - std::vector fetches; - - onnxruntime::RunOptions run_options; - auto st = session.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(st.IsOK()) << st; - ASSERT_EQ(1u, fetches.size()); -} -#endif - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/fuse_conv_bias.onnx b/onnxruntime/test/testdata/fuse_conv_bias.onnx deleted file mode 100644 index e6cf15c34c1e..000000000000 --- a/onnxruntime/test/testdata/fuse_conv_bias.onnx +++ /dev/null @@ -1,37 +0,0 @@ -:Þ - -X -W -BY"Conv - -Y -ZC"Add - -CR"RelugraphZ -X - - - - - Z -W - - - - - Z -B - - -Z -Z - - - -b? -R: -84 - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/fuse_conv_bias_slice.onnx b/onnxruntime/test/testdata/fuse_conv_bias_slice.onnx deleted file mode 100644 index c55e611c9618..000000000000 --- a/onnxruntime/test/testdata/fuse_conv_bias_slice.onnx +++ /dev/null @@ -1,40 +0,0 @@ -:‡ -7 -X -W -BY"Conv* -pads@@@@ * -strides@@  - -Y -ZC"Add - -CR"RelugraphZ -X - - - - -Z -W - - - - -Z -B - - -Z -Z - - - - -b? -R: -84 - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/fuse_conv_no_bias.onnx b/onnxruntime/test/testdata/fuse_conv_no_bias.onnx deleted file mode 100644 index 2094d7532673..000000000000 --- a/onnxruntime/test/testdata/fuse_conv_no_bias.onnx +++ /dev/null @@ -1,32 +0,0 @@ -:Ê - -X -WY"Conv - -Y -ZC"Add - -CR"RelugraphZ -X - - - - - Z -W - - - - - Z -Z - - - -b? -R: -84 - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿ - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file