diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-10-05 13:33:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 13:43:48 -0700 |
commit | 4d69a79b1ebd0c2180959c1047fbc9db106701e1 (patch) | |
tree | b18202af4984d906ec860e670daec7836931ee77 /tensorflow/contrib/lite/toco | |
parent | 0c37dcc02f54395d2bde3cc5850574c8f98f1b46 (diff) |
Handle Range & BatchMatMul in partial Flex mode
PiperOrigin-RevId: 215957535
Diffstat (limited to 'tensorflow/contrib/lite/toco')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 37 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/model.h | 9 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.cc | 83 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export_test.cc | 34 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 32 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.h | 6 |
6 files changed, 155 insertions, 46 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 5eaf6e27fc..133ef79a34 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -477,6 +477,30 @@ string CreateConstArray(Model* model, string const& name, return array_name; } +// Retain TensorFlow NodeDef in Toco Operator. +// +// If an op is supported by Toco but not supported by TFLite, TFLite exporter +// will use the retained NodeDef to populate a Flex op when Flex mode is +// enabled. +// +// This can't be easily applied to all operations, because a TensorFlow node +// may become multiple Toco operators. Thus we need to call this function in +// operator conversion functions one by one whenever feasible. +// +// This may cause problems if a graph transformation rule changes parameters +// of the node. When calling this function, please check if any existing +// graph transformation rule will change an existing operator with the same +// type. +// +// This provides a route to handle Toco-supported & TFLite-unsupported ops +// in Flex mode. However it's not a solid solution. Eventually we should +// get rid of this. +// TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove +// this function. +void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) { + node.SerializeToString(&op->tensorflow_node_def); +} + tensorflow::Status ConvertConstOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -990,6 +1014,10 @@ tensorflow::Status ConvertBatchMatMulOperator( auto* batch_matmul = new BatchMatMulOperator; batch_matmul->inputs = {node.input(0), node.input(1)}; batch_matmul->outputs = {node.name()}; + + // For Flex mode. Please read the comments of the function. + RetainTensorFlowNodeDef(node, batch_matmul); + model->operators.emplace_back(batch_matmul); return tensorflow::Status::OK(); } @@ -1081,7 +1109,10 @@ tensorflow::Status ConvertUnsupportedOperator( auto* op = new TensorFlowUnsupportedOperator; op->tensorflow_op = node.op(); - node.SerializeToString(&op->tensorflow_node_def); + + // For Flex mode. Please read the comments of the function. + RetainTensorFlowNodeDef(node, op); + model->operators.emplace_back(op); // Parse inputs. @@ -1605,6 +1636,10 @@ tensorflow::Status ConvertRangeOperator( op->inputs.push_back(node.input(1)); op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); + + // For Flex mode. Please read the comments of the function. + RetainTensorFlowNodeDef(node, op); + model->operators.emplace_back(op); return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 6e207fdf54..61f1f095e9 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -376,6 +376,13 @@ struct Operator { // looks unused. bool unresolved_outputs = false; + // A serialized tensorflow::NodeDef string. + // The field is filled only when importing from TensorFlow. + // It's guaranteed to be filled for `TensorFlowUnsupportedOperator`. + // It's not guaranteed to be filled for other ops. Ops created by graph + // transformations won't have TensorFlow NodeDef. + string tensorflow_node_def; + protected: // Constructor used by subclasses for specific OperatorType's. explicit Operator(OperatorType t) @@ -1535,8 +1542,6 @@ struct TensorFlowUnsupportedOperator : Operator { // The original TF operation type. Used for diagnostic purposes. string tensorflow_op; - // A serialized tensorflow::NodeDef string. - string tensorflow_node_def; // A boolean indicating if the unsupported op should be treated as quantized. bool quantized = false; // A boolean indicating if the unsupported op output should allow float values diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index f6f76e48a4..3b34cd6285 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -95,11 +95,13 @@ OperatorKey GetOperatorKey( const ::toco::Operator& op, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, bool allow_flex_ops) { + // Get the op name (by Toco definition). string name = HelpfulOperatorTypeName(op); - const auto& builtin_ops = GetBuiltinOpsMap(); bool is_builtin = false; OperatorKey key; + + const auto& builtin_ops = GetBuiltinOpsMap(); if (ops_by_type.count(op.type) != 0) { key.version = ops_by_type.at(op.type)->GetVersion(op); name = ops_by_type.at(op.type)->name(); @@ -110,37 +112,46 @@ OperatorKey GetOperatorKey( // For TFLite supported builtin ops, find out its BuiltinOperator enum used // in FlatBuffer. key.type = builtin_ops.at(name); - } else { - key.type = BuiltinOperator_CUSTOM; - - key.is_custom_op = true; - if (op.type == OperatorType::kUnsupported) { - const TensorFlowUnsupportedOperator& unsupported_op = - static_cast<const TensorFlowUnsupportedOperator&>(op); - const auto tensorflow_op = unsupported_op.tensorflow_op; - - // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way - // to populate a regular custom op. We need to find a way to fix this. - if (allow_flex_ops) { - // Memorize the original TensorFlow op name. - key.flex_tensorflow_op = tensorflow_op; - // Prefix the custom code of the flex op. - key.custom_code = - string(::tflite::kFlexCustomCodePrefix) + tensorflow_op; - key.is_flex_op = true; - - if (IsControlFlowOp(tensorflow_op)) { - key.is_unsupported_flex_op = true; - } - } else { - key.custom_code = tensorflow_op; - } + return key; + } + + // The logic below is all for custom ops. + key.is_custom_op = true; + key.type = BuiltinOperator_CUSTOM; + + if (op.type == OperatorType::kUnsupported) { + const TensorFlowUnsupportedOperator& unsupported_op = + static_cast<const TensorFlowUnsupportedOperator&>(op); + const auto tensorflow_op = unsupported_op.tensorflow_op; + + // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way + // to populate a regular custom op. We need to find a way to fix this. + if (allow_flex_ops) { + key.is_flex_op = true; + key.flex_tensorflow_op = tensorflow_op; + key.custom_code = + string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op; } else { - // For Toco-supported/TFLite-unsupported ops, currently we produce a - // custom op. This gives developers a chance to implement custom ops. - // TODO(b/116800229): Also produce Toco-supported/TFLite-unsupported ops - // as Flex ops when Flex mode is enabled. - key.custom_code = name; + key.custom_code = tensorflow_op; + } + } else if (allow_flex_ops && !op.tensorflow_node_def.empty()) { + // For Toco-supported/TFLite-unsupported ops, if the TensorFlow NodeDef + // is retained in the Toco Operator, we produce a Flex op if Flex mode + // is enabled. + key.is_flex_op = true; + key.flex_tensorflow_op = name; + key.custom_code = + string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op; + } else { + // If Flex is disabled or the original TensorFlow NodeDef isn't available, + // we produce a custom op. This gives developers a chance to implemenr + // custom ops. + key.custom_code = name; + } + + if (key.is_flex_op) { + if (IsControlFlowOp(key.flex_tensorflow_op)) { + key.is_unsupported_flex_op = true; } } return key; @@ -323,8 +334,9 @@ Offset<Vector<Offset<Operator>>> ExportOperators( outputs.push_back(tensors_map.at(output)); } - int op_index = operators_map.at( - details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops)); + const auto key = + details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops); + int op_index = operators_map.at(key); auto tflite_op_it = ops_by_type.find(op->type); BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() @@ -349,6 +361,11 @@ Offset<Vector<Offset<Operator>>> ExportOperators( variable_tensor_indices->insert(variable_tensor_index); } } + } else if (key.is_flex_op && !op->tensorflow_node_def.empty()) { + auto fbb = WriteFlexOpOptions(op->tensorflow_node_def); + if (fbb) { + options = Options::Custom(builder->CreateVector(fbb->GetBuffer())); + } } // The only supported CustomOptionFormat is FLEXBUFFERS now. op_vector.push_back(CreateOperator( diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index d48ab78285..eda1aa78a3 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" #include "tensorflow/contrib/lite/toco/tflite/operator.h" #include "tensorflow/contrib/lite/toco/tflite/types.h" +#include "tensorflow/core/framework/node_def.pb.h" namespace toco { namespace tflite { @@ -382,6 +383,39 @@ TEST(OperatorKeyTest, TestFlexWithControlFlowOp) { EXPECT_TRUE(key.is_unsupported_flex_op); } +TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) { + // Test Toco-supported/TFLite-unsupported operators. + // TODO(ycling): The test will be broken if Range is implemented in TFLite. + // Find a more robust way to test the fallback logic. + auto op = absl::make_unique<RangeOperator>(); + + const auto ops_by_type = BuildOperatorByTypeMap(); + + { + // If NodeDef isn't retained in the Toco op, a regular custom op + // will be exported. + const auto key = details::GetOperatorKey(*op, ops_by_type, true); + EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM); + EXPECT_EQ(key.custom_code, "Range"); + EXPECT_EQ(key.version, 1); + EXPECT_FALSE(key.is_flex_op); + } + + ::tensorflow::NodeDef node_def; + node_def.set_name("Range"); + node_def.set_op("Range"); + node_def.SerializeToString(&op->tensorflow_node_def); + + { + // If NodeDef is retained in the Toco op, a Flex op will be exported. + const auto key = details::GetOperatorKey(*op, ops_by_type, true); + EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM); + EXPECT_EQ(key.custom_code, "FlexRange"); + EXPECT_EQ(key.version, 1); + EXPECT_TRUE(key.is_flex_op); + } +} + // TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators. } // namespace diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 9addbb81e7..ed37535fe0 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1157,6 +1157,25 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions, int GetVersion(const Operator& op) const override { return 1; } }; +std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions( + const string& tensorflow_node_def) { + auto fbb = absl::make_unique<flexbuffers::Builder>(); + + ::tensorflow::NodeDef node_def; + if (!node_def.ParseFromString(tensorflow_node_def)) { + LOG(ERROR) << "Failed to parse TensorFlow NodeDef"; + return {}; + } + + fbb->Vector([&]() { + fbb->String(node_def.op()); + fbb->String(tensorflow_node_def); + }); + fbb->Finish(); + LOG(INFO) << "Writing flex op: " << node_def.op(); + return std::unique_ptr<flexbuffers::Builder>(fbb.release()); +} + class TensorFlowUnsupported : public BaseOperator { public: TensorFlowUnsupported(const string& name, OperatorType type, @@ -1192,6 +1211,9 @@ class TensorFlowUnsupported : public BaseOperator { std::unique_ptr<flexbuffers::Builder> WriteOptions( const TensorFlowUnsupportedOperator& op) const { + if (allow_flex_ops_) { + return WriteFlexOpOptions(op.tensorflow_node_def); + } auto fbb = absl::make_unique<flexbuffers::Builder>(); ::tensorflow::NodeDef node_def; @@ -1200,16 +1222,6 @@ class TensorFlowUnsupported : public BaseOperator { return std::unique_ptr<flexbuffers::Builder>(); } - if (allow_flex_ops_) { - fbb->Vector([&]() { - fbb->String(node_def.op()); - fbb->String(op.tensorflow_node_def); - }); - fbb->Finish(); - LOG(INFO) << "Writing flex op: " << node_def.op(); - return std::unique_ptr<flexbuffers::Builder>(fbb.release()); - } - bool has_valid_attr = false; size_t map_start = fbb->StartMap(); for (const auto& pair : node_def.attr()) { diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 13d9f6c49a..6e4e0a16d1 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ #include "flatbuffers/flatbuffers.h" +#include "flatbuffers/flexbuffers.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/toco/model.h" @@ -36,6 +37,11 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap( bool allow_flex_ops = false); +// Write the custom option FlexBuffer with a serialized TensorFlow NodeDef +// for a Flex op. +std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions( + const string& tensorflow_node_def); + // These are the flatbuffer types for custom and builtin options. using CustomOptions = flatbuffers::Vector<uint8_t>; using BuiltinOptions = void; |