diff options
author | 2018-08-16 12:25:17 -0700 | |
---|---|---|
committer | 2018-08-16 12:28:46 -0700 | |
commit | 1326f33515dac82fbdd7ef502d1df5e96986fc12 (patch) | |
tree | 0a363b47cf4c6f68ecef108430ced4966b888860 /tensorflow/python/distribute | |
parent | 5360d7368713fa3d4e1aedc682ba8bbd3362deba (diff) |
Use distribution strategy to configure distribute coordinator.
Add session_creator and a couple properties to worker context which then are used to configure monitored sessions.
PiperOrigin-RevId: 209026599
Diffstat (limited to 'tensorflow/python/distribute')
4 files changed, 400 insertions, 92 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 68d8b8d13b..16fbe3f4b5 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -41,3 +41,12 @@ py_test( "//tensorflow/python:variables", ], ) + +py_library( + name = "distribute_coordinator_context", + srcs = [ + "distribute_coordinator_context.py", + ], + srcs_version = "PY2AND3", + deps = [], +) diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index fc9ca4ac4a..eb081b65fc 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A unified and split coordinator for distributed TensorFlow.""" +"""A component for running distributed TensorFlow.""" from __future__ import absolute_import from __future__ import division @@ -24,6 +24,8 @@ import os import threading from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.distribute import distribute_coordinator_context +from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib @@ -43,23 +45,12 @@ class CoordinatorMode(object): # client and connects to remote servers for training. Each remote server can # use the distribute coordinator binary with task_type set correctly which # will then turn into standard servers. - SPLIT_CLIENT = 0 + STANDALONE_CLIENT = "standalone_client" # The distribute coordinator runs on each worker. It will run a standard # server on each worker and optionally run the `worker_fn` that is configured # to talk to its standard server. - INDEPENDENT_WORKER = 1 - - -_worker_context = threading.local() - - -def get_current_worker_context(): - """Returns the current task context.""" - try: - return _worker_context.current - except AttributeError: - return None + INDEPENDENT_WORKER = "independent_worker" class _Barrier(object): @@ -113,14 +104,17 @@ class _WorkerContext(object): """ def __init__(self, + strategy, cluster_spec, task_type, task_id, + session_config=None, rpc_layer="grpc", worker_barrier=None): """Initialize the worker context object. Args: + strategy: a `DistributionStrategy` object. cluster_spec: a ClusterSpec object. It can be empty or None in the local training case. task_type: a string indicating the role of the corresponding task, such as @@ -128,14 +122,17 @@ class _WorkerContext(object): replicated training. task_id: an integer indicating id of the corresponding task. It can be None if it is local training or in-graph replicated training. + session_config: an optional @{tf.ConfigProto} object. rpc_layer: optional string specifying the RPC protocol for communication with worker masters. If None or empty, hosts in the `cluster_spec` will be used directly. worker_barrier: optional, the barrier object for worker synchronization. """ + self._strategy = strategy self._cluster_spec = cluster_spec self._task_type = task_type self._task_id = task_id + self._session_config = session_config self._worker_barrier = worker_barrier self._rpc_layer = rpc_layer self._master_target = self._get_master_target() @@ -143,26 +140,31 @@ class _WorkerContext(object): self._is_chief_node = self._is_chief() def _debug_message(self): - return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( - self._cluster_spec, self.task_type, self.task_id) + if self._cluster_spec: + return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( + self._cluster_spec, self.task_type, self.task_id) + else: + return "[local]" def __enter__(self): - old_context = get_current_worker_context() + old_context = distribute_coordinator_context.get_current_worker_context() if old_context: raise ValueError( "You cannot run distribute coordinator in a `worker_fn`.\t" + self._debug_message()) - _worker_context.current = self + # pylint: disable=protected-access + distribute_coordinator_context._worker_context.current = self def __exit__(self, unused_exception_type, unused_exception_value, unused_traceback): - _worker_context.current = None + # pylint: disable=protected-access + distribute_coordinator_context._worker_context.current = None def _get_master_target(self): """Return the master target for a task.""" # If cluster_spec is None or empty, we use local master. if not self._cluster_spec: - return "local" + return "" # If task_type is None, then it is in-graph replicated training. In this # case we use the chief or first worker's master target. @@ -207,6 +209,47 @@ class _WorkerContext(object): self._debug_message()) self._worker_barrier.wait() + def session_creator(self, + scaffold=None, + config=None, + checkpoint_dir=None, + checkpoint_filename_with_path=None, + max_wait_secs=7200): + """Returns a session creator. + + The returned session creator will be configured with the correct master + target and session configs. It will also run either init ops or ready ops + by querying the `strategy` object when `create_session` is called on it. + + Args: + scaffold: A `Scaffold` used for gathering or building supportive ops. If + not specified a default one is created. It's used to finalize the graph. + config: `ConfigProto` proto used to configure the session. + checkpoint_dir: A string. Optional path to a directory where to restore + variables. + checkpoint_filename_with_path: Full file name path to the checkpoint file. + Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be + specified. + max_wait_secs: Maximum time to wait for the session to become available. + + Returns: + a descendant of SessionCreator. + """ + # TODO(yuefengz): merge session config. + if self._strategy.should_init: + return monitored_session.ChiefSessionCreator( + scaffold, + master=self.master_target, + config=config or self._session_config, + checkpoint_dir=checkpoint_dir, + checkpoint_filename_with_path=checkpoint_filename_with_path) + else: + return monitored_session.WorkerSessionCreator( + scaffold, + master=self.master_target, + config=config or self._session_config, + max_wait_secs=max_wait_secs) + @property def has_barrier(self): """Whether the barrier is set or not.""" @@ -247,21 +290,38 @@ class _WorkerContext(object): """Returns number of workers in the cluster, including chief.""" return self._num_workers + @property + def should_checkpoint(self): + """Whether to save checkpoint.""" + return self._strategy.should_checkpoint + + @property + def should_save_summary(self): + """Whether to save summaries.""" + return self._strategy.should_save_summary + def _run_single_worker(worker_fn, + strategy, cluster_spec, task_type, task_id, - rpc_layer, + session_config, + rpc_layer="", worker_barrier=None): """Runs a single worker by calling `worker_fn` under context.""" - with _WorkerContext( + strategy = copy.deepcopy(strategy) + strategy.configure(session_config, cluster_spec, task_type, task_id) + context = _WorkerContext( + strategy, cluster_spec, task_type, task_id, + session_config=session_config, rpc_layer=rpc_layer, - worker_barrier=worker_barrier): - worker_fn() + worker_barrier=worker_barrier) + with context: + worker_fn(strategy) def _run_std_server(cluster_spec=None, @@ -280,13 +340,15 @@ def _run_std_server(cluster_spec=None, return server -def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer): +def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config, + rpc_layer): """Runs a standalone client for between-graph replication.""" eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0), + args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, + session_config), kwargs={ "rpc_layer": rpc_layer, }) @@ -298,7 +360,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer): for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): t = threading.Thread( target=_run_single_worker, - args=(worker_fn, cluster_spec, task_type, task_id), + args=(worker_fn, strategy, cluster_spec, task_type, task_id, + session_config), kwargs={ "rpc_layer": rpc_layer, "worker_barrier": worker_barrier @@ -315,43 +378,53 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer): eval_thread.join() -def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer): +def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, + rpc_layer): """Runs a standalone client for in-graph replication.""" eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0), + args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, + session_config), kwargs={ "rpc_layer": rpc_layer, }) eval_thread.start() - _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer) + _run_single_worker( + worker_fn, + strategy, + cluster_spec, + None, + None, + session_config, + rpc_layer=rpc_layer) if eval_thread: eval_thread.join() - -# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode. +# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode. # TODO(yuefengz): we may need a smart way to figure out whether the current task # is the special task when we support cluster_spec propagation. def run_distribute_coordinator(worker_fn, - mode=CoordinatorMode.SPLIT_CLIENT, + strategy, + mode=CoordinatorMode.STANDALONE_CLIENT, cluster_spec=None, task_type=None, task_id=None, - between_graph=False, + session_config=None, rpc_layer="grpc"): """Runs the coordinator for distributed TensorFlow. This function runs a split coordinator for distributed TensorFlow in its - default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying - server addresses and their roles in a cluster, this coordinator will figure - out how to set them up, give the underlying function the right targets for - master sessions via a scope object and coordinate their training. The cluster - consisting of standard servers needs to be brought up either with the standard - server binary or with a binary running distribute coordinator with `task_type` - set to non-client type which will then turn into standard servers. + default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec` + specifying server addresses and their roles in a cluster, this coordinator + will figure out how to set them up, give the underlying function the right + targets for master sessions via a scope object and coordinate their training. + The cluster consisting of standard servers needs to be brought up either with + the standard server binary or with a binary running distribute coordinator + with `task_type` set to non-client type which will then turn into standard + servers. In addition to be the distribute coordinator, this is also the source of configurations for each job in the distributed training. As there are multiple @@ -370,6 +443,14 @@ def run_distribute_coordinator(worker_fn, `worker_fn` depending whether it is between-graph training or in-graph replicated training. + The `strategy` object is expected to be a DistributionStrategy object which + has implemented methods needed by distributed coordinator such as + `configure(session_config, cluster_spec, task_type, task_id)` which configures + the strategy object for a specific task and `should_init` property which + instructs the distribute coordinator whether to run init ops for a task. The + distribute coordinator will make a copy of the `strategy` object, call its + `configure` method and pass it to `worker_fn` as an argument. + The `worker_fn` defines the training logic and is called under a its own worker context which can be accessed to via `get_current_worker_context`. A worker context provides access to configurations for each task, e.g. the @@ -413,16 +494,20 @@ def run_distribute_coordinator(worker_fn, evaluation. Args: - worker_fn: the function to be called and given the access to a coordinator - context object. + worker_fn: the function to be called. The function should accept a + `strategy` object and will be given access to a context object via a + context manager scope. + strategy: a DistributionStrategy object which specifying whether it should + run between-graph replicated training or not, whether to run init ops, + etc. This object will also be configured given `session_config`, + `cluster_spc`, `task_type` and `task_id`. mode: in which mode this distribute coordinator runs. cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles in a cluster. If not set or empty, fall back to local training. task_type: the current task type, optional if this is a client. task_id: the current task id, optional if this is a client. - between_graph: a boolean. It is only useful when `cluster_spec` is set and - not empty. If true, it will use between-graph replicated training; - otherwise it will use in-graph replicated training. + session_config: an optional @{tf.ConfigProto} object which will be passed + to `strategy`'s `configure` method and used to create a session. rpc_layer: optional string, the protocol for RPC, e.g. "grpc". Raises: @@ -448,15 +533,18 @@ def run_distribute_coordinator(worker_fn, if not cluster_spec: # `mode` is ignored in the local case. - _run_single_worker(worker_fn, None, None, None, rpc_layer) - elif mode == CoordinatorMode.SPLIT_CLIENT: + _run_single_worker(worker_fn, strategy, None, None, None, session_config, + rpc_layer) + elif mode == CoordinatorMode.STANDALONE_CLIENT: # The client must know the cluster but servers in the cluster don't have to # know the client. if task_type in [_TaskType.CLIENT, None]: - if between_graph: - _run_between_graph_client(worker_fn, cluster_spec, rpc_layer) + if strategy.between_graph: + _run_between_graph_client(worker_fn, strategy, cluster_spec, + session_config, rpc_layer) else: - _run_in_graph_client(worker_fn, cluster_spec, rpc_layer) + _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, + rpc_layer) else: # If not a client job, run the standard server. server = _run_std_server( @@ -471,19 +559,21 @@ def run_distribute_coordinator(worker_fn, cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: - if between_graph: + if strategy.between_graph: # All jobs run `worker_fn` if between-graph. - _run_single_worker(worker_fn, cluster_spec, task_type, task_id, - rpc_layer) + _run_single_worker(worker_fn, strategy, cluster_spec, task_type, + task_id, session_config, rpc_layer) else: # Only one node runs `worker_fn` if in-graph. - context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer) + context = _WorkerContext(strategy, cluster_spec, task_type, task_id) if context.is_chief: - _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer) + _run_single_worker(worker_fn, strategy, cluster_spec, None, None, + session_config, rpc_layer) else: server.join() elif task_type == _TaskType.EVALUATOR: - _run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer) + _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id, + session_config, rpc_layer) else: if task_type != _TaskType.PS: raise ValueError("Unexpected task_type: %r" % task_type) diff --git a/tensorflow/python/distribute/distribute_coordinator_context.py b/tensorflow/python/distribute/distribute_coordinator_context.py new file mode 100644 index 0000000000..dee65ce883 --- /dev/null +++ b/tensorflow/python/distribute/distribute_coordinator_context.py @@ -0,0 +1,31 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The context retrieval method for distribute coordinator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +_worker_context = threading.local() + + +def get_current_worker_context(): + """Returns the current task context.""" + try: + return _worker_context.current + except AttributeError: + return None diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py index 319c29ba2f..97c6bdd15a 100644 --- a/tensorflow/python/distribute/distribute_coordinator_test.py +++ b/tensorflow/python/distribute/distribute_coordinator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for distribute coordinator.""" +"""Tests for Distribute Coordinator.""" from __future__ import absolute_import from __future__ import division @@ -37,6 +37,7 @@ except ImportError as _error: from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.distribute import distribute_coordinator +from tensorflow.python.distribute import distribute_coordinator_context from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops @@ -44,17 +45,17 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + CHIEF = distribute_coordinator._TaskType.CHIEF WORKER = distribute_coordinator._TaskType.WORKER PS = distribute_coordinator._TaskType.PS EVALUATOR = distribute_coordinator._TaskType.EVALUATOR -SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT +STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER -RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server" - NUM_WORKERS = 3 NUM_PS = 2 @@ -74,6 +75,57 @@ def _strip_protocol(target): return target +class MockStrategy(object): + + def __init__(self, + between_graph=False, + should_init=None, + should_checkpoint=None, + should_save_summary=None): + self._between_graph = between_graph + self._should_init = should_init + self._should_checkpoint = should_checkpoint + self._should_save_summary = should_save_summary + + @property + def between_graph(self): + return self._between_graph + + def configure(self, + session_options=None, + cluster_spec=None, + task_type=None, + task_id=None): + del session_options, cluster_spec, task_type + if self._should_init is None: + if task_id == 0: + self._should_init = True + else: + self._should_init = False + if self._should_checkpoint is None: + if task_id == 0: + self._should_checkpoint = True + else: + self._should_checkpoint = False + if self._should_save_summary is None: + if task_id == 0: + self._should_save_summary = True + else: + self._should_save_summary = False + + @property + def should_init(self): + return self._should_init + + @property + def should_checkpoint(self): + return self._should_checkpoint + + @property + def should_save_summary(self): + return self._should_save_summary + + class MockServer(object): def __init__(self): @@ -108,6 +160,7 @@ class DistributeCoordinatorTestBase(test.TestCase): self._result_correct = 0 self._lock = threading.Lock() self._worker_context = {} + self._strategy_property = {} self._std_servers = {} self._barrier = distribute_coordinator._Barrier(NUM_WORKERS) @@ -142,8 +195,8 @@ class DistributeCoordinatorTestBase(test.TestCase): cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()] return cluster_spec - def _in_graph_worker_fn(self): - context = distribute_coordinator.get_current_worker_context() + def _in_graph_worker_fn(self, strategy): + context = distribute_coordinator_context.get_current_worker_context() self.assertTrue(context is not None) with self._test_session(target=context.master_target) as sess: xs = [] @@ -164,22 +217,23 @@ class DistributeCoordinatorTestBase(test.TestCase): if result_value == expected: self._result_correct += 1 - def _run_coordinator_in_thread(self, worker_fn, **kwargs): + def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs): t = threading.Thread( target=distribute_coordinator.run_distribute_coordinator, - args=(worker_fn,), + args=(worker_fn, strategy), kwargs=kwargs) t.start() return t - def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec, - **kwargs): + def _run_multiple_coordinator_in_threads(self, worker_fn, strategy, + cluster_spec, **kwargs): threads = {} for task_type in cluster_spec.keys(): threads[task_type] = [] for task_id in range(len(cluster_spec[task_type])): t = self._run_coordinator_in_thread( worker_fn, + strategy, cluster_spec=cluster_spec, task_type=task_type, task_id=task_id, @@ -187,8 +241,8 @@ class DistributeCoordinatorTestBase(test.TestCase): threads[task_type].append(t) return threads - def _between_graph_worker_fn(self): - context = distribute_coordinator.get_current_worker_context() + def _between_graph_worker_fn(self, strategy): + context = distribute_coordinator_context.get_current_worker_context() self.assertTrue(context is not None) with self._test_session(target=context.master_target) as sess: with ops.device("/job:ps/task:0"): @@ -234,14 +288,50 @@ class DistributeCoordinatorTestBase(test.TestCase): with self._lock: self._result_correct += 1 - def _dump_worker_context(self): + def _between_graph_with_monitored_session(self, strategy): + context = distribute_coordinator_context.get_current_worker_context() + self.assertTrue(context is not None) + with ops.device("/job:ps/task:0"): + # TODO(yuefengz): investigate why not using resource variable will make + # the test flaky. + x = variable_scope.get_variable("x", initializer=10.0, use_resource=True) + with ops.device("/job:ps/task:1"): + y = variable_scope.get_variable("y", initializer=20.0, use_resource=True) + + x_add = x.assign_add(2.0) + y_sub = y.assign_sub(2.0) + train_op = control_flow_ops.group([x_add, y_sub]) + + # The monitored session will run init or ready ops. + with monitored_session.MonitoredSession() as sess: + sess.run(train_op) + + # Synchronize workers after one step to make sure they all have finished + # training. + if context.has_barrier: + context.wait_for_other_workers() + else: + self._barrier.wait() + + x_val, y_val = sess.run([x, y]) + + self.assertEqual(x_val, 16.0) + self.assertEqual(y_val, 14.0) + if x_val == 16.0 and y_val == 14.0: + with self._lock: + self._result_correct += 1 + + def _dump_worker_context(self, strategy): """Dumps the propoerties of each worker context. It dumps the context properties to a dict mapping from task_type to a list of tuples of master_target, num_workers, is_chief and distribute_mode, where the list is indexed by the task_id. + + Args: + strategy: a `DistributionStrategy` object. """ - context = distribute_coordinator.get_current_worker_context() + context = distribute_coordinator_context.get_current_worker_context() self.assertTrue(context is not None) task_type = str(context.task_type) task_id = context.task_id or 0 @@ -255,6 +345,25 @@ class DistributeCoordinatorTestBase(test.TestCase): context.is_chief, context.distributed_mode) + def _dump_strategy_property(self, strategy): + context = distribute_coordinator_context.get_current_worker_context() + self.assertTrue(context is not None) + + self.assertEqual(context._strategy.should_init, strategy.should_init) + self.assertEqual(context.should_checkpoint, strategy.should_checkpoint) + self.assertEqual(context.should_save_summary, strategy.should_save_summary) + + task_type = str(context.task_type) + task_id = context.task_id or 0 + with self._lock: + if task_type not in self._strategy_property: + self._strategy_property[task_type] = [] + while len(self._strategy_property[task_type]) <= task_id: + self._strategy_property[task_type].append(None) + self._strategy_property[task_type][task_id] = ( + context._strategy.should_init, context.should_checkpoint, + context.should_save_summary) + def _run_mock_std_server(self, session_config=None, cluster_spec=None, @@ -274,22 +383,32 @@ class DistributeCoordinatorTestBase(test.TestCase): return server -class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): +class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase): - def testInGraphSplitMode(self): - """Test it runs in-graph replication in split client mode.""" + def testInGraphStandaloneMode(self): + """Test it runs in-graph replication in standalone client mode.""" distribute_coordinator.run_distribute_coordinator( self._in_graph_worker_fn, - cluster_spec=self._cluster_spec, - between_graph=False) + MockStrategy(between_graph=False), + cluster_spec=self._cluster_spec) self.assertEqual(self._result_correct, 1) def testBetweenGraph(self): - """Test it runs between-graph replication in split client mode.""" + """Test it runs between-graph replication in standalone client mode.""" distribute_coordinator.run_distribute_coordinator( self._between_graph_worker_fn, - cluster_spec=self._cluster_spec, - between_graph=True) + MockStrategy(between_graph=True), + cluster_spec=self._cluster_spec) + + # Each finished worker will increment self._result_correct. + self.assertEqual(self._result_correct, NUM_WORKERS) + + def testBetweenGraphWithMonitoredSession(self): + """Test monitored session in standalone client mode.""" + distribute_coordinator.run_distribute_coordinator( + self._between_graph_with_monitored_session, + MockStrategy(between_graph=True), + cluster_spec=self._cluster_spec) # Each finished worker will increment self._result_correct. self.assertEqual(self._result_correct, NUM_WORKERS) @@ -298,8 +417,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, - cluster_spec=self._cluster_spec, - between_graph=True) + MockStrategy(between_graph=True), + cluster_spec=self._cluster_spec) # There is only one type of task and there three such tasks. self.assertEqual(len(self._worker_context), 1) @@ -318,12 +437,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): self._worker_context[WORKER][2], (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True)) + def testBetweenGraphStrategyProperties(self): + # Dumps properties of the strategy objects. + distribute_coordinator.run_distribute_coordinator( + self._dump_strategy_property, + MockStrategy(between_graph=True, should_init=True), + cluster_spec=self._cluster_spec) + + # There is only one type of task and there three such tasks. + self.assertEqual(len(self._strategy_property), 1) + self.assertTrue(WORKER in self._strategy_property) + self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS) + + # Check whether each task has the right properties of should_init, + # should_checkpoint and should_save_summary. + self.assertEqual(self._strategy_property[WORKER][0], (True, True, True)) + self.assertEqual(self._strategy_property[WORKER][1], (True, False, False)) + self.assertEqual(self._strategy_property[WORKER][2], (True, False, False)) + def testInGraphContext(self): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, - cluster_spec=self._cluster_spec, - between_graph=False) + MockStrategy(between_graph=False), + cluster_spec=self._cluster_spec) # There is only a "None" task in the dumped task context. self.assertEqual(len(self._worker_context), 1) @@ -339,7 +476,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): def testLocalContext(self): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( - self._dump_worker_context, cluster_spec=None, between_graph=True) + self._dump_worker_context, + MockStrategy(between_graph=False), + cluster_spec=None) # There is only a "None" task. self.assertEqual(len(self._worker_context), 1) @@ -348,7 +487,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. - self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False)) + self.assertEqual(self._worker_context["None"][0], ("", 0, True, False)) def testBetweenGraphContextWithChief(self): # Adds a chief node, so there are NUM_WORKERS + 1 workers in total. @@ -358,8 +497,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, + MockStrategy(between_graph=True), cluster_spec=cluster_spec, - between_graph=True, rpc_layer="grpc") # There are one CHIEF and three workers. @@ -391,8 +530,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, + MockStrategy(between_graph=False), cluster_spec=cluster_spec, - between_graph=False, rpc_layer=None) # There are one "None" task and one EVALUATOR task. @@ -417,8 +556,8 @@ class DistributeCoordinatorTestInpendentWorkerMode( cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) threads = self._run_multiple_coordinator_in_threads( self._in_graph_worker_fn, + MockStrategy(between_graph=False), cluster_spec, - between_graph=False, mode=INDEPENDENT_WORKER) threads[WORKER][0].join() self.assertEqual(self._result_correct, 1) @@ -428,8 +567,22 @@ class DistributeCoordinatorTestInpendentWorkerMode( num_workers=NUM_WORKERS, num_ps=NUM_PS) threads = self._run_multiple_coordinator_in_threads( self._between_graph_worker_fn, + MockStrategy(between_graph=True), + cluster_spec, + mode=INDEPENDENT_WORKER) + for task_id in range(NUM_WORKERS): + threads[WORKER][task_id].join() + + # Each finished worker will increment self._result_correct. + self.assertEqual(self._result_correct, NUM_WORKERS) + + def testBetweenGraphWithMonitoredSession(self): + cluster_spec = self._create_cluster_spec( + num_workers=NUM_WORKERS, num_ps=NUM_PS) + threads = self._run_multiple_coordinator_in_threads( + self._between_graph_with_monitored_session, + MockStrategy(between_graph=True), cluster_spec, - between_graph=True, mode=INDEPENDENT_WORKER) for task_id in range(NUM_WORKERS): threads[WORKER][task_id].join() @@ -444,9 +597,9 @@ class DistributeCoordinatorTestInpendentWorkerMode( self._run_mock_std_server): threads = self._run_multiple_coordinator_in_threads( self._dump_worker_context, + MockStrategy(between_graph=True), cluster_spec, mode=INDEPENDENT_WORKER, - between_graph=True, rpc_layer=None) for task_id in range(NUM_WORKERS): threads[WORKER][task_id].join() @@ -476,6 +629,31 @@ class DistributeCoordinatorTestInpendentWorkerMode( self.assertFalse(self._std_servers[WORKER][1].joined) self.assertFalse(self._std_servers[WORKER][2].joined) + def testBetweenGraphStrategyProperties(self): + cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) + # Dumps properties of the strategy objects. + with test.mock.patch.object(distribute_coordinator, "_run_std_server", + self._run_mock_std_server): + threads = self._run_multiple_coordinator_in_threads( + self._dump_strategy_property, + MockStrategy(between_graph=True, should_init=True), + cluster_spec, + mode=INDEPENDENT_WORKER, + rpc_layer=None) + for task_id in range(NUM_WORKERS): + threads[WORKER][task_id].join() + + # There is only one type of task and there three such tasks. + self.assertEqual(len(self._strategy_property), 1) + self.assertTrue(WORKER in self._strategy_property) + self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS) + + # Check whether each task has the right properties of should_init, + # should_checkpoint and should_save_summary. + self.assertEqual(self._strategy_property[WORKER][0], (True, True, True)) + self.assertEqual(self._strategy_property[WORKER][1], (True, False, False)) + self.assertEqual(self._strategy_property[WORKER][2], (True, False, False)) + def testInGraphContext(self): cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) # Dumps the task contexts and std server arguments. @@ -483,9 +661,9 @@ class DistributeCoordinatorTestInpendentWorkerMode( self._run_mock_std_server): threads = self._run_multiple_coordinator_in_threads( self._dump_worker_context, + MockStrategy(between_graph=False), cluster_spec, mode=INDEPENDENT_WORKER, - between_graph=False, rpc_layer=None) for task_id in range(NUM_WORKERS): threads[WORKER][task_id].join() @@ -519,9 +697,9 @@ class DistributeCoordinatorTestInpendentWorkerMode( self._run_mock_std_server): threads = self._run_multiple_coordinator_in_threads( self._dump_worker_context, + MockStrategy(between_graph=False), cluster_spec, mode=INDEPENDENT_WORKER, - between_graph=False, rpc_layer=None) for task_id in range(NUM_WORKERS): threads[WORKER][task_id].join() |