Skip to content

Commit

Permalink
Add unbounded dynamism test for AllToAllOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623939802
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed May 7, 2024
1 parent 783c383 commit 78ffcaf
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 20 deletions.
71 changes: 67 additions & 4 deletions third_party/xla/xla/client/xla_builder.cc
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "xla/client/xla_builder.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
Expand Down Expand Up @@ -3800,15 +3801,55 @@ XlaOp XlaBuilder::AllToAllArray(
return all_to_all;
}
DimensionVector sizes;
const bool is_unbounded = operand_shape->is_unbounded_dynamic();
std::vector<XlaOp> dynamic_sizes;
auto GetR1DimensionSizeOrConstant = [&](XlaOp operand,
int64_t dimension) -> XlaOp {
if (operand_shape->is_unbounded_dynamic_dimension(dimension)) {
return Reshape(GetDimensionSize(operand, dimension), {1});
}
return ConstantR1<int32_t>(
this, {static_cast<int32_t>(operand_shape->dimensions(dimension))});
};
XlaOp r1_split_count =
ConstantR1<int32_t>(this, {static_cast<int32_t>(split_count)});
for (int64_t i = 0; i < operand_shape->rank(); ++i) {
if (i != split_dimension) {
sizes.push_back(operand_shape->dimensions(i));
if (is_unbounded) {
dynamic_sizes.push_back(GetR1DimensionSizeOrConstant(operand, i));
}
continue;
}
sizes.push_back(split_count);
sizes.push_back(operand_shape->dimensions(i) / split_count);
sizes.push_back(operand_shape->is_unbounded_dynamic_dimension(i)
? Shape::kUnboundedSize
: operand_shape->dimensions(i) / split_count);

if (is_unbounded) {
dynamic_sizes.push_back(r1_split_count);
dynamic_sizes.push_back(
operand_shape->is_unbounded_dynamic_dimension(i)
? Div(GetR1DimensionSizeOrConstant(operand, i), r1_split_count)
: ConstantR1<int32_t>(this,
{static_cast<int32_t>(sizes.back())}));
}
}

if (is_unbounded) {
std::vector<bool> dynamic_dimensions;
std::transform(
sizes.begin(), sizes.end(), std::back_inserter(dynamic_dimensions),
[](int64_t size) { return size == Shape::kUnboundedSize; });
TF_ASSIGN_OR_RETURN(
const Shape shape,
ShapeUtil::MakeValidatedShape(all_to_all_shape.element_type(), sizes,
dynamic_dimensions));
all_to_all =
MhloDynamicReshape(all_to_all, ConcatInDim(dynamic_sizes, 0), shape);
} else {
all_to_all = Reshape(all_to_all, sizes);
}
all_to_all = Reshape(all_to_all, sizes);

std::vector<int64_t> permutation;
const auto rank = operand_shape->rank();
Expand All @@ -3821,6 +3862,21 @@ XlaOp XlaBuilder::AllToAllArray(
permutation.push_back(dim_after_reshape);
}
all_to_all = Transpose(all_to_all, permutation);

if (is_unbounded) {
std::vector<XlaOp> new_dimensions;
for (int64_t i = 0; i < operand_shape->rank(); ++i) {
new_dimensions.push_back(GetR1DimensionSizeOrConstant(operand, i));
}
new_dimensions[split_dimension] =
Div(new_dimensions[split_dimension], r1_split_count);
new_dimensions[concat_dimension] =
Mul(new_dimensions[concat_dimension], r1_split_count);

return MhloDynamicReshape(all_to_all, ConcatInDim(new_dimensions, 0),
all_to_all_shape);
}

return Reshape(all_to_all_shape, all_to_all);
});
}
Expand Down Expand Up @@ -3876,6 +3932,13 @@ XlaOp XlaBuilder::AllToAllTuple(
const std::optional<ChannelHandle>& channel_id) {
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
if (operand_shape->is_unbounded_dynamic() ||
split_dimension == Shape::kUnboundedSize ||
concat_dimension == Shape::kUnboundedSize ||
split_count == Shape::kUnboundedSize) {
return InvalidArgument(
"AllToAllTuple does not support unbounded dynamic shapes");
}

// The HloInstruction for AllToAll currently only handles the data
// communication: it accepts N already split parts and scatters them to N
Expand All @@ -3901,14 +3964,14 @@ XlaOp XlaBuilder::AllToAllTuple(
}

// Handle data communication.
XlaOp alltoall =
XlaOp all_to_all =
this->AllToAllTuple(slices, replica_groups, layout, channel_id);

// Concat the N received parts.
std::vector<XlaOp> received;
received.reserve(split_count);
for (int i = 0; i < split_count; i++) {
received.push_back(this->GetTupleElement(alltoall, i));
received.push_back(this->GetTupleElement(all_to_all, i));
}
return this->ConcatInDim(received, concat_dimension);
});
Expand Down
5 changes: 4 additions & 1 deletion third_party/xla/xla/client/xla_builder.h
Expand Up @@ -2566,7 +2566,10 @@ XlaOp ReduceScatter(
const std::optional<Layout>& layout = std::nullopt,
std::optional<bool> use_global_device_ids = std::nullopt);

