Skip to content

Commit

Permalink
Refactor and Fix Some Prombles on Autocast
Browse files Browse the repository at this point in the history
1. Fix the wrong test cases
2. Refactor tests

ghstack-source-id: 9d3687b064a8f17112e5ad2cb394761e9f8c4273
Pull Request resolved: #125118
  • Loading branch information
FFFrog committed Apr 28, 2024
1 parent ce503c1 commit 73df5bd
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 28 deletions.
14 changes: 10 additions & 4 deletions aten/src/ATen/autocast_mode.cpp
Expand Up @@ -321,19 +321,22 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL_CUDA(__VA_ARGS__, lower_precision_fp)

AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP)
#undef _KERNEL_CUDA_LOW_PRECISION_FP(...)
KERNEL_CUDA(cudnn_convolution, lower_precision_fp)
KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp)

// fp32
#define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32)

AT_FORALL_FP32(_KERNEL_CUDA_FP32)
#undef _KERNEL_CUDA_FP32

// fp32_set_opt_dtype
#define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \
KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype)

AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE)
#undef _KERNEL_CUDA_FP32_SET_OPT_DTYPE
// commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
// when autocasting.
// KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype)
Expand All @@ -350,9 +353,9 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
#define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote)

AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE)
#undef _KERNEL_CUDA_PROMOTE

m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
KERNEL_FN("binary_cross_entropy", &at::autocast::binary_cross_entropy_banned)
}

TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
Expand Down Expand Up @@ -507,17 +510,20 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
KERNEL_XPU(__VA_ARGS__, lower_precision_fp)

AT_FORALL_LOWER_PRECISION_FP(_KERNEL_XPU_LOW_PRECISION_FP)
#undef _KERNEL_XPU_LOW_PRECISION_FP

// fp32
#define _KERNEL_XPU_FP32(...) KERNEL_XPU(__VA_ARGS__, fp32)

AT_FORALL_FP32(_KERNEL_XPU_FP32)
#undef _KERNEL_XPU_FP32

// fp32_set_opt_dtype
#define _KERNEL_XPU_FP32_SET_OPT_DTYPE(...) \
KERNEL_XPU(__VA_ARGS__, fp32_set_opt_dtype)

AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_XPU_FP32_SET_OPT_DTYPE)
#undef _KERNEL_XPU_FP32_SET_OPT_DTYPE

// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
Expand All @@ -529,9 +535,9 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
#define _KERNEL_XPU_PROMOTE(...) KERNEL_XPU(__VA_ARGS__, promote)

AT_FORALL_PROMOTE(_KERNEL_XPU_PROMOTE)
#undef _KERNEL_XPU_PROMOTE

m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
KERNEL_FN("binary_cross_entropy", &at::autocast::binary_cross_entropy_banned)
}

} // namespace
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/autocast_mode.h
Expand Up @@ -623,6 +623,9 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
#define _KERNEL_OVERLOAD_NARG(...) \
C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))

#define KERNEL_FN(OP, Function) \
m.impl(TORCH_SELECTIVE_NAME("aten::" OP), TORCH_FN(Function));

// Common cases where registration signature matches redispatch signature
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
#define KERNEL1(DISPATCHKEY, OP, POLICY) \
Expand Down
43 changes: 19 additions & 24 deletions c10/test/core/DispatchKeySet_test.cpp
Expand Up @@ -9,6 +9,20 @@

using namespace c10;

static bool isRealDispatchKey(DispatchKey k) {
if (k == DispatchKey::EndOfFunctionalityKeys ||
k == DispatchKey::StartOfDenseBackends ||
k == DispatchKey::StartOfQuantizedBackends ||
k == DispatchKey::StartOfSparseBackends ||
k == DispatchKey::StartOfSparseCsrBackends ||
k == DispatchKey::StartOfNestedTensorBackends ||
k == DispatchKey::StartOfAutogradFunctionalityBackends) {
return false;
}

return true;
}

// This test exists not to be comprehensive, but to more clearly show
// what the semantics of DispatchKeySet are.
TEST(DispatchKeySet, ShowSemantics) {
Expand Down Expand Up @@ -179,10 +193,7 @@ TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
i++) {
auto tid = static_cast<DispatchKey>(i);
// Skip these because they aren't real keys.
if (tid == DispatchKey::StartOfDenseBackends ||
tid == DispatchKey::StartOfSparseBackends ||
tid == DispatchKey::StartOfQuantizedBackends ||
tid == DispatchKey::StartOfAutogradFunctionalityBackends) {
if (isRealDispatchKey(tid)) {
continue;
}
DispatchKeySet sing(tid);
Expand Down Expand Up @@ -221,20 +232,9 @@ TEST(DispatchKeySet, DoubletonPerBackend) {
auto tid2 = static_cast<DispatchKey>(j);

// Skip these because they aren't real keys.
if (tid1 == DispatchKey::StartOfDenseBackends ||
tid1 == DispatchKey::StartOfSparseBackends ||
tid1 == DispatchKey::StartOfSparseCsrBackends ||
tid1 == DispatchKey::StartOfQuantizedBackends ||
tid1 == DispatchKey::StartOfNestedTensorBackends ||
tid1 == DispatchKey::StartOfAutogradFunctionalityBackends)
continue;
if (tid2 == DispatchKey::StartOfDenseBackends ||
tid2 == DispatchKey::StartOfSparseBackends ||
tid2 == DispatchKey::StartOfSparseCsrBackends ||
tid2 == DispatchKey::StartOfQuantizedBackends ||
tid2 == DispatchKey::StartOfNestedTensorBackends ||
tid2 == DispatchKey::StartOfAutogradFunctionalityBackends)
if (!isRealDispatchKey(tid1) || !isRealDispatchKey(tid2)) {
continue;
}

auto backend1 = toBackendComponent(tid1);
auto backend2 = toBackendComponent(tid2);
Expand Down Expand Up @@ -421,14 +421,9 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
auto k = static_cast<DispatchKey>(i);
// These synthetic keys never actually get used and don't need
// to be printed
if (k == DispatchKey::EndOfFunctionalityKeys ||
k == DispatchKey::StartOfDenseBackends ||
k == DispatchKey::StartOfQuantizedBackends ||
k == DispatchKey::StartOfSparseBackends ||
k == DispatchKey::StartOfSparseCsrBackends ||
k == DispatchKey::StartOfNestedTensorBackends ||
k == DispatchKey::StartOfAutogradFunctionalityBackends)
if (!isRealDispatchKey(k)) {
continue;
}
auto res = std::string(toString(k));
ASSERT_TRUE(res.find("Unknown") == std::string::npos)
<< i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
Expand Down
2 changes: 2 additions & 0 deletions test/test_ops.py
Expand Up @@ -2594,6 +2594,8 @@ def test_fake(self, device, dtype, op):

@ops(op_db, dtypes=OpDTypes.any_one)
def test_fake_autocast(self, device, dtype, op):
# remove the index from the device, first
device = device.split(":")[0]
if op.name in fake_autocast_device_skips[device]:
self.skipTest("Skip failing test")
context = (
Expand Down

0 comments on commit 73df5bd

Please sign in to comment.