Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use BFloat16 in distributed quantization when supported by NCCL #125113

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 16 additions & 10 deletions torch/csrc/distributed/c10d/quantization/quantization_gpu.cu
Expand Up @@ -69,15 +69,16 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {

auto output = at::empty(
{nrows, output_columns},
input.options().dtype(at::kHalf)); // at::kHalf
#if HAS_NCCL_BF16_DATATYPE
input.options().dtype(at::kBFloat16));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi i don't think you need to do this one in the preprocessor, you should be able to do it like:

input.options().dtype(HAS_NCCL_BF16_DATATYPE ? at::kBFloat16 : at::kHalf));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HAS_NCCL_BF16_DATATYPE is a macro and I think it's better to format code like this so that it is easy to identify and remove the old branch in the future.

#else
input.options().dtype(at::kHalf));
#endif

if (nrows == 0 || output_columns == 0) {
return output;
}

// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
// NCCL input.options().dtype(at::kBFloat16)); // at::kBFloat16

constexpr int threads_per_block = 256;
const int blockDim_x = std::min(output_columns, threads_per_block);
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
Expand All @@ -93,10 +94,13 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
input.const_data_ptr<float>(),
nrows,
ncols,
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
// NCCL
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>()));
//C10_CUDA_KERNEL_LAUNCH_CHECK();
#if HAS_NCCL_BF16_DATATYPE
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::BFloat16>())
#else
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>())
#endif
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the C10_CUDA_KERNEL_LAUNCH_CHECK function do? What's the purpose of uncommenting it?


return output;
}
Expand Down Expand Up @@ -134,9 +138,11 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
blockDim,
0,
at::cuda::getCurrentCUDAStream()>>>(
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
// NCCL
#if HAS_NCCL_BF16_DATATYPE
reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::BFloat16>()),
#else
reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::Half>()),
#endif
nrows,
ncols,
output.mutable_data_ptr<float>());
Expand Down