aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/cross_tower_ops_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py99
1 files changed, 56 insertions, 43 deletions
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 3508c9d599..490371477a 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -26,12 +26,12 @@ import numpy as np
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import test
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -40,9 +40,17 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
-def _make_per_device(values, devices):
+def _make_per_device(values, devices, regroup=False):
devices = cross_tower_ops_lib.get_devices_from(devices)
assert len(values) == len(devices)
+
+ # We simulate the result of regroup called on PerDevice which strips the
+ # PerDevice wrapper if it has only one value.
+ if len(values) == 1 and regroup:
+ with ops.device(devices[0]):
+ placed_v = array_ops.identity(values[0])
+ return placed_v
+
index = {}
for d, v in zip(devices, values):
with ops.device(d):
@@ -127,7 +135,7 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
destination_list = devices
all_destinations = [
- None, destination_mirrored, destination_different, destination_str,
+ destination_mirrored, destination_different, destination_str,
destination_list
]
@@ -138,24 +146,24 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
vs.VariableAggregation.MEAN,
per_device,
destinations=destinations),
- _fake_mirrored(mean, destinations or per_device))
+ _fake_mirrored(mean, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.MEAN,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2, destinations or per_device))
+ _fake_mirrored(mean_2, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM, per_device,
destinations=destinations),
- _fake_mirrored(mean * len(devices), destinations or per_device))
+ _fake_mirrored(mean * len(devices), destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2 * len(devices), destinations or per_device))
+ _fake_mirrored(mean_2 * len(devices), destinations))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
@@ -163,25 +171,22 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN,
[(per_device, d1), (per_device_2, d2)]),
[
- _fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)
+ _fake_mirrored(mean, d1),
+ _fake_mirrored(mean_2, d2)
])
self._assert_values_equal(
cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM,
[(per_device, d1), (per_device_2, d2)]),
[
- _fake_mirrored(mean * len(devices), d1 or per_device),
- _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
+ _fake_mirrored(mean * len(devices), d1),
+ _fake_mirrored(mean_2 * len(devices), d2)
])
# test broadcast()
for destinations in all_destinations:
- if destinations is None:
- continue
- else:
- self._assert_values_equal(
- cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
- _fake_mirrored(1., destinations))
+ self._assert_values_equal(
+ cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
+ _fake_mirrored(1., destinations))
class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
@@ -368,14 +373,27 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
("xring", 2, -1)], 0, 0, 0)),
],
distribution=[
- combinations.multi_worker_strategy_with_cpu,
- combinations.multi_worker_strategy_with_one_gpu,
- combinations.multi_worker_strategy_with_two_gpus
+ combinations.NamedDistribution(
+ "MirroredCPU",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=0),
+ required_gpus=0),
+ combinations.NamedDistribution(
+ "Mirrored1GPU",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=1),
+ required_gpus=1),
+ combinations.NamedDistribution(
+ "Mirrored2GPUs",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=2),
+ required_gpus=2),
],
mode=["graph"])
@combinations.generate(multi_worker_allreduce_combinations)
def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ distribution.configure(cluster_spec={
+ "worker":
+ ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"]
+ })
with distribution.scope():
self._testReductionAndBroadcast(cross_tower_ops, distribution)
@@ -388,13 +406,8 @@ class MultiWorkerCollectiveAllReduceTest(
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- "fake_worker_0", "fake_worker_1", "fake_worker_2"
- ]
- }
def setUp(self):
super(MultiWorkerCollectiveAllReduceTest, self).setUp()
@@ -428,7 +441,8 @@ class MultiWorkerCollectiveAllReduceTest(
]
else:
devices = ["/job:%s/task:%d" % (task_type, task_id)]
- return collective_all_reduce_ops, devices, self._workers[task_id].target
+ return (collective_all_reduce_ops, devices,
+ "grpc://" + self._cluster_spec[task_type][task_id])
def _assert_values_equal(self, left, right, sess):
if isinstance(left, list):
@@ -455,7 +469,8 @@ class MultiWorkerCollectiveAllReduceTest(
num_workers = 1
worker_device = None
else:
- num_workers = len(self._workers)
+ num_workers = len(self._cluster_spec.get("chief", [])) + len(
+ self._cluster_spec.get("worker", []))
worker_device = "/job:%s/task:%d" % (task_type, task_id)
with ops.Graph().as_default(), \
ops.device(worker_device), \
@@ -463,7 +478,7 @@ class MultiWorkerCollectiveAllReduceTest(
# Collective ops doesn't support scalar tensors, so we have to construct
# 1-d tensors.
values = [constant_op.constant([float(d)]) for d in range(len(devices))]
- per_device = _make_per_device(values, devices)
+ per_device = _make_per_device(values, devices, regroup=True)
mean = np.array([(len(devices) - 1.) / 2.])
values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
@@ -476,7 +491,7 @@ class MultiWorkerCollectiveAllReduceTest(
destination_list = devices
all_destinations = [
- destination_different, None, destination_mirrored, destination_str,
+ destination_different, destination_mirrored, destination_str,
destination_list
]
@@ -487,27 +502,27 @@ class MultiWorkerCollectiveAllReduceTest(
vs.VariableAggregation.MEAN,
per_device,
destinations=destinations),
- _fake_mirrored(mean, destinations or per_device), sess)
+ _fake_mirrored(mean, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.MEAN,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2, destinations or per_device), sess)
+ _fake_mirrored(mean_2, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
per_device,
destinations=destinations),
- _fake_mirrored(mean * len(devices) * num_workers, destinations or
- per_device), sess)
+ _fake_mirrored(mean * len(devices) * num_workers, destinations),
+ sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or
- per_device), sess)
+ _fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
+ sess)
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
@@ -516,24 +531,22 @@ class MultiWorkerCollectiveAllReduceTest(
[(per_device, d1),
(per_device_2, d2)]),
[
- _fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)
+ _fake_mirrored(mean, d1),
+ _fake_mirrored(mean_2, d2)
], sess)
self._assert_values_equal(
collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM,
[(per_device, d1),
(per_device_2, d2)]),
[
- _fake_mirrored(mean * len(devices) * num_workers, d1 or
- per_device),
- _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or
- per_device_2)
+ _fake_mirrored(mean * len(devices) * num_workers, d1),
+ _fake_mirrored(mean_2 * len(devices) * num_workers, d2)
], sess)
return True
@combinations.generate(
- combinations.combine(mode=["graph"], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1))
def testReductionDistributed(self, num_gpus):
if context.num_gpus() < num_gpus:
return