aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-20 20:13:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-20 21:18:14 -0700
commitd7e7b7c1dd9af1566945cf5ca78e80b40913a27b (patch)
treeffe4e8e2fb3e76f582714bc2874ee8af9b51d50e /tensorflow/core/framework/shape_inference_test.cc
parente40a34e3749aeccb2c6ce45dfdd3596a5233bb97 (diff)
Add C++ shape inference functions for several image ops.
Change InferenceContext::ReplaceDim to support negative indexing. Change: 128022331
Diffstat (limited to 'tensorflow/core/framework/shape_inference_test.cc')
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 323992ac5d..ce175de561 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -481,6 +481,19 @@ TEST(ShapeInferenceTest, ReplaceDim) {
EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
EXPECT_TRUE(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced).ok());
EXPECT_EQ("?", c.DebugString(replaced));
+
+ // Negative indexing.
+ EXPECT_TRUE(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced).ok());
+ EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
+ EXPECT_TRUE(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced).ok());
+ EXPECT_EQ("?", c.DebugString(replaced));
+
+ // out of range indexing.
+ EXPECT_FALSE(c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced).ok());
+ EXPECT_TRUE(replaced == nullptr);
+ replaced = in;
+ EXPECT_FALSE(c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced).ok());
+ EXPECT_TRUE(replaced == nullptr);
}
TEST(ShapeInferenceTest, MakeShape) {
@@ -501,6 +514,10 @@ TEST(ShapeInferenceTest, MakeShape) {
auto s2 = c.MakeShape(dims);
EXPECT_TRUE(s != s2); // different pointers
EXPECT_TRUE(c.Dim(s2, 0) == c.Dim(in0, rank - 1));
+
+ auto s3 = c.MakeShape({1, 2, dims[2]});
+ EXPECT_TRUE(s != s3); // different pointers
+ EXPECT_EQ("[1,2,3]", c.DebugString(s3));
}
TEST(ShapeInferenceTest, UnknownShape) {