diff options
author | 2016-10-27 13:47:18 -0800 | |
---|---|---|
committer | 2016-10-27 15:06:51 -0700 | |
commit | b0f12ee6afcfaa92feae8ce998bbecdbf4f30fec (patch) | |
tree | 7020be640d34ccbae5d005ba1e1e8a45288e7ec5 | |
parent | cf2eb4cd9a5f8fe0cc7e9ccaf7d2e32c808af703 (diff) |
Set the shape attr of a variable whenever we have enough information to do so.
Change: 137444094
-rw-r--r-- | tensorflow/python/ops/variables.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index c57d8c6ee0..05f780ccaa 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.core.framework import variable_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -280,6 +281,7 @@ class Variable(object): "or set. Got %s of type %s" % (collections, type(collections))) if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] + expected_shape = tensor_shape.as_shape(expected_shape) with ops.control_dependencies(None): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: @@ -287,6 +289,13 @@ class Variable(object): # Get the initial value from a callable function. The real shape of the # variable will be set later, since under the init_from_fn case, the # shape won't be known until after the function is invoked. + # + # NOTE: The current Variable OpKernel does not support + # partially defined shapes, so we only set the shape if it is + # fully defined. For historical reasons, we use the scalar + # shape (`[]`) to represent an unknown or partially known + # shape. A future version of the Variable ops will remove this + # limitation. def full_shape_to_list(shape): """Returns shape as a list if shape is fully defined.""" if shape and shape.is_fully_defined(): @@ -302,8 +311,10 @@ class Variable(object): if init_from_fn: expected_shape_list = full_shape_to_list(expected_shape) + set_shape = validate_shape and expected_shape.is_fully_defined() self._variable = state_ops.variable_op( - expected_shape_list, dtype.base_dtype, set_shape=False, name=name) + expected_shape_list, dtype.base_dtype, set_shape=set_shape, + name=name) with ops.colocate_with(self._variable.op): with ops.name_scope("Initializer"): # Colocate the tensors created by the initial_value() function @@ -317,12 +328,14 @@ class Variable(object): self._initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) assert_expected_shape() + set_shape = (validate_shape + and self._initial_value.get_shape().is_fully_defined()) # In this case, the variable op can't be created until after the # initial_value has been converted to a Tensor with a known type. self._variable = state_ops.variable_op( full_shape_to_list(self._initial_value.get_shape()), self._initial_value.dtype.base_dtype, - set_shape=False, + set_shape=set_shape, name=name) # Manually overrides the variable's shape with the initial value's. |