Skip to content

Commit

Permalink
[Bug Fix]Fix global_scatter/global_gather in ProcessGroup (#43027)
Browse files Browse the repository at this point in the history
* fix alltoall

* rename utest
  • Loading branch information
ForFishes committed May 28, 2022
1 parent 9eb18c7 commit 8cc2e28
Show file tree
Hide file tree
Showing 11 changed files with 376 additions and 11 deletions.
13 changes: 13 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroup.h
Expand Up @@ -113,6 +113,19 @@ class ProcessGroup {
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&,
int, int,
int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, int, int, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Expand Up @@ -428,6 +428,53 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors, int dst_rank, int offset, int length) {
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});

phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);

std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);

auto task = PointToPoint(shared_tensors,
[&](phi::DenseTensor& input, ncclComm_t comm,
const gpuStream_t& stream, int dst_rank) {
return platform::dynload::ncclSend(
input.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank, comm, stream);
},
dst_rank, CommType::SEND);
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
phi::DenseTensor& tensors, int src_rank, int offset, int length) {
// phi::DenseTensor shared_input = tensors.Slice(offset, offset+length);

phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);

std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);

auto task = PointToPoint(shared_tensors,
[&](phi::DenseTensor& output, ncclComm_t comm,
const gpuStream_t& stream, int src_rank) {
return platform::dynload::ncclRecv(
output.data(), output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank, comm, stream);
},
src_rank, CommType::RECV);
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.h
Expand Up @@ -102,6 +102,14 @@ class ProcessGroupNCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;

std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank, int offset,
int length) override;

std::shared_ptr<ProcessGroup::Task> Recv_Partial(phi::DenseTensor& tensors,
int src_rank, int offset,
int length) override;

std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
Expand Down
132 changes: 129 additions & 3 deletions paddle/fluid/operators/collective/global_gather_op.cu.cc
Expand Up @@ -22,10 +22,10 @@ limitations under the License. */

namespace paddle {
namespace operators {

template <typename T>
class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
struct GlobalGatherFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
Expand Down Expand Up @@ -137,6 +137,132 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
}
};

template <typename T>
struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count");
auto global_count = ctx.Input<framework::LoDTensor>("global_count");
auto local_count_type =
framework::TransToProtoVarType(local_count->dtype());
auto global_count_type =
framework::TransToProtoVarType(global_count->dtype());
if (local_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in local_count."));
}
if (global_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in global_count."));
}
auto out = ctx.Output<framework::LoDTensor>("Out");
const int64_t* cpu_local_count_data;
const int64_t* cpu_global_count_data;
auto local_count_len = 0;

framework::Tensor cpu_local_count;
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
local_count_len = local_count->numel();
} else {
framework::TensorCopySync(*local_count, platform::CPUPlace(),
&cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
local_count_len = cpu_local_count.numel();
}

framework::Tensor cpu_global_count;
if (platform::is_cpu_place(global_count->place())) {
cpu_global_count_data = global_count->data<int64_t>();
} else {
framework::TensorCopySync(*global_count, platform::CPUPlace(),
&cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
}

int ring_id = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for global gather op must be non-negative.",
ring_id));
auto place = ctx.GetPlace();

auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(ring_id);

int nranks = pg->GetSize();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;

auto fwd_count = 0;

for (auto i = 0; i < local_count_len; ++i) {
fwd_count += cpu_local_count_data[i];
}
framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat});
int64_t* expert_ptr = new int64_t[n_expert * nranks];
expert_ptr[0] = 0;
auto tot_experts = n_expert * nranks;
for (auto i = 1; i < tot_experts; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto send_ptr = 0;
out->mutable_data<T>(out_dims, place);

for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send_Partial(tmp, j, send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat);
send_ptr += cpu_global_count_data[idx];
}
if (cpu_local_count_data[idx]) {
pg->Recv_Partial(*out, j, expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}

#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif

#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
#endif
#else
PADDLE_THROW(
platform::errors::Unavailable("PaddlePaddle should compile with GPU."));
#endif
}
};

template <typename T>
class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
GlobalGatherProcessGroupFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
} else {
GlobalGatherFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
}
}
};

} // namespace operators
} // namespace paddle

Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/collective/global_gather_op.h
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
Expand All @@ -33,5 +34,15 @@ class GlobalGatherOpCPUKernel : public framework::OpKernel<T> {
}
};

template <typename Context, typename T>
struct GlobalGatherFunctor {
void operator()(const framework::ExecutionContext& ctx);
};

template <typename Context, typename T>
struct GlobalGatherProcessGroupFunctor {
void operator()(const framework::ExecutionContext& ctx);
};

} // namespace operators
} // namespace paddle

0 comments on commit 8cc2e28

Please sign in to comment.