Skip to content

Commit

Permalink
Rename mlir_import_options to saved_model_import_options.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619370979
  • Loading branch information
rocketas authored and tensorflower-gardener committed May 8, 2024
1 parent 4382439 commit 57984a0
Show file tree
Hide file tree
Showing 18 changed files with 51 additions and 49 deletions.
2 changes: 1 addition & 1 deletion tensorflow/BUILD
Expand Up @@ -1389,7 +1389,7 @@ tf_cc_shared_library(
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:export_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
"//tensorflow/compiler/mlir/tensorflow:saved_model_import_options",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"@local_xla//xla/service:computation_placer",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/BUILD
Expand Up @@ -229,6 +229,7 @@ tf_cc_binary(
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
"//tensorflow/compiler/mlir/tensorflow/translate:saved_model_import_options",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
"@com_google_absl//absl/strings",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/lite/BUILD
Expand Up @@ -1445,8 +1445,8 @@ cc_library(
"//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:saved_model_import_options",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes",
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
Expand Up @@ -68,8 +68,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/op.h"
Expand Down Expand Up @@ -551,7 +551,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
if (!module_or.status().ok()) return module_or.status();
return std::move(module_or).value();
} else if (saved_model_version == 1) {
MLIRImportOptions options;
SavedModelImportOptions options;
options.upgrade_legacy = specs.upgrade_legacy;
options.unconditionally_use_set_output_shapes = true;
options.lift_variables = enable_variable_lifting;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/python/BUILD
Expand Up @@ -46,6 +46,7 @@ cc_library(
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:tfe_context_internal",
"@local_xla//xla/mlir/framework/transforms:passes",
"//tensorflow/compiler/mlir/tensorflow/translate:saved_model_import_options",
"@local_xla//xla/mlir_hlo:all_passes",
"//tensorflow/compiler/mlir/lite:flatbuffer_import",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/compiler/mlir/python/mlir.cc
Expand Up @@ -55,6 +55,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
Expand Down Expand Up @@ -275,7 +276,7 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
mlir::func::registerAllExtensions(registry);
mlir::MLIRContext context(registry);

tensorflow::MLIRImportOptions import_options;
tensorflow::SavedModelImportOptions import_options;
import_options.upgrade_legacy = upgrade_legacy;
auto module_or = SavedModelSignatureDefsToMlirImportLite(
saved_model_path, tag_set, absl::Span<std::string>(exported_names),
Expand Down Expand Up @@ -312,7 +313,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
mlir::MLIRContext context(registry);
tensorflow::MLIRImportOptions import_options;
tensorflow::SavedModelImportOptions import_options;
import_options.upgrade_legacy = upgrade_legacy;
import_options.lift_variables = lift_variables;
import_options.include_variables_in_initializers =
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD
Expand Up @@ -217,7 +217,7 @@ cc_library(
"//tensorflow/cc/saved_model:reader",
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess",
"//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
"//tensorflow/compiler/mlir/tensorflow:saved_model_import_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core/protobuf:for_core_protos_cc",
"@com_google_absl//absl/algorithm:container",
Expand Down Expand Up @@ -417,7 +417,7 @@ cc_library(
"//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:unfreeze_constants",
"//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
"//tensorflow/compiler/mlir/tensorflow:saved_model_import_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
Expand Down
Expand Up @@ -39,7 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tsl/platform/errors.h"
Expand All @@ -48,8 +48,8 @@ limitations under the License.
namespace mlir::quant::stablehlo {

using ::stablehlo::quantization::QuantizationConfig;
using ::tensorflow::MLIRImportOptions;
using ::tensorflow::SavedModelBundle;
using ::tensorflow::SavedModelImportOptions;
using ::tensorflow::SavedModelSignatureDefsToMlirImport;
using ::tensorflow::quantization::PreprocessAndFreezeGraph;

Expand All @@ -58,7 +58,7 @@ absl::StatusOr<ImportedMlirModuleOp> SavedModelToMlirModuleOp(
const std::unordered_set<std::string>& tags,
const std::vector<std::string>& signature_keys,
MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) {
MLIRImportOptions import_options;
SavedModelImportOptions import_options;
import_options.upgrade_legacy = true;
import_options.lift_variables = false;
import_options.include_variables_in_initializers = true;
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD
Expand Up @@ -56,8 +56,8 @@ cc_library(
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", # Required for CustomAggregator op registration.
"//tensorflow/compiler/mlir/quantization/tensorflow/cc:convert_asset_args",
"//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes",
"//tensorflow/compiler/mlir/quantization/tensorflow/debugging:dump_tensor_op", # Required for DumpTensor op registration.
"//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
"//tensorflow/compiler/mlir/quantization/tensorflow/debugging:dump_tensor_op",
"//tensorflow/compiler/mlir/tensorflow:saved_model_import_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes",
"//tensorflow/core:protos_all_cc",
Expand Down
Expand Up @@ -56,7 +56,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
Expand Down Expand Up @@ -98,7 +98,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportAndPreprocessSavedModel(
const bool deserialize_xla_call_module,
absl::flat_hash_map<std::string, std::string> &function_aliases) {
// Convert the SavedModelBundle to an MLIR module.
MLIRImportOptions import_options;
SavedModelImportOptions import_options;
import_options.upgrade_legacy = true;
import_options.lift_variables = false;
import_options.include_variables_in_initializers = true;
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/tensorflow/BUILD
Expand Up @@ -1668,7 +1668,7 @@ aliased_targets = [
"mlir_roundtrip_pass",
"mlir_roundtrip_pass_registration",
"mlir_roundtrip_flags",
"mlir_import_options",
"saved_model_import_options",
"translate_lib",
"translate_cl_options",
"translate_registration",
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/compiler/mlir/tensorflow/translate/BUILD
Expand Up @@ -34,7 +34,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
"//tensorflow/compiler/mlir/tensorflow:saved_model_import_options",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:translate_utils",
Expand Down Expand Up @@ -175,8 +175,8 @@ cc_library(
)

cc_library(
name = "mlir_import_options",
hdrs = ["mlir_import_options.h"],
name = "saved_model_import_options",
hdrs = ["saved_model_import_options.h"],
visibility = ["//visibility:public"],
)

Expand All @@ -194,7 +194,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:import_utils",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_import_options",
"//tensorflow/compiler/mlir/tensorflow:saved_model_import_options",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
Expand Down
18 changes: 9 additions & 9 deletions tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
Expand Up @@ -84,8 +84,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
Expand Down Expand Up @@ -2791,7 +2791,7 @@ class SavedModelObjectGraphImporter : public ImporterBase {
// Module.
static absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
SavedModelV2Bundle* saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, MLIRImportOptions options);
mlir::MLIRContext* context, SavedModelImportOptions options);

private:
explicit SavedModelObjectGraphImporter(
Expand Down Expand Up @@ -3345,7 +3345,7 @@ Status CreateSavedModelIR(
const ObjectNames& object_names, mlir::ModuleOp module,
const SavedObjectGraph& object_graph,
const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
SavedModelV2Bundle* saved_model, MLIRImportOptions import_options) {
SavedModelV2Bundle* saved_model, SavedModelImportOptions import_options) {
mlir::OpBuilder builder(module.getBodyRegion());
mlir::SymbolTable symbol_table(module);

Expand Down Expand Up @@ -3562,7 +3562,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model,
absl::Span<std::string> exported_names,
mlir::MLIRContext* context,
MLIRImportOptions import_options) {
SavedModelImportOptions import_options) {
LoadImporterDialects(*context);
GraphDebugInfo dummy_debug_info;
const GraphDebugInfo& debug_info =
Expand Down Expand Up @@ -3641,7 +3641,7 @@ SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model,
class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput {
public:
static absl::StatusOr<SimpleSavedModelMLIRImportInput> Create(
const MLIRImportOptions& import_options,
const SavedModelImportOptions& import_options,
const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) {
DCHECK(meta_graph_def);
GraphDef graph_def(meta_graph_def->graph_def());
Expand Down Expand Up @@ -4183,7 +4183,7 @@ class SavedModelSignatureDefImporter {
static absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> Convert(
const SavedModelBundle& bundle,
std::optional<absl::Span<const std::string>> exported_names,
mlir::MLIRContext* context, tensorflow::MLIRImportOptions options) {
mlir::MLIRContext* context, tensorflow::SavedModelImportOptions options) {
// debug_info might not be loaded with loader_lite.
GraphDebugInfo debug_info;
if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info;
Expand Down Expand Up @@ -4348,14 +4348,14 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertFunctionToMlir(

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelToMlir(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, MLIRImportOptions options) {
absl::Span<std::string> exported_names, SavedModelImportOptions options) {
return SavedModelObjectGraphImporter::Convert(saved_model, exported_names,
context, options);
}

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlir(
const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, MLIRImportOptions options) {
mlir::MLIRContext* context, SavedModelImportOptions options) {
std::optional<absl::Span<const std::string>> optional_exported_names;
// TODO(b/187062560): Change ConvertSavedModelV1ToMlir() to take an optional
// `exported_names` so that it can be configured to import only restore/init
Expand All @@ -4368,7 +4368,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlir(
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlirLite(
const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
std::optional<absl::Span<const std::string>> exported_names,
mlir::MLIRContext* context, MLIRImportOptions options) {
mlir::MLIRContext* context, SavedModelImportOptions options) {
TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create(
options, &meta_graph_def, debug_info));
return ConvertSavedModelV1ToMlirLite(
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/compiler/mlir/tensorflow/translate/import_model.h
Expand Up @@ -26,8 +26,8 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_debug_info.pb.h"
Expand Down Expand Up @@ -61,13 +61,14 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertFunctionToMlir(
// with tf_executor dialect.
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelToMlir(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, MLIRImportOptions options = {});
absl::Span<std::string> exported_names,
SavedModelImportOptions options = {});

// Given a V1 SavedModel, returns a MLIR module containing the functions,
// expressed with tf_executor dialect.
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlir(
const SavedModelBundle& saved_model, absl::Span<std::string> exported_names,
mlir::MLIRContext* context, MLIRImportOptions options = {});
mlir::MLIRContext* context, SavedModelImportOptions options = {});

// Given a V1 SavedModel, returns a MLIR module containing the functions,
// expressed with tf_executor dialect. It does not require a session to be
Expand All @@ -82,7 +83,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlir(
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertSavedModelV1ToMlirLite(
const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
std::optional<absl::Span<const std::string>> exported_names,
mlir::MLIRContext* context, MLIRImportOptions options);
mlir::MLIRContext* context, SavedModelImportOptions options);

// SavedModelMLIRImportInput is an adapter class for users to inject custom
// graph transformation logic on Tensorflow graphs before importing to MLIR. It
Expand Down
Expand Up @@ -13,16 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SAVED_MODEL_IMPORT_OPTIONS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SAVED_MODEL_IMPORT_OPTIONS_H_

namespace tensorflow {

// TODO(jpienaar): This file and class are confusingly named. This seems to be
// a SavedModel only import options file that exposes a subset of the
// GraphImportConfig options, but the naming would make one think it is more
// general.
struct MLIRImportOptions {
struct SavedModelImportOptions {
// If true, functionalize the input graph before importing it into MLIR.
bool upgrade_legacy = false;

Expand Down Expand Up @@ -53,4 +49,4 @@ struct MLIRImportOptions {

} // namespace tensorflow

#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SAVED_MODEL_IMPORT_OPTIONS_H_

0 comments on commit 57984a0

Please sign in to comment.