aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-09-14 11:42:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 11:45:56 -0700
commit39f50af5634b8a4d2132b57bad2152308a0fd41c (patch)
tree5a5d0b0a9722067b702995dc84a1c4d8156d36a4 /tensorflow/contrib/lite/toco/import_tensorflow_test.cc
parentc20a7b81d79d30db9e990309ddb419bcb48120cc (diff)
Improve output parsing for unsupported ops
PiperOrigin-RevId: 213017532
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow_test.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index da248826a7..8a236d4444 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -60,6 +60,28 @@ Status ImportNode(const NodeDef& node) {
return ImportNode(node, &model);
}
+NodeDef BuildNode(
+ const std::string& op,
+ const std::vector<std::initializer_list<int>>& output_shapes) {
+ NodeDef node;
+ node.set_op(op);
+ node.set_name("Node1");
+ node.add_input();
+ node.set_input(0, "Node0");
+
+ AttrValue::ListValue* shapes =
+ (*node.mutable_attr())["_output_shapes"].mutable_list();
+ for (const auto& output_shape : output_shapes) {
+ tensorflow::TensorShapeProto* shape = shapes->add_shape();
+ for (int64_t output_shape_dim : output_shape) {
+ auto shape_dim = shape->add_dim();
+ shape_dim->set_size(output_shape_dim);
+ }
+ }
+
+ return node;
+}
+
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
protected:
ShapeImportTest() {}
@@ -232,5 +254,35 @@ TEST(ImportTest, FailedTypeInference) {
ASSERT_TRUE(op->output_data_types.empty());
}
+TEST(ImportTest, UnsupportedOpWithOutputShapes) {
+ // Create an unsupported op with output shapes.
+ Model model;
+ EXPECT_TRUE(ImportNode(BuildNode("Atan", {{1, 2}, {2, 3}}), &model).ok());
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ // The output shapes should be imported.
+ ASSERT_EQ(op->output_shapes.size(), 2);
+ ASSERT_THAT(op->output_shapes[0].dims(), ::testing::ElementsAre(1, 2));
+ ASSERT_THAT(op->output_shapes[1].dims(), ::testing::ElementsAre(2, 3));
+}
+
+TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) {
+ // Create an unsupported op with wildcard output shapes.
+ Model model;
+ EXPECT_TRUE(ImportNode(BuildNode("Atan", {{-1, 2}}), &model).ok());
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ // Wildcard shapes aren't yet supported.
+ ASSERT_TRUE(op->output_shapes.empty());
+}
+
} // namespace
} // namespace toco