Skip to content

Commit

Permalink
[nnx] Object refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed May 13, 2024
1 parent a2a6242 commit 4786c5f
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 216 deletions.
2 changes: 1 addition & 1 deletion flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.graph import GraphDef as GraphDef
from .nnx.graph import GraphNode as GraphNode
from .nnx.object import Object as Object
from .nnx.helpers import Dict as Dict
from .nnx.helpers import List as List
from .nnx.helpers import Sequential as Sequential
Expand Down
211 changes: 14 additions & 197 deletions flax/experimental/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,16 @@

import dataclasses
import enum
import threading
import typing as tp
from abc import ABCMeta
from copy import deepcopy


import jax
import numpy as np
import typing_extensions as tpe

from flax.experimental.nnx.nnx import (
errors,
filterlib,
reprlib,
tracers,
)
from flax.experimental.nnx.nnx.proxy_caller import (
ApplyCaller,
Expand All @@ -44,12 +39,12 @@
is_state_leaf,
)
from flax.experimental.nnx.nnx.variables import Variable, VariableState
from flax.typing import PathParts, Key
from flax.typing import Key, PathParts

A = tp.TypeVar('A')
B = tp.TypeVar('B')
C = tp.TypeVar('C')
G = tp.TypeVar('G', bound='GraphNode')

HA = tp.TypeVar('HA', bound=tp.Hashable)
HB = tp.TypeVar('HB', bound=tp.Hashable)

Expand All @@ -75,14 +70,6 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
return isinstance(x, (Variable, np.ndarray, jax.Array))


@dataclasses.dataclass
class GraphUtilsContext(threading.local):
seen_modules_repr: set[int] | None = None


CONTEXT = GraphUtilsContext()


