Skip to content

Commit

Permalink
Refactor and Fix Some Prombles on Autocast
Browse files Browse the repository at this point in the history
1. Fix the wrong test cases
2. Refactor tests

ghstack-source-id: 9f3685af5dd509cbcd3c2dfd057de4cad9395f47
Pull Request resolved: #125118
  • Loading branch information
FFFrog committed Apr 28, 2024
1 parent 94b328e commit 4e6cc03
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 30 deletions.
14 changes: 10 additions & 4 deletions aten/src/ATen/autocast_mode.cpp
Expand Up @@ -321,19 +321,22 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL_CUDA(__VA_ARGS__, lower_precision_fp)

AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP)
#undef _KERNEL_CUDA_LOW_PRECISION_FP
KERNEL_CUDA(cudnn_convolution, lower_precision_fp)
KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp)

// fp32
#define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32)

AT_FORALL_FP32(_KERNEL_CUDA_FP32)
#undef _KERNEL_CUDA_FP32

// fp32_set_opt_dtype
#define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \
KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype)

AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE)
#undef _KERNEL_CUDA_FP32_SET_OPT_DTYPE
// commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
// when autocasting.
// KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype)
Expand All @@ -350,9 +353,9 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
#define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote)

AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE)
#undef _KERNEL_CUDA_PROMOTE

m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
KERNEL_FN("binary_cross_entropy", &at::autocast::binary_cross_entropy_banned)
}

TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
Expand Down Expand Up @@ -507,17 +510,20 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
KERNEL_XPU(__VA_ARGS__, lower_precision_fp)

AT_FORALL_LOWER_PRECISION_FP(_KERNEL_XPU_LOW_PRECISION_FP)
#undef _KERNEL_XPU_LOW_PRECISION_FP

// fp32
#define _KERNEL_XPU_FP32(...) KERNEL_XPU(__VA_ARGS__, fp32)

AT_FORALL_FP32(_KERNEL_XPU_FP32)
#undef _KERNEL_XPU_FP32

// fp32_set_opt_dtype
#define _KERNEL_XPU_FP32_SET_OPT_DTYPE(...) \
KERNEL_XPU(__VA_ARGS__, fp32_set_opt_dtype)

AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_XPU_FP32_SET_OPT_DTYPE)
#undef _KERNEL_XPU_FP32_SET_OPT_DTYPE

// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
Expand All @@ -529,9 +535,9 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
#define _KERNEL_XPU_PROMOTE(...) KERNEL_XPU(__VA_ARGS__, promote)

AT_FORALL_PROMOTE(_KERNEL_XPU_PROMOTE)
#undef _KERNEL_XPU_PROMOTE

m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
KERNEL_FN("binary_cross_entropy", &at::autocast::binary_cross_entropy_banned)
}

} // namespace
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/autocast_mode.h
Expand Up @@ -623,6 +623,9 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
#define _KERNEL_OVERLOAD_NARG(...) \
C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))

#define KERNEL_FN(OP, Function) \
m.impl(TORCH_SELECTIVE_NAME("aten::" OP), TORCH_FN(Function));

// Common cases where registration signature matches redispatch signature
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
#define KERNEL1(DISPATCHKEY, OP, POLICY) \
Expand Down
43 changes: 19 additions & 24 deletions c10/test/core/DispatchKeySet_test.cpp
Expand Up @@ -9,6 +9,20 @@

using namespace c10;

static bool isRealDispatchKey(DispatchKey k) {
if (k == DispatchKey::EndOfFunctionalityKeys ||
k == DispatchKey::StartOfDenseBackends ||
k == DispatchKey::StartOfQuantizedBackends ||
k == DispatchKey::StartOfSparseBackends ||
k == DispatchKey::StartOfSparseCsrBackends ||
k == DispatchKey::StartOfNestedTensorBackends ||
k == DispatchKey::StartOfAutogradFunctionalityBackends) {
return false;
}

return true;
}

// This test exists not to be comprehensive, but to more clearly show
// what the semantics of DispatchKeySet are.
TEST(DispatchKeySet, ShowSemantics) {
Expand Down Expand Up @@ -179,10 +193,7 @@ TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
i++) {
auto tid = static_cast<DispatchKey>(i);
// Skip these because they aren't real keys.
if (tid == DispatchKey::StartOfDenseBackends ||
tid == DispatchKey::StartOfSparseBackends ||
tid == DispatchKey::StartOfQuantizedBackends ||
tid == DispatchKey::StartOfAutogradFunctionalityBackends) {
if (isRealDispatchKey(tid)) {
continue;
}
DispatchKeySet sing(tid);
Expand Down Expand Up @@ -221,20 +232,9 @@ TEST(DispatchKeySet, DoubletonPerBackend) {
auto tid2 = static_cast<DispatchKey>(j);

// Skip these because they aren't real keys.
if (tid1 == DispatchKey::StartOfDenseBackends ||
tid1 == DispatchKey::StartOfSparseBackends ||
tid1 == DispatchKey::StartOfSparseCsrBackends ||
tid1 == DispatchKey::StartOfQuantizedBackends ||
tid1 == DispatchKey::StartOfNestedTensorBackends ||
tid1 == DispatchKey::StartOfAutogradFunctionalityBackends)
continue;
if (tid2 == DispatchKey::StartOfDenseBackends ||
tid2 == DispatchKey::StartOfSparseBackends ||
tid2 == DispatchKey::StartOfSparseCsrBackends ||
tid2 == DispatchKey::StartOfQuantizedBackends ||
tid2 == DispatchKey::StartOfNestedTensorBackends ||
tid2 == DispatchKey::StartOfAutogradFunctionalityBackends)
if (!isRealDispatchKey(tid1) || !isRealDispatchKey(tid2)) {
continue;
}

auto backend1 = toBackendComponent(tid1);
auto backend2 = toBackendComponent(tid2);
Expand Down Expand Up @@ -421,14 +421,9 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
auto k = static_cast<DispatchKey>(i);
// These synthetic keys never actually get used and don't need
// to be printed
if (k == DispatchKey::EndOfFunctionalityKeys ||
k == DispatchKey::StartOfDenseBackends ||
k == DispatchKey::StartOfQuantizedBackends ||
k == DispatchKey::StartOfSparseBackends ||
k == DispatchKey::StartOfSparseCsrBackends ||
k == DispatchKey::StartOfNestedTensorBackends ||
k == DispatchKey::StartOfAutogradFunctionalityBackends)
if (!isRealDispatchKey(k)) {
continue;
}
auto res = std::string(toString(k));
ASSERT_TRUE(res.find("Unknown") == std::string::npos)
<< i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
Expand Down
6 changes: 4 additions & 2 deletions test/test_ops.py
Expand Up @@ -2350,10 +2350,10 @@ def test_refs_are_in_decomp_table(self, op):
"narrow", # Fails only for one overload with DataDependentOutputException (hence skip).
)

fake_autocast_device_skips = defaultdict(dict)

# TODO: investigate/fix
fake_autocast_device_skips = defaultdict(set)
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"}


dynamic_output_op_tests = (
Expand Down Expand Up @@ -2594,6 +2594,8 @@ def test_fake(self, device, dtype, op):

@ops(op_db, dtypes=OpDTypes.any_one)
def test_fake_autocast(self, device, dtype, op):
# remove the index from the device, first
device = device.split(":")[0]
if op.name in fake_autocast_device_skips[device]:
self.skipTest("Skip failing test")
context = (
Expand Down

0 comments on commit 4e6cc03

Please sign in to comment.