diff options
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r-- | tensorflow/python/ops/variables.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index c700a8a924..f5b7ad6632 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.util.deprecation import deprecated @@ -316,9 +317,14 @@ class Variable(object): if init_from_fn: expected_shape_list = full_shape_to_list(expected_shape) set_shape = validate_shape and expected_shape.is_fully_defined() - self._variable = state_ops.variable_op( - expected_shape_list, dtype.base_dtype, set_shape=set_shape, - name=name) + self._variable = gen_state_ops._variable( + shape=expected_shape_list, + dtype=dtype.base_dtype, + name=name, + container="", + shared_name="") + if set_shape: + self._variable.set_shape(expected_shape_list) with ops.colocate_with(self._variable.op): with ops.name_scope("Initializer"): # Colocate the tensors created by the initial_value() function @@ -336,12 +342,15 @@ class Variable(object): and self._initial_value.get_shape().is_fully_defined()) # In this case, the variable op can't be created until after the # initial_value has been converted to a Tensor with a known type. - self._variable = state_ops.variable_op( - full_shape_to_list(self._initial_value.get_shape()), - self._initial_value.dtype.base_dtype, - set_shape=set_shape, - name=name) - + self._variable = gen_state_ops._variable( + shape=full_shape_to_list(self._initial_value.get_shape()), + dtype=self._initial_value.dtype.base_dtype, + name=name, + container="", + shared_name="") + if set_shape: + self._variable.set_shape( + full_shape_to_list(self._initial_value.get_shape())) # Manually overrides the variable's shape with the initial value's. if validate_shape: initial_value_shape = self._initial_value.get_shape() |