diff options
author | 2018-10-05 15:08:18 -0700 | |
---|---|---|
committer | 2018-10-05 15:13:31 -0700 | |
commit | c966b5eed60a570f2121cb84ddb4ece84c413719 (patch) | |
tree | c83bd5adb11106cb6034ecc1ed11d989a0e2afdd | |
parent | 07921022ddc68aacbf210acc62545a90e3091fb1 (diff) |
Add DistributionStrategy support to moving average APIs.
Fixes #21405.
PiperOrigin-RevId: 215973401
-rw-r--r-- | tensorflow/contrib/distribute/python/BUILD | 18 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/moving_averages_test.py | 141 | ||||
-rw-r--r-- | tensorflow/python/training/moving_averages.py | 49 |
3 files changed, 189 insertions, 19 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 8267612236..76d5b59ce1 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -412,6 +412,24 @@ cuda_py_test( ) cuda_py_test( + name = "moving_averages_test", + srcs = ["moving_averages_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python/eager:test", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], + tags = [ + "no_pip", + ], +) + +cuda_py_test( name = "optimizer_v2_test", srcs = ["optimizer_v2_test.py"], additional_deps = [ diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py new file mode 100644 index 0000000000..119352ad91 --- /dev/null +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training.moving_averages when using a DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import moving_averages + + +all_combinations = combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu], + mode=["graph"]) + + +class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): + + @combinations.generate(all_combinations) + def testTowerModeWithoutZeroDebias(self, distribution): + tower_id = [0] + + def tower_fn(): + var = variables.Variable([10.0, 11.0]) + val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]]) + tower_id[0] += 1 + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_tower(tower_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + # Mean of val across calls to tower_fn(). + average_val = [1.0 + 0.5 * (tower_id[0] - 1), + 2.0 - 0.5 * (tower_id[0] - 1)] + val_weight = 1.0 - 0.25 + self.assertAllClose( + [10.0 * 0.25 + average_val[0] * val_weight, + 11.0 * 0.25 + average_val[1] * val_weight], + var.eval()) + + @combinations.generate(all_combinations) + def testTowerMode(self, distribution): + tower_id = [0] + + def tower_fn(): + var = variables.Variable([0.0, 0.0]) + val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]]) + tower_id[0] += 1 + decay = 0.25 + assign = moving_averages.assign_moving_average(var, val, decay) + return var, assign.op + + with distribution.scope(), self.cached_session() as sess: + var, assign_op = distribution.call_for_each_tower(tower_fn) + variables.global_variables_initializer().run() + self.assertAllClose([0.0, 0.0], var.eval()) + sess.run(distribution.unwrap(assign_op)) + # Mean of val across calls to tower_fn(). + average_val = [1.0 + 0.5 * (tower_id[0] - 1), + 2.0 - 0.5 * (tower_id[0] - 1)] + self.assertAllClose(average_val, var.eval()) + + @combinations.generate(all_combinations) + def testCrossTowerWithoutZeroDebias(self, distribution): + with distribution.scope(), self.cached_session() as sess: + var = variables.Variable([10.0, 11.0]) + val = constant_op.constant([1.0, 2.0]) + decay = 0.25 + # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(assign) + average_val = [1.0, 2.0] + val_weight = 1.0 - 0.25 + self.assertAllClose( + [10.0 * 0.25 + average_val[0] * val_weight, + 11.0 * 0.25 + average_val[1] * val_weight], + var.eval()) + # Also try assign.op. + sess.run(assign.op) + orig_weight = 0.25 * 0.25 + val_weight = 1.0 - orig_weight + self.assertAllClose( + [10.0 * orig_weight + average_val[0] * val_weight, + 11.0 * orig_weight + average_val[1] * val_weight], + var.eval()) + + @combinations.generate(all_combinations) + def testCrossTower(self, distribution): + with distribution.scope(), self.cached_session() as sess: + var = variables.Variable([0.0, 0.0]) + val = array_ops.placeholder(dtypes.float32) + decay = 0.25 + # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + assign = moving_averages.assign_moving_average(var, val, decay) + + variables.global_variables_initializer().run() + self.assertAllClose([0.0, 0.0], var.eval()) + sess.run(assign, feed_dict={val: [1.0, 2.0]}) + self.assertAllClose([1.0, 2.0], var.eval()) + + # Also try assign.op. + sess.run(assign.op, feed_dict={val: [10.0, 0.0]}) + self.assertAllClose( + [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0), + (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], + var.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 041266da3e..89bfcaf4ad 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import slot_creator from tensorflow.python.util.tf_export import tf_export @@ -36,9 +37,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): The moving average of 'variable' updated with 'value' is: variable * decay + value * (1 - decay) - The returned Operation sets 'variable' to the newly computed moving average. - - The new value of 'variable' can be set with the 'AssignSub' op as: + The returned Operation sets 'variable' to the newly computed moving average, + by performing this subtraction: variable -= (1 - decay) * (variable - value) Since variables that are initialized to a `0` value will be `0` biased, @@ -50,7 +50,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): The names of the debias shadow variables, by default, include both the scope they were created in and the scope of the variables they debias. They are also - given a uniqifying-suffix. + given a uniquifying-suffix. E.g.: @@ -58,8 +58,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): with tf.variable_scope('scope1'): with tf.variable_scope('scope2'): var = tf.get_variable('foo') - tf.assign_moving_average(var, 0.0, 1.0) - tf.assign_moving_average(var, 0.0, 0.9) + update_1 = tf.assign_moving_average(var, 0.0, 1.0) + update_2 = tf.assign_moving_average(var, 0.0, 0.9) # var.name: 'scope1/scope2/foo' # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' @@ -76,20 +76,33 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None): name: Optional name of the returned operation. Returns: - A reference to the input 'variable' tensor with the newly computed - moving average. + A tensor which if evaluated will compute and return the new moving average. """ + def update_fn(v, value, decay=decay): + decay = ops.convert_to_tensor(1.0 - decay, name="decay") + if decay.dtype != v.dtype.base_dtype: + decay = math_ops.cast(decay, v.dtype.base_dtype) + if zero_debias: + update_delta = _zero_debias(v, value, decay) + else: + update_delta = (v - value) * decay + return state_ops.assign_sub(v, update_delta, name=scope) + with ops.name_scope(name, "AssignMovingAvg", [variable, value, decay]) as scope: - with ops.colocate_with(variable): - decay = ops.convert_to_tensor(1.0 - decay, name="decay") - if decay.dtype != variable.dtype.base_dtype: - decay = math_ops.cast(decay, variable.dtype.base_dtype) - if zero_debias: - update_delta = _zero_debias(variable, value, decay) - else: - update_delta = (variable - value) * decay - return state_ops.assign_sub(variable, update_delta, name=scope) + tower_context = distribution_strategy_context.get_tower_context() + if tower_context: + # In a tower context, we update variable using the mean of value across + # towers. + def merge_fn(strategy, v, value): + value = strategy.reduce( + variable_scope.VariableAggregation.MEAN, value, v) + return strategy.update(v, update_fn, value) + + return tower_context.merge_call(merge_fn, variable, value) + else: + strategy = distribution_strategy_context.get_cross_tower_context() + return strategy.update(variable, update_fn, value) def weighted_moving_average(value, @@ -379,8 +392,6 @@ class ExponentialMovingAverage(object): Raises: TypeError: If the arguments are not an allowed type. - ValueError: If the moving average of one of the variables is already - being computed. """ # TODO(touts): op_scope if var_list is None: |