Skip to content

Commit

Permalink
Use BFloat16 in distributed quantization when supported by NCCL (pyto…
Browse files Browse the repository at this point in the history
…rch#125113)

This PR enables BFloat16 in torch/csrc/distributed/c10d/quantization/quantization_gpu.cu .

Pull Request resolved: pytorch#125113
Approved by: https://github.com/kwen2501
  • Loading branch information
cyyever authored and petrex committed May 3, 2024
1 parent c733745 commit 950d0ce
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions torch/csrc/distributed/c10d/quantization/quantization_gpu.cu
Expand Up @@ -69,15 +69,16 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {

auto output = at::empty(
{nrows, output_columns},
input.options().dtype(at::kHalf)); // at::kHalf
#if HAS_NCCL_BF16_DATATYPE
input.options().dtype(at::kBFloat16));
#else
input.options().dtype(at::kHalf));
#endif

if (nrows == 0 || output_columns == 0) {
return output;
}

// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
// NCCL input.options().dtype(at::kBFloat16)); // at::kBFloat16

constexpr int threads_per_block = 256;
const int blockDim_x = std::min(output_columns, threads_per_block);
dim3 blockDim(blockDim_x, threads_per_block / blockDim_x);
Expand All @@ -93,10 +94,13 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
input.const_data_ptr<float>(),
nrows,
ncols,
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
// NCCL
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>()));
//C10_CUDA_KERNEL_LAUNCH_CHECK();
#if HAS_NCCL_BF16_DATATYPE
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::BFloat16>())
#else
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>())
#endif
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
Expand Down Expand Up @@ -134,9 +138,11 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
blockDim,
0,
at::cuda::getCurrentCUDAStream()>>>(
// TODO: replace Half by BFloat16, after BFloat16 is supported by Nvidia
// NCCL
#if HAS_NCCL_BF16_DATATYPE
reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::BFloat16>()),
#else
reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::Half>()),
#endif
nrows,
ncols,
output.mutable_data_ptr<float>());
Expand Down

0 comments on commit 950d0ce

Please sign in to comment.