aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-12 09:23:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 09:28:12 -0700
commit6b507a6de855a6f988100904229b7f46a5652b88 (patch)
treecbb0c14a47f2da3dd0add9211f03641965b181f4
parent9e78991b5c380b7fba0444685e5c6ef40e3c5b26 (diff)
Add basic type propagation for unsupported ops in TFLite conversion
PiperOrigin-RevId: 212651704
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc18
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc75
3 files changed, 89 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index bea90f1ce8..72c71b2841 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -331,6 +331,7 @@ cc_library(
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
] + select({
# Placeholder for internal darwin rule.
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9bc23c4b3c..eb36b3411d 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -58,6 +58,7 @@ using tensorflow::DT_STRING;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
+using tensorflow::OpRegistry;
using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
@@ -1079,6 +1080,23 @@ tensorflow::Status ConvertUnsupportedOperator(
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
+ } else {
+ const tensorflow::OpDef* op_def = nullptr;
+ if (OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+ for (const auto& output_arg : op_def->output_arg()) {
+ if (HasAttr(node, output_arg.type_attr())) {
+ op->output_data_types.push_back(
+ ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
+ } else {
+ LOG(INFO) << "Op node missing output type attribute: " << node.name();
+ }
+ }
+ }
+ if (op->output_data_types.empty()) {
+ // TODO(b/113613439): Figure out how to propagate types for custom ops
+ // that have no OpDef.
+ LOG(INFO) << "Unable to determine output type for op: " << node.op();
+ }
}
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index a00e136dd6..da248826a7 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -49,6 +49,17 @@ Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
namespace {
+Status ImportNode(const NodeDef& node, Model* model) {
+ const auto converter = internal::GetTensorFlowNodeConverterMap();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+ converter);
+}
+
+Status ImportNode(const NodeDef& node) {
+ Model model;
+ return ImportNode(node, &model);
+}
+
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
protected:
ShapeImportTest() {}
@@ -109,12 +120,24 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
SetAttrValue(t, &value_attr);
(*node->mutable_attr())["value"] = value_attr;
}
+};
+
+class TypeImportTest : public ::testing::TestWithParam<
+ std::pair<tensorflow::DataType, ArrayDataType>> {
+ protected:
+ TypeImportTest() {}
+
+ void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype,
+ NodeDef* node) {
+ node->set_op(op_name);
+ node->set_name("Node1");
+
+ node->add_input();
+ node->set_input(0, "Node0");
- Status ImportNode(const NodeDef& node) {
- Model model;
- const auto converter = internal::GetTensorFlowNodeConverterMap();
- return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
- converter);
+ AttrValue dtype_attr;
+ SetAttrValue(dtype, &dtype_attr);
+ (*node->mutable_attr())["T"] = dtype_attr;
}
};
@@ -167,5 +190,47 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
+std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {
+ return {{DT_FLOAT, ArrayDataType::kFloat},
+ {DT_INT32, ArrayDataType::kInt32},
+ {DT_INT64, ArrayDataType::kInt64}};
+}
+
+TEST_P(TypeImportTest, BasicTypeInference) {
+ NodeDef node;
+ BuildUnaryNode("Atan", GetParam().first, &node);
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second));
+}
+INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest,
+ ::testing::ValuesIn(UnaryTestTypes()));
+
+TEST(ImportTest, FailedTypeInference) {
+ // Create a unary op with no Type ("T") annotation.
+ NodeDef node;
+ node.set_op("Atan");
+ node.set_name("Node1");
+ node.add_input();
+ node.set_input(0, "Node0");
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_TRUE(op->output_data_types.empty());
+}
+
} // namespace
} // namespace toco