Skip to content

Commit

Permalink
[ONNX] Add test case for onnx::Max scalar type (#88751)
Browse files Browse the repository at this point in the history
Referenced by minimum cases
Pull Request resolved: #88751
Approved by: https://github.com/wschin, https://github.com/BowenBao
  • Loading branch information
titaiwangms authored and pytorchmergebot committed Nov 11, 2022
1 parent 396c3b1 commit b843f4d
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8728,6 +8728,28 @@ def forward(self, x, y):
y = torch.full_like(x, True)
self.run_test(MinimumModel(), (x, y))

@skipIfUnsupportedMinOpsetVersion(12)
def test_maximum_dtypes(self):
class MaximumModel(torch.nn.Module):
def forward(self, x, y):
return torch.maximum(x, y)

x = torch.randn((5, 5), dtype=torch.float16)
y = torch.randn((5, 5), dtype=torch.float)
self.run_test(MaximumModel(), (x, y))

x = torch.randn((5, 5), dtype=torch.float16)
y = torch.randint(10, (5, 5), dtype=torch.int16)
self.run_test(MaximumModel(), (x, y))

x = torch.randint(10, (5, 5), dtype=torch.int16)
y = torch.randint(10, (5, 5), dtype=torch.int32)
self.run_test(MaximumModel(), (x, y))

x = torch.randint(10, (5, 5), dtype=torch.int)
y = torch.full_like(x, True)
self.run_test(MaximumModel(), (x, y))

@skipIfUnsupportedMinOpsetVersion(9)
def test_any(self):
class M(torch.nn.Module):
Expand Down

0 comments on commit b843f4d

Please sign in to comment.