Skip to content

Commit

Permalink
feat(ops): add accessor APIs tree_flatten_with_accessor and `PyTree…
Browse files Browse the repository at this point in the history
…Spec.accessors` (#108)
  • Loading branch information
XuehaiPan committed May 10, 2024
1 parent 10842aa commit 7a20f35
Show file tree
Hide file tree
Showing 32 changed files with 4,482 additions and 267 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
rev: v18.1.5
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.4.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add accessor APIs `tree_flatten_with_accessor` and `PyTreeSpec.accessors` by [@XuehaiPan](https://github.com/XuehaiPan) in [#108](https://github.com/metaopt/optree/pull/108).
- Add submodule `optree.functools` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134).

### Changed
Expand All @@ -27,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

- Deprecate key path APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#108](https://github.com/metaopt/optree/pull/108).
- Deprecate `optree.Partial` and replace with `optree.functools.partial` by [@XuehaiPan](https://github.com/XuehaiPan) in [#134](https://github.com/metaopt/optree/pull/134).

------
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ pytest: pytest-install
$(PYTHON) -m pytest --version
cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \
$(PYTHON) -c 'import $(PROJECT_PATH)._C; print(f"GLIBCXX_USE_CXX11_ABI={$(PROJECT_PATH)._C.GLIBCXX_USE_CXX11_ABI}")' && \
$(PYTHON) -m pytest --verbose --color=yes --durations=0 \
$(PYTHON) -m pytest --verbose --color=yes \
--cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \
$(PYTESTOPTS) .

Expand Down Expand Up @@ -185,7 +185,7 @@ clang-format: clang-format-install

clang-tidy: clang-tidy-install cmake-configure
clang-tidy --version
clang-tidy -p=cmake-build-debug $(CXX_FILES)
clang-tidy --extra-arg="-v" -p=cmake-build-debug $(CXX_FILES)

# Documentation

Expand Down
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ optree.register_pytree_node(
# Register a Python type into a namespace
import torch

class Torch2NumpyEntry(optree.PyTreeEntry):
def __call__(self, obj):
assert self.entry == 0
return obj.cpu().detach().numpy()

def codegen(self, node=''):
assert self.entry == 0
return f'{node}.cpu().detach().numpy()'

optree.register_pytree_node(
torch.Tensor,
# (tensor) -> (children, metadata)
Expand All @@ -161,6 +170,7 @@ optree.register_pytree_node(
),
# (metadata, children) -> tensor
unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
path_entry_type=Torch2NumpyEntry,
namespace='torch2numpy',
)
```
Expand Down Expand Up @@ -192,6 +202,13 @@ optree.register_pytree_node(
>>> optree.tree_paths(tree, namespace='torch2numpy')
[('bias', 0), ('weight', 0)]

# Custom path entry type defines the pytree access behavior
>>> optree.tree_accessors(tree, namespace='torch2numpy')
[
PyTreeAccessor(*['bias'].cpu().detach().numpy(), (MappingEntry(key='bias', type=<class 'dict'>), Torch2NumpyEntry(entry=0, type=<class 'torch.Tensor'>))),
PyTreeAccessor(*['weight'].cpu().detach().numpy(), (MappingEntry(key='weight', type=<class 'dict'>), Torch2NumpyEntry(entry=0, type=<class 'torch.Tensor'>)))
]

# Unflatten back to a copy of the original object
>>> optree.tree_unflatten(treespec, leaves)
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
Expand All @@ -204,6 +221,8 @@ from collections import UserDict

@optree.register_pytree_node_class(namespace='mydict')
class MyDict(UserDict):
TREE_PATH_ENTRY_TYPE = optree.MappingEntry # used by accessor APIs

def tree_flatten(self): # -> (children, metadata, entries)
reversed_keys = sorted(self.keys(), reverse=True)
return (
Expand Down Expand Up @@ -238,6 +257,21 @@ class MyDict(UserDict):
namespace='mydict'
)
)
>>> optree.tree_flatten_with_accessor(tree, namespace='mydict')
(
[
PyTreeAccessor(*['c']['f'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='f', type=<class 'MyDict'>))),
PyTreeAccessor(*['c']['d'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='d', type=<class 'MyDict'>))),
PyTreeAccessor(*['b'], (MappingEntry(key='b', type=<class 'MyDict'>),)),
PyTreeAccessor(*['a'][0], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['a'][1], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
],
[6, 5, 4, 2, 3],
PyTreeSpec(
CustomTreeNode(MyDict[['c', 'b', 'a']], [CustomTreeNode(MyDict[['f', 'd']], [*, *]), *, (*, *)]),
namespace='mydict'
)
)
```

#### Notes about the PyTree Type Registry
Expand All @@ -262,6 +296,8 @@ There are several key attributes of the pytree type registry:

@optree.register_pytree_node_class(namespace='mydict')
class MyDict(UserDict):
TREE_PATH_ENTRY_TYPE = optree.MappingEntry # used by accessor APIs

def __init_subclass__(cls): # define this in the base class
super().__init_subclass__()
# Register a subclass to namespace 'mydict'
Expand Down Expand Up @@ -295,6 +331,14 @@ There are several key attributes of the pytree type registry:
namespace='mydict'
)
)
>>> optree.tree_accessors(tree, namespace='mydict')
[
PyTreeAccessor(*['c']['f'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='f', type=<class 'MyAnotherDict'>))),
PyTreeAccessor(*['c']['d'], (MappingEntry(key='c', type=<class 'MyDict'>), MappingEntry(key='d', type=<class 'MyAnotherDict'>))),
PyTreeAccessor(*['b'], (MappingEntry(key='b', type=<class 'MyDict'>),)),
PyTreeAccessor(*['a'][0], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=0, type=<class 'tuple'>))),
PyTreeAccessor(*['a'][1], (MappingEntry(key='a', type=<class 'MyDict'>), SequenceEntry(index=1, type=<class 'tuple'>)))
]
```

5. **Be careful about the potential infinite recursion of the custom flatten function.** The returned `children` from the custom flatten function are considered subtrees. They will be further flattened recursively. The `children` can have the same type as the current node. Users must design their termination condition carefully.
Expand Down
29 changes: 15 additions & 14 deletions conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ channels:
- conda-forge

dependencies:
- python = 3.10
- python = 3.11
- pip

# Dependency
Expand All @@ -36,15 +36,15 @@ dependencies:
- cmake >= 3.11
- make
- cxx-compiler
- pybind11
- pybind11 >= 2.11.1

# Benchmark
- pytorch::pytorch >= 2.0, < 2.1.0a0
- pytorch::pytorch >= 2.0, < 2.4.0a0
- pytorch::torchvision
- pytorch::pytorch-mutex = *=*cpu*
- jax >= 0.4.6, < 0.5.0a0
- jaxlib >= 0.4, < 0.5.0a0 =*cpu*
- dm-tree
- conda-forge::jax >= 0.4.6, < 0.5.0a0
- conda-forge::jaxlib >= 0.4, < 0.5.0a0 =*cpu*
- conda-forge::dm-tree
- pandas
- tabulate
- termcolor
Expand All @@ -65,10 +65,10 @@ dependencies:
- pytest
- pytest-cov
- pytest-xdist
- isort >= 5.11.0
- conda-forge::black >= 22.6.0
- pylint >= 2.15.0
- mypy >= 0.990
- isort
- conda-forge::black
- pylint
- mypy
- flake8
- flake8-bugbear
- flake8-comprehensions
Expand All @@ -79,7 +79,8 @@ dependencies:
- doc8
- pydocstyle
- xdoctest
- clang-format >= 14
- clang-tools >= 14 # clang-tidy
- cpplint
- pre-commit
- conda-forge::clang-format >= 14
- conda-forge::clang-tools >= 14 # clang-tidy
- conda-forge::cpplint
- conda-forge::pre-commit
- conda-forge::identify
4 changes: 2 additions & 2 deletions docs/conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ channels:
- conda-forge

dependencies:
- python = 3.10
- python = 3.11
- pip

# Dependency
Expand All @@ -35,7 +35,7 @@ dependencies:
- cmake >= 3.11
- make
- cxx-compiler
- pybind11
- pybind11 >= 2.11.1

# Documentation
- sphinx >= 5.2.1
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ sphinxcontrib-bibtex
sphinx-autodoc-typehints >= 1.19.2
docutils

jax[cpu] >= 0.3
jax[cpu] >= 0.4.6
numpy
torch
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ API References
:undoc-members:
:show-inheritance:
:noindex:
:exclude-members: register_keypaths, AttributeKeyPathEntry, GetitemKeyPathEntry
14 changes: 14 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,53 +25,65 @@ Tree Manipulation Functions

tree_flatten
tree_flatten_with_path
tree_flatten_with_accessor
tree_unflatten
tree_iter
tree_leaves
tree_structure
tree_paths
tree_accessors
tree_is_leaf
all_leaves
tree_map
tree_map_
tree_map_with_path
tree_map_with_path_
tree_map_with_accessor
tree_map_with_accessor_
tree_replace_nones
tree_transpose
tree_transpose_map
tree_transpose_map_with_path
tree_transpose_map_with_accessor
tree_broadcast_prefix
broadcast_prefix
tree_broadcast_common
broadcast_common
tree_broadcast_map
tree_broadcast_map_with_path
tree_broadcast_map_with_accessor
tree_flatten_one_level
prefix_errors

.. autofunction:: tree_flatten
.. autofunction:: tree_flatten_with_path
.. autofunction:: tree_flatten_with_accessor
.. autofunction:: tree_unflatten
.. autofunction:: tree_iter
.. autofunction:: tree_leaves
.. autofunction:: tree_structure
.. autofunction:: tree_paths
.. autofunction:: tree_accessors
.. autofunction:: tree_is_leaf
.. autofunction:: all_leaves
.. autofunction:: tree_map
.. autofunction:: tree_map_
.. autofunction:: tree_map_with_path
.. autofunction:: tree_map_with_path_
.. autofunction:: tree_map_with_accessor
.. autofunction:: tree_map_with_accessor_
.. autofunction:: tree_replace_nones
.. autofunction:: tree_transpose
.. autofunction:: tree_transpose_map
.. autofunction:: tree_transpose_map_with_path
.. autofunction:: tree_transpose_map_with_accessor
.. autofunction:: tree_broadcast_prefix
.. autofunction:: broadcast_prefix
.. autofunction:: tree_broadcast_common
.. autofunction:: broadcast_common
.. autofunction:: tree_broadcast_map
.. autofunction:: tree_broadcast_map_with_path
.. autofunction:: tree_broadcast_map_with_accessor
.. autofunction:: tree_flatten_one_level
.. autofunction:: prefix_errors

Expand Down Expand Up @@ -104,6 +116,7 @@ PyTreeSpec Functions
.. autosummary::

treespec_paths
treespec_accessors
treespec_entries
treespec_entry
treespec_children
Expand All @@ -125,6 +138,7 @@ PyTreeSpec Functions
treespec_from_collection

.. autofunction:: treespec_paths
.. autofunction:: treespec_accessors
.. autofunction:: treespec_entries
.. autofunction:: treespec_entry
.. autofunction:: treespec_children
Expand Down
6 changes: 0 additions & 6 deletions docs/source/registry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@ PyTree Node Registry
register_pytree_node
register_pytree_node_class
unregister_pytree_node
register_keypaths
AttributeKeyPathEntry
GetitemKeyPathEntry

.. autofunction:: register_pytree_node
.. autofunction:: register_pytree_node_class
.. autofunction:: unregister_pytree_node
.. autofunction:: register_keypaths
.. autofunction:: AttributeKeyPathEntry
.. autofunction:: GetitemKeyPathEntry
8 changes: 8 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
PyTree
pytree
pytrees
PyTreeKind
treespec
treespecs
subtree
Expand Down Expand Up @@ -66,6 +67,7 @@ CPython
CPython's
sortable
OrderedDict
ordereddict
defaultdict
unsubclassable
sys
Expand All @@ -81,3 +83,9 @@ setattr
delattr
typecheck
subclassed
dataclasses
subpath
accessor
accessors
codegen
subtype
4 changes: 4 additions & 0 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class PyTreeTypeRegistry {
py::function flatten_func{};
// A function with signature: (metadata, iterable) -> object
py::function unflatten_func{};
// The Python type object for the path entry class.
py::object path_entry_type{};
};

using RegistrationPtr = std::shared_ptr<const Registration>;
Expand All @@ -82,6 +84,7 @@ class PyTreeTypeRegistry {
static void Register(const py::object &cls,
const py::function &flatten_func,
const py::function &unflatten_func,
const py::object &path_entry_type,
const std::string &registry_namespace = "");

static void Unregister(const py::object &cls, const std::string &registry_namespace = "");
Expand All @@ -104,6 +107,7 @@ class PyTreeTypeRegistry {
static void RegisterImpl(const py::object &cls,
const py::function &flatten_func,
const py::function &unflatten_func,
const py::object &path_entry_type,
const std::string &registry_namespace);

template <bool NoneIsLeaf>
Expand Down

0 comments on commit 7a20f35

Please sign in to comment.