Skip to content

Commit

Permalink
Merge branch 'develop' into triplet_margin_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Jun 1, 2022
2 parents 2be018e + 8162270 commit b7b06e3
Show file tree
Hide file tree
Showing 695 changed files with 20,712 additions and 6,826 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Expand Up @@ -66,3 +66,8 @@ paddle/infrt/tests/lit.cfg.py
paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc
paddle/fluid/pybind/eager_final_state_op_function_impl.h
paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h

# these files (directories) are generated before build system generation
paddle/fluid/operators/generated_op.cc
paddle/phi/ops/compat/generated_sig.cc
python/paddle/utils/code_gen/parsed_apis/
1 change: 1 addition & 0 deletions CMakeLists.txt
Expand Up @@ -60,6 +60,7 @@ option(WITH_ONNXRUNTIME "Compile PaddlePaddle with ONNXRUNTIME"
# Note(zhouwei): It use option above, so put here
include(init)
include(generic) # simplify cmake module
include(experimental) # experimental build options

if (WITH_GPU AND WITH_XPU)
message(FATAL_ERROR "Error when compile GPU and XPU at the same time")
Expand Down
10 changes: 6 additions & 4 deletions cmake/cblas.cmake
Expand Up @@ -52,6 +52,7 @@ if(NOT DEFINED CBLAS_PROVIDER)
set(OPENBLAS_INCLUDE_SEARCH_PATHS
${OPENBLAS_ROOT}/include
/usr/include
/usr/include/lapacke
/usr/include/openblas
/usr/local/opt/openblas/include)
set(OPENBLAS_LIB_SEARCH_PATHS
Expand All @@ -65,15 +66,17 @@ if(NOT DEFINED CBLAS_PROVIDER)
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS} NO_DEFAULT_PATH)
find_path(OPENBLAS_LAPACKE_INC_DIR NAMES lapacke.h
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
find_path(OPENBLAS_CONFIG_INC_DIR NAMES openblas_config.h
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
find_library(OPENBLAS_LIB NAMES openblas
PATHS ${OPENBLAS_LIB_SEARCH_PATHS})

if(OPENBLAS_LAPACKE_INC_DIR AND OPENBLAS_INC_DIR AND OPENBLAS_LIB)
file(READ "${OPENBLAS_INC_DIR}/openblas_config.h" config_file)
if(OPENBLAS_LAPACKE_INC_DIR AND OPENBLAS_INC_DIR AND OPENBLAS_CONFIG_INC_DIR AND OPENBLAS_LIB)
file(READ "${OPENBLAS_CONFIG_INC_DIR}/openblas_config.h" config_file)
string(REGEX MATCH "OpenBLAS ([0-9]+\.[0-9]+\.[0-9]+)" tmp ${config_file})
string(REGEX MATCH "([0-9]+\.[0-9]+\.[0-9]+)" ver ${tmp})

if (${ver} VERSION_GREATER_EQUAL "0.3.7")
if (${ver} VERSION_GREATER_EQUAL "0.3.5")
set(CBLAS_PROVIDER OPENBLAS)
set(CBLAS_INC_DIR ${OPENBLAS_INC_DIR} ${OPENBLAS_LAPACKE_INC_DIR})
set(CBLAS_LIBRARIES ${OPENBLAS_LIB})
Expand Down Expand Up @@ -138,4 +141,3 @@ if(${CBLAS_PROVIDER} STREQUAL REFERENCE_CBLAS)
elseif(NOT ${CBLAS_PROVIDER} STREQUAL MKLML)
target_link_libraries(cblas ${CBLAS_LIBRARIES})
endif()

17 changes: 17 additions & 0 deletions cmake/experimental.cmake
@@ -0,0 +1,17 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# this file contains experimental build options

include(experiments/cuda_module_loading_lazy)
40 changes: 40 additions & 0 deletions cmake/experiments/cuda_module_loading_lazy.cmake
@@ -0,0 +1,40 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# this file contains experimental build options for lazy cuda module loading
# cuda moduel lazy loading is supported by CUDA 11.6+
# this experiment option makes Paddle supports lazy loading before CUDA 11.6.

option(EXP_CUDA_MODULE_LOADING_LAZY "enable lazy cuda module loading" OFF)
if (${EXP_CUDA_MODULE_LOADING_LAZY})
if (NOT ${ON_INFER} OR NOT ${LINUX})
message("EXP_CUDA_MODULE_LOADING_LAZY only works with ON_INFER=ON on Linux platforms")
return()
endif ()
if (NOT ${CUDA_FOUND})
message("EXP_CUDA_MODULE_LOADING_LAZY only works with CUDA")
return()
endif ()
if (${CUDA_VERSION} VERSION_GREATER_EQUAL "11.6")
message("cuda 11.6+ already support lazy module loading")
return()
endif ()

message("for cuda before 11.6, libcudart.so must be used for the lazy module loading trick to work, instead of libcudart_static.a")
set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE BOOL "" FORCE)
set(CMAKE_CUDA_FLAGS "--cudart shared")
enable_language(CUDA)
set(CUDA_NVCC_EXECUTABLE "${CMAKE_SOURCE_DIR}/tools/nvcc_lazy" CACHE FILEPATH "" FORCE)
set(CMAKE_CUDA_COMPILER "${CMAKE_SOURCE_DIR}/tools/nvcc_lazy" CACHE FILEPATH "" FORCE)
endif()
4 changes: 2 additions & 2 deletions cmake/external/xpu.cmake
Expand Up @@ -9,15 +9,15 @@ SET(XPU_RT_LIB_NAME "libxpurt.so")

if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220511")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220520")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()

