Skip to content

Commit

Permalink
Use ShapeUtil::HumanString instead of calling Shape::ToString dir…
Browse files Browse the repository at this point in the history
…ectly in `xla_builder.cc`.

PiperOrigin-RevId: 625081373
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed May 6, 2024
1 parent c8aba81 commit f61ff93
Show file tree
Hide file tree
Showing 7 changed files with 1,100 additions and 36 deletions.
472 changes: 472 additions & 0 deletions third_party/triton/temporary/pipelining.patch

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Expand Up @@ -6,4 +6,5 @@ internal patch during the next triton integration process.
"""

temporary_patch_list = [
"//third_party/triton/temporary:pipelining.patch",
]
472 changes: 472 additions & 0 deletions third_party/xla/third_party/triton/temporary/pipelining.patch

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions third_party/xla/third_party/triton/temporary/series.bzl
Expand Up @@ -6,4 +6,5 @@ internal patch during the next triton integration process.
"""

temporary_patch_list = [
"//third_party/triton/temporary:pipelining.patch",
]
78 changes: 61 additions & 17 deletions third_party/xla/xla/client/xla_builder.cc
Expand Up @@ -869,7 +869,43 @@ absl::StatusOr<XlaComputation> XlaBuilder::Build(
return OkStatus();
}

XlaOp XlaBuilder::DynamicBroadcastInDim(
XlaOp XlaBuilder::MhloDynamicReshape(XlaOp operand, XlaOp output_shape,
const Shape& shape) {
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
if (operand_shape->element_type() != shape.element_type()) {
return InvalidArgument(
"Element type of operand %s and output %s must match",
ShapeUtil::HumanString(*operand_shape),
ShapeUtil::HumanString(shape));
}
if (operand_shape->is_static() && shape.is_static() &&
ShapeUtil::ElementsIn(*operand_shape) != ShapeUtil::ElementsIn(shape)) {
return InvalidArgument(
"MhloDynamicReshape has mismatched element counts: from=%d (%s) "
"to=%d (%s)",
ShapeUtil::ElementsIn(*operand_shape),
ShapeUtil::HumanString(*operand_shape), ShapeUtil::ElementsIn(shape),
ShapeUtil::HumanString(shape));
}
TF_ASSIGN_OR_RETURN(const Shape* output_shape_shape,
GetShapePtr(output_shape));
if (output_shape_shape->dimensions(0) != shape.rank()) {
return InvalidArgument(
"output_shape dimension size=%d (%s) and rank of shape=%d (%s) must "
"match",
output_shape_shape->dimensions(0),
ShapeUtil::HumanString(*output_shape_shape), shape.rank(),
ShapeUtil::HumanString(shape));
}
return xla::CustomCall(operand.builder(), "mhlo.dynamic_reshape",
/*operands=*/{operand, output_shape},
/*shape=*/shape,
/*opaque=*/"");
});
};

