aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-09-05 06:01:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 06:06:23 -0700
commit858f4672e25825bc5e091a79fd4234f1968a278d (patch)
treec2f41b48ff253d104e37b4a8ef52614f239669f2
parentf15e8613aa42f7f2b1439c652a465438553df219 (diff)
Minimum change for generating Eager ops with Toco.
PiperOrigin-RevId: 211621189
-rw-r--r--tensorflow/contrib/lite/toco/args.h4
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc10
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h5
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc52
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h51
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc9
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc39
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h8
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc18
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto15
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc24
11 files changed, 183 insertions, 52 deletions
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 84f71dc7a7..f14dbc258b 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -247,6 +247,10 @@ struct ParsedTocoFlags {
Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
+ // WARNING: Experimental interface, subject to change
+ Arg<bool> allow_eager_ops = Arg<bool>(false);
+ // WARNING: Experimental interface, subject to change
+ Arg<bool> force_eager_ops = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index cb6da21039..9bc23c4b3c 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -2061,8 +2061,14 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
}
Model* model = new Model;
- const internal::ConverterMapType& converter_map =
- internal::GetTensorFlowNodeConverterMap();
+ internal::ConverterMapType converter_map;
+
+ // This is used for the TFLite "Full Eager Mode" conversion. All the ops are
+ // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
+ // converted to TFLite Eager ops.
+ if (!tf_import_flags.import_all_ops_as_unsupported) {
+ converter_map = internal::GetTensorFlowNodeConverterMap();
+ }
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
index 2177872334..7db23f2d44 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -27,6 +27,11 @@ struct TensorFlowImportFlags {
// If true, control dependencies will be dropped immediately
// during the import of the TensorFlow GraphDef.
bool drop_control_dependency = false;
+
+ // Do not recognize any op and import all ops as
+ // `TensorFlowUnsupportedOperator`. This is used to populated with the
+ // `force_eager_ops` flag.
+ bool import_all_ops_as_unsupported = false;
};
std::unique_ptr<Model> ImportTensorFlowGraphDef(
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index c79469f59b..fee10b1dff 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -49,12 +49,21 @@ namespace {
details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops) {
string custom_code;
if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
- custom_code = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_eager_ops) {
+ custom_code = string(::tflite::kEagerCustomCodePrefix) +
+ unsupported_op.tensorflow_op;
+ } else {
+ custom_code = unsupported_op.tensorflow_op;
+ }
}
int version = 1;
if (ops_by_type.count(op.type) != 0) {
@@ -91,11 +100,12 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops) {
// First find a list of unique operator types.
std::set<OperatorKey> keys;
for (const auto& op : model.operators) {
- keys.insert(GetOperatorKey(*op, ops_by_type));
+ keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops));
}
// Now assign indices to them and fill in the map.
int index = 0;
@@ -189,7 +199,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* error_summary) {
+ std::set<string>* error_summary, const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -205,7 +215,8 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
std::map<int, Offset<OperatorCode>> ordered_opcodes;
for (const auto& op : model.operators) {
- const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type);
+ const details::OperatorKey operator_key =
+ GetOperatorKey(*op, ops_by_type, params.allow_eager_ops);
int op_index = operators_map.at(operator_key);
int op_version = operator_key.version;
@@ -252,7 +263,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map,
const details::TensorsMap& tensors_map, FlatBufferBuilder* builder,
- std::set<int32_t>* variable_tensor_indices) {
+ std::set<int32_t>* variable_tensor_indices, const ExportParams& params) {
variable_tensor_indices->clear();
// The operators are in execution order, so we just follow tf.mini order.
@@ -269,7 +280,8 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
+ int op_index = operators_map.at(
+ GetOperatorKey(*op, ops_by_type, params.allow_eager_ops));
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -320,16 +332,15 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
return builder->CreateVector(buffer_vector);
}
-void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
- string* output_file_contents) {
- const auto ops_by_type = BuildOperatorByTypeMap();
- Export(model, allow_custom_ops, quantize_weights, output_file_contents,
- ops_by_type);
+void Export(const Model& model, string* output_file_contents,
+ const ExportParams& params) {
+ const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops);
+ Export(model, output_file_contents, params, ops_by_type);
}
void Export(
- const Model& model, bool allow_custom_ops, bool quantize_weights,
- string* output_file_contents,
+ const Model& model, string* output_file_contents,
+ const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
@@ -337,7 +348,8 @@ void Export(
details::LoadTensorsMap(model, &tensors_map);
details::OperatorsMap operators_map;
- details::LoadOperatorsMap(model, &operators_map, ops_by_type);
+ details::LoadOperatorsMap(model, &operators_map, ops_by_type,
+ params.allow_eager_ops);
std::vector<const Array*> buffers_to_write;
Array empty_array;
@@ -345,7 +357,7 @@ void Export(
std::set<string> error_summary;
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &error_summary);
+ &builder, &error_summary, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -355,7 +367,7 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!allow_custom_ops && !error_summary.empty()) {
+ if (!params.allow_custom_ops && !error_summary.empty()) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
// compose the list. Both ops are removed during graph transformations.
// However, if an op is unimplemented earlier in the model, the graph
@@ -383,7 +395,7 @@ void Export(
std::set<int32_t> variable_tensor_indices;
auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map,
- &builder, &variable_tensor_indices);
+ &builder, &variable_tensor_indices, params);
auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write,
variable_tensor_indices);
@@ -402,7 +414,7 @@ void Export(
builder.CreateVector(subgraphs), description, buffers);
::tflite::FinishModelBuffer(builder, new_model_location);
- if (quantize_weights) {
+ if (params.quantize_weights) {
// Call the quantize_weights tool.
LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
"dump_graphviz will only output the model before this "
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 915d5dd3d6..b070a38768 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -23,22 +23,54 @@ namespace toco {
namespace tflite {
+// The parameters for exporting a TFLite model.
+struct ExportParams {
+ bool allow_custom_ops = false;
+ bool allow_eager_ops = false;
+ bool quantize_weights = false;
+};
+
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
// result in the given string.
-void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
- string* output_file_contents);
+void Export(const Model& model, string* output_file_contents,
+ const ExportParams& params);
+
+// Export API with custom TFLite operator mapping.
+void Export(
+ const Model& model, string* output_file_contents,
+ const ExportParams& params,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
-// This if backward-compatibility.
+// This is for backward-compatibility.
// TODO(ycling): Remove the deprecated entry functions.
-inline void Export(const Model& model, string* output_file_contents) {
- Export(model, true, false, output_file_contents);
+inline void Export(const Model& model, bool allow_custom_ops,
+ bool quantize_weights, string* output_file_contents) {
+ ExportParams params;
+ params.allow_custom_ops = allow_custom_ops;
+ params.quantize_weights = quantize_weights;
+ Export(model, output_file_contents, params);
}
-// Export API with custom TFLite operator mapping.
-void Export(
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(
const Model& model, bool allow_custom_ops, bool quantize_weights,
string* output_file_contents,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ ExportParams params;
+ params.allow_custom_ops = allow_custom_ops;
+ params.quantize_weights = quantize_weights;
+ Export(model, output_file_contents, params, ops_by_type);
+}
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(const Model& model, string* output_file_contents) {
+ ExportParams params;
+ params.allow_custom_ops = true;
+ Export(model, output_file_contents, params);
+ Export(model, true, false, output_file_contents);
+}
namespace details {
@@ -88,7 +120,8 @@ using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops);
} // namespace details
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 4994ea30de..8d4d197c46 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -105,7 +105,8 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ // TODO(ycling): Add a test for allow_eager_ops.
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
@@ -253,7 +254,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
@@ -264,7 +265,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
@@ -276,7 +277,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(2, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index a314c8d53a..eb0f7c443a 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1149,7 +1149,9 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
class TensorFlowUnsupported : public BaseOperator {
public:
- using BaseOperator::BaseOperator;
+ TensorFlowUnsupported(const string& name, OperatorType type,
+ bool allow_eager_ops)
+ : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
@@ -1165,6 +1167,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<Operator> Deserialize(
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const override {
+ // Deserializing Eager ops doesn't work now.
+ // TODO(ycling): Revisit and decide if we should fix the flow for importing
+ // TFLite models with Eager ops.
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
if (custom_options) {
auto flexbuffer_map =
@@ -1185,6 +1190,16 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
+ if (allow_eager_ops_) {
+ fbb->Vector([&]() {
+ fbb->String(node_def.op());
+ fbb->String(op.tensorflow_node_def);
+ });
+ fbb->Finish();
+ LOG(INFO) << "Writing eager op: " << node_def.op();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+ }
+
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
@@ -1285,11 +1300,15 @@ class TensorFlowUnsupported : public BaseOperator {
// custom ops.
return 1;
}
+
+ private:
+ const bool allow_eager_ops_;
};
namespace {
// Build a vector containing all the known operators.
-std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
+std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
+ bool allow_eager_ops = false) {
std::vector<std::unique_ptr<BaseOperator>> ops;
using tensorflow::MakeUnique;
// Builtin Operators.
@@ -1400,8 +1419,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
- ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
- OperatorType::kUnsupported));
+ ops.push_back(MakeUnique<TensorFlowUnsupported>(
+ "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops));
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
@@ -1474,10 +1493,12 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
}
} // namespace
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+ bool allow_eager_ops) {
std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
- std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ std::vector<std::unique_ptr<BaseOperator>> ops =
+ BuildOperatorList(allow_eager_ops);
for (auto& op : ops) {
result[op->type()] = std::move(op);
}
@@ -1485,10 +1506,12 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
return result;
}
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+ bool allow_eager_ops) {
std::map<string, std::unique_ptr<BaseOperator>> result;
- std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ std::vector<std::unique_ptr<BaseOperator>> ops =
+ BuildOperatorList(allow_eager_ops);
for (auto& op : ops) {
result[op->name()] = std::move(op);
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index d9ea23edf2..702fb28ea6 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -26,11 +26,15 @@ namespace tflite {
class BaseOperator;
// Return a map contained all know TF Lite Operators, keyed by their names.
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap();
+// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops)
+// is ugly here. Consider refactoring.
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+ bool allow_eager_ops = false);
// Return a map contained all know TF Lite Operators, keyed by the type of
// their tf.mini counterparts.
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap();
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+ bool allow_eager_ops = false);
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index f83a290195..b6aebc0470 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -165,7 +165,13 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.post_training_quantize.default_value(),
"Boolean indicating whether to quantize the weights of the "
"converted float model. Model size will be reduced and there will "
- "be latency improvements (at the cost of accuracy).")};
+ "be latency improvements (at the cost of accuracy)."),
+ // WARNING: Experimental interface, subject to change
+ Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(),
+ parsed_flags.allow_eager_ops.default_value(), ""),
+ // WARNING: Experimental interface, subject to change
+ Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(),
+ parsed_flags.force_eager_ops.default_value(), "")};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
if (asked_for_help) {
@@ -260,6 +266,16 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone);
+ READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone);
+
+ if (parsed_toco_flags.force_eager_ops.value() &&
+ !parsed_toco_flags.allow_eager_ops.value()) {
+ // TODO(ycling): Consider to enforce `allow_eager_ops` when
+ // `force_eager_ops` is true.
+ LOG(WARNING) << "--force_eager_ops should always be used with "
+ "--allow_eager_ops.";
+ }
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index c1dd621429..53d60fed05 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 27.
+// Next ID to use: 29.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -189,4 +189,17 @@ message TocoFlags {
// model. Model size will be reduced and there will be latency improvements
// (at the cost of accuracy).
optional bool post_training_quantize = 26 [default = false];
+
+ // When enabled, unsupported ops will be converted to TFLite Eager ops.
+ // TODO(ycling): Consider to rename the following 2 flags and don't call it
+ // "Eager".
+ // `allow_eager_ops` should always be used with `allow_custom_ops`.
+ // WARNING: Experimental interface, subject to change
+ optional bool allow_eager_ops = 27 [default = false];
+
+ // When enabled, all TensorFlow ops will be converted to TFLite Eager
+ // ops directly. This will force `allow_eager_ops` to true.
+ // `force_eager_ops` should always be used with `allow_eager_ops`.
+ // WARNING: Experimental interface, subject to change
+ optional bool force_eager_ops = 28 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 7db7acb44d..a7c17156b1 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -197,6 +197,10 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
toco_flags.has_drop_control_dependency()
? toco_flags.drop_control_dependency()
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
+
+ tf_import_flags.import_all_ops_as_unsupported =
+ toco_flags.force_eager_ops();
+
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
input_file_contents);
break;
@@ -397,11 +401,21 @@ void Export(const TocoFlags& toco_flags, const Model& model,
case TENSORFLOW_GRAPHDEF:
ExportTensorFlowGraphDef(model, output_file_contents);
break;
- case TFLITE:
- toco::tflite::Export(model, allow_custom_ops,
- toco_flags.post_training_quantize(),
- output_file_contents);
- break;
+ case TFLITE: {
+ toco::tflite::ExportParams params;
+
+ // Always allow custom ops when eager ops are allowed.
+ if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) {
+ params.allow_eager_ops = true;
+ params.allow_custom_ops = true;
+ } else if (allow_custom_ops) {
+ params.allow_custom_ops = true;
+ }
+
+ params.quantize_weights = toco_flags.post_training_quantize();
+
+ toco::tflite::Export(model, output_file_contents, params);
+ } break;
case GRAPHVIZ_DOT:
DumpGraphviz(model, output_file_contents);
break;