From 39f50af5634b8a4d2132b57bad2152308a0fd41c Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Fri, 14 Sep 2018 11:42:02 -0700 Subject: Improve output parsing for unsupported ops PiperOrigin-RevId: 213017532 --- .../contrib/lite/toco/import_tensorflow_test.cc | 52 ++++++++++++++++++++++ 1 file changed, 52 insertions(+) (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc') diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index da248826a7..8a236d4444 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -60,6 +60,28 @@ Status ImportNode(const NodeDef& node) { return ImportNode(node, &model); } +NodeDef BuildNode( + const std::string& op, + const std::vector>& output_shapes) { + NodeDef node; + node.set_op(op); + node.set_name("Node1"); + node.add_input(); + node.set_input(0, "Node0"); + + AttrValue::ListValue* shapes = + (*node.mutable_attr())["_output_shapes"].mutable_list(); + for (const auto& output_shape : output_shapes) { + tensorflow::TensorShapeProto* shape = shapes->add_shape(); + for (int64_t output_shape_dim : output_shape) { + auto shape_dim = shape->add_dim(); + shape_dim->set_size(output_shape_dim); + } + } + + return node; +} + class ShapeImportTest : public ::testing::TestWithParam { protected: ShapeImportTest() {} @@ -232,5 +254,35 @@ TEST(ImportTest, FailedTypeInference) { ASSERT_TRUE(op->output_data_types.empty()); } +TEST(ImportTest, UnsupportedOpWithOutputShapes) { + // Create an unsupported op with output shapes. + Model model; + EXPECT_TRUE(ImportNode(BuildNode("Atan", {{1, 2}, {2, 3}}), &model).ok()); + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast( + model.operators[0].get()); + + // The output shapes should be imported. + ASSERT_EQ(op->output_shapes.size(), 2); + ASSERT_THAT(op->output_shapes[0].dims(), ::testing::ElementsAre(1, 2)); + ASSERT_THAT(op->output_shapes[1].dims(), ::testing::ElementsAre(2, 3)); +} + +TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) { + // Create an unsupported op with wildcard output shapes. + Model model; + EXPECT_TRUE(ImportNode(BuildNode("Atan", {{-1, 2}}), &model).ok()); + ASSERT_THAT(model.operators.size(), ::testing::Ge(1)); + ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported); + const TensorFlowUnsupportedOperator* op = + static_cast( + model.operators[0].get()); + + // Wildcard shapes aren't yet supported. + ASSERT_TRUE(op->output_shapes.empty()); +} + } // namespace } // namespace toco -- cgit v1.2.3