Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linalg.lstsq to use svd to compute the result for rank deficient matrices. #125110

Closed
wants to merge 237 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
237 commits
Select commit Hold shift + click to select a range
7372645
Add logic for lstsq to be able to use the SVD driver as a backend for…
ZelboK Apr 27, 2024
99e7cfb
Formatting.
ZelboK Apr 28, 2024
e0fec86
run lintrunner -a
ZelboK Apr 28, 2024
bb20952
Update aten/src/ATen/native/BatchLinearAlgebra.cpp
ZelboK Apr 28, 2024
b6d6086
Address comments. Clean up use of zeros and utilize higher level func…
ZelboK Apr 28, 2024
755e7d9
Add test to ensure it will not throw an exception
ZelboK Apr 28, 2024
6e8b3fd
Formatting.
ZelboK Apr 28, 2024
c71e504
Add conditional to ensure only CUDA goes through SVD code path as fal…
ZelboK Apr 28, 2024
d5b0174
Merge branch 'main' into feat-improve-driver-linalg-lstq
ZelboK Apr 29, 2024
da81459
Add tests and fix logic accordingly so tbehavior is as expected.
ZelboK Apr 29, 2024
c856b9e
Set rank for svd workflow.
ZelboK Apr 29, 2024
de502bc
Update aten/src/ATen/native/BatchLinearAlgebra.cpp
ZelboK Apr 30, 2024
3006f30
Remove conditional around raw_residuals andd unnecessary rank.fill
ZelboK May 1, 2024
428f02a
Merge branch 'main' into feat-improve-driver-linalg-lstq
ZelboK May 13, 2024
4eab1c3
Merge branch 'main' into feat-improve-driver-linalg-lstq
ZelboK May 14, 2024
fec9793
Clean up rank so that it uses Spinv and revert change in init.py
ZelboK May 14, 2024
489afbe
lint
ZelboK May 14, 2024
93d2573
[export] handle aliased/unused params for unflattening (#125758)
pianpwk May 14, 2024
4f024c8
Enable epilogue fusion benchmarking internally (#125455)
eellison May 13, 2024
a2e8b90
Fanatically correct real tensor cloning for propagate_real_tensors (#…
ezyang May 14, 2024
10b10f2
[reland][dynamo][disable] Move disable impl to its own __call__ metho…
anijain2305 May 14, 2024
f209865
[easy][dynamo] Use disable_dynamo for torch.manual_seed (#126192)
anijain2305 May 14, 2024
a95b7e9
Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
pytorchmergebot May 14, 2024
50b88b0
Revert "[CUDA] [CI] Add cu124 docker images (#125944)"
pytorchmergebot May 14, 2024
37f84cb
Remove use of USE_C10D (#126120)
briancoutinho May 15, 2024
00b9974
[torch/distributed] Bugfix: wait for all child procs to exit before c…
kiukchung May 15, 2024
1dfe2d1
Allow for trailing 'a' in sm_arch (#126185)
drisspg May 15, 2024
ed27236
[pipelining] Add manual pipeline stage (#126123)
H-Huang May 14, 2024
636ea1c
Refactor make_fx to better support hop subgraph tracing (#125267)
ydwu4 May 14, 2024
a745003
Support trace_subgraph in _MakefxTracer (#125363)
ydwu4 May 14, 2024
976f0f2
[Dynamo] Supports torch._C._is_any_autocast_enabled (#126196)
yanboliang May 15, 2024
b3f0fce
Set dtype when copying empty tensor (#126124)
huydhn May 15, 2024
aa17484
[BE] Abstract out strings to top of file (#125640)
c-p-i-o May 14, 2024
685b207
[Inductor] Flex attention supports dynamic shape (#125994)
yanboliang May 15, 2024
b959b4f
Add missing type uint16, uint32, and uint64 to TensorHash in LTC. (#1…
vanbasten23 May 15, 2024
86d560a
Add some type annotations to python stream and event classes (#126171)
cyyever May 15, 2024
074173b
Support third-party devices emit a range for each autograd operator (…
1274085042 May 15, 2024
d0688dd
[Inductor] Skip test_nll_loss_backward for intel GPU. (#126157)
etaf May 14, 2024
a749763
use statically known instead of suppress guard for ddp stride propaga…
eellison May 14, 2024
11aea9e
Update CUDA out of memory mesage with private pool info (#124673)
isuruf May 15, 2024
32fdb75
Adjust number of repeats when using --warm-start-latency benchmark fl…
masnesral May 10, 2024
38e2661
[benchmarking] Suppress csv creation on cold-start phase of --warm-st…
masnesral May 10, 2024
bc9f57b
Add a few "warm start" smoketest runs to CI (#125955)
masnesral May 10, 2024
ce7a832
[audio hash update] update the pinned audio hash (#126248)
pytorchupdatebot May 15, 2024
3512895
Add force_disable_caches to the docs (#126184)
oulgen May 14, 2024
b743f89
[inductor][cpp] GEMM template (infra and fp32) (#124021)
jgong5 May 15, 2024
170380e
[CUDA] [CI] Add cu124 docker images (#125944)
nWEIdia May 15, 2024
f33cc7a
Don't assert about pending when we are peeking (#126239)
ezyang May 15, 2024
8dc8ae9
[AOTI][torchgen] Update NativeFunctionsGroup mapping (#125962)
desertfire May 11, 2024
5bc525c
[AOTI][torchgen] Add a few more fallback ops (#126013)
desertfire May 13, 2024
4e3dfb0
[Memory Snapshot] Add recordAnnotations to capture record_function an…
aaronenyeshi May 15, 2024
68c29aa
Enable UFMT on `test/test_fake_tensor.py`, `test/test_flop_counter.py…
shink May 15, 2024
cd60801
[Inductor] Generalize new introduced device-bias code. (#126261)
etaf May 15, 2024
a48463e
[export] Cover more cases to copy tensor conversions. (#125628)
zhxchen17 May 15, 2024
ff266cd
Revert "[Memory Snapshot] Add recordAnnotations to capture record_fun…
pytorchmergebot May 15, 2024
e49ccce
[CI] 3 procs non cuda (#125932)
clee2000 May 15, 2024
0df5ed0
Foward fix lint after #125747 (#126295)
clee2000 May 15, 2024
12f2960
Faster int8 quantized (#125704)
malfet May 15, 2024
9b24e7f
[DTensor] Turn on foreach implementation of optimizer for DTensor by …
wz337 May 15, 2024
22c50a3
[Dynamo] SizeVariable supports hasattr (#126222)
yanboliang May 15, 2024
35117bf
CMake: Improve check and report of Magma (#117858)
Flamefire May 15, 2024
3ca1ae4
[onnx.export] Avoid linear loop over symbol_dim_map (#123029)
gustavla May 15, 2024
39b2795
[easy] Remove aot_config from pre_compile returns, rename fw_metadata…
jamesjwu May 14, 2024
b3b9f72
Reland '[Inductor] GEMM shape padding improvements (#118522)' (#125773)
eellison May 14, 2024
0ce75f9
Skip padding cost of fusible/planable inputs (#125780)
eellison May 14, 2024
0f2db1c
Forward fix failures for torch.export switch to predispatch (#126081)
tugsbayasgalan May 15, 2024
6b733b2
Beef up error message for pending assert failure (#126212)
ezyang May 15, 2024
1480537
Enable UFMT format on test/test_utils.py (#125996)
hippocookie May 15, 2024
adf9cc7
Fix aarch64 debug build with GCC (#126290)
malfet May 15, 2024
147ba73
Fix public binding to actually traverse modules (#126103)
albanD May 15, 2024
9397380
[FSDP] Fixed docs for inter/intra node PG helpers (#126288)
awgu May 15, 2024
921a824
Revert "Fix aarch64 debug build with GCC (#126290)"
pytorchmergebot May 15, 2024
0658670
Parametrize test_dim_reduction (#126292)
ezyang May 15, 2024
b5e6220
[DCP] overwrites existing checkpoint by default (#125877)
LucasLLC May 15, 2024
a0a6bbc
Fix public api allowlist logical merge conflict (#126321)
albanD May 15, 2024
910f26f
2 rocm shards on trunk.yml (#125933)
clee2000 May 15, 2024
5b4dea2
[FSDP2] allow meta tensors during loading state dict and cpu offloadi…
weifengpy May 15, 2024
3a3f8a9
[dynamo] Detect monkeypatching on nn module forward method (#126203)
anijain2305 May 15, 2024
569ee1e
[onnx.export] Avoid unnecessary copy of debug_names (#123026)
gustavla May 15, 2024
6243a43
Warn SDPA users about dropout behavior (#126294)
jbschlosser May 15, 2024
9e2b899
Improve Storage copy_ size mismatch error message (#126280)
ezyang May 15, 2024
0d9def0
[CI] Add AMP models in inductor cpu smoketest for performance (#125830)
zxd1997066 May 15, 2024
eb5e9ed
Remove Caffe2 python code (#126035)
cyyever May 15, 2024
05eff35
Enable UFMT on `test/test_datapipe.py` (#124994)
shink May 15, 2024
75add2f
Remove expected failure in `test_eager_transforms.py` (#125883)
eqy May 15, 2024
079d3f5
[optim] Fix: wrong ASGD implementation (#125440)
david20571015 May 15, 2024
0e22566
Fix triton codegen main do_bench_gpu import error (#126213)
adelesun May 15, 2024
8cc3b81
[dynamo] graph break on const dict KeyError (#125882)
williamwen42 May 15, 2024
972f76f
[dynamo] graph break on issubclass call with non-const args (#125943)
williamwen42 May 15, 2024
2524635
[dynamo] fix https://github.com/pytorch/pytorch/issues/93624 (#125945)
williamwen42 May 15, 2024
d3d25a3
[dynamo][inline-inbuilt-nn-modules] Bug fix - Only unspecialized nn m…
anijain2305 May 15, 2024
592bc1f
[FSDP2] support fully_shard(model_on_meta, cpu_offload) (#126305)
weifengpy May 15, 2024
3fe0c6d
Add VariableTracker.debug_repr (#126299)
ezyang May 15, 2024
421b23d
Also remove compile_time_strobelight_meta frame when generating stack…
ezyang May 15, 2024
db73a01
Make propagate_real_tensor more safe (#126281)
ezyang May 15, 2024
fed9d93
Switched from parameter in can_cast to from_. (#126030)
tringwald May 16, 2024
8f9fa47
[easy][dynamo][inline-inbuilt-nn-modules] Change test to check for pa…
anijain2305 May 15, 2024
7a4c6b9
[Export] Allow ExportedProgram to take empty decomp table (#126142)
StellarrZ May 16, 2024
6944593
[optim] add fused_adagrad support for CPU device (#124905)
zhuhaozhe May 15, 2024
7754cc1
[Inductor][Flex-attention] Make num_head support dynamic (#126342)
yanboliang May 16, 2024
009b5b6
[dynamo][inline-inbuilt-nn-modules] Change test to not depend on id o…
anijain2305 May 15, 2024
ae2fdc8
[dynamo][inline-inbuilt-nn-modules] Add and update test_modules.py fo…
anijain2305 May 15, 2024
675c49f
[inductor] [FX graph cache] Ignore unbacked symints in guards express…
masnesral May 15, 2024
930e757
Revert "Switched from parameter in can_cast to from_. (#126030)"
pytorchmergebot May 16, 2024
88643f1
[inductor][cpp] epilogue support for gemm template (#126019)
jgong5 May 14, 2024
4417b4c
[TEST][Dynamo] fix test_deviceguard.py (#126240)
Aidyn-A May 16, 2024
f30d086
Revert "Remove deprecated _aminmax operator (#125995)"
pytorchmergebot May 16, 2024
b8c08b6
[dynamo][nn module guards] Use TENSOR_MATCH, and not ID_MATCH, for nu…
anijain2305 May 15, 2024
479f3f9
[DeviceMesh] Fix hash and eq not match (#123572)
wz337 May 16, 2024
e974908
[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilog…
jgong5 May 14, 2024
a4250cc
Initial implementation of AdaRound (#126153)
kwanghoon-meta May 16, 2024
195d01c
Revert "[optim] Fix: wrong ASGD implementation (#125440)"
pytorchmergebot May 16, 2024
4397921
Revert "Initial implementation of AdaRound (#126153)"
pytorchmergebot May 16, 2024
22db67f
Add Lowering for FlexAttention Backwards (#125515)
drisspg May 16, 2024
8dced59
[dynamo] Delete extra testing of cpp guard manager (#126343)
anijain2305 May 15, 2024
c73f90c
fix the device type for with_comms decorator (#125798)
wanchaol May 15, 2024
9f09eae
Add mode to MemoryDep to track atomic accumulates (#123223)
isuruf May 15, 2024
2ba6d37
[c10d] Add an option for NAN check on every collective (#125726)
shuqiangzhang May 15, 2024
8989a88
Generate runtime asserts when propagate real tensor is used (#126287)
ezyang May 15, 2024
1473472
[ez] fix exported diff mismatch (#126357)
izaitsevfb May 16, 2024
64fb6ed
[Add sliding window attention bias] (#126061)
lvaleriu May 16, 2024
7dab5f7
Fix lint failures coming from #126035 (#126378)
huydhn May 16, 2024
fb2c753
[1/N] Non-Tensor: Scalar Support: Enable aot compile to support aten …
EikanWang May 15, 2024
45d93f9
[Doc] Add deprecated autocast comments for doc (#126062)
guangyey May 15, 2024
75289f2
Revert "Fix lint failures coming from #126035 (#126378)"
pytorchmergebot May 16, 2024
dd2f8d1
Revert "Add Lowering for FlexAttention Backwards (#125515)"
pytorchmergebot May 16, 2024
adc0551
Fix lint failures coming from #126035 (#126378)
huydhn May 16, 2024
64efc14
[Traceable FSDP2] Add all_gather_into_tensor out variant (#126334)
yf225 May 16, 2024
0ddafc0
Fix broken link of scikit-learn (#120972)
yuanx749 May 16, 2024
8288174
[Reopen] Upgrade submodule oneDNN to v3.4.2 (#126137)
Xia-Weiwen May 16, 2024
5dd875a
[FSDP2] Supported `set_all_reduce_gradients=False` for HSDP (#126166)
awgu May 14, 2024
cebb5df
Fix aarch64 debug build with GCC (#126290)
malfet May 16, 2024
60fb3ef
Add distributed/_tensor/test_attention to ROCM_BLOCKLIST (#126336)
jithunnair-amd May 16, 2024
9df7bda
[ROCm] amax hipblaslt integration (#125921)
alugorey May 16, 2024
19dfbce
Add 2nd shard to ROCm trunk workflow for core distributed UTs (#121716)
jithunnair-amd May 16, 2024
7e7392b
[AOTI][torchgen] Support at::Generator via C shim (#126181)
desertfire May 15, 2024
27b7381
[AOTI] Refactor some fallback op util functions (#126182)
desertfire May 15, 2024
d27e21d
[AOTI] Support InplaceBernoulliFallback in the ABI-compatible codegen…
desertfire May 15, 2024
272b119
[AOTI][refactor] Add aoti_torch_item as a util function (#126352)
desertfire May 16, 2024
667af78
[BE][FSDP] Change the logging level to info (#126362)
fegin May 16, 2024
08e5a7e
[BE][FSDP] Remove unnecessary warnings (#126365)
fegin May 16, 2024
2a34465
[onnx.export] Cache SetGraphInputTypeReliable (#124912)
gustavla May 16, 2024
45a699a
Remove redundant serialization code (#126249)
jiashenC May 16, 2024
b24a9e3
[Dynamo] Support SET_UPDATE (#126243)
yanboliang May 16, 2024
a2e563d
xpu: implement xpu serialization (#125530)
dvrogozh May 16, 2024
4c93c7a
Don't install inplace_methods on MockHandler, not needed (#126398)
ezyang May 16, 2024
f1897d4
Make 'pytest test/inductor/test_memory_planning.py' work (#126397)
ezyang May 16, 2024
f4daf9e
Switched from parameter in can_cast to from_. (#126030)
tringwald May 16, 2024
74ad455
[Traceable FSDP2] Use DTensor.from_local() in _from_local_no_grad whe…
yf225 May 16, 2024
255ae5d
Fix strict default value in StateDictOptions (#125998)
shink May 16, 2024
1367209
Print export warning only once in capture_pre_autograd (#126403)
tarun292 May 16, 2024
80798a7
[compiled autograd] Fix LoggingTensor flaky test (#126144)
xmfan May 16, 2024
b2efbae
[inductor] Clear cache on ctx manager exit (#126146)
xmfan May 16, 2024
b29fd1f
[compiled autograd] clear compiled_autograd_verbose once test is done…
xmfan May 16, 2024
19e7924
add 3.12 inductor CI tests (#126218)
williamwen42 May 15, 2024
cd76785
Eliminate some C++11 checks (#126308)
r-barnes May 16, 2024
2b7ac1e
Add prefix option to CapabilityBasedPartitioner (#126382)
hongyang-zhao May 16, 2024
1948225
Import MKL via //third-party/mkl targets (#126371)
MatzeB May 16, 2024
3bbd7fa
[c10d] add pg_name and pg_desc to logger (#126409)
shuqiangzhang May 16, 2024
ac162de
Use object identity for deepcopy memo (#126126)
davidberard98 May 15, 2024
fa207b5
Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/…
pytorchmergebot May 17, 2024
8f51cf7
Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
pytorchmergebot May 17, 2024
2a6c92a
Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
pytorchmergebot May 17, 2024
7cea4a5
Add Lowering for FlexAttention Backwards (#125515)
drisspg May 17, 2024
814dbc7
Fix documentation for register_fake_class (#126422)
ydwu4 May 16, 2024
65f4d4f
[export] Delete predispatch tests (#126459)
angelayi May 17, 2024
b6d8201
[DeviceMesh] Supported N groups in `from_group` (#126258)
awgu May 16, 2024
b9da19d
[easy] Fix typing for `map_location` docs in torch.load (#125473)
mikaylagawarecki May 16, 2024
22b4b22
[doc] expose torch.Tensor.xpu API to doc (#126383)
guangyey May 16, 2024
6fc8524
Add symbolic_shape_specialization structured trace (#126450)
ezyang May 16, 2024
54ce306
Make inductor scheduler graph extension configurable (#125578)
AlexDenisov May 17, 2024
0f31e61
[FSDP2][Test] Fix _test_clip_grad_norm (#126457)
wz337 May 17, 2024
2cbbe21
dont pad 0 dim mm inputs (#126475)
eellison May 16, 2024
a05c0fa
c10d: add Collectives abstraction (#125978)
d4l3k May 17, 2024
ae7ee03
Add dist_pp shortcut to TORCH_LOGS (#126322)
wconstab May 17, 2024
c61bdbf
[dtensor] refactor view ops to use OpStrategy (#126011)
tianyu-l May 16, 2024
b1770bd
[XPU] call empty_cache for dynamo tests (#126377)
Stonepia May 17, 2024
c271827
Refactor partitioner and clean it up (#126318)
Chillee May 16, 2024
99190da
[DTensor] Turn on foreach implementation for clip_grad_norm_ for DTen…
wz337 May 17, 2024
8221d3d
Fix cummax and cummin lowering for empty case (#126461)
isuruf May 16, 2024
747bdea
[Quant][Inductor] Enable lowering of qlinear-binary(-unary) fusion fo…
Xia-Weiwen May 17, 2024
f55c0cc
variable search spaces for gemm autotuning (#126220)
nmacchioni May 17, 2024
ae3c9ca
save the reciprocal of weights for welford_reduce (#125148)
CaoE May 17, 2024
9882241
[Submodule] Remove zstd dependency (#126485)
cyyever May 17, 2024
7263893
Update ops handler documentation some more (#126480)
ezyang May 17, 2024
9a47caa
[FSDP2] Fixed 2D clip grad norm test (#126497)
awgu May 17, 2024
be7b65a
Default to env variable instead of config value for precompile parall…
eellison May 16, 2024
3f1ccfd
Delete refactored function, move changes over (#126407)
jamesjwu May 16, 2024
e1a0676
[optim] Fix: wrong ASGD implementation (#126375)
david20571015 May 17, 2024
0be8b0f
Early return in _recursive_build if obj is a Tensor (#125639)
guilhermeleobas May 16, 2024
bd10ff6
Remove removed ruff rule TRY200 (#126256)
ringohoffman May 17, 2024
e24f7b3
[Perf] Vectorize more dtype for int4mm (#126512)
malfet May 17, 2024
bb5e037
[inductor] fix unbacked case in pointwise + reduction vertical fusion…
ColinPeppler May 16, 2024
45a8ba4
Workflow for uploading additional test stats on workflow dispatch (#1…
clee2000 May 17, 2024
d0d2d0b
Allow tensor subclasses and add `torch.serialization.add_safe_globals…
mikaylagawarecki May 17, 2024
39f5adb
Enable FX graph cache for huggingface and timm benchmarks (#126205)
masnesral May 16, 2024
218756f
[quant][pt2e] Allow multi users without output observers (#126487)
andrewor14 May 17, 2024
45a3349
Add coms metadata to execution trace (ET) (#126317)
briancoutinho May 17, 2024
2f044a8
Revert "Remove redundant serialization code (#126249)"
pytorchmergebot May 17, 2024
02bf7e2
Revert "Fix aarch64 debug build with GCC (#126290)"
pytorchmergebot May 17, 2024
b2aff20
Initial implementation of AdaRound (#126153)
kwanghoon-meta May 17, 2024
782792b
[distributed] Add cpp-httplib to pytorch (#126470)
PaliC May 17, 2024
5182e2e
[BE][Ez]: Use NotADirectoryError in tensorboard writer (#126534)
Skylion007 May 17, 2024
c81bf77
Revert "[FSDP2] Fixed 2D clip grad norm test (#126497)"
pytorchmergebot May 17, 2024
6372770
[ROCm] enable faster_load_save for Fused_SGD (#125456)
petrex May 17, 2024
04c3751
Experimental prototype for converting torch.jit.trace modules to expo…
tugsbayasgalan May 16, 2024
a1245dd
Disable vulkan test batch_norm_invalid_inputs (#126571)
clee2000 May 17, 2024
68a6cdd
[AOTI] config target platform (#126306)
manuelcandales May 17, 2024
a6235d0
Fix issue of lowering nn.linear ops with kwargs (#126331)
yihanhemeta May 17, 2024
6e4ed6c
[inductor] Load python modules using importlib (#126454)
amjames May 16, 2024
edbd215
[dynamo] Sourceless builder - ordered dict and re.pattern (#126468)
anijain2305 May 17, 2024
6708519
Added error checks for invalid inputs on thnn_conv2d (#121906)
Martim03 May 17, 2024
38a85b2
Fix aarch64 debug build with GCC (#126290)
malfet May 17, 2024
8ab08f9
Remove dist_ prefix from TORCH_LOGS shortcuts (#126499)
wconstab May 17, 2024
bd786d8
Tool for scouting exportability in one shot (#126471)
SherlockNoMad May 18, 2024
4de26b7
[torch-distributed] Make log directory creation idempotent (#126496)
ktsiam May 18, 2024
fbf8018
[AOTI] Flag to include aoti sources when building lite interpreter (#…
manuelcandales May 18, 2024
492ef49
[Pipelining] Fix 1f1b schedule (#126419)
wconstab May 17, 2024
b6caa15
[C10D] Add __repr__ to P2POp class (#126538)
wconstab May 17, 2024
d288e44
gitmodules: switch cpp-httplib to https (#126580)
d4l3k May 18, 2024
68ff312
[pipelining] Follow improvements in export.unflatten (#126217)
kwen2501 May 14, 2024
743df86
[Submodule] Remove third-party CUB (#126540)
cyyever May 18, 2024
deb6f3f
[halide-backend] Refactor codegen/triton.py into codegen/simd.py (#12…
jansel May 17, 2024
8a7f719
Faster(?) FP16 gemv kernel (#126297)
swolchok May 15, 2024
b51e6dd
[2/N] Non-Tensor: Scalar Support: Add scalar to the cache for eager-t…
EikanWang May 17, 2024
23b6ebd
Map float8 types to uint8 for allgather (#126556)
drisspg May 18, 2024
b4a2288
[Traceable FSDP2] Change from register_multi_grad_hook to per-tensor …
yf225 May 18, 2024
b10f3dd
[Dynamo] Treat integers stored on nn.Modules as dynamic (#126466)
yanboliang May 18, 2024
d82bbb0
Refactor variables / function names related to non-strict export (#12…
jiashenC May 18, 2024
6ed6142
Updated test_torch.py to use new OptimizerInfo infrastructure (#125538)
gambiTarun May 18, 2024
0e59bd4
Forward fix the failed new test from D57474327 (#126596)
huydhn May 18, 2024
367a0c5
Cached required_fw_nodes creation (#126613)
Chillee May 18, 2024
0ac2cec
Revert "[Dynamo] Treat integers stored on nn.Modules as dynamic (#126…
pytorchmergebot May 19, 2024
197ebc5
Remove unnecessary implementations from MockHandler (#126511)
ezyang May 17, 2024
2d65795
UFMT torch.utils._sympy.functions (#126553)
ezyang May 19, 2024
0d1108c
Update hf_BirdBird periodic-dynamo-benchmarks results (#126414)
xmfan May 17, 2024
454d0d4
Replace torch.library.impl_abstract with torch.library.register_fake …
cyyever May 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 20 additions & 5 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -117,6 +117,8 @@
#include <ATen/ops/triu.h>
#include <ATen/ops/vdot.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/matmul.h>
#include <ATen/ops/narrow.h>
#endif

// First the required LAPACK implementations are registered here.
Expand Down Expand Up @@ -1556,7 +1558,7 @@ void _linalg_check_errors(
": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ").");
} else if (api_name.find("lstsq") != api_name.npos) {
TORCH_CHECK_LINALG(false, api_name, batch_str,
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, "). Specify SVD in the driver if you would like to do this.");
} else if (api_name.find("lu_factor") != api_name.npos) {
TORCH_CHECK(false, api_name, batch_str,
": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. "
Expand Down Expand Up @@ -3427,8 +3429,21 @@ static void linalg_lstsq_out_info(
auto input_working_copy = copyBatchedColumnMajor(input);

// now the actual call that computes the result in-place (apply_lstsq)
lstsq_stub(input.device().type(), input_working_copy, solution, rank, singular_values, infos, rcond, driver);

if (driver == "gelss" input.device() != at::kCPU) {
auto [U, S, Vh] = at::_linalg_svd(input, false, true, "gesvd");
auto S_pinv = S.reciprocal();
auto s1 = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
S_pinv.masked_fill_(S < rcond * s1, 0);
auto uhOther = at::matmul(U.adjoint(), other);
if(S_pinv.dim() != uhOther.dim()) {
S_pinv = S_pinv.unsqueeze(-1);
}
auto S_pinv_other = S_pinv * uhOther;
solution = at::matmul(Vh.adjoint(), S_pinv_other);
}
else {
lstsq_stub(input.device().type(), input_working_copy, solution, rank, singular_values, infos, rcond, driver);
}
// residuals are available only if m > n and drivers other than gelsy used
if (m > n && driver != "gelsy") {
// if the driver is gelss or gelsd then the residuals are available only if rank == n
Expand Down Expand Up @@ -3490,8 +3505,8 @@ static std::string get_default_lstsq_driver(c10::optional<c10::string_view> driv
);
} else { // else if (input.is_cuda())
TORCH_CHECK(
driver_str == "gels",
"torch.linalg.lstsq: `driver` other than `gels` is not supported on CUDA"
(driver_str == "gelss" || driver_str == "gels"),
"torch.linalg.lstsq: `driver` other than `gels` or `gelss` is not supported on CUDA"
);
}
} else {
Expand Down
10 changes: 9 additions & 1 deletion test/test_linalg.py
Expand Up @@ -445,7 +445,15 @@ def complement_device(device):
b = torch.rand(2, 2, 2, dtype=dtype, device=device)

if device != 'cpu':
with self.assertRaisesRegex(RuntimeError, '`driver` other than `gels` is not supported on CUDA'):
try:
result = torch.linalg.lstsq(a, b, driver='gelss')
self.assertTrue(result is not None)
except Exception as e:
self.fail(f"Unexpected error occurred: {e}")
with self.assertRaisesRegex(
RuntimeError,
'torch.linalg.lstsq: `driver` other than `gels` or `gelss` is not supported on CUDA'
):
ZelboK marked this conversation as resolved.
Show resolved Hide resolved
torch.linalg.lstsq(a, b, driver='fictitious_driver')
# if on cpu
else:
Expand Down
4 changes: 2 additions & 2 deletions torch/linalg/__init__.py
Expand Up @@ -1032,7 +1032,7 @@
- `'gelsd'` (tridiagonal reduction and SVD)
- But if you run into memory issues: `'gelss'` (full SVD).

For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank.
For CUDA inputs, two drivers are available: 'gels' and 'gelss'.

See also the `full description of these drivers`_

Expand Down Expand Up @@ -1080,7 +1080,7 @@

Keyword args:
driver (str, optional): name of the LAPACK/MAGMA method to be used.
If `None`, `'gelsy'` is used for CPU inputs and `'gels'` for CUDA inputs.
If `None`, `'gelsy'` is used for CPU inputs, `'gels'` and `'gelss'` for CUDA inputs.
ZelboK marked this conversation as resolved.
Show resolved Hide resolved
Default: `None`.

Returns:
Expand Down