aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/minimize_loss_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/minimize_loss_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py115
1 files changed, 80 insertions, 35 deletions
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index e134fe34e1..d2054715f1 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -44,13 +44,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
combinations.distributions_and_v1_optimizers(),
combinations.combine(mode=["graph"], use_callable_loss=[True, False])
+ combinations.combine(mode=["eager"], use_callable_loss=[True]),
- combinations.combine(is_tpu=[False])) +
- combinations.combine(
- distribution=[combinations.tpu_strategy],
- optimizer_fn=[combinations.adam_optimizer_v1_fn],
- mode=["graph"],
- use_callable_loss=[False],
- is_tpu=[True]))
+ combinations.combine(is_tpu=[False])) + combinations.combine(
+ distribution=[combinations.tpu_strategy],
+ optimizer_fn=[
+ combinations.adam_optimizer_v1_fn,
+ # TODO(isaprykin): Make Adam v2 work with while_loops
+ # and TPUs.
+ ],
+ mode=["graph"],
+ use_callable_loss=[False],
+ is_tpu=[True]))
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss,
is_tpu):
with distribution.scope():
@@ -101,7 +104,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution=[combinations.tpu_strategy],
optimizer_fn=[
combinations.adam_optimizer_v1_fn,
- combinations.gradient_descent_optimizer_v1_fn
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v2_fn,
],
mode=["graph"],
is_tpu=[True]))
@@ -171,13 +175,28 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
set(created_variables))
@combinations.generate(
- combinations.times(combinations.distributions_and_v1_optimizers(),
- combinations.combine(
- mode=["graph", "eager"],
- momentum=[0.8, 0.9, 0.99],
- renorm=[False, True])))
+ combinations.times(
+ combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]),
+ combinations.times(
+ combinations.distributions_and_v1_optimizers(),
+ combinations.combine(
+ mode=["graph", "eager"],
+ is_tpu=[False],
+ # TODO(isaprykin): Allow False here. Currently subsequent
+ # towers will re-execute UPDATE_OPS of previous towers.
+ update_ops_in_cross_tower_mode=[True])) +
+ combinations.combine(
+ distribution=[combinations.tpu_strategy_single_iteration],
+ optimizer_fn=[
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v2_fn
+ ],
+ mode=["graph"],
+ is_tpu=[True],
+ update_ops_in_cross_tower_mode=[False])))
def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum,
- renorm):
+ renorm, is_tpu,
+ update_ops_in_cross_tower_mode):
"""Verifies that moving mean updates are reduced across towers."""
with distribution.scope():
num_towers = len(distribution.worker_devices)
@@ -185,7 +204,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
optimizer_fn,
batch_per_epoch=num_towers,
momentum=momentum,
- renorm=renorm)
+ renorm=renorm,
+ update_ops_in_tower_mode=not update_ops_in_cross_tower_mode)
# Disable prefetching since that makes the specific input on each device
# to be non deterministic, and this test relies on specific input being
@@ -196,16 +216,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
dataset_fn).make_one_shot_iterator()
def run_step():
- return control_flow_ops.group(
- distribution.unwrap(
- distribution.call_for_each_tower(
- model_fn,
- iterator.get_next(),
- run_concurrently=batchnorm.built)) +
- ops.get_collection(ops.GraphKeys.UPDATE_OPS))
+ fetches = distribution.unwrap(
+ distribution.call_for_each_tower(
+ model_fn, iterator.get_next(),
+ run_concurrently=batchnorm.built))
+ if update_ops_in_cross_tower_mode:
+ fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
+ return control_flow_ops.group(fetches)
if not context.executing_eagerly():
with self.test_session() as sess:
+ if is_tpu:
+ sess.run(tpu.initialize_system())
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -229,22 +251,40 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)
+ if is_tpu:
+ with self.test_session() as sess:
+ sess.run(tpu.shutdown_system())
+
@combinations.generate(
combinations.times(
combinations.combine(
- distribution=[combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus],
- optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn,
- combinations.gradient_descent_optimizer_v2_fn],
- loss_reduction=[losses_impl.Reduction.SUM,
- losses_impl.Reduction.MEAN,
- losses_impl.Reduction.SUM_OVER_BATCH_SIZE,
- losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]),
- combinations.combine(mode=["graph"], use_callable_loss=[True, False])
- + combinations.combine(mode=["eager"], use_callable_loss=[True])))
+ optimizer_fn=[
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v2_fn
+ ],
+ loss_reduction=[
+ losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN,
+ losses_impl.Reduction.SUM_OVER_BATCH_SIZE,
+ losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS
+ ]),
+ combinations.times(
+ combinations.combine(
+ distribution=[
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus
+ ],
+ is_tpu=[False]),
+ combinations.combine(
+ mode=["graph"], use_callable_loss=[True, False]) +
+ combinations.combine(mode=["eager"], use_callable_loss=[True])) +
+ combinations.combine(
+ distribution=[combinations.tpu_strategy_single_iteration],
+ is_tpu=[True],
+ mode=["graph"],
+ use_callable_loss=[True, False])))
def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction,
- use_callable_loss):
+ use_callable_loss, is_tpu):
with distribution.scope():
all_vars = []
@@ -280,12 +320,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.test_session() as sess:
+ if is_tpu:
+ sess.run(tpu.initialize_system())
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
run_step()
- self.assertEqual(distribution.num_towers, len(all_vars))
v = all_vars[0]
self.assertTrue(all([v is vi for vi in all_vars[1:]]))
weight = numpy.squeeze(self.evaluate(distribution.fetch(v)))
@@ -312,6 +353,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# One of the mean loss reductions.
self.assertNear(weight, 2 + 10.6, 0.0001)
+ if is_tpu:
+ with self.test_session() as sess:
+ sess.run(tpu.shutdown_system())
+
if __name__ == "__main__":
test.main()