diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-09-05 06:01:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-05 06:06:23 -0700 |
commit | 858f4672e25825bc5e091a79fd4234f1968a278d (patch) | |
tree | c2f41b48ff253d104e37b4a8ef52614f239669f2 /tensorflow/contrib/lite/toco | |
parent | f15e8613aa42f7f2b1439c652a465438553df219 (diff) |
Minimum change for generating Eager ops with Toco.
PiperOrigin-RevId: 211621189
Diffstat (limited to 'tensorflow/contrib/lite/toco')
-rw-r--r-- | tensorflow/contrib/lite/toco/args.h | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 10 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.h | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.cc | 52 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.h | 51 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export_test.cc | 9 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 39 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.h | 8 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_cmdline_flags.cc | 18 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_flags.proto | 15 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_tooling.cc | 24 |
11 files changed, 183 insertions, 52 deletions
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 84f71dc7a7..f14dbc258b 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -247,6 +247,10 @@ struct ParsedTocoFlags { Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false); Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64); Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true); + // WARNING: Experimental interface, subject to change + Arg<bool> allow_eager_ops = Arg<bool>(false); + // WARNING: Experimental interface, subject to change + Arg<bool> force_eager_ops = Arg<bool>(false); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index cb6da21039..9bc23c4b3c 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -2061,8 +2061,14 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef( } Model* model = new Model; - const internal::ConverterMapType& converter_map = - internal::GetTensorFlowNodeConverterMap(); + internal::ConverterMapType converter_map; + + // This is used for the TFLite "Full Eager Mode" conversion. All the ops are + // imported as `TensorFlowUnsupportedOperator`, and later all these ops are + // converted to TFLite Eager ops. + if (!tf_import_flags.import_all_ops_as_unsupported) { + converter_map = internal::GetTensorFlowNodeConverterMap(); + } for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h index 2177872334..7db23f2d44 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.h +++ b/tensorflow/contrib/lite/toco/import_tensorflow.h @@ -27,6 +27,11 @@ struct TensorFlowImportFlags { // If true, control dependencies will be dropped immediately // during the import of the TensorFlow GraphDef. bool drop_control_dependency = false; + + // Do not recognize any op and import all ops as + // `TensorFlowUnsupportedOperator`. This is used to populated with the + // `force_eager_ops` flag. + bool import_all_ops_as_unsupported = false; }; std::unique_ptr<Model> ImportTensorFlowGraphDef( diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index c79469f59b..fee10b1dff 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -49,12 +49,21 @@ namespace { details::OperatorKey GetOperatorKey( const ::toco::Operator& op, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_eager_ops) { string custom_code; if (op.type == OperatorType::kUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast<const TensorFlowUnsupportedOperator&>(op); - custom_code = unsupported_op.tensorflow_op; + + // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way + // to populate a regular custom op. We need to find a way to fix this. + if (allow_eager_ops) { + custom_code = string(::tflite::kEagerCustomCodePrefix) + + unsupported_op.tensorflow_op; + } else { + custom_code = unsupported_op.tensorflow_op; + } } int version = 1; if (ops_by_type.count(op.type) != 0) { @@ -91,11 +100,12 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_eager_ops) { // First find a list of unique operator types. std::set<OperatorKey> keys; for (const auto& op : model.operators) { - keys.insert(GetOperatorKey(*op, ops_by_type)); + keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops)); } // Now assign indices to them and fill in the map. int index = 0; @@ -189,7 +199,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( const Model& model, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, - std::set<string>* error_summary) { + std::set<string>* error_summary, const ExportParams& params) { // Map from operator name to TF Lite enum value, for all builtins. std::map<string, BuiltinOperator> builtin_ops; for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { @@ -205,7 +215,8 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( std::map<int, Offset<OperatorCode>> ordered_opcodes; for (const auto& op : model.operators) { - const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type); + const details::OperatorKey operator_key = + GetOperatorKey(*op, ops_by_type, params.allow_eager_ops); int op_index = operators_map.at(operator_key); int op_version = operator_key.version; @@ -252,7 +263,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators( const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, const details::OperatorsMap& operators_map, const details::TensorsMap& tensors_map, FlatBufferBuilder* builder, - std::set<int32_t>* variable_tensor_indices) { + std::set<int32_t>* variable_tensor_indices, const ExportParams& params) { variable_tensor_indices->clear(); // The operators are in execution order, so we just follow tf.mini order. @@ -269,7 +280,8 @@ Offset<Vector<Offset<Operator>>> ExportOperators( outputs.push_back(tensors_map.at(output)); } - int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); + int op_index = operators_map.at( + GetOperatorKey(*op, ops_by_type, params.allow_eager_ops)); auto tflite_op_it = ops_by_type.find(op->type); BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() @@ -320,16 +332,15 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers( return builder->CreateVector(buffer_vector); } -void Export(const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents) { - const auto ops_by_type = BuildOperatorByTypeMap(); - Export(model, allow_custom_ops, quantize_weights, output_file_contents, - ops_by_type); +void Export(const Model& model, string* output_file_contents, + const ExportParams& params) { + const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops); + Export(model, output_file_contents, params, ops_by_type); } void Export( - const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents, + const Model& model, string* output_file_contents, + const ExportParams& params, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); @@ -337,7 +348,8 @@ void Export( details::LoadTensorsMap(model, &tensors_map); details::OperatorsMap operators_map; - details::LoadOperatorsMap(model, &operators_map, ops_by_type); + details::LoadOperatorsMap(model, &operators_map, ops_by_type, + params.allow_eager_ops); std::vector<const Array*> buffers_to_write; Array empty_array; @@ -345,7 +357,7 @@ void Export( std::set<string> error_summary; auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, - &builder, &error_summary); + &builder, &error_summary, params); for (const auto& op : model.operators) { if (op->type == OperatorType::kFakeQuant) { @@ -355,7 +367,7 @@ void Export( "for --std_values and --mean_values."; } } - if (!allow_custom_ops && !error_summary.empty()) { + if (!params.allow_custom_ops && !error_summary.empty()) { // Remove ExpandDims and ReorderAxes from unimplemented list unless they // compose the list. Both ops are removed during graph transformations. // However, if an op is unimplemented earlier in the model, the graph @@ -383,7 +395,7 @@ void Export( std::set<int32_t> variable_tensor_indices; auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map, - &builder, &variable_tensor_indices); + &builder, &variable_tensor_indices, params); auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write, variable_tensor_indices); @@ -402,7 +414,7 @@ void Export( builder.CreateVector(subgraphs), description, buffers); ::tflite::FinishModelBuffer(builder, new_model_location); - if (quantize_weights) { + if (params.quantize_weights) { // Call the quantize_weights tool. LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. " "dump_graphviz will only output the model before this " diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 915d5dd3d6..b070a38768 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -23,22 +23,54 @@ namespace toco { namespace tflite { +// The parameters for exporting a TFLite model. +struct ExportParams { + bool allow_custom_ops = false; + bool allow_eager_ops = false; + bool quantize_weights = false; +}; + // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the // result in the given string. -void Export(const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents); +void Export(const Model& model, string* output_file_contents, + const ExportParams& params); + +// Export API with custom TFLite operator mapping. +void Export( + const Model& model, string* output_file_contents, + const ExportParams& params, + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); -// This if backward-compatibility. +// This is for backward-compatibility. // TODO(ycling): Remove the deprecated entry functions. -inline void Export(const Model& model, string* output_file_contents) { - Export(model, true, false, output_file_contents); +inline void Export(const Model& model, bool allow_custom_ops, + bool quantize_weights, string* output_file_contents) { + ExportParams params; + params.allow_custom_ops = allow_custom_ops; + params.quantize_weights = quantize_weights; + Export(model, output_file_contents, params); } -// Export API with custom TFLite operator mapping. -void Export( +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export( const Model& model, bool allow_custom_ops, bool quantize_weights, string* output_file_contents, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { + ExportParams params; + params.allow_custom_ops = allow_custom_ops; + params.quantize_weights = quantize_weights; + Export(model, output_file_contents, params, ops_by_type); +} + +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export(const Model& model, string* output_file_contents) { + ExportParams params; + params.allow_custom_ops = true; + Export(model, output_file_contents, params); + Export(model, true, false, output_file_contents); +} namespace details { @@ -88,7 +120,8 @@ using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>; void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_eager_ops); } // namespace details } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 4994ea30de..8d4d197c46 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -105,7 +105,8 @@ TEST_F(ExportTest, LoadOperatorsMap) { details::OperatorsMap operators; const auto ops_by_type = BuildOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + // TODO(ycling): Add a test for allow_eager_ops. + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); @@ -253,7 +254,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(1, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); @@ -264,7 +265,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(1, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); @@ -276,7 +277,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(2, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index a314c8d53a..eb0f7c443a 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1149,7 +1149,9 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions, class TensorFlowUnsupported : public BaseOperator { public: - using BaseOperator::BaseOperator; + TensorFlowUnsupported(const string& name, OperatorType type, + bool allow_eager_ops) + : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {} Options Serialize(const Operator& op, flatbuffers::FlatBufferBuilder* builder) const override { @@ -1165,6 +1167,9 @@ class TensorFlowUnsupported : public BaseOperator { std::unique_ptr<Operator> Deserialize( const BuiltinOptions* builtin_options, const CustomOptions* custom_options) const override { + // Deserializing Eager ops doesn't work now. + // TODO(ycling): Revisit and decide if we should fix the flow for importing + // TFLite models with Eager ops. auto op = absl::make_unique<TensorFlowUnsupportedOperator>(); if (custom_options) { auto flexbuffer_map = @@ -1185,6 +1190,16 @@ class TensorFlowUnsupported : public BaseOperator { return std::unique_ptr<flexbuffers::Builder>(); } + if (allow_eager_ops_) { + fbb->Vector([&]() { + fbb->String(node_def.op()); + fbb->String(op.tensorflow_node_def); + }); + fbb->Finish(); + LOG(INFO) << "Writing eager op: " << node_def.op(); + return std::unique_ptr<flexbuffers::Builder>(fbb.release()); + } + bool has_valid_attr = false; size_t map_start = fbb->StartMap(); for (const auto& pair : node_def.attr()) { @@ -1285,11 +1300,15 @@ class TensorFlowUnsupported : public BaseOperator { // custom ops. return 1; } + + private: + const bool allow_eager_ops_; }; namespace { // Build a vector containing all the known operators. -std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { +std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList( + bool allow_eager_ops = false) { std::vector<std::unique_ptr<BaseOperator>> ops; using tensorflow::MakeUnique; // Builtin Operators. @@ -1400,8 +1419,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.push_back(MakeUnique<CTCBeamSearchDecoder>( "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); - ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED", - OperatorType::kUnsupported)); + ops.push_back(MakeUnique<TensorFlowUnsupported>( + "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops)); // There operators are supported by Toco, but not by TF Lite, and has no // attributes. @@ -1474,10 +1493,12 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { } } // namespace -std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() { +std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap( + bool allow_eager_ops) { std::map<OperatorType, std::unique_ptr<BaseOperator>> result; - std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList(); + std::vector<std::unique_ptr<BaseOperator>> ops = + BuildOperatorList(allow_eager_ops); for (auto& op : ops) { result[op->type()] = std::move(op); } @@ -1485,10 +1506,12 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() { return result; } -std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() { +std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( + bool allow_eager_ops) { std::map<string, std::unique_ptr<BaseOperator>> result; - std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList(); + std::vector<std::unique_ptr<BaseOperator>> ops = + BuildOperatorList(allow_eager_ops); for (auto& op : ops) { result[op->name()] = std::move(op); } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index d9ea23edf2..702fb28ea6 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -26,11 +26,15 @@ namespace tflite { class BaseOperator; // Return a map contained all know TF Lite Operators, keyed by their names. -std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(); +// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops) +// is ugly here. Consider refactoring. +std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( + bool allow_eager_ops = false); // Return a map contained all know TF Lite Operators, keyed by the type of // their tf.mini counterparts. -std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(); +std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap( + bool allow_eager_ops = false); // These are the flatbuffer types for custom and builtin options. using CustomOptions = flatbuffers::Vector<uint8_t>; diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index f83a290195..b6aebc0470 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -165,7 +165,13 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.post_training_quantize.default_value(), "Boolean indicating whether to quantize the weights of the " "converted float model. Model size will be reduced and there will " - "be latency improvements (at the cost of accuracy).")}; + "be latency improvements (at the cost of accuracy)."), + // WARNING: Experimental interface, subject to change + Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(), + parsed_flags.allow_eager_ops.default_value(), ""), + // WARNING: Experimental interface, subject to change + Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(), + parsed_flags.force_eager_ops.default_value(), "")}; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); if (asked_for_help) { @@ -260,6 +266,16 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone); + READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone); + READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone); + + if (parsed_toco_flags.force_eager_ops.value() && + !parsed_toco_flags.allow_eager_ops.value()) { + // TODO(ycling): Consider to enforce `allow_eager_ops` when + // `force_eager_ops` is true. + LOG(WARNING) << "--force_eager_ops should always be used with " + "--allow_eager_ops."; + } // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index c1dd621429..53d60fed05 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 27. +// Next ID to use: 29. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -189,4 +189,17 @@ message TocoFlags { // model. Model size will be reduced and there will be latency improvements // (at the cost of accuracy). optional bool post_training_quantize = 26 [default = false]; + + // When enabled, unsupported ops will be converted to TFLite Eager ops. + // TODO(ycling): Consider to rename the following 2 flags and don't call it + // "Eager". + // `allow_eager_ops` should always be used with `allow_custom_ops`. + // WARNING: Experimental interface, subject to change + optional bool allow_eager_ops = 27 [default = false]; + + // When enabled, all TensorFlow ops will be converted to TFLite Eager + // ops directly. This will force `allow_eager_ops` to true. + // `force_eager_ops` should always be used with `allow_eager_ops`. + // WARNING: Experimental interface, subject to change + optional bool force_eager_ops = 28 [default = false]; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 7db7acb44d..a7c17156b1 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -197,6 +197,10 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags, toco_flags.has_drop_control_dependency() ? toco_flags.drop_control_dependency() : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF); + + tf_import_flags.import_all_ops_as_unsupported = + toco_flags.force_eager_ops(); + model = ImportTensorFlowGraphDef(model_flags, tf_import_flags, input_file_contents); break; @@ -397,11 +401,21 @@ void Export(const TocoFlags& toco_flags, const Model& model, case TENSORFLOW_GRAPHDEF: ExportTensorFlowGraphDef(model, output_file_contents); break; - case TFLITE: - toco::tflite::Export(model, allow_custom_ops, - toco_flags.post_training_quantize(), - output_file_contents); - break; + case TFLITE: { + toco::tflite::ExportParams params; + + // Always allow custom ops when eager ops are allowed. + if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) { + params.allow_eager_ops = true; + params.allow_custom_ops = true; + } else if (allow_custom_ops) { + params.allow_custom_ops = true; + } + + params.quantize_weights = toco_flags.post_training_quantize(); + + toco::tflite::Export(model, output_file_contents, params); + } break; case GRAPHVIZ_DOT: DumpGraphviz(model, output_file_contents); break; |