aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/common_shape_fns_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-25 18:25:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-25 19:35:26 -0700
commit8cb8645322f5a738c8ce7bbfd6ebcbaac3e3ba02 (patch)
tree26a7119636e38baf6d6321a784d0ad84a5cf5581 /tensorflow/core/framework/common_shape_fns_test.cc
parentc1208d19b78ff47e32ec31d24ce32d1d4054f264 (diff)
In ShapeRefiner, add support for the C++ equivalent of the Python
constant_value_as_tensor functions. This follows the same lazy-evaluation as getting constant tensors. Add validation in InferenceContext::MakeShapeFromShapeTensor for invalid values in the input tensor. Change: 137231472
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns_test.cc')
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc53
1 files changed, 27 insertions, 26 deletions
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index a4efc04467..7196bc8304 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -56,7 +56,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
.Input({{"data", 0, DT_FLOAT}})
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({}), S({10})}, {});
+ InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {});
TF_EXPECT_OK(NoOutputs(&c));
EXPECT_EQ(0, c.num_outputs());
}
@@ -74,14 +74,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({})}, {});
+ InferenceContext c(&def, op_def, {S({})}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
{
- InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {});
+ InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@@ -108,7 +108,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown inner dimension for one
- InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {});
+ InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -126,7 +126,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Invalid rank.
- InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {});
+ InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -136,7 +136,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown outer dimension
- InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -145,7 +145,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
- InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {});
+ InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -156,7 +156,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
- InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {});
+ InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -174,7 +174,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {});
+ InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -191,7 +191,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -215,7 +215,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {});
+ InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -224,7 +224,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Unknown ranks.
- InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {});
+ InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_FALSE(c.RankKnown(output));
@@ -232,7 +232,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Rank > 2
- InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {});
+ InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
@@ -245,7 +245,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {});
+ InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
@@ -258,7 +258,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {});
+ InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {},
+ {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
@@ -271,7 +272,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {});
+ InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[10,11,12]", c.DebugString(output));
@@ -279,7 +280,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Input rank not high enough
- InferenceContext c(&def, op_def, {S({3}), S({3})}, {});
+ InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@@ -291,7 +292,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {});
+ InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@@ -310,7 +311,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 10})}, {});
+ InferenceContext c(&def, op_def, {S({2, 10})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -318,7 +319,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Rank > 2
- InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {});
+ InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -330,7 +331,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {});
+ InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -342,7 +343,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {});
+ InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -354,7 +355,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({10, 11, 12})}, {});
+ InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -362,7 +363,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Input rank not high enough
- InferenceContext c(&def, op_def, {S({3})}, {});
+ InferenceContext c(&def, op_def, {S({3})}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@@ -373,7 +374,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {S({2, 3})}, {});
+ InferenceContext c(&def, op_def, {S({2, 3})}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}