aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/training_util.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-11 17:03:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-11 17:07:24 -0800
commitaaf2eb05502e1a0e37f30017d79bb08a9a534711 (patch)
treee7231ad5c0a131eaca200b8e3048aa7f1c8734ec /tensorflow/python/training/training_util.py
parentc4a242f6d24378d722131b0cddf7d8700fb65f5a (diff)
Automated g4 rollback of changelist 178634559
PiperOrigin-RevId: 178695724
Diffstat (limited to 'tensorflow/python/training/training_util.py')
-rw-r--r--tensorflow/python/training/training_util.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index 2a42ff2003..89a9e12932 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -23,7 +23,6 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
@@ -222,6 +221,7 @@ def _get_or_create_global_step_read(graph=None):
global_step_tensor = get_global_step(graph)
if global_step_tensor is None:
return None
+ # add 'zero' so that it will create a copy of variable as Tensor.
with graph.as_default() as g, g.name_scope(None):
with g.name_scope(global_step_tensor.op.name + '/'):
# using initialized_value to ensure that global_step is initialized before
@@ -229,10 +229,7 @@ def _get_or_create_global_step_read(graph=None):
# under global_step_read_tensor dependency.
global_step_value = global_step_tensor.initialized_value() if isinstance(
global_step_tensor, variables.Variable) else global_step_tensor
- # pylint: disable=protected-access
- # We use the snapshot kernel to make sure a copy is made of this tensor.
- global_step_read_tensor = gen_array_ops._snapshot(global_step_value)
- # pylint: enable=protected-access
+ global_step_read_tensor = global_step_value + 0
ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)
return _get_global_step_read(graph)