// Enqueues an operation that do an Alltoall of the operand cross cores.
// Enqueues an operation that do an AllToAll of the operand cross cores.
// This involves AllToAll, followed by Reshape, Transpose, and another Reshape
// to get proper codegen. See implementation for additional details.
//
// An optional `layout` can be specified to force the layout of the instruction.
// This is used to guarantee the same layout for a group of AllToAll ops
// compiled separately.
Expand Down
124 changes: 124 additions & 0 deletions third_party/xla/xla/client/xla_builder_test.cc
Expand Up @@ -1990,6 +1990,130 @@ TEST(XlaBuilderTest, UnboundedAllReduce) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedAllToAllDynamicSplitDimension) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 45]"));
AllToAll(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*split_dimension=*/0,
/*concat_dimension=*/1,
/*split_count=*/3,
/*replica_groups=*/{});
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
std::cout << module->ToString() << "\n";
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedAllToAllDynamicConcatDimension) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 5]"));
AllToAll(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*split_dimension=*/1,
/*concat_dimension=*/0,
/*split_count=*/3,
/*replica_groups=*/{});
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
std::cout << module->ToString() << "\n";
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedAllToAllDynamicSplitAndConcatDimensionEqual) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 15]"));
AllToAll(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*split_dimension=*/0,
/*concat_dimension=*/0,
/*split_count=*/3,
/*replica_groups=*/{});
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
std::cout << module->ToString() << "\n";
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedAllToAllFullyDynamic) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, ?]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, ?]"));
AllToAll(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*split_dimension=*/0,
/*concat_dimension=*/1,
/*split_count=*/3,
/*replica_groups=*/{});
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<HloModule> module,
BuildHloModule(b));
std::cout << module->ToString() << "\n";
EXPECT_THAT(GetRoot(*module),
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedAllToAllTupleVariadicUnsupported) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]{1,0}"));
b.ReportErrorOrReturn(
AllToAllTuple(/*operands=*/{Parameter(&b, 0, operand, "operand0"),
Parameter(&b, 1, operand, "operand1")},
/*replica_groups=*/{}));
EXPECT_THAT(
BuildHloModule(b),
StatusIs(_,
HasSubstr(
"AllToAllTuple does not support unbounded dynamic shapes")));
}

TEST(XlaBuilderTest, UnboundedAllToAllTupleUnsupported) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 15]{1,0}"));
b.ReportErrorOrReturn(
AllToAllTuple(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*split_dimension=*/0,
/*concat_dimension=*/1,
/*split_count=*/3,
/*replica_groups=*/{}));
EXPECT_THAT(
BuildHloModule(b),
StatusIs(_,
HasSubstr(
"AllToAllTuple does not support unbounded dynamic shapes")));
}

TEST(XlaBuilderTest, BoundedAllToAllTupleUnsupported) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, <=15]{1,0}"));
b.ReportErrorOrReturn(
AllToAllTuple(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*split_dimension=*/0,
/*concat_dimension=*/1,
/*split_count=*/3,
/*replica_groups=*/{}));
EXPECT_THAT(
BuildHloModule(b),
StatusIs(_,
HasSubstr("AllToAll does not support bounded dynamic shapes")));
}

TEST(XlaBuilderTest, BoundedAllToAllUnsupported) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[3, <=15]{1,0}"));
b.ReportErrorOrReturn(
AllToAllTuple(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*split_dimension=*/0,
/*concat_dimension=*/1,
/*split_count=*/3,
/*replica_groups=*/{}));
EXPECT_THAT(
BuildHloModule(b),
StatusIs(_,
HasSubstr("AllToAll does not support bounded dynamic shapes")));
}

