Skip to content

Commit

Permalink
Merge branch 'nanquantile' of https://github.com/Asthestarsfalll/Paddle
Browse files Browse the repository at this point in the history
… into nanquantile
  • Loading branch information
Asthestarsfalll committed Apr 14, 2022
2 parents 86d65bb + 4250918 commit b3bc441
Show file tree
Hide file tree
Showing 92 changed files with 1,479 additions and 1,355 deletions.
18 changes: 7 additions & 11 deletions paddle/fluid/distributed/collective/Common.cc
Expand Up @@ -17,11 +17,11 @@
namespace paddle {
namespace distributed {

std::vector<Place> GetPlaceList(const std::vector<Tensor>& tensors) {
std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors) {
std::vector<Place> places;
places.reserve(tensors.size());
for (auto& tensor : tensors) {
places.push_back(tensor.inner_place());
places.push_back(tensor.place());
}
return places;
}
Expand All @@ -40,15 +40,11 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return placeList;
}

static bool CheckTensorsInPlace(const std::vector<Tensor>& tensors,
phi::AllocationType type) {
return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) {
return t.place().GetType() == type;
});
}

bool CheckTensorsInCudaPlace(const std::vector<Tensor>& tensors) {
return CheckTensorsInPlace(tensors, phi::AllocationType::GPU);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors) {
return std::all_of(tensors.cbegin(), tensors.cend(),
[&](const phi::DenseTensor& t) {
return platform::is_gpu_place(t.place());
});
}

} // namespace distributed
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/collective/Common.h
Expand Up @@ -16,18 +16,18 @@

#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace distributed {

using Tensor = paddle::experimental::Tensor;

using Place = paddle::platform::Place;
// Get the list of devices from list of tensors
std::vector<Place> GetPlaceList(const std::vector<Tensor>& tensors);
std::vector<Place> GetPlaceList(const std::vector<phi::DenseTensor>& tensors);
// Get the deviceList String from the list of devices
std::string GetKeyFromPlaces(const std::vector<Place>& places);

bool CheckTensorsInCudaPlace(const std::vector<Tensor>& tensors);
bool CheckTensorsInCudaPlace(const std::vector<phi::DenseTensor>& tensors);

} // namespace distributed
} // namespace paddle
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/collective/ProcessGroup.cc
Expand Up @@ -17,7 +17,8 @@
namespace paddle {
namespace distributed {

ProcessGroup::Task::Task(int rank, const std::vector<Tensor>& inputTensors,
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputTensors,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}

Expand Down
37 changes: 17 additions & 20 deletions paddle/fluid/distributed/collective/ProcessGroup.h
Expand Up @@ -54,7 +54,7 @@ class ProcessGroup {
public:
class Task {
public:
Task(int rank, const std::vector<Tensor>& inputTensors,
Task(int rank, const std::vector<phi::DenseTensor>& inputTensors,
CommType opType = CommType::UNKNOWN);

virtual ~Task();
Expand All @@ -79,68 +79,65 @@ class ProcessGroup {
virtual const std::string GetBackendName() const = 0;

virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<Tensor>& /* tensors */,
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const AllreduceOptions& = AllreduceOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<Tensor>& /* tensors */,
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const BroadcastOptions& = BroadcastOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast", GetBackendName()));
}

virtual void Broadcast(const phi::DenseTensor* in, phi::DenseTensor* out) {
PADDLE_THROW(platform::errors::Fatal(
"ProcessGroup%s does not support broadcast for static mode runtime",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support barrier", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<Tensor>& tensors /* tensors */, int dst_rank) { // NOLINT
std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<Tensor>& tensors /* tensors */, int src_rank) { // NOLINT
std::vector<phi::DenseTensor>& tensors, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<Tensor>& in_tensors /* tensors */, // NOLINT
std::vector<Tensor>& out_tensors /* tensors */) { // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<Tensor>& in /* tensors */, // NOLINT
std::vector<Tensor>& out /* tensors */) { // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllToAll", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<Tensor>& tensors /* tensors */, // NOLINT
const ReduceOptions& opts) { // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ReduceOptions& opts) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Reduce", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<Tensor>& in_tensors /* tensors */, // NOLINT
std::vector<Tensor>& out_tensors /* tensors */, // NOLINT
const ScatterOptions&) { // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Scatter", GetBackendName()));
}
Expand Down

0 comments on commit b3bc441

Please sign in to comment.