diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-05 15:08:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 15:13:31 -0700 |
commit | c966b5eed60a570f2121cb84ddb4ece84c413719 (patch) | |
tree | c83bd5adb11106cb6034ecc1ed11d989a0e2afdd /tensorflow/python | |
parent | 07921022ddc68aacbf210acc62545a90e3091fb1 (diff) |
Add DistributionStrategy support to moving average APIs.
Fixes #21405.
PiperOrigin-RevId: 215973401
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/training/moving_averages.py | 49 |
1 files changed, 30 insertions, 19 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 041266da3e..89bfcaf4ad 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import slot_creator from tensorflow.python.util.tf_export import tf_export @@ -36,9 +37,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): The moving average of 'variable' updated with 'value' is: variable * decay + value * (1 - decay) - The returned Operation sets 'variable' to the newly computed moving average. - - The new value of 'variable' can be set with the 'AssignSub' op as: + The returned Operation sets 'variable' to the newly computed moving average, + by performing this subtraction: variable -= (1 - decay) * (variable - value) Since variables that are initialized to a `0` value will be `0` biased, @@ -50,7 +50,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): The names of the debias shadow variables, by default, include both the scope they were created in and the scope of the variables they debias. They are also - given a uniqifying-suffix. + given a uniquifying-suffix. E.g.: @@ -58,8 +58,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): with tf.variable_scope('scope1'): with tf.variable_scope('scope2'): var = tf.get_variable('foo') - tf.assign_moving_average(var, 0.0, 1.0) - tf.assign_moving_average(var, 0.0, 0.9) + update_1 = tf.assign_moving_average(var, 0.0, 1.0) + update_2 = tf.assign_moving_average(var, 0.0, 0.9) # var.name: 'scope1/scope2/foo' # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' @@ -76,20 +76,33 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): name: Optional name of the returned operation. Returns: - A reference to the input 'variable' tensor with the newly computed - moving average. + A tensor which if evaluated will compute and return the new moving average. """ + def update_fn(v, value, decay=decay): + decay = ops.convert_to_tensor(1.0 - decay, name="decay") + if decay.dtype != v.dtype.base_dtype: + decay = math_ops.cast(decay, v.dtype.base_dtype) + if zero_debias: + update_delta = _zero_debias(v, value, decay) + else: + update_delta = (v - value) * decay + return state_ops.assign_sub(v, update_delta, name=scope) + with ops.name_scope(name, "AssignMovingAvg", [variable, value, decay]) as scope: - with ops.colocate_with(variable): - decay = ops.convert_to_tensor(1.0 - decay, name="decay") - if decay.dtype != variable.dtype.base_dtype: - decay = math_ops.cast(decay, variable.dtype.base_dtype) - if zero_debias: - update_delta = _zero_debias(variable, value, decay) - else: - update_delta = (variable - value) * decay - return state_ops.assign_sub(variable, update_delta, name=scope) + tower_context = distribution_strategy_context.get_tower_context() + if tower_context: + # In a tower context, we update variable using the mean of value across + # towers. + def merge_fn(strategy, v, value): + value = strategy.reduce( + variable_scope.VariableAggregation.MEAN, value, v) + return strategy.update(v, update_fn, value) + + return tower_context.merge_call(merge_fn, variable, value) + else: + strategy = distribution_strategy_context.get_cross_tower_context() + return strategy.update(variable, update_fn, value) def weighted_moving_average(value, @@ -379,8 +392,6 @@ class ExponentialMovingAverage(object): Raises: TypeError: If the arguments are not an allowed type. - ValueError: If the moving average of one of the variables is already - being computed. """ # TODO(touts): op_scope if var_list is None: |