TEST(XlaBuilderTest, UnboundedAnd) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs,
Expand Down
30 changes: 24 additions & 6 deletions third_party/xla/xla/service/shape_inference.cc
Expand Up @@ -2492,6 +2492,8 @@ ShapeInference::InferScalarBroadcastShape(absl::Span<const Shape> shapes) {
const Shape& shape, int64_t split_dimension, int64_t concat_dimension,
int64_t split_count) {
TF_RET_CHECK(split_count > 0);
TF_RET_CHECK(!shape.is_bounded_dynamic())
<< "AllToAll does not support bounded dynamic shapes";
if (split_dimension >= shape.rank() || split_dimension < 0) {
return InvalidArgument(
"AllToAll split_dimension %d is out-of-bounds in shape %s.",
Expand All @@ -2502,25 +2504,41 @@ ShapeInference::InferScalarBroadcastShape(absl::Span<const Shape> shapes) {
"AllToAll concat_dimension %d is out-of-bounds in shape %s.",
concat_dimension, ShapeUtil::HumanString(shape));
}
if (shape.dimensions(split_dimension) % split_count != 0) {
int64_t split_dimension_size = shape.dimensions(split_dimension);
if (!IsUnboundedDynamicSize(split_dimension_size) &&
split_dimension_size % split_count != 0) {
return InvalidArgument(
"AllToAll split dimension size %d must be dividable by split_count "
"%d.",
shape.dimensions(split_dimension), split_count);
split_dimension_size, split_count);
}
std::vector<int64_t> new_dimensions(shape.dimensions().begin(),
shape.dimensions().end());
new_dimensions[split_dimension] /= split_count;
new_dimensions[concat_dimension] *= split_count;
return ShapeUtil::MakeShape(shape.element_type(), new_dimensions);
new_dimensions[split_dimension] =
IsUnboundedDynamicSize(new_dimensions[split_dimension])
? Shape::kUnboundedSize
: new_dimensions[split_dimension] / split_count;
new_dimensions[concat_dimension] =
IsUnboundedDynamicSize(new_dimensions[concat_dimension])
? Shape::kUnboundedSize
: new_dimensions[concat_dimension] * split_count;

const std::vector<bool> dynamic_dimensions(shape.dynamic_dimensions().begin(),
shape.dynamic_dimensions().end());
return ShapeUtil::MakeShape(shape.element_type(), new_dimensions,
dynamic_dimensions);
}

/* static */ absl::StatusOr<Shape> ShapeInference::InferAllToAllTupleShape(
absl::Span<const Shape* const> operand_shapes) {
// An Alltoall HLO instruction receives N operands (with the same shape) and
// An AllToAll HLO instruction receives N operands (with the same shape) and
// returns a tuple that contains N array shapes.
TF_RET_CHECK(!operand_shapes.empty());
for (int i = 0; i < operand_shapes.size(); i++) {
if (operand_shapes[i]->is_unbounded_dynamic()) {
return InvalidArgument(
"AllToAllTuple does not support unbounded dynamic shapes");
}
if (!Shape::Equal().IgnoreMemorySpaceInLayout()(*operand_shapes[0],
*operand_shapes[i])) {
return InvalidArgument(
Expand Down
26 changes: 26 additions & 0 deletions third_party/xla/xla/service/shape_inference_test.cc
Expand Up @@ -4056,6 +4056,32 @@ TEST_F(ShapeInferenceTest, UnboundedAllReduce) {
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedAllToAll) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(
const Shape inferred_shape,
ShapeInference::InferAllToAllShape(/*shape=*/operand,
/*split_dimension=*/0,
/*concat_dimension=*/0,
/*split_count=*/3));
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape, expected))
<< "inferred: " << ShapeUtil::HumanString(inferred_shape)
<< " expected: " << ShapeUtil::HumanString(expected);
}

TEST_F(ShapeInferenceTest, UnboundedAllToAllTupleUnsupported) {
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape expected,
ParseShape("(f32[?, 10], f32[?, 10])"));
const absl::StatusOr<Shape> inferred_shape =
ShapeInference::InferAllToAllTupleShape(
/*operand_shapes=*/{&operand, &operand});
EXPECT_THAT(
inferred_shape.status().message(),
HasSubstr("AllToAllTuple does not support unbounded dynamic shapes"));
}

TEST_P(UnboundedLogicalOpShapeInferenceTest, UnboundedAnd) {
TF_ASSERT_OK_AND_ASSIGN(const Shape lhs, ParseShape(GetParam().lhs));
TF_ASSERT_OK_AND_ASSIGN(const Shape rhs, ParseShape(GetParam().rhs));
Expand Down
7 changes: 6 additions & 1 deletion third_party/xla/xla/tools/multihost_hlo_runner/BUILD
Expand Up @@ -117,7 +117,10 @@ xla_test(
"notap",
],
},
backends = ["gpu"],
backends = [
"cpu",
"gpu",
],
data = [
"data/sharded_16_devices.hlo",
"data/sharded_2_devices.hlo",
Expand All @@ -127,6 +130,8 @@ xla_test(
tags = ["nomac"],
deps = [
":functional_hlo_runner",
"//xla:statusor",
"//xla/pjrt:pjrt_client",
"//xla/tests:filecheck",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/lib/core:status_test_util",
Expand Down

0 comments on commit 78ffcaf

Please sign in to comment.