aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 15:08:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 15:13:31 -0700
commitc966b5eed60a570f2121cb84ddb4ece84c413719 (patch)
treec83bd5adb11106cb6034ecc1ed11d989a0e2afdd
parent07921022ddc68aacbf210acc62545a90e3091fb1 (diff)
Add DistributionStrategy support to moving average APIs.
Fixes #21405. PiperOrigin-RevId: 215973401
-rw-r--r--tensorflow/contrib/distribute/python/BUILD18
-rw-r--r--tensorflow/contrib/distribute/python/moving_averages_test.py141
-rw-r--r--tensorflow/python/training/moving_averages.py49
3 files changed, 189 insertions, 19 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 8267612236..76d5b59ce1 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -412,6 +412,24 @@ cuda_py_test(
)
cuda_py_test(
+ name = "moving_averages_test",
+ srcs = ["moving_averages_test.py"],
+ additional_deps = [
+ ":combinations",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
+
+cuda_py_test(
name = "optimizer_v2_test",
srcs = ["optimizer_v2_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py
new file mode 100644
index 0000000000..119352ad91
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/moving_averages_test.py
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for training.moving_averages when using a DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import moving_averages
+
+
+all_combinations = combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu],
+ mode=["graph"])
+
+
+class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_combinations)
+ def testTowerModeWithoutZeroDebias(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+ return var, assign
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(distribution.unwrap(assign))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testTowerMode(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([0.0, 0.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ return var, assign.op
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign_op = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(distribution.unwrap(assign_op))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ self.assertAllClose(average_val, var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTowerWithoutZeroDebias(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0, 2.0])
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(assign)
+ average_val = [1.0, 2.0]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+ # Also try assign.op.
+ sess.run(assign.op)
+ orig_weight = 0.25 * 0.25
+ val_weight = 1.0 - orig_weight
+ self.assertAllClose(
+ [10.0 * orig_weight + average_val[0] * val_weight,
+ 11.0 * orig_weight + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTower(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([0.0, 0.0])
+ val = array_ops.placeholder(dtypes.float32)
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(var, val, decay)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(assign, feed_dict={val: [1.0, 2.0]})
+ self.assertAllClose([1.0, 2.0], var.eval())
+
+ # Also try assign.op.
+ sess.run(assign.op, feed_dict={val: [10.0, 0.0]})
+ self.assertAllClose(
+ [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0),
+ (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
+ var.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 041266da3e..89bfcaf4ad 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import slot_creator
from tensorflow.python.util.tf_export import tf_export
@@ -36,9 +37,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The moving average of 'variable' updated with 'value' is:
variable * decay + value * (1 - decay)
- The returned Operation sets 'variable' to the newly computed moving average.
-
- The new value of 'variable' can be set with the 'AssignSub' op as:
+ The returned Operation sets 'variable' to the newly computed moving average,
+ by performing this subtraction:
variable -= (1 - decay) * (variable - value)
Since variables that are initialized to a `0` value will be `0` biased,
@@ -50,7 +50,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The names of the debias shadow variables, by default, include both the scope
they were created in and the scope of the variables they debias. They are also
- given a uniqifying-suffix.
+ given a uniquifying-suffix.
E.g.:
@@ -58,8 +58,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
with tf.variable_scope('scope1'):
with tf.variable_scope('scope2'):
var = tf.get_variable('foo')
- tf.assign_moving_average(var, 0.0, 1.0)
- tf.assign_moving_average(var, 0.0, 0.9)
+ update_1 = tf.assign_moving_average(var, 0.0, 1.0)
+ update_2 = tf.assign_moving_average(var, 0.0, 0.9)
# var.name: 'scope1/scope2/foo'
# shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
@@ -76,20 +76,33 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
name: Optional name of the returned operation.
Returns:
- A reference to the input 'variable' tensor with the newly computed
- moving average.
+ A tensor which if evaluated will compute and return the new moving average.
"""
+ def update_fn(v, value, decay=decay):
+ decay = ops.convert_to_tensor(1.0 - decay, name="decay")
+ if decay.dtype != v.dtype.base_dtype:
+ decay = math_ops.cast(decay, v.dtype.base_dtype)
+ if zero_debias:
+ update_delta = _zero_debias(v, value, decay)
+ else:
+ update_delta = (v - value) * decay
+ return state_ops.assign_sub(v, update_delta, name=scope)
+
with ops.name_scope(name, "AssignMovingAvg",
[variable, value, decay]) as scope:
- with ops.colocate_with(variable):
- decay = ops.convert_to_tensor(1.0 - decay, name="decay")
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- if zero_debias:
- update_delta = _zero_debias(variable, value, decay)
- else:
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ # In a tower context, we update variable using the mean of value across
+ # towers.
+ def merge_fn(strategy, v, value):
+ value = strategy.reduce(
+ variable_scope.VariableAggregation.MEAN, value, v)
+ return strategy.update(v, update_fn, value)
+
+ return tower_context.merge_call(merge_fn, variable, value)
+ else:
+ strategy = distribution_strategy_context.get_cross_tower_context()
+ return strategy.update(variable, update_fn, value)
def weighted_moving_average(value,
@@ -379,8 +392,6 @@ class ExponentialMovingAverage(object):
Raises:
TypeError: If the arguments are not an allowed type.
- ValueError: If the moving average of one of the variables is already
- being computed.
"""
# TODO(touts): op_scope
if var_list is None: