Skip to content

Commit

Permalink
[xla][gpu] Implement pipelined-p2p-rewriter.
Browse files Browse the repository at this point in the history
This pass rewrite pipelined point-to-point communication by rotating the
SendDone and RecvDone operations in a while-body to the beginning of the next iteration.
The SendDone and RecvDone operations for the last iteration are moved to the
while-op calling computation, after the while-op.

Add the pass to the GPU post-scheduler pipeline.

This is another approach to achieve the code pattern to pipeline two Send-Recv
chains decomposed from a collective-permute with a source-target pair cycle for
performance. The pipelined Send-Recv pattern puts SendDone and RecvDone before
Send and Recv in the while-body, and if we generate such code pattern too early
in the GPU compilation pipeline, copy-insertion may generate copies of Send
causing Send and SendDone with different buffers and thus correctness problem.

PiperOrigin-RevId: 621317739
  • Loading branch information
bixia1 authored and tensorflower-gardener committed May 2, 2024
1 parent 62baf6a commit 7b0010b
Show file tree
Hide file tree
Showing 9 changed files with 1,621 additions and 9 deletions.
45 changes: 45 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Expand Up @@ -3399,6 +3399,7 @@ cc_library(
]),
deps = if_gpu_is_configured([
":gpu_p2p_pipeliner",
":pipelined_p2p_rewriter",
":collective_permute_cycle_decomposer",
":address_computation_fusion_rewriter",
":algorithm_checker",
Expand Down Expand Up @@ -3648,6 +3649,7 @@ xla_test(
"//xla/service:pattern_matcher_gmock",
"//xla/service:xla_debug_info_manager",
"//xla/service/gpu:autotuner_util",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/log",
Expand Down Expand Up @@ -3945,12 +3947,15 @@ cc_library(
"gpu_algebraic_simplifier.h",
],
deps = [
":triton_support",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:algebraic_simplifier",
"//xla/service:hlo_pass",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
Expand All @@ -3963,6 +3968,7 @@ xla_cc_test(
":gpu_algebraic_simplifier",
"//xla/hlo/ir:hlo",
"//xla/service:algebraic_simplifier",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_googletest//:gtest",
Expand Down Expand Up @@ -6138,3 +6144,42 @@ xla_test(
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "pipelined_p2p_rewriter",
srcs = ["pipelined_p2p_rewriter.cc"],
hdrs = ["pipelined_p2p_rewriter.h"],
deps = [
"//xla:shape_util",
"//xla:status",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/service:collective_ops_utils",
"//xla/service:hlo_pass",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "pipelined_p2p_rewriter_test",
srcs = ["pipelined_p2p_rewriter_test.cc"],
deps = [
":pipelined_p2p_rewriter",
"//xla/hlo/ir:hlo",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test_main",
],
)
53 changes: 52 additions & 1 deletion third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc
Expand Up @@ -15,13 +15,58 @@ limitations under the License.

#include "xla/service/gpu/gpu_algebraic_simplifier.h"

#include <variant>

#include "absl/log/check.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/gpu/triton_support.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla_data.pb.h"

namespace xla::gpu {

bool IsDotSupportedByGemmFusion(const HloInstruction* dot,
se::GpuComputeCapability compute_capability) {
auto supported_output_type = [&](const PrimitiveType t) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&compute_capability);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&compute_capability);

CHECK(cuda_compute_capability || rocm_compute_capability);

switch (t) {
case F16:
case F32:
return true;
case BF16:
if (cuda_compute_capability) {
return true;
}
if (rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return false;
default:
return false;
}
};

if (!supported_output_type(dot->shape().element_type())) {
return false;
}

if (!IsTritonSupportedDataType(dot->operand(0)->shape().element_type(),
compute_capability) ||
!IsTritonSupportedDataType(dot->operand(1)->shape().element_type(),
compute_capability)) {
return false;
}
return true;
}

bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce(
const HloInstruction* hlo) {
if (!options_.enable_dot_strength_reduction()) {
Expand All @@ -44,7 +89,13 @@ bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce(
rhs->shape().rank());
// Strength-reduce vector-vector dots since they are not supported by
// GemmFusion.
return lhs_is_vector && rhs_is_vector;
if (lhs_is_vector && rhs_is_vector) {
return true;
}

// If GemmFusion cannot handle this dot, we should strength-reduce it so that
// it can be handled by the fusion pipeline.
return !IsDotSupportedByGemmFusion(dot, compute_capability_);
}

} // namespace xla::gpu
21 changes: 17 additions & 4 deletions third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h
Expand Up @@ -16,12 +16,15 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_
#define XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_

#include <utility>

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/algebraic_simplifier.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"

namespace xla::gpu {
Expand All @@ -30,16 +33,23 @@ class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor {
public:
explicit GpuAlgebraicSimplifierVisitor(
const AlgebraicSimplifierOptions& options,
se::GpuComputeCapability compute_capability,
AlgebraicSimplifier* simplifier)
: AlgebraicSimplifierVisitor(options, simplifier) {}
: AlgebraicSimplifierVisitor(options, simplifier),
compute_capability_(std::move(compute_capability)) {}

bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override;

private:
se::GpuComputeCapability compute_capability_;
};

class GpuAlgebraicSimplifier : public AlgebraicSimplifier {
public:
explicit GpuAlgebraicSimplifier(const AlgebraicSimplifierOptions& options)
: AlgebraicSimplifier(options) {}
explicit GpuAlgebraicSimplifier(const AlgebraicSimplifierOptions& options,
se::GpuComputeCapability compute_capability)
: AlgebraicSimplifier(options),
compute_capability_(std::move(compute_capability)) {}

using HloPassInterface::Run;
absl::StatusOr<bool> Run(HloModule* module,
Expand All @@ -48,7 +58,7 @@ class GpuAlgebraicSimplifier : public AlgebraicSimplifier {
XLA_VLOG_LINES(
2, "GpuAlgebraicSimplifier::Run(), before:\n" + module->ToString());
bool changed = false;
GpuAlgebraicSimplifierVisitor visitor(options_, this);
GpuAlgebraicSimplifierVisitor visitor(options_, compute_capability_, this);
for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
if (visitor.Run(comp, options_, this)) {
changed = true;
Expand All @@ -58,6 +68,9 @@ class GpuAlgebraicSimplifier : public AlgebraicSimplifier {
2, "GpuAlgebraicSimplifier::Run(), after:\n" + module->ToString());
return changed;
}

private:
se::GpuComputeCapability compute_capability_;
};

} // namespace xla::gpu
Expand Down
33 changes: 29 additions & 4 deletions third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/algebraic_simplifier.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/statusor.h"

Expand All @@ -43,8 +44,9 @@ ENTRY entry {
const HloInstruction* dot = module->entry_computation()->root_instruction();
AlgebraicSimplifierOptions options;
options.set_enable_dot_strength_reduction(true);
GpuAlgebraicSimplifier simplifier(options);
GpuAlgebraicSimplifierVisitor visitor(options, &simplifier);
se::CudaComputeCapability ampere(8, 0);
GpuAlgebraicSimplifier simplifier(options, ampere);
GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
}

Expand All @@ -63,10 +65,33 @@ ENTRY entry {
const HloInstruction* dot = module->entry_computation()->root_instruction();
AlgebraicSimplifierOptions options;
options.set_enable_dot_strength_reduction(true);
GpuAlgebraicSimplifier simplifier(options);
GpuAlgebraicSimplifierVisitor visitor(options, &simplifier);
se::CudaComputeCapability ampere(8, 0);
GpuAlgebraicSimplifier simplifier(options, ampere);
GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
EXPECT_FALSE(visitor.ShouldStrengthReduceDotToReduce(dot));
}

TEST_F(GpuAlgebraicSimplifierTest,
DotWithTypeUnsupportedByGemmFusionShouldBeStrengthReduced) {
const std::string& hlo_string = R"(
HloModule m
ENTRY entry {
p0 = c64[32, 5, 7] parameter(0)
p1 = c64[32, 5] parameter(1)
ROOT dot = c64[32,7] dot(p0, p1), lhs_batch_dims={0},
lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
const HloInstruction* dot = module->entry_computation()->root_instruction();
AlgebraicSimplifierOptions options;
options.set_enable_dot_strength_reduction(true);
se::CudaComputeCapability ampere(8, 0);
GpuAlgebraicSimplifier simplifier(options, ampere);
GpuAlgebraicSimplifierVisitor visitor(options, ampere, &simplifier);
EXPECT_TRUE(visitor.ShouldStrengthReduceDotToReduce(dot));
}

} // namespace
} // namespace xla::gpu
7 changes: 7 additions & 0 deletions third_party/xla/xla/service/gpu/gpu_compiler.cc
Expand Up @@ -157,6 +157,7 @@ limitations under the License.
#include "xla/service/gpu/model/gpu_cost_model_stats_collection.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/move_copy_to_users.h"
#include "xla/service/gpu/pipelined_p2p_rewriter.h"
#include "xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h"
#include "xla/service/gpu/reduction_degenerate_dim_remover.h"
#include "xla/service/gpu/reduction_dimension_grouper.h"
Expand Down Expand Up @@ -2199,6 +2200,12 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
{
HloPassPipeline pipeline("post-scheduling-passes");

if (module->config()
.debug_options()
.xla_gpu_enable_pipelined_collectives() ||
module->config().debug_options().xla_gpu_enable_pipelined_p2p()) {
pipeline.AddPass<PipelinedP2PRewriter>();
}
HloPredicate is_nop =
HloPredicateIsOp<HloOpcode::kParameter, HloOpcode::kConstant,
HloOpcode::kBitcast, HloOpcode::kGetTupleElement>;
Expand Down

0 comments on commit 7b0010b

Please sign in to comment.