Skip to content

Commit

Permalink
[dtensor] fix CI and update op db (pytorch#571)
Browse files Browse the repository at this point in the history
* [dtensor] fix CI and update op db

* add xfails

* add zeros
  • Loading branch information
wanchaol authored and wz337 committed Oct 24, 2022
1 parent 508a091 commit 498c773
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 23 deletions.
58 changes: 37 additions & 21 deletions spmd/testing/dtensor_lagging_op_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.testing._internal.common_methods_invocations import op_db, OpInfo

# Generated from test/gen_dtensor_op_db.py via
# python spmd/test/gen_dtensor_lagging_op_db.py > spmd/test/dtensor_lagging_op_db.py
# python spmd/testing/gen_dtensor_lagging_op_db.py > spmd/testing/dtensor_lagging_op_db.py
#
# This approach is copied from functorch:
# People add new OpInfos to PyTorch all the time.
Expand All @@ -28,25 +28,6 @@
("__rpow__", ""),
("__rsub__", ""),
("__rxor__", ""),
("masked.amax", ""),
("masked.amin", ""),
("masked.argmax", ""),
("masked.argmin", ""),
("masked.cumprod", ""),
("masked.cumsum", ""),
("masked.log_softmax", ""),
("masked.logaddexp", ""),
("masked.logsumexp", ""),
("masked.mean", ""),
("masked.median", ""),
("masked.norm", ""),
("masked.normalize", ""),
("masked.prod", ""),
("masked.softmax", ""),
("masked.softmin", ""),
("masked.std", ""),
("masked.sum", ""),
("masked.var", ""),
("abs", ""),
("acos", ""),
("acosh", ""),
Expand Down Expand Up @@ -101,7 +82,9 @@
("cartesian_prod", ""),
("cat", ""),
("cdist", ""),
("cdouble", ""),
("ceil", ""),
("cfloat", ""),
("chalf", ""),
("char", ""),
("cholesky", ""),
Expand Down Expand Up @@ -293,7 +276,7 @@
("log1p", ""),
("log2", ""),
("log_softmax", ""),
("log_softmax", "dtype"),
("log_softmax", "with_dtype"),
("logaddexp", ""),
("logaddexp2", ""),
("logcumsumexp", ""),
Expand All @@ -312,6 +295,25 @@
("lu_unpack", ""),
("mH", ""),
("mT", ""),
("masked.amax", ""),
("masked.amin", ""),
("masked.argmax", ""),
("masked.argmin", ""),
("masked.cumprod", ""),
("masked.cumsum", ""),
("masked.log_softmax", ""),
("masked.logaddexp", ""),
("masked.logsumexp", ""),
("masked.mean", ""),
("masked.median", ""),
("masked.norm", ""),
("masked.normalize", ""),
("masked.prod", ""),
("masked.softmax", ""),
("masked.softmin", ""),
("masked.std", ""),
("masked.sum", ""),
("masked.var", ""),
("masked_fill", ""),
("masked_scatter", ""),
("masked_select", ""),
Expand Down Expand Up @@ -345,6 +347,8 @@
("nanquantile", ""),
("nansum", ""),
("narrow", ""),
("narrow_copy", ""),
("native_batch_norm", ""),
("native_layer_norm", ""),
("ne", ""),
("neg", ""),
Expand All @@ -354,6 +358,7 @@
("new_ones", ""),
("new_zeros", ""),
("nextafter", ""),
("nn.functional._scaled_dot_product_attention", ""),
("nn.functional.adaptive_avg_pool1d", ""),
("nn.functional.adaptive_avg_pool2d", ""),
("nn.functional.adaptive_avg_pool3d", ""),
Expand Down Expand Up @@ -468,6 +473,7 @@
("norm", "nuc"),
("normal", ""),
("normal", "number_mean"),
("ones", ""),
("ones_like", ""),
("ormqr", ""),
("outer", ""),
Expand All @@ -488,7 +494,9 @@
("quantile", ""),
("rad2deg", ""),
("rand_like", ""),
("randint", ""),
("randint_like", ""),
("randn", ""),
("randn_like", ""),
("ravel", ""),
("real", ""),
Expand Down Expand Up @@ -527,10 +535,14 @@
("short", ""),
("sigmoid", ""),
("sign", ""),
("signal.windows.cosine", ""),
("signal.windows.exponential", ""),
("signal.windows.gaussian", ""),
("signbit", ""),
("sin", ""),
("sinc", ""),
("sinh", ""),
("slice", ""),
("slice_scatter", ""),
("softmax", ""),
("softmax", "with_dtype"),
Expand Down Expand Up @@ -595,6 +607,7 @@
("tensor_split", ""),
("tensordot", ""),
("tile", ""),
("to", ""),
("to_sparse", ""),
("topk", ""),
("trace", ""),
Expand All @@ -611,6 +624,8 @@
("unbind", ""),
("unflatten", ""),
("unfold", ""),
("unfold_copy", ""),
("uniform", ""),
("unique", ""),
("unique_consecutive", ""),
("unsqueeze", ""),
Expand All @@ -626,6 +641,7 @@
("where", ""),
("xlogy", ""),
("zero_", ""),
("zeros", ""),
("zeros_like", ""),
}

