Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement matrix_exp Batching Rule #125086

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp
Expand Up @@ -36,6 +36,15 @@ Tensor vdot_decomp(const Tensor& A, const Tensor& B) {
return at::dot(A.is_complex() ? A.conj() : A, B);
}


oneOutput matrix_exp_batch_rule(const Tensor& self, c10::optional<int64_t> self_bdim) {
TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2,
"torch.matrix_exp: The input tensor must have at least 2 dimensions.");

auto self_ = moveBatchDimToFront(self, self_bdim);
return std::make_tuple(at::matrix_exp(self_), 0);
}

// NB: I wrote this like this because we *might* want its for a future matmul
// batch rule that isn't decomposed...
// "tv" = tensor @ vector
Expand Down Expand Up @@ -593,6 +602,8 @@ LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(_linalg_slogdet, linalg.slogdet);
LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_svd, linalg.svd);
// NOLINTEND(*array*)

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {m.impl("matrix_exp", matrix_exp_batch_rule);}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(bmm, bmm_batch_rule);
m.impl("addmv", addmv_decomp);
Expand Down
21 changes: 21 additions & 0 deletions aten/src/ATen/test/basic.cpp
Expand Up @@ -365,6 +365,27 @@ void test(DeprecatedTypeProperties& type) {
TestIntArrayRefExpansion(type);
}

void TestMatrixExpBatching(DeprecatedTypeProperties& type) {
auto matrices = at::randn({10, 3, 3}, type.options());
std::vector<Tensor> expected_results;

for (const auto& mat : matrices) {
expected_results.push_back(at::matrix_exp(mat));
}
auto expected = at::stack(expected_results);

Tensor actual = at::matrix_exp(matrices);

// Compare actual results to expected results
ASSERT_TRUE(actual.allclose(expected, 1e-5, 1e-8));
}

TEST(MatrixExpTest, BatchedMatrixExp) {
manual_seed(42); // Set a manual seed for reproducibility
DeprecatedTypeProperties type = CPU(at::kDouble);
TestMatrixExpBatching(type);
}

TEST(BasicTest, BasicTestCPU) {
manual_seed(123);

Expand Down
37 changes: 37 additions & 0 deletions test/test_nn.py
Expand Up @@ -52,6 +52,7 @@
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
from torch.types import _TensorOrTensors
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.nn.functional import deterministic_max_pool3d

AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()

Expand Down Expand Up @@ -7165,6 +7166,42 @@ def test_preserves_memory_format(self):
nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))

class TestDeterministicMaxPool3D(unittest.TestCase):
def setUp(self):
# This method will be called before each test.
self.input_tensor = torch.randn(1, 1, 4, 4, 4, dtype=torch.double)
self.kernel_size = 2
self.stride = 2
self.padding = 0
self.dilation = 1
self.ceil_mode = False

def test_basic_functionality(self):
# Test the basic functionality of the deterministic max pooling
deterministic_output = deterministic_max_pool3d(
self.input_tensor, self.kernel_size, self.stride, self.padding,
self.dilation, self.ceil_mode, deterministic=False
)
expected_output = torch.nn.functional.max_pool3d(
self.input_tensor, self.kernel_size, self.stride, self.padding,
self.dilation, self.ceil_mode
)
self.assertTrue(torch.allclose(deterministic_output, expected_output),
"The deterministic function does not match expected output.")

def test_deterministic_output(self):
# Test that deterministic flag provides the same output on multiple runs
output1 = deterministic_max_pool3d(
self.input_tensor, self.kernel_size, self.stride, self.padding,
self.dilation, self.ceil_mode, deterministic=True
)
output2 = deterministic_max_pool3d(
self.input_tensor, self.kernel_size, self.stride, self.padding,
self.dilation, self.ceil_mode, deterministic=True
)
self.assertTrue(torch.allclose(output1, output2),
"Deterministic outputs are not identical on repeated runs.")


class TestAddRelu(TestCase):
def test_add_relu(self):
Expand Down
30 changes: 30 additions & 0 deletions torch/nn/functional.py
Expand Up @@ -856,6 +856,36 @@ def max_pool3d_with_indices(
stride = torch.jit.annotate(List[int], [])
return torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)

def deterministic_max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, deterministic=False):
"""
A version of the max_pool3d operation that can optionally be made deterministic.

Args:
input (Tensor): The input tensor.
kernel_size (int or tuple): Size of the pooling window.
stride (int or tuple, optional): Stride of the pooling window. Default: None (same as kernel_size).
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0.
dilation (int or tuple, optional): A parameter that controls the stride of elements in the window. Default: 1.
ceil_mode (bool, optional): Whether to use ceil instead of floor to compute the output shape. Default: False.
deterministic (bool, optional): Whether to enforce deterministic behavior. Default: False.

Returns:
Tensor: The pooled tensor.
"""
if deterministic:
# Set torch to use deterministic algorithms (where available)
torch.use_deterministic_algorithms(True)

# Call the regular max_pool3d function
output = torch.nn.functional.max_pool3d(
input, kernel_size, stride, padding, dilation, ceil_mode
)

if deterministic:
# Reset to original state if necessary
torch.use_deterministic_algorithms(False)

return output

def _max_pool3d(
input: Tensor, kernel_size: BroadcastingList3[int],
Expand Down