class _HashById(tp.Hashable, tp.Generic[A]):
"""A wrapper around a value that uses its id for hashing and equality.
This is used by RefMap to explicitly use object id as the hash for the keys.
Expand Down Expand Up @@ -854,14 +841,12 @@ def __eq__(self, other):
return isinstance(other, UpdateContext)

@tp.overload
def split(self, graph_node: A, /) -> tuple[GraphDef[A], State]:
...
def split(self, graph_node: A, /) -> tuple[GraphDef[A], State]: ...

@tp.overload
def split(
self, graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], State]:
...
) -> tuple[GraphDef[A], State]: ...

@tp.overload
def split(
Expand All @@ -871,8 +856,7 @@ def split(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]:
...
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ...

def split(
self, node: A, *filters: filterlib.Filter
Expand Down Expand Up @@ -934,17 +918,15 @@ def update(


@tp.overload
def split(graph_node: A, /) -> tuple[GraphDef[A], State]:
...
def split(graph_node: A, /) -> tuple[GraphDef[A], State]: ...


@tp.overload
def split(
graph_node: A,
first: filterlib.Filter,
/,
) -> tuple[GraphDef[A], State]:
...
) -> tuple[GraphDef[A], State]: ...


@tp.overload
Expand All @@ -954,8 +936,7 @@ def split(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]:
...
) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ...


def split(
Expand Down Expand Up @@ -995,13 +976,11 @@ def update(node, state: State, /, *states: State) -> None:


@tp.overload
def state(node, /) -> State:
...
def state(node, /) -> State: ...


@tp.overload
def state(node, first: filterlib.Filter, /) -> State:
...
def state(node, first: filterlib.Filter, /) -> State: ...


@tp.overload
Expand All @@ -1011,8 +990,7 @@ def state(
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[State, ...]:
...
) -> tuple[State, ...]: ...


def state(
Expand Down Expand Up @@ -1042,8 +1020,7 @@ def pop(
node,
filter: filterlib.Filter,
/,
) -> State:
...
) -> State: ...


@tp.overload
Expand All @@ -1053,8 +1030,7 @@ def pop(
filter2: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[State, ...]:
...
) -> tuple[State, ...]: ...


def pop(node, *filters: filterlib.Filter) -> tp.Union[State, tuple[State, ...]]:
Expand Down Expand Up @@ -1173,169 +1149,10 @@ def _maybe_insert(x):
)


# ---------------------------------------------------------
# GraphNode
# ---------------------------------------------------------


class ModuleState(reprlib.Representable):
__slots__ = ('_trace_state', '_initializing')

def __init__(self, initializing: bool = False):
self._trace_state = tracers.TraceState()
self._initializing = initializing

@property
def trace_state(self) -> tracers.TraceState:
return self._trace_state

@property
def initializing(self) -> bool:
return self._initializing

def __nnx_repr__(self):
yield reprlib.Object(type(self))
yield reprlib.Attr('trace_state', self._trace_state)


class GraphNodeMeta(ABCMeta):
if not tp.TYPE_CHECKING:

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
return _graph_node_meta_call(cls, *args, **kwargs)


def _graph_node_meta_call(cls: tp.Type[G], *args, **kwargs) -> G:
node = cls.__new__(cls, *args, **kwargs)
vars(node)['_graph_node__state'] = ModuleState()
node.__init__(*args, **kwargs) # type: ignore[misc]

return node

@dataclasses.dataclass(frozen=True, repr=False)
class Array:
shape: tp.Tuple[int, ...]
dtype: tp.Any

def __repr__(self):
return f'Array(shape={self.shape}, dtype={self.dtype.name})'


class GraphNode(reprlib.Representable, metaclass=GraphNodeMeta):
if tp.TYPE_CHECKING:
_graph_node__state: ModuleState

def __init_subclass__(cls) -> None:
super().__init_subclass__()

register_graph_node_type(
type=cls,
flatten=cls._graph_node_flatten,
set_key=cls._graph_node_set_key,
pop_key=cls._graph_node_pop_key,
create_empty=cls._graph_node_create_empty,
clear=cls._graph_node_clear,
)

if not tp.TYPE_CHECKING:

def __setattr__(self, name: str, value: Any) -> None:
self._setattr(name, value)

def _setattr(self, name: str, value: tp.Any) -> None:
self.check_valid_context(
f"Cannot mutate '{type(self).__name__}' from different trace level"
)
object.__setattr__(self, name, value)

def check_valid_context(self, error_msg: str) -> None:
if not self._graph_node__state.trace_state.is_valid():
raise errors.TraceContextError(error_msg)

def __deepcopy__(self: G, memo=None) -> G:
graphdef, state = split(self)
graphdef = deepcopy(graphdef)
state = deepcopy(state)
return merge(graphdef, state)

def __nnx_repr__(self):
if CONTEXT.seen_modules_repr is None:
CONTEXT.seen_modules_repr = set()
clear_seen = True
else:
clear_seen = False

if id(self) in CONTEXT.seen_modules_repr:
yield reprlib.Object(type=type(self), empty_repr='...')
return

yield reprlib.Object(type=type(self))
CONTEXT.seen_modules_repr.add(id(self))

try:
for name, value in vars(self).items():
if name.startswith('_'):
continue

def to_shape_dtype(value):
if isinstance(value, Variable):
return value.replace(
raw_value=jax.tree.map(to_shape_dtype, value.raw_value)
)
elif isinstance(value, (np.ndarray, jax.Array)):
return Array(value.shape, value.dtype)
return value

value = jax.tree.map(to_shape_dtype, value)
yield reprlib.Attr(name, repr(value))
finally:
if clear_seen:
CONTEXT.seen_modules_repr = None

# Graph Definition
def _graph_node_flatten(self):
nodes = sorted(
(key, value)
for key, value in vars(self).items()
if key != '_graph_node__state'
)
return nodes, type(self)

def _graph_node_set_key(self, key: Key, value: tp.Any):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
elif (
hasattr(self, key)
and isinstance(variable := getattr(self, key), Variable)
and isinstance(value, VariableState)
):
variable.copy_from_state(value)
else:
setattr(self, key, value)

def _graph_node_pop_key(self, key: Key):
if not isinstance(key, str):
raise KeyError(f'Invalid key: {key!r}')
return vars(self).pop(key)

@staticmethod
def _graph_node_create_empty(node_type: tp.Type[G]) -> G:
node = object.__new__(node_type)
vars(node).update(_graph_node__state=ModuleState())
return node

def _graph_node_clear(self, cls: tp.Type[G]):
module_state = self._graph_node__state
module_vars = vars(self)
module_vars.clear()
module_vars['_graph_node__state'] = module_state


# ---------------------------------------------------------
# Pytree
# ---------------------------------------------------------
class PytreeType:
...
class PytreeType: ...


def is_pytree_node(x: tp.Any) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __setattr__(self, key, value):
super().__setattr__(key, value)

def __iter__(self) -> tp.Iterator[str]:
return (k for k in vars(self) if k != '_graph_node__state')
return (k for k in vars(self) if k != '_object__state')

def __len__(self) -> int:
return len(vars(self))
Expand Down Expand Up @@ -108,7 +108,7 @@ def _graph_node_flatten(self):
nodes: list[tuple[Key, tp.Any]] = sorted(
(int(key), value)
for key, value in vars(self).items()
if key not in ('_graph_node__state', '_length')
if key not in ('_object__state', '_length')
)
nodes.append(('_length', self._length))
return nodes, type(self)
Expand Down

0 comments on commit 4786c5f

Please sign in to comment.