Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding custom calls cannot be hoisted individually in WhileLoopInvariantCodeMotion. #66574

Merged
merged 1 commit into from May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 11 additions & 3 deletions third_party/xla/xla/service/while_loop_invariant_code_motion.cc
Expand Up @@ -15,6 +15,11 @@ limitations under the License.

#include "xla/service/while_loop_invariant_code_motion.h"

#include <cstdint>
#include <iterator>
#include <string>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -32,7 +37,6 @@ limitations under the License.
#include "xla/service/while_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/statusor.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -112,10 +116,14 @@ static void CreateLoopInvariantCopy(
}

// Returns true if `instruction` is worth hoisting only if it lets us hoist some
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification and fusion in the while body.
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification, fusion, and sharding annotation in the while body.
bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
const HloInstruction& instruction) {
if (instruction.IsCustomCall("Sharding")) {
return true;
}

switch (instruction.opcode()) {
default:
return false;
Expand Down
Expand Up @@ -639,7 +639,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) {
EXPECT_FALSE(simplified_loop);
}

TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistSPMDFullToShardShape) {
auto m = CreateNewVerifiedModule();
auto array_s32 = ShapeUtil::MakeShape(S32, {4});
Shape while_shape =
Expand Down Expand Up @@ -690,5 +690,43 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
EXPECT_FALSE(simplified_loop);
}

TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
const char* const kHloModule = R"(
HloModule ModuleWithWhile

body {
p_body = (f32[2], f32[2], s32[]) parameter(0)
gte.0 = f32[2] get-tuple-element(p_body), index=0
gte.1 = f32[2] get-tuple-element(p_body), index=1
sharding.0 = f32[2] custom-call(gte.0), custom_call_target="Sharding", sharding={devices=[2]<=[2]}
sharding.1 = f32[2] custom-call(gte.1), custom_call_target="Sharding", sharding={replicated}
add.0 = f32[2] add(sharding.0, sharding.1)
gte.2 = s32[] get-tuple-element(p_body), index=2
const = s32[] constant(1)
add.1 = s32[] add(gte.2, const)
ROOT root = (f32[2], f32[2], s32[]) tuple(gte.0, add.0, add.1)
}

condition {
p_cond = (f32[2], f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=2
const = s32[] constant(5)
ROOT result = pred[] compare(gte, const), direction=LT
}

ENTRY entry {
param.0 = f32[2] parameter(0)
param.1 = s32[] parameter(1)
while_init = (f32[2], f32[2], s32[]) tuple(param.0, param.0, param.1)
ROOT while = (f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));

TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(module.get()));
EXPECT_FALSE(simplified_loop);
}

} // namespace
} // namespace xla