Skip to content

Commit

Permalink
Sharding custom calls cannot be hoisted individually in `WhileLoopInv…
Browse files Browse the repository at this point in the history
…ariantCodeMotion`.

A sharding custom call annotates sharding to its operand. If we move a single sharding custom call out of the while body, the sharding annotation may not take effect as expected.

Taking the follow while body as an example,
```
body {
  p_body = (f32[2], f32[2], s32[]) parameter(0)
  gte.0 = f32[2] get-tuple-element(p_body), index=0
  gte.1 = f32[2] get-tuple-element(p_body), index=1
  sharding.0 = f32[2] custom-call(gte.0), custom_call_target="Sharding", sharding={devices=[2]<=[2]}
  sharding.1 = f32[2] custom-call(gte.1), custom_call_target="Sharding", sharding={replicated}
  add.0 = f32[2] add(sharding.0, sharding.1)
  gte.2 = s32[] get-tuple-element(p_body), index=2
  const = s32[] constant(1)
  add.1 = s32[] add(gte.2, const)
  ROOT root = (f32[2], f32[2], s32[]) tuple(gte.0, add.0, add.1)
}
```

Before this cl, `WhileLoopInvariantCodeMotion` moves `sharding.0` out of the body and keeps `sharding.1` in the body. With this cl, `WhileLoopInvariantCodeMotion` does not modify this body.

PiperOrigin-RevId: 629637435
  • Loading branch information
ZixuanJiang authored and tensorflower-gardener committed May 1, 2024
1 parent 839aa17 commit 5c90e79
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
14 changes: 11 additions & 3 deletions third_party/xla/xla/service/while_loop_invariant_code_motion.cc
Expand Up @@ -15,6 +15,11 @@ limitations under the License.

#include "xla/service/while_loop_invariant_code_motion.h"

#include <cstdint>
#include <iterator>
#include <string>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -32,7 +37,6 @@ limitations under the License.
#include "xla/service/while_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/statusor.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -112,10 +116,14 @@ static void CreateLoopInvariantCopy(
}

// Returns true if `instruction` is worth hoisting only if it lets us hoist some
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification and fusion in the while body.
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification, fusion, and sharding annotation in the while body.
bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
const HloInstruction& instruction) {
if (instruction.IsCustomCall("Sharding")) {
return true;
}

switch (instruction.opcode()) {
default:
return false;
Expand Down
Expand Up @@ -639,7 +639,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) {
EXPECT_FALSE(simplified_loop);
}

TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistSPMDFullToShardShape) {
auto m = CreateNewVerifiedModule();
auto array_s32 = ShapeUtil::MakeShape(S32, {4});
Shape while_shape =
Expand Down Expand Up @@ -690,5 +690,43 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
EXPECT_FALSE(simplified_loop);
}

TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
const char* const kHloModule = R"(
HloModule ModuleWithWhile
body {
p_body = (f32[2], f32[2], s32[]) parameter(0)
gte.0 = f32[2] get-tuple-element(p_body), index=0
gte.1 = f32[2] get-tuple-element(p_body), index=1
sharding.0 = f32[2] custom-call(gte.0), custom_call_target="Sharding", sharding={devices=[2]<=[2]}
sharding.1 = f32[2] custom-call(gte.1), custom_call_target="Sharding", sharding={replicated}
add.0 = f32[2] add(sharding.0, sharding.1)
gte.2 = s32[] get-tuple-element(p_body), index=2
const = s32[] constant(1)
add.1 = s32[] add(gte.2, const)
ROOT root = (f32[2], f32[2], s32[]) tuple(gte.0, add.0, add.1)
}
condition {
p_cond = (f32[2], f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=2
const = s32[] constant(5)
ROOT result = pred[] compare(gte, const), direction=LT
}
ENTRY entry {
param.0 = f32[2] parameter(0)
param.1 = s32[] parameter(1)
while_init = (f32[2], f32[2], s32[]) tuple(param.0, param.0, param.1)
ROOT while = (f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));

TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(module.get()));
EXPECT_FALSE(simplified_loop);
}

} // namespace
} // namespace xla

0 comments on commit 5c90e79

Please sign in to comment.