diff options
author | 2018-08-16 21:31:39 -0700 | |
---|---|---|
committer | 2018-08-16 21:35:13 -0700 | |
commit | 439d8c4809139b163853fe87e8c5cdaba5d832eb (patch) | |
tree | b599479c27353cd3c5a537c3fddc388d03fec2f9 | |
parent | a1606d5e0f667fddd7f3f5705bda3aee5b3c2554 (diff) |
Merge MultiWorkerMirroredStrategy into MirroredStrategy
PiperOrigin-RevId: 209099475
8 files changed, 169 insertions, 247 deletions
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index d3628d480d..c16f1d6035 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -29,7 +29,6 @@ py_library( "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", - "//tensorflow/contrib/distribute/python:multi_worker_strategy", "//tensorflow/contrib/distribute/python:one_device_strategy", "//tensorflow/contrib/distribute/python:parameter_server_strategy", "//tensorflow/contrib/distribute/python:step_fn", diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 2c93ce92ce..588a4f2898 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -23,7 +23,6 @@ from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor -from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * @@ -40,7 +39,6 @@ _allowed_symbols = [ 'CrossTowerOps', 'DistributionStrategy', 'MirroredStrategy', - 'MultiWorkerMirroredStrategy', 'Monitor', 'OneDeviceStrategy', 'ParameterServerStrategy', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index ae50d4e3fc..59efd17746 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -72,31 +72,21 @@ py_library( ":cross_tower_ops", ":shared_variable_creator", ":values", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:device", "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", - "@six_archive//:six", - ], -) - -py_library( - name = "multi_worker_strategy", - srcs = ["multi_worker_strategy.py"], - visibility = ["//tensorflow:internal"], - deps = [ - ":mirrored_strategy", - ":values", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:training", - "//tensorflow/python:util", ], ) @@ -185,7 +175,6 @@ py_library( ], deps = [ ":mirrored_strategy", - ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", @@ -220,9 +209,13 @@ py_test( ], deps = [ ":mirrored_strategy", + ":multi_worker_test_base", ":strategy_test_lib", + "//tensorflow/python:constant_op", "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index aeec9c44d7..2fbadfe0f5 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -48,7 +48,6 @@ import six from tensorflow.contrib.cluster_resolver import TPUClusterResolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib -from tensorflow.contrib.distribute.python import multi_worker_strategy from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 @@ -344,31 +343,31 @@ mirrored_strategy_with_two_gpus = NamedDistribution( multi_worker_strategy_with_cpu = NamedDistribution( "MultiWorkerCPU", - lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster={ + lambda: mirrored_lib.MirroredStrategy( + cluster_spec={ "worker": [ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] }, - num_gpus_per_worker=0), 0) + num_gpus=0), 0) multi_worker_strategy_with_one_gpu = NamedDistribution( "MultiWorker1GPU", - lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster={ + lambda: mirrored_lib.MirroredStrategy( + cluster_spec={ "worker": [ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] }, - num_gpus_per_worker=1), 1) + num_gpus=1), 1) multi_worker_strategy_with_two_gpus = NamedDistribution( "MultiWorker2GPUs", - lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster={ + lambda: mirrored_lib.MirroredStrategy( + cluster_spec={ "worker": [ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] }, - num_gpus_per_worker=2), 2) + num_gpus=2), 2) adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index edd5c6d17a..6981449a4c 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function import contextlib +from functools import partial import threading from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import shared_variable_creator from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import tape @@ -37,6 +39,7 @@ from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import coordinator from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import server_lib from tensorflow.python.util import nest @@ -291,24 +294,112 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): class MirroredStrategy(distribute_lib.DistributionStrategy): - """Mirrors vars to distribute across multiple devices on a single machine. + """Mirrors vars to distribute across multiple devices and machines. + + This strategy uses one tower per device and sync replication for its multi-GPU + version. + + When `cluster_spec` is given, it turns into the mulit-worker version that + works on multiple workers with in-graph replication. + + There are several important concepts for distributed TensorFlow, e.g. + `client`, `job`, 'task', `cluster`, `in-graph replication` and + 'synchronous training' and they have already been defined in the + [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). + The distribution strategy inherits these concepts as well and in addition to + that we also clarify several more concepts: + * **In-graph replication**: the `client` creates a single `tf.Graph` that + specifies tasks for devices on all workers. The `client` then creates a + client session which will talk to the `master` service of a `worker`. Then + the `master` will partition the graph and distribute the work to all + participating workers. + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + physical machine. We will have multiple `worker`s with different `task` + index. They all do similar things except for one worker checkpointing model + variables, writing summaries, etc. in addition to its ordinary work. + + The multi-worker version of this class maps one tower to one device on a + worker. It mirrors all model variables on all towers. For example, if you have + two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of the + model variables on these 8 GPUs. Then like in MirroredStrategy, each tower + performs their computation with their own copy of variables unless in + cross-tower model where variable or tensor reduction happens. - This strategy uses one tower per device and sync replication. + Args: + devices: a list of device strings. + num_gpus: number of GPUs. For local training, either specify `devices` or + `num_gpus`. In distributed training, this must be specified as number of + GPUs on each worker. + cluster_spec: if this is set, it turns into the multi-worker version and + `devices` must not be set but `num_gpus` must be set. + cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not + set, the `configure` method will try to find the best one. + prefetch_on_device: optional boolean to specify whether to prefetch input + data to devices. """ def __init__(self, devices=None, num_gpus=None, + cluster_spec=None, cross_tower_ops=None, prefetch_on_device=None): super(MirroredStrategy, self).__init__() - # Convert `num_gpus` into `devices`, shouldn't specify both. - if devices is None: + + if cluster_spec: + if devices is not None: + raise ValueError("Specifying devices when `cluster_spec` is also given " + "is not supported in MirroredStrategy.") + + # TODO(yuefengz): use the utility method to normalize cluster_spec. + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + self._cluster_spec = cluster_spec + + self._workers = [] + for job in sorted(cluster_spec.jobs): + for task in range(cluster_spec.num_tasks(job)): + self._workers.append("/job:%s/task:%d" % (job, task)) + if num_gpus is None: - num_gpus = context.num_gpus() - devices = ["/device:GPU:%d" % d for d in range(num_gpus)] - elif num_gpus is not None: - raise ValueError("Must only specify one of `devices` and `num_gpus`.") + raise ValueError("`num_gpus` is required if `cluster_spec` is given.") + self._num_gpus = num_gpus + if num_gpus > 0: + self._worker_device_map = { + worker: [ + device_util.canonicalize(worker + "/device:GPU:%d" % gpu) + for gpu in range(num_gpus) + ] for worker in self._workers + } + else: + self._worker_device_map = { + worker: [device_util.canonicalize(worker, "/device:CPU:0")] + for worker in self._workers + } + devices = nest.flatten(self._worker_device_map) + + # Setting `_default_device` will add a device scope in the + # distribution.scope. We set the default device to the first worker. When + # users specify device under distribution.scope by + # with tf.device("/cpu:0"): + # ... + # their ops will end up on the cpu device of its first worker, e.g. + # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. + self._default_device = self._workers[0] + else: + self._cluster_spec = None + # Convert `num_gpus` into `devices`, shouldn't specify both. + if devices is None: + if num_gpus is None: + num_gpus = context.num_gpus() + devices = ["/device:GPU:%d" % d for d in range(num_gpus)] + elif num_gpus is not None: + raise ValueError("Must only specify one of `devices` and `num_gpus`.") + # TODO(yuefengz): consider setting the default device. assert devices, "Must specify at least one device." assert len(set(devices)) == len(devices), ( @@ -320,7 +411,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): {d: i for i, d in enumerate(devices)}) self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device - # TODO(yuefengz): consider setting the default device. def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" @@ -357,9 +447,14 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): **kwargs) def distribute_dataset(self, dataset_fn): - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), self._devices, - self._prefetch_on_device) + if self._cluster_spec: + return values.MultiWorkerDataset( + partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, + self._prefetch_on_device) + else: + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), self._devices, + self._prefetch_on_device) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _run_steps_on_dataset(self, fn, iterator, iterations, @@ -444,10 +539,19 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # in addition to PerDevice data. return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) - def configure(self, session_config=None): + def configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): + del cluster_spec, task_type, task_id if self._cross_tower_ops is None: - self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( - self._devices, session_config=session_config) + if self._cluster_spec: + self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce( + self._workers, self._num_gpus) + else: + self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( + self._devices, session_config=session_config) def _get_cross_tower_ops(self): if self._cross_tower_ops is None: diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index 5db2fff239..55d59adc07 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -19,12 +19,16 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import variable_scope from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import server_lib class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): @@ -86,5 +90,33 @@ class VariableCreatorStackTest(test.TestCase): self.assertEquals(expected, result) +class MultiWorkerMirroredStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return mirrored_strategy.MirroredStrategy( + cluster_spec=server_lib.ClusterSpec({ + 'worker': ['/job:worker/task:0', '/job:worker/task:1'] + }), + num_gpus=context.num_gpus()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testDeviceScope(self): + """Test the device scope of multi-worker MirroredStrategy.""" + with context.graph_mode(): + strategy = mirrored_strategy.MirroredStrategy( + cluster_spec={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, + num_gpus=context.num_gpus()) + with strategy.scope(): + a = constant_op.constant(1.) + with ops.device('/cpu:0'): + b = constant_op.constant(1.) + self.assertEqual(a.device, '/job:worker/task:0') + self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py deleted file mode 100644 index cbfe5df61d..0000000000 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Classes implementing a mirrored DistributionStrategy for multiple workers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from functools import partial - -from tensorflow.contrib.distribute.python import values -from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy -from tensorflow.core.protobuf import cluster_pb2 -from tensorflow.python.training import device_util -from tensorflow.python.training import server_lib -from tensorflow.python.util import nest - - -# TODO(yuefengz): support between-graph replication. -# TODO(yuefengz): merge this class into its base class. -# TODO(yuefengz): in some cases, we probably want to use configure method to -# configure this class. -# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the -# class is introduced. -class MultiWorkerMirroredStrategy(MirroredStrategy): - """Mirrored strategy that works on multiple workers with in-graph replication. - - There are several important concepts for distributed TensorFlow, e.g. - `client`, `job`, 'task', `cluster`, `in-graph replication` and - 'synchronous training' and they have already been defined in the - [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). - The distribution strategy inherits these concepts as well and in addition to - that we also clarify several more concepts: - * **In-graph replication**: the `client` creates a single `tf.Graph` that - specifies tasks for devices on all workers. The `client` then creates a - client session which will talk to the `master` service of a `worker`. Then - the `master` will partition the graph and distribute the work to all - participating workers. - * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one - physical machine. We will have multiple `worker`s with different `task` - index. They all do similar things except for one worker checkpointing model - variables, writing summaries, etc. in addition to its ordinary work. - - This class maps one tower to one device on a worker. It mirrors all model - variables on all towers. For example, if you have two `worker`s and each - `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8 - GPUs. Then like in MirroredStrategy, each tower performs their computation - with their own copy of variables unless in cross-tower model where variable or - tensor reduction happens. - """ - - def __init__(self, - num_gpus_per_worker=1, - worker_job_name=None, - num_workers=None, - cluster=None, - cross_tower_ops=None, - prefetch_on_device=None): - """Initialize the strategy object. - - Args: - num_gpus_per_worker: number of GPUs per work. If it is zero, the local - CPU will be used. - worker_job_name: the job name for `worker`, typically just 'worker'. - num_workers: the number of workers. If it is 0, it regenerates to - single-worker MirroredStrategy. - cluster: a `tf.train.ClusterSpec` object or a dict that can be used to - construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef` - proto buffer. It is an alternative way to initialize this object. - cross_tower_ops: the cross tower ops to use. If None, a default one will - be used. If configure method is called, a best one for the configuration - will be chosen. - prefetch_on_device: a boolean to specify whether to prefetech input to - each worker's devices. - - Raises: - ValueError: if got an unexpected `cluster`. - """ - if cluster is None: - self._workers = [ - '/job:%s/task:%d' % (worker_job_name, task_index) - for task_index in range(num_workers) - ] - else: - if isinstance(cluster, (dict, cluster_pb2.ClusterDef)): - cluster_spec = server_lib.ClusterSpec(cluster) - elif isinstance(cluster, server_lib.ClusterSpec): - cluster_spec = cluster - else: - raise ValueError( - "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " - '`tf.train.ClusterDef` object') - - self._workers = [] - for job in sorted(cluster_spec.jobs): - for task in range(cluster_spec.num_tasks(job)): - self._workers.append('/job:%s/task:%d' % (job, task)) - - self._num_gpus_per_worker = num_gpus_per_worker - if num_gpus_per_worker > 0: - self._worker_device_map = { - worker: [ - device_util.canonicalize(worker + '/device:GPU:%d' % gpu) - for gpu in range(num_gpus_per_worker) - ] for worker in self._workers - } - else: - self._worker_device_map = { - worker: [device_util.canonicalize(worker, '/device:CPU:0')] - for worker in self._workers - } - self._devices = nest.flatten(self._worker_device_map) - - super(MultiWorkerMirroredStrategy, self).__init__( - devices=self._devices, prefetch_on_device=prefetch_on_device) - - # Setting `_default_device` will add a device scope in the - # distribution.scope. We set the default device to the first worker. When - # users specify device under distribution.scope by - # with tf.device("/cpu:0"): - # ... - # their ops will end up on the cpu device of its first worker, e.g. - # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. - self._default_device = self._workers[0] - - def distribute_dataset(self, dataset_fn): - return values.MultiWorkerDataset( - partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py deleted file mode 100644 index 09c859b32a..0000000000 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for MultiWorkerMirroredStrategy.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import multi_worker_strategy -from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.training import server_lib - - -class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, - strategy_test_lib.DistributionTestBase): - - def _get_distribution_strategy(self): - return multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster=server_lib.ClusterSpec({ - 'worker': ['/job:worker/task:0', '/job:worker/task:1'] - }), - num_gpus_per_worker=context.num_gpus()) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) - - -class DeviceScopeTest(test.TestCase): - """Test the device scope of MultiWorkerMirroredStrategy.""" - - def testDeviceScope(self): - with context.graph_mode(): - strategy = multi_worker_strategy.MultiWorkerMirroredStrategy( - cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, - num_gpus_per_worker=context.num_gpus()) - with strategy.scope(): - a = constant_op.constant(1.) - with ops.device('/cpu:0'): - b = constant_op.constant(1.) - self.assertEqual(a.device, '/job:worker/task:0') - self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') - - -if __name__ == '__main__': - test.main() |