aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-13 16:58:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 17:02:49 -0700
commit4b42a284683416ab6159f32c903321af9dc9a591 (patch)
tree9ed2dab1ec07a6713538bda1a6a47759d3055521
parent4137d84a3b41638d4048e45ab579662c18a06df5 (diff)
Reland "Add basic type propagation for unsupported ops in TFLite conversion"
The original CL was rolled back due to op registration conflicts in the pip. Resolve the issue by only including core:ops in the toco binary itself, not in intermediate libraries. PiperOrigin-RevId: 212902838
-rw-r--r--tensorflow/contrib/lite/toco/BUILD6
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc20
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc75
3 files changed, 95 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index bea90f1ce8..96b88b60fc 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -347,6 +347,7 @@ tf_cc_test(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
+ "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
],
@@ -407,8 +408,11 @@ tf_cc_binary(
":toco_port",
":toco_tooling",
":types_proto_cc",
- "//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:lib",
+ # We cannot embed the core:ops dependency directly into :toco_tooling as
+ # it can conflict with downstream deps when toco is used as a library.
+ "//tensorflow/core:ops",
],
)
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9bc23c4b3c..efc1007925 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -58,6 +58,7 @@ using tensorflow::DT_STRING;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
+using tensorflow::OpRegistry;
using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
@@ -1079,6 +1080,25 @@ tensorflow::Status ConvertUnsupportedOperator(
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
+ } else {
+ const tensorflow::OpDef* op_def = nullptr;
+ if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+ for (const auto& output_arg : op_def->output_arg()) {
+ if (HasAttr(node, output_arg.type_attr())) {
+ op->output_data_types.push_back(
+ ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
+ } else {
+ LOG(INFO) << "Op node missing output type attribute: " << node.name();
+ op->output_data_types.clear();
+ break;
+ }
+ }
+ }
+ if (op->output_data_types.empty()) {
+ // TODO(b/113613439): Figure out how to propagate types for custom ops
+ // that have no OpDef.
+ LOG(INFO) << "Unable to determine output type for op: " << node.op();
+ }
}
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index a00e136dd6..da248826a7 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -49,6 +49,17 @@ Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
namespace {
+Status ImportNode(const NodeDef& node, Model* model) {
+ const auto converter = internal::GetTensorFlowNodeConverterMap();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+ converter);
+}
+
+Status ImportNode(const NodeDef& node) {
+ Model model;
+ return ImportNode(node, &model);
+}
+
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
protected:
ShapeImportTest() {}
@@ -109,12 +120,24 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
SetAttrValue(t, &value_attr);
(*node->mutable_attr())["value"] = value_attr;
}
+};
+
+class TypeImportTest : public ::testing::TestWithParam<
+ std::pair<tensorflow::DataType, ArrayDataType>> {
+ protected:
+ TypeImportTest() {}
+
+ void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype,
+ NodeDef* node) {
+ node->set_op(op_name);
+ node->set_name("Node1");
+
+ node->add_input();
+ node->set_input(0, "Node0");
- Status ImportNode(const NodeDef& node) {
- Model model;
- const auto converter = internal::GetTensorFlowNodeConverterMap();
- return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
- converter);
+ AttrValue dtype_attr;
+ SetAttrValue(dtype, &dtype_attr);
+ (*node->mutable_attr())["T"] = dtype_attr;
}
};
@@ -167,5 +190,47 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
+std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {
+ return {{DT_FLOAT, ArrayDataType::kFloat},
+ {DT_INT32, ArrayDataType::kInt32},
+ {DT_INT64, ArrayDataType::kInt64}};
+}
+
+TEST_P(TypeImportTest, BasicTypeInference) {
+ NodeDef node;
+ BuildUnaryNode("Atan", GetParam().first, &node);
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second));
+}
+INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest,
+ ::testing::ValuesIn(UnaryTestTypes()));
+
+TEST(ImportTest, FailedTypeInference) {
+ // Create a unary op with no Type ("T") annotation.
+ NodeDef node;
+ node.set_op("Atan");
+ node.set_name("Node1");
+ node.add_input();
+ node.set_input(0, "Node0");
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_TRUE(op->output_data_types.empty());
+}
+
} // namespace
} // namespace toco