Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and Fix Some Prombles on Autocast #125118

Open
wants to merge 3 commits into
base: gh/fffrog/9/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 10 additions & 4 deletions aten/src/ATen/autocast_mode.cpp
Expand Up @@ -321,19 +321,22 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL_CUDA(__VA_ARGS__, lower_precision_fp)

AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP)
#undef _KERNEL_CUDA_LOW_PRECISION_FP
KERNEL_CUDA(cudnn_convolution, lower_precision_fp)
KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp)

// fp32
#define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32)

AT_FORALL_FP32(_KERNEL_CUDA_FP32)
#undef _KERNEL_CUDA_FP32

// fp32_set_opt_dtype
#define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \
KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype)

AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE)
#undef _KERNEL_CUDA_FP32_SET_OPT_DTYPE
// commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
// when autocasting.
// KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype)
Expand All @@ -350,9 +353,9 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
#define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote)

AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE)
#undef _KERNEL_CUDA_PROMOTE

m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
KERNEL_FN("binary_cross_entropy", &at::autocast::binary_cross_entropy_banned)
}

TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
Expand Down Expand Up @@ -507,17 +510,20 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
KERNEL_XPU(__VA_ARGS__, lower_precision_fp)

AT_FORALL_LOWER_PRECISION_FP(_KERNEL_XPU_LOW_PRECISION_FP)
#undef _KERNEL_XPU_LOW_PRECISION_FP

// fp32
#define _KERNEL_XPU_FP32(...) KERNEL_XPU(__VA_ARGS__, fp32)

AT_FORALL_FP32(_KERNEL_XPU_FP32)
#undef _KERNEL_XPU_FP32

// fp32_set_opt_dtype
#define _KERNEL_XPU_FP32_SET_OPT_DTYPE(...) \
KERNEL_XPU(__VA_ARGS__, fp32_set_opt_dtype)

AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_XPU_FP32_SET_OPT_DTYPE)
#undef _KERNEL_XPU_FP32_SET_OPT_DTYPE

// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
Expand All @@ -529,9 +535,9 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) {
#define _KERNEL_XPU_PROMOTE(...) KERNEL_XPU(__VA_ARGS__, promote)

AT_FORALL_PROMOTE(_KERNEL_XPU_PROMOTE)
#undef _KERNEL_XPU_PROMOTE

m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
KERNEL_FN("binary_cross_entropy", &at::autocast::binary_cross_entropy_banned)
}

} // namespace
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/autocast_mode.h
Expand Up @@ -623,6 +623,9 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
#define _KERNEL_OVERLOAD_NARG(...) \
C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))

#define KERNEL_FN(OP, Function) \
m.impl(TORCH_SELECTIVE_NAME("aten::" OP), TORCH_FN(Function));

// Common cases where registration signature matches redispatch signature
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
#define KERNEL1(DISPATCHKEY, OP, POLICY) \
Expand Down
43 changes: 19 additions & 24 deletions c10/test/core/DispatchKeySet_test.cpp
Expand Up @@ -9,6 +9,20 @@

using namespace c10;

static bool isRealDispatchKey(DispatchKey k) {
if (k == DispatchKey::EndOfFunctionalityKeys ||
k == DispatchKey::StartOfDenseBackends ||
k == DispatchKey::StartOfQuantizedBackends ||
k == DispatchKey::StartOfSparseBackends ||
k == DispatchKey::StartOfSparseCsrBackends ||
k == DispatchKey::StartOfNestedTensorBackends ||
k == DispatchKey::StartOfAutogradFunctionalityBackends) {
return false;
}

return true;
}

// This test exists not to be comprehensive, but to more clearly show
// what the semantics of DispatchKeySet are.
TEST(DispatchKeySet, ShowSemantics) {
Expand Down Expand Up @@ -179,10 +193,7 @@ TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
i++) {
auto tid = static_cast<DispatchKey>(i);
// Skip these because they aren't real keys.
if (tid == DispatchKey::StartOfDenseBackends ||
tid == DispatchKey::StartOfSparseBackends ||
tid == DispatchKey::StartOfQuantizedBackends ||
tid == DispatchKey::StartOfAutogradFunctionalityBackends) {
if (isRealDispatchKey(tid)) {
continue;
}
DispatchKeySet sing(tid);
Expand Down Expand Up @@ -221,20 +232,9 @@ TEST(DispatchKeySet, DoubletonPerBackend) {
auto tid2 = static_cast<DispatchKey>(j);

// Skip these because they aren't real keys.
if (tid1 == DispatchKey::StartOfDenseBackends ||
tid1 == DispatchKey::StartOfSparseBackends ||
tid1 == DispatchKey::StartOfSparseCsrBackends ||
tid1 == DispatchKey::StartOfQuantizedBackends ||
tid1 == DispatchKey::StartOfNestedTensorBackends ||
tid1 == DispatchKey::StartOfAutogradFunctionalityBackends)
continue;
if (tid2 == DispatchKey::StartOfDenseBackends ||
tid2 == DispatchKey::StartOfSparseBackends ||
tid2 == DispatchKey::StartOfSparseCsrBackends ||
tid2 == DispatchKey::StartOfQuantizedBackends ||
tid2 == DispatchKey::StartOfNestedTensorBackends ||
tid2 == DispatchKey::StartOfAutogradFunctionalityBackends)
if (!isRealDispatchKey(tid1) || !isRealDispatchKey(tid2)) {
continue;
}

auto backend1 = toBackendComponent(tid1);
auto backend2 = toBackendComponent(tid2);
Expand Down Expand Up @@ -421,14 +421,9 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
auto k = static_cast<DispatchKey>(i);
// These synthetic keys never actually get used and don't need
// to be printed
if (k == DispatchKey::EndOfFunctionalityKeys ||
k == DispatchKey::StartOfDenseBackends ||
k == DispatchKey::StartOfQuantizedBackends ||
k == DispatchKey::StartOfSparseBackends ||
k == DispatchKey::StartOfSparseCsrBackends ||
k == DispatchKey::StartOfNestedTensorBackends ||
k == DispatchKey::StartOfAutogradFunctionalityBackends)
if (!isRealDispatchKey(k)) {
continue;
}
auto res = std::string(toString(k));
ASSERT_TRUE(res.find("Unknown") == std::string::npos)
<< i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
Expand Down
6 changes: 4 additions & 2 deletions test/test_ops.py
Expand Up @@ -2350,10 +2350,10 @@ def test_refs_are_in_decomp_table(self, op):
"narrow", # Fails only for one overload with DataDependentOutputException (hence skip).
)

fake_autocast_device_skips = defaultdict(dict)

# TODO: investigate/fix
fake_autocast_device_skips = defaultdict(set)
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"}


dynamic_output_op_tests = (
Expand Down Expand Up @@ -2594,6 +2594,8 @@ def test_fake(self, device, dtype, op):

@ops(op_db, dtypes=OpDTypes.any_one)
def test_fake_autocast(self, device, dtype, op):
# remove the index from the device, first
device = device.split(":")[0]
if op.name in fake_autocast_device_skips[device]:
self.skipTest("Skip failing test")
context = (
Expand Down