aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/shape_inference_test.cc')
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc61
1 files changed, 47 insertions, 14 deletions
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index e4ca7645b2..e52d1c5a2d 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h"
@@ -21,7 +23,8 @@ namespace tensorflow {
namespace shape_inference {
TEST(ShapeInferenceTest, RankAndDimInspection) {
- InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@@ -54,7 +57,8 @@ TEST(ShapeInferenceTest, RankAndDimInspection) {
}
TEST(ShapeInferenceTest, WithRank) {
- InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -91,7 +95,8 @@ TEST(ShapeInferenceTest, WithRank) {
}
TEST(ShapeInferenceTest, WithRankAtLeast) {
- InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -125,7 +130,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) {
}
TEST(ShapeInferenceTest, WithValue) {
- InferenceContext c({"[1,?]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[1,?]"}, 2 /* num_outputs */);
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@@ -163,7 +169,8 @@ TEST(ShapeInferenceTest, WithValue) {
}
TEST(ShapeInferenceTest, MergeDim) {
- InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[2,?,2,1,?]"}, 2 /* num_outputs */);
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@@ -202,7 +209,9 @@ TEST(ShapeInferenceTest, MergeDim) {
}
TEST(ShapeInferenceTest, MergeShape) {
- InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
+ NodeDef def;
+ InferenceContext c(&def,
+ {"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
2 /* num_outputs */);
auto s_unknown = c.input(0);
@@ -260,7 +269,8 @@ TEST(ShapeInferenceTest, MergeShape) {
}
TEST(ShapeInferenceTest, Subshape) {
- InferenceContext c({"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
const Shape* unknown = c.input(1);
const Shape* out;
@@ -297,7 +307,8 @@ TEST(ShapeInferenceTest, Subshape) {
}
TEST(ShapeInferenceTest, Concatenate) {
- InferenceContext c({"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -322,7 +333,8 @@ TEST(ShapeInferenceTest, Concatenate) {
}
TEST(ShapeInferenceTest, CreateShape) {
- InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[1,2,3,?,5]"}, 2 /* num_outputs */);
std::vector<const Dimension*> dims;
auto in0 = c.input(0);
@@ -341,7 +353,8 @@ TEST(ShapeInferenceTest, CreateShape) {
}
TEST(ShapeInferenceTest, CreateUnknownShape) {
- InferenceContext c({}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {}, 2 /* num_outputs */);
auto u0 = c.CreateUnknownShape();
auto u1 = c.CreateUnknownShape();
@@ -352,7 +365,8 @@ TEST(ShapeInferenceTest, CreateUnknownShape) {
TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
auto create = [](Tensor* t) {
- InferenceContext c({"?"}, 0 /* num_outputs */, {t});
+ NodeDef def;
+ InferenceContext c(&def, {"?"}, 0 /* num_outputs */, {t});
const Shape* out;
Status s = c.CreateShapeFromShapeTensor(0, &out);
if (s.ok()) {
@@ -386,7 +400,8 @@ TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
}
TEST(ShapeInferenceTest, CreateDim) {
- InferenceContext c({}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {}, 2 /* num_outputs */);
auto* d0 = c.CreateDim(1);
auto* d1 = c.CreateDim(1);
@@ -398,7 +413,8 @@ TEST(ShapeInferenceTest, CreateDim) {
}
TEST(ShapeInferenceTest, CreateUnknownDim) {
- InferenceContext c({}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {}, 2 /* num_outputs */);
auto* d0 = c.CreateUnknownDim();
auto* d1 = c.CreateUnknownDim();
@@ -410,12 +426,29 @@ TEST(ShapeInferenceTest, CreateUnknownDim) {
TEST(ShapeInferenceTest, InputTensors) {
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
- InferenceContext c({"[1]", "[2]", "[3]"}, 2 /* num_outputs */, {&t1, &t2});
+ NodeDef def;
+ InferenceContext c(&def, {"[1]", "[2]", "[3]"}, 2 /* num_outputs */,
+ {&t1, &t2});
EXPECT_TRUE(c.input_tensor(0) == &t1);
EXPECT_TRUE(c.input_tensor(1) == &t2);
EXPECT_TRUE(c.input_tensor(2) == nullptr);
}
+TEST(ShapeInferenceTest, GetAttr) {
+ OpRegistrationData op_reg_data;
+ CHECK(OpDefBuilder("dummy").Attr("foo:string").Finalize(&op_reg_data).ok());
+ NodeDef def;
+ CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
+ .Attr("foo", "bar")
+ .Finalize(&def)
+ .ok());
+
+ InferenceContext c(&def, {}, 2 /* num_outputs */);
+ string value;
+ EXPECT_TRUE(c.GetAttr("foo", &value).ok());
+ EXPECT_EQ("bar", value);
+}
+
} // namespace shape_inference
} // namespace tensorflow