diff options
Diffstat (limited to 'tensorflow/core/framework/shape_inference_test.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference_test.cc | 95 |
1 files changed, 94 insertions, 1 deletions
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 80a8639c02..9f363d50b3 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -123,7 +123,8 @@ TEST_F(ShapeInferenceTest, Run) { NodeDef def; def.set_name("foo"); def.set_op("foo_op"); - InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {}, {}, {}); + InferenceContext c(&def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}, {}); + TF_ASSERT_OK(c.construction_status()); { auto fn = [](InferenceContext* c) { @@ -152,6 +153,98 @@ TEST_F(ShapeInferenceTest, Run) { } } +// Tests different context data added when Run returns error. +TEST_F(ShapeInferenceTest, AttachContext) { + NodeDef def; + def.set_name("foo"); + def.set_op("foo_op"); + // Error when no constant tensors were requested. + { + InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, {}, {}); + TF_ASSERT_OK(c.construction_status()); + auto fn = [](InferenceContext* c) { + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + return Status::OK(); + }; + EXPECT_EQ( + "Invalid argument: Shape must be at most rank 0 but is rank 3 for " + "'foo' (op: 'foo_op') with input shapes: [1,2,3].", + c.Run(fn).ToString()); + } + + // Error when a constant tensor value was requested. + { + Tensor input_t = + ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5}); + InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3}), S({4, 5})}, + {nullptr, &input_t}, {}, {}, {}); + TF_ASSERT_OK(c.construction_status()); + auto fn = [](InferenceContext* c) { + c->input_tensor(0); // get this one, but it's null - won't be in error. + c->input_tensor(1); // get this one, will now be in error. + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + return Status::OK(); + }; + EXPECT_EQ( + "Invalid argument: Shape must be at most rank 0 but is rank 3 for " + "'foo' (op: 'foo_op') with input shapes: [1,2,3], [4,5] and with " + "computed input tensors: input[1] = <1.1 2.2 3.3 4.4 5.5>.", + c.Run(fn).ToString()); + } + + // Error when a constant tensor value as shape was requested, but no partial + // shapes provided. + { + Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); + InferenceContext c(&def, MakeOpDef(2, 2), {S({3}), S({4})}, + {nullptr, &input_t}, {}, {}, {}); + TF_ASSERT_OK(c.construction_status()); + auto fn = [](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + return Status::OK(); + }; + EXPECT_EQ( + "Invalid argument: Shape must be at most rank 0 but is rank 1 for " + "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed " + "input tensors: input[1] = <1 2 3 4 5>.", + c.Run(fn).ToString()); + } + + // Error when a constant tensor value as shape was requested, and a partial + // shape was provided. + { + Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); + InferenceContext c(&def, MakeOpDef(2, 2), {S({3}), S({4})}, + {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {}, + {}); + TF_ASSERT_OK(c.construction_status()); + auto fn = [](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); + ShapeHandle h; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h)); + c->set_output(0, c->input(0)); + return Status::OK(); + }; + EXPECT_EQ( + "Invalid argument: Shape must be at most rank 0 but is rank 1 for " + "'foo' (op: 'foo_op') with input shapes: [3], [4] and with computed " + "input tensors: input[1] = <1 2 3 4 5> and with input tensors computed " + "as partial shapes: input[0] = [10,?,5].", + c.Run(fn).ToString()); + } +} + TEST_F(ShapeInferenceTest, RankAndDimInspection) { NodeDef def; InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})}, |