aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variables.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r--tensorflow/python/ops/variables.py27
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index c700a8a924..f5b7ad6632 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -23,6 +23,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.util.deprecation import deprecated
@@ -316,9 +317,14 @@ class Variable(object):
if init_from_fn:
expected_shape_list = full_shape_to_list(expected_shape)
set_shape = validate_shape and expected_shape.is_fully_defined()
- self._variable = state_ops.variable_op(
- expected_shape_list, dtype.base_dtype, set_shape=set_shape,
- name=name)
+ self._variable = gen_state_ops._variable(
+ shape=expected_shape_list,
+ dtype=dtype.base_dtype,
+ name=name,
+ container="",
+ shared_name="")
+ if set_shape:
+ self._variable.set_shape(expected_shape_list)
with ops.colocate_with(self._variable.op):
with ops.name_scope("Initializer"):
# Colocate the tensors created by the initial_value() function
@@ -336,12 +342,15 @@ class Variable(object):
and self._initial_value.get_shape().is_fully_defined())
# In this case, the variable op can't be created until after the
# initial_value has been converted to a Tensor with a known type.
- self._variable = state_ops.variable_op(
- full_shape_to_list(self._initial_value.get_shape()),
- self._initial_value.dtype.base_dtype,
- set_shape=set_shape,
- name=name)
-
+ self._variable = gen_state_ops._variable(
+ shape=full_shape_to_list(self._initial_value.get_shape()),
+ dtype=self._initial_value.dtype.base_dtype,
+ name=name,
+ container="",
+ shared_name="")
+ if set_shape:
+ self._variable.set_shape(
+ full_shape_to_list(self._initial_value.get_shape()))
# Manually overrides the variable's shape with the initial value's.
if validate_shape:
initial_value_shape = self._initial_value.get_shape()