Skip to content

Commit

Permalink
fix normal with empty std (#66524)
Browse files Browse the repository at this point in the history
  • Loading branch information
ngimel committed Oct 14, 2021
1 parent 9509e8a commit c3ea586
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/DistributionTemplates.h
Expand Up @@ -238,7 +238,7 @@ template<template<typename> class normal_kernel, typename RNG>
Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex");
TORCH_CHECK(
std.min().ge(0).item<bool>(),
std.numel() == 0 || std.min().ge(0).item<bool>(),
"normal expects all elements of std >= 0.0");
bool is_deprecated_th_impl = resize_output_for_normal(output, mean, std);
normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
Expand Down
4 changes: 4 additions & 0 deletions test/test_tensor_creation_ops.py
Expand Up @@ -3258,6 +3258,10 @@ def helper(self, device, dtype, ptype, t_transform, std_transform):
self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0)
self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0)

# test empty mean/std
out = torch.normal(mean=torch.empty((0, 2)), std=torch.empty((0, 1)))
self.assertEqual(out.size(), torch.Size([0, 2]))

r.fill_(42)
r = torch.normal(2, 3, (100, 100), dtype=dtype, device=device)
self.assertEqual(r.dtype, dtype)
Expand Down

0 comments on commit c3ea586

Please sign in to comment.