Skip to content

Commit

Permalink
Skip graph export for newly added TF MLIR functions if backend_compil…
Browse files Browse the repository at this point in the history
…er is not specified.

All the needed function defs are already in the function library if no special backend_compiler is used. This reduces the memory usage.

PiperOrigin-RevId: 628740418
  • Loading branch information
cky9301 authored and tensorflower-gardener committed Apr 28, 2024
1 parent 32520d4 commit a4136e7
Showing 1 changed file with 30 additions and 28 deletions.
58 changes: 30 additions & 28 deletions tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.cc
Expand Up @@ -59,37 +59,39 @@ absl::StatusOr<mlrt::bc::Buffer> ConvertTfMlirToBytecode(
mlrt::bc::Buffer bytecode_buffer;
TF_RETURN_IF_ERROR(ConvertTfMlirToRuntimeExecutable(
options, module,
[&bytecode_buffer, &fallback_state, &model_context, module_with_op_keys](
mlir::PassManager& pm, mlir::ModuleOp module,
const TfrtPipelineOptions& options) {
if (auto* flib_def = model_context.function_library_definition()) {
// Copy the module before exporting as exporting to graph will
// transform the MLIR to TFG dialect.
mlir::OwningOpRef<mlir::ModuleOp> copy(module.clone());
TF_RETURN_IF_ERROR(
ExportFunctionDefs(*copy, [flib_def](FunctionDef function_def) {
VLOG(1) << absl::StrCat(
"Exporting MLIR function as function_def: ",
// clang-tidy off
function_def.DebugString()
// clang-tidy on
);
[&bytecode_buffer, &fallback_state, &model_context,
backend_compiler = options.backend_compiler,
module_with_op_keys](mlir::PassManager& pm, mlir::ModuleOp module,
const TfrtPipelineOptions& options) {
if (backend_compiler) {
if (auto* flib_def = model_context.function_library_definition()) {
// Copy the module before exporting as exporting to graph will
// transform the MLIR to TFG dialect.
mlir::OwningOpRef<mlir::ModuleOp> copy(module.clone());
TF_RETURN_IF_ERROR(
ExportFunctionDefs(*copy, [flib_def](FunctionDef function_def) {
VLOG(1) << absl::StrCat(
"Exporting MLIR function as function_def: ",
// NOLINTNEXTLINE
function_def.DebugString());

// The TF MLIR compiler may change the function name. Then we
// need to retrieve the original name from the
// _original_func_name attribute.
auto iter = function_def.attr().find("_original_func_name");
if (iter != function_def.attr().end()) {
function_def.mutable_signature()->set_name(iter->second.s());
}
// The TF MLIR compiler may change the function name. Then we
// need to retrieve the original name from the
// _original_func_name attribute.
auto iter = function_def.attr().find("_original_func_name");
if (iter != function_def.attr().end()) {
function_def.mutable_signature()->set_name(
iter->second.s());
}

const auto& name = function_def.signature().name();
if (flib_def->Contains(name)) {
TF_RETURN_IF_ERROR(flib_def->RemoveFunction(name));
}
const auto& name = function_def.signature().name();
if (flib_def->Contains(name)) {
TF_RETURN_IF_ERROR(flib_def->RemoveFunction(name));
}

return flib_def->AddFunctionDef(function_def);
}));
return flib_def->AddFunctionDef(function_def);
}));
}
}

mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
Expand Down

0 comments on commit a4136e7

Please sign in to comment.