Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609329369
  • Loading branch information
tensorflower-gardener committed Feb 22, 2024
1 parent a57309f commit e717e71
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 30 deletions.
11 changes: 6 additions & 5 deletions third_party/xla/xla/backends/interpreter/compiler.cc
Expand Up @@ -61,7 +61,7 @@ namespace {

// Handles custom_call ops during evaluation by routing them through the global
// CPU registry used by other CPU-based backends.
StatusOr<Literal> HandleEvaluatorCustomCall(
absl::StatusOr<Literal> HandleEvaluatorCustomCall(
const HloInstruction* custom_call, absl::Span<const Literal*> operands) {
// Find the target C function in the global registry.
auto* registry = CustomCallTargetRegistry::Global();
Expand Down Expand Up @@ -110,15 +110,15 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
return pipeline.Run(hlo_module).status();
}

StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
absl::StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
const CompileOptions& /*options*/) {
VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
return std::move(hlo_module);
}

StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
absl::StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
const CompileOptions& /*options*/) {
TF_RET_CHECK(stream_exec != nullptr);
Expand Down Expand Up @@ -147,7 +147,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
return std::move(executable);
}

StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
absl::StatusOr<std::vector<std::unique_ptr<Executable>>>
InterpreterCompiler::Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
const CompileOptions& options) {
Expand All @@ -171,7 +172,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
return std::move(ret);
}

StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
absl::StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
InterpreterCompiler::CompileAheadOfTime(
std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& aot_options) {
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/backends/interpreter/compiler.h
Expand Up @@ -42,18 +42,18 @@ class InterpreterCompiler : public Compiler {
InterpreterCompiler() {}
~InterpreterCompiler() override {}

StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
absl::StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
absl::StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
absl::StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
const CompileOptions& options) override;

StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
absl::StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& aot_options) override;

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/backends/interpreter/executable.cc
Expand Up @@ -51,7 +51,7 @@ InterpreterExecutable::InterpreterExecutable(
}
}

StatusOr<Literal> InterpreterExecutable::Evaluate(
absl::StatusOr<Literal> InterpreterExecutable::Evaluate(
const ServiceExecutableRunOptions* run_options,
const HloComputation& computation, absl::Span<const Literal> arg_literals) {
// Execute the graph using the HloEvaluator.
Expand Down
7 changes: 4 additions & 3 deletions third_party/xla/xla/backends/interpreter/executable.h
Expand Up @@ -48,9 +48,10 @@ class InterpreterExecutable : public InterpreterExecutableBase {
static int64_t ShapeSizeBytes(const Shape& shape);

protected:
StatusOr<Literal> Evaluate(const ServiceExecutableRunOptions* run_options,
const HloComputation& computation,
absl::Span<const Literal> arg_literals) override
absl::StatusOr<Literal> Evaluate(
const ServiceExecutableRunOptions* run_options,
const HloComputation& computation,
absl::Span<const Literal> arg_literals) override
ABSL_LOCKS_EXCLUDED(evaluator_lock_);

// The interpreter interprets executables with an HloEvaluator.
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/backends/interpreter/executable_base.cc
Expand Up @@ -38,7 +38,7 @@ InterpreterExecutableBase::InterpreterExecutableBase(
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
/*hlo_profile_index_map=*/nullptr) {}

StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
absl::StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
Expand Down Expand Up @@ -150,7 +150,7 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
return std::move(result);
}

StatusOr<ExecutionOutput>
absl::StatusOr<ExecutionOutput>
InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse(
const Shape& shape, const HloInputOutputAliasConfig& alias_config,
se::DeviceMemoryAllocator* allocator,
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/backends/interpreter/executable_base.h
Expand Up @@ -37,19 +37,19 @@ class InterpreterExecutableBase : public Executable {
public:
explicit InterpreterExecutableBase(std::unique_ptr<HloModule> hlo_module);

StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
absl::StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override;

protected:
virtual StatusOr<Literal> Evaluate(
virtual absl::StatusOr<Literal> Evaluate(
const ServiceExecutableRunOptions* run_options,
const HloComputation& computation,
absl::Span<const Literal> arg_literals) = 0;

private:
StatusOr<ExecutionOutput> AllocateOutputMemoryWithInputReuse(
absl::StatusOr<ExecutionOutput> AllocateOutputMemoryWithInputReuse(
const Shape& shape, const HloInputOutputAliasConfig& alias_config,
se::DeviceMemoryAllocator* allocator,
std::vector<ExecutionInput>* arguments, stream_executor::Stream* stream);
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/backends/interpreter/executor.cc
Expand Up @@ -110,7 +110,7 @@ absl::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) {
return AsExecutorStream(stream)->BlockUntilDone();
}

tsl::StatusOr<std::unique_ptr<DeviceDescription>>
absl::StatusOr<std::unique_ptr<DeviceDescription>>
XlaInterpreterExecutor::CreateDeviceDescription(int device_ordinal) {
internal::DeviceDescriptionBuilder builder;

Expand Down
7 changes: 4 additions & 3 deletions third_party/xla/xla/backends/interpreter/executor.h
Expand Up @@ -150,12 +150,12 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
return false;
}

tsl::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
absl::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
const override {
return CreateDeviceDescription(0);
}

static tsl::StatusOr<std::unique_ptr<DeviceDescription>>
static absl::StatusOr<std::unique_ptr<DeviceDescription>>
CreateDeviceDescription(int device_ordinal);

absl::Status EnablePeerAccessTo(StreamExecutorInterface *other) override {
Expand Down Expand Up @@ -184,7 +184,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {

DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);

tsl::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(const xla::Shape &shape);
absl::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(
const xla::Shape &shape);
};

} // namespace interpreter
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/backends/interpreter/platform.cc
Expand Up @@ -41,26 +41,26 @@ int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; }

const std::string& XlaInterpreterPlatform::Name() const { return name_; }

tsl::StatusOr<std::unique_ptr<DeviceDescription>>
absl::StatusOr<std::unique_ptr<DeviceDescription>>
XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const {
return XlaInterpreterExecutor::CreateDeviceDescription(ordinal);
}

tsl::StatusOr<StreamExecutor*> XlaInterpreterPlatform::ExecutorForDevice(
absl::StatusOr<StreamExecutor*> XlaInterpreterPlatform::ExecutorForDevice(
int ordinal) {
StreamExecutorConfig config;
config.ordinal = ordinal;
config.device_options = DeviceOptions::Default();
return GetExecutor(config);
}

tsl::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor(
absl::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor(
const StreamExecutorConfig& config) {
return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
}

tsl::StatusOr<std::unique_ptr<StreamExecutor>>
absl::StatusOr<std::unique_ptr<StreamExecutor>>
XlaInterpreterPlatform::GetUncachedExecutor(
const StreamExecutorConfig& config) {
auto executor = std::make_unique<StreamExecutor>(
Expand Down
8 changes: 4 additions & 4 deletions third_party/xla/xla/backends/interpreter/platform.h
Expand Up @@ -39,15 +39,15 @@ class XlaInterpreterPlatform : public Platform {

const std::string& Name() const override;

tsl::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
absl::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
int ordinal) const override;

tsl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
absl::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;

tsl::StatusOr<StreamExecutor*> GetExecutor(
absl::StatusOr<StreamExecutor*> GetExecutor(
const StreamExecutorConfig& config) override;

tsl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
absl::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
const StreamExecutorConfig& config) override;

private:
Expand Down

0 comments on commit e717e71

Please sign in to comment.