Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cjian/rel 1.13.1 cherry pick round 1 #13372

Merged
merged 3 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 26 additions & 12 deletions include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,18 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

ThreadPoolProfiler profiler_;

void SignalAllAndWait() {
done_ = true;

// Now if all threads block without work, they will start exiting.
// But note that threads can continue to work arbitrary long,
// block, submit new work, unblock and otherwise live full life.
WakeAllWorkersForExit();
// Join threads explicitly (by destroying) to avoid destruction order within
// this class.
for (size_t i = 0; i < worker_data_.size(); ++i) worker_data_[i].thread.reset();
}

public:
void StartProfiling() override {
profiler_.Start();
Expand Down Expand Up @@ -750,22 +762,24 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
ComputeCoprimes(i, &all_coprimes_.back());
}

worker_data_.resize(num_threads_);
for (auto i = 0u; i < num_threads_; i++) {
worker_data_[i].thread.reset(env_.CreateThread(name, i, WorkerLoop, this, thread_options));
// Eigen::MaxSizeVector has neither essential exception safety features
// such as swap, nor it is movable. So we have to join threads right here
// on exception
ORT_TRY {
worker_data_.resize(num_threads_);
for (auto i = 0u; i < num_threads_; i++) {
worker_data_[i].thread.reset(env_.CreateThread(name, i, WorkerLoop, this, thread_options));
}
} ORT_CATCH(...) {
ORT_HANDLE_EXCEPTION([&]() {
SignalAllAndWait();
throw;
});
}
}

~ThreadPoolTempl() override {
done_ = true;

// Now if all threads block without work, they will start exiting.
// But note that threads can continue to work arbitrary long,
// block, submit new work, unblock and otherwise live full life.
WakeAllWorkersForExit();
// Join threads explicitly (by destroying) to avoid destruction order within
// this class.
for (size_t i = 0; i < worker_data_.size(); ++i) worker_data_[i].thread.reset();
SignalAllAndWait();
}

// Run fn(). Ordinarily, the function will be added to the thread pool and executed
Expand Down
87 changes: 34 additions & 53 deletions onnxruntime/contrib_ops/cuda/fused_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class FusedConv : public onnxruntime::cuda::Conv<T> {
Status ComputeInternal(OpKernelContext* context) const override {
CUDNN_RETURN_IF_ERROR(status_);
std::lock_guard<OrtMutex> lock(Base::s_.mutex);
ORT_RETURN_IF_ERROR(Base::UpdateState(context));
ORT_RETURN_IF_ERROR(Base::UpdateState(context, true));
if (Base::s_.Y->Shape().Size() == 0) {
return Status::OK();
}
Expand All @@ -47,27 +47,25 @@ class FusedConv : public onnxruntime::cuda::Conv<T> {
const auto alpha = onnxruntime::cuda::Consts<CudaT>::One;
const auto beta = onnxruntime::cuda::Consts<CudaT>::Zero;
IAllocatorUniquePtr<void> workspace = Base::GetWorkSpace();

if (has_b && has_z && !Base::s_.post_slicing_required) {
CUDNN_RETURN_IF_ERROR(cudnnConvolutionBiasActivationForward(Base::CudnnHandle(),
&alpha,
Base::s_.x_tensor,
Base::s_.x_data,
Base::s_.w_desc,
Base::s_.w_data,
Base::s_.conv_desc,
Base::s_.algo,
workspace.get(),
Base::s_.workspace_bytes,
&alpha,
Base::s_.z_tensor,
Base::s_.z_data,
Base::s_.b_tensor,
Base::s_.b_data,
activation_desc_,
Base::s_.y_tensor,
Base::s_.y_data));
} else {
auto cudnn_status = cudnnConvolutionBiasActivationForward(Base::CudnnHandle(),
&alpha,
Base::s_.x_tensor,
Base::s_.x_data,
Base::s_.w_desc,
Base::s_.w_data,
Base::s_.conv_desc,
Base::s_.algo,
workspace.get(),
Base::s_.workspace_bytes,
has_z ? &alpha : &beta,
has_z ? Base::s_.z_tensor : Base::s_.y_tensor,
has_z ? Base::s_.z_data : Base::s_.y_data,
Base::s_.b_tensor,
has_b ? Base::s_.b_data : Base::s_.b_zero,
activation_desc_,
Base::s_.y_tensor,
Base::s_.y_data);
if (CUDNN_STATUS_SUCCESS != cudnn_status) {
CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(Base::CudnnHandle(),
&alpha,
Base::s_.x_tensor,
Expand All @@ -81,38 +79,21 @@ class FusedConv : public onnxruntime::cuda::Conv<T> {
&beta,
Base::s_.y_tensor,
Base::s_.y_data));

if (Base::s_.post_slicing_required) {
ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection(
this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(),
Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size));

onnxruntime::cuda::CudnnTensor sliced_y_tensor;
ORT_RETURN_IF_ERROR(sliced_y_tensor.Set(Base::s_.y_dims.GetDims(), onnxruntime::cuda::CudnnTensor::GetDataType<CudaT>()));

if (has_b) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data,
&alpha, sliced_y_tensor, Base::s_.Y->MutableDataRaw()));
}
if (has_z) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data,
&alpha, sliced_y_tensor, Base::s_.Y->MutableDataRaw()));
}

CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, sliced_y_tensor,
Base::s_.y_data, &beta, sliced_y_tensor, Base::s_.y_data));
} else {
if (has_b) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data,
&alpha, Base::s_.y_tensor, Base::s_.y_data));
}
if (has_z) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data,
&alpha, Base::s_.y_tensor, Base::s_.y_data));
}
CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, Base::s_.y_tensor,
Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data));
if (has_b) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.b_tensor, Base::s_.b_data,
&alpha, Base::s_.y_tensor, Base::s_.y_data));
}
if (has_z) {
CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data,
&alpha, Base::s_.y_tensor, Base::s_.y_data));
}
CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, &alpha, Base::s_.y_tensor,
Base::s_.y_data, &beta, Base::s_.y_tensor, Base::s_.y_data));
}
if (Base::s_.post_slicing_required) {
ORT_RETURN_IF_ERROR(onnxruntime::cuda::SliceOutUnwantedOutputSection(
this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(),
Base::s_.y_dims.GetDims(), Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size));
}
return Status::OK();
}
Expand Down
74 changes: 52 additions & 22 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ limitations under the License.
#include <fcntl.h>
#include <dlfcn.h>
#include <ftw.h>
#include <optional>
#include <string.h>
#include <thread>
#include <utility> // for std::forward
#include <vector>
#include <assert.h>

#include <gsl/gsl>

