aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-10-27 13:47:18 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-27 15:06:51 -0700
commitb0f12ee6afcfaa92feae8ce998bbecdbf4f30fec (patch)
tree7020be640d34ccbae5d005ba1e1e8a45288e7ec5
parentcf2eb4cd9a5f8fe0cc7e9ccaf7d2e32c808af703 (diff)
Set the shape attr of a variable whenever we have enough information to do so.
Change: 137444094
-rw-r--r--tensorflow/python/ops/variables.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index c57d8c6ee0..05f780ccaa 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.core.framework import variable_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -280,6 +281,7 @@ class Variable(object):
"or set. Got %s of type %s" % (collections, type(collections)))
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
+ expected_shape = tensor_shape.as_shape(expected_shape)
with ops.control_dependencies(None):
with ops.name_scope(name, "Variable", [] if init_from_fn else
[initial_value]) as name:
@@ -287,6 +289,13 @@ class Variable(object):
# Get the initial value from a callable function. The real shape of the
# variable will be set later, since under the init_from_fn case, the
# shape won't be known until after the function is invoked.
+ #
+ # NOTE: The current Variable OpKernel does not support
+ # partially defined shapes, so we only set the shape if it is
+ # fully defined. For historical reasons, we use the scalar
+ # shape (`[]`) to represent an unknown or partially known
+ # shape. A future version of the Variable ops will remove this
+ # limitation.
def full_shape_to_list(shape):
"""Returns shape as a list if shape is fully defined."""
if shape and shape.is_fully_defined():
@@ -302,8 +311,10 @@ class Variable(object):
if init_from_fn:
expected_shape_list = full_shape_to_list(expected_shape)
+ set_shape = validate_shape and expected_shape.is_fully_defined()
self._variable = state_ops.variable_op(
- expected_shape_list, dtype.base_dtype, set_shape=False, name=name)
+ expected_shape_list, dtype.base_dtype, set_shape=set_shape,
+ name=name)
with ops.colocate_with(self._variable.op):
with ops.name_scope("Initializer"):
# Colocate the tensors created by the initial_value() function
@@ -317,12 +328,14 @@ class Variable(object):
self._initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
assert_expected_shape()
+ set_shape = (validate_shape
+ and self._initial_value.get_shape().is_fully_defined())
# In this case, the variable op can't be created until after the
# initial_value has been converted to a Tensor with a known type.
self._variable = state_ops.variable_op(
full_shape_to_list(self._initial_value.get_shape()),
self._initial_value.dtype.base_dtype,
- set_shape=False,
+ set_shape=set_shape,
name=name)
# Manually overrides the variable's shape with the initial value's.