aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/strategy_test_lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/strategy_test_lib.py')
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index d2fe8b3b1e..baed0ebaae 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
@@ -110,7 +111,8 @@ class DistributionTestBase(test.TestCase):
before_list.append(fetched)
# control_dependencies irrelevant but harmless in eager execution
with ops.control_dependencies([fetched]):
- g = d.reduce("sum", g, destinations=v)
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -162,7 +164,8 @@ class DistributionTestBase(test.TestCase):
fetched = d.read_var(v)
before_list.append(fetched)
with ops.control_dependencies([fetched]):
- g = d.reduce("sum", g, destinations=v)
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -184,7 +187,7 @@ class DistributionTestBase(test.TestCase):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
- observed = d.reduce("sum", map_out)
+ observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out)
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())