aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-13 16:58:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 17:02:49 -0700
commit4b42a284683416ab6159f32c903321af9dc9a591 (patch)
tree9ed2dab1ec07a6713538bda1a6a47759d3055521 /tensorflow/contrib/lite/toco/import_tensorflow_test.cc
parent4137d84a3b41638d4048e45ab579662c18a06df5 (diff)
Reland "Add basic type propagation for unsupported ops in TFLite conversion"
The original CL was rolled back due to op registration conflicts in the pip. Resolve the issue by only including core:ops in the toco binary itself, not in intermediate libraries. PiperOrigin-RevId: 212902838
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc75
1 files changed, 70 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index a00e136dd6..da248826a7 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -49,6 +49,17 @@ Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
namespace {
+Status ImportNode(const NodeDef& node, Model* model) {
+ const auto converter = internal::GetTensorFlowNodeConverterMap();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+ converter);
+}
+
+Status ImportNode(const NodeDef& node) {
+ Model model;
+ return ImportNode(node, &model);
+}
+
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
protected:
ShapeImportTest() {}
@@ -109,12 +120,24 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
SetAttrValue(t, &value_attr);
(*node->mutable_attr())["value"] = value_attr;
}
+};
+
+class TypeImportTest : public ::testing::TestWithParam<
+ std::pair<tensorflow::DataType, ArrayDataType>> {
+ protected:
+ TypeImportTest() {}
+
+ void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype,
+ NodeDef* node) {
+ node->set_op(op_name);
+ node->set_name("Node1");
+
+ node->add_input();
+ node->set_input(0, "Node0");
- Status ImportNode(const NodeDef& node) {
- Model model;
- const auto converter = internal::GetTensorFlowNodeConverterMap();
- return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
- converter);
+ AttrValue dtype_attr;
+ SetAttrValue(dtype, &dtype_attr);
+ (*node->mutable_attr())["T"] = dtype_attr;
}
};
@@ -167,5 +190,47 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
+std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {
+ return {{DT_FLOAT, ArrayDataType::kFloat},
+ {DT_INT32, ArrayDataType::kInt32},
+ {DT_INT64, ArrayDataType::kInt64}};
+}
+
+TEST_P(TypeImportTest, BasicTypeInference) {
+ NodeDef node;
+ BuildUnaryNode("Atan", GetParam().first, &node);
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second));
+}
+INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest,
+ ::testing::ValuesIn(UnaryTestTypes()));
+
+TEST(ImportTest, FailedTypeInference) {
+ // Create a unary op with no Type ("T") annotation.
+ NodeDef node;
+ node.set_op("Atan");
+ node.set_name("Node1");
+ node.add_input();
+ node.set_input(0, "Node0");
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_TRUE(op->output_data_types.empty());
+}
+
} // namespace
} // namespace toco