#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/platform/scoped_resource.h"
Expand All @@ -54,8 +57,7 @@ class UnmapFileParam {
*
* @return a pair of {errno, error message}
*/
static std::pair<int, std::string> GetSystemError() {
auto e = errno;
static std::pair<int, std::string> GetSystemError(int e) {
char buf[1024];
const char* msg = "";
if (e > 0) {
Expand All @@ -73,6 +75,11 @@ static std::pair<int, std::string> GetSystemError() {
return std::make_pair(e, msg);
}

static std::pair<int, std::string> GetSystemError() {
auto e = errno;
return GetSystemError(e);
}

static void UnmapFile(void* param) noexcept {
std::unique_ptr<UnmapFileParam> p(reinterpret_cast<UnmapFileParam*>(param));
int ret = munmap(p->addr, p->len);
Expand Down Expand Up @@ -128,29 +135,46 @@ struct Freer {

using MallocdStringPtr = std::unique_ptr<char, Freer<char> >;


class PosixThread : public EnvThread {
private:
struct Param {
const ORTCHAR_T* name_prefix;
int index;
unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param);
Eigen::ThreadPoolInterface* param;
const ThreadOptions& thread_options;
std::optional<size_t> affinity_mask;

Param(const ORTCHAR_T* name_prefix1,
int index1,
unsigned (*start_address1)(int id, Eigen::ThreadPoolInterface* param),
Eigen::ThreadPoolInterface* param1)
: name_prefix(name_prefix1),
index(index1),
start_address(start_address1),
param(param1) {}
};

public:
PosixThread(const ORTCHAR_T* name_prefix, int index,
unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param,
const ThreadOptions& thread_options) {
ORT_ENFORCE(index >= 0, "Negative thread index is not allowed");
custom_create_thread_fn = thread_options.custom_create_thread_fn;
custom_thread_creation_options = thread_options.custom_thread_creation_options;
custom_join_thread_fn = thread_options.custom_join_thread_fn;

auto param_ptr = std::make_unique<Param>(name_prefix, index, start_address, param);
if (gsl::narrow<size_t>(index) < thread_options.affinity.size()) {
param_ptr->affinity_mask = thread_options.affinity[index];
}

if (custom_create_thread_fn) {
custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, CustomThreadMain, new Param{name_prefix, index, start_address, param, thread_options});
custom_thread_handle = custom_create_thread_fn(custom_thread_creation_options, CustomThreadMain, param_ptr.get());
if (!custom_thread_handle) {
ORT_THROW("custom_create_thread_fn returned invalid handle.");
}
param_ptr.release();
} else {
pthread_attr_t attr;
int s = pthread_attr_init(&attr);
Expand All @@ -165,24 +189,14 @@ class PosixThread : public EnvThread {
ORT_THROW("pthread_attr_setstacksize failed, error code: ", err_no, " error msg: ", err_msg);
}
}
s = pthread_create(&hThread, &attr, ThreadMain,
new Param{name_prefix, index, start_address, param, thread_options});

s = pthread_create(&hThread, &attr, ThreadMain, param_ptr.get());
if (s != 0) {
auto [err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_create failed, error code: ", err_no, " error msg: ", err_msg);
}
#if !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__) && !defined(_AIX)
if (!thread_options.affinity.empty()) {
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(thread_options.affinity[index], &cpuset);
s = pthread_setaffinity_np(hThread, sizeof(cpu_set_t), &cpuset);
if (s != 0) {
auto [err_no, err_msg] = GetSystemError();
ORT_THROW("pthread_setaffinity_np failed, error code: ", err_no, " error msg: ", err_msg);
}
}
#endif
param_ptr.release();
// Do not throw beyond this point so we do not lose thread handle and then not being able to join it.
}
}

Expand All @@ -203,13 +217,29 @@ class PosixThread : public EnvThread {

private:
static void* ThreadMain(void* param) {
std::unique_ptr<Param> p((Param*)param);
std::unique_ptr<Param> p(static_cast<Param*>(param));
ORT_TRY {
#if !defined(__APPLE__) && !defined(__ANDROID__) && !defined(__wasm__) && !defined(_AIX)
if (p->affinity_mask.has_value()) {
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(*p->affinity_mask, &cpuset);
// pthread_setaffinity_np() does not set errno, it returns it.
auto ret = pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (ret != 0) {
auto [err_no, err_msg] = GetSystemError(ret);
LOGS_DEFAULT(ERROR) << "pthread_setaffinity_np failed for thread: " << pthread_self()
<< ", mask: " << *p->affinity_mask
<< ", error code: " << err_no << " error msg: " << err_msg
<< ". Specify the number of threads explicitly so the affinity is not set.";
}
}
#endif
// Ignore the returned value for now
p->start_address(p->index, p->param);
}
ORT_CATCH(const std::exception&) {
//ignore any exceptions
ORT_CATCH(...) {
// Ignore exceptions
}
return nullptr;
}
Expand Down Expand Up @@ -440,7 +470,7 @@ class PosixEnv : public Env {
common::Status GetCanonicalPath(
const PathString& path,
PathString& canonical_path) const override {
MallocdStringPtr canonical_path_cstr{realpath(path.c_str(), nullptr)};
MallocdStringPtr canonical_path_cstr{realpath(path.c_str(), nullptr), Freer<char>()};
if (!canonical_path_cstr) {
return ReportSystemError("realpath", path);
}
Expand Down