Skip to content

Commit

Permalink
premul_sum: Check numel of factor tensor
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 15, 2022
1 parent 7e0a70d commit 2dcff40
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torch/csrc/distributed/c10d/Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/Tensor.h>

#include <c10/macros/Macros.h>
#include <c10/util/intrusive_ptr.h>

namespace c10d {
Expand All @@ -23,7 +24,9 @@ struct NCCLPreMulSumSupplement : _SupplementBase {
double double_factor{0.0};
at::Tensor tensor_factor;
NCCLPreMulSumSupplement(double f) : double_factor{f} {}
NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} {}
NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} {
TORCH_CHECK_EQ(t.numel(), 1);
}
};

// Other ReduceOps that need different supplementary data can also
Expand Down

0 comments on commit 2dcff40

Please sign in to comment.