Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sparse] validate BCOO on instantiation #13619

Merged
merged 1 commit into from Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion jax/experimental/sparse/_base.py
Expand Up @@ -62,8 +62,9 @@ def tree_flatten(self):
...

@classmethod
@abc.abstractmethod
def tree_unflatten(cls, aux_data, children):
return cls(children, **aux_data)
...

@abc.abstractmethod
def transpose(self, axes=None):
Expand Down
14 changes: 12 additions & 2 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -31,7 +31,7 @@
from jax.config import config
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import (
_broadcasting_vmap, _count_stored_elements, _safe_asarray,
_broadcasting_vmap, _count_stored_elements,
_dot_general_validated_shape, CuSparseEfficiencyWarning,
SparseEfficiencyError, SparseEfficiencyWarning)
from jax.interpreters import batching
Expand Down Expand Up @@ -2386,10 +2386,11 @@ def __init__(self, args: Tuple[Array, Array], *, shape: Sequence[int],
indices_sorted: bool = False, unique_indices: bool = False):
# JAX transforms will sometimes instantiate pytrees with null values, so we
# must catch that in the initialization of inputs.
self.data, self.indices = _safe_asarray(args) # type: ignore[assignment]
self.data, self.indices = map(jnp.asarray, args)
self.indices_sorted = indices_sorted
self.unique_indices = unique_indices
super().__init__(args, shape=tuple(shape))
_validate_bcoo(self.data, self.indices, self.shape)

def __repr__(self):
name = self.__class__.__name__
Expand Down Expand Up @@ -2582,6 +2583,15 @@ def transpose(self, axes: Optional[Sequence[int]] = None) -> BCOO:
def tree_flatten(self):
return (self.data, self.indices), self._info._asdict()

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices = children
if aux_data.keys() != {'shape', 'indices_sorted', 'unique_indices'}:
raise ValueError(f"BCOO.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj


# vmappable handlers
def _bcoo_to_elt(cont, _, val, axis):
Expand Down
14 changes: 12 additions & 2 deletions jax/experimental/sparse/bcsr.py
Expand Up @@ -25,7 +25,7 @@
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse import bcoo
from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _csr_to_coo, _safe_asarray
from jax.experimental.sparse.util import _broadcasting_vmap, _count_stored_elements, _csr_to_coo
import jax.numpy as jnp
from jax.util import split_list, safe_zip
from jax.interpreters import batching
Expand Down Expand Up @@ -320,8 +320,9 @@ def _sparse_shape(self):
def __init__(self, args, *, shape):
# JAX transforms will sometimes instantiate pytrees with null values, so we
# must catch that in the initialization of inputs.
self.data, self.indices, self.indptr = _safe_asarray(args)
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)
_validate_bcsr(self.data, self.indices, self.indptr, self.shape)

def __repr__(self):
name = self.__class__.__name__
Expand All @@ -348,6 +349,15 @@ def transpose(self, *args, **kwargs):
def tree_flatten(self):
return (self.data, self.indices, self.indptr), {'shape': self.shape}

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape'}:
raise ValueError(f"BCSR.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj

@classmethod
def _empty(cls, shape, *, dtype=None, index_dtype='int32', n_dense=0,
n_batch=0, nse=0):
Expand Down
15 changes: 13 additions & 2 deletions jax/experimental/sparse/coo.py
Expand Up @@ -27,7 +27,7 @@
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.lax.lax import _const
from jax._src.lib.mlir.dialects import mhlo
Expand Down Expand Up @@ -69,7 +69,7 @@ class COO(JAXSparse):

def __init__(self, args: Tuple[Array, Array, Array], *, shape: Shape,
rows_sorted: bool = False, cols_sorted: bool = False):
self.data, self.row, self.col = _safe_asarray(args) # type: ignore[assignment]
self.data, self.row, self.col = map(jnp.asarray, args)
self._rows_sorted = rows_sorted
self._cols_sorted = cols_sorted
super().__init__(args, shape=shape)
Expand Down Expand Up @@ -135,6 +135,17 @@ def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> COO:
def tree_flatten(self) -> Tuple[Tuple[Array, Array, Array], Dict[str, Any]]:
return (self.data, self.row, self.col), self._info._asdict()

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.row, obj.col = children
if aux_data.keys() != {'shape', 'rows_sorted', 'cols_sorted'}:
raise ValueError(f"COO.tree_unflatten: invalid {aux_data=}")
obj.shape = aux_data['shape']
obj._rows_sorted = aux_data['rows_sorted']
obj._cols_sorted = aux_data['cols_sorted']
return obj

def __matmul__(self, other: ArrayLike) -> Array:
if isinstance(other, JAXSparse):
raise NotImplementedError("matmul between two sparse objects.")
Expand Down
24 changes: 21 additions & 3 deletions jax/experimental/sparse/csr.py
Expand Up @@ -26,7 +26,7 @@
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning
from jax import lax
from jax import tree_util
from jax._src.lax.lax import _const
Expand All @@ -51,7 +51,7 @@ class CSR(JAXSparse):
dtype = property(lambda self: self.data.dtype)

def __init__(self, args, *, shape):
self.data, self.indices, self.indptr = _safe_asarray(args)
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)

@classmethod
Expand Down Expand Up @@ -116,6 +116,15 @@ def __matmul__(self, other):
def tree_flatten(self):
return (self.data, self.indices, self.indptr), {"shape": self.shape}

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape'}:
raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj


@tree_util.register_pytree_node_class
class CSC(JAXSparse):
Expand All @@ -128,7 +137,7 @@ class CSC(JAXSparse):
dtype = property(lambda self: self.data.dtype)

def __init__(self, args, *, shape):
self.data, self.indices, self.indptr = _safe_asarray(args)
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)

@classmethod
Expand Down Expand Up @@ -174,6 +183,15 @@ def __matmul__(self, other):
def tree_flatten(self):
return (self.data, self.indices, self.indptr), {"shape": self.shape}

@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data, obj.indices, obj.indptr = children
if aux_data.keys() != {'shape'}:
raise ValueError(f"CSC.tree_unflatten: invalid {aux_data=}")
obj.__dict__.update(**aux_data)
return obj


#--------------------------------------------------------------------
# csr_todense
Expand Down
5 changes: 0 additions & 5 deletions jax/experimental/sparse/util.py
Expand Up @@ -107,11 +107,6 @@ def _is_aval(*args: Any) -> bool:
def _is_arginfo(*args: Any) -> bool:
return all(isinstance(arg, stages.ArgInfo) for arg in args)

def _safe_asarray(args: Sequence[Any]) -> Iterable[Union[np.ndarray, Array]]:
if _is_pytree_placeholder(*args) or _is_aval(*args) or _is_arginfo(*args):
return args
return map(_asarray_or_float0, args)

def _dot_general_validated_shape(
lhs_shape: Tuple[int, ...], rhs_shape: Tuple[int, ...],
dimension_numbers: DotDimensionNumbers) -> Tuple[int, ...]:
Expand Down
3 changes: 2 additions & 1 deletion tests/sparse_test.py
Expand Up @@ -692,7 +692,8 @@ def test_repr(self):
y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1, n_dense=1)
self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=1, n_batch=1, n_dense=1)")

M_invalid = sparse.BCOO(([], []), shape=(100,))
M_invalid = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3))
M_invalid.indices = jnp.array([])
self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")

@jit
Expand Down