New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up _extract_graph_with_inputs_outputs #125937
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125937
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 2cb95da with merge base 946b96f (): BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 6ab8f1b7e94e977ad9189f4785813a70accf8ee5 Pull Request resolved: #125937
torch/_functorch/partitioners.py
Outdated
@@ -88,6 +88,9 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): | |||
new_graph = fx.Graph() | |||
env = {} | |||
|
|||
# Ensure we can ask about our input nodes quickly. | |||
inputs = inputs if isinstance(inputs, set) else set(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are iterating over this right after, and want to have deterministic iteration order, should we instead use dict.fromkeys
and update whatever is making inputs set in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol - I actually meant to put this line after the iteration (and in my internal version I did - I just copied the change over to the OSS version poorly). But I also realized that it's probably correct to do node in env
instead.
_extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n). Ensure it's a set before looping over all the nodes. This change speeds up the internal repro (D57090987) by about 18%: before: ``` 708.88user 15.86system 12:16.19elapsed 98%CPU (0avgtext+0avgdata 12898628maxresident)k 0inputs+91968outputs (3major+3532970minor)pagefaults 0swaps ``` after: ``` 583.39user 15.98system 10:10.11elapsed 98%CPU (0avgtext+0avgdata 12895108maxresident)k 0inputs+87488outputs (4major+3374582minor)pagefaults 0swaps ``` [ghstack-poisoned]
ghstack-source-id: 6787796ed5033d8bf22730703fbd547575b1a366 Pull Request resolved: #125937
_extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n). Ensure it's a set before looping over all the nodes. This change speeds up the internal repro (D57090987) by about 18%: before: ``` 708.88user 15.86system 12:16.19elapsed 98%CPU (0avgtext+0avgdata 12898628maxresident)k 0inputs+91968outputs (3major+3532970minor)pagefaults 0swaps ``` after: ``` 583.39user 15.98system 10:10.11elapsed 98%CPU (0avgtext+0avgdata 12895108maxresident)k 0inputs+87488outputs (4major+3374582minor)pagefaults 0swaps ``` [ghstack-poisoned]
ghstack-source-id: e8b93cf6f7278eea2e0d04c4c627c2190f508ae0 Pull Request resolved: #125937
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Nice! I think it this compilation with cache, right? Because cold start was taking >6000sec for me. |
That's right - that timing is with the cache. I imagine it will help the cold start time less. |
_extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n). Ensure it's a set before looping over all the nodes. This change speeds up the internal repro (D57090987) by about 18%: before: ``` 708.88user 15.86system 12:16.19elapsed 98%CPU (0avgtext+0avgdata 12898628maxresident)k 0inputs+91968outputs (3major+3532970minor)pagefaults 0swaps ``` after: ``` 583.39user 15.98system 10:10.11elapsed 98%CPU (0avgtext+0avgdata 12895108maxresident)k 0inputs+87488outputs (4major+3374582minor)pagefaults 0swaps ``` Pull Request resolved: pytorch#125937 Approved by: https://github.com/oulgen, https://github.com/anijain2305
_extract_graph_with_inputs_outputs() does membership testing on the input nodes but often that collection is a list so the test is O(n). Ensure it's a set before looping over all the nodes.
This change speeds up the internal repro (D57090987) by about 18%:
before:
after:
Stack from ghstack (oldest at bottom):