XlaOp XlaBuilder::MhloDynamicBroadcastInDim(
const XlaOp operand, const XlaOp output_dimensions,
absl::Span<const int64_t> broadcast_dimensions, const Shape& output_shape) {
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
Expand All @@ -879,7 +915,7 @@ XlaOp XlaBuilder::DynamicBroadcastInDim(

if (!output_dimensions_shape->IsInteger()) {
return InvalidArgument("output_dimensions must be an integer type %s",
output_dimensions_shape->ToString());
ShapeUtil::HumanString(*output_dimensions_shape));
}

if (output_dimensions_shape->rank() != 1) {
Expand Down Expand Up @@ -954,8 +990,8 @@ absl::StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
TF_RET_CHECK(operand_shape->is_bounded_dynamic_dimension(
it - broadcast_dimensions.begin()) ==
shape.is_bounded_dynamic_dimension(i))
<< " i: " << i << ", shape: " << shape.ToString()
<< ", operand_shape: " << operand_shape->ToString();
<< " i: " << i << ", shape: " << ShapeUtil::HumanString(shape)
<< ", operand_shape: " << ShapeUtil::HumanString(*operand_shape);
} else {
// Non-broadcast dimensions must be static.
TF_RET_CHECK(shape.is_static_dimension(i));
Expand Down Expand Up @@ -1084,7 +1120,7 @@ absl::StatusOr<std::vector<XlaOp>> ExtractDimensionSizesAndPadOnesToLeft(
// Broadcast `scalar` to `output_shape` with all shapes static at runtime. If a
// dimension of `output_shape` is dynamic, get the dimension size of the dynamic
// dimension from `output` and reshape them to `tensor<1xi32>`. This is used as
// one of the inputs to DynamicBroadcastInDim.
// one of the inputs to MhloDynamicBroadcastInDim.
absl::StatusOr<XlaOp> BroadcastScalarToOutputShapeWithUnbounded(
XlaBuilder* builder, XlaOp scalar, XlaOp output,
const Shape& output_shape) {
Expand All @@ -1100,7 +1136,7 @@ absl::StatusOr<XlaOp> BroadcastScalarToOutputShapeWithUnbounded(
/*values=*/{static_cast<int32_t>(output_shape.dimensions(i))})
: Reshape(GetDimensionSize(output, i), {1});
}
return DynamicBroadcastInDim(
return MhloDynamicBroadcastInDim(
scalar, /*output_dimensions=*/ConcatInDim(builder, output_sizes, 0), {},
output_shape);
}
Expand All @@ -1117,8 +1153,8 @@ absl::StatusOr<XlaOp> DegenerateBroadcastWithUnbounded(
std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(),
output_shape.rank() - operand_shape->rank());

return DynamicBroadcastInDim(operand, output_dimensions, broadcast_dimensions,
output_shape);
return MhloDynamicBroadcastInDim(operand, output_dimensions,
broadcast_dimensions, output_shape);
}

// Helper struct to store the result of `BroadcastToOutputShapeWithUnbounded`.
Expand Down Expand Up @@ -1387,7 +1423,7 @@ XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
if (!shape.is_static()) {
return InvalidArgument(
"The output of iota must not have dynamic dimensions: %s",
shape.ToString());
ShapeUtil::HumanString(shape));
}
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
Expand Down Expand Up @@ -1479,7 +1515,7 @@ XlaOp XlaBuilder::BroadcastInDim(
operand_shape->element_type(), out_dim_size));
TF_RET_CHECK(!output_shape.is_unbounded_dynamic())
<< "BroadcastInDim output must shape be static or bounded dynamic "
<< output_shape.ToString();
<< ShapeUtil::HumanString(output_shape);
int64_t broadcast_rank = broadcast_dimensions.size();
if (operand_shape->rank() != broadcast_rank) {
return InvalidArgument(
Expand Down Expand Up @@ -3164,13 +3200,14 @@ XlaOp XlaBuilder::AllReduceImpl(XlaOp operand,
if (layout) {
if (!LayoutUtil::HasLayout(*layout)) {
return InvalidArgument("shape_with_layout must have the layout set: %s",
layout->ToString());
ShapeUtil::HumanString(*layout));
}
if (!ShapeUtil::Compatible(*layout, *operand_shape)) {
return InvalidArgument(
"Provided shape_with_layout must be compatible with the "
"operand shape: %s vs %s",
layout->ToString(), operand_shape->ToString());
ShapeUtil::HumanString(*layout),
ShapeUtil::HumanString(*operand_shape));
}
instr.set_constrain_layout(true);
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
Expand Down Expand Up @@ -3812,7 +3849,8 @@ XlaOp XlaBuilder::AllToAllTuple(
return InvalidArgument(
"Provided layout must be compatible with the operands' shape. "
"The layout is %s, but operand %d has shape %s.",
layout->ToString(), i, shape.tuple_shapes(i).ToString());
layout->ToString(), i,
ShapeUtil::HumanString(shape.tuple_shapes(i)));
}
*(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
}
Expand Down Expand Up @@ -4726,10 +4764,16 @@ XlaOp BroadcastInDim(const XlaOp operand,
broadcast_dimensions);
}

