aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/distribute_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/distribute_test.py')
-rw-r--r--tensorflow/python/training/distribute_test.py39
1 files changed, 36 insertions, 3 deletions
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py
index 0a4f19c31f..694145ede7 100644
--- a/tensorflow/python/training/distribute_test.py
+++ b/tensorflow/python/training/distribute_test.py
@@ -29,6 +29,14 @@ class _TestTowerContext(distribute.TowerContext):
return kwargs["test_arg"]
+def _get_test_variable(name, synchronization, aggregation):
+ return {
+ "name": name,
+ "synchronization": synchronization,
+ "aggregation": aggregation
+ }
+
+
class _TestStrategy(distribute.DistributionStrategy):
def _call_for_each_tower(self, fn, *args, **kwargs):
@@ -36,7 +44,8 @@ class _TestStrategy(distribute.DistributionStrategy):
return fn(*args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs):
- return kwargs["name"]
+ return _get_test_variable(kwargs["name"], kwargs["synchronization"],
+ kwargs["aggregation"])
def _assert_in_default_state(t):
@@ -61,7 +70,11 @@ class TestStrategyTest(test.TestCase):
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
- self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
+ expected_value = _get_test_variable(
+ "bar", variable_scope.VariableSynchronization.AUTO,
+ variable_scope.VariableAggregation.NONE)
+ self.assertDictEqual(expected_value,
+ variable_scope.variable(1.0, name="bar"))
with self.assertRaises(RuntimeError):
dist.call_for_each_tower(run_fn)
@@ -77,7 +90,27 @@ class TestStrategyTest(test.TestCase):
self.assertIs(dist, distribute.get_cross_tower_context())
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
- self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
+ expected_value = _get_test_variable(
+ "baz", variable_scope.VariableSynchronization.AUTO,
+ variable_scope.VariableAggregation.NONE)
+ self.assertDictEqual(expected_value,
+ variable_scope.variable(1.0, name="baz"))
+ _assert_in_default_state(self)
+
+ def testSettingSynchronizationAndAggregation(self):
+ _assert_in_default_state(self)
+ dist = _TestStrategy()
+ with dist.scope():
+ expected_value = _get_test_variable(
+ "baz", variable_scope.VariableSynchronization.ON_WRITE,
+ variable_scope.VariableAggregation.MEAN)
+ self.assertDictEqual(
+ expected_value,
+ variable_scope.variable(
+ 1.0,
+ name="baz",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN))
_assert_in_default_state(self)