aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 15:08:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 15:13:31 -0700
commitc966b5eed60a570f2121cb84ddb4ece84c413719 (patch)
treec83bd5adb11106cb6034ecc1ed11d989a0e2afdd /tensorflow/python
parent07921022ddc68aacbf210acc62545a90e3091fb1 (diff)
Add DistributionStrategy support to moving average APIs.
Fixes #21405. PiperOrigin-RevId: 215973401
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/training/moving_averages.py49
1 files changed, 30 insertions, 19 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 041266da3e..89bfcaf4ad 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import slot_creator
from tensorflow.python.util.tf_export import tf_export
@@ -36,9 +37,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The moving average of 'variable' updated with 'value' is:
variable * decay + value * (1 - decay)
- The returned Operation sets 'variable' to the newly computed moving average.
-
- The new value of 'variable' can be set with the 'AssignSub' op as:
+ The returned Operation sets 'variable' to the newly computed moving average,
+ by performing this subtraction:
variable -= (1 - decay) * (variable - value)
Since variables that are initialized to a `0` value will be `0` biased,
@@ -50,7 +50,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The names of the debias shadow variables, by default, include both the scope
they were created in and the scope of the variables they debias. They are also
- given a uniqifying-suffix.
+ given a uniquifying-suffix.
E.g.:
@@ -58,8 +58,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
with tf.variable_scope('scope1'):
with tf.variable_scope('scope2'):
var = tf.get_variable('foo')
- tf.assign_moving_average(var, 0.0, 1.0)
- tf.assign_moving_average(var, 0.0, 0.9)
+ update_1 = tf.assign_moving_average(var, 0.0, 1.0)
+ update_2 = tf.assign_moving_average(var, 0.0, 0.9)
# var.name: 'scope1/scope2/foo'
# shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
@@ -76,20 +76,33 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
name: Optional name of the returned operation.
Returns:
- A reference to the input 'variable' tensor with the newly computed
- moving average.
+ A tensor which if evaluated will compute and return the new moving average.
"""
+ def update_fn(v, value, decay=decay):
+ decay = ops.convert_to_tensor(1.0 - decay, name="decay")
+ if decay.dtype != v.dtype.base_dtype:
+ decay = math_ops.cast(decay, v.dtype.base_dtype)
+ if zero_debias:
+ update_delta = _zero_debias(v, value, decay)
+ else:
+ update_delta = (v - value) * decay
+ return state_ops.assign_sub(v, update_delta, name=scope)
+
with ops.name_scope(name, "AssignMovingAvg",
[variable, value, decay]) as scope:
- with ops.colocate_with(variable):
- decay = ops.convert_to_tensor(1.0 - decay, name="decay")
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- if zero_debias:
- update_delta = _zero_debias(variable, value, decay)
- else:
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ # In a tower context, we update variable using the mean of value across
+ # towers.
+ def merge_fn(strategy, v, value):
+ value = strategy.reduce(
+ variable_scope.VariableAggregation.MEAN, value, v)
+ return strategy.update(v, update_fn, value)
+
+ return tower_context.merge_call(merge_fn, variable, value)
+ else:
+ strategy = distribution_strategy_context.get_cross_tower_context()
+ return strategy.update(variable, update_fn, value)
def weighted_moving_average(value,
@@ -379,8 +392,6 @@ class ExponentialMovingAverage(object):
Raises:
TypeError: If the arguments are not an allowed type.
- ValueError: If the moving average of one of the variables is already
- being computed.
"""
# TODO(touts): op_scope
if var_list is None: