diff options
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/ops/state_ops.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/ops/state_ops_test.cc | 26 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/ops/state_ops.py | 3 |
5 files changed, 57 insertions, 12 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index b1b553ec8c..6e076a092e 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2478,11 +2478,10 @@ REGISTER_OP("Placeholder") PartialTensorShape shape; TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); - // Placeholder has a legacy bug where we cannot tell - // the difference between a scalar shape attribute and - // 'unknown shape'. So if the shape is a scalar, we return - // an unknown shape. - if (shape.dims() == 0) { + // Placeholder has legacy behavior where we cannot tell the difference + // between a scalar shape attribute and 'unknown shape'. So if the shape + // is a scalar, we return an unknown shape. + if (shape.dims() <= 0) { return shape_inference::UnknownShape(c); } diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 629a280cc8..b9ac8b16ff 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -28,7 +28,24 @@ REGISTER_OP("Variable") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + + // Variable has legacy behavior where we cannot tell the difference + // between a scalar shape attribute and 'unknown shape'. So if the shape + // is a scalar, we return an unknown shape. + if (shape.dims() <= 0) { + return shape_inference::UnknownShape(c); + } + + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + c->set_output(0, out); + return Status::OK(); + }) .Doc(R"doc( Holds state in the form of a tensor that persists across steps. diff --git a/tensorflow/core/ops/state_ops_test.cc b/tensorflow/core/ops/state_ops_test.cc index 586de77edc..4c1ec67e9c 100644 --- a/tensorflow/core/ops/state_ops_test.cc +++ b/tensorflow/core/ops/state_ops_test.cc @@ -71,4 +71,30 @@ TEST(StateOpsTest, TemporaryVariable_ShapeFn) { INFER_OK(op, "", "[1,2,3]"); } +TEST(StateOpsTest, Variable_ShapeFn) { + ShapeInferenceTestOp op("Variable"); + TensorShapeProto shape_proto; + + // Unknown rank. + PartialTensorShape().AsProto(&shape_proto); + TF_ASSERT_OK(NodeDefBuilder("test", "Variable") + .Attr("shape", shape_proto) + .Finalize(&op.node_def)); + INFER_OK(op, "", "?"); + + // For historical reasons an empty TensorShapeProto can be either an unknown + // rank or a scalar, so the shape function conservatively says "unknown" + shape_proto.Clear(); + TF_ASSERT_OK(NodeDefBuilder("test", "Variable") + .Attr("shape", shape_proto) + .Finalize(&op.node_def)); + INFER_OK(op, "", "?"); + + // Specified shape. + TensorShape({1, 2, 3}).AsProto(&shape_proto); + TF_ASSERT_OK(NodeDefBuilder("test", "Variable") + .Attr("shape", shape_proto) + .Finalize(&op.node_def)); + INFER_OK(op, "", "[1,2,3]"); +} } // end namespace tensorflow diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 1e2db3e565..5010b79a6a 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -637,19 +637,19 @@ class VariableScopeTest(tf.test.TestCase): def testGetVarWithDevice(self): g = tf.Graph() - varname_shape = [] + varname_type = [] def device_func(op): if op.type == "Variable": - varname_shape.append((op.name, tf.TensorShape(op.get_attr("shape")))) + varname_type.append((op.name, op.get_attr("dtype"))) return "/gpu:0" with g.as_default(): with tf.device(device_func): - _ = tf.get_variable("x", (100, 200)) # init fn - _ = tf.get_variable("y", initializer=numpy.arange(73)) # init constant - self.assertEqual(varname_shape[0], ("x", tf.TensorShape([100, 200]))) - self.assertEqual(varname_shape[1], ("y", tf.TensorShape([73]))) + _ = tf.get_variable("x", (100, 200)) + _ = tf.get_variable("y", dtype=tf.int64, initializer=numpy.arange(73)) + self.assertEqual(varname_type[0], ("x", tf.float32)) + self.assertEqual(varname_type[1], ("y", tf.int64)) def axis0_into1_partitioner(shape=None, **unused_kwargs): diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index f869301873..636acc3e2a 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -116,6 +116,7 @@ from __future__ import print_function from tensorflow.python.framework import common_shapes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_state_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -146,6 +147,8 @@ def variable_op(shape, dtype, name="Variable", set_shape=True, container="", Returns: A variable tensor. """ + if not set_shape: + shape = tensor_shape.unknown_shape() ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name, container=container, shared_name=shared_name) # TODO(mrry): Move this to where it is used, so we can get rid of this op |