# ubuntu and centos: use output by XDNN API team
if(NOT DEFINED XPU_XDNN_BASE_URL)
SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220511")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220520")
else()
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()
Expand Down
2 changes: 0 additions & 2 deletions cmake/flags.cmake
Expand Up @@ -142,12 +142,10 @@ set(COMMON_FLAGS
-Wno-unused-function
-Wno-error=literal-suffix
-Wno-error=unused-local-typedefs
-Wno-error=parentheses-equality # Warnings in pybind11
-Wno-error=ignored-attributes # Warnings in Eigen, gcc 6.3
-Wno-error=terminate # Warning in PADDLE_ENFORCE
-Wno-error=int-in-bool-context # Warning in Eigen gcc 7.2
-Wimplicit-fallthrough=0 # Warning in tinyformat.h
-Wno-error=maybe-uninitialized # Warning in boost gcc 7.2
${fsanitize}
)

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/CMakeLists.txt
Expand Up @@ -26,7 +26,7 @@ add_custom_command(TARGET ps_framework_proto POST_BUILD
COMMAND mv the_one_ps.pb.h ps.pb.h
COMMAND mv the_one_ps.pb.cc ps.pb.cc)

set(DISTRIBUTE_COMPILE_FLAGS "-Wno-error=unused-value -Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result")
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-error=unused-value -Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result")

if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroup.h
Expand Up @@ -113,6 +113,19 @@ class ProcessGroup {
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&,
int, int,
int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, int, int, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Expand Up @@ -428,6 +428,53 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors, int dst_rank, int offset, int length) {
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});

phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);

std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);

auto task = PointToPoint(shared_tensors,
[&](phi::DenseTensor& input, ncclComm_t comm,
const gpuStream_t& stream, int dst_rank) {
return platform::dynload::ncclSend(
input.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank, comm, stream);
},
dst_rank, CommType::SEND);
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
phi::DenseTensor& tensors, int src_rank, int offset, int length) {
// phi::DenseTensor shared_input = tensors.Slice(offset, offset+length);

phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);

std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);

auto task = PointToPoint(shared_tensors,
[&](phi::DenseTensor& output, ncclComm_t comm,
const gpuStream_t& stream, int src_rank) {
return platform::dynload::ncclRecv(
output.data(), output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank, comm, stream);
},
src_rank, CommType::RECV);
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.h
Expand Up @@ -102,6 +102,14 @@ class ProcessGroupNCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;

std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank, int offset,
int length) override;

std::shared_ptr<ProcessGroup::Task> Recv_Partial(phi::DenseTensor& tensors,
int src_rank, int offset,
int length) override;

