aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-05 13:33:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 13:43:48 -0700
commit4d69a79b1ebd0c2180959c1047fbc9db106701e1 (patch)
treeb18202af4984d906ec860e670daec7836931ee77 /tensorflow/contrib
parent0c37dcc02f54395d2bde3cc5850574c8f98f1b46 (diff)
Handle Range & BatchMatMul in partial Flex mode
PiperOrigin-RevId: 215957535
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc37
-rw-r--r--tensorflow/contrib/lite/toco/model.h9
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc83
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc34
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc32
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h6
6 files changed, 155 insertions, 46 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 5eaf6e27fc..133ef79a34 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -477,6 +477,30 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
+// Retain TensorFlow NodeDef in Toco Operator.
+//
+// If an op is supported by Toco but not supported by TFLite, TFLite exporter
+// will use the retained NodeDef to populate a Flex op when Flex mode is
+// enabled.
+//
+// This can't be easily applied to all operations, because a TensorFlow node
+// may become multiple Toco operators. Thus we need to call this function in
+// operator conversion functions one by one whenever feasible.
+//
+// This may cause problems if a graph transformation rule changes parameters
+// of the node. When calling this function, please check if any existing
+// graph transformation rule will change an existing operator with the same
+// type.
+//
+// This provides a route to handle Toco-supported & TFLite-unsupported ops
+// in Flex mode. However it's not a solid solution. Eventually we should
+// get rid of this.
+// TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
+// this function.
+void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
+ node.SerializeToString(&op->tensorflow_node_def);
+}
+
tensorflow::Status ConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -990,6 +1014,10 @@ tensorflow::Status ConvertBatchMatMulOperator(
auto* batch_matmul = new BatchMatMulOperator;
batch_matmul->inputs = {node.input(0), node.input(1)};
batch_matmul->outputs = {node.name()};
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, batch_matmul);
+
model->operators.emplace_back(batch_matmul);
return tensorflow::Status::OK();
}
@@ -1081,7 +1109,10 @@ tensorflow::Status ConvertUnsupportedOperator(
auto* op = new TensorFlowUnsupportedOperator;
op->tensorflow_op = node.op();
- node.SerializeToString(&op->tensorflow_node_def);
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, op);
+
model->operators.emplace_back(op);
// Parse inputs.
@@ -1605,6 +1636,10 @@ tensorflow::Status ConvertRangeOperator(
op->inputs.push_back(node.input(1));
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, op);
+
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 6e207fdf54..61f1f095e9 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -376,6 +376,13 @@ struct Operator {
// looks unused.
bool unresolved_outputs = false;
+ // A serialized tensorflow::NodeDef string.
+ // The field is filled only when importing from TensorFlow.
+ // It's guaranteed to be filled for `TensorFlowUnsupportedOperator`.
+ // It's not guaranteed to be filled for other ops. Ops created by graph
+ // transformations won't have TensorFlow NodeDef.
+ string tensorflow_node_def;
+
protected:
// Constructor used by subclasses for specific OperatorType's.
explicit Operator(OperatorType t)
@@ -1535,8 +1542,6 @@ struct TensorFlowUnsupportedOperator : Operator {
// The original TF operation type. Used for diagnostic purposes.
string tensorflow_op;
- // A serialized tensorflow::NodeDef string.
- string tensorflow_node_def;
// A boolean indicating if the unsupported op should be treated as quantized.
bool quantized = false;
// A boolean indicating if the unsupported op output should allow float values
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index f6f76e48a4..3b34cd6285 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -95,11 +95,13 @@ OperatorKey GetOperatorKey(
const ::toco::Operator& op,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
bool allow_flex_ops) {
+ // Get the op name (by Toco definition).
string name = HelpfulOperatorTypeName(op);
- const auto& builtin_ops = GetBuiltinOpsMap();
bool is_builtin = false;
OperatorKey key;
+
+ const auto& builtin_ops = GetBuiltinOpsMap();
if (ops_by_type.count(op.type) != 0) {
key.version = ops_by_type.at(op.type)->GetVersion(op);
name = ops_by_type.at(op.type)->name();
@@ -110,37 +112,46 @@ OperatorKey GetOperatorKey(
// For TFLite supported builtin ops, find out its BuiltinOperator enum used
// in FlatBuffer.
key.type = builtin_ops.at(name);
- } else {
- key.type = BuiltinOperator_CUSTOM;
-
- key.is_custom_op = true;
- if (op.type == OperatorType::kUnsupported) {
- const TensorFlowUnsupportedOperator& unsupported_op =
- static_cast<const TensorFlowUnsupportedOperator&>(op);
- const auto tensorflow_op = unsupported_op.tensorflow_op;
-
- // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
- // to populate a regular custom op. We need to find a way to fix this.
- if (allow_flex_ops) {
- // Memorize the original TensorFlow op name.
- key.flex_tensorflow_op = tensorflow_op;
- // Prefix the custom code of the flex op.
- key.custom_code =
- string(::tflite::kFlexCustomCodePrefix) + tensorflow_op;
- key.is_flex_op = true;
-
- if (IsControlFlowOp(tensorflow_op)) {
- key.is_unsupported_flex_op = true;
- }
- } else {
- key.custom_code = tensorflow_op;
- }
+ return key;
+ }
+
+ // The logic below is all for custom ops.
+ key.is_custom_op = true;
+ key.type = BuiltinOperator_CUSTOM;
+
+ if (op.type == OperatorType::kUnsupported) {
+ const TensorFlowUnsupportedOperator& unsupported_op =
+ static_cast<const TensorFlowUnsupportedOperator&>(op);
+ const auto tensorflow_op = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_flex_ops) {
+ key.is_flex_op = true;
+ key.flex_tensorflow_op = tensorflow_op;
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op;
} else {
- // For Toco-supported/TFLite-unsupported ops, currently we produce a
- // custom op. This gives developers a chance to implement custom ops.
- // TODO(b/116800229): Also produce Toco-supported/TFLite-unsupported ops
- // as Flex ops when Flex mode is enabled.
- key.custom_code = name;
+ key.custom_code = tensorflow_op;
+ }
+ } else if (allow_flex_ops && !op.tensorflow_node_def.empty()) {
+ // For Toco-supported/TFLite-unsupported ops, if the TensorFlow NodeDef
+ // is retained in the Toco Operator, we produce a Flex op if Flex mode
+ // is enabled.
+ key.is_flex_op = true;
+ key.flex_tensorflow_op = name;
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op;
+ } else {
+ // If Flex is disabled or the original TensorFlow NodeDef isn't available,
+ // we produce a custom op. This gives developers a chance to implemenr
+ // custom ops.
+ key.custom_code = name;
+ }
+
+ if (key.is_flex_op) {
+ if (IsControlFlowOp(key.flex_tensorflow_op)) {
+ key.is_unsupported_flex_op = true;
}
}
return key;
@@ -323,8 +334,9 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(
- details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
+ const auto key =
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
+ int op_index = operators_map.at(key);
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -349,6 +361,11 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
variable_tensor_indices->insert(variable_tensor_index);
}
}
+ } else if (key.is_flex_op && !op->tensorflow_node_def.empty()) {
+ auto fbb = WriteFlexOpOptions(op->tensorflow_node_def);
+ if (fbb) {
+ options = Options::Custom(builder->CreateVector(fbb->GetBuffer()));
+ }
}
// The only supported CustomOptionFormat is FLEXBUFFERS now.
op_vector.push_back(CreateOperator(
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index d48ab78285..eda1aa78a3 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include "tensorflow/core/framework/node_def.pb.h"
namespace toco {
namespace tflite {
@@ -382,6 +383,39 @@ TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
EXPECT_TRUE(key.is_unsupported_flex_op);
}
+TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
+ // Test Toco-supported/TFLite-unsupported operators.
+ // TODO(ycling): The test will be broken if Range is implemented in TFLite.
+ // Find a more robust way to test the fallback logic.
+ auto op = absl::make_unique<RangeOperator>();
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+
+ {
+ // If NodeDef isn't retained in the Toco op, a regular custom op
+ // will be exported.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "Range");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_FALSE(key.is_flex_op);
+ }
+
+ ::tensorflow::NodeDef node_def;
+ node_def.set_name("Range");
+ node_def.set_op("Range");
+ node_def.SerializeToString(&op->tensorflow_node_def);
+
+ {
+ // If NodeDef is retained in the Toco op, a Flex op will be exported.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "FlexRange");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_TRUE(key.is_flex_op);
+ }
+}
+
// TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
} // namespace
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 9addbb81e7..ed37535fe0 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1157,6 +1157,25 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
+ const string& tensorflow_node_def) {
+ auto fbb = absl::make_unique<flexbuffers::Builder>();
+
+ ::tensorflow::NodeDef node_def;
+ if (!node_def.ParseFromString(tensorflow_node_def)) {
+ LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
+ return {};
+ }
+
+ fbb->Vector([&]() {
+ fbb->String(node_def.op());
+ fbb->String(tensorflow_node_def);
+ });
+ fbb->Finish();
+ LOG(INFO) << "Writing flex op: " << node_def.op();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+}
+
class TensorFlowUnsupported : public BaseOperator {
public:
TensorFlowUnsupported(const string& name, OperatorType type,
@@ -1192,6 +1211,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<flexbuffers::Builder> WriteOptions(
const TensorFlowUnsupportedOperator& op) const {
+ if (allow_flex_ops_) {
+ return WriteFlexOpOptions(op.tensorflow_node_def);
+ }
auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def;
@@ -1200,16 +1222,6 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
- if (allow_flex_ops_) {
- fbb->Vector([&]() {
- fbb->String(node_def.op());
- fbb->String(op.tensorflow_node_def);
- });
- fbb->Finish();
- LOG(INFO) << "Writing flex op: " << node_def.op();
- return std::unique_ptr<flexbuffers::Builder>(fbb.release());
- }
-
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 13d9f6c49a..6e4e0a16d1 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/model.h"
@@ -36,6 +37,11 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
bool allow_flex_ops = false);
+// Write the custom option FlexBuffer with a serialized TensorFlow NodeDef
+// for a Flex op.
+std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
+ const string& tensorflow_node_def);
+
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
using BuiltinOptions = void;