XlaOp DynamicBroadcastInDim(const XlaOp operand, const XlaOp output_dimensions,
absl::Span<const int64_t> broadcast_dimensions,
const Shape& output_shape) {
return operand.builder()->DynamicBroadcastInDim(
XlaOp MhloDynamicReshape(const XlaOp operand, const XlaOp output_shape,
const Shape& shape) {
return operand.builder()->MhloDynamicReshape(operand, output_shape, shape);
}

XlaOp MhloDynamicBroadcastInDim(const XlaOp operand,
const XlaOp output_dimensions,
absl::Span<const int64_t> broadcast_dimensions,
const Shape& output_shape) {
return operand.builder()->MhloDynamicBroadcastInDim(
operand, output_dimensions, broadcast_dimensions, output_shape);
}

Expand Down
26 changes: 19 additions & 7 deletions third_party/xla/xla/client/xla_builder.h
Expand Up @@ -524,9 +524,10 @@ class XlaBuilder {
// op from the XlaBuilder. This is only intended for export to MHLO or
// StableHLO, and cannot be compiled. Only static output_dimensions are
// allowed, and broadcast_dimensions is verified.
XlaOp DynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions,
absl::Span<const int64_t> broadcast_dimensions,
const Shape& output_shape);
XlaOp MhloDynamicBroadcastInDim(
XlaOp operand, XlaOp output_dimensions,
absl::Span<const int64_t> broadcast_dimensions,
const Shape& output_shape);

XlaOp Pad(XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config);
Expand All @@ -551,6 +552,9 @@ class XlaBuilder {
absl::Span<const int64_t> new_size_bounds,
const std::vector<bool>& dims_are_dynamic);

XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape,
const Shape& shape);

XlaOp Collapse(XlaOp operand, absl::Span<const int64_t> dimensions);

XlaOp Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
Expand Down Expand Up @@ -1212,7 +1216,7 @@ class XlaBuilder {
absl::Span<const int64_t> out_dim_size,
absl::Span<const int64_t> broadcast_dimensions);

friend XlaOp DynamicBroadcastInDim(
friend XlaOp MhloDynamicBroadcastInDim(
XlaOp operand, XlaOp output_dimensions,
absl::Span<const int64_t> broadcast_dimensions,
const Shape& output_shape);
Expand All @@ -1236,6 +1240,9 @@ class XlaBuilder {
absl::Span<const int64_t> new_size_bounds,
const std::vector<bool>& dims_are_dynamic);

friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape,
const Shape& shape);

friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
absl::Span<const int64_t> new_sizes,
int64_t inferred_dimension);
Expand Down Expand Up @@ -1918,9 +1925,9 @@ XlaOp BroadcastInDim(XlaOp operand, absl::Span<const int64_t> out_dim_size,
// StableHLO, and cannot be compiled. See
// https://www.tensorflow.org/mlir/hlo_ops#mhlodynamic_broadcast_in_dim_mhlodynamicbroadcastindimop.
// for the op semantics.
XlaOp DynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions,
absl::Span<const int64_t> broadcast_dimensions,
const Shape& output_shape);
XlaOp MhloDynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions,
absl::Span<const int64_t> broadcast_dimensions,
const Shape& output_shape);

// Copies the input operand to the output. This operation is for internal
// purpose and is only used by the compiler for optimization purposes or to
Expand Down Expand Up @@ -1966,6 +1973,11 @@ XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
absl::Span<const int64_t> new_size_bounds,
const std::vector<bool>& dims_are_dynamic);

// This is an experimental API for creating the mhlo.dynamic_reshape op from the
// XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot
// be compiled.
XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape);

// Enqueues an operation onto the computation that collapses the operand,
// from first to last dimension (C order), then reshapes it to the given
// dimension sizes. Conceptually, this is a limited form of "shape casting".
Expand Down

0 comments on commit f61ff93

Please sign in to comment.