std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/fleet_executor/dist_model.cc
Expand Up @@ -546,9 +546,9 @@ bool DistModel::Run(const std::vector<DistModelTensor> &input_data,

DistModelTimer timer;
timer.tic();
double feed_elapse;
double fleet_exe_elapse;
double fetch_elapse;
double feed_elapse = 0;
double fleet_exe_elapse = 0;
double fetch_elapse = 0;

if (!FeedData(input_data, scope_.get())) {
LOG(ERROR) << "DistModel failed at feeding data.";
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Expand Up @@ -261,7 +261,7 @@ int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) {
}

int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) {
uint32_t feasign_size = 0;
int32_t feasign_size = 0;
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, "
"err:"
Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/distributed/ps/service/brpc_ps_server.cc
Expand Up @@ -301,11 +301,6 @@ int32_t BrpcPsService::PullDense(Table *table, const PsRequestMessage &request,
}
CostTimer timer("pserver_server_pull_dense");
uint32_t num = *(const uint32_t *)request.params(0).c_str();
if (num < 0) {
set_response_code(response, -1,
"PsRequestMessage.datas[0] is invalid, num must >= 0");
return 0;
}

auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->ValueAccesor()->GetAccessorInfo().select_size /
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/distributed/test/ctr_accessor_test.cc
Expand Up @@ -196,9 +196,10 @@ TEST(downpour_feature_value_accessor_test, test_update) {
ptr[idx + j] = embedx_w[j];
}
idx += 8;
for (auto j = 0u; j < 0; ++j) {
ptr[idx + j] = embedx_g2sum[j];
}
// NaiveSGD has no embedx_g2sum
// for (auto j = 0u; j < 0; ++j) {
// ptr[idx + j] = embedx_g2sum[j];
// }
}
};
struct DownpourSparsePushValueTest {
Expand Down
54 changes: 31 additions & 23 deletions paddle/fluid/eager/accumulation/accumulation_node.cc
Expand Up @@ -28,33 +28,40 @@
namespace egr {

static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
const paddle::experimental::Tensor& t) {
if (!tensor->defined() || !tensor->initialized()) {
// Simply copy tensor->impl
const paddle::experimental::Tensor& t,
bool is_fake_empty) {
if (is_fake_empty) {
*tensor = t;
} else {
// Accumulation
if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t, tensor);
if (!tensor->defined() || !tensor->initialized()) {
// Simply copy tensor->impl
*tensor = t;
} else {
// Accumulation
if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t,
tensor);
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
// add_dygraph_function once it's supported
paddle::experimental::Tensor new_buffer(
std::make_shared<phi::DenseTensor>(), "tmp_accumulator");
paddle::imperative::SelectedRowsAddTensor(*tensor, t, &new_buffer);
tensor->set_impl(new_buffer.impl());
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
// add_dygraph_function once it's supported
paddle::experimental::Tensor new_buffer(
std::make_shared<phi::DenseTensor>(), "tmp_accumulator");
paddle::imperative::SelectedRowsAddTensor(*tensor, t, &new_buffer);
tensor->set_impl(new_buffer.impl());
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function
// once it's supported
if (tensor->is_dense_tensor()) {
paddle::imperative::SelectedRowsAddToTensor(t, tensor);
} else {
*tensor = std::move(*paddle::imperative::SelectedRowsMerge<
paddle::experimental::Tensor>(t, *tensor));
// add_dygraph_function
// once it's supported
if (tensor->is_dense_tensor()) {
paddle::imperative::SelectedRowsAddToTensor(t, tensor);
} else {
*tensor = std::move(*paddle::imperative::SelectedRowsMerge<
paddle::experimental::Tensor>(t, *tensor));
}
}
}
}
Expand Down Expand Up @@ -91,7 +98,8 @@ GradNodeAccumulation::operator()(

if (!weak_grad_.expired() && !is_new_grad) {
auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out);
CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_);
is_fake_empty_ = false;
}

// Apply Reduce Hooks
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/eager/accumulation/accumulation_node.h
Expand Up @@ -64,14 +64,16 @@ class GradNodeAccumulation : public GradNodeBase {
new GradNodeAccumulation(nullptr));
}

void SetFakeEmpty(bool is_fake_empty) { is_fake_empty_ = is_fake_empty; }

private:
// TODO(Jiabin): remove this when we make our clear gradient really cleared;
bool is_fake_empty_ = {false};
std::weak_ptr<paddle::experimental::Tensor> weak_grad_;

std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
retain_grad_hook_;

std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
};

} // namespace egr

0 comments on commit b7b06e3

Please sign in to comment.