diff options
Diffstat (limited to 'tensorflow/core/framework/shape_inference_test.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference_test.cc | 61 |
1 files changed, 47 insertions, 14 deletions
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index e4ca7645b2..e52d1c5a2d 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" @@ -21,7 +23,8 @@ namespace tensorflow { namespace shape_inference { TEST(ShapeInferenceTest, RankAndDimInspection) { - InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {"?", "[1,?,3]", "[]"}, 2 /* num_outputs */); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(2, c.num_outputs()); @@ -54,7 +57,8 @@ TEST(ShapeInferenceTest, RankAndDimInspection) { } TEST(ShapeInferenceTest, WithRank) { - InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */); auto in0 = c.input(0); auto in1 = c.input(1); @@ -91,7 +95,8 @@ TEST(ShapeInferenceTest, WithRank) { } TEST(ShapeInferenceTest, WithRankAtLeast) { - InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */); auto in0 = c.input(0); auto in1 = c.input(1); @@ -125,7 +130,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) { } TEST(ShapeInferenceTest, WithValue) { - InferenceContext c({"[1,?]"}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {"[1,?]"}, 2 /* num_outputs */); auto d0 = c.Dim(c.input(0), 0); auto d1 = c.Dim(c.input(0), 1); @@ -163,7 +169,8 @@ TEST(ShapeInferenceTest, WithValue) { } TEST(ShapeInferenceTest, MergeDim) { - InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {"[2,?,2,1,?]"}, 2 /* num_outputs */); auto d2 = c.Dim(c.input(0), 0); auto d_unknown = c.Dim(c.input(0), 1); @@ -202,7 +209,9 @@ TEST(ShapeInferenceTest, MergeDim) { } TEST(ShapeInferenceTest, MergeShape) { - InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"}, + NodeDef def; + InferenceContext c(&def, + {"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"}, 2 /* num_outputs */); auto s_unknown = c.input(0); @@ -260,7 +269,8 @@ TEST(ShapeInferenceTest, MergeShape) { } TEST(ShapeInferenceTest, Subshape) { - InferenceContext c({"[1,2,3,?,5]", "?"}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {"[1,2,3,?,5]", "?"}, 2 /* num_outputs */); const Shape* unknown = c.input(1); const Shape* out; @@ -297,7 +307,8 @@ TEST(ShapeInferenceTest, Subshape) { } TEST(ShapeInferenceTest, Concatenate) { - InferenceContext c({"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */); auto in0 = c.input(0); auto in1 = c.input(1); @@ -322,7 +333,8 @@ TEST(ShapeInferenceTest, Concatenate) { } TEST(ShapeInferenceTest, CreateShape) { - InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {"[1,2,3,?,5]"}, 2 /* num_outputs */); std::vector<const Dimension*> dims; auto in0 = c.input(0); @@ -341,7 +353,8 @@ TEST(ShapeInferenceTest, CreateShape) { } TEST(ShapeInferenceTest, CreateUnknownShape) { - InferenceContext c({}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {}, 2 /* num_outputs */); auto u0 = c.CreateUnknownShape(); auto u1 = c.CreateUnknownShape(); @@ -352,7 +365,8 @@ TEST(ShapeInferenceTest, CreateUnknownShape) { TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) { auto create = [](Tensor* t) { - InferenceContext c({"?"}, 0 /* num_outputs */, {t}); + NodeDef def; + InferenceContext c(&def, {"?"}, 0 /* num_outputs */, {t}); const Shape* out; Status s = c.CreateShapeFromShapeTensor(0, &out); if (s.ok()) { @@ -386,7 +400,8 @@ TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) { } TEST(ShapeInferenceTest, CreateDim) { - InferenceContext c({}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {}, 2 /* num_outputs */); auto* d0 = c.CreateDim(1); auto* d1 = c.CreateDim(1); @@ -398,7 +413,8 @@ TEST(ShapeInferenceTest, CreateDim) { } TEST(ShapeInferenceTest, CreateUnknownDim) { - InferenceContext c({}, 2 /* num_outputs */); + NodeDef def; + InferenceContext c(&def, {}, 2 /* num_outputs */); auto* d0 = c.CreateUnknownDim(); auto* d1 = c.CreateUnknownDim(); @@ -410,12 +426,29 @@ TEST(ShapeInferenceTest, CreateUnknownDim) { TEST(ShapeInferenceTest, InputTensors) { const Tensor t1 = tensorflow::test::AsTensor<float>({10}); const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30}); - InferenceContext c({"[1]", "[2]", "[3]"}, 2 /* num_outputs */, {&t1, &t2}); + NodeDef def; + InferenceContext c(&def, {"[1]", "[2]", "[3]"}, 2 /* num_outputs */, + {&t1, &t2}); EXPECT_TRUE(c.input_tensor(0) == &t1); EXPECT_TRUE(c.input_tensor(1) == &t2); EXPECT_TRUE(c.input_tensor(2) == nullptr); } +TEST(ShapeInferenceTest, GetAttr) { + OpRegistrationData op_reg_data; + CHECK(OpDefBuilder("dummy").Attr("foo:string").Finalize(&op_reg_data).ok()); + NodeDef def; + CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def) + .Attr("foo", "bar") + .Finalize(&def) + .ok()); + + InferenceContext c(&def, {}, 2 /* num_outputs */); + string value; + EXPECT_TRUE(c.GetAttr("foo", &value).ok()); + EXPECT_EQ("bar", value); +} + } // namespace shape_inference } // namespace tensorflow |