Skip to content

Commit

Permalink
[xla:gpu] Include xla_gpu_filter_kernels_spilling_registers_on_autotu…
Browse files Browse the repository at this point in the history
…ning in the CompilationCacheKey

This fixes a filure in TritonAutotunerTest.DoNotFilterOutAutotuningKernelSpillingRegisters
uncovered by 9a986d6.

PiperOrigin-RevId: 609309769
  • Loading branch information
superbobry authored and tensorflower-gardener committed Feb 22, 2024
1 parent 44b338f commit 4042539
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
5 changes: 4 additions & 1 deletion third_party/xla/xla/service/gpu/nvptx_compiler.cc
Expand Up @@ -638,10 +638,13 @@ NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
tsl::profiler::TraceMeLevel::kInfo);
CompilationCacheValue* cache_value = nullptr;
bool inserted = [&] {
auto flags = CompilationCacheFlags{
hlo_module_config.debug_options()
.xla_gpu_filter_kernels_spilling_registers_on_autotuning()};
absl::MutexLock lock(&mutex_);
auto [iter, inserted] = compilation_cache_.emplace(
std::piecewise_construct,
std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable),
std::forward_as_tuple(ptx, cc.major, cc.minor, relocatable, flags),
std::forward_as_tuple());
// Do not move this assignment outside of the critical section. There is
// a TOCTOU if `compilation_cache_` is rehashed before the iterator is used.
Expand Down
31 changes: 27 additions & 4 deletions third_party/xla/xla/service/gpu/nvptx_compiler.h
Expand Up @@ -102,6 +102,22 @@ class NVPTXCompiler : public GpuCompiler {
const HloModuleConfig& hlo_module_config, absl::string_view module_name,
bool relocatable, const CompileOptions& options);

struct CompilationCacheFlags {
template <typename H>
friend H AbslHashValue(H h, const CompilationCacheFlags& flags) {
return H::combine(std::move(h),
flags.filter_kernels_spilling_registers_on_autotuning);
}

friend bool operator==(const CompilationCacheFlags& a,
const CompilationCacheFlags& b) {
return a.filter_kernels_spilling_registers_on_autotuning ==
b.filter_kernels_spilling_registers_on_autotuning;
}

bool filter_kernels_spilling_registers_on_autotuning;
};

// The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor}
// -> cubin so we don't recompile the same ptx twice. This is important for
// some interactive workflows. (We also cache at the HLO level, but sometimes
Expand All @@ -116,26 +132,33 @@ class NVPTXCompiler : public GpuCompiler {
// and leave compilation up to the driver.
struct CompilationCacheKey {
CompilationCacheKey(std::string ptx, int cc_major, int cc_minor,
bool relocatable)
bool relocatable, CompilationCacheFlags flags)
: ptx(std::move(ptx)),
cc_major(cc_major),
cc_minor(cc_minor),
relocatable(relocatable) {}
relocatable(relocatable),
flags(std::move(flags)) {}

template <typename H>
friend H AbslHashValue(H h, const CompilationCacheKey& key) {
return H::combine(std::move(h), key.ptx, key.cc_major, key.cc_minor,
key.relocatable);
key.relocatable, key.flags);
}

friend bool operator==(const CompilationCacheKey& a,
const CompilationCacheKey& b) {
return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor &&
a.ptx == b.ptx && a.relocatable == b.relocatable;
a.ptx == b.ptx && a.relocatable == b.relocatable &&
a.flags == b.flags;
}

std::string ptx;
int cc_major;
int cc_minor;
bool relocatable;
CompilationCacheFlags flags;
};

struct CompilationCacheValue {
bool compilation_done = false;
absl::StatusOr<std::vector<uint8_t>> maybe_cubin;
Expand Down

0 comments on commit 4042539

Please sign in to comment.