aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-03 14:14:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-03 14:26:26 -0800
commitb177e3720721dea593f1f15ba731ab351e87d298 (patch)
tree6a58e524addc6b2518df22fee094b9a712576910 /tensorflow
parent1dc89c1ab1bee51ae40f97994ef81ac3d6b1391c (diff)
Add the graphdef version to InferenceContext and to ShapeRefiner::AddNode.
Use this to allow loading reductions saved with older graphdefs. Change GraphConstructor to not increase the version when importing, but instead take the min of all versions. Change: 149152437
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/c/c_api.cc2
-rw-r--r--tensorflow/cc/framework/scope.cc3
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc12
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.h5
-rw-r--r--tensorflow/core/common_runtime/shape_refiner_test.cc49
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc8
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc135
-rw-r--r--tensorflow/core/framework/shape_inference.cc10
-rw-r--r--tensorflow/core/framework/shape_inference.h9
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc118
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc5
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.h2
-rw-r--r--tensorflow/core/graph/graph_constructor.cc8
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc80
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc2
-rw-r--r--tensorflow/core/ops/array_ops_test.cc7
-rw-r--r--tensorflow/core/ops/math_ops_test.cc4
-rw-r--r--tensorflow/core/public/version.h3
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/framework/common_shapes.py5
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.cc11
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.h2
-rw-r--r--tensorflow/python/framework/importer_test.py19
-rw-r--r--tensorflow/python/framework/test_ops.cc11
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py7
25 files changed, 344 insertions, 174 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 496ec8dc86..02aba54e43 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -730,7 +730,7 @@ extern "C" {
struct TF_Graph {
TF_Graph()
: graph(OpRegistry::Global()),
- refiner(graph.op_registry()),
+ refiner(graph.versions().producer(), graph.op_registry()),
num_sessions(0),
delete_requested(false),
parent(nullptr),
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index c5629dbd6d..571c6e1e57 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -118,7 +118,8 @@ Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
- ShapeRefiner* refiner = new ShapeRefiner(graph->op_registry());
+ ShapeRefiner* refiner =
+ new ShapeRefiner(graph->versions().producer(), graph->op_registry());
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
}
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index 52c9ea182f..7288ecb143 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -31,8 +31,9 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-ShapeRefiner::ShapeRefiner(const OpRegistryInterface* ops)
- : ops_registry_(ops) {}
+ShapeRefiner::ShapeRefiner(int graph_def_version,
+ const OpRegistryInterface* ops)
+ : graph_def_version_(graph_def_version), ops_registry_(ops) {}
Status ShapeRefiner::AddNode(const Node* node) {
// For each 'input' of this node, fetch the corresponding shape
@@ -85,9 +86,10 @@ Status ShapeRefiner::AddNode(const Node* node) {
std::vector<ShapeHandle> input_tensors_as_shapes;
// Create the inference context for this node with the existing input shapes.
- std::unique_ptr<InferenceContext> c(new InferenceContext(
- &node->def(), node->op_def(), input_shapes, input_tensors,
- input_tensors_as_shapes, input_handle_shapes, input_handle_dtypes));
+ std::unique_ptr<InferenceContext> c(
+ new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
+ input_shapes, input_tensors, input_tensors_as_shapes,
+ input_handle_shapes, input_handle_dtypes));
if (!c->construction_status().ok()) {
return c->construction_status();
}
diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h
index 43466727d9..b8d69fc05b 100644
--- a/tensorflow/core/common_runtime/shape_refiner.h
+++ b/tensorflow/core/common_runtime/shape_refiner.h
@@ -31,7 +31,7 @@ namespace tensorflow {
// construction time.
class ShapeRefiner {
public:
- explicit ShapeRefiner(const OpRegistryInterface* ops);
+ ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
// Performs validation of 'node' and runs 'node's shape function,
// storing its shape outputs.
@@ -98,7 +98,8 @@ class ShapeRefiner {
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
- const OpRegistryInterface* ops_registry_ = nullptr;
+ const int graph_def_version_;
+ const OpRegistryInterface* const ops_registry_;
// Stores a map from a node to its InferenceContext.
//
diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc
index 54a00ac9ff..05274ff311 100644
--- a/tensorflow/core/common_runtime/shape_refiner_test.cc
+++ b/tensorflow/core/common_runtime/shape_refiner_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
@@ -38,14 +39,14 @@ TEST(ShapeRefinerTest, Constant) {
// and that its shape is correct.
Scope root = Scope::NewRootScope();
auto c = ops::Const(root, 42.0f);
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(c.node()));
EXPECT_SHAPE("[]", m, c, 0);
}
TEST(ShapeRefinerTest, MatMul) {
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
Scope root = Scope::NewRootScope();
auto a = ops::Const(root, {{1.0f}, {2.0f}});
@@ -62,7 +63,7 @@ TEST(ShapeRefinerTest, MatMul) {
}
TEST(ShapeRefinerTest, InvalidOrder) {
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
Scope root = Scope::NewRootScope();
auto a = ops::Const(root, {{1.0f}, {2.0f}});
auto b = ops::Const(root, {{1.0f, 2.0f}});
@@ -77,7 +78,7 @@ TEST(ShapeRefinerTest, InvalidOrder) {
}
TEST(ShapeRefinerTest, BadShapes) {
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
Scope root = Scope::NewRootScope();
auto a = ops::Const(root, {{1.0f}, {2.0f}});
auto b = ops::Const(root, {{1.0f}, {2.0f}});
@@ -94,7 +95,7 @@ TEST(ShapeRefinerTest, BadShapes) {
}
TEST(ShapeRefinerTest, SetShape) {
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
Scope root = Scope::NewRootScope();
auto a = ops::Placeholder(root, DT_FLOAT);
@@ -136,7 +137,7 @@ TEST(ShapeRefinerTest, PropagateConstants) {
auto dim = ops::Variable(root, {}, DT_INT32);
auto am = ops::ArgMax(root, input, dim);
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(dim.node()));
TF_ASSERT_OK(m.AddNode(am.node()));
@@ -153,7 +154,7 @@ TEST(ShapeRefinerTest, PropagateConstants) {
auto dim = ops::Const(root, 1);
auto am = ops::ArgMax(root, input, dim);
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(dim.node()));
TF_ASSERT_OK(m.AddNode(am.node()));
@@ -169,7 +170,7 @@ TEST(ShapeRefinerTest, PropagateConstants) {
auto dim = ops::Const(root, 0);
auto am = ops::ArgMax(root, input, dim);
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(dim.node()));
TF_ASSERT_OK(m.AddNode(am.node()));
@@ -199,7 +200,7 @@ REGISTER_OP("TestOp")
} // namespace
TEST(ShapeRefinerTest, InputTensorDependencies) {
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
Graph graph(OpRegistry::Global());
Node* node;
@@ -260,7 +261,7 @@ TEST(ShapeRefinerTest, PropagateShape) {
.Input(shape.node())
.Finalize(root.graph(), &shape_data));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(shape.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
@@ -281,7 +282,7 @@ TEST(ShapeRefinerTest, PropagateSize) {
.Input(size.node())
.Finalize(root.graph(), &shape_data));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(size.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
@@ -302,7 +303,7 @@ TEST(ShapeRefinerTest, PropagateRank) {
.Input(rank.node())
.Finalize(root.graph(), &shape_data));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input.node()));
TF_ASSERT_OK(m.AddNode(rank.node()));
TF_ASSERT_OK(m.AddNode(shape_data));
@@ -323,7 +324,7 @@ TEST(ShapeRefinerTest, PropagateRange) {
.Input(range.node())
.Finalize(root.graph(), &shape_data));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(begin.node()));
TF_ASSERT_OK(m.AddNode(limit.node()));
TF_ASSERT_OK(m.AddNode(delta.node()));
@@ -346,7 +347,7 @@ TEST(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) {
.Input(range.node())
.Finalize(root.graph(), &shape_data));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(begin_and_delta.node()));
TF_ASSERT_OK(m.AddNode(limit.node()));
TF_ASSERT_OK(m.AddNode(range.node()));
@@ -381,7 +382,7 @@ TEST(ShapeRefinerTest, ConstantValueVisitNodeTwice) {
.Input(range.node())
.Finalize(root.graph(), &shape_data));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(begin.node()));
TF_ASSERT_OK(m.AddNode(limit.node()));
TF_ASSERT_OK(m.AddNode(delta.node()));
@@ -477,7 +478,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) {
.Input(input)
.Finalize(root.graph(), &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input));
TF_ASSERT_OK(m.AddNode(result));
@@ -498,7 +499,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Shape) {
.Input(shape.node())
.Finalize(root.graph(), &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(input));
TF_ASSERT_OK(m.AddNode(shape.node()));
TF_ASSERT_OK(m.AddNode(result));
@@ -533,7 +534,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
.Input(pack.node())
.Finalize(root.graph(), &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
for (auto input : inputs) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
@@ -565,7 +566,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
.Input(pack.node())
.Finalize(root.graph(), &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
for (const auto& input : inputs) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
@@ -591,7 +592,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) {
.Input(pack.node())
.Finalize(root.graph(), &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
for (const auto& input : inputs) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
@@ -618,7 +619,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
.Input(pack.node())
.Finalize(root.graph(), &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
for (const auto& input : inputs) {
TF_ASSERT_OK(m.AddNode(input.node()));
}
@@ -650,7 +651,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) {
.Input(concat.node())
.Finalize(g, &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(partial_1));
TF_ASSERT_OK(m.AddNode(partial_2));
for (const auto& o : concat_inputs) {
@@ -692,7 +693,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
.Input(concat.node())
.Finalize(g, &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(partial_1));
TF_ASSERT_OK(m.AddNode(partial_2));
TF_ASSERT_OK(m.AddNode(unknown));
@@ -734,7 +735,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
.Input(concat.node())
.Finalize(g, &result));
- ShapeRefiner m(OpRegistry::Global());
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
TF_ASSERT_OK(m.AddNode(partial_1));
TF_ASSERT_OK(m.AddNode(partial_2));
for (const auto& o : concat_inputs) {
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 9d5d212ddd..ede0452f14 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -590,7 +590,13 @@ Status ReductionShape(InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle indices;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
+ // Older versions of TensorFlow accidentally allowed higher rank tensors like
+ // [[1,2]] or [[1],[2]] to represent axis=[1,2].
+ if (c->graph_def_version() < 21) {
+ indices = c->input(1);
+ } else {
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
+ }
bool keep_dims;
TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index 89acf1202c..2d9e96e6bc 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -69,7 +69,8 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
.Input({{"data", 0, DT_FLOAT}})
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {},
+ {}, {}, {});
TF_EXPECT_OK(NoOutputs(&c));
EXPECT_EQ(0, c.num_outputs());
}
@@ -87,14 +88,16 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {},
+ {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
{
- InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({1, 23, 4, 4, 2})}, {}, {}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@@ -121,7 +124,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({2, 3}), S({3, 4})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -130,7 +134,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown inner dimension for one
- InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({2, -1}), S({3, 4})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -139,7 +144,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Invalid rank.
- InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})},
+ {}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -149,7 +155,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown outer dimension
- InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({2, 3}), S({3, -1})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -158,7 +165,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
- InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({2, 5}), S({3, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -169,8 +177,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
- InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {},
- {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@@ -188,7 +196,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({3, 2}), S({3, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -205,7 +214,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({2, 3}), S({4, 3})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -229,7 +239,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({2, 10}), S({10})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@@ -238,7 +249,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Unknown ranks.
- InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {Unknown(), Unknown()}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_FALSE(c.RankKnown(output));
@@ -246,8 +258,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Rank > 2
- InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {},
- {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
@@ -260,7 +272,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
@@ -273,8 +286,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {},
- {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
@@ -287,8 +300,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {}, {},
- {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({10, 11, 12}), S({10})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[10,11,12]", c.DebugString(output));
@@ -296,7 +309,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Input rank not high enough
- InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {},
+ {}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@@ -308,7 +322,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})},
+ {}, {}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@@ -327,7 +342,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
{
- InferenceContext c(&def, op_def, {S({2, 10})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {},
+ {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -335,7 +351,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Rank > 2
- InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})},
+ {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -347,7 +364,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})},
+ {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -359,8 +377,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {},
- {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
+ {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@@ -372,7 +390,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
- InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})},
+ {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@@ -380,7 +399,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Input rank not high enough
- InferenceContext c(&def, op_def, {S({3})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {}, {},
+ {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@@ -391,7 +411,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
- InferenceContext c(&def, op_def, {S({2, 3})}, {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {},
+ {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}
@@ -781,12 +802,24 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
op.input_tensors[1] = nullptr;
INFER_OK(op, "[?,?,?];?", "[?,?,?]");
INFER_OK(op, "[?,?,?];[2]", "[?,?,?]");
+
+ // Reduction indices with too many dimensions.
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "[?,?,?];[?,?]");
+ // With older graph-def version, this is allowed.
+ op.graph_def_version = 20;
+ INFER_OK(op, "[?,?,?];[?,?]", "[?,?,?]");
+ // And when the tensor is specified, it's still allowed.
+ op.input_tensors[1] = &indices;
+ indices = test::AsTensor<int32>({-1, -2}, TensorShape({2, 1}));
+ INFER_OK(op, "[2,4,5];[2,1]", "[d0_0, 1, 1]");
+ indices = test::AsTensor<int32>({-1, -2}, TensorShape({1, 2}));
+ INFER_OK(op, "[2,4,5];[1,2]", "[d0_0, 1, 1]");
}
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
- {}, {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {Unknown(), Unknown(), Unknown()}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -798,8 +831,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
- {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({-1, -1}), S({-1}), S({-1})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -811,8 +844,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
- {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({-1}), S({-1}), S({-1})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -825,8 +858,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
- {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({5, 3}), S({4}), S({3})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -839,8 +872,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
- {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({5, 3}), S({5}), S({4})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -853,8 +886,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
- {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({-1, 3}), S({5}), S({3})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -866,8 +899,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
- {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({5, 3}), S({-1}), S({3})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -879,8 +912,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
- {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({5, -1}), S({5}), S({3})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -892,8 +925,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
- {}, {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({5, 3}), S({5}), S({-1})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@@ -905,8 +938,8 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
- {}, {});
+ InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
+ {S({5, 3}), S({5}), S({3})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 00f2c3407a..cbfa9bd20c 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -29,13 +29,14 @@ constexpr int32 InferenceContext::kUnknownRank;
constexpr int64 InferenceContext::kUnknownDim;
InferenceContext::InferenceContext(
- const NodeDef* node_def, const OpDef& op_def,
+ int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<TensorShapeProto>& input_tensors_as_shapes,
const std::vector<TensorShapeProto>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes)
- : node_def_(*CHECK_NOTNULL(node_def)) {
+ : graph_def_version_(graph_def_version),
+ node_def_(*CHECK_NOTNULL(node_def)) {
std::vector<ShapeHandle> input_tensors_as_shape_handles;
for (const TensorShapeProto& p : input_tensors_as_shapes) {
ShapeHandle shape;
@@ -68,13 +69,14 @@ InferenceContext::InferenceContext(
}
InferenceContext::InferenceContext(
- const NodeDef* node_def, const OpDef& op_def,
+ int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes,
const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes)
- : node_def_(*CHECK_NOTNULL(node_def)) {
+ : graph_def_version_(graph_def_version),
+ node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index fd4e25c728..dba8d30302 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -144,7 +144,8 @@ class InferenceContext {
// Values of <input_tensors_as_shapes> do not need to outlive the context.
//
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
- InferenceContext(const NodeDef* node_def, const OpDef& op_def,
+ InferenceContext(int graph_def_version, const NodeDef* node_def,
+ const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes,
@@ -161,7 +162,8 @@ class InferenceContext {
// Values of <input_tensors_as_shapes> do not need to outlive the context.
//
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
- InferenceContext(const NodeDef* node_def, const OpDef& op_def,
+ InferenceContext(int graph_def_version, const NodeDef* node_def,
+ const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<TensorShapeProto>& input_tensors_as_shapes,
@@ -436,6 +438,8 @@ class InferenceContext {
Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape,
ShapeHandle* out);
+ int graph_def_version() const { return graph_def_version_; }
+
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {
@@ -508,6 +512,7 @@ class InferenceContext {
std::vector<ShapeHandle> output_handle_shape_;
std::vector<DataType> output_handle_dtype_;
+ const int graph_def_version_;
const NodeDef& node_def_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index 9f363d50b3..9fc068aebe 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -61,6 +61,8 @@ class ShapeInferenceTest : public ::testing::Test {
bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
bool IsSet(DimensionHandle d) { return d.IsSet(); }
bool IsSet(ShapeHandle s) { return s.IsSet(); }
+
+ static const int kVersion = 0; // used for graph-def version.
};
TEST_F(ShapeInferenceTest, InputOutputByName) {
@@ -71,8 +73,8 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
.Attr("N", 3)
.Input(FakeInput(DT_FLOAT))
.Finalize(&def);
- InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {},
- {}, {});
+ InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
+ {}, {}, {}, {});
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
@@ -108,7 +110,8 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {},
+ {});
EXPECT_EQ(InferenceContext::kUnknownDim,
c.Value(InferenceContext::kUnknownDim));
EXPECT_EQ(1, c.Value(1));
@@ -123,7 +126,7 @@ TEST_F(ShapeInferenceTest, Run) {
NodeDef def;
def.set_name("foo");
def.set_op("foo_op");
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}, {});
TF_ASSERT_OK(c.construction_status());
{
@@ -160,7 +163,8 @@ TEST_F(ShapeInferenceTest, AttachContext) {
def.set_op("foo_op");
// Error when no constant tensors were requested.
{
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
+ {}, {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
ShapeHandle h;
@@ -178,8 +182,9 @@ TEST_F(ShapeInferenceTest, AttachContext) {
{
Tensor input_t =
::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
- InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3}), S({4, 5})},
- {nullptr, &input_t}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
+ {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {},
+ {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
c->input_tensor(0); // get this one, but it's null - won't be in error.
@@ -200,7 +205,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
// shapes provided.
{
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
- InferenceContext c(&def, MakeOpDef(2, 2), {S({3}), S({4})},
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
{nullptr, &input_t}, {}, {}, {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
@@ -223,7 +228,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
// shape was provided.
{
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
- InferenceContext c(&def, MakeOpDef(2, 2), {S({3}), S({4})},
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
{nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {},
{});
TF_ASSERT_OK(c.construction_status());
@@ -247,8 +252,8 @@ TEST_F(ShapeInferenceTest, AttachContext) {
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})},
- {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
+ {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@@ -288,7 +293,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
TEST_F(ShapeInferenceTest, NumElements) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 2),
+ InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {},
{});
@@ -303,8 +308,8 @@ TEST_F(ShapeInferenceTest, NumElements) {
TEST_F(ShapeInferenceTest, WithRank) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
- {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
+ {Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -342,8 +347,8 @@ TEST_F(ShapeInferenceTest, WithRank) {
TEST_F(ShapeInferenceTest, WithRankAtMost) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
- {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
+ {Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -381,8 +386,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
- {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
+ {Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -420,7 +425,8 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
TEST_F(ShapeInferenceTest, WithValue) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {},
+ {});
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@@ -461,8 +467,8 @@ TEST_F(ShapeInferenceTest, WithValue) {
TEST_F(ShapeInferenceTest, MergeDim) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {}, {},
- {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
+ {}, {}, {}, {});
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@@ -508,7 +514,7 @@ TEST_F(ShapeInferenceTest, MergeDim) {
TEST_F(ShapeInferenceTest, MergeShape) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(7, 2),
+ InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
Unknown(), S({1})},
{}, {}, {}, {});
@@ -578,7 +584,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
TEST_F(ShapeInferenceTest, MergePrefix) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(4, 2),
+ InferenceContext c(kVersion, &def, MakeOpDef(4, 2),
{
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
},
@@ -634,8 +640,8 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
TEST_F(ShapeInferenceTest, Subshape) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()},
- {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
+ {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {}, {});
ShapeHandle unknown = c.input(1);
ShapeHandle out;
@@ -709,7 +715,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
TEST_F(ShapeInferenceTest, Concatenate) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 2),
+ InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {});
auto in0 = c.input(0);
@@ -736,8 +742,8 @@ TEST_F(ShapeInferenceTest, Concatenate) {
TEST_F(ShapeInferenceTest, ReplaceDim) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {},
- {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
+ {}, {}, {}, {});
auto in = c.input(0);
auto unknown = c.input(1);
@@ -768,8 +774,8 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
TEST_F(ShapeInferenceTest, MakeShape) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}, {},
- {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
+ {}, {}, {});
std::vector<DimensionHandle> dims;
auto in0 = c.input(0);
@@ -794,7 +800,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
TEST_F(ShapeInferenceTest, UnknownShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto u0 = c.UnknownShape();
auto u1 = c.UnknownShape();
@@ -806,7 +812,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
TEST_F(ShapeInferenceTest, Scalar) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Scalar();
EXPECT_EQ("[]", c.DebugString(s0));
@@ -817,7 +823,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
TEST_F(ShapeInferenceTest, Vector) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Vector(1);
EXPECT_EQ("[1]", c.DebugString(s0));
@@ -833,7 +839,7 @@ TEST_F(ShapeInferenceTest, Vector) {
TEST_F(ShapeInferenceTest, Matrix) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Matrix(1, 2);
EXPECT_EQ("[1,2]", c.DebugString(s0));
@@ -855,7 +861,8 @@ TEST_F(ShapeInferenceTest, Matrix) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
auto create = [&](Tensor* t) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
+ {}, {});
ShapeHandle out;
Status s = c.MakeShapeFromShapeTensor(0, &out);
if (s.ok()) {
@@ -907,8 +914,8 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
// Test when the input shape is wrong.
{
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}, {},
- {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
+ {}, {}, {});
ShapeHandle out;
EXPECT_EQ("Shape must be rank 1 but is rank 2",
c.MakeShapeFromShapeTensor(0, &out).error_message());
@@ -918,7 +925,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
TensorShapeProto proto;
// With a set unknown rank.
@@ -954,7 +961,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
TEST_F(ShapeInferenceTest, MakeDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto d0 = c.MakeDim(1);
auto d1 = c.MakeDim(1);
@@ -968,7 +975,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
TEST_F(ShapeInferenceTest, UnknownDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto d0 = c.UnknownDim();
auto d1 = c.UnknownDim();
@@ -980,7 +987,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
@@ -993,7 +1000,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
NodeDef def;
- InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
+ InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
{&t1, &t2}, {}, {}, {});
EXPECT_TRUE(c.input_tensor(0) == &t1);
@@ -1005,8 +1012,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
NodeDef def;
- InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {}, {},
- {});
+ InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
+ {&t1, &t2}, {}, {}, {});
DimensionHandle d;
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
@@ -1037,7 +1044,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
.ok());
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, op_reg_data.op_def, empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {}, {});
string value;
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
EXPECT_EQ("bar", value);
@@ -1045,8 +1052,8 @@ TEST_F(ShapeInferenceTest, GetAttr) {
TEST_F(ShapeInferenceTest, Divide) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}, {},
- {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
+ {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1108,7 +1115,8 @@ TEST_F(ShapeInferenceTest, Divide) {
TEST_F(ShapeInferenceTest, Add) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
+ {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1159,7 +1167,8 @@ TEST_F(ShapeInferenceTest, Add) {
TEST_F(ShapeInferenceTest, Subtract) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
+ {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1208,7 +1217,8 @@ TEST_F(ShapeInferenceTest, Subtract) {
TEST_F(ShapeInferenceTest, Multiply) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
+ {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@@ -1261,7 +1271,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
TEST_F(ShapeInferenceTest, FullyDefined) {
NodeDef def;
std::vector<ShapeHandle> empty;
- InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
// No rank or missing dimension information should return false.
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
@@ -1274,7 +1284,8 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
TEST_F(ShapeInferenceTest, Min) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
+ {}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@@ -1322,7 +1333,8 @@ TEST_F(ShapeInferenceTest, Min) {
TEST_F(ShapeInferenceTest, Max) {
NodeDef def;
- InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, {}, {});
+ InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
+ {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index a225824f82..85e085af99 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -43,8 +43,9 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
in_shapes.push_back(shape);
}
- shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
- in_shapes, op.input_tensors, {}, {}, {});
+ shape_inference::InferenceContext c(op.graph_def_version, &op.node_def,
+ op_reg_data->op_def, in_shapes,
+ op.input_tensors, {}, {}, {});
TF_RETURN_IF_ERROR(c.construction_status());
if (op_reg_data->shape_inference_fn == nullptr) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index 64067464fb..996281e70e 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/version.h"
// Contains utilities for writing tests for shape inference functions.
@@ -34,6 +35,7 @@ struct ShapeInferenceTestOp {
string name;
NodeDef node_def;
std::vector<const Tensor*> input_tensors;
+ int graph_def_version = TF_GRAPH_DEF_VERSION;
};
namespace shape_inference {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index c68ac37fa8..a83cf26723 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -745,8 +745,8 @@ Status GraphConstructor::UpdateVersionDef() {
return Status::OK();
}
VersionDef versions = g_->versions();
- // This new graph is being "produced" by the binary invoking ImportGraphDef.
- versions.set_producer(TF_GRAPH_DEF_VERSION);
+ versions.set_producer(
+ std::min(versions.producer(), gdef_->versions().producer()));
versions.set_min_consumer(
std::max(versions.min_consumer(), gdef_->versions().min_consumer()));
if (gdef_->versions().bad_consumers_size() > 0) {
@@ -820,14 +820,14 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g) {
- ShapeRefiner refiner(g->op_registry());
+ ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
return GraphConstructor::Construct(opts, &gdef, g, &refiner, nullptr);
}
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
Graph* g, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors) {
- ShapeRefiner default_refiner(g->op_registry());
+ ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
if (refiner == nullptr) {
refiner = &default_refiner;
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index e20d89485d..02f614dad2 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -203,6 +203,15 @@ REGISTER_OP("TestOneInputOneOutput")
REGISTER_OP("TestDefaultAttr")
.Attr("default_int: int=31415")
.SetShapeFn(shape_inference::NoOutputs);
+REGISTER_OP("RequiresCurrentGraphVersion")
+ .Output("version: int32")
+ .SetIsStateful()
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ if (c->graph_def_version() != TF_GRAPH_DEF_VERSION) {
+ return errors::InvalidArgument("Wrong graph version for shape");
+ }
+ return shape_inference::ScalarShape(c);
+ });
TEST_F(GraphConstructorTest, InvalidNodeName) {
auto expect_invalid_name = [this](const char* name) {
@@ -1052,7 +1061,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ShapeWhitelist) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Populate graph with node we'll use in input map
ExpectOK("node { name: 'input' op: 'TestInput' }", ImportGraphDefOptions(),
@@ -1092,7 +1101,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Populate graph with node we'll use in input map
ExpectOK(
@@ -1155,7 +1164,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Populate graph with node we'll use in input map
ExpectOK("node { name: 'W1' op: 'TestParams' }", ImportGraphDefOptions(),
@@ -1219,7 +1228,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithBadControlEdge) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Populate graph with node we'll use in input map
ExpectOK("node { name: 'W1' op: 'TestParams' }", ImportGraphDefOptions(),
@@ -1251,7 +1260,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithBadControlEdge) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithInvalidNodeIndex) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Populate graph with node we'll use in input map
ExpectOK("node { name: 'input1' op: 'TestInput' }", ImportGraphDefOptions(),
@@ -1272,7 +1281,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithInvalidNodeIndex) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithMissingEntries) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Populate graph with node we'll use in input map
ExpectOK("node { name: 'W1' op: 'TestParams' }", ImportGraphDefOptions(),
@@ -1293,7 +1302,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithMissingEntries) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Add two nodes with the same name to graph
Node* node;
@@ -1318,7 +1327,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
ImportGraphDefOptions opts;
opts.return_tensors.push_back({"input", 1});
@@ -1634,7 +1643,7 @@ versions {
}
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Populate graph with nodes we'll use in control deps and input map
ExpectOK(
@@ -1701,7 +1710,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) {
}
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
- ShapeRefiner refiner(graph_.op_registry());
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// Populate graph with nodes we'll use in control deps and input map
ExpectOK(
@@ -1939,5 +1948,56 @@ TEST_F(GraphConstructorTest, CopyGraph) {
EXPECT_EQ(dst.versions().bad_consumers(0), bad);
}
+// Confirms that graph def version in the graph reaches the shape inference
+// function.
+TEST_F(GraphConstructorTest, GraphDefVersionUsedForShapeInference) {
+ string gdef_ascii = strings::StrCat(R"EOF(
+ node{ name:"A" op:"RequiresCurrentGraphVersion" }
+ versions { producer: )EOF",
+ TF_GRAPH_DEF_VERSION - 1, "}");
+ ImportGraphDefOptions opts;
+ ExpectError(gdef_ascii, opts, {"Wrong graph version for shape"});
+ gdef_ascii = strings::StrCat(R"EOF(
+ node{ name:"A" op:"RequiresCurrentGraphVersion" }
+ versions { producer: )EOF",
+ TF_GRAPH_DEF_VERSION, "}");
+ ExpectOK(gdef_ascii, opts);
+}
+
+TEST_F(GraphConstructorTest, GraphDefVersionMergingDuringImport) {
+ ImportGraphDefOptions opts;
+ ExpectOK(
+ "versions { producer: 15 min_consumer: 5 bad_consumers: 2 bad_consumers: "
+ "3 "
+ "}",
+ opts);
+ EXPECT_EQ(15, graph_.versions().producer());
+ EXPECT_EQ(5, graph_.versions().min_consumer());
+ ASSERT_EQ(2, graph_.versions().bad_consumers_size());
+ EXPECT_EQ(2, graph_.versions().bad_consumers(0));
+ EXPECT_EQ(3, graph_.versions().bad_consumers(1));
+
+ ExpectOK(
+ "versions { producer: 10 min_consumer: 8 bad_consumers: 1 bad_consumers: "
+ "3 "
+ "}",
+ opts);
+ EXPECT_EQ(10, graph_.versions().producer());
+ EXPECT_EQ(8, graph_.versions().min_consumer());
+ ASSERT_EQ(3, graph_.versions().bad_consumers_size());
+ EXPECT_EQ(1, graph_.versions().bad_consumers(0));
+ EXPECT_EQ(2, graph_.versions().bad_consumers(1));
+ EXPECT_EQ(3, graph_.versions().bad_consumers(2));
+
+ // This one is a no-op.
+ ExpectOK("versions { producer: 20 min_consumer: 7 }", opts);
+ EXPECT_EQ(10, graph_.versions().producer());
+ EXPECT_EQ(8, graph_.versions().min_consumer());
+ ASSERT_EQ(3, graph_.versions().bad_consumers_size());
+ EXPECT_EQ(1, graph_.versions().bad_consumers(0));
+ EXPECT_EQ(2, graph_.versions().bad_consumers(1));
+ EXPECT_EQ(3, graph_.versions().bad_consumers(2));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index 94fa4257ae..c87aa82534 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -70,7 +70,7 @@ Status GraphTransferer::LoadGraphFromProto(
const OutputTensorMap& output_tensor_map) {
ImportGraphDefOptions opts;
Graph graph(OpRegistry::Global());
- ShapeRefiner shape_refiner(graph.op_registry());
+ ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
VLOG(1) << "Start import graph";
Status status = ImportGraphDef(opts, graph_def, &graph, &shape_refiner);
if (!status.ok()) {
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 2f12afc9c7..2ed8db4a3f 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -161,9 +162,9 @@ TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
// Check that handle dtypes are preserved.
const OpRegistrationData* op_reg_data;
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
- shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
- {TensorShapeProto()}, {}, {}, {},
- {DT_BOOL});
+ shape_inference::InferenceContext c(TF_GRAPH_DEF_VERSION, &op.node_def,
+ op_reg_data->op_def, {TensorShapeProto()},
+ {}, {}, {}, {DT_BOOL});
TF_ASSERT_OK(c.construction_status());
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index 84264f13dc..8881857b29 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -229,7 +229,7 @@ TEST(MathOpsTest, Select_ShapeFn) {
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
shape_inference::InferenceContext c(
- &op.node_def, op_reg_data->op_def,
+ TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
{TensorShapeProto(), i0, i1}, {});
TF_ASSERT_OK(c.construction_status());
@@ -242,7 +242,7 @@ TEST(MathOpsTest, Select_ShapeFn) {
i1.add_dim()->set_size(2);
i1.add_dim()->set_size(2);
shape_inference::InferenceContext c2(
- &op.node_def, op_reg_data->op_def,
+ TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
{TensorShapeProto(), i0, i2}, {});
TF_ASSERT_OK(c.construction_status());
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 81d49684a8..f0859ed23f 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -79,6 +79,9 @@ limitations under the License.
// used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is
// now used by tf.concat. Graphs use flooring
// division and mod semantics. TensorArrayV3. (12dec2016)
+// Also considered the version for when it is required for reduction
+// ops' indices to be scalar or vector, and not higher rank.
+// Some earlier graph def versions allowed this.
// 21. Dropped FunctionDef.Node support, switched to node_def introduced
// in version 12. (11jan2017)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 6224bd8489..70a66e7a72 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -768,6 +768,7 @@ py_test(
":nn_grad",
":nn_ops",
":random_ops",
+ ":test_ops",
":variables",
"//tensorflow/core:protos_all_py",
"//third_party/py/numpy",
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 487387cd83..5ec73afa99 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -635,6 +635,7 @@ def _call_cpp_shape_fn_impl(
input_tensors_as_shapes_needed,
debug_python_shape_fn, require_shape_fn):
"""Core implementaton of call_cpp_shape_fn."""
+ graph_def_version = op.graph.graph_def_versions.producer
node_def_str = op.node_def.SerializeToString()
def tensor_to_inference_result(t):
@@ -666,8 +667,8 @@ def _call_cpp_shape_fn_impl(
try:
with errors.raise_exception_on_not_ok_status() as status:
output = pywrap_tensorflow.RunCppShapeInference(
- node_def_str, input_shapes, input_tensors, input_tensors_as_shapes,
- status)
+ graph_def_version, node_def_str, input_shapes, input_tensors,
+ input_tensors_as_shapes, status)
except errors.InvalidArgumentError as err:
if err.message.startswith("No shape inference function exists for op"):
missing_shape_fn = True
diff --git a/tensorflow/python/framework/cpp_shape_inference.cc b/tensorflow/python/framework/cpp_shape_inference.cc
index cc08e3b705..e1fab4fc2d 100644
--- a/tensorflow/python/framework/cpp_shape_inference.cc
+++ b/tensorflow/python/framework/cpp_shape_inference.cc
@@ -47,7 +47,7 @@ void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s,
}
Status RunCppShapeInferenceImpl(
- const string& serialized_node_def,
+ int graph_def_version, const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
const std::vector<PyObject*>& input_constant_tensor_values,
const std::vector<string>& input_constant_tensor_as_shape_values,
@@ -115,8 +115,9 @@ Status RunCppShapeInferenceImpl(
// Run shape inference.
tensorflow::shape_inference::InferenceContext c(
- &node, op_reg_data->op_def, input_shapes, input_tensors,
- input_tensor_as_shapes_protos, input_handle_shapes, input_handle_dtypes);
+ graph_def_version, &node, op_reg_data->op_def, input_shapes,
+ input_tensors, input_tensor_as_shapes_protos, input_handle_shapes,
+ input_handle_dtypes);
TF_RETURN_IF_ERROR(c.construction_status());
TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
@@ -151,7 +152,7 @@ Status RunCppShapeInferenceImpl(
} // namespace
std::vector<string> RunCppShapeInference(
- const string& serialized_node_def,
+ int graph_def_version, const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
PyObject* input_constant_tensor_values,
const std::vector<string>& input_constant_tensor_as_shape_values,
@@ -171,7 +172,7 @@ std::vector<string> RunCppShapeInference(
std::vector<string> output;
string input_tensors_needed_out;
tensorflow::Status status = RunCppShapeInferenceImpl(
- serialized_node_def, input_serialized_shapes,
+ graph_def_version, serialized_node_def, input_serialized_shapes,
input_constant_tensor_values_v, input_constant_tensor_as_shape_values,
&output, &input_tensors_needed_out);
diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h
index 79b37aa6b4..afca7277c7 100644
--- a/tensorflow/python/framework/cpp_shape_inference.h
+++ b/tensorflow/python/framework/cpp_shape_inference.h
@@ -42,7 +42,7 @@ namespace swig {
// This is temporary code to be used during the migration
// from python shape inference functions to C++ shape inference functions.
std::vector<string> RunCppShapeInference(
- const string& serialized_node_def,
+ int graph_def_version, const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
PyObject* input_constant_tensor_values,
const std::vector<string>& input_constant_tensor_as_shape_values,
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index c82bf16bb2..5e4d5bbecc 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import importer
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_ops # pylint: disable=unused-import
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
@@ -845,6 +846,24 @@ class ImportGraphDefTest(test.TestCase):
with self.assertRaisesRegexp(Exception, pat):
sess.run(x)
+ def testVersionAppliesToOpConstruction(self):
+ """These tests rely on shape fns in test_ops.cc."""
+ with ops.Graph().as_default():
+ importer.import_graph_def(
+ self._MakeGraphDef(
+ "node { name: 'A' op: 'RequiresOlderGraphVersion' }",
+ producer=versions.GRAPH_DEF_VERSION - 1),
+ return_elements=["A"])
+
+ with ops.Graph().as_default():
+ with self.assertRaisesWithPredicateMatch(ValueError,
+ "Wrong graph version.*"):
+ importer.import_graph_def(
+ self._MakeGraphDef(
+ "node { name: 'A' op: 'RequiresOlderGraphVersion' }",
+ producer=versions.GRAPH_DEF_VERSION),
+ return_elements=["A"])
+
def testDefaultAttrsAdded(self):
with ops.Graph().as_default():
a = importer.import_graph_def(
diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc
index c19094847d..19f07fb754 100644
--- a/tensorflow/python/framework/test_ops.cc
+++ b/tensorflow/python/framework/test_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_handle.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -31,6 +32,16 @@ REGISTER_OP("GraphDefVersion")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("RequiresOlderGraphVersion")
+ .Output("version: int32")
+ .SetIsStateful()
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ if (c->graph_def_version() != TF_GRAPH_DEF_VERSION - 1) {
+ return errors::InvalidArgument("Wrong graph version for shape");
+ }
+ return shape_inference::ScalarShape(c);
+ });
+
REGISTER_OP("Old")
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(8, "For reasons");
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index 0da5a2ecc5..316c23609c 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -241,6 +241,13 @@ class SumReductionTest(test.TestCase):
c_unknown_indices, unknown_indices, keep_dims=True)
self.assertEqual(2, s_unknown_indices_keep.get_shape().ndims)
+ def testWrongShapeForReductionIndices(self):
+ reduction_axes = [[1], [2]]
+ c_unknown = array_ops.placeholder(dtypes.float32)
+ with self.assertRaisesWithPredicateMatch(ValueError,
+ ".*must be at most rank 1.*"):
+ math_ops.reduce_sum(c_unknown, reduction_axes)
+
# Int64??
def _compareGradient(self, shape, sum_shape, reduction_axes):