Skip to content

Commit

Permalink
Merge branch 'develop' into triplet_margin_distance_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Jun 1, 2022
2 parents b0ae52a + 8162270 commit 2a4ed81
Show file tree
Hide file tree
Showing 1,219 changed files with 50,381 additions and 12,806 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Expand Up @@ -66,3 +66,8 @@ paddle/infrt/tests/lit.cfg.py
paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc
paddle/fluid/pybind/eager_final_state_op_function_impl.h
paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h

# these files (directories) are generated before build system generation
paddle/fluid/operators/generated_op.cc
paddle/phi/ops/compat/generated_sig.cc
python/paddle/utils/code_gen/parsed_apis/
10 changes: 8 additions & 2 deletions CMakeLists.txt
Expand Up @@ -60,6 +60,7 @@ option(WITH_ONNXRUNTIME "Compile PaddlePaddle with ONNXRUNTIME"
# Note(zhouwei): It use option above, so put here
include(init)
include(generic) # simplify cmake module
include(experimental) # experimental build options

if (WITH_GPU AND WITH_XPU)
message(FATAL_ERROR "Error when compile GPU and XPU at the same time")
Expand Down Expand Up @@ -256,8 +257,8 @@ option(WITH_CUSTOM_DEVICE "Compile with custom device support" OFF)
option(WITH_ARM_BRPC "Supprot Brpc in Arm" OFF)

if(WITH_RECORD_BUILDTIME)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh")
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh")
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh ${CMAKE_CURRENT_BINARY_DIR}")
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh ${CMAKE_CURRENT_BINARY_DIR}")
else()
include(ccache) # set ccache for compilation ; if WITH_RECORD_BUILDTIME=ON can't use ccache
endif()
Expand Down Expand Up @@ -395,6 +396,11 @@ if(WITH_DISTRIBUTE)
MESSAGE(WARNING "Disable WITH_PSCORE when compiling with NPU. Force WITH_PSCORE=OFF.")
set(WITH_PSCORE OFF CACHE BOOL "Disable WITH_PSCORE when compiling with NPU" FORCE)
endif()
if(WITH_ROCM AND HIP_VERSION LESS_EQUAL 40020496)
# TODO(qili93): third-party rocksdb throw Illegal instruction with HIP version 40020496
MESSAGE(WARNING "Disable WITH_PSCORE when HIP_VERSION is less than or equal 40020496. Force WITH_PSCORE=OFF.")
set(WITH_PSCORE OFF CACHE BOOL "Disable WITH_PSCORE when HIP_VERSION is less than or equal 40020496" FORCE)
endif()
endif()

include(third_party) # download, build, install third_party, Contains about 20+ dependencies
Expand Down
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -20,7 +20,7 @@ PaddlePaddle is originated from industrial practices with dedication and commitm

## Installation

### Latest PaddlePaddle Release: [v2.2](https://github.com/PaddlePaddle/Paddle/tree/release/2.2)
### Latest PaddlePaddle Release: [v2.3](https://github.com/PaddlePaddle/Paddle/tree/release/2.3)

Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest features of PaddlePaddle.
Expand Down
10 changes: 6 additions & 4 deletions cmake/cblas.cmake
Expand Up @@ -52,6 +52,7 @@ if(NOT DEFINED CBLAS_PROVIDER)
set(OPENBLAS_INCLUDE_SEARCH_PATHS
${OPENBLAS_ROOT}/include
/usr/include
/usr/include/lapacke
/usr/include/openblas
/usr/local/opt/openblas/include)
set(OPENBLAS_LIB_SEARCH_PATHS
Expand All @@ -65,15 +66,17 @@ if(NOT DEFINED CBLAS_PROVIDER)
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS} NO_DEFAULT_PATH)
find_path(OPENBLAS_LAPACKE_INC_DIR NAMES lapacke.h
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
find_path(OPENBLAS_CONFIG_INC_DIR NAMES openblas_config.h
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
find_library(OPENBLAS_LIB NAMES openblas
PATHS ${OPENBLAS_LIB_SEARCH_PATHS})

if(OPENBLAS_LAPACKE_INC_DIR AND OPENBLAS_INC_DIR AND OPENBLAS_LIB)
file(READ "${OPENBLAS_INC_DIR}/openblas_config.h" config_file)
if(OPENBLAS_LAPACKE_INC_DIR AND OPENBLAS_INC_DIR AND OPENBLAS_CONFIG_INC_DIR AND OPENBLAS_LIB)
file(READ "${OPENBLAS_CONFIG_INC_DIR}/openblas_config.h" config_file)
string(REGEX MATCH "OpenBLAS ([0-9]+\.[0-9]+\.[0-9]+)" tmp ${config_file})
string(REGEX MATCH "([0-9]+\.[0-9]+\.[0-9]+)" ver ${tmp})

if (${ver} VERSION_GREATER_EQUAL "0.3.7")
if (${ver} VERSION_GREATER_EQUAL "0.3.5")
set(CBLAS_PROVIDER OPENBLAS)
set(CBLAS_INC_DIR ${OPENBLAS_INC_DIR} ${OPENBLAS_LAPACKE_INC_DIR})
set(CBLAS_LIBRARIES ${OPENBLAS_LIB})
Expand Down Expand Up @@ -138,4 +141,3 @@ if(${CBLAS_PROVIDER} STREQUAL REFERENCE_CBLAS)
elseif(NOT ${CBLAS_PROVIDER} STREQUAL MKLML)
target_link_libraries(cblas ${CBLAS_LIBRARIES})
endif()

17 changes: 17 additions & 0 deletions cmake/experimental.cmake
@@ -0,0 +1,17 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# this file contains experimental build options

include(experiments/cuda_module_loading_lazy)
40 changes: 40 additions & 0 deletions cmake/experiments/cuda_module_loading_lazy.cmake
@@ -0,0 +1,40 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# this file contains experimental build options for lazy cuda module loading
# cuda moduel lazy loading is supported by CUDA 11.6+
# this experiment option makes Paddle supports lazy loading before CUDA 11.6.

option(EXP_CUDA_MODULE_LOADING_LAZY "enable lazy cuda module loading" OFF)
if (${EXP_CUDA_MODULE_LOADING_LAZY})
if (NOT ${ON_INFER} OR NOT ${LINUX})
message("EXP_CUDA_MODULE_LOADING_LAZY only works with ON_INFER=ON on Linux platforms")
return()
endif ()
if (NOT ${CUDA_FOUND})
message("EXP_CUDA_MODULE_LOADING_LAZY only works with CUDA")
return()
endif ()
if (${CUDA_VERSION} VERSION_GREATER_EQUAL "11.6")
message("cuda 11.6+ already support lazy module loading")
return()
endif ()

message("for cuda before 11.6, libcudart.so must be used for the lazy module loading trick to work, instead of libcudart_static.a")
set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE BOOL "" FORCE)
set(CMAKE_CUDA_FLAGS "--cudart shared")
enable_language(CUDA)
set(CUDA_NVCC_EXECUTABLE "${CMAKE_SOURCE_DIR}/tools/nvcc_lazy" CACHE FILEPATH "" FORCE)
set(CMAKE_CUDA_COMPILER "${CMAKE_SOURCE_DIR}/tools/nvcc_lazy" CACHE FILEPATH "" FORCE)
endif()
5 changes: 3 additions & 2 deletions cmake/external/ascend.cmake
Expand Up @@ -90,9 +90,10 @@ endif()
if (WITH_ASCEND_CL)
macro(find_ascend_toolkit_version ascend_toolkit_version_info)
file(READ ${ascend_toolkit_version_info} ASCEND_TOOLKIT_VERSION_CONTENTS)
string(REGEX MATCH "version=([0-9]+\.[0-9]+\.(RC)?[0-9]+\.[a-z]*[0-9]*)" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION_CONTENTS}")
string(REGEX REPLACE "version=([0-9]+\.[0-9]+\.(RC)?[0-9]+\.[a-z]*[0-9]*)" "\\1" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION}")
string(REGEX MATCH "version=([0-9]+\.[0-9]+\.(RC)?[0-9][.a-z0-9]*)" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION_CONTENTS}")
string(REGEX REPLACE "version=([0-9]+\.[0-9]+\.(RC)?[0-9][.a-z0-9]*)" "\\1" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION}")
string(REGEX REPLACE "[A-Z]|[a-z|\.]" "" CANN_VERSION ${ASCEND_TOOLKIT_VERSION})
STRING(SUBSTRING "${CANN_VERSION}000" 0 6 CANN_VERSION)
add_definitions("-DCANN_VERSION_CODE=${CANN_VERSION}")
if(NOT ASCEND_TOOLKIT_VERSION)
set(ASCEND_TOOLKIT_VERSION "???")
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/mkldnn.cmake
Expand Up @@ -19,7 +19,7 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY ${GIT_URL}/oneapi-src/oneDNN.git)
SET(MKLDNN_TAG 9a35435c18722ff17a48fb60bceac42bfdf78754)
SET(MKLDNN_TAG 9b186765dded79066e0cd9c17eb70b680b76fb8e)


