Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609296056
  • Loading branch information
tensorflower-gardener committed Feb 22, 2024
1 parent 5039c3a commit 83870a0
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 91 deletions.
121 changes: 63 additions & 58 deletions third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc
Expand Up @@ -96,9 +96,10 @@ namespace {
using primitive_util::NativeTypeOf;

template <typename OperandT>
StatusOr<Literal> Compare(const Shape& shape, Comparison comparison,
LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
auto populate = [&](auto compare_op) -> StatusOr<Literal> {
absl::StatusOr<Literal> Compare(const Shape& shape, Comparison comparison,
LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
auto populate = [&](auto compare_op) -> absl::StatusOr<Literal> {
Literal result(shape);
TF_RETURN_IF_ERROR(result.PopulateParallel<bool>(
[&](absl::Span<const int64_t> multi_index, int /*thread_id*/) {
Expand Down Expand Up @@ -147,7 +148,7 @@ StatusOr<Literal> Compare(const Shape& shape, Comparison comparison,
std::optional<bool> GetInstructionStaticValueAsBool(
const HloInstruction* instruction) {
HloEvaluator evaluator;
StatusOr<Literal> static_value = evaluator.Evaluate(
absl::StatusOr<Literal> static_value = evaluator.Evaluate(
instruction, /*recursively_evaluate_nonconstant_operands=*/true);
if (static_value.ok()) {
return static_value->GetFirstElement<bool>();
Expand Down Expand Up @@ -251,7 +252,7 @@ struct DynamicOrStaticInteger {
std::optional<DynamicOrStaticInteger> GetInstructionValueAsInteger(
const HloInstruction* instruction) {
HloEvaluator evaluator;
StatusOr<Literal> static_value = evaluator.Evaluate(
absl::StatusOr<Literal> static_value = evaluator.Evaluate(
instruction, /*recursively_evaluate_nonconstant_operands=*/true);
if (static_value.ok()) {
if (instruction->shape().element_type() == PrimitiveType::PRED) {
Expand Down Expand Up @@ -859,7 +860,7 @@ HloEvaluator::HloEvaluator(int64_t max_loop_iterations)
});
}

StatusOr<Literal> HloEvaluator::Evaluate(
absl::StatusOr<Literal> HloEvaluator::Evaluate(
const HloComputation& computation,
absl::Span<const Literal* const> arg_literals) {
CHECK(computation.parent() != nullptr);
Expand Down Expand Up @@ -920,7 +921,7 @@ StatusOr<Literal> HloEvaluator::Evaluate(
return result.Clone();
}

StatusOr<Literal> HloEvaluator::Evaluate(
absl::StatusOr<Literal> HloEvaluator::Evaluate(
const HloInstruction* instruction,
bool recursively_evaluate_nonconstant_operands) {
arg_literals_.clear();
Expand Down Expand Up @@ -955,7 +956,7 @@ bool HloEvaluator::TryEvaluate(const HloInstruction* instruction,
return true;
}

StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
absl::StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
const HloInstruction* instruction,
const absl::flat_hash_map<const HloInstruction*, const Literal*>&
substitutions) {
Expand Down Expand Up @@ -983,7 +984,7 @@ StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
return result;
}

StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
absl::StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.Clone());
Expand All @@ -998,7 +999,7 @@ StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
return result;
}

StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
absl::StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
HloOpcode opcode, const Literal& lhs, const Literal& rhs,
const Literal& ehs) {
std::unique_ptr<HloInstruction> lhs_instr =
Expand All @@ -1016,7 +1017,7 @@ StatusOr<Literal> HloEvaluator::EvaluateElementwiseTernaryOp(
return Evaluate(cloned_instruction.get());
}

StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
absl::StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
ComparisonDirection direction, const Literal& lhs, const Literal& rhs) {
std::unique_ptr<HloInstruction> lhs_instr =
HloInstruction::CreateConstant(lhs.Clone());
Expand All @@ -1032,7 +1033,7 @@ StatusOr<Literal> HloEvaluator::EvaluateElementwiseCompareOp(
return result;
}

StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
absl::StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand) {
std::unique_ptr<HloInstruction> operand_instr =
HloInstruction::CreateConstant(operand.Clone());
Expand All @@ -1046,7 +1047,7 @@ StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
return result;
}

StatusOr<Literal> HloEvaluator::EvaluateDotOp(
absl::StatusOr<Literal> HloEvaluator::EvaluateDotOp(
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config, const Literal& lhs,
const Literal& rhs) {
Expand Down Expand Up @@ -1189,7 +1190,7 @@ Status HloEvaluator::EvaluateInternal(
}
if (!tuple_points_to_analysis_cache_) {
HloModule* module = instruction->GetModule();
StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
absl::StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
tuple_points_to_analysis = TuplePointsToAnalysis::Run(module);
if (tuple_points_to_analysis.ok()) {
tuple_points_to_analysis_cache_ =
Expand Down Expand Up @@ -2347,7 +2348,7 @@ class OutputBatchIndexToInputIndex {
// same storage for all invocations.
//
// This returns a Span into memory owned by the class.
StatusOr<absl::Span<const int64_t>> operator()(
absl::StatusOr<absl::Span<const int64_t>> operator()(
absl::Span<const int64_t> output_index) {
PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
Expand Down Expand Up @@ -2467,7 +2468,7 @@ class OutputOffsetIndexToInputIndex {
// result (input_index_), mutating it in place.
//
// This returns a Span into memory owned by the class.
StatusOr<absl::Span<const int64_t>> operator()(
absl::StatusOr<absl::Span<const int64_t>> operator()(
absl::Span<const int64_t> output_index) {
PropagateOutputIndexWindowDimsToInputIndex(output_index);
return absl::Span<const int64_t>(input_index_);
Expand Down Expand Up @@ -2507,9 +2508,9 @@ class OutputOffsetIndexToInputIndex {
// Reshapes the gather indices input to have a trailing degenerate `1` dimension
// if necessary. Hands over the ownership of the newly created literal (if
// there is one) to `reshaped_start_indices`.
static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
int64_t index_vector_dim, const Literal& start_indices,
Literal* reshaped_start_indices) {
static absl::StatusOr<std::reference_wrapper<const Literal>>
ReshapedGatherIndices(int64_t index_vector_dim, const Literal& start_indices,
Literal* reshaped_start_indices) {
if (start_indices.shape().dimensions_size() != index_vector_dim) {
return std::cref(start_indices);
}
Expand Down Expand Up @@ -2574,7 +2575,8 @@ Status HloEvaluator::HandleGather(const HloInstruction* gather) {
auto gather_inner_loop_body =
[&](absl::Span<const int64_t> output_window_index,
absl::Span<const int64_t> input_gather_index,
absl::Span<const int64_t> output_gather_index) -> StatusOr<bool> {
absl::Span<const int64_t> output_gather_index)
-> absl::StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
absl::Span<const int64_t> input_window_index,
output_offset_index_to_input_index(output_window_index));
Expand Down Expand Up @@ -2608,7 +2610,8 @@ Status HloEvaluator::HandleGather(const HloInstruction* gather) {
};

auto gather_outer_loop_body =
[&](absl::Span<const int64_t> output_gather_index) -> StatusOr<bool> {
[&](absl::Span<const int64_t> output_gather_index)
-> absl::StatusOr<bool> {
TF_ASSIGN_OR_RETURN(absl::Span<const int64_t> input_gather_index,
output_batch_index_to_input_index(output_gather_index));
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
Expand All @@ -2628,7 +2631,7 @@ namespace {
// Reshapes the scatter indices input to have a trailing degenerate `1`
// dimension if necessary. Hands over the ownership of the newly created
// literal (if there is one) to `reshaped_indices`.
StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
absl::StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
int64_t index_vector_dim, const Literal& indices,
Literal* reshaped_indices) {
if (indices.shape().dimensions_size() != index_vector_dim) {
Expand Down Expand Up @@ -2750,7 +2753,7 @@ class UpdateScatterIndexToInputIndex {
// same storage for all invocations.
//
// This returns a Span into memory owned by the class.
StatusOr<absl::Span<const int64_t>> operator()(
absl::StatusOr<absl::Span<const int64_t>> operator()(
absl::Span<const int64_t> update_index) {
PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
Expand Down Expand Up @@ -2873,7 +2876,7 @@ class UpdateWindowIndexToInputIndex {
// result (input_index_), mutating it in place.
//
// This returns a Span into memory owned by the class.
StatusOr<absl::Span<const int64_t>> operator()(
absl::StatusOr<absl::Span<const int64_t>> operator()(
absl::Span<const int64_t> update_index) {
PropagateUpdateIndexWindowDimsToInputIndex(update_index);
return absl::Span<const int64_t>(input_index_);
Expand Down Expand Up @@ -2966,7 +2969,8 @@ Status HloEvaluator::HandleScatter(const HloInstruction* hlo) {
auto scatter_inner_loop_body =
[&](absl::Span<const int64_t> update_window_index,
absl::Span<const int64_t> input_scatter_index,
absl::Span<const int64_t> update_scatter_index) -> StatusOr<bool> {
absl::Span<const int64_t> update_scatter_index)
-> absl::StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
absl::Span<const int64_t> input_window_index,
update_window_index_to_input_index(update_window_index));
Expand Down Expand Up @@ -3018,7 +3022,8 @@ Status HloEvaluator::HandleScatter(const HloInstruction* hlo) {
};

auto scatter_outer_loop_body =
[&](absl::Span<const int64_t> update_scatter_index) -> StatusOr<bool> {
[&](absl::Span<const int64_t> update_scatter_index)
-> absl::StatusOr<bool> {
TF_ASSIGN_OR_RETURN(
absl::Span<const int64_t> input_scatter_index,
update_scatter_index_to_input_index(update_scatter_index));
Expand Down Expand Up @@ -3416,10 +3421,10 @@ Status HloEvaluator::HandleSelect(const HloInstruction* select) {

namespace {

StatusOr<Literal> CreateScalarLiteral(int64_t value,
PrimitiveType element_type) {
absl::StatusOr<Literal> CreateScalarLiteral(int64_t value,
PrimitiveType element_type) {
return primitive_util::PrimitiveTypeSwitch<StatusOr<Literal>>(
[&](auto primitive_type_constant) -> StatusOr<Literal> {
[&](auto primitive_type_constant) -> absl::StatusOr<Literal> {
if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) {
return LiteralUtil::CreateR0(
static_cast<NativeTypeOf<primitive_type_constant>>(value));
Expand All @@ -3432,7 +3437,7 @@ StatusOr<Literal> CreateScalarLiteral(int64_t value,
// Parses the while loop if it matches one of the known patterns. Returns the
// value of the loop induction variable after the loop execution if the loop is
// static.
StatusOr<Literal> TryParseAndEvaluateWhileInductionVar(
absl::StatusOr<Literal> TryParseAndEvaluateWhileInductionVar(
const HloInstruction* while_hlo) {
std::optional<ParsedWhileLoop> parsed_while_loop =
PatternMatchParseWhileLoop(while_hlo);
Expand Down Expand Up @@ -3507,7 +3512,7 @@ Status HloEvaluator::HandleWhile(const HloInstruction* while_hlo) {
dynamic_dimension_inference_);
while (keep_going) {
if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) {
StatusOr<Literal> result =
absl::StatusOr<Literal> result =
TryParseAndEvaluateWhileInductionVar(while_hlo);
if (result.ok()) {
lcv = std::move(result).value();
Expand Down Expand Up @@ -3546,11 +3551,11 @@ Literal ExtractLiteralFromIndexPositions(const Literal& from,
return LiteralUtil::CreateR1<NativeT>(values);
}

StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
absl::Span<int64_t const> indices) {
absl::StatusOr<Literal> ExtractFromIndexPositions(
const Literal& from, absl::Span<int64_t const> indices) {
PrimitiveType type = from.shape().element_type();
return primitive_util::PrimitiveTypeSwitch<StatusOr<Literal>>(
[&](auto primitive_type_constant) -> StatusOr<Literal> {
[&](auto primitive_type_constant) -> absl::StatusOr<Literal> {
if constexpr (primitive_util::IsArrayType(primitive_type_constant)) {
return ExtractLiteralFromIndexPositions<
NativeTypeOf<primitive_type_constant>>(from, indices);
Expand Down Expand Up @@ -3609,9 +3614,9 @@ void IterateThroughWindow(
}

template <typename Fp, typename Uint, typename ResultT>
StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Literal& random_literal,
const Shape& result_shape) {
absl::StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Literal& random_literal,
const Shape& result_shape) {
std::function<ResultT(Fp, Uint)> stochastic_convert_op =
[](Fp operand, Uint random) -> ResultT {
bool is_negative = static_cast<bool>(Eigen::numext::signbit(operand));
Expand Down Expand Up @@ -3673,9 +3678,9 @@ StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
// Converts from primitive types to native types.
template <PrimitiveType operand_type, PrimitiveType random_type,
PrimitiveType result_type>
StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Literal& random_literal,
const Shape& result_shape) {
absl::StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Literal& random_literal,
const Shape& result_shape) {
return StochasticConvertOp<
typename primitive_util::PrimitiveTypeToNative<operand_type>::type,
typename primitive_util::PrimitiveTypeToNative<random_type>::type,
Expand All @@ -3685,11 +3690,11 @@ StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,

// Evaluates all possible paths of converting to different integers.
template <PrimitiveType operand_type, PrimitiveType random_type>
StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Literal& random_literal,
const Shape& result_shape) {
absl::StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Literal& random_literal,
const Shape& result_shape) {
return primitive_util::PrimitiveTypeSwitch<StatusOr<Literal>>(
[&](auto primitive_type_constant) -> StatusOr<Literal> {
[&](auto primitive_type_constant) -> absl::StatusOr<Literal> {
if constexpr (primitive_util::IsSignedIntegralType(
primitive_type_constant)) {
return StochasticConvertOp<operand_type, random_type,
Expand All @@ -3706,11 +3711,11 @@ StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
result_shape.element_type());
}

StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Literal& random_literal,
const Shape& result_shape) {
absl::StatusOr<Literal> StochasticConvertOp(const Literal& operand_literal,
const Literal& random_literal,
const Shape& result_shape) {
return primitive_util::PrimitiveTypeSwitch<StatusOr<Literal>>(
[&](auto primitive_type_constant) -> StatusOr<Literal> {
[&](auto primitive_type_constant) -> absl::StatusOr<Literal> {
if constexpr (primitive_util::IsFloatingPointType(
primitive_type_constant)) {
return StochasticConvertOp<
Expand Down Expand Up @@ -3925,9 +3930,9 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) {
<< " accessing increment of size " << increment.size();
increment[sort_dim] = sort_dim_elements;

auto comparator = [sort](absl::Span<const Literal> literals_to_sort,
int64_t a, int64_t b,
HloEvaluator* embedded_evaluator) -> StatusOr<bool> {
auto comparator =
[sort](absl::Span<const Literal> literals_to_sort, int64_t a, int64_t b,
HloEvaluator* embedded_evaluator) -> absl::StatusOr<bool> {
absl::InlinedVector<Literal, 8> literals;
literals.reserve(2 * sort->operand_count());
for (int64_t i = 0; i < sort->operand_count(); ++i) {
Expand All @@ -3948,10 +3953,10 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) {
embedded_evaluator->ResetVisitStates();
return computed_result.Get<bool>({});
};
auto less_than = [&comparator](
absl::Span<const Literal> literals_to_sort, int64_t a,
int64_t b,
HloEvaluator* embedded_evaluator) -> StatusOr<bool> {
auto less_than =
[&comparator](absl::Span<const Literal> literals_to_sort, int64_t a,
int64_t b,
HloEvaluator* embedded_evaluator) -> absl::StatusOr<bool> {
TF_ASSIGN_OR_RETURN(bool a_is_smaller,
comparator(literals_to_sort, a, b, embedded_evaluator));
#ifndef NDEBUG
Expand Down Expand Up @@ -4101,7 +4106,7 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) {
// Iterate through each dimension except 'sort_dim'.
TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
key_shape, zero_base, key_shape.dimensions(), increment,
[&](absl::Span<const int64_t> indices) -> StatusOr<bool> {
[&](absl::Span<const int64_t> indices) -> absl::StatusOr<bool> {
// Extract a slice from each operand literal that corresponds to
// exactly the row in dimension 'sort_dim'.
std::vector<int64_t> limit_indices(indices.begin(), indices.end());
Expand Down Expand Up @@ -4186,7 +4191,7 @@ static bool IsScalarAdd(HloComputation* computation) {
// the user-provided computation on the accumulator and the output element
// (until the reduction is completed, the output element is also used as
// an accumulator).
static StatusOr<bool> PerformReductionStep(
static absl::StatusOr<bool> PerformReductionStep(
bool is_tuple, absl::Span<const int64_t> input_index,
absl::Span<const int64_t> output_index,
absl::Span<const Literal* const> input_args, absl::Span<Literal> results,
Expand Down Expand Up @@ -4236,7 +4241,7 @@ static StatusOr<bool> PerformReductionStep(
return true;
}

static StatusOr<bool> GenerateReduceOutputElement(
static absl::StatusOr<bool> GenerateReduceOutputElement(
bool is_tuple, absl::Span<const int64_t> output_index,

absl::Span<const Literal* const> init_values,
Expand Down

0 comments on commit 83870a0

Please sign in to comment.