Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable clang-tidy coverage on torch/csrc/distributed/c10d/* #125102

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp
Expand Up @@ -751,7 +751,7 @@ class UvClient : public UvTcpSocket {
if (!stream.read_key(key))
return false;

auto data = store->get(key);
const auto& data = store->get(key);
StreamWriter sw(iptr());
sw.write_vector(data);
sw.send();
Expand Down Expand Up @@ -883,8 +883,7 @@ class UvClient : public UvTcpSocket {
return false;
}

auto data = store->get(key);
sw.write_vector(data);
sw.write_vector(store->get(key));
}
sw.send();

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -2031,7 +2031,7 @@ communication mechanism.
self->registerOnCompletionHook(
[hookWrapper = ::c10d::PythonOnCompletionHook(std::move(
hook))](std::shared_ptr<::c10d::WorkInfo> workInfo) {
hookWrapper(std::move(workInfo));
hookWrapper(workInfo);
});
},
py::arg("hook"),
Expand Down
31 changes: 9 additions & 22 deletions torch/csrc/distributed/c10d/quantization/quantization.cpp
Expand Up @@ -2,10 +2,7 @@
#include <torch/csrc/distributed/c10d/quantization/quantization_utils.h>
#include <torch/library.h>

namespace torch {
namespace distributed {
namespace c10d {
namespace quantization {
namespace torch::distributed::c10d::quantization {

// TODO: The kernels are copied from fbgemm_gpu, we should dedup them later

Expand All @@ -31,11 +28,9 @@ static void BFloat16QuantizedToFloat_ref(
const size_t nrows,
const size_t ncols,
float* const output) {
const int32_t output_columns = ncols;

for (const auto row : c10::irange(nrows)) {
const at::BFloat16* input_row = input + row * ncols;
float* output_row = output + row * output_columns;
float* output_row = output + row * ncols;

for (const auto col : c10::irange(ncols)) {
uint32_t val_fp32 = static_cast<uint32_t>(
Expand All @@ -52,11 +47,9 @@ at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input) {
TENSOR_NDIM_EQUALS(input, 2);

const auto input_sizes = input.sizes();
const int32_t nrows = input_sizes[0];
const int32_t ncols = input_sizes[1];
const int32_t output_columns = ncols;
auto output =
at::empty({nrows, output_columns}, input.options().dtype(at::kHalf));
const auto nrows = input_sizes[0];
const auto ncols = input_sizes[1];
auto output = at::empty({nrows, ncols}, input.options().dtype(at::kHalf));

FloatToBFloat16Quantized_ref(
input.const_data_ptr<float>(),
Expand All @@ -73,13 +66,10 @@ at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input) {
TENSOR_NDIM_EQUALS(input, 2);

const auto input_sizes = input.sizes();
const int32_t nrows = input_sizes[0];
const int32_t ncols = input_sizes[1];
const int32_t output_columns = ncols;
const auto nrows = input_sizes[0];
const auto ncols = input_sizes[1];

auto output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
input.options().dtype(at::kFloat)); //
auto output = at::empty({nrows, ncols}, input.options().dtype(at::kFloat));
BFloat16QuantizedToFloat_ref(
reinterpret_cast<const at::BFloat16*>(input.const_data_ptr<at::Half>()),
nrows,
Expand All @@ -99,7 +89,4 @@ TORCH_LIBRARY_IMPL(quantization, CPU, m) {
m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu);
}

} // namespace quantization
} // namespace c10d
} // namespace distributed
} // namespace torch
} // namespace torch::distributed::c10d::quantization
10 changes: 2 additions & 8 deletions torch/csrc/distributed/c10d/quantization/quantization.h
Expand Up @@ -8,15 +8,9 @@
#include <ATen/ATen.h>
#include <vector>

namespace torch {
namespace distributed {
namespace c10d {
namespace quantization {
namespace torch::distributed::c10d::quantization {

at::Tensor _float_to_bfloat16_cpu(const at::Tensor& input);
at::Tensor _bfloat16_to_float_cpu(const at::Tensor& input);

} // namespace quantization
} // namespace c10d
} // namespace distributed
} // namespace torch
} // namespace torch::distributed::c10d::quantization
78 changes: 37 additions & 41 deletions torch/csrc/distributed/c10d/quantization/quantization_gpu.cu
Expand Up @@ -9,16 +9,16 @@
// FP32 -> BF16 kernel
__global__ void _float_to_bfloat16_cuda_kernel(
const float* __restrict__ input,
const int nrows,
const int ncols,
const size_t nrows,
const size_t ncols,
uint16_t* __restrict__ output) {
const int row_incre = blockDim.y * gridDim.y;
const int col_incre = blockDim.x * gridDim.x;
for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
const auto row_incre = blockDim.y * gridDim.y;
const auto col_incre = blockDim.x * gridDim.x;
for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
row += row_incre) {
const float* input_row = input + row * ncols;
uint16_t* output_row = output + row * ncols;
for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
col += col_incre) {
// Add 2^15 and right shift 16 to do round-nearest
output_row[col] =
Expand All @@ -31,14 +31,14 @@ __global__ void _float_to_bfloat16_cuda_kernel(
// BF16 -> FP32 kernel
__global__ void _bfloat16_to_float_cuda_kernel(
const uint16_t* __restrict__ input,
const int nrows,
const int ncols,
const size_t nrows,
const size_t ncols,
float* __restrict__ output) {
const int row_incre = blockDim.y * gridDim.y;
const int col_incre = blockDim.x * gridDim.x;
for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
const auto row_incre = blockDim.y * gridDim.y;
const auto col_incre = blockDim.x * gridDim.x;
for (auto row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
row += row_incre) {
for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
for (auto col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
col += col_incre) {
const uint16_t* input_row = input + row * ncols;
float* output_row = output + row * ncols;
Expand All @@ -50,10 +50,7 @@ __global__ void _bfloat16_to_float_cuda_kernel(
}
}

namespace torch {
namespace distributed {
namespace c10d {
namespace quantization {
namespace torch::distributed::c10d::quantization {

at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
TENSOR_ON_CUDA_GPU(input);
Expand All @@ -63,27 +60,28 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());

const int nrows = input.size(0);
const int ncols = input.size(1);
const int output_columns = ncols;
const auto nrows = input.size(0);
const auto ncols = input.size(1);
const size_t output_columns = ncols;

auto output = at::empty(
{nrows, output_columns},
{nrows, ncols},
#if HAS_NCCL_BF16_DATATYPE
input.options().dtype(at::kBFloat16));
#else
input.options().dtype(at::kHalf));
#endif

if (nrows == 0 || output_columns == 0) {
if (nrows == 0 || ncols == 0) {
return output;
}

constexpr int threads_per_block = 256;
const int blockDim_x = std::min(output_columns, threads_per_block);
constexpr size_t threads_per_block = 256;
const auto blockDim_x = std::min(output_columns, threads_per_block);
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u);
const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
const auto gridDim_y =
std::min<size_t>((nrows + blockDim.y - 1) / blockDim.y, 65535u);
dim3 gridDim(gridDim_x, gridDim_y);

_float_to_bfloat16_cuda_kernel<<<
Expand Down Expand Up @@ -113,24 +111,25 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());

const int nrows = input.size(0);
const int ncols = input.size(1);
const int output_columns = ncols;
const auto nrows = input.size(0);
const auto ncols = input.size(1);
const size_t output_columns = ncols;

auto output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
{nrows, ncols}, // 4 = sizeof(float)
input.options().dtype(at::kFloat)); // at::kBytes for uint8_t

if (nrows == 0 || output_columns == 0) {
if (nrows == 0 || ncols == 0) {
return output;
}

constexpr int threads_per_block = 256;
constexpr size_t threads_per_block = 256;

const int blockDim_x = std::min(output_columns, threads_per_block);
const auto blockDim_x = std::min(output_columns, threads_per_block);
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
const int gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
const int gridDim_y = std::min((nrows + blockDim.y - 1) / blockDim.y, 65535u);
const auto gridDim_x = (output_columns + blockDim.x - 1) / blockDim.x;
const auto gridDim_y =
std::min<size_t>((nrows + blockDim.y - 1) / blockDim.y, 65535u);
dim3 gridDim(gridDim_x, gridDim_y);

_bfloat16_to_float_cuda_kernel<<<
Expand All @@ -152,14 +151,11 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
}

#define DISPATCH_TO_CUDA(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))
m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))

TORCH_LIBRARY_IMPL(quantization, CUDA, m) {
DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda);
DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda);
DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda);
DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda);
}

} // namespace quantization
} // namespace c10d
} // namespace distributed
} // namespace torch
} // namespace torch::distributed::c10d::quantization
10 changes: 2 additions & 8 deletions torch/csrc/distributed/c10d/quantization/quantization_gpu.h
Expand Up @@ -8,15 +8,9 @@
#include <ATen/ATen.h>
#include <vector>

namespace torch {
namespace distributed {
namespace c10d {
namespace quantization {
namespace torch::distributed::c10d::quantization {

at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input);
at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input);

} // namespace quantization
} // namespace c10d
} // namespace distributed
} // namespace torch
} // namespace torch::distributed::c10d::quantization
3 changes: 1 addition & 2 deletions torch/csrc/distributed/c10d/sequence_num.cpp
@@ -1,11 +1,10 @@
#include <ATen/ThreadLocalState.h>
#include <c10/util/Optional.h>
#include <torch/csrc/distributed/c10d/sequence_num.hpp>

#include <c10/util/Logging.h>

namespace c10d {
SequenceNum::SequenceNum() : num_(c10::nullopt) {}
SequenceNum::SequenceNum() = default;

SequenceNum::SequenceNum(const uint64_t num) : num_(num) {}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/socket.cpp
Expand Up @@ -670,7 +670,7 @@ class SocketConnectOp {

static const std::chrono::seconds delay_duration_;

enum class ConnectResult { Success, Error, Retry };
enum class ConnectResult:uint8_t { Success, Error, Retry };

public:
SocketConnectOp(
Expand Down