# Introduce variables:
Expand Down
4 changes: 2 additions & 2 deletions cmake/external/xpu.cmake
Expand Up @@ -9,15 +9,15 @@ SET(XPU_RT_LIB_NAME "libxpurt.so")

if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220425")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220520")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()

# ubuntu and centos: use output by XDNN API team
if(NOT DEFINED XPU_XDNN_BASE_URL)
SET(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220425")
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220520")
else()
SET(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()
Expand Down
2 changes: 0 additions & 2 deletions cmake/flags.cmake
Expand Up @@ -142,12 +142,10 @@ set(COMMON_FLAGS
-Wno-unused-function
-Wno-error=literal-suffix
-Wno-error=unused-local-typedefs
-Wno-error=parentheses-equality # Warnings in pybind11
-Wno-error=ignored-attributes # Warnings in Eigen, gcc 6.3
-Wno-error=terminate # Warning in PADDLE_ENFORCE
-Wno-error=int-in-bool-context # Warning in Eigen gcc 7.2
-Wimplicit-fallthrough=0 # Warning in tinyformat.h
-Wno-error=maybe-uninitialized # Warning in boost gcc 7.2
${fsanitize}
)

Expand Down
27 changes: 27 additions & 0 deletions cmake/hip.cmake
Expand Up @@ -18,6 +18,33 @@ include_directories(${ROCM_PATH}/include)
message(STATUS "HIP version: ${HIP_VERSION}")
message(STATUS "HIP_CLANG_PATH: ${HIP_CLANG_PATH}")

macro(find_hip_version hip_header_file)
file(READ ${hip_header_file} HIP_VERSION_FILE_CONTENTS)

string(REGEX MATCH "define HIP_VERSION_MAJOR +([0-9]+)" HIP_MAJOR_VERSION
"${HIP_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define HIP_VERSION_MAJOR +([0-9]+)" "\\1"
HIP_MAJOR_VERSION "${HIP_MAJOR_VERSION}")
string(REGEX MATCH "define HIP_VERSION_MINOR +([0-9]+)" HIP_MINOR_VERSION
"${HIP_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define HIP_VERSION_MINOR +([0-9]+)" "\\1"
HIP_MINOR_VERSION "${HIP_MINOR_VERSION}")
string(REGEX MATCH "define HIP_VERSION_PATCH +([0-9]+)" HIP_PATCH_VERSION
"${HIP_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define HIP_VERSION_PATCH +([0-9]+)" "\\1"
HIP_PATCH_VERSION "${HIP_PATCH_VERSION}")

if(NOT HIP_MAJOR_VERSION)
set(HIP_VERSION "???")
message(WARNING "Cannot find HIP version in ${HIP_PATH}/include/hip/hip_version.h")
else()
math(EXPR HIP_VERSION "${HIP_MAJOR_VERSION} * 10000000 + ${HIP_MINOR_VERSION} * 100000 + ${HIP_PATCH_VERSION}")
message(STATUS "Current HIP header is ${HIP_PATH}/include/hip/hip_version.h "
"Current HIP version is v${HIP_MAJOR_VERSION}.${HIP_MINOR_VERSION}.${HIP_PATCH_VERSION}. ")
endif()
endmacro()
find_hip_version(${HIP_PATH}/include/hip/hip_version.h)

macro(find_package_and_include PACKAGE_NAME)
find_package("${PACKAGE_NAME}" REQUIRED)
include_directories("${ROCM_PATH}/${PACKAGE_NAME}/include")
Expand Down
2 changes: 1 addition & 1 deletion cmake/inference_lib.cmake
Expand Up @@ -416,7 +416,7 @@ function(version version_file)
endif()
if(WITH_ROCM)
file(APPEND ${version_file}
"HIP version: ${HIP_VERSION}\n"
"HIP version: v${HIP_MAJOR_VERSION}.${HIP_MINOR_VERSION}\n"
"MIOpen version: v${MIOPEN_MAJOR_VERSION}.${MIOPEN_MINOR_VERSION}\n")
endif()
if(WITH_ASCEND_CL)
Expand Down
7 changes: 7 additions & 0 deletions cmake/operators.cmake
Expand Up @@ -261,6 +261,13 @@ function(op_library TARGET)
elseif (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0)
xpu_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${xpu_kp_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps})
else()
# deal with CANN version control while registering NPU operators before build
if (WITH_ASCEND_CL)
if (CANN_VERSION LESS 504000)
list(REMOVE_ITEM npu_cc_srcs "multinomial_op_npu.cc")
list(REMOVE_ITEM npu_cc_srcs "take_along_axis_op_npu.cc")
endif()
endif()
# Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
if(WITH_UNITY_BUILD AND op_library_UNITY)
# Combine the cc source files.
Expand Down
2 changes: 1 addition & 1 deletion cmake/third_party.cmake
Expand Up @@ -357,7 +357,7 @@ if (WITH_PSCORE)
include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)

