Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628553906
  • Loading branch information
tensorflower-gardener committed May 1, 2024
1 parent 839aa17 commit edecbd3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
Expand Up @@ -30,9 +30,7 @@ namespace internal {

namespace {

using llvm::DenseSet;
using mlir::Operation;
using mlir::TypeID;
using mlir::WalkResult;

#define GEN_PASS_DEF_INPUTLOWERINGMETRICSPASS
Expand Down
14 changes: 11 additions & 3 deletions third_party/xla/xla/service/while_loop_invariant_code_motion.cc
Expand Up @@ -15,6 +15,11 @@ limitations under the License.

#include "xla/service/while_loop_invariant_code_motion.h"

#include <cstdint>
#include <iterator>
#include <string>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -32,7 +37,6 @@ limitations under the License.
#include "xla/service/while_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/statusor.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -112,10 +116,14 @@ static void CreateLoopInvariantCopy(
}

// Returns true if `instruction` is worth hoisting only if it lets us hoist some
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification and fusion in the while body.
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification, fusion, and sharding annotation in the while body.
bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
const HloInstruction& instruction) {
if (instruction.IsCustomCall("Sharding")) {
return true;
}

switch (instruction.opcode()) {
default:
return false;
Expand Down
Expand Up @@ -639,7 +639,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) {
EXPECT_FALSE(simplified_loop);
}

TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistSPMDFullToShardShape) {
auto m = CreateNewVerifiedModule();
auto array_s32 = ShapeUtil::MakeShape(S32, {4});
Shape while_shape =
Expand Down Expand Up @@ -690,5 +690,43 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
EXPECT_FALSE(simplified_loop);
}

TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistShardingCustomCalls) {
const char* const kHloModule = R"(
HloModule ModuleWithWhile
body {
p_body = (f32[2], f32[2], s32[]) parameter(0)
gte.0 = f32[2] get-tuple-element(p_body), index=0
gte.1 = f32[2] get-tuple-element(p_body), index=1
sharding.0 = f32[2] custom-call(gte.0), custom_call_target="Sharding", sharding={devices=[2]<=[2]}
sharding.1 = f32[2] custom-call(gte.1), custom_call_target="Sharding", sharding={replicated}
add.0 = f32[2] add(sharding.0, sharding.1)
gte.2 = s32[] get-tuple-element(p_body), index=2
const = s32[] constant(1)
add.1 = s32[] add(gte.2, const)
ROOT root = (f32[2], f32[2], s32[]) tuple(gte.0, add.0, add.1)
}
condition {
p_cond = (f32[2], f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=2
const = s32[] constant(5)
ROOT result = pred[] compare(gte, const), direction=LT
}
ENTRY entry {
param.0 = f32[2] parameter(0)
param.1 = s32[] parameter(1)
while_init = (f32[2], f32[2], s32[]) tuple(param.0, param.0, param.1)
ROOT while = (f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));

TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(module.get()));
EXPECT_FALSE(simplified_loop);
}

} // namespace
} // namespace xla

0 comments on commit edecbd3

Please sign in to comment.