Expand Down
2 changes: 1 addition & 1 deletion spmd/testing/gen_dtensor_lagging_op_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def deindent(code: str) -> str:
from typing import List
from torch.testing._internal.common_methods_invocations import op_db, OpInfo
# Generated from test/gen_dtensor_op_db.py via
# python spmd/test/gen_dtensor_lagging_op_db.py > spmd/test/dtensor_lagging_op_db.py
# python spmd/testing/gen_dtensor_lagging_op_db.py > spmd/testing/dtensor_lagging_op_db.py
#
# This approach is copied from functorch:
# People add new OpInfos to PyTorch all the time.
Expand Down
14 changes: 13 additions & 1 deletion test/spmd/tensor/test_dtensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def wrapped(fn):
xfail("linalg.vector_norm"),
xfail("linspace"),
xfail("log_softmax"),
xfail("log_softmax", "dtype"),
xfail("log_softmax", "with_dtype"),
xfail("logcumsumexp"),
xfail("logdet"),
xfail("logical_not"),
Expand Down Expand Up @@ -314,7 +314,9 @@ def wrapped(fn):
xfail("nanmedian"),
xfail("nanquantile"),
xfail("nansum"),
xfail("native_batch_norm"),
xfail("native_layer_norm"),
xfail("narrow_copy"),
xfail("ne"),
xfail("new_empty"),
xfail("new_empty_strided"),
Expand Down Expand Up @@ -419,6 +421,7 @@ def wrapped(fn):
xfail("normal"),
xfail("normal", "number_mean"),
xfail("ormqr"),
xfail("ones"),
xfail("pca_lowrank"),
xfail("pinverse"),
xfail("polar"),
Expand All @@ -428,6 +431,8 @@ def wrapped(fn):
xfail("rad2deg"),
xfail("rand_like"),
xfail("randint_like"),
xfail("randint"),
xfail("randn"),
xfail("randn_like"),
xfail("renorm"),
xfail("repeat_interleave"),
Expand Down Expand Up @@ -478,6 +483,9 @@ def wrapped(fn):
xfail("split"),
xfail("split", "list_args"),
xfail("split_with_sizes"),
xfail("signal.windows.cosine"),
xfail("signal.windows.exponential"),
xfail("signal.windows.gaussian"),
xfail("squeeze"),
xfail("stack"),
xfail("std"),
Expand All @@ -500,6 +508,8 @@ def wrapped(fn):
xfail("triu"),
xfail("unbind"),
xfail("unfold"),
xfail("unfold_copy"),
xfail("uniform"),
xfail("unflatten"),
xfail("unique_consecutive"),
xfail("unique"),
Expand All @@ -508,6 +518,7 @@ def wrapped(fn):
xfail("vdot"),
xfail("view_as_complex"),
xfail("vstack"),
xfail("zeros"),
# ops inside this might even fail without dtensor
# tests, as we rescale op db common test size factor (i.e. L, M, S)
# which triggered the orignal function run failures with input
Expand All @@ -518,6 +529,7 @@ def wrapped(fn):
skip("__rmatmul__"),
skip("meshgrid", "list_of_tensors"),
skip("meshgrid", "variadic_tensors"),
skip("nn.functional._scaled_dot_product_attention"),
skip("nn.functional.softmin"),
skip("nn.functional.embedding"),
skip("nn.functional.embedding_bag"),
Expand Down

0 comments on commit 498c773

Please sign in to comment.