include(external/rocksdb) # download, build, install libmct
include(external/rocksdb) # download, build, install rocksdb
list(APPEND third_party_deps extern_rocksdb)
endif()

Expand Down
4 changes: 4 additions & 0 deletions cmake/xpu_kp.cmake
Expand Up @@ -16,6 +16,10 @@ if(NOT WITH_XPU_KP)
return()
endif()

set(LINK_FLAGS "-Wl,--allow-multiple-definition")
set(CMAKE_EXE_LINKER_FLAGS "${LINK_FLAGS}")
set(CMAKE_SHARED_LINKER_FLAGS "${LINK_FLAGS}")

if(NOT XPU_TOOLCHAIN)
set(XPU_TOOLCHAIN /workspace/output/XTDK-ubuntu_x86_64)
get_filename_component(XPU_TOOLCHAIN ${XPU_TOOLCHAIN} REALPATH)
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/CMakeLists.txt
Expand Up @@ -26,7 +26,7 @@ add_custom_command(TARGET ps_framework_proto POST_BUILD
COMMAND mv the_one_ps.pb.h ps.pb.h
COMMAND mv the_one_ps.pb.cc ps.pb.cc)

set(DISTRIBUTE_COMPILE_FLAGS "-Wno-error=unused-value -Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result")
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-error=unused-value -Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result")

