diff options
author | 2017-12-11 17:03:54 -0800 | |
---|---|---|
committer | 2017-12-11 17:07:24 -0800 | |
commit | aaf2eb05502e1a0e37f30017d79bb08a9a534711 (patch) | |
tree | e7231ad5c0a131eaca200b8e3048aa7f1c8734ec /tensorflow/python/training/training_util.py | |
parent | c4a242f6d24378d722131b0cddf7d8700fb65f5a (diff) |
Automated g4 rollback of changelist 178634559
PiperOrigin-RevId: 178695724
Diffstat (limited to 'tensorflow/python/training/training_util.py')
-rw-r--r-- | tensorflow/python/training/training_util.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 2a42ff2003..89a9e12932 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -23,7 +23,6 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops @@ -222,6 +221,7 @@ def _get_or_create_global_step_read(graph=None): global_step_tensor = get_global_step(graph) if global_step_tensor is None: return None + # add 'zero' so that it will create a copy of variable as Tensor. with graph.as_default() as g, g.name_scope(None): with g.name_scope(global_step_tensor.op.name + '/'): # using initialized_value to ensure that global_step is initialized before @@ -229,10 +229,7 @@ def _get_or_create_global_step_read(graph=None): # under global_step_read_tensor dependency. global_step_value = global_step_tensor.initialized_value() if isinstance( global_step_tensor, variables.Variable) else global_step_tensor - # pylint: disable=protected-access - # We use the snapshot kernel to make sure a copy is made of this tensor. - global_step_read_tensor = gen_array_ops._snapshot(global_step_value) - # pylint: enable=protected-access + global_step_read_tensor = global_step_value + 0 ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor) return _get_global_step_read(graph) |