Skip to content

Commit

Permalink
Move autocast op list to autocast_mode.h to make sure other backends …
Browse files Browse the repository at this point in the history
…can reuse it. (#125114)

This PR refactors the op list added in #124051. To make sure other backends can reuse it.

Pull Request resolved: #125114
Approved by: https://github.com/albanD
  • Loading branch information
PHLens authored and pytorchmergebot committed May 6, 2024
1 parent 2a42c40 commit ad9a27f
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 153 deletions.
153 changes: 0 additions & 153 deletions aten/src/ATen/autocast_mode.cpp
Expand Up @@ -158,159 +158,6 @@ namespace {
Explicit registration for out-of-place ops
*****************************************/

#define AT_FORALL_LOWER_PRECISION_FP(_) \
_(_convolution, deprecated) \
_(_convolution) \
_(conv1d) \
_(conv2d) \
_(conv3d) \
_(conv_tbc) \
_(conv_transpose1d) \
_(conv_transpose2d, input) \
_(conv_transpose3d, input) \
_(convolution) \
_(prelu) \
_(addmm) \
_(addmv) \
_(addr) \
_(matmul) \
_(einsum) \
_(mm) \
_(mv) \
_(linalg_vecdot) \
_(linear) \
_(addbmm) \
_(baddbmm) \
_(bmm) \
_(chain_matmul) \
_(linalg_multi_dot) \
_(_thnn_fused_lstm_cell) \
_(_thnn_fused_gru_cell) \
_(lstm_cell) \
_(gru_cell) \
_(rnn_tanh_cell) \
_(rnn_relu_cell) \
_(_scaled_dot_product_flash_attention) \
_(scaled_dot_product_attention)

#define AT_FORALL_FP32(_) \
_(acos) \
_(asin) \
_(cosh) \
_(erfinv) \
_(exp) \
_(expm1) \
_(log) \
_(log10) \
_(log2) \
_(log1p) \
_(reciprocal) \
_(rsqrt) \
_(sinh) \
_(tan) \
_(pow, Tensor_Scalar) \
_(pow, Tensor_Tensor) \
_(pow, Scalar) \
_(softplus) \
_(layer_norm) \
_(native_layer_norm) \
_(group_norm) \
_(frobenius_norm, dim) \
_(nuclear_norm) \
_(nuclear_norm, dim) \
_(cosine_similarity) \
_(poisson_nll_loss) \
_(cosine_embedding_loss) \
_(nll_loss) \
_(nll_loss2d) \
_(hinge_embedding_loss) \
_(kl_div) \
_(l1_loss) \
_(smooth_l1_loss) \
_(huber_loss) \
_(mse_loss) \
_(margin_ranking_loss) \
_(multilabel_margin_loss) \
_(soft_margin_loss) \
_(triplet_margin_loss) \
_(multi_margin_loss) \
_(binary_cross_entropy_with_logits) \
_(dist) \
_(pdist) \
_(cdist) \
_(renorm) \
_(logsumexp) \
_(upsample_nearest1d) \
_(_upsample_nearest_exact1d) \
_(upsample_nearest2d) \
_(_upsample_nearest_exact2d) \
_(upsample_nearest3d) \
_(_upsample_nearest_exact3d) \
_(upsample_linear1d) \
_(upsample_bilinear2d) \
_(_upsample_bilinear2d_aa) \
_(upsample_trilinear3d) \
_(upsample_bicubic2d) \
_(_upsample_bicubic2d_aa)

#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
_(prod) \
_(prod, dim_int) \
_(prod, dim_Dimname) \
_(softmax, int) \
_(softmax, Dimname) \
_(log_softmax, int) \
_(log_softmax, Dimname) \
_(cumprod) \
_(cumprod, dimname) \
_(cumsum) \
_(cumsum, dimname) \
_(linalg_vector_norm) \
_(linalg_matrix_norm) \
_(linalg_matrix_norm, str_ord) \
_(sum) \
_(sum, dim_IntList) \
_(sum, dim_DimnameList)

#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
_(ADD_NS(norm), \
"norm.Scalar", \
Tensor(const Tensor&, const Scalar&), \
Tensor(const Tensor&, const c10::optional<Scalar>&, ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, IntArrayRef, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
IntArrayRef, \
bool, \
ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.names_ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, DimnameList, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
DimnameList, \
bool, \
ScalarType), \
fp32_append_dtype)

#define AT_FORALL_PROMOTE(_) \
_(addcdiv) \
_(addcmul) \
_(atan2) \
_(bilinear) \
_(cross) \
_(dot) \
_(grid_sampler) \
_(index_put) \
_(tensordot) \
_(scatter_add)

TORCH_LIBRARY_IMPL(_, Autocast, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
Expand Down
155 changes: 155 additions & 0 deletions aten/src/ATen/autocast_mode.h
Expand Up @@ -744,3 +744,158 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)

// Op lists for different policies.
// To make sure other backends can reuse the policy op list.
#define AT_FORALL_LOWER_PRECISION_FP(_) \
_(_convolution, deprecated) \
_(_convolution) \
_(conv1d) \
_(conv2d) \
_(conv3d) \
_(conv_tbc) \
_(conv_transpose1d) \
_(conv_transpose2d, input) \
_(conv_transpose3d, input) \
_(convolution) \
_(prelu) \
_(addmm) \
_(addmv) \
_(addr) \
_(matmul) \
_(einsum) \
_(mm) \
_(mv) \
_(linalg_vecdot) \
_(linear) \
_(addbmm) \
_(baddbmm) \
_(bmm) \
_(chain_matmul) \
_(linalg_multi_dot) \
_(_thnn_fused_lstm_cell) \
_(_thnn_fused_gru_cell) \
_(lstm_cell) \
_(gru_cell) \
_(rnn_tanh_cell) \
_(rnn_relu_cell) \
_(_scaled_dot_product_flash_attention) \
_(scaled_dot_product_attention)

#define AT_FORALL_FP32(_) \
_(acos) \
_(asin) \
_(cosh) \
_(erfinv) \
_(exp) \
_(expm1) \
_(log) \
_(log10) \
_(log2) \
_(log1p) \
_(reciprocal) \
_(rsqrt) \
_(sinh) \
_(tan) \
_(pow, Tensor_Scalar) \
_(pow, Tensor_Tensor) \
_(pow, Scalar) \
_(softplus) \
_(layer_norm) \
_(native_layer_norm) \
_(group_norm) \
_(frobenius_norm, dim) \
_(nuclear_norm) \
_(nuclear_norm, dim) \
_(cosine_similarity) \
_(poisson_nll_loss) \
_(cosine_embedding_loss) \
_(nll_loss) \
_(nll_loss2d) \
_(hinge_embedding_loss) \
_(kl_div) \
_(l1_loss) \
_(smooth_l1_loss) \
_(huber_loss) \
_(mse_loss) \
_(margin_ranking_loss) \
_(multilabel_margin_loss) \
_(soft_margin_loss) \
_(triplet_margin_loss) \
_(multi_margin_loss) \
_(binary_cross_entropy_with_logits) \
_(dist) \
_(pdist) \
_(cdist) \
_(renorm) \
_(logsumexp) \
_(upsample_nearest1d) \
_(_upsample_nearest_exact1d) \
_(upsample_nearest2d) \
_(_upsample_nearest_exact2d) \
_(upsample_nearest3d) \
_(_upsample_nearest_exact3d) \
_(upsample_linear1d) \
_(upsample_bilinear2d) \
_(_upsample_bilinear2d_aa) \
_(upsample_trilinear3d) \
_(upsample_bicubic2d) \
_(_upsample_bicubic2d_aa)

#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
_(prod) \
_(prod, dim_int) \
_(prod, dim_Dimname) \
_(softmax, int) \
_(softmax, Dimname) \
_(log_softmax, int) \
_(log_softmax, Dimname) \
_(cumprod) \
_(cumprod, dimname) \
_(cumsum) \
_(cumsum, dimname) \
_(linalg_vector_norm) \
_(linalg_matrix_norm) \
_(linalg_matrix_norm, str_ord) \
_(sum) \
_(sum, dim_IntList) \
_(sum, dim_DimnameList)

#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
_(ADD_NS(norm), \
"norm.Scalar", \
Tensor(const Tensor&, const Scalar&), \
Tensor(const Tensor&, const c10::optional<Scalar>&, ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, IntArrayRef, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
IntArrayRef, \
bool, \
ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.names_ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, DimnameList, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
DimnameList, \
bool, \
ScalarType), \
fp32_append_dtype)

#define AT_FORALL_PROMOTE(_) \
_(addcdiv) \
_(addcmul) \
_(atan2) \
_(bilinear) \
_(cross) \
_(dot) \
_(grid_sampler) \
_(index_put) \
_(tensordot) \
_(scatter_add)

0 comments on commit ad9a27f

Please sign in to comment.