Skip to content

Commit

Permalink
Enable some PyTorch core tests with inductor (pytorch#87490)
Browse files Browse the repository at this point in the history
Summary:
1) Graph break on torch.random.set_rng_state since it blocks running
inductor core tests;
2) Add several inductor-specific skips;
3) Enable several core tests for inductor CI;

cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: pytorch#87490
Approved by: https://github.com/eellison
  • Loading branch information
desertfire authored and kulinseth committed Dec 9, 2022
1 parent 7455967 commit c65a40d
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 10 deletions.
11 changes: 4 additions & 7 deletions .jenkins/pytorch/test.sh
Expand Up @@ -251,13 +251,10 @@ test_dynamo_shard() {


test_inductor() {
echo "TODO: enable inductor unit tests"
# time python test/run_test.py --core --exclude test_autograd --continue-through-error --verbose

# PYTORCH_TEST_WITH_DYNAMO and PYTORCH_TEST_WITH_INDUCTOR are only needed for PyTorch tests not written with
# using dynamo/inductor. For dynamo/inductor unit tests, specifiying them will trigger an error like
# "Detected two calls to `torchdynamo.optimize(...)` with a different backend compiler arguments."
# PYTORCH_TEST_WITH_DYNAMO=0 PYTORCH_TEST_WITH_INDUCTOR=0 pytest test/inductor
python test/test_modules.py --verbose
# TODO: investigate "RuntimeError: CUDA driver API confirmed a leak"
# seen intest_ops_gradients.py
# pytest test/test_ops_gradients.py --verbose -k "not _complex and not test_inplace_grad_acos_cuda_float64"
}

test_inductor_huggingface_shard() {
Expand Down
2 changes: 2 additions & 0 deletions test/dynamo/test_repros.py
Expand Up @@ -1016,6 +1016,8 @@ def test_create_rand_mask_from_inputs(self):
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 8)

# TODO: make set_rng_state work with FakeTensor/aot_autograd
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_rng_state(self):
def fn():
state = torch.get_rng_state()
Expand Down
6 changes: 5 additions & 1 deletion test/test_modules.py
Expand Up @@ -11,7 +11,8 @@
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps)
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
gradgradcheck, skipIfMps, skipIfTorchInductor)
from unittest.mock import patch, call


Expand Down Expand Up @@ -326,6 +327,7 @@ def inner_zero_grad(obj):

@skipIfMps
@modules(module_db)
@skipIfTorchInductor("to be fixed")
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
# Check modules work with non-contiguous tensors

Expand Down Expand Up @@ -489,6 +491,7 @@ def test_gradgrad(self, device, dtype, module_info, training):
@toleranceOverride({torch.float32: tol(5e-2, 0),
torch.float64: tol(4e-4, 0)})
@modules(module_db)
@skipIfTorchInductor("to be fixed")
def test_cpu_gpu_parity(self, device, dtype, module_info, training):
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
# nicer way for eval mode only.
Expand Down Expand Up @@ -579,6 +582,7 @@ def check_backward(cpu_output, gpu_output):

@skipIfMps
@modules(module_db)
@skipIfTorchInductor("to be fixed")
def test_memory_format(self, device, dtype, module_info, training):
is_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6)
# TODO tighten it to a specific module
Expand Down
6 changes: 6 additions & 0 deletions test/test_ops.py
Expand Up @@ -36,6 +36,7 @@
first_sample,
parametrize,
skipIfSlowGradcheckEnv,
skipIfTorchInductor,
slowTest,
)
from torch.testing._internal.common_methods_invocations import (
Expand Down Expand Up @@ -209,6 +210,7 @@ def to_cpu(arg):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@ops(python_ref_db)
@skipIfTorchInductor("Takes too long for inductor")
def test_python_ref_meta(self, device, dtype, op):
with FakeTensorMode() as mode:
pass
Expand Down Expand Up @@ -374,6 +376,7 @@ def _distance(a, b):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@ops(python_ref_db)
@skipIfTorchInductor("Takes too long for inductor")
def test_python_ref(self, device, dtype, op):
# In this test, primTorch refs call into the refs namespace
# For example, a ref with torch.foo in it will calls refs.foo instead
Expand All @@ -386,6 +389,7 @@ def test_python_ref(self, device, dtype, op):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@ops(python_ref_db)
@skipIfTorchInductor("Takes too long for inductor")
def test_python_ref_torch_fallback(self, device, dtype, op):
# In this test, refs call into the torch namespace (after the initial invocation)
# For example, a ref with torch.foo in it will call torch.foo instead of refs.foo
Expand All @@ -397,6 +401,7 @@ def test_python_ref_torch_fallback(self, device, dtype, op):
@skipCUDAIfRocm
@ops(python_ref_db)
@parametrize('executor', ['aten', 'nvfuser'])
@skipIfTorchInductor("Takes too long for inductor")
def test_python_ref_executor(self, device, dtype, op, executor):
# TODO: Not all dtypes are supported with nvfuser
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
Expand Down Expand Up @@ -457,6 +462,7 @@ def test_errors(self, device, op):
@skipMeta
@onlyNativeDeviceTypes
@ops([op for op in python_ref_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
@skipIfTorchInductor("Takes too long for inductor")
def test_python_ref_errors(self, device, op):
mode = FakeTensorMode()
with mode:
Expand Down
6 changes: 4 additions & 2 deletions test/test_ops_gradients.py
Expand Up @@ -4,8 +4,9 @@
from itertools import chain
import torch

from torch.testing._internal.common_utils import \
(TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, is_slow_gradcheck_env)
from torch.testing._internal.common_utils import (
TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, is_slow_gradcheck_env,
skipIfTorchInductor)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
Expand Down Expand Up @@ -253,6 +254,7 @@ def test_forward_mode_AD(self, device, dtype, op):
self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False)

@_gradcheck_ops(op_db)
@skipIfTorchInductor("to be fixed")
def test_inplace_forward_mode_AD(self, device, dtype, op):
self._skip_helper(op, device, dtype)

Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/variables/torch.py
Expand Up @@ -320,6 +320,9 @@ def get_state_from_generator():
assert isinstance(args[0], TensorVariable)

if config.fake_tensor_propagation:
unimplemented(
"TODO: make torch.random.set_rng_state work with FakeTensor/aot_autograd"
)
# In fake tensor case, this state doesn't matter, but
# it needs to be valid to not segfault. Pull a real tensor out.
# The value won't matter since we are running with fake tensors anyway, so rng doesn't matter.
Expand Down

0 comments on commit c65a40d

Please sign in to comment.