Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ops): add accessor APIs tree_flatten_with_accessor and PyTreeSpec.accessors #108

Merged
merged 59 commits into from
May 10, 2024

Conversation

XuehaiPan
Copy link
Member

@XuehaiPan XuehaiPan commented Nov 29, 2023

Description

Describe your changes in detail.

This PR adds a new method accessors to PyTreeSpec objects.

The path to a leaf node is a tuple of entries, where an entry can be a raw index or dictionary key.

The newly added accessor will associate the corresponding node type to the entry. An accessor to a leaf node is a tuple of triples (entry, type, kind). For example:

In [1]: from collections import *

In [2]: MyTuple = namedtuple('MyTuple', ['x', 'y', 'z'])

In [3]: tree = OrderedDict([('a', (1, 2)), ('b', {'c': [3, 4], 'd': 5}), ('e', MyTuple(6, 7, 8))])

In [4]: import optree

In [5]: treespec = optree.tree_structure(tree)

In [6]: treespec
Out[6]: PyTreeSpec(OrderedDict([('a', (*, *)), ('b', {'c': [*, *], 'd': *}), ('e', MyTuple(x=*, y=*, z=*))]))

In [7]: treespec.paths()
Out[7]: [('a', 0), ('a', 1), ('b', 'c', 0), ('b', 'c', 1), ('b', 'd'), ('e', 0), ('e', 1), ('e', 2)]

In [8]: treespec.accessors()
Out[8]: 
[
    PyTreeAccessor(*['a'][0], (MappingEntry(key='a', type=<class 'collections.OrderedDict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
    PyTreeAccessor(*['a'][1], (MappingEntry(key='a', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>))),
    PyTreeAccessor(*['b']['c'][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), MappingEntry(key='c', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'list'>))),
    PyTreeAccessor(*['b']['c'][1], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), MappingEntry(key='c', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'list'>))),
    PyTreeAccessor(*['b']['d'], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), MappingEntry(key='d', type=<class 'dict'>))),
    PyTreeAccessor(*['e'].x, (MappingEntry(key='e', type=<class 'collections.OrderedDict'>), NamedTupleEntry(field=0, type=<class '__main__.MyTuple'>))),
    PyTreeAccessor(*['e'].y, (MappingEntry(key='e', type=<class 'collections.OrderedDict'>), NamedTupleEntry(field=1, type=<class '__main__.MyTuple'>))),
    PyTreeAccessor(*['e'].z, (MappingEntry(key='e', type=<class 'collections.OrderedDict'>), NamedTupleEntry(field=2, type=<class '__main__.MyTuple'>)))
]

In [9]: treespec.paths()[0]
Out[9]: ('a', 0)

In [10]: treespec.accessors()[0]
Out[10]:
PyTreeAccessor(*['a'][0], (MappingEntry(key='a', type=<class 'collections.OrderedDict'>), SequenceEntry(index=0, type=<class 'tuple'>)))

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213 if this solves the issue #15213

Types of changes

What types of changes does your code introduce? Put an x in all the boxes that apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide. (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly. (required for a bug fix or a new feature)
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format. (required)
  • I have checked the code using make lint. (required)
  • I have ensured make test pass. (required)

@XuehaiPan XuehaiPan added enhancement New feature or request cxx Something related to the CXX source code py Something related to the Python source code labels Nov 29, 2023
@XuehaiPan XuehaiPan self-assigned this Nov 29, 2023
Copy link

codecov bot commented Nov 29, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 100.00%. Comparing base (10842aa) to head (b089dad).

Additional details and impacted files
@@            Coverage Diff             @@
##              main      #108    +/-   ##
==========================================
  Coverage   100.00%   100.00%            
==========================================
  Files            9        10     +1     
  Lines          787       986   +199     
==========================================
+ Hits           787       986   +199     
Flag Coverage Δ
unittests 100.00% <100.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@stroxler
Copy link

The API shape here looks pretty good to me, I think it meets the needs discussed in
pytorch/pytorch#113378

@XuehaiPan XuehaiPan changed the title feat(ops): add typed paths APIs tree_flatten_with_typed_path and PyTreeSpec.typed_paths feat(ops): add typed path APIs tree_flatten_with_typed_path and PyTreeSpec.typed_paths Nov 29, 2023
@XuehaiPan XuehaiPan force-pushed the typed-path branch 2 times, most recently from 572388b to f814452 Compare November 30, 2023 08:36
@XuehaiPan XuehaiPan force-pushed the typed-path branch 2 times, most recently from e019c1e to 9ec497c Compare December 28, 2023 17:52
@XuehaiPan XuehaiPan force-pushed the typed-path branch 4 times, most recently from d021c01 to 0dc9239 Compare January 7, 2024 20:46
@XuehaiPan XuehaiPan force-pushed the typed-path branch 2 times, most recently from c8002e7 to 487c26c Compare January 17, 2024 06:56
@XuehaiPan XuehaiPan marked this pull request as draft March 9, 2024 16:00
@XuehaiPan XuehaiPan force-pushed the typed-path branch 4 times, most recently from 91edc06 to 1d03768 Compare April 2, 2024 13:40
@XuehaiPan XuehaiPan changed the title feat(ops): add typed path APIs tree_flatten_with_typed_path and PyTreeSpec.typed_paths feat(ops): add accessor APIs tree_flatten_with_accessor and PyTreeSpec.accessors Apr 2, 2024
@XuehaiPan XuehaiPan marked this pull request as ready for review May 2, 2024 12:53
@XuehaiPan XuehaiPan requested a review from JieRen98 May 2, 2024 12:54
@XuehaiPan XuehaiPan merged commit 7a20f35 into metaopt:main May 10, 2024
44 checks passed
@XuehaiPan XuehaiPan deleted the typed-path branch May 10, 2024 16:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cxx Something related to the CXX source code enhancement New feature or request py Something related to the Python source code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants