aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/shape_inference_test.cc')
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc95
1 files changed, 94 insertions, 1 deletions
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 80a8639c02..9f363d50b3 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -123,7 +123,8 @@ TEST_F(ShapeInferenceTest, Run) {
NodeDef def;
def.set_name("foo");
def.set_op("foo_op");
- InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {}, {}, {});
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}, {});
+ TF_ASSERT_OK(c.construction_status());
{
auto fn = [](InferenceContext* c) {
@@ -152,6 +153,98 @@ TEST_F(ShapeInferenceTest, Run) {
}
}
+// Tests different context data added when Run returns error.
+TEST_F(ShapeInferenceTest, AttachContext) {
+ NodeDef def;
+ def.set_name("foo");
+ def.set_op("foo_op");
+ // Error when no constant tensors were requested.
+ {
+ InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, {}, {});
+ TF_ASSERT_OK(c.construction_status());
+ auto fn = [](InferenceContext* c) {
+ ShapeHandle h;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ };
+ EXPECT_EQ(
+ "Invalid argument: Shape must be at most rank 0 but is rank 3 for "
+ "'foo' (op: 'foo_op') with input shapes: [1,2,3].",
+ c.Run(fn).ToString());
+ }
+
+ // Error when a constant tensor value was requested.
+ {
+ Tensor input_t =
+ ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
+ InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3}), S({4, 5})},
+ {nullptr, &input_t}, {}, {}, {});
+ TF_ASSERT_OK(c.construction_status());
+ auto fn = [](InferenceContext* c) {
+ c->input_tensor(0); // get this one, but it's null - won't be in error.
+ c->input_tensor(1); // get this one, will now be in error.
+ ShapeHandle h;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ };
+ EXPECT_EQ(
+ "Invalid argument: Shape must be at most rank 0 but is rank 3 for "
+ "'foo' (op: 'foo_op') with input shapes: [1,2,3], [4,5] and with "
+ "computed input tensors: input[1] = <1.1 2.2 3.3 4.4 5.5>.",
+ c.Run(fn).ToString());
+ }
+
+ // Error when a constant tensor value as shape was requested, but no partial
+ // shapes provided.
+ {
+ Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
+ InferenceContext c(&def, MakeOpDef(2, 2), {S({3}), S({4})},
+ {nullptr, &input_t}, {}, {}, {});
+ TF_ASSERT_OK(c.construction_status());
+ auto fn = [](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ ShapeHandle h;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ };
+ EXPECT_EQ(
+ "Invalid argument: Shape must be at most rank 0 but is rank 1 for "
+ "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed "
+ "input tensors: input[1] = <1 2 3 4 5>.",
+ c.Run(fn).ToString());
+ }
+
+ // Error when a constant tensor value as shape was requested, and a partial
+ // shape was provided.
+ {
+ Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
+ InferenceContext c(&def, MakeOpDef(2, 2), {S({3}), S({4})},
+ {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {},
+ {});
+ TF_ASSERT_OK(c.construction_status());
+ auto fn = [](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ ShapeHandle h;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ };
+ EXPECT_EQ(
+ "Invalid argument: Shape must be at most rank 0 but is rank 1 for "
+ "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed "
+ "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed "
+ "as partial shapes: input[0] = [10,?,5].",
+ c.Run(fn).ToString());
+ }
+}
+
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})},