diff options
author | 2017-03-03 14:14:16 -0800 | |
---|---|---|
committer | 2017-03-03 14:26:26 -0800 | |
commit | b177e3720721dea593f1f15ba731ab351e87d298 (patch) | |
tree | 6a58e524addc6b2518df22fee094b9a712576910 /tensorflow | |
parent | 1dc89c1ab1bee51ae40f97994ef81ac3d6b1391c (diff) |
Add the graphdef version to InferenceContext and to ShapeRefiner::AddNode.
Use this to allow loading reductions saved with older graphdefs.
Change GraphConstructor to not increase the version when importing, but instead take the min of all versions.
Change: 149152437
Diffstat (limited to 'tensorflow')
25 files changed, 344 insertions, 174 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 496ec8dc86..02aba54e43 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -730,7 +730,7 @@ extern "C" { struct TF_Graph { TF_Graph() : graph(OpRegistry::Global()), - refiner(graph.op_registry()), + refiner(graph.versions().producer(), graph.op_registry()), num_sessions(0), delete_requested(false), parent(nullptr), diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index c5629dbd6d..571c6e1e57 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -118,7 +118,8 @@ Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); - ShapeRefiner* refiner = new ShapeRefiner(graph->op_registry()); + ShapeRefiner* refiner = + new ShapeRefiner(graph->versions().producer(), graph->op_registry()); return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner)); } diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 52c9ea182f..7288ecb143 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -31,8 +31,9 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -ShapeRefiner::ShapeRefiner(const OpRegistryInterface* ops) - : ops_registry_(ops) {} +ShapeRefiner::ShapeRefiner(int graph_def_version, + const OpRegistryInterface* ops) + : graph_def_version_(graph_def_version), ops_registry_(ops) {} Status ShapeRefiner::AddNode(const Node* node) { // For each 'input' of this node, fetch the corresponding shape @@ -85,9 +86,10 @@ Status ShapeRefiner::AddNode(const Node* node) { std::vector<ShapeHandle> input_tensors_as_shapes; // Create the inference context for this node with the existing input shapes. - std::unique_ptr<InferenceContext> c(new InferenceContext( - &node->def(), node->op_def(), input_shapes, input_tensors, - input_tensors_as_shapes, input_handle_shapes, input_handle_dtypes)); + std::unique_ptr<InferenceContext> c( + new InferenceContext(graph_def_version_, &node->def(), node->op_def(), + input_shapes, input_tensors, input_tensors_as_shapes, + input_handle_shapes, input_handle_dtypes)); if (!c->construction_status().ok()) { return c->construction_status(); } diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 43466727d9..b8d69fc05b 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -31,7 +31,7 @@ namespace tensorflow { // construction time. class ShapeRefiner { public: - explicit ShapeRefiner(const OpRegistryInterface* ops); + ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); // Performs validation of 'node' and runs 'node's shape function, // storing its shape outputs. @@ -98,7 +98,8 @@ class ShapeRefiner { const Node* node, int dst_idx, shape_inference::ShapeHandle* result); - const OpRegistryInterface* ops_registry_ = nullptr; + const int graph_def_version_; + const OpRegistryInterface* const ops_registry_; // Stores a map from a node to its InferenceContext. // diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 54a00ac9ff..05274ff311 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -38,14 +39,14 @@ TEST(ShapeRefinerTest, Constant) { // and that its shape is correct. Scope root = Scope::NewRootScope(); auto c = ops::Const(root, 42.0f); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(c.node())); EXPECT_SHAPE("[]", m, c, 0); } TEST(ShapeRefinerTest, MatMul) { - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Scope root = Scope::NewRootScope(); auto a = ops::Const(root, {{1.0f}, {2.0f}}); @@ -62,7 +63,7 @@ TEST(ShapeRefinerTest, MatMul) { } TEST(ShapeRefinerTest, InvalidOrder) { - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Scope root = Scope::NewRootScope(); auto a = ops::Const(root, {{1.0f}, {2.0f}}); auto b = ops::Const(root, {{1.0f, 2.0f}}); @@ -77,7 +78,7 @@ TEST(ShapeRefinerTest, InvalidOrder) { } TEST(ShapeRefinerTest, BadShapes) { - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Scope root = Scope::NewRootScope(); auto a = ops::Const(root, {{1.0f}, {2.0f}}); auto b = ops::Const(root, {{1.0f}, {2.0f}}); @@ -94,7 +95,7 @@ TEST(ShapeRefinerTest, BadShapes) { } TEST(ShapeRefinerTest, SetShape) { - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Scope root = Scope::NewRootScope(); auto a = ops::Placeholder(root, DT_FLOAT); @@ -136,7 +137,7 @@ TEST(ShapeRefinerTest, PropagateConstants) { auto dim = ops::Variable(root, {}, DT_INT32); auto am = ops::ArgMax(root, input, dim); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(input.node())); TF_ASSERT_OK(m.AddNode(dim.node())); TF_ASSERT_OK(m.AddNode(am.node())); @@ -153,7 +154,7 @@ TEST(ShapeRefinerTest, PropagateConstants) { auto dim = ops::Const(root, 1); auto am = ops::ArgMax(root, input, dim); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(input.node())); TF_ASSERT_OK(m.AddNode(dim.node())); TF_ASSERT_OK(m.AddNode(am.node())); @@ -169,7 +170,7 @@ TEST(ShapeRefinerTest, PropagateConstants) { auto dim = ops::Const(root, 0); auto am = ops::ArgMax(root, input, dim); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(input.node())); TF_ASSERT_OK(m.AddNode(dim.node())); TF_ASSERT_OK(m.AddNode(am.node())); @@ -199,7 +200,7 @@ REGISTER_OP("TestOp") } // namespace TEST(ShapeRefinerTest, InputTensorDependencies) { - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Graph graph(OpRegistry::Global()); Node* node; @@ -260,7 +261,7 @@ TEST(ShapeRefinerTest, PropagateShape) { .Input(shape.node()) .Finalize(root.graph(), &shape_data)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(input.node())); TF_ASSERT_OK(m.AddNode(shape.node())); TF_ASSERT_OK(m.AddNode(shape_data)); @@ -281,7 +282,7 @@ TEST(ShapeRefinerTest, PropagateSize) { .Input(size.node()) .Finalize(root.graph(), &shape_data)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(input.node())); TF_ASSERT_OK(m.AddNode(size.node())); TF_ASSERT_OK(m.AddNode(shape_data)); @@ -302,7 +303,7 @@ TEST(ShapeRefinerTest, PropagateRank) { .Input(rank.node()) .Finalize(root.graph(), &shape_data)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(input.node())); TF_ASSERT_OK(m.AddNode(rank.node())); TF_ASSERT_OK(m.AddNode(shape_data)); @@ -323,7 +324,7 @@ TEST(ShapeRefinerTest, PropagateRange) { .Input(range.node()) .Finalize(root.graph(), &shape_data)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(begin.node())); TF_ASSERT_OK(m.AddNode(limit.node())); TF_ASSERT_OK(m.AddNode(delta.node())); @@ -346,7 +347,7 @@ TEST(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) { .Input(range.node()) .Finalize(root.graph(), &shape_data)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(begin_and_delta.node())); TF_ASSERT_OK(m.AddNode(limit.node())); TF_ASSERT_OK(m.AddNode(range.node())); @@ -381,7 +382,7 @@ TEST(ShapeRefinerTest, ConstantValueVisitNodeTwice) { .Input(range.node()) .Finalize(root.graph(), &shape_data)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(begin.node())); TF_ASSERT_OK(m.AddNode(limit.node())); TF_ASSERT_OK(m.AddNode(delta.node())); @@ -477,7 +478,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) { .Input(input) .Finalize(root.graph(), &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(input)); TF_ASSERT_OK(m.AddNode(result)); @@ -498,7 +499,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Shape) { .Input(shape.node()) .Finalize(root.graph(), &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(input)); TF_ASSERT_OK(m.AddNode(shape.node())); TF_ASSERT_OK(m.AddNode(result)); @@ -533,7 +534,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt32) { .Input(pack.node()) .Finalize(root.graph(), &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); for (auto input : inputs) { TF_ASSERT_OK(m.AddNode(input.node())); } @@ -565,7 +566,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt64) { .Input(pack.node()) .Finalize(root.graph(), &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); for (const auto& input : inputs) { TF_ASSERT_OK(m.AddNode(input.node())); } @@ -591,7 +592,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) { .Input(pack.node()) .Finalize(root.graph(), &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); for (const auto& input : inputs) { TF_ASSERT_OK(m.AddNode(input.node())); } @@ -618,7 +619,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) { .Input(pack.node()) .Finalize(root.graph(), &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); for (const auto& input : inputs) { TF_ASSERT_OK(m.AddNode(input.node())); } @@ -650,7 +651,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) { .Input(concat.node()) .Finalize(g, &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(partial_1)); TF_ASSERT_OK(m.AddNode(partial_2)); for (const auto& o : concat_inputs) { @@ -692,7 +693,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) { .Input(concat.node()) .Finalize(g, &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(partial_1)); TF_ASSERT_OK(m.AddNode(partial_2)); TF_ASSERT_OK(m.AddNode(unknown)); @@ -734,7 +735,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { .Input(concat.node()) .Finalize(g, &result)); - ShapeRefiner m(OpRegistry::Global()); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); TF_ASSERT_OK(m.AddNode(partial_1)); TF_ASSERT_OK(m.AddNode(partial_2)); for (const auto& o : concat_inputs) { diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 9d5d212ddd..ede0452f14 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -590,7 +590,13 @@ Status ReductionShape(InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle indices; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices)); + // Older versions of TensorFlow accidentally allowed higher rank tensors like + // [[1,2]] or [[1],[2]] to represent axis=[1,2]. + if (c->graph_def_version() < 21) { + indices = c->input(1); + } else { + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices)); + } bool keep_dims; TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims)); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 89acf1202c..2d9e96e6bc 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -69,7 +69,8 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) { .Input({{"data", 0, DT_FLOAT}}) .Finalize(&def)); - InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {}, + {}, {}, {}); TF_EXPECT_OK(NoOutputs(&c)); EXPECT_EQ(0, c.num_outputs()); } @@ -87,14 +88,16 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) { NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def)); { - InferenceContext c(&def, op_def, {S({})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {}, + {}); TF_EXPECT_OK(ScalarShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(0, c.Rank(output)); } { - InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({1, 23, 4, 4, 2})}, {}, {}, {}, {}); TF_EXPECT_OK(ScalarShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(0, c.Rank(output)); @@ -121,7 +124,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Finalize(&def)); { - InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 3}), S({3, 4})}, {}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -130,7 +134,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Unknown inner dimension for one - InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, -1}), S({3, 4})}, {}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -139,7 +144,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Invalid rank. - InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})}, + {}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); EXPECT_TRUE( @@ -149,7 +155,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Unknown outer dimension - InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 3}), S({3, -1})}, {}, {}, {}, {}); TF_EXPECT_OK(MatMulShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -158,7 +165,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Inner shapes not compatible - InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 5}), S({3, 4})}, {}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); EXPECT_TRUE( @@ -169,8 +177,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { { // Inner shapes not compatible - InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}, - {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}, {}); auto s = MatMulShape(&c); EXPECT_FALSE(s.ok()); EXPECT_TRUE( @@ -188,7 +196,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Attr("type", DT_FLOAT) .Finalize(&def)); - InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({3, 2}), S({3, 4})}, {}, {}, {}, {}); auto s = MatMulShape(&c); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -205,7 +214,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { .Attr("type", DT_FLOAT) .Finalize(&def)); - InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 3}), S({4, 3})}, {}, {}, {}, {}); auto s = MatMulShape(&c); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -229,7 +239,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Finalize(&def)); { - InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 10}), S({10})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(2, c.Value(c.Dim(output, 0))); @@ -238,7 +249,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Unknown ranks. - InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {Unknown(), Unknown()}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_FALSE(c.RankKnown(output)); @@ -246,8 +258,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Rank > 2 - InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {}, - {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output)); @@ -260,7 +272,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ("[2,3,4,5]", c.DebugString(output)); @@ -273,8 +286,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, - {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output)); @@ -287,8 +300,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Input("b", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {}, {}, - {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({10, 11, 12}), S({10})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ("[10,11,12]", c.DebugString(output)); @@ -296,7 +309,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { { // Input rank not high enough - InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {}, + {}, {}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); } @@ -308,7 +322,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) { .Attr("data_format", "NCHW") .Finalize(&def)); // NCHW format - InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})}, + {}, {}, {}, {}); EXPECT_FALSE(BiasAddShape(&c).ok()); } } @@ -327,7 +342,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Finalize(&def)); { - InferenceContext c(&def, op_def, {S({2, 10})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {}, + {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(10, c.Value(c.Dim(output, 0))); @@ -335,7 +351,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { { // Rank > 2 - InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})}, + {}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(10, c.Value(c.Dim(output, 0))); @@ -347,7 +364,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})}, + {}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(3, c.Value(c.Dim(output, 0))); @@ -359,8 +377,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {}, - {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, + {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(3, c.Value(c.Dim(output, 0))); @@ -372,7 +390,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Input("a", 0, DT_FLOAT) .Attr("data_format", "NCHW") .Finalize(&def)); - InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})}, + {}, {}, {}, {}); TF_EXPECT_OK(BiasAddGradShape(&c)); ShapeHandle output = c.output(0); EXPECT_EQ(10, c.Value(c.Dim(output, 0))); @@ -380,7 +399,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { { // Input rank not high enough - InferenceContext c(&def, op_def, {S({3})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {}, {}, + {}); EXPECT_FALSE(BiasAddGradShape(&c).ok()); } @@ -391,7 +411,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { .Attr("data_format", "NCHW") .Finalize(&def)); // NCHW format - InferenceContext c(&def, op_def, {S({2, 3})}, {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {}, + {}, {}); EXPECT_FALSE(BiasAddGradShape(&c).ok()); } } @@ -781,12 +802,24 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) { op.input_tensors[1] = nullptr; INFER_OK(op, "[?,?,?];?", "[?,?,?]"); INFER_OK(op, "[?,?,?];[2]", "[?,?,?]"); + + // Reduction indices with too many dimensions. + INFER_ERROR("must be at most rank 1 but is rank 2", op, "[?,?,?];[?,?]"); + // With older graph-def version, this is allowed. + op.graph_def_version = 20; + INFER_OK(op, "[?,?,?];[?,?]", "[?,?,?]"); + // And when the tensor is specified, it's still allowed. + op.input_tensors[1] = &indices; + indices = test::AsTensor<int32>({-1, -2}, TensorShape({2, 1})); + INFER_OK(op, "[2,4,5];[2,1]", "[d0_0, 1, 1]"); + indices = test::AsTensor<int32>({-1, -2}, TensorShape({1, 2})); + INFER_OK(op, "[2,4,5];[1,2]", "[d0_0, 1, 1]"); } TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()}, - {}, {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {Unknown(), Unknown(), Unknown()}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -798,8 +831,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {}, - {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({-1, -1}), S({-1}), S({-1})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -811,8 +844,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) { TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {}, - {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({-1}), S({-1}), S({-1})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -825,8 +858,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) { TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {}, - {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({4}), S({3})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -839,8 +872,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) { TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {}, - {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({5}), S({4})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -853,8 +886,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {}, - {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({-1, 3}), S({5}), S({3})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -866,8 +899,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {}, - {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({-1}), S({3})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -879,8 +912,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {}, - {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, -1}), S({5}), S({3})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -892,8 +925,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) { TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {}, - {}, {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({5}), S({-1})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); @@ -905,8 +938,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) { TEST(CommonShapeFnsTest, ValidateSparseTensor) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {}, - {}, {}); + InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1), + {S({5, 3}), S({5}), S({3})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(1, c.num_outputs()); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 00f2c3407a..cbfa9bd20c 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -29,13 +29,14 @@ constexpr int32 InferenceContext::kUnknownRank; constexpr int64 InferenceContext::kUnknownDim; InferenceContext::InferenceContext( - const NodeDef* node_def, const OpDef& op_def, + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector<TensorShapeProto>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<TensorShapeProto>& input_tensors_as_shapes, const std::vector<TensorShapeProto>& input_handle_shapes, const std::vector<DataType>& input_handle_dtypes) - : node_def_(*CHECK_NOTNULL(node_def)) { + : graph_def_version_(graph_def_version), + node_def_(*CHECK_NOTNULL(node_def)) { std::vector<ShapeHandle> input_tensors_as_shape_handles; for (const TensorShapeProto& p : input_tensors_as_shapes) { ShapeHandle shape; @@ -68,13 +69,14 @@ InferenceContext::InferenceContext( } InferenceContext::InferenceContext( - const NodeDef* node_def, const OpDef& op_def, + int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<ShapeHandle>& input_tensors_as_shapes, const std::vector<ShapeHandle>& input_handle_shapes, const std::vector<DataType>& input_handle_dtypes) - : node_def_(*CHECK_NOTNULL(node_def)) { + : graph_def_version_(graph_def_version), + node_def_(*CHECK_NOTNULL(node_def)) { PreInputInit(op_def, input_tensors, input_tensors_as_shapes); if (!construction_status_.ok()) return; inputs_ = input_shapes; diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index fd4e25c728..dba8d30302 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -144,7 +144,8 @@ class InferenceContext { // Values of <input_tensors_as_shapes> do not need to outlive the context. // // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext. - InferenceContext(const NodeDef* node_def, const OpDef& op_def, + InferenceContext(int graph_def_version, const NodeDef* node_def, + const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<ShapeHandle>& input_tensors_as_shapes, @@ -161,7 +162,8 @@ class InferenceContext { // Values of <input_tensors_as_shapes> do not need to outlive the context. // // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext. - InferenceContext(const NodeDef* node_def, const OpDef& op_def, + InferenceContext(int graph_def_version, const NodeDef* node_def, + const OpDef& op_def, const std::vector<TensorShapeProto>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<TensorShapeProto>& input_tensors_as_shapes, @@ -436,6 +438,8 @@ class InferenceContext { Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, ShapeHandle* out); + int graph_def_version() const { return graph_def_version_; } + private: // Creates and stores shapes for use in InferenceContext. class ShapeManager { @@ -508,6 +512,7 @@ class InferenceContext { std::vector<ShapeHandle> output_handle_shape_; std::vector<DataType> output_handle_dtype_; + const int graph_def_version_; const NodeDef& node_def_; NameRangeMap input_name_map_; NameRangeMap output_name_map_; diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 9f363d50b3..9fc068aebe 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -61,6 +61,8 @@ class ShapeInferenceTest : public ::testing::Test { bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); } bool IsSet(DimensionHandle d) { return d.IsSet(); } bool IsSet(ShapeHandle s) { return s.IsSet(); } + + static const int kVersion = 0; // used for graph-def version. }; TEST_F(ShapeInferenceTest, InputOutputByName) { @@ -71,8 +73,8 @@ TEST_F(ShapeInferenceTest, InputOutputByName) { .Attr("N", 3) .Input(FakeInput(DT_FLOAT)) .Finalize(&def); - InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {}, - {}, {}); + InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, + {}, {}, {}, {}); EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0)))); EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1)))); @@ -108,7 +110,8 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) { TEST_F(ShapeInferenceTest, DimensionOrConstant) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}, + {}); EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(InferenceContext::kUnknownDim)); EXPECT_EQ(1, c.Value(1)); @@ -123,7 +126,7 @@ TEST_F(ShapeInferenceTest, Run) { NodeDef def; def.set_name("foo"); def.set_op("foo_op"); - InferenceContext c(&def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}, {}); TF_ASSERT_OK(c.construction_status()); { @@ -160,7 +163,8 @@ TEST_F(ShapeInferenceTest, AttachContext) { def.set_op("foo_op"); // Error when no constant tensors were requested. { - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, + {}, {}); TF_ASSERT_OK(c.construction_status()); auto fn = [](InferenceContext* c) { ShapeHandle h; @@ -178,8 +182,9 @@ TEST_F(ShapeInferenceTest, AttachContext) { { Tensor input_t = ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5}); - InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3}), S({4, 5})}, - {nullptr, &input_t}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {}, + {}); TF_ASSERT_OK(c.construction_status()); auto fn = [](InferenceContext* c) { c->input_tensor(0); // get this one, but it's null - won't be in error. @@ -200,7 +205,7 @@ TEST_F(ShapeInferenceTest, AttachContext) { // shapes provided. { Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); - InferenceContext c(&def, MakeOpDef(2, 2), {S({3}), S({4})}, + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, {nullptr, &input_t}, {}, {}, {}); TF_ASSERT_OK(c.construction_status()); auto fn = [](InferenceContext* c) { @@ -223,7 +228,7 @@ TEST_F(ShapeInferenceTest, AttachContext) { // shape was provided. { Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5}); - InferenceContext c(&def, MakeOpDef(2, 2), {S({3}), S({4})}, + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})}, {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {}, {}); TF_ASSERT_OK(c.construction_status()); @@ -247,8 +252,8 @@ TEST_F(ShapeInferenceTest, AttachContext) { TEST_F(ShapeInferenceTest, RankAndDimInspection) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})}, - {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(3, 2), + {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {}, {}); EXPECT_EQ(3, c.num_inputs()); EXPECT_EQ(2, c.num_outputs()); @@ -288,7 +293,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) { TEST_F(ShapeInferenceTest, NumElements) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 2), + InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {}, {}); @@ -303,8 +308,8 @@ TEST_F(ShapeInferenceTest, NumElements) { TEST_F(ShapeInferenceTest, WithRank) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}, - {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {Unknown(), S({1, -1, 3})}, {}, {}, {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -342,8 +347,8 @@ TEST_F(ShapeInferenceTest, WithRank) { TEST_F(ShapeInferenceTest, WithRankAtMost) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}, - {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {Unknown(), S({1, -1, 3})}, {}, {}, {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -381,8 +386,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) { TEST_F(ShapeInferenceTest, WithRankAtLeast) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {}, - {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {Unknown(), S({1, -1, 3})}, {}, {}, {}, {}); auto in0 = c.input(0); auto in1 = c.input(1); @@ -420,7 +425,8 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) { TEST_F(ShapeInferenceTest, WithValue) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}, + {}); auto d0 = c.Dim(c.input(0), 0); auto d1 = c.Dim(c.input(0), 1); @@ -461,8 +467,8 @@ TEST_F(ShapeInferenceTest, WithValue) { TEST_F(ShapeInferenceTest, MergeDim) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {}, {}, - {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, + {}, {}, {}, {}); auto d2 = c.Dim(c.input(0), 0); auto d_unknown = c.Dim(c.input(0), 1); @@ -508,7 +514,7 @@ TEST_F(ShapeInferenceTest, MergeDim) { TEST_F(ShapeInferenceTest, MergeShape) { NodeDef def; - InferenceContext c(&def, MakeOpDef(7, 2), + InferenceContext c(kVersion, &def, MakeOpDef(7, 2), {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}), Unknown(), S({1})}, {}, {}, {}, {}); @@ -578,7 +584,7 @@ TEST_F(ShapeInferenceTest, MergeShape) { TEST_F(ShapeInferenceTest, MergePrefix) { NodeDef def; - InferenceContext c(&def, MakeOpDef(4, 2), + InferenceContext c(kVersion, &def, MakeOpDef(4, 2), { Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}), }, @@ -634,8 +640,8 @@ TEST_F(ShapeInferenceTest, MergePrefix) { TEST_F(ShapeInferenceTest, Subshape) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()}, - {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), + {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {}, {}); ShapeHandle unknown = c.input(1); ShapeHandle out; @@ -709,7 +715,7 @@ TEST_F(ShapeInferenceTest, Subshape) { TEST_F(ShapeInferenceTest, Concatenate) { NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 2), + InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {}); auto in0 = c.input(0); @@ -736,8 +742,8 @@ TEST_F(ShapeInferenceTest, Concatenate) { TEST_F(ShapeInferenceTest, ReplaceDim) { NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {}, - {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, + {}, {}, {}, {}); auto in = c.input(0); auto unknown = c.input(1); @@ -768,8 +774,8 @@ TEST_F(ShapeInferenceTest, ReplaceDim) { TEST_F(ShapeInferenceTest, MakeShape) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}, {}, - {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, + {}, {}, {}); std::vector<DimensionHandle> dims; auto in0 = c.input(0); @@ -794,7 +800,7 @@ TEST_F(ShapeInferenceTest, MakeShape) { TEST_F(ShapeInferenceTest, UnknownShape) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto u0 = c.UnknownShape(); auto u1 = c.UnknownShape(); @@ -806,7 +812,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) { TEST_F(ShapeInferenceTest, Scalar) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto s0 = c.Scalar(); EXPECT_EQ("[]", c.DebugString(s0)); @@ -817,7 +823,7 @@ TEST_F(ShapeInferenceTest, Scalar) { TEST_F(ShapeInferenceTest, Vector) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto s0 = c.Vector(1); EXPECT_EQ("[1]", c.DebugString(s0)); @@ -833,7 +839,7 @@ TEST_F(ShapeInferenceTest, Vector) { TEST_F(ShapeInferenceTest, Matrix) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto s0 = c.Matrix(1, 2); EXPECT_EQ("[1,2]", c.DebugString(s0)); @@ -855,7 +861,8 @@ TEST_F(ShapeInferenceTest, Matrix) { TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { auto create = [&](Tensor* t) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, + {}, {}); ShapeHandle out; Status s = c.MakeShapeFromShapeTensor(0, &out); if (s.ok()) { @@ -907,8 +914,8 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { // Test when the input shape is wrong. { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}, {}, - {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, + {}, {}, {}); ShapeHandle out; EXPECT_EQ("Shape must be rank 1 but is rank 2", c.MakeShapeFromShapeTensor(0, &out).error_message()); @@ -918,7 +925,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); TensorShapeProto proto; // With a set unknown rank. @@ -954,7 +961,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) { TEST_F(ShapeInferenceTest, MakeDim) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto d0 = c.MakeDim(1); auto d1 = c.MakeDim(1); @@ -968,7 +975,7 @@ TEST_F(ShapeInferenceTest, MakeDim) { TEST_F(ShapeInferenceTest, UnknownDim) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto d0 = c.UnknownDim(); auto d1 = c.UnknownDim(); @@ -980,7 +987,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) { TEST_F(ShapeInferenceTest, UnknownShapeOfRank) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3); EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3)); @@ -993,7 +1000,7 @@ TEST_F(ShapeInferenceTest, InputTensors) { const Tensor t1 = tensorflow::test::AsTensor<float>({10}); const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30}); NodeDef def; - InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})}, + InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})}, {&t1, &t2}, {}, {}, {}); EXPECT_TRUE(c.input_tensor(0) == &t1); @@ -1005,8 +1012,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) { Tensor t1 = tensorflow::test::AsScalar<int32>(20); Tensor t2 = tensorflow::test::AsScalar<int32>(-1); NodeDef def; - InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {}, {}, - {}); + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, + {&t1, &t2}, {}, {}, {}); DimensionHandle d; EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok()); @@ -1037,7 +1044,7 @@ TEST_F(ShapeInferenceTest, GetAttr) { .ok()); std::vector<ShapeHandle> empty; - InferenceContext c(&def, op_reg_data.op_def, empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {}, {}); string value; EXPECT_TRUE(c.GetAttr("foo", &value).ok()); EXPECT_EQ("bar", value); @@ -1045,8 +1052,8 @@ TEST_F(ShapeInferenceTest, GetAttr) { TEST_F(ShapeInferenceTest, Divide) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}, {}, - {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, + {}, {}, {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1108,7 +1115,8 @@ TEST_F(ShapeInferenceTest, Divide) { TEST_F(ShapeInferenceTest, Add) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, + {}, {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1159,7 +1167,8 @@ TEST_F(ShapeInferenceTest, Add) { TEST_F(ShapeInferenceTest, Subtract) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, + {}, {}, {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1208,7 +1217,8 @@ TEST_F(ShapeInferenceTest, Subtract) { TEST_F(ShapeInferenceTest, Multiply) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, + {}, {}, {}); auto s = c.input(0); auto d_6 = c.Dim(s, 0); @@ -1261,7 +1271,7 @@ TEST_F(ShapeInferenceTest, Multiply) { TEST_F(ShapeInferenceTest, FullyDefined) { NodeDef def; std::vector<ShapeHandle> empty; - InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {}); // No rank or missing dimension information should return false. EXPECT_FALSE(c.FullyDefined(c.UnknownShape())); @@ -1274,7 +1284,8 @@ TEST_F(ShapeInferenceTest, FullyDefined) { TEST_F(ShapeInferenceTest, Min) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, + {}, {}, {}); auto s = c.input(0); auto d_1 = c.Dim(s, 0); @@ -1322,7 +1333,8 @@ TEST_F(ShapeInferenceTest, Min) { TEST_F(ShapeInferenceTest, Max) { NodeDef def; - InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, {}, {}); + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, + {}, {}); auto s = c.input(0); auto d_1 = c.Dim(s, 0); diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index a225824f82..85e085af99 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -43,8 +43,9 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, in_shapes.push_back(shape); } - shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def, - in_shapes, op.input_tensors, {}, {}, {}); + shape_inference::InferenceContext c(op.graph_def_version, &op.node_def, + op_reg_data->op_def, in_shapes, + op.input_tensors, {}, {}, {}); TF_RETURN_IF_ERROR(c.construction_status()); if (op_reg_data->shape_inference_fn == nullptr) { return errors::InvalidArgument( diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index 64067464fb..996281e70e 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/version.h" // Contains utilities for writing tests for shape inference functions. @@ -34,6 +35,7 @@ struct ShapeInferenceTestOp { string name; NodeDef node_def; std::vector<const Tensor*> input_tensors; + int graph_def_version = TF_GRAPH_DEF_VERSION; }; namespace shape_inference { diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index c68ac37fa8..a83cf26723 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -745,8 +745,8 @@ Status GraphConstructor::UpdateVersionDef() { return Status::OK(); } VersionDef versions = g_->versions(); - // This new graph is being "produced" by the binary invoking ImportGraphDef. - versions.set_producer(TF_GRAPH_DEF_VERSION); + versions.set_producer( + std::min(versions.producer(), gdef_->versions().producer())); versions.set_min_consumer( std::max(versions.min_consumer(), gdef_->versions().min_consumer())); if (gdef_->versions().bad_consumers_size() > 0) { @@ -820,14 +820,14 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, const GraphDef& gdef, Graph* g) { - ShapeRefiner refiner(g->op_registry()); + ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); return GraphConstructor::Construct(opts, &gdef, g, &refiner, nullptr); } Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g, ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors) { - ShapeRefiner default_refiner(g->op_registry()); + ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry()); if (refiner == nullptr) { refiner = &default_refiner; } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index e20d89485d..02f614dad2 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -203,6 +203,15 @@ REGISTER_OP("TestOneInputOneOutput") REGISTER_OP("TestDefaultAttr") .Attr("default_int: int=31415") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("RequiresCurrentGraphVersion") + .Output("version: int32") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + if (c->graph_def_version() != TF_GRAPH_DEF_VERSION) { + return errors::InvalidArgument("Wrong graph version for shape"); + } + return shape_inference::ScalarShape(c); + }); TEST_F(GraphConstructorTest, InvalidNodeName) { auto expect_invalid_name = [this](const char* name) { @@ -1052,7 +1061,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ShapeWhitelist) { } TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Populate graph with node we'll use in input map ExpectOK("node { name: 'input' op: 'TestInput' }", ImportGraphDefOptions(), @@ -1092,7 +1101,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) { } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Populate graph with node we'll use in input map ExpectOK( @@ -1155,7 +1164,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) { } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Populate graph with node we'll use in input map ExpectOK("node { name: 'W1' op: 'TestParams' }", ImportGraphDefOptions(), @@ -1219,7 +1228,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) { } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithBadControlEdge) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Populate graph with node we'll use in input map ExpectOK("node { name: 'W1' op: 'TestParams' }", ImportGraphDefOptions(), @@ -1251,7 +1260,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithBadControlEdge) { } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithInvalidNodeIndex) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Populate graph with node we'll use in input map ExpectOK("node { name: 'input1' op: 'TestInput' }", ImportGraphDefOptions(), @@ -1272,7 +1281,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithInvalidNodeIndex) { } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithMissingEntries) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Populate graph with node we'll use in input map ExpectOK("node { name: 'W1' op: 'TestParams' }", ImportGraphDefOptions(), @@ -1293,7 +1302,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithMissingEntries) { } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Add two nodes with the same name to graph Node* node; @@ -1318,7 +1327,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) { } TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); ImportGraphDefOptions opts; opts.return_tensors.push_back({"input", 1}); @@ -1634,7 +1643,7 @@ versions { } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Populate graph with nodes we'll use in control deps and input map ExpectOK( @@ -1701,7 +1710,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) { } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { - ShapeRefiner refiner(graph_.op_registry()); + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // Populate graph with nodes we'll use in control deps and input map ExpectOK( @@ -1939,5 +1948,56 @@ TEST_F(GraphConstructorTest, CopyGraph) { EXPECT_EQ(dst.versions().bad_consumers(0), bad); } +// Confirms that graph def version in the graph reaches the shape inference +// function. +TEST_F(GraphConstructorTest, GraphDefVersionUsedForShapeInference) { + string gdef_ascii = strings::StrCat(R"EOF( + node{ name:"A" op:"RequiresCurrentGraphVersion" } + versions { producer: )EOF", + TF_GRAPH_DEF_VERSION - 1, "}"); + ImportGraphDefOptions opts; + ExpectError(gdef_ascii, opts, {"Wrong graph version for shape"}); + gdef_ascii = strings::StrCat(R"EOF( + node{ name:"A" op:"RequiresCurrentGraphVersion" } + versions { producer: )EOF", + TF_GRAPH_DEF_VERSION, "}"); + ExpectOK(gdef_ascii, opts); +} + +TEST_F(GraphConstructorTest, GraphDefVersionMergingDuringImport) { + ImportGraphDefOptions opts; + ExpectOK( + "versions { producer: 15 min_consumer: 5 bad_consumers: 2 bad_consumers: " + "3 " + "}", + opts); + EXPECT_EQ(15, graph_.versions().producer()); + EXPECT_EQ(5, graph_.versions().min_consumer()); + ASSERT_EQ(2, graph_.versions().bad_consumers_size()); + EXPECT_EQ(2, graph_.versions().bad_consumers(0)); + EXPECT_EQ(3, graph_.versions().bad_consumers(1)); + + ExpectOK( + "versions { producer: 10 min_consumer: 8 bad_consumers: 1 bad_consumers: " + "3 " + "}", + opts); + EXPECT_EQ(10, graph_.versions().producer()); + EXPECT_EQ(8, graph_.versions().min_consumer()); + ASSERT_EQ(3, graph_.versions().bad_consumers_size()); + EXPECT_EQ(1, graph_.versions().bad_consumers(0)); + EXPECT_EQ(2, graph_.versions().bad_consumers(1)); + EXPECT_EQ(3, graph_.versions().bad_consumers(2)); + + // This one is a no-op. + ExpectOK("versions { producer: 20 min_consumer: 7 }", opts); + EXPECT_EQ(10, graph_.versions().producer()); + EXPECT_EQ(8, graph_.versions().min_consumer()); + ASSERT_EQ(3, graph_.versions().bad_consumers_size()); + EXPECT_EQ(1, graph_.versions().bad_consumers(0)); + EXPECT_EQ(2, graph_.versions().bad_consumers(1)); + EXPECT_EQ(3, graph_.versions().bad_consumers(2)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index 94fa4257ae..c87aa82534 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -70,7 +70,7 @@ Status GraphTransferer::LoadGraphFromProto( const OutputTensorMap& output_tensor_map) { ImportGraphDefOptions opts; Graph graph(OpRegistry::Global()); - ShapeRefiner shape_refiner(graph.op_registry()); + ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); VLOG(1) << "Start import graph"; Status status = ImportGraphDef(opts, graph_def, &graph, &shape_refiner); if (!status.ok()) { diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 2f12afc9c7..2ed8db4a3f 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -161,9 +162,9 @@ TEST(ArrayOpsTest, Identity_ShapeFnHandles) { // Check that handle dtypes are preserved. const OpRegistrationData* op_reg_data; TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); - shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def, - {TensorShapeProto()}, {}, {}, {}, - {DT_BOOL}); + shape_inference::InferenceContext c(TF_GRAPH_DEF_VERSION, &op.node_def, + op_reg_data->op_def, {TensorShapeProto()}, + {}, {}, {}, {DT_BOOL}); TF_ASSERT_OK(c.construction_status()); ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr); TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn)); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 84264f13dc..8881857b29 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -229,7 +229,7 @@ TEST(MathOpsTest, Select_ShapeFn) { ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr); shape_inference::InferenceContext c( - &op.node_def, op_reg_data->op_def, + TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def, {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {}, {TensorShapeProto(), i0, i1}, {}); TF_ASSERT_OK(c.construction_status()); @@ -242,7 +242,7 @@ TEST(MathOpsTest, Select_ShapeFn) { i1.add_dim()->set_size(2); i1.add_dim()->set_size(2); shape_inference::InferenceContext c2( - &op.node_def, op_reg_data->op_def, + TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def, {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {}, {TensorShapeProto(), i0, i2}, {}); TF_ASSERT_OK(c.construction_status()); diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 81d49684a8..f0859ed23f 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -79,6 +79,9 @@ limitations under the License. // used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is // now used by tf.concat. Graphs use flooring // division and mod semantics. TensorArrayV3. (12dec2016) +// Also considered the version for when it is required for reduction +// ops' indices to be scalar or vector, and not higher rank. +// Some earlier graph def versions allowed this. // 21. Dropped FunctionDef.Node support, switched to node_def introduced // in version 12. (11jan2017) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 6224bd8489..70a66e7a72 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -768,6 +768,7 @@ py_test( ":nn_grad", ":nn_ops", ":random_ops", + ":test_ops", ":variables", "//tensorflow/core:protos_all_py", "//third_party/py/numpy", diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index 487387cd83..5ec73afa99 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -635,6 +635,7 @@ def _call_cpp_shape_fn_impl( input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn): """Core implementaton of call_cpp_shape_fn.""" + graph_def_version = op.graph.graph_def_versions.producer node_def_str = op.node_def.SerializeToString() def tensor_to_inference_result(t): @@ -666,8 +667,8 @@ def _call_cpp_shape_fn_impl( try: with errors.raise_exception_on_not_ok_status() as status: output = pywrap_tensorflow.RunCppShapeInference( - node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, - status) + graph_def_version, node_def_str, input_shapes, input_tensors, + input_tensors_as_shapes, status) except errors.InvalidArgumentError as err: if err.message.startswith("No shape inference function exists for op"): missing_shape_fn = True diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc index cc08e3b705..e1fab4fc2d 100644 --- a/tensorflow/python/framework/cpp_shape_inference.cc +++ b/tensorflow/python/framework/cpp_shape_inference.cc @@ -47,7 +47,7 @@ void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s, } Status RunCppShapeInferenceImpl( - const string& serialized_node_def, + int graph_def_version, const string& serialized_node_def, const std::vector<string>& input_serialized_shapes, const std::vector<PyObject*>& input_constant_tensor_values, const std::vector<string>& input_constant_tensor_as_shape_values, @@ -115,8 +115,9 @@ Status RunCppShapeInferenceImpl( // Run shape inference. tensorflow::shape_inference::InferenceContext c( - &node, op_reg_data->op_def, input_shapes, input_tensors, - input_tensor_as_shapes_protos, input_handle_shapes, input_handle_dtypes); + graph_def_version, &node, op_reg_data->op_def, input_shapes, + input_tensors, input_tensor_as_shapes_protos, input_handle_shapes, + input_handle_dtypes); TF_RETURN_IF_ERROR(c.construction_status()); TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); @@ -151,7 +152,7 @@ Status RunCppShapeInferenceImpl( } // namespace std::vector<string> RunCppShapeInference( - const string& serialized_node_def, + int graph_def_version, const string& serialized_node_def, const std::vector<string>& input_serialized_shapes, PyObject* input_constant_tensor_values, const std::vector<string>& input_constant_tensor_as_shape_values, @@ -171,7 +172,7 @@ std::vector<string> RunCppShapeInference( std::vector<string> output; string input_tensors_needed_out; tensorflow::Status status = RunCppShapeInferenceImpl( - serialized_node_def, input_serialized_shapes, + graph_def_version, serialized_node_def, input_serialized_shapes, input_constant_tensor_values_v, input_constant_tensor_as_shape_values, &output, &input_tensors_needed_out); diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h index 79b37aa6b4..afca7277c7 100644 --- a/tensorflow/python/framework/cpp_shape_inference.h +++ b/tensorflow/python/framework/cpp_shape_inference.h @@ -42,7 +42,7 @@ namespace swig { // This is temporary code to be used during the migration // from python shape inference functions to C++ shape inference functions. std::vector<string> RunCppShapeInference( - const string& serialized_node_def, + int graph_def_version, const string& serialized_node_def, const std::vector<string>& input_serialized_shapes, PyObject* input_constant_tensor_values, const std::vector<string>& input_constant_tensor_as_shape_values, diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index c82bf16bb2..5e4d5bbecc 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_ops # pylint: disable=unused-import from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl @@ -845,6 +846,24 @@ class ImportGraphDefTest(test.TestCase): with self.assertRaisesRegexp(Exception, pat): sess.run(x) + def testVersionAppliesToOpConstruction(self): + """These tests rely on shape fns in test_ops.cc.""" + with ops.Graph().as_default(): + importer.import_graph_def( + self._MakeGraphDef( + "node { name: 'A' op: 'RequiresOlderGraphVersion' }", + producer=versions.GRAPH_DEF_VERSION - 1), + return_elements=["A"]) + + with ops.Graph().as_default(): + with self.assertRaisesWithPredicateMatch(ValueError, + "Wrong graph version.*"): + importer.import_graph_def( + self._MakeGraphDef( + "node { name: 'A' op: 'RequiresOlderGraphVersion' }", + producer=versions.GRAPH_DEF_VERSION), + return_elements=["A"]) + def testDefaultAttrsAdded(self): with ops.Graph().as_default(): a = importer.import_graph_def( diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc index c19094847d..19f07fb754 100644 --- a/tensorflow/python/framework/test_ops.cc +++ b/tensorflow/python/framework/test_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_handle.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -31,6 +32,16 @@ REGISTER_OP("GraphDefVersion") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("RequiresOlderGraphVersion") + .Output("version: int32") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + if (c->graph_def_version() != TF_GRAPH_DEF_VERSION - 1) { + return errors::InvalidArgument("Wrong graph version for shape"); + } + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("Old") .SetShapeFn(shape_inference::UnknownShape) .Deprecated(8, "For reasons"); diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index 0da5a2ecc5..316c23609c 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -241,6 +241,13 @@ class SumReductionTest(test.TestCase): c_unknown_indices, unknown_indices, keep_dims=True) self.assertEqual(2, s_unknown_indices_keep.get_shape().ndims) + def testWrongShapeForReductionIndices(self): + reduction_axes = [[1], [2]] + c_unknown = array_ops.placeholder(dtypes.float32) + with self.assertRaisesWithPredicateMatch(ValueError, + ".*must be at most rank 1.*"): + math_ops.reduce_sum(c_unknown, reduction_axes) + # Int64?? def _compareGradient(self, shape, sum_shape, reduction_axes): |