diff options
author | 2016-10-25 10:19:49 -0800 | |
---|---|---|
committer | 2016-10-25 11:31:53 -0700 | |
commit | 6f259612b6763444df59f8229e14b5127ec40b75 (patch) | |
tree | 4833b6b2dbac245740cc1bb3b5b1689ec7109f9e /tensorflow/core/ops/state_ops.cc | |
parent | ad3d63f66e359cf4246a55dfa85f5d6d4cb43101 (diff) |
Fix shape function for the Variable op.
Well, almost. The shape function cannot distinguish
between an unknown ranked variable and a scalar variable.
This is for historical reasons, as the Variable operator
(like the Placeholder operator) has been in existence
since a time when an empty TensorShapeProto used to
imply an unknown rank.
There is a plan to "fix" this by introducing and transitioning
to PlaceholderV2 and VariableV2 operations (see https://github.com/tensorflow/tensorflow/commit/5e176998d92a64d78df57e9fb78582e5e7e4ebb6)
But until that happens, make the shape function for the
Variable op work for all but scalars.
An explanation of the changes to python:
- state_ops.py: This is done to keep the in-memory representation of the graph
in python (which resulted in an unknown shape when set_shape=False)
is consistent with what any C++ code thinks about the shape of the Variable op
- variable_scope_test.py: When an initializer is provided to tf.get_variable,
it ultimately calls state_ops.variable with set_shape=False. Because of the
change to state_ops.py, the shape attribute wouldn't end up with [73].
My understanding is that the use of the shape attribute wasn't the intention
of the test, so switching to another attribute maintains the intention.
Fixes #5106
Change: 137179498
Diffstat (limited to 'tensorflow/core/ops/state_ops.cc')
-rw-r--r-- | tensorflow/core/ops/state_ops.cc | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 629a280cc8..b9ac8b16ff 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -28,7 +28,24 @@ REGISTER_OP("Variable") .Attr("container: string = ''") .Attr("shared_name: string = ''") .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + + // Variable has legacy behavior where we cannot tell the difference + // between a scalar shape attribute and 'unknown shape'. So if the shape + // is a scalar, we return an unknown shape. + if (shape.dims() <= 0) { + return shape_inference::UnknownShape(c); + } + + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + c->set_output(0, out); + return Status::OK(); + }) .Doc(R"doc( Holds state in the form of a tensor that persists across steps. |