diff options
author | 2018-09-12 09:23:42 -0700 | |
---|---|---|
committer | 2018-09-12 09:28:12 -0700 | |
commit | 6b507a6de855a6f988100904229b7f46a5652b88 (patch) | |
tree | cbb0c14a47f2da3dd0add9211f03641965b181f4 /tensorflow/contrib/lite/toco/import_tensorflow_test.cc | |
parent | 9e78991b5c380b7fba0444685e5c6ef40e3c5b26 (diff) |
Add basic type propagation for unsupported ops in TFLite conversion
PiperOrigin-RevId: 212651704
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow_test.cc | 75 |
1 files changed, 70 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index a00e136dd6..da248826a7 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -49,6 +49,17 @@ Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, namespace { +Status ImportNode(const NodeDef& node, Model* model) { + const auto converter = internal::GetTensorFlowNodeConverterMap(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, + converter); +} + +Status ImportNode(const NodeDef& node) { + Model model; + return ImportNode(node, &model); +} + class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> { protected: ShapeImportTest() {} @@ -109,12 +120,24 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> { SetAttrValue(t, &value_attr); (*node->mutable_attr())["value"] = value_attr; } +}; + +class TypeImportTest : public ::testing::TestWithParam< + std::pair<tensorflow::DataType, ArrayDataType>> { + protected: + TypeImportTest() {} + + void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype, + NodeDef* node) { + node->set_op(op_name); + node->set_name("Node1"); + + node->add_input(); + node->set_input(0, "Node0"); - Status ImportNode(const NodeDef& node) { - Model model; - const auto converter = internal::GetTensorFlowNodeConverterMap(); - return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model, - converter); + AttrValue dtype_attr; + SetAttrValue(dtype, &dtype_attr); + (*node->mutable_attr())["T"] = dtype_attr; } }; @@ -167,5 +190,47 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) { INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, ::testing::ValuesIn(TestTypes())); +std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() { + return {{DT_FLOAT, ArrayDataType::kFloat}, + {DT_INT32, ArrayDataType::kInt32}, + {DT_INT64, ArrayDataType::kInt64}}; +} + +TEST_P(TypeImportTest, BasicTypeInference) { + NodeDef node; + BuildUnaryNode("Atan", GetParam().first, &node); + + Model model; + EXPECT_TRUE(ImportNode(node, &model).ok()); + + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast<const TensorFlowUnsupportedOperator*>( + model.operators[0].get()); + ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second)); +} +INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest, + ::testing::ValuesIn(UnaryTestTypes())); + +TEST(ImportTest, FailedTypeInference) { + // Create a unary op with no Type ("T") annotation. + NodeDef node; + node.set_op("Atan"); + node.set_name("Node1"); + node.add_input(); + node.set_input(0, "Node0"); + + Model model; + EXPECT_TRUE(ImportNode(node, &model).ok()); + + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast<const TensorFlowUnsupportedOperator*>( + model.operators[0].get()); + ASSERT_TRUE(op->output_data_types.empty()); +} + } // namespace } // namespace toco |