aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/ops/array_ops.cc9
-rw-r--r--tensorflow/core/ops/state_ops.cc19
-rw-r--r--tensorflow/core/ops/state_ops_test.cc26
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py12
-rw-r--r--tensorflow/python/ops/state_ops.py3
5 files changed, 57 insertions, 12 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index b1b553ec8c..6e076a092e 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2478,11 +2478,10 @@ REGISTER_OP("Placeholder")
PartialTensorShape shape;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
- // Placeholder has a legacy bug where we cannot tell
- // the difference between a scalar shape attribute and
- // 'unknown shape'. So if the shape is a scalar, we return
- // an unknown shape.
- if (shape.dims() == 0) {
+ // Placeholder has legacy behavior where we cannot tell the difference
+ // between a scalar shape attribute and 'unknown shape'. So if the shape
+ // is a scalar, we return an unknown shape.
+ if (shape.dims() <= 0) {
return shape_inference::UnknownShape(c);
}
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 629a280cc8..b9ac8b16ff 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -28,7 +28,24 @@ REGISTER_OP("Variable")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
+
+ // Variable has legacy behavior where we cannot tell the difference
+ // between a scalar shape attribute and 'unknown shape'. So if the shape
+ // is a scalar, we return an unknown shape.
+ if (shape.dims() <= 0) {
+ return shape_inference::UnknownShape(c);
+ }
+
+ TensorShapeProto shape_proto;
+ shape.AsProto(&shape_proto);
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
.Doc(R"doc(
Holds state in the form of a tensor that persists across steps.
diff --git a/tensorflow/core/ops/state_ops_test.cc b/tensorflow/core/ops/state_ops_test.cc
index 586de77edc..4c1ec67e9c 100644
--- a/tensorflow/core/ops/state_ops_test.cc
+++ b/tensorflow/core/ops/state_ops_test.cc
@@ -71,4 +71,30 @@ TEST(StateOpsTest, TemporaryVariable_ShapeFn) {
INFER_OK(op, "", "[1,2,3]");
}
+TEST(StateOpsTest, Variable_ShapeFn) {
+ ShapeInferenceTestOp op("Variable");
+ TensorShapeProto shape_proto;
+
+ // Unknown rank.
+ PartialTensorShape().AsProto(&shape_proto);
+ TF_ASSERT_OK(NodeDefBuilder("test", "Variable")
+ .Attr("shape", shape_proto)
+ .Finalize(&op.node_def));
+ INFER_OK(op, "", "?");
+
+ // For historical reasons an empty TensorShapeProto can be either an unknown
+ // rank or a scalar, so the shape function conservatively says "unknown"
+ shape_proto.Clear();
+ TF_ASSERT_OK(NodeDefBuilder("test", "Variable")
+ .Attr("shape", shape_proto)
+ .Finalize(&op.node_def));
+ INFER_OK(op, "", "?");
+
+ // Specified shape.
+ TensorShape({1, 2, 3}).AsProto(&shape_proto);
+ TF_ASSERT_OK(NodeDefBuilder("test", "Variable")
+ .Attr("shape", shape_proto)
+ .Finalize(&op.node_def));
+ INFER_OK(op, "", "[1,2,3]");
+}
} // end namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 1e2db3e565..5010b79a6a 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -637,19 +637,19 @@ class VariableScopeTest(tf.test.TestCase):
def testGetVarWithDevice(self):
g = tf.Graph()
- varname_shape = []
+ varname_type = []
def device_func(op):
if op.type == "Variable":
- varname_shape.append((op.name, tf.TensorShape(op.get_attr("shape"))))
+ varname_type.append((op.name, op.get_attr("dtype")))
return "/gpu:0"
with g.as_default():
with tf.device(device_func):
- _ = tf.get_variable("x", (100, 200)) # init fn
- _ = tf.get_variable("y", initializer=numpy.arange(73)) # init constant
- self.assertEqual(varname_shape[0], ("x", tf.TensorShape([100, 200])))
- self.assertEqual(varname_shape[1], ("y", tf.TensorShape([73])))
+ _ = tf.get_variable("x", (100, 200))
+ _ = tf.get_variable("y", dtype=tf.int64, initializer=numpy.arange(73))
+ self.assertEqual(varname_type[0], ("x", tf.float32))
+ self.assertEqual(varname_type[1], ("y", tf.int64))
def axis0_into1_partitioner(shape=None, **unused_kwargs):
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index f869301873..636acc3e2a 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -116,6 +116,7 @@ from __future__ import print_function
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_state_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
@@ -146,6 +147,8 @@ def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
Returns:
A variable tensor.
"""
+ if not set_shape:
+ shape = tensor_shape.unknown_shape()
ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name,
container=container, shared_name=shared_name)
# TODO(mrry): Move this to where it is used, so we can get rid of this op