diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-11 10:15:59 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-11 10:19:49 -0800 |
commit | 865bef39bcd563bac6216807bdd4dfa06647adf1 (patch) | |
tree | 83988f8e7666a697ea68ec1bf3feab7a373cbd57 /tensorflow/python/training/training_util.py | |
parent | d423542a78257aa32966d6fc26915874803bc166 (diff) |
Use the Snapshot kernel to force a copy of global step instead of the ugly "x + 0" hack.
PiperOrigin-RevId: 178634559
Diffstat (limited to 'tensorflow/python/training/training_util.py')
-rw-r--r-- | tensorflow/python/training/training_util.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 89a9e12932..2a42ff2003 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -23,6 +23,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops @@ -221,7 +222,6 @@ def _get_or_create_global_step_read(graph=None): global_step_tensor = get_global_step(graph) if global_step_tensor is None: return None - # add 'zero' so that it will create a copy of variable as Tensor. with graph.as_default() as g, g.name_scope(None): with g.name_scope(global_step_tensor.op.name + '/'): # using initialized_value to ensure that global_step is initialized before @@ -229,7 +229,10 @@ def _get_or_create_global_step_read(graph=None): # under global_step_read_tensor dependency. global_step_value = global_step_tensor.initialized_value() if isinstance( global_step_tensor, variables.Variable) else global_step_tensor - global_step_read_tensor = global_step_value + 0 + # pylint: disable=protected-access + # We use the snapshot kernel to make sure a copy is made of this tensor. + global_step_read_tensor = gen_array_ops._snapshot(global_step_value) + # pylint: enable=protected-access ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor) return _get_global_step_read(graph) |