Skip to content

Commit

Permalink
Move with_sharding_constraint out of experimental into jax.lax na…
Browse files Browse the repository at this point in the history
…mespace.

PiperOrigin-RevId: 494635809
  • Loading branch information
yashk2810 authored and jax authors committed Dec 12, 2022
1 parent 94590e2 commit 13c34f9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -24,6 +24,8 @@ Remember to align the itemized text with the first line of an item within a list
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months.
* `with_sharding_constraint`s new public endpoint is
`jax.lax.with_sharding_constraint`.
* If using ABSL flags together with `jax.config`, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
Expand Down
2 changes: 2 additions & 0 deletions jax/lax/__init__.py
Expand Up @@ -363,3 +363,5 @@
)
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
from jax.lax import linalg as linalg

from jax.experimental.pjit import with_sharding_constraint
14 changes: 7 additions & 7 deletions tests/pjit_test.py
Expand Up @@ -35,6 +35,7 @@
from jax import stages
from jax.errors import JAXTypeError
from jax import lax
from jax.lax import with_sharding_constraint
from jax import prng
from jax.sharding import PartitionSpec as P
from jax.experimental import maps
Expand All @@ -45,8 +46,7 @@
from jax._src import array
from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding
import jax.experimental.pjit as pjit_lib
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
FROM_GDA, AUTO)
from jax.experimental.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
from jax.interpreters import pxla
from jax.interpreters import mlir
from jax._src.lib import xla_client as xc, xla_bridge, xla_extension_version
Expand Down Expand Up @@ -454,7 +454,7 @@ def f(x):
def testShardingConstraintPyTree(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x):
x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
x = jax.lax.with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
x = x.copy()
x[0]["a"] *= 2
return x
Expand Down Expand Up @@ -2432,7 +2432,7 @@ def test_with_sharding_constraint_jit(self):
@partial(jax.jit, static_argnums=(0, 1))
def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))

out = sharded_zeros((4096, 3072), P('x', 'y'))
out_s = NamedSharding(mesh, P('x', 'y'))
Expand All @@ -2447,7 +2447,7 @@ def test_with_sharding_constraint_pjit(self):
@partial(pjit, static_argnums=(0, 1))
def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))

out = sharded_zeros((4096, 3072), P('x', 'y'))
out_s = NamedSharding(mesh, P('x', 'y'))
Expand All @@ -2461,7 +2461,7 @@ def test_jit_with_sharding_constraint_committed_inp_error(self):

@jax.jit
def sharded_inp(inp):
return pjit_lib.with_sharding_constraint(
return jax.lax.with_sharding_constraint(
inp, NamedSharding(mesh, P('x', 'y')))

committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0])
Expand All @@ -2477,7 +2477,7 @@ def test_jit_device_with_sharding_constraint_error(self):
@partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0])
def sharded_zeros(shape, pspec):
out = jnp.zeros(shape, jnp.bfloat16)
return pjit_lib.with_sharding_constraint(out, NamedSharding(mesh, pspec))
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))

with self.assertRaisesRegex(
ValueError,
Expand Down

0 comments on commit 13c34f9

Please sign in to comment.