if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroup.h
Expand Up @@ -113,6 +113,19 @@ class ProcessGroup {
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&,
int, int,
int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, int, int, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
Expand Down
14 changes: 1 addition & 13 deletions paddle/fluid/distributed/collective/ProcessGroupHeter.cc
Expand Up @@ -255,12 +255,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(

std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(
std::vector<phi::DenseTensor>& in_tensors, int peer) {
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
#endif

PADDLE_ENFORCE_EQ(
in_tensors.size(), 1,
platform::errors::PreconditionNotMet(
Expand Down Expand Up @@ -299,12 +293,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Send(

std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv(
std::vector<phi::DenseTensor>& out_tensors, int peer) {
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
#endif

PADDLE_ENFORCE_EQ(
out_tensors.size(), 1,
platform::errors::PreconditionNotMet(
Expand Down Expand Up @@ -343,7 +331,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Recv(
end = std::chrono::high_resolution_clock::now();
diff = end - start;
VLOG(2) << "Time to copy tensor of dims(" << cpu_tensor.dims()
<< ") from gpu to cpu for recv " << std::setw(9)
<< ") from cpu to gpu for recv " << std::setw(9)
<< " is: " << diff.count() << " s" << std::endl;
return CreateTask(rank_, CommType::RECV, out_tensors);
}
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Expand Up @@ -428,6 +428,53 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors, int dst_rank, int offset, int length) {
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));

phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});

phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);

std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);

auto task = PointToPoint(shared_tensors,
[&](phi::DenseTensor& input, ncclComm_t comm,
const gpuStream_t& stream, int dst_rank) {
return platform::dynload::ncclSend(
input.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank, comm, stream);
},
dst_rank, CommType::SEND);
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
phi::DenseTensor& tensors, int src_rank, int offset, int length) {
// phi::DenseTensor shared_input = tensors.Slice(offset, offset+length);

phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);

std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);

auto task = PointToPoint(shared_tensors,
[&](phi::DenseTensor& output, ncclComm_t comm,
const gpuStream_t& stream, int src_rank) {
return platform::dynload::ncclRecv(
output.data(), output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank, comm, stream);
},
src_rank, CommType::RECV);
return task;
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
Expand Down

0 comments on commit 2a4ed81

Please sign in to comment.