aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 15:08:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 15:13:31 -0700
commitc966b5eed60a570f2121cb84ddb4ece84c413719 (patch)
treec83bd5adb11106cb6034ecc1ed11d989a0e2afdd /tensorflow/contrib
parent07921022ddc68aacbf210acc62545a90e3091fb1 (diff)
Add DistributionStrategy support to moving average APIs.
Fixes #21405. PiperOrigin-RevId: 215973401
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/distribute/python/BUILD18
-rw-r--r--tensorflow/contrib/distribute/python/moving_averages_test.py141
2 files changed, 159 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 8267612236..76d5b59ce1 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -412,6 +412,24 @@ cuda_py_test(
)
cuda_py_test(
+ name = "moving_averages_test",
+ srcs = ["moving_averages_test.py"],
+ additional_deps = [
+ ":combinations",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
+
+cuda_py_test(
name = "optimizer_v2_test",
srcs = ["optimizer_v2_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py
new file mode 100644
index 0000000000..119352ad91
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/moving_averages_test.py
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for training.moving_averages when using a DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import moving_averages
+
+
+all_combinations = combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu],
+ mode=["graph"])
+
+
+class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_combinations)
+ def testTowerModeWithoutZeroDebias(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+ return var, assign
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(distribution.unwrap(assign))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testTowerMode(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([0.0, 0.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ return var, assign.op
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign_op = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(distribution.unwrap(assign_op))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ self.assertAllClose(average_val, var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTowerWithoutZeroDebias(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0, 2.0])
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(assign)
+ average_val = [1.0, 2.0]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+ # Also try assign.op.
+ sess.run(assign.op)
+ orig_weight = 0.25 * 0.25
+ val_weight = 1.0 - orig_weight
+ self.assertAllClose(
+ [10.0 * orig_weight + average_val[0] * val_weight,
+ 11.0 * orig_weight + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTower(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([0.0, 0.0])
+ val = array_ops.placeholder(dtypes.float32)
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(var, val, decay)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(assign, feed_dict={val: [1.0, 2.0]})
+ self.assertAllClose([1.0, 2.0], var.eval())
+
+ # Also try assign.op.
+ sess.run(assign.op, feed_dict={val: [10.0, 0.0]})
+ self.assertAllClose(
+ [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0),
+ (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
+ var.eval())
+
+
+if __name__ == "__main__":
+ test.main()