diff options
author | 2018-09-13 15:34:43 -0700 | |
---|---|---|
committer | 2018-09-13 15:39:27 -0700 | |
commit | e59ddcca727340a8b45694a28cd9f52352607e63 (patch) | |
tree | 333e4da24dae2071c8946afd2c8a9def74f68889 /tensorflow/contrib/lite/toco/import_tensorflow_test.cc | |
parent | ea52ecd836098e0b1d37325cf1b91133f908547e (diff) |
Automated rollback of commit 6b507a6de855a6f988100904229b7f46a5652b88
PiperOrigin-RevId: 212890622
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow_test.cc | 75 |
1 files changed, 5 insertions, 70 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index da248826a7..a00e136dd6 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -49,17 +49,6 @@ Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, namespace { -Status ImportNode(const NodeDef& node, Model* model) { - const auto converter = internal::GetTensorFlowNodeConverterMap(); - return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model, - converter); -} - -Status ImportNode(const NodeDef& node) { - Model model; - return ImportNode(node, &model); -} - class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> { protected: ShapeImportTest() {} @@ -120,24 +109,12 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> { SetAttrValue(t, &value_attr); (*node->mutable_attr())["value"] = value_attr; } -}; - -class TypeImportTest : public ::testing::TestWithParam< - std::pair<tensorflow::DataType, ArrayDataType>> { - protected: - TypeImportTest() {} - - void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype, - NodeDef* node) { - node->set_op(op_name); - node->set_name("Node1"); - - node->add_input(); - node->set_input(0, "Node0"); - AttrValue dtype_attr; - SetAttrValue(dtype, &dtype_attr); - (*node->mutable_attr())["T"] = dtype_attr; + Status ImportNode(const NodeDef& node) { + Model model; + const auto converter = internal::GetTensorFlowNodeConverterMap(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model, + converter); } }; @@ -190,47 +167,5 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) { INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, ::testing::ValuesIn(TestTypes())); -std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() { - return {{DT_FLOAT, ArrayDataType::kFloat}, - {DT_INT32, ArrayDataType::kInt32}, - {DT_INT64, ArrayDataType::kInt64}}; -} - -TEST_P(TypeImportTest, BasicTypeInference) { - NodeDef node; - BuildUnaryNode("Atan", GetParam().first, &node); - - Model model; - EXPECT_TRUE(ImportNode(node, &model).ok()); - - ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); - ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); - const TensorFlowUnsupportedOperator* op = - static_cast<const TensorFlowUnsupportedOperator*>( - model.operators[0].get()); - ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second)); -} -INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest, - ::testing::ValuesIn(UnaryTestTypes())); - -TEST(ImportTest, FailedTypeInference) { - // Create a unary op with no Type ("T") annotation. - NodeDef node; - node.set_op("Atan"); - node.set_name("Node1"); - node.add_input(); - node.set_input(0, "Node0"); - - Model model; - EXPECT_TRUE(ImportNode(node, &model).ok()); - - ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); - ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); - const TensorFlowUnsupportedOperator* op = - static_cast<const TensorFlowUnsupportedOperator*>( - model.operators[0].get()); - ASSERT_TRUE(op->output_data_types.empty()); -} - } // namespace } // namespace toco |