diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/strategy_test_lib.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/strategy_test_lib.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index d2fe8b3b1e..baed0ebaae 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer @@ -110,7 +111,8 @@ class DistributionTestBase(test.TestCase): before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce("sum", g, destinations=v) + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): after_list.append(d.read_var(v)) return before_list, after_list @@ -162,7 +164,8 @@ class DistributionTestBase(test.TestCase): fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce("sum", g, destinations=v) + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): after_list.append(d.read_var(v)) return before_list, after_list @@ -184,7 +187,7 @@ class DistributionTestBase(test.TestCase): with d.scope(): map_in = [constant_op.constant(i) for i in range(10)] map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.reduce("sum", map_out) + observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out) expected = 90 # 2 * (0 + 1 + ... + 9) self.assertEqual(expected, observed.numpy()) |