aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/state_ops.cc
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-25 10:19:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-25 11:31:53 -0700
commit6f259612b6763444df59f8229e14b5127ec40b75 (patch)
tree4833b6b2dbac245740cc1bb3b5b1689ec7109f9e /tensorflow/core/ops/state_ops.cc
parentad3d63f66e359cf4246a55dfa85f5d6d4cb43101 (diff)
Fix shape function for the Variable op.
Well, almost. The shape function cannot distinguish between an unknown ranked variable and a scalar variable. This is for historical reasons, as the Variable operator (like the Placeholder operator) has been in existence since a time when an empty TensorShapeProto used to imply an unknown rank. There is a plan to "fix" this by introducing and transitioning to PlaceholderV2 and VariableV2 operations (see https://github.com/tensorflow/tensorflow/commit/5e176998d92a64d78df57e9fb78582e5e7e4ebb6) But until that happens, make the shape function for the Variable op work for all but scalars. An explanation of the changes to python: - state_ops.py: This is done to keep the in-memory representation of the graph in python (which resulted in an unknown shape when set_shape=False) is consistent with what any C++ code thinks about the shape of the Variable op - variable_scope_test.py: When an initializer is provided to tf.get_variable, it ultimately calls state_ops.variable with set_shape=False. Because of the change to state_ops.py, the shape attribute wouldn't end up with [73]. My understanding is that the use of the shape attribute wasn't the intention of the test, so switching to another attribute maintains the intention. Fixes #5106 Change: 137179498
Diffstat (limited to 'tensorflow/core/ops/state_ops.cc')
-rw-r--r--tensorflow/core/ops/state_ops.cc19
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 629a280cc8..b9ac8b16ff 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -28,7 +28,24 @@ REGISTER_OP("Variable")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
+
+ // Variable has legacy behavior where we cannot tell the difference
+ // between a scalar shape attribute and 'unknown shape'. So if the shape
+ // is a scalar, we return an unknown shape.
+ if (shape.dims() <= 0) {
+ return shape_inference::UnknownShape(c);
+ }
+
+ TensorShapeProto shape_proto;
+ shape.AsProto(&shape_proto);
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
.Doc(R"doc(
Holds state in the form of a tensor that persists across steps.