Skip to content

Commit

Permalink
importer: fix usage after PyTorch update (#1555)
Browse files Browse the repository at this point in the history
Unless requested otherwise, PyTorch no longer installs most of the
header files under the caffe2 directory (see
pytorch/pytorch#87986).  This breaks our
importer code since we need to use the `MakeGuard()` function to execute
statements in the event of exceptions.

To fix this issue, this patch implements a rudimentary version of
PyTorch's ScopeGuard, where once the class variable goes out of scope,
it executes a predefined method.
  • Loading branch information
ashay committed Nov 4, 2022
1 parent fedf8c0 commit d99b2dd
Showing 1 changed file with 17 additions and 5 deletions.
Expand Up @@ -22,7 +22,6 @@
#include "torch-mlir-c/TorchTypes.h"

#include "ATen/native/quantized/PackedParams.h"
#include "caffe2/core/scope_guard.h"

using namespace torch_mlir;

Expand Down Expand Up @@ -153,6 +152,22 @@ class IValueImporter {
};
} // namespace

// RAII pattern to insert an operation before going out of scope.
class InserterGuard {
private:
MlirBlock _importBlock;
MlirOperation _nnModule;

public:
InserterGuard(MlirBlock importBlock, MlirOperation nnModule)
: _importBlock(importBlock), _nnModule(nnModule) {}

~InserterGuard() {
mlirBlockInsertOwnedOperationBefore(
_importBlock, mlirBlockGetTerminator(_importBlock), _nnModule);
}
};

MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
Expand All @@ -177,10 +192,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
auto inserter = caffe2::MakeGuard([&]() {
mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
});
InserterGuard inserterGuard(importBlock, nnModule);

if (!rootModuleName.has_value()) {
rootModuleName = moduleTypeName;
Expand Down

0 comments on commit d99b2dd

Please sign in to comment.