Skip to content

Commit

Permalink
Make LazyGraphExecutor extensible (pytorch#87218)
Browse files Browse the repository at this point in the history
Add `LazyGraphExecutor` to backend interface so that its is extensible by a vendor backend.

I've made some preliminary methods virtual. Not sure if we want to make all methods in `LazyGraphExecutor` virtual.

Pull Request resolved: pytorch#87218
Approved by: https://github.com/wconstab, https://github.com/alanwaketan
  • Loading branch information
antoniojkim authored and sgrigory committed Oct 28, 2022
1 parent 122d8f1 commit a2f855a
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e1f5a49664b904e3ec1ddb9095ca75b6bbb5c10d
eff277e81fcfdeccba71e75ff40b6e2f3e29e27b
5 changes: 0 additions & 5 deletions torch/csrc/lazy/backend/backend_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ const BackendImplInterface* getBackend() {
return interface;
}

// default implementation
bool BackendImplInterface::ShouldSyncTensor(const LazyTensorPtr tensor) const {
return tensor->GetIrValue()->op() != ltc_not_supported;
}

BackendRegistrar::BackendRegistrar(
const BackendImplInterface* backend_impl_interface) {
backend_impl_registry.store(backend_impl_interface);
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/lazy/backend/backend_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <atomic>
Expand Down Expand Up @@ -41,8 +42,6 @@ class TORCH_API BackendImplInterface {

virtual const IrBuilder* GetIrBuilder() const = 0;

virtual bool ShouldSyncTensor(const LazyTensorPtr tensor) const;

/**
* Data Transfer
* */
Expand Down
13 changes: 11 additions & 2 deletions torch/csrc/lazy/core/lazy_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,15 @@ bool TensorsHaveIR(const std::vector<LazyTensorPtr>& tensors) {
return false;
}

std::atomic<LazyGraphExecutor*> lazy_graph_executor_registry;
} // namespace

void LazyGraphExecutor::Register(LazyGraphExecutor* executor) {
lazy_graph_executor_registry.store(executor);
}
LazyGraphExecutor* LazyGraphExecutor::Get() {
static LazyGraphExecutor* executor = new LazyGraphExecutor();
auto* executor = lazy_graph_executor_registry.load();
TORCH_CHECK(executor, "Lazy graph executor not registered.");
return executor;
}

Expand Down Expand Up @@ -604,6 +609,10 @@ void LazyGraphExecutor::Async::Wait() {
}
}

bool LazyGraphExecutor::ShouldSyncTensor(const LazyTensorPtr tensor) const {
return tensor->GetIrValue()->op() != ltc_not_supported;
}

LazyGraphExecutor::SyncTensorCollection LazyGraphExecutor::CollectSyncTensors(
const std::vector<LazyTensorPtr>& tensors,
const SyncTensorsConfig& config) {
Expand Down Expand Up @@ -635,7 +644,7 @@ LazyGraphExecutor::SyncTensorCollection LazyGraphExecutor::CollectSyncTensors(
tensors[i]->CurrentDataHandle() == nullptr) {
Value ir_value = tensors[i]->CurrentIrValue();
if (ir_value) {
if (getBackend()->ShouldSyncTensor(tensors[i])) {
if (ShouldSyncTensor(tensors[i])) {
// Add only tensors which need to be synced.
coll.hash = HashCombine(coll.hash, ir_value.hash());
coll.indices.push_back(i);
Expand Down
14 changes: 12 additions & 2 deletions torch/csrc/lazy/core/lazy_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ class TORCH_API LazyGraphExecutor {
bool read_only = false;
};

// Register a lazy graph executor instance that can be retrieved using Get()
static void Register(LazyGraphExecutor*);
static LazyGraphExecutor* Get();

void RegisterTensor(std::shared_ptr<LazyTensor::Data> data);
void UnregisterTensor(LazyTensor::Data* data);
virtual ~LazyGraphExecutor() = default;

// Override these methods to perform custom tensor registration and
// unregistration Note: It is vital that the parent implementations are also
// called
// in order for the tensors to show up in the live tensor list
virtual void RegisterTensor(std::shared_ptr<LazyTensor::Data> data);
virtual void UnregisterTensor(LazyTensor::Data* data);

// Seed for random generator
Value GetRngSeed(const BackendDevice& device);
Expand Down Expand Up @@ -181,6 +189,8 @@ class TORCH_API LazyGraphExecutor {
std::vector<BackendDataPtr> tensors_data;
};

virtual bool ShouldSyncTensor(const LazyTensorPtr tensor) const;

SyncTensorCollection CollectSyncTensors(
const std::vector<LazyTensorPtr>& tensors,
const SyncTensorsConfig& config);
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/Functions.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/generated/LazyNativeFunctions.h>
#include <torch/csrc/lazy/ts_backend/config.h>
#include <torch/csrc/lazy/ts_backend/ir_builder.h>
Expand Down Expand Up @@ -273,6 +274,9 @@ void InitTorchScriptBackend() {
register_ts_ltc_eager_fallback();
static std::unique_ptr<BackendRegistrar> s_registrar;
s_registrar = std::make_unique<BackendRegistrar>(GetTSBackendImpl());

static LazyGraphExecutor* executor = new LazyGraphExecutor();
LazyGraphExecutor::Register(executor);
}

} // namespace lazy
Expand Down

0 comments on commit a2f